re-init
This commit is contained in:
296
scripts/convert-keras-to-tfjs.py
Normal file
296
scripts/convert-keras-to-tfjs.py
Normal file
@@ -0,0 +1,296 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Inspect and convert a .keras plant disease model to TF.js GraphModel format.
|
||||
|
||||
Uses tensorflowjs_converter CLI to avoid Keras version deserialization issues.
|
||||
|
||||
Usage:
|
||||
pip3 install tensorflowjs # also pulls tensorflow as dependency
|
||||
python3 scripts/convert-keras-to-tfjs.py
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
MODEL_PATH = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
"public",
|
||||
"models",
|
||||
"plant-disease-classifier",
|
||||
"best_mnv2_pv_original.keras",
|
||||
)
|
||||
|
||||
OUTPUT_DIR = os.path.join(
|
||||
os.path.dirname(MODEL_PATH),
|
||||
"tfjs_model",
|
||||
)
|
||||
|
||||
|
||||
def inspect_keras_metadata():
|
||||
"""Read .keras archive metadata without loading the model."""
|
||||
print("=" * 60)
|
||||
print("MODEL INSPECTION (metadata only)")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
import zipfile
|
||||
except ImportError:
|
||||
print("ERROR: zipfile not available")
|
||||
sys.exit(1)
|
||||
|
||||
if not os.path.exists(MODEL_PATH):
|
||||
print(f"ERROR: Model not found at {MODEL_PATH}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"\nModel file: {MODEL_PATH}")
|
||||
print(
|
||||
f"File size: {os.path.getsize(MODEL_PATH):,} bytes ({os.path.getsize(MODEL_PATH) / 1024 / 1024:.1f} MB)"
|
||||
)
|
||||
|
||||
# .keras files are ZIP archives
|
||||
with zipfile.ZipFile(MODEL_PATH) as zf:
|
||||
names = zf.namelist()
|
||||
print(f"\nArchive contents ({len(names)} entries):")
|
||||
for name in names:
|
||||
info = zf.getinfo(name)
|
||||
print(f" {name:<40s} {info.file_size:>10,} bytes")
|
||||
|
||||
# Read config.json for model architecture info
|
||||
config_path = None
|
||||
for name in names:
|
||||
if name.endswith("config.json"):
|
||||
config_path = name
|
||||
break
|
||||
|
||||
if config_path:
|
||||
print(f"\nReading {config_path}...")
|
||||
with zf.open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Extract key info
|
||||
model_type = config.get("class_name", "unknown")
|
||||
print(f"Model class: {model_type}")
|
||||
|
||||
# Try to find output layer info
|
||||
if "config" in config:
|
||||
inner_config = config["config"]
|
||||
|
||||
# Look for output shape in config
|
||||
if "output_shape" in inner_config:
|
||||
print(f"Output shape: {inner_config['output_shape']}")
|
||||
|
||||
# Look through layers for the final dense layer
|
||||
if "layers" in inner_config:
|
||||
layers = inner_config["layers"]
|
||||
print(f"\nLayers ({len(layers)} total):")
|
||||
for layer in layers:
|
||||
layer_name = layer.get("config", {}).get("name", "?")
|
||||
layer_class = layer.get("class_name", "?")
|
||||
layer_module = layer.get("module", "?")
|
||||
|
||||
# Extract units/activation for dense layers
|
||||
layer_config = layer.get("config", {})
|
||||
units = layer_config.get("units")
|
||||
activation = layer_config.get("activation")
|
||||
|
||||
detail = ""
|
||||
if units:
|
||||
detail = f" units={units}"
|
||||
if activation:
|
||||
detail += f" activation={activation}"
|
||||
|
||||
print(f" {layer_name:<30s} {layer_class:<20s}{detail}")
|
||||
|
||||
# Find last dense layer for class count
|
||||
for layer in reversed(layers):
|
||||
if layer.get("class_name") == "Dense":
|
||||
units = layer.get("config", {}).get("units")
|
||||
activation = layer.get("config", {}).get("activation")
|
||||
print("\nClassification head:")
|
||||
print(f" Units (classes): {units}")
|
||||
print(f" Activation: {activation}")
|
||||
print(
|
||||
f" Layer name: {layer.get('config', {}).get('name', '?')}"
|
||||
)
|
||||
break
|
||||
|
||||
# Check compile config
|
||||
if "compile_config" in config:
|
||||
compile_cfg = config["compile_config"]
|
||||
optimizer = compile_cfg.get("optimizer", {})
|
||||
if isinstance(optimizer, dict):
|
||||
opt_name = optimizer.get("class_name", "?")
|
||||
lr = optimizer.get("config", {}).get("learning_rate")
|
||||
print("\nTraining config:")
|
||||
print(f" Optimizer: {opt_name}")
|
||||
if lr:
|
||||
print(f" Learning rate: {lr}")
|
||||
loss = compile_cfg.get("loss", "?")
|
||||
metrics = compile_cfg.get("metrics", [])
|
||||
print(f" Loss: {loss}")
|
||||
print(f" Metrics: {metrics}")
|
||||
|
||||
# Check input shape
|
||||
if "build_config" in config:
|
||||
build_cfg = config["build_config"]
|
||||
if "input_shape" in build_cfg:
|
||||
print(f"\nInput shape: {build_cfg['input_shape']}")
|
||||
|
||||
|
||||
def convert_to_tfjs():
|
||||
"""Convert using tensorflowjs_converter CLI."""
|
||||
print("\n" + "=" * 60)
|
||||
print("CONVERTING TO TF.JS GRAPH MODEL")
|
||||
print("=" * 60)
|
||||
|
||||
# Check tensorflowjs_converter CLI is available
|
||||
converter = shutil.which("tensorflowjs_converter")
|
||||
if not converter:
|
||||
print("ERROR: tensorflowjs_converter not found in PATH.")
|
||||
print(" pip3 install tensorflowjs")
|
||||
sys.exit(1)
|
||||
|
||||
# Clean output dir
|
||||
if os.path.exists(OUTPUT_DIR):
|
||||
print(f"Removing existing output dir: {OUTPUT_DIR}")
|
||||
shutil.rmtree(OUTPUT_DIR)
|
||||
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
print(f"\nConverting {MODEL_PATH} -> {OUTPUT_DIR}/")
|
||||
print("(this may take a minute...)")
|
||||
|
||||
# Use the venv's python to run the converter (avoids import issues)
|
||||
python_exe = sys.executable # the python running this script
|
||||
result = subprocess.run(
|
||||
[
|
||||
python_exe,
|
||||
"-m",
|
||||
"tensorflowjs.converters.converter",
|
||||
"--input_format=keras",
|
||||
"--output_format=tfjs_graph_model",
|
||||
MODEL_PATH,
|
||||
OUTPUT_DIR,
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
print("\nERROR: Conversion failed!")
|
||||
print(f"stdout: {result.stdout}")
|
||||
print(f"stderr: {result.stderr}")
|
||||
sys.exit(1)
|
||||
|
||||
if result.stdout:
|
||||
print(result.stdout)
|
||||
if result.stderr:
|
||||
# Some warnings are normal
|
||||
print(f"Converter output: {result.stderr}")
|
||||
|
||||
# Verify output
|
||||
model_json_path = os.path.join(OUTPUT_DIR, "model.json")
|
||||
if not os.path.exists(model_json_path):
|
||||
print("ERROR: Conversion did not produce model.json")
|
||||
sys.exit(1)
|
||||
|
||||
# List output files
|
||||
files = os.listdir(OUTPUT_DIR)
|
||||
total_size = sum(
|
||||
os.path.getsize(os.path.join(OUTPUT_DIR, f))
|
||||
for f in files
|
||||
if os.path.isfile(os.path.join(OUTPUT_DIR, f))
|
||||
)
|
||||
|
||||
print("\nConversion complete!")
|
||||
print(f"Output directory: {OUTPUT_DIR}/")
|
||||
print(f"Files: {len(files)}")
|
||||
for f in sorted(files):
|
||||
fpath = os.path.join(OUTPUT_DIR, f)
|
||||
if os.path.isfile(fpath):
|
||||
size = os.path.getsize(fpath)
|
||||
print(f" {f:<30s} {size:>10,} bytes")
|
||||
print(f"Total size: {total_size:,} bytes ({total_size / 1024 / 1024:.1f} MB)")
|
||||
|
||||
# Read model.json to check config
|
||||
with open(model_json_path) as f:
|
||||
model_json = json.load(f)
|
||||
|
||||
print(f"\nTF.js model format: {model_json.get('format', 'unknown')}")
|
||||
print(f"Generated by: {model_json.get('generatedBy', 'unknown')}")
|
||||
|
||||
# Inspect model topology
|
||||
if "modelTopology" in model_json:
|
||||
topology = model_json["modelTopology"]
|
||||
print("\nModel topology:")
|
||||
print(f" Name: {topology.get('model_name', 'unnamed')}")
|
||||
print(f" Ops: {len(topology.get('node', []))} nodes")
|
||||
|
||||
# Input/output nodes
|
||||
inputs = topology.get("inputs", {})
|
||||
outputs = topology.get("outputs", {})
|
||||
print(f" Inputs: {list(inputs.keys())}")
|
||||
for name, info in inputs.items():
|
||||
shape = info.get("tensorShape", {})
|
||||
print(f" {name}: shape={shape.get('dim', 'unknown')}")
|
||||
print(f" Outputs: {list(outputs.keys())}")
|
||||
for name, info in outputs.items():
|
||||
shape = info.get("tensorShape", {})
|
||||
print(f" {name}: shape={shape.get('dim', 'unknown')}")
|
||||
|
||||
# Check weights specification
|
||||
if "weightsManifest" in model_json:
|
||||
manifest = model_json["weightsManifest"]
|
||||
print(f"\nWeight manifests: {len(manifest)}")
|
||||
for i, m in enumerate(manifest):
|
||||
shards = m.get("shards", [])
|
||||
print(f" Manifest {i}: {len(shards)} shard(s)")
|
||||
|
||||
return OUTPUT_DIR
|
||||
|
||||
|
||||
def main():
|
||||
if not os.path.exists(MODEL_PATH):
|
||||
print(f"ERROR: Model not found at {MODEL_PATH}")
|
||||
sys.exit(1)
|
||||
|
||||
# Step 1: Inspect metadata
|
||||
inspect_keras_metadata()
|
||||
|
||||
# Step 2: Convert
|
||||
output_dir = convert_to_tfjs()
|
||||
|
||||
# Step 3: Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("NEXT STEPS")
|
||||
print("=" * 60)
|
||||
print(f"""
|
||||
1. Move the TF.js model to the expected location:
|
||||
The model-loader expects model.json at:
|
||||
public/models/plant-disease-classifier/model.json
|
||||
|
||||
Move files:
|
||||
mv {output_dir}/model.json public/models/plant-disease-classifier/
|
||||
mv {output_dir}/group1-shard* public/models/plant-disease-classifier/
|
||||
|
||||
2. IMPORTANT: This model has 38 output classes (original PlantVillage).
|
||||
Your labels.ts expects 95 classes (93 diseases + healthy + unknown).
|
||||
You'll need to either:
|
||||
a) Fine-tune the model with your 95-class dataset, OR
|
||||
b) Map the 38 PlantVillage classes to your disease IDs
|
||||
|
||||
3. Install @tensorflow/tfjs in your project:
|
||||
npm install @tensorflow/tfjs
|
||||
|
||||
4. Test with your API:
|
||||
npm run dev
|
||||
POST /api/identify with an uploaded image
|
||||
""")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user