171 lines
7.5 KiB
Markdown
171 lines
7.5 KiB
Markdown
# 03. TensorFlow.js Model Loading Verification and Fixes
|
||
|
||
meta:
|
||
id: production-ml-pipeline-03
|
||
feature: production-ml-pipeline
|
||
priority: P0
|
||
depends_on: []
|
||
tags: [implementation, model, tests-required]
|
||
|
||
objective:
|
||
|
||
- Verify the converted TF.js GraphModel loads successfully on the Node.js server
|
||
- Fix input tensor format handling (NCHW pipeline input → NHWC model input)
|
||
- Determine whether model output is logits or pre-computed softmax probabilities
|
||
- Ensure inference produces valid [1, 38] output without errors
|
||
- Install `@tensorflow/tfjs-node` for server-side native acceleration
|
||
|
||
deliverables:
|
||
|
||
- `src/lib/ml/model-loader.ts` — fixed and verified for real model loading
|
||
- `src/lib/ml/model-loader.test.ts` — updated integration tests
|
||
- `package.json` — `@tensorflow/tfjs-node` added as dependency (if needed)
|
||
- `src/lib/ml/inference.ts` — fixed output interpretation (logits vs probabilities)
|
||
- `src/lib/ml/inference.test.ts` — updated for real model inference
|
||
|
||
steps:
|
||
|
||
1. **Determine output interpretation** — inspect the graph topology to resolve whether `Identity:0` is pre-softmax logits or post-softmax probabilities:
|
||
- The model graph contains a `Softmax` node at `StatefulPartitionedCall/mnv2_pv_original_1/dense_1/Softmax`
|
||
- The output `Identity:0` may be after Softmax (probabilities) or before (logits)
|
||
- Test: run inference on a zero tensor — if output sums to ~1.0, it's already probabilities; if output has negative values or doesn't sum to 1.0, it's logits
|
||
- Fix: if output is already probabilities, remove the `softmaxFloat32()` call in `inference.ts` and use the raw output directly
|
||
|
||
2. **Fix input tensor format** — the model expects NHWC `[1, 160, 160, 3]` but our pipeline produces NCHW `[3, 160, 160]`:
|
||
|
||
```typescript
|
||
// Current code in model-loader.ts tryLoadTFJS():
|
||
const inputTensor = tf
|
||
.tensor4d(Array.from(tensor), [3, 160, 160])
|
||
.transpose([1, 2, 0]) // [160, 160, 3]
|
||
.expandDims(0); // [1, 160, 160, 3] NHWC
|
||
```
|
||
|
||
- Verify this transpose is correct (NCHW → NHWC)
|
||
- Verify the tensor values are in the expected range (ImageNet-normalized: roughly -2.5 to +2.5)
|
||
- Alternative: reshape directly as `[1, 160, 160, 3]` if the identify endpoint produces NHWC data
|
||
|
||
3. **Install `@tensorflow/tfjs-node`** for server-side native acceleration:
|
||
|
||
```bash
|
||
npm install @tensorflow/tfjs-node
|
||
```
|
||
|
||
- Browser tfjs works on server but is significantly slower (no native BLAS)
|
||
- `@tensorflow/tfjs-node` uses libtensorflow C library for ~10-100x speedup
|
||
- Verify native bindings install correctly (may need `@tensorflow/tfjs-node-gpu` for GPU, but CPU is fine for this use case)
|
||
- Fallback chain remains: tfjs-node → tfjs (browser) → mock
|
||
|
||
4. **Verify model loads from filesystem**:
|
||
|
||
```typescript
|
||
const model = await tf.loadGraphModel(`file://${MODEL_JSON_PATH}`);
|
||
console.log("Model loaded:", model.inputs, model.outputs);
|
||
// Expected:
|
||
// inputs: [{ shape: [-1, 160, 160, 3], dtype: 'float32' }]
|
||
// outputs: [{ shape: [-1, 38], dtype: 'float32' }]
|
||
```
|
||
|
||
- Verify `model.inputs[0].shape` matches `[null, 160, 160, 3]`
|
||
- Verify `model.outputs[0].shape` matches `[null, 38]`
|
||
- Verify model has `predict()` method (GraphModel uses `predict()`, not `execute()`)
|
||
|
||
5. **Run inference smoke test**:
|
||
|
||
```typescript
|
||
// Create a test tensor (random normalized values)
|
||
const testTensor = new Float32Array(3 * 160 * 160);
|
||
for (let i = 0; i < testTensor.length; i++) {
|
||
testTensor[i] = (Math.random() - 0.5) * 2;
|
||
}
|
||
// Reshape to NHWC for TF.js
|
||
const input = tf.tensor4d(
|
||
Array.from(testTensor),
|
||
[1, 160, 160, 3], // NHWC
|
||
);
|
||
const output = model.predict(input);
|
||
const data = await output.data();
|
||
console.log("Output shape:", output.shape);
|
||
console.log(
|
||
"Output sum:",
|
||
data.reduce((a, b) => a + b, 0),
|
||
);
|
||
console.log("Output max:", Math.max(...data));
|
||
console.log("Output min:", Math.min(...data));
|
||
```
|
||
|
||
- Output should be [1, 38] with 38 float values
|
||
- If values are probabilities: sum ≈ 1.0, all values ≥ 0
|
||
- If values are logits: sum ≠ 1.0, may have negative values
|
||
|
||
6. **Fix `model-loader.ts` `getStatus()` to report real class count**:
|
||
|
||
```typescript
|
||
getStatus(): ModelStatus {
|
||
return {
|
||
loaded: true,
|
||
backend: "tfjs",
|
||
modelId: MODEL_ID,
|
||
numClasses: 38, // PlantVillage, not 95
|
||
};
|
||
}
|
||
```
|
||
|
||
7. **Add memory management** — dispose tensors after use to prevent memory leaks:
|
||
|
||
```typescript
|
||
// In predict():
|
||
tf.tidy(() => {
|
||
const input = tf.tensor4d(...);
|
||
const output = model.predict(input);
|
||
return output.dataSync();
|
||
});
|
||
```
|
||
|
||
- Or manually dispose: `inputTensor.dispose()`, `outputTensor.dispose()`
|
||
- Use `tf.memory()` to monitor tensor count during development
|
||
|
||
8. **Handle model load failures gracefully**:
|
||
- If model files are corrupted, log the specific error
|
||
- If tfjs-node native bindings fail, fall back to browser tfjs with a warning
|
||
- Never crash the server on model load failure — fall back to mock mode with clear logging
|
||
|
||
tests:
|
||
|
||
- Integration: model loads from `public/models/plant-disease-classifier/model.json` without errors
|
||
- Integration: `model.inputs[0].shape` is `[-1, 160, 160, 3]`
|
||
- Integration: `model.outputs[0].shape` is `[-1, 38]`
|
||
- Integration: inference on random tensor produces [38] float output
|
||
- Integration: if output is probabilities, sum is within 0.99–1.01
|
||
- Integration: `getStatus()` returns `{ loaded: true, backend: "tfjs", numClasses: 38 }`
|
||
- Unit: `validateInput()` correctly rejects tensors with wrong length
|
||
- Unit: NCHW → NHWC transpose produces correct layout
|
||
- Performance: inference completes in < 500ms on a typical server (with tfjs-node)
|
||
|
||
acceptance_criteria:
|
||
|
||
- `getModel()` returns a model with `loaded: true` and `backend: "tfjs"`
|
||
- `model.predict()` on a valid [1, 160, 160, 3] input returns [1, 38] output without errors
|
||
- Output interpretation is correctly determined (logits vs probabilities) and handled
|
||
- `@tensorflow/tfjs-node` is installed and used as primary backend
|
||
- No memory leaks: tensor count stays stable after repeated inference calls
|
||
- Fallback chain works: tfjs-node → tfjs → mock (each failure logs warning)
|
||
- Model load time < 30 seconds on first request
|
||
- Inference time < 500ms per image on server
|
||
|
||
validation:
|
||
|
||
- `npm install @tensorflow/tfjs-node` — native bindings install successfully
|
||
- `npx vitest run src/lib/ml/model-loader.test.ts` — all loading tests pass
|
||
- `npx vitest run src/lib/ml/inference.test.ts` — all inference tests pass
|
||
- Manual: `curl -X POST http://localhost:3000/api/identify -H "Content-Type: application/json" -d '{"imageId":"<existing-id>"}'` — returns real predictions (no `demo_mode: true`)
|
||
- Check server logs for `[model-loader] Loaded TF.js model` (not mock fallback)
|
||
|
||
notes:
|
||
|
||
- The model file `best_mnv2_pv_original.keras` is the original Keras file — the TF.js conversion is already done (model.json + 3 weight shards)
|
||
- The `.keras` file can be deleted after confirming TF.js works, saving ~27MB
|
||
- `@tensorflow/tfjs-node` requires libtensorflow — it downloads automatically during npm install
|
||
- The `file://` protocol for `loadGraphModel` works with `@tensorflow/tfjs-node` but may not work with browser tfjs (which uses fetch) — if using browser tfjs fallback, need to read file and use `tf.io.loadGraphModel` with a custom loader
|
||
- ImageNet normalization in `preprocessImageBuffer()` uses mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225] — verify this matches what the PlantVillage model expects (it should, since MobileNetV2 is typically trained with ImageNet preprocessing)
|