297 lines
9.9 KiB
Python
297 lines
9.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Inspect and convert a .keras plant disease model to TF.js GraphModel format.
|
|
|
|
Uses tensorflowjs_converter CLI to avoid Keras version deserialization issues.
|
|
|
|
Usage:
|
|
pip3 install tensorflowjs # also pulls tensorflow as dependency
|
|
python3 scripts/convert-keras-to-tfjs.py
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
|
|
MODEL_PATH = os.path.join(
|
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
|
"public",
|
|
"models",
|
|
"plant-disease-classifier",
|
|
"best_mnv2_pv_original.keras",
|
|
)
|
|
|
|
OUTPUT_DIR = os.path.join(
|
|
os.path.dirname(MODEL_PATH),
|
|
"tfjs_model",
|
|
)
|
|
|
|
|
|
def inspect_keras_metadata():
|
|
"""Read .keras archive metadata without loading the model."""
|
|
print("=" * 60)
|
|
print("MODEL INSPECTION (metadata only)")
|
|
print("=" * 60)
|
|
|
|
try:
|
|
import zipfile
|
|
except ImportError:
|
|
print("ERROR: zipfile not available")
|
|
sys.exit(1)
|
|
|
|
if not os.path.exists(MODEL_PATH):
|
|
print(f"ERROR: Model not found at {MODEL_PATH}")
|
|
sys.exit(1)
|
|
|
|
print(f"\nModel file: {MODEL_PATH}")
|
|
print(
|
|
f"File size: {os.path.getsize(MODEL_PATH):,} bytes ({os.path.getsize(MODEL_PATH) / 1024 / 1024:.1f} MB)"
|
|
)
|
|
|
|
# .keras files are ZIP archives
|
|
with zipfile.ZipFile(MODEL_PATH) as zf:
|
|
names = zf.namelist()
|
|
print(f"\nArchive contents ({len(names)} entries):")
|
|
for name in names:
|
|
info = zf.getinfo(name)
|
|
print(f" {name:<40s} {info.file_size:>10,} bytes")
|
|
|
|
# Read config.json for model architecture info
|
|
config_path = None
|
|
for name in names:
|
|
if name.endswith("config.json"):
|
|
config_path = name
|
|
break
|
|
|
|
if config_path:
|
|
print(f"\nReading {config_path}...")
|
|
with zf.open(config_path) as f:
|
|
config = json.load(f)
|
|
|
|
# Extract key info
|
|
model_type = config.get("class_name", "unknown")
|
|
print(f"Model class: {model_type}")
|
|
|
|
# Try to find output layer info
|
|
if "config" in config:
|
|
inner_config = config["config"]
|
|
|
|
# Look for output shape in config
|
|
if "output_shape" in inner_config:
|
|
print(f"Output shape: {inner_config['output_shape']}")
|
|
|
|
# Look through layers for the final dense layer
|
|
if "layers" in inner_config:
|
|
layers = inner_config["layers"]
|
|
print(f"\nLayers ({len(layers)} total):")
|
|
for layer in layers:
|
|
layer_name = layer.get("config", {}).get("name", "?")
|
|
layer_class = layer.get("class_name", "?")
|
|
layer_module = layer.get("module", "?")
|
|
|
|
# Extract units/activation for dense layers
|
|
layer_config = layer.get("config", {})
|
|
units = layer_config.get("units")
|
|
activation = layer_config.get("activation")
|
|
|
|
detail = ""
|
|
if units:
|
|
detail = f" units={units}"
|
|
if activation:
|
|
detail += f" activation={activation}"
|
|
|
|
print(f" {layer_name:<30s} {layer_class:<20s}{detail}")
|
|
|
|
# Find last dense layer for class count
|
|
for layer in reversed(layers):
|
|
if layer.get("class_name") == "Dense":
|
|
units = layer.get("config", {}).get("units")
|
|
activation = layer.get("config", {}).get("activation")
|
|
print("\nClassification head:")
|
|
print(f" Units (classes): {units}")
|
|
print(f" Activation: {activation}")
|
|
print(
|
|
f" Layer name: {layer.get('config', {}).get('name', '?')}"
|
|
)
|
|
break
|
|
|
|
# Check compile config
|
|
if "compile_config" in config:
|
|
compile_cfg = config["compile_config"]
|
|
optimizer = compile_cfg.get("optimizer", {})
|
|
if isinstance(optimizer, dict):
|
|
opt_name = optimizer.get("class_name", "?")
|
|
lr = optimizer.get("config", {}).get("learning_rate")
|
|
print("\nTraining config:")
|
|
print(f" Optimizer: {opt_name}")
|
|
if lr:
|
|
print(f" Learning rate: {lr}")
|
|
loss = compile_cfg.get("loss", "?")
|
|
metrics = compile_cfg.get("metrics", [])
|
|
print(f" Loss: {loss}")
|
|
print(f" Metrics: {metrics}")
|
|
|
|
# Check input shape
|
|
if "build_config" in config:
|
|
build_cfg = config["build_config"]
|
|
if "input_shape" in build_cfg:
|
|
print(f"\nInput shape: {build_cfg['input_shape']}")
|
|
|
|
|
|
def convert_to_tfjs():
|
|
"""Convert using tensorflowjs_converter CLI."""
|
|
print("\n" + "=" * 60)
|
|
print("CONVERTING TO TF.JS GRAPH MODEL")
|
|
print("=" * 60)
|
|
|
|
# Check tensorflowjs_converter CLI is available
|
|
converter = shutil.which("tensorflowjs_converter")
|
|
if not converter:
|
|
print("ERROR: tensorflowjs_converter not found in PATH.")
|
|
print(" pip3 install tensorflowjs")
|
|
sys.exit(1)
|
|
|
|
# Clean output dir
|
|
if os.path.exists(OUTPUT_DIR):
|
|
print(f"Removing existing output dir: {OUTPUT_DIR}")
|
|
shutil.rmtree(OUTPUT_DIR)
|
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
print(f"\nConverting {MODEL_PATH} -> {OUTPUT_DIR}/")
|
|
print("(this may take a minute...)")
|
|
|
|
# Use the venv's python to run the converter (avoids import issues)
|
|
python_exe = sys.executable # the python running this script
|
|
result = subprocess.run(
|
|
[
|
|
python_exe,
|
|
"-m",
|
|
"tensorflowjs.converters.converter",
|
|
"--input_format=keras",
|
|
"--output_format=tfjs_graph_model",
|
|
MODEL_PATH,
|
|
OUTPUT_DIR,
|
|
],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=300,
|
|
)
|
|
|
|
if result.returncode != 0:
|
|
print("\nERROR: Conversion failed!")
|
|
print(f"stdout: {result.stdout}")
|
|
print(f"stderr: {result.stderr}")
|
|
sys.exit(1)
|
|
|
|
if result.stdout:
|
|
print(result.stdout)
|
|
if result.stderr:
|
|
# Some warnings are normal
|
|
print(f"Converter output: {result.stderr}")
|
|
|
|
# Verify output
|
|
model_json_path = os.path.join(OUTPUT_DIR, "model.json")
|
|
if not os.path.exists(model_json_path):
|
|
print("ERROR: Conversion did not produce model.json")
|
|
sys.exit(1)
|
|
|
|
# List output files
|
|
files = os.listdir(OUTPUT_DIR)
|
|
total_size = sum(
|
|
os.path.getsize(os.path.join(OUTPUT_DIR, f))
|
|
for f in files
|
|
if os.path.isfile(os.path.join(OUTPUT_DIR, f))
|
|
)
|
|
|
|
print("\nConversion complete!")
|
|
print(f"Output directory: {OUTPUT_DIR}/")
|
|
print(f"Files: {len(files)}")
|
|
for f in sorted(files):
|
|
fpath = os.path.join(OUTPUT_DIR, f)
|
|
if os.path.isfile(fpath):
|
|
size = os.path.getsize(fpath)
|
|
print(f" {f:<30s} {size:>10,} bytes")
|
|
print(f"Total size: {total_size:,} bytes ({total_size / 1024 / 1024:.1f} MB)")
|
|
|
|
# Read model.json to check config
|
|
with open(model_json_path) as f:
|
|
model_json = json.load(f)
|
|
|
|
print(f"\nTF.js model format: {model_json.get('format', 'unknown')}")
|
|
print(f"Generated by: {model_json.get('generatedBy', 'unknown')}")
|
|
|
|
# Inspect model topology
|
|
if "modelTopology" in model_json:
|
|
topology = model_json["modelTopology"]
|
|
print("\nModel topology:")
|
|
print(f" Name: {topology.get('model_name', 'unnamed')}")
|
|
print(f" Ops: {len(topology.get('node', []))} nodes")
|
|
|
|
# Input/output nodes
|
|
inputs = topology.get("inputs", {})
|
|
outputs = topology.get("outputs", {})
|
|
print(f" Inputs: {list(inputs.keys())}")
|
|
for name, info in inputs.items():
|
|
shape = info.get("tensorShape", {})
|
|
print(f" {name}: shape={shape.get('dim', 'unknown')}")
|
|
print(f" Outputs: {list(outputs.keys())}")
|
|
for name, info in outputs.items():
|
|
shape = info.get("tensorShape", {})
|
|
print(f" {name}: shape={shape.get('dim', 'unknown')}")
|
|
|
|
# Check weights specification
|
|
if "weightsManifest" in model_json:
|
|
manifest = model_json["weightsManifest"]
|
|
print(f"\nWeight manifests: {len(manifest)}")
|
|
for i, m in enumerate(manifest):
|
|
shards = m.get("shards", [])
|
|
print(f" Manifest {i}: {len(shards)} shard(s)")
|
|
|
|
return OUTPUT_DIR
|
|
|
|
|
|
def main():
|
|
if not os.path.exists(MODEL_PATH):
|
|
print(f"ERROR: Model not found at {MODEL_PATH}")
|
|
sys.exit(1)
|
|
|
|
# Step 1: Inspect metadata
|
|
inspect_keras_metadata()
|
|
|
|
# Step 2: Convert
|
|
output_dir = convert_to_tfjs()
|
|
|
|
# Step 3: Summary
|
|
print("\n" + "=" * 60)
|
|
print("NEXT STEPS")
|
|
print("=" * 60)
|
|
print(f"""
|
|
1. Move the TF.js model to the expected location:
|
|
The model-loader expects model.json at:
|
|
public/models/plant-disease-classifier/model.json
|
|
|
|
Move files:
|
|
mv {output_dir}/model.json public/models/plant-disease-classifier/
|
|
mv {output_dir}/group1-shard* public/models/plant-disease-classifier/
|
|
|
|
2. IMPORTANT: This model has 38 output classes (original PlantVillage).
|
|
Your labels.ts expects 95 classes (93 diseases + healthy + unknown).
|
|
You'll need to either:
|
|
a) Fine-tune the model with your 95-class dataset, OR
|
|
b) Map the 38 PlantVillage classes to your disease IDs
|
|
|
|
3. Install @tensorflow/tfjs in your project:
|
|
npm install @tensorflow/tfjs
|
|
|
|
4. Test with your API:
|
|
npm run dev
|
|
POST /api/identify with an uploaded image
|
|
""")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|