7.5 KiB
03. TensorFlow.js Model Loading Verification and Fixes
meta: id: production-ml-pipeline-03 feature: production-ml-pipeline priority: P0 depends_on: [] tags: [implementation, model, tests-required]
objective:
- Verify the converted TF.js GraphModel loads successfully on the Node.js server
- Fix input tensor format handling (NCHW pipeline input → NHWC model input)
- Determine whether model output is logits or pre-computed softmax probabilities
- Ensure inference produces valid [1, 38] output without errors
- Install
@tensorflow/tfjs-nodefor server-side native acceleration
deliverables:
src/lib/ml/model-loader.ts— fixed and verified for real model loadingsrc/lib/ml/model-loader.test.ts— updated integration testspackage.json—@tensorflow/tfjs-nodeadded as dependency (if needed)src/lib/ml/inference.ts— fixed output interpretation (logits vs probabilities)src/lib/ml/inference.test.ts— updated for real model inference
steps:
-
Determine output interpretation — inspect the graph topology to resolve whether
Identity:0is pre-softmax logits or post-softmax probabilities:- The model graph contains a
Softmaxnode atStatefulPartitionedCall/mnv2_pv_original_1/dense_1/Softmax - The output
Identity:0may be after Softmax (probabilities) or before (logits) - Test: run inference on a zero tensor — if output sums to ~1.0, it's already probabilities; if output has negative values or doesn't sum to 1.0, it's logits
- Fix: if output is already probabilities, remove the
softmaxFloat32()call ininference.tsand use the raw output directly
- The model graph contains a
-
Fix input tensor format — the model expects NHWC
[1, 160, 160, 3]but our pipeline produces NCHW[3, 160, 160]:// Current code in model-loader.ts tryLoadTFJS(): const inputTensor = tf .tensor4d(Array.from(tensor), [3, 160, 160]) .transpose([1, 2, 0]) // [160, 160, 3] .expandDims(0); // [1, 160, 160, 3] NHWC- Verify this transpose is correct (NCHW → NHWC)
- Verify the tensor values are in the expected range (ImageNet-normalized: roughly -2.5 to +2.5)
- Alternative: reshape directly as
[1, 160, 160, 3]if the identify endpoint produces NHWC data
-
Install
@tensorflow/tfjs-nodefor server-side native acceleration:npm install @tensorflow/tfjs-node- Browser tfjs works on server but is significantly slower (no native BLAS)
@tensorflow/tfjs-nodeuses libtensorflow C library for ~10-100x speedup- Verify native bindings install correctly (may need
@tensorflow/tfjs-node-gpufor GPU, but CPU is fine for this use case) - Fallback chain remains: tfjs-node → tfjs (browser) → mock
-
Verify model loads from filesystem:
const model = await tf.loadGraphModel(`file://${MODEL_JSON_PATH}`); console.log("Model loaded:", model.inputs, model.outputs); // Expected: // inputs: [{ shape: [-1, 160, 160, 3], dtype: 'float32' }] // outputs: [{ shape: [-1, 38], dtype: 'float32' }]- Verify
model.inputs[0].shapematches[null, 160, 160, 3] - Verify
model.outputs[0].shapematches[null, 38] - Verify model has
predict()method (GraphModel usespredict(), notexecute())
- Verify
-
Run inference smoke test:
// Create a test tensor (random normalized values) const testTensor = new Float32Array(3 * 160 * 160); for (let i = 0; i < testTensor.length; i++) { testTensor[i] = (Math.random() - 0.5) * 2; } // Reshape to NHWC for TF.js const input = tf.tensor4d( Array.from(testTensor), [1, 160, 160, 3], // NHWC ); const output = model.predict(input); const data = await output.data(); console.log("Output shape:", output.shape); console.log( "Output sum:", data.reduce((a, b) => a + b, 0), ); console.log("Output max:", Math.max(...data)); console.log("Output min:", Math.min(...data));- Output should be [1, 38] with 38 float values
- If values are probabilities: sum ≈ 1.0, all values ≥ 0
- If values are logits: sum ≠ 1.0, may have negative values
-
Fix
model-loader.tsgetStatus()to report real class count:getStatus(): ModelStatus { return { loaded: true, backend: "tfjs", modelId: MODEL_ID, numClasses: 38, // PlantVillage, not 95 }; } -
Add memory management — dispose tensors after use to prevent memory leaks:
// In predict(): tf.tidy(() => { const input = tf.tensor4d(...); const output = model.predict(input); return output.dataSync(); });- Or manually dispose:
inputTensor.dispose(),outputTensor.dispose() - Use
tf.memory()to monitor tensor count during development
- Or manually dispose:
-
Handle model load failures gracefully:
- If model files are corrupted, log the specific error
- If tfjs-node native bindings fail, fall back to browser tfjs with a warning
- Never crash the server on model load failure — fall back to mock mode with clear logging
tests:
- Integration: model loads from
public/models/plant-disease-classifier/model.jsonwithout errors - Integration:
model.inputs[0].shapeis[-1, 160, 160, 3] - Integration:
model.outputs[0].shapeis[-1, 38] - Integration: inference on random tensor produces [38] float output
- Integration: if output is probabilities, sum is within 0.99–1.01
- Integration:
getStatus()returns{ loaded: true, backend: "tfjs", numClasses: 38 } - Unit:
validateInput()correctly rejects tensors with wrong length - Unit: NCHW → NHWC transpose produces correct layout
- Performance: inference completes in < 500ms on a typical server (with tfjs-node)
acceptance_criteria:
getModel()returns a model withloaded: trueandbackend: "tfjs"model.predict()on a valid [1, 160, 160, 3] input returns [1, 38] output without errors- Output interpretation is correctly determined (logits vs probabilities) and handled
@tensorflow/tfjs-nodeis installed and used as primary backend- No memory leaks: tensor count stays stable after repeated inference calls
- Fallback chain works: tfjs-node → tfjs → mock (each failure logs warning)
- Model load time < 30 seconds on first request
- Inference time < 500ms per image on server
validation:
npm install @tensorflow/tfjs-node— native bindings install successfullynpx vitest run src/lib/ml/model-loader.test.ts— all loading tests passnpx vitest run src/lib/ml/inference.test.ts— all inference tests pass- Manual:
curl -X POST http://localhost:3000/api/identify -H "Content-Type: application/json" -d '{"imageId":"<existing-id>"}'— returns real predictions (nodemo_mode: true) - Check server logs for
[model-loader] Loaded TF.js model(not mock fallback)
notes:
- The model file
best_mnv2_pv_original.kerasis the original Keras file — the TF.js conversion is already done (model.json + 3 weight shards) - The
.kerasfile can be deleted after confirming TF.js works, saving ~27MB @tensorflow/tfjs-noderequires libtensorflow — it downloads automatically during npm install- The
file://protocol forloadGraphModelworks with@tensorflow/tfjs-nodebut may not work with browser tfjs (which uses fetch) — if using browser tfjs fallback, need to read file and usetf.io.loadGraphModelwith a custom loader - ImageNet normalization in
preprocessImageBuffer()uses mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225] — verify this matches what the PlantVillage model expects (it should, since MobileNetV2 is typically trained with ImageNet preprocessing)