Files
plant-disease-id/scripts/convert-keras-to-tfjs.py
2026-06-08 16:42:04 -04:00

297 lines
9.9 KiB
Python

#!/usr/bin/env python3
"""
Inspect and convert a .keras plant disease model to TF.js GraphModel format.
Uses tensorflowjs_converter CLI to avoid Keras version deserialization issues.
Usage:
pip3 install tensorflowjs # also pulls tensorflow as dependency
python3 scripts/convert-keras-to-tfjs.py
"""
import json
import os
import shutil
import subprocess
import sys
MODEL_PATH = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"public",
"models",
"plant-disease-classifier",
"best_mnv2_pv_original.keras",
)
OUTPUT_DIR = os.path.join(
os.path.dirname(MODEL_PATH),
"tfjs_model",
)
def inspect_keras_metadata():
"""Read .keras archive metadata without loading the model."""
print("=" * 60)
print("MODEL INSPECTION (metadata only)")
print("=" * 60)
try:
import zipfile
except ImportError:
print("ERROR: zipfile not available")
sys.exit(1)
if not os.path.exists(MODEL_PATH):
print(f"ERROR: Model not found at {MODEL_PATH}")
sys.exit(1)
print(f"\nModel file: {MODEL_PATH}")
print(
f"File size: {os.path.getsize(MODEL_PATH):,} bytes ({os.path.getsize(MODEL_PATH) / 1024 / 1024:.1f} MB)"
)
# .keras files are ZIP archives
with zipfile.ZipFile(MODEL_PATH) as zf:
names = zf.namelist()
print(f"\nArchive contents ({len(names)} entries):")
for name in names:
info = zf.getinfo(name)
print(f" {name:<40s} {info.file_size:>10,} bytes")
# Read config.json for model architecture info
config_path = None
for name in names:
if name.endswith("config.json"):
config_path = name
break
if config_path:
print(f"\nReading {config_path}...")
with zf.open(config_path) as f:
config = json.load(f)
# Extract key info
model_type = config.get("class_name", "unknown")
print(f"Model class: {model_type}")
# Try to find output layer info
if "config" in config:
inner_config = config["config"]
# Look for output shape in config
if "output_shape" in inner_config:
print(f"Output shape: {inner_config['output_shape']}")
# Look through layers for the final dense layer
if "layers" in inner_config:
layers = inner_config["layers"]
print(f"\nLayers ({len(layers)} total):")
for layer in layers:
layer_name = layer.get("config", {}).get("name", "?")
layer_class = layer.get("class_name", "?")
layer_module = layer.get("module", "?")
# Extract units/activation for dense layers
layer_config = layer.get("config", {})
units = layer_config.get("units")
activation = layer_config.get("activation")
detail = ""
if units:
detail = f" units={units}"
if activation:
detail += f" activation={activation}"
print(f" {layer_name:<30s} {layer_class:<20s}{detail}")
# Find last dense layer for class count
for layer in reversed(layers):
if layer.get("class_name") == "Dense":
units = layer.get("config", {}).get("units")
activation = layer.get("config", {}).get("activation")
print("\nClassification head:")
print(f" Units (classes): {units}")
print(f" Activation: {activation}")
print(
f" Layer name: {layer.get('config', {}).get('name', '?')}"
)
break
# Check compile config
if "compile_config" in config:
compile_cfg = config["compile_config"]
optimizer = compile_cfg.get("optimizer", {})
if isinstance(optimizer, dict):
opt_name = optimizer.get("class_name", "?")
lr = optimizer.get("config", {}).get("learning_rate")
print("\nTraining config:")
print(f" Optimizer: {opt_name}")
if lr:
print(f" Learning rate: {lr}")
loss = compile_cfg.get("loss", "?")
metrics = compile_cfg.get("metrics", [])
print(f" Loss: {loss}")
print(f" Metrics: {metrics}")
# Check input shape
if "build_config" in config:
build_cfg = config["build_config"]
if "input_shape" in build_cfg:
print(f"\nInput shape: {build_cfg['input_shape']}")
def convert_to_tfjs():
"""Convert using tensorflowjs_converter CLI."""
print("\n" + "=" * 60)
print("CONVERTING TO TF.JS GRAPH MODEL")
print("=" * 60)
# Check tensorflowjs_converter CLI is available
converter = shutil.which("tensorflowjs_converter")
if not converter:
print("ERROR: tensorflowjs_converter not found in PATH.")
print(" pip3 install tensorflowjs")
sys.exit(1)
# Clean output dir
if os.path.exists(OUTPUT_DIR):
print(f"Removing existing output dir: {OUTPUT_DIR}")
shutil.rmtree(OUTPUT_DIR)
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"\nConverting {MODEL_PATH} -> {OUTPUT_DIR}/")
print("(this may take a minute...)")
# Use the venv's python to run the converter (avoids import issues)
python_exe = sys.executable # the python running this script
result = subprocess.run(
[
python_exe,
"-m",
"tensorflowjs.converters.converter",
"--input_format=keras",
"--output_format=tfjs_graph_model",
MODEL_PATH,
OUTPUT_DIR,
],
capture_output=True,
text=True,
timeout=300,
)
if result.returncode != 0:
print("\nERROR: Conversion failed!")
print(f"stdout: {result.stdout}")
print(f"stderr: {result.stderr}")
sys.exit(1)
if result.stdout:
print(result.stdout)
if result.stderr:
# Some warnings are normal
print(f"Converter output: {result.stderr}")
# Verify output
model_json_path = os.path.join(OUTPUT_DIR, "model.json")
if not os.path.exists(model_json_path):
print("ERROR: Conversion did not produce model.json")
sys.exit(1)
# List output files
files = os.listdir(OUTPUT_DIR)
total_size = sum(
os.path.getsize(os.path.join(OUTPUT_DIR, f))
for f in files
if os.path.isfile(os.path.join(OUTPUT_DIR, f))
)
print("\nConversion complete!")
print(f"Output directory: {OUTPUT_DIR}/")
print(f"Files: {len(files)}")
for f in sorted(files):
fpath = os.path.join(OUTPUT_DIR, f)
if os.path.isfile(fpath):
size = os.path.getsize(fpath)
print(f" {f:<30s} {size:>10,} bytes")
print(f"Total size: {total_size:,} bytes ({total_size / 1024 / 1024:.1f} MB)")
# Read model.json to check config
with open(model_json_path) as f:
model_json = json.load(f)
print(f"\nTF.js model format: {model_json.get('format', 'unknown')}")
print(f"Generated by: {model_json.get('generatedBy', 'unknown')}")
# Inspect model topology
if "modelTopology" in model_json:
topology = model_json["modelTopology"]
print("\nModel topology:")
print(f" Name: {topology.get('model_name', 'unnamed')}")
print(f" Ops: {len(topology.get('node', []))} nodes")
# Input/output nodes
inputs = topology.get("inputs", {})
outputs = topology.get("outputs", {})
print(f" Inputs: {list(inputs.keys())}")
for name, info in inputs.items():
shape = info.get("tensorShape", {})
print(f" {name}: shape={shape.get('dim', 'unknown')}")
print(f" Outputs: {list(outputs.keys())}")
for name, info in outputs.items():
shape = info.get("tensorShape", {})
print(f" {name}: shape={shape.get('dim', 'unknown')}")
# Check weights specification
if "weightsManifest" in model_json:
manifest = model_json["weightsManifest"]
print(f"\nWeight manifests: {len(manifest)}")
for i, m in enumerate(manifest):
shards = m.get("shards", [])
print(f" Manifest {i}: {len(shards)} shard(s)")
return OUTPUT_DIR
def main():
if not os.path.exists(MODEL_PATH):
print(f"ERROR: Model not found at {MODEL_PATH}")
sys.exit(1)
# Step 1: Inspect metadata
inspect_keras_metadata()
# Step 2: Convert
output_dir = convert_to_tfjs()
# Step 3: Summary
print("\n" + "=" * 60)
print("NEXT STEPS")
print("=" * 60)
print(f"""
1. Move the TF.js model to the expected location:
The model-loader expects model.json at:
public/models/plant-disease-classifier/model.json
Move files:
mv {output_dir}/model.json public/models/plant-disease-classifier/
mv {output_dir}/group1-shard* public/models/plant-disease-classifier/
2. IMPORTANT: This model has 38 output classes (original PlantVillage).
Your labels.ts expects 95 classes (93 diseases + healthy + unknown).
You'll need to either:
a) Fine-tune the model with your 95-class dataset, OR
b) Map the 38 PlantVillage classes to your disease IDs
3. Install @tensorflow/tfjs in your project:
npm install @tensorflow/tfjs
4. Test with your API:
npm run dev
POST /api/identify with an uploaded image
""")
if __name__ == "__main__":
main()