beepboop
This commit is contained in:
@@ -0,0 +1,170 @@
|
||||
# 03. TensorFlow.js Model Loading Verification and Fixes
|
||||
|
||||
meta:
|
||||
id: production-ml-pipeline-03
|
||||
feature: production-ml-pipeline
|
||||
priority: P0
|
||||
depends_on: []
|
||||
tags: [implementation, model, tests-required]
|
||||
|
||||
objective:
|
||||
|
||||
- Verify the converted TF.js GraphModel loads successfully on the Node.js server
|
||||
- Fix input tensor format handling (NCHW pipeline input → NHWC model input)
|
||||
- Determine whether model output is logits or pre-computed softmax probabilities
|
||||
- Ensure inference produces valid [1, 38] output without errors
|
||||
- Install `@tensorflow/tfjs-node` for server-side native acceleration
|
||||
|
||||
deliverables:
|
||||
|
||||
- `src/lib/ml/model-loader.ts` — fixed and verified for real model loading
|
||||
- `src/lib/ml/model-loader.test.ts` — updated integration tests
|
||||
- `package.json` — `@tensorflow/tfjs-node` added as dependency (if needed)
|
||||
- `src/lib/ml/inference.ts` — fixed output interpretation (logits vs probabilities)
|
||||
- `src/lib/ml/inference.test.ts` — updated for real model inference
|
||||
|
||||
steps:
|
||||
|
||||
1. **Determine output interpretation** — inspect the graph topology to resolve whether `Identity:0` is pre-softmax logits or post-softmax probabilities:
|
||||
- The model graph contains a `Softmax` node at `StatefulPartitionedCall/mnv2_pv_original_1/dense_1/Softmax`
|
||||
- The output `Identity:0` may be after Softmax (probabilities) or before (logits)
|
||||
- Test: run inference on a zero tensor — if output sums to ~1.0, it's already probabilities; if output has negative values or doesn't sum to 1.0, it's logits
|
||||
- Fix: if output is already probabilities, remove the `softmaxFloat32()` call in `inference.ts` and use the raw output directly
|
||||
|
||||
2. **Fix input tensor format** — the model expects NHWC `[1, 160, 160, 3]` but our pipeline produces NCHW `[3, 160, 160]`:
|
||||
|
||||
```typescript
|
||||
// Current code in model-loader.ts tryLoadTFJS():
|
||||
const inputTensor = tf
|
||||
.tensor4d(Array.from(tensor), [3, 160, 160])
|
||||
.transpose([1, 2, 0]) // [160, 160, 3]
|
||||
.expandDims(0); // [1, 160, 160, 3] NHWC
|
||||
```
|
||||
|
||||
- Verify this transpose is correct (NCHW → NHWC)
|
||||
- Verify the tensor values are in the expected range (ImageNet-normalized: roughly -2.5 to +2.5)
|
||||
- Alternative: reshape directly as `[1, 160, 160, 3]` if the identify endpoint produces NHWC data
|
||||
|
||||
3. **Install `@tensorflow/tfjs-node`** for server-side native acceleration:
|
||||
|
||||
```bash
|
||||
npm install @tensorflow/tfjs-node
|
||||
```
|
||||
|
||||
- Browser tfjs works on server but is significantly slower (no native BLAS)
|
||||
- `@tensorflow/tfjs-node` uses libtensorflow C library for ~10-100x speedup
|
||||
- Verify native bindings install correctly (may need `@tensorflow/tfjs-node-gpu` for GPU, but CPU is fine for this use case)
|
||||
- Fallback chain remains: tfjs-node → tfjs (browser) → mock
|
||||
|
||||
4. **Verify model loads from filesystem**:
|
||||
|
||||
```typescript
|
||||
const model = await tf.loadGraphModel(`file://${MODEL_JSON_PATH}`);
|
||||
console.log("Model loaded:", model.inputs, model.outputs);
|
||||
// Expected:
|
||||
// inputs: [{ shape: [-1, 160, 160, 3], dtype: 'float32' }]
|
||||
// outputs: [{ shape: [-1, 38], dtype: 'float32' }]
|
||||
```
|
||||
|
||||
- Verify `model.inputs[0].shape` matches `[null, 160, 160, 3]`
|
||||
- Verify `model.outputs[0].shape` matches `[null, 38]`
|
||||
- Verify model has `predict()` method (GraphModel uses `predict()`, not `execute()`)
|
||||
|
||||
5. **Run inference smoke test**:
|
||||
|
||||
```typescript
|
||||
// Create a test tensor (random normalized values)
|
||||
const testTensor = new Float32Array(3 * 160 * 160);
|
||||
for (let i = 0; i < testTensor.length; i++) {
|
||||
testTensor[i] = (Math.random() - 0.5) * 2;
|
||||
}
|
||||
// Reshape to NHWC for TF.js
|
||||
const input = tf.tensor4d(
|
||||
Array.from(testTensor),
|
||||
[1, 160, 160, 3], // NHWC
|
||||
);
|
||||
const output = model.predict(input);
|
||||
const data = await output.data();
|
||||
console.log("Output shape:", output.shape);
|
||||
console.log(
|
||||
"Output sum:",
|
||||
data.reduce((a, b) => a + b, 0),
|
||||
);
|
||||
console.log("Output max:", Math.max(...data));
|
||||
console.log("Output min:", Math.min(...data));
|
||||
```
|
||||
|
||||
- Output should be [1, 38] with 38 float values
|
||||
- If values are probabilities: sum ≈ 1.0, all values ≥ 0
|
||||
- If values are logits: sum ≠ 1.0, may have negative values
|
||||
|
||||
6. **Fix `model-loader.ts` `getStatus()` to report real class count**:
|
||||
|
||||
```typescript
|
||||
getStatus(): ModelStatus {
|
||||
return {
|
||||
loaded: true,
|
||||
backend: "tfjs",
|
||||
modelId: MODEL_ID,
|
||||
numClasses: 38, // PlantVillage, not 95
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
7. **Add memory management** — dispose tensors after use to prevent memory leaks:
|
||||
|
||||
```typescript
|
||||
// In predict():
|
||||
tf.tidy(() => {
|
||||
const input = tf.tensor4d(...);
|
||||
const output = model.predict(input);
|
||||
return output.dataSync();
|
||||
});
|
||||
```
|
||||
|
||||
- Or manually dispose: `inputTensor.dispose()`, `outputTensor.dispose()`
|
||||
- Use `tf.memory()` to monitor tensor count during development
|
||||
|
||||
8. **Handle model load failures gracefully**:
|
||||
- If model files are corrupted, log the specific error
|
||||
- If tfjs-node native bindings fail, fall back to browser tfjs with a warning
|
||||
- Never crash the server on model load failure — fall back to mock mode with clear logging
|
||||
|
||||
tests:
|
||||
|
||||
- Integration: model loads from `public/models/plant-disease-classifier/model.json` without errors
|
||||
- Integration: `model.inputs[0].shape` is `[-1, 160, 160, 3]`
|
||||
- Integration: `model.outputs[0].shape` is `[-1, 38]`
|
||||
- Integration: inference on random tensor produces [38] float output
|
||||
- Integration: if output is probabilities, sum is within 0.99–1.01
|
||||
- Integration: `getStatus()` returns `{ loaded: true, backend: "tfjs", numClasses: 38 }`
|
||||
- Unit: `validateInput()` correctly rejects tensors with wrong length
|
||||
- Unit: NCHW → NHWC transpose produces correct layout
|
||||
- Performance: inference completes in < 500ms on a typical server (with tfjs-node)
|
||||
|
||||
acceptance_criteria:
|
||||
|
||||
- `getModel()` returns a model with `loaded: true` and `backend: "tfjs"`
|
||||
- `model.predict()` on a valid [1, 160, 160, 3] input returns [1, 38] output without errors
|
||||
- Output interpretation is correctly determined (logits vs probabilities) and handled
|
||||
- `@tensorflow/tfjs-node` is installed and used as primary backend
|
||||
- No memory leaks: tensor count stays stable after repeated inference calls
|
||||
- Fallback chain works: tfjs-node → tfjs → mock (each failure logs warning)
|
||||
- Model load time < 30 seconds on first request
|
||||
- Inference time < 500ms per image on server
|
||||
|
||||
validation:
|
||||
|
||||
- `npm install @tensorflow/tfjs-node` — native bindings install successfully
|
||||
- `npx vitest run src/lib/ml/model-loader.test.ts` — all loading tests pass
|
||||
- `npx vitest run src/lib/ml/inference.test.ts` — all inference tests pass
|
||||
- Manual: `curl -X POST http://localhost:3000/api/identify -H "Content-Type: application/json" -d '{"imageId":"<existing-id>"}'` — returns real predictions (no `demo_mode: true`)
|
||||
- Check server logs for `[model-loader] Loaded TF.js model` (not mock fallback)
|
||||
|
||||
notes:
|
||||
|
||||
- The model file `best_mnv2_pv_original.keras` is the original Keras file — the TF.js conversion is already done (model.json + 3 weight shards)
|
||||
- The `.keras` file can be deleted after confirming TF.js works, saving ~27MB
|
||||
- `@tensorflow/tfjs-node` requires libtensorflow — it downloads automatically during npm install
|
||||
- The `file://` protocol for `loadGraphModel` works with `@tensorflow/tfjs-node` but may not work with browser tfjs (which uses fetch) — if using browser tfjs fallback, need to read file and use `tf.io.loadGraphModel` with a custom loader
|
||||
- ImageNet normalization in `preprocessImageBuffer()` uses mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225] — verify this matches what the PlantVillage model expects (it should, since MobileNetV2 is typically trained with ImageNet preprocessing)
|
||||
Reference in New Issue
Block a user