Files
plant-disease-id/apps/web/tasks/production-ml-pipeline/03-model-loading-verification.md
2026-06-06 15:09:46 -04:00

171 lines
7.5 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 03. TensorFlow.js Model Loading Verification and Fixes
meta:
id: production-ml-pipeline-03
feature: production-ml-pipeline
priority: P0
depends_on: []
tags: [implementation, model, tests-required]
objective:
- Verify the converted TF.js GraphModel loads successfully on the Node.js server
- Fix input tensor format handling (NCHW pipeline input → NHWC model input)
- Determine whether model output is logits or pre-computed softmax probabilities
- Ensure inference produces valid [1, 38] output without errors
- Install `@tensorflow/tfjs-node` for server-side native acceleration
deliverables:
- `src/lib/ml/model-loader.ts` — fixed and verified for real model loading
- `src/lib/ml/model-loader.test.ts` — updated integration tests
- `package.json``@tensorflow/tfjs-node` added as dependency (if needed)
- `src/lib/ml/inference.ts` — fixed output interpretation (logits vs probabilities)
- `src/lib/ml/inference.test.ts` — updated for real model inference
steps:
1. **Determine output interpretation** — inspect the graph topology to resolve whether `Identity:0` is pre-softmax logits or post-softmax probabilities:
- The model graph contains a `Softmax` node at `StatefulPartitionedCall/mnv2_pv_original_1/dense_1/Softmax`
- The output `Identity:0` may be after Softmax (probabilities) or before (logits)
- Test: run inference on a zero tensor — if output sums to ~1.0, it's already probabilities; if output has negative values or doesn't sum to 1.0, it's logits
- Fix: if output is already probabilities, remove the `softmaxFloat32()` call in `inference.ts` and use the raw output directly
2. **Fix input tensor format** — the model expects NHWC `[1, 160, 160, 3]` but our pipeline produces NCHW `[3, 160, 160]`:
```typescript
// Current code in model-loader.ts tryLoadTFJS():
const inputTensor = tf
.tensor4d(Array.from(tensor), [3, 160, 160])
.transpose([1, 2, 0]) // [160, 160, 3]
.expandDims(0); // [1, 160, 160, 3] NHWC
```
- Verify this transpose is correct (NCHW → NHWC)
- Verify the tensor values are in the expected range (ImageNet-normalized: roughly -2.5 to +2.5)
- Alternative: reshape directly as `[1, 160, 160, 3]` if the identify endpoint produces NHWC data
3. **Install `@tensorflow/tfjs-node`** for server-side native acceleration:
```bash
npm install @tensorflow/tfjs-node
```
- Browser tfjs works on server but is significantly slower (no native BLAS)
- `@tensorflow/tfjs-node` uses libtensorflow C library for ~10-100x speedup
- Verify native bindings install correctly (may need `@tensorflow/tfjs-node-gpu` for GPU, but CPU is fine for this use case)
- Fallback chain remains: tfjs-node → tfjs (browser) → mock
4. **Verify model loads from filesystem**:
```typescript
const model = await tf.loadGraphModel(`file://${MODEL_JSON_PATH}`);
console.log("Model loaded:", model.inputs, model.outputs);
// Expected:
// inputs: [{ shape: [-1, 160, 160, 3], dtype: 'float32' }]
// outputs: [{ shape: [-1, 38], dtype: 'float32' }]
```
- Verify `model.inputs[0].shape` matches `[null, 160, 160, 3]`
- Verify `model.outputs[0].shape` matches `[null, 38]`
- Verify model has `predict()` method (GraphModel uses `predict()`, not `execute()`)
5. **Run inference smoke test**:
```typescript
// Create a test tensor (random normalized values)
const testTensor = new Float32Array(3 * 160 * 160);
for (let i = 0; i < testTensor.length; i++) {
testTensor[i] = (Math.random() - 0.5) * 2;
}
// Reshape to NHWC for TF.js
const input = tf.tensor4d(
Array.from(testTensor),
[1, 160, 160, 3], // NHWC
);
const output = model.predict(input);
const data = await output.data();
console.log("Output shape:", output.shape);
console.log(
"Output sum:",
data.reduce((a, b) => a + b, 0),
);
console.log("Output max:", Math.max(...data));
console.log("Output min:", Math.min(...data));
```
- Output should be [1, 38] with 38 float values
- If values are probabilities: sum ≈ 1.0, all values ≥ 0
- If values are logits: sum ≠ 1.0, may have negative values
6. **Fix `model-loader.ts` `getStatus()` to report real class count**:
```typescript
getStatus(): ModelStatus {
return {
loaded: true,
backend: "tfjs",
modelId: MODEL_ID,
numClasses: 38, // PlantVillage, not 95
};
}
```
7. **Add memory management** — dispose tensors after use to prevent memory leaks:
```typescript
// In predict():
tf.tidy(() => {
const input = tf.tensor4d(...);
const output = model.predict(input);
return output.dataSync();
});
```
- Or manually dispose: `inputTensor.dispose()`, `outputTensor.dispose()`
- Use `tf.memory()` to monitor tensor count during development
8. **Handle model load failures gracefully**:
- If model files are corrupted, log the specific error
- If tfjs-node native bindings fail, fall back to browser tfjs with a warning
- Never crash the server on model load failure — fall back to mock mode with clear logging
tests:
- Integration: model loads from `public/models/plant-disease-classifier/model.json` without errors
- Integration: `model.inputs[0].shape` is `[-1, 160, 160, 3]`
- Integration: `model.outputs[0].shape` is `[-1, 38]`
- Integration: inference on random tensor produces [38] float output
- Integration: if output is probabilities, sum is within 0.991.01
- Integration: `getStatus()` returns `{ loaded: true, backend: "tfjs", numClasses: 38 }`
- Unit: `validateInput()` correctly rejects tensors with wrong length
- Unit: NCHW → NHWC transpose produces correct layout
- Performance: inference completes in < 500ms on a typical server (with tfjs-node)
acceptance_criteria:
- `getModel()` returns a model with `loaded: true` and `backend: "tfjs"`
- `model.predict()` on a valid [1, 160, 160, 3] input returns [1, 38] output without errors
- Output interpretation is correctly determined (logits vs probabilities) and handled
- `@tensorflow/tfjs-node` is installed and used as primary backend
- No memory leaks: tensor count stays stable after repeated inference calls
- Fallback chain works: tfjs-node → tfjs → mock (each failure logs warning)
- Model load time < 30 seconds on first request
- Inference time < 500ms per image on server
validation:
- `npm install @tensorflow/tfjs-node` — native bindings install successfully
- `npx vitest run src/lib/ml/model-loader.test.ts` — all loading tests pass
- `npx vitest run src/lib/ml/inference.test.ts` — all inference tests pass
- Manual: `curl -X POST http://localhost:3000/api/identify -H "Content-Type: application/json" -d '{"imageId":"<existing-id>"}'` — returns real predictions (no `demo_mode: true`)
- Check server logs for `[model-loader] Loaded TF.js model` (not mock fallback)
notes:
- The model file `best_mnv2_pv_original.keras` is the original Keras file — the TF.js conversion is already done (model.json + 3 weight shards)
- The `.keras` file can be deleted after confirming TF.js works, saving ~27MB
- `@tensorflow/tfjs-node` requires libtensorflow — it downloads automatically during npm install
- The `file://` protocol for `loadGraphModel` works with `@tensorflow/tfjs-node` but may not work with browser tfjs (which uses fetch) — if using browser tfjs fallback, need to read file and use `tf.io.loadGraphModel` with a custom loader
- ImageNet normalization in `preprocessImageBuffer()` uses mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225] — verify this matches what the PlantVillage model expects (it should, since MobileNetV2 is typically trained with ImageNet preprocessing)