This commit is contained in:
2026-06-06 15:09:46 -04:00
parent 78220d3568
commit 06295c83ca
56 changed files with 12018 additions and 440 deletions

View File

@@ -0,0 +1,170 @@
# 03. TensorFlow.js Model Loading Verification and Fixes
meta:
id: production-ml-pipeline-03
feature: production-ml-pipeline
priority: P0
depends_on: []
tags: [implementation, model, tests-required]
objective:
- Verify the converted TF.js GraphModel loads successfully on the Node.js server
- Fix input tensor format handling (NCHW pipeline input → NHWC model input)
- Determine whether model output is logits or pre-computed softmax probabilities
- Ensure inference produces valid [1, 38] output without errors
- Install `@tensorflow/tfjs-node` for server-side native acceleration
deliverables:
- `src/lib/ml/model-loader.ts` — fixed and verified for real model loading
- `src/lib/ml/model-loader.test.ts` — updated integration tests
- `package.json``@tensorflow/tfjs-node` added as dependency (if needed)
- `src/lib/ml/inference.ts` — fixed output interpretation (logits vs probabilities)
- `src/lib/ml/inference.test.ts` — updated for real model inference
steps:
1. **Determine output interpretation** — inspect the graph topology to resolve whether `Identity:0` is pre-softmax logits or post-softmax probabilities:
- The model graph contains a `Softmax` node at `StatefulPartitionedCall/mnv2_pv_original_1/dense_1/Softmax`
- The output `Identity:0` may be after Softmax (probabilities) or before (logits)
- Test: run inference on a zero tensor — if output sums to ~1.0, it's already probabilities; if output has negative values or doesn't sum to 1.0, it's logits
- Fix: if output is already probabilities, remove the `softmaxFloat32()` call in `inference.ts` and use the raw output directly
2. **Fix input tensor format** — the model expects NHWC `[1, 160, 160, 3]` but our pipeline produces NCHW `[3, 160, 160]`:
```typescript
// Current code in model-loader.ts tryLoadTFJS():
const inputTensor = tf
.tensor4d(Array.from(tensor), [3, 160, 160])
.transpose([1, 2, 0]) // [160, 160, 3]
.expandDims(0); // [1, 160, 160, 3] NHWC
```
- Verify this transpose is correct (NCHW → NHWC)
- Verify the tensor values are in the expected range (ImageNet-normalized: roughly -2.5 to +2.5)
- Alternative: reshape directly as `[1, 160, 160, 3]` if the identify endpoint produces NHWC data
3. **Install `@tensorflow/tfjs-node`** for server-side native acceleration:
```bash
npm install @tensorflow/tfjs-node
```
- Browser tfjs works on server but is significantly slower (no native BLAS)
- `@tensorflow/tfjs-node` uses libtensorflow C library for ~10-100x speedup
- Verify native bindings install correctly (may need `@tensorflow/tfjs-node-gpu` for GPU, but CPU is fine for this use case)
- Fallback chain remains: tfjs-node → tfjs (browser) → mock
4. **Verify model loads from filesystem**:
```typescript
const model = await tf.loadGraphModel(`file://${MODEL_JSON_PATH}`);
console.log("Model loaded:", model.inputs, model.outputs);
// Expected:
// inputs: [{ shape: [-1, 160, 160, 3], dtype: 'float32' }]
// outputs: [{ shape: [-1, 38], dtype: 'float32' }]
```
- Verify `model.inputs[0].shape` matches `[null, 160, 160, 3]`
- Verify `model.outputs[0].shape` matches `[null, 38]`
- Verify model has `predict()` method (GraphModel uses `predict()`, not `execute()`)
5. **Run inference smoke test**:
```typescript
// Create a test tensor (random normalized values)
const testTensor = new Float32Array(3 * 160 * 160);
for (let i = 0; i < testTensor.length; i++) {
testTensor[i] = (Math.random() - 0.5) * 2;
}
// Reshape to NHWC for TF.js
const input = tf.tensor4d(
Array.from(testTensor),
[1, 160, 160, 3], // NHWC
);
const output = model.predict(input);
const data = await output.data();
console.log("Output shape:", output.shape);
console.log(
"Output sum:",
data.reduce((a, b) => a + b, 0),
);
console.log("Output max:", Math.max(...data));
console.log("Output min:", Math.min(...data));
```
- Output should be [1, 38] with 38 float values
- If values are probabilities: sum ≈ 1.0, all values ≥ 0
- If values are logits: sum ≠ 1.0, may have negative values
6. **Fix `model-loader.ts` `getStatus()` to report real class count**:
```typescript
getStatus(): ModelStatus {
return {
loaded: true,
backend: "tfjs",
modelId: MODEL_ID,
numClasses: 38, // PlantVillage, not 95
};
}
```
7. **Add memory management** — dispose tensors after use to prevent memory leaks:
```typescript
// In predict():
tf.tidy(() => {
const input = tf.tensor4d(...);
const output = model.predict(input);
return output.dataSync();
});
```
- Or manually dispose: `inputTensor.dispose()`, `outputTensor.dispose()`
- Use `tf.memory()` to monitor tensor count during development
8. **Handle model load failures gracefully**:
- If model files are corrupted, log the specific error
- If tfjs-node native bindings fail, fall back to browser tfjs with a warning
- Never crash the server on model load failure — fall back to mock mode with clear logging
tests:
- Integration: model loads from `public/models/plant-disease-classifier/model.json` without errors
- Integration: `model.inputs[0].shape` is `[-1, 160, 160, 3]`
- Integration: `model.outputs[0].shape` is `[-1, 38]`
- Integration: inference on random tensor produces [38] float output
- Integration: if output is probabilities, sum is within 0.991.01
- Integration: `getStatus()` returns `{ loaded: true, backend: "tfjs", numClasses: 38 }`
- Unit: `validateInput()` correctly rejects tensors with wrong length
- Unit: NCHW → NHWC transpose produces correct layout
- Performance: inference completes in < 500ms on a typical server (with tfjs-node)
acceptance_criteria:
- `getModel()` returns a model with `loaded: true` and `backend: "tfjs"`
- `model.predict()` on a valid [1, 160, 160, 3] input returns [1, 38] output without errors
- Output interpretation is correctly determined (logits vs probabilities) and handled
- `@tensorflow/tfjs-node` is installed and used as primary backend
- No memory leaks: tensor count stays stable after repeated inference calls
- Fallback chain works: tfjs-node → tfjs → mock (each failure logs warning)
- Model load time < 30 seconds on first request
- Inference time < 500ms per image on server
validation:
- `npm install @tensorflow/tfjs-node` — native bindings install successfully
- `npx vitest run src/lib/ml/model-loader.test.ts` — all loading tests pass
- `npx vitest run src/lib/ml/inference.test.ts` — all inference tests pass
- Manual: `curl -X POST http://localhost:3000/api/identify -H "Content-Type: application/json" -d '{"imageId":"<existing-id>"}'` — returns real predictions (no `demo_mode: true`)
- Check server logs for `[model-loader] Loaded TF.js model` (not mock fallback)
notes:
- The model file `best_mnv2_pv_original.keras` is the original Keras file — the TF.js conversion is already done (model.json + 3 weight shards)
- The `.keras` file can be deleted after confirming TF.js works, saving ~27MB
- `@tensorflow/tfjs-node` requires libtensorflow — it downloads automatically during npm install
- The `file://` protocol for `loadGraphModel` works with `@tensorflow/tfjs-node` but may not work with browser tfjs (which uses fetch) — if using browser tfjs fallback, need to read file and use `tf.io.loadGraphModel` with a custom loader
- ImageNet normalization in `preprocessImageBuffer()` uses mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225] — verify this matches what the PlantVillage model expects (it should, since MobileNetV2 is typically trained with ImageNet preprocessing)