Files
plant-disease-id/apps/web/tasks/production-ml-pipeline/03-model-loading-verification.md
2026-06-06 15:09:46 -04:00

7.5 KiB
Raw Blame History

03. TensorFlow.js Model Loading Verification and Fixes

meta: id: production-ml-pipeline-03 feature: production-ml-pipeline priority: P0 depends_on: [] tags: [implementation, model, tests-required]

objective:

  • Verify the converted TF.js GraphModel loads successfully on the Node.js server
  • Fix input tensor format handling (NCHW pipeline input → NHWC model input)
  • Determine whether model output is logits or pre-computed softmax probabilities
  • Ensure inference produces valid [1, 38] output without errors
  • Install @tensorflow/tfjs-node for server-side native acceleration

deliverables:

  • src/lib/ml/model-loader.ts — fixed and verified for real model loading
  • src/lib/ml/model-loader.test.ts — updated integration tests
  • package.json@tensorflow/tfjs-node added as dependency (if needed)
  • src/lib/ml/inference.ts — fixed output interpretation (logits vs probabilities)
  • src/lib/ml/inference.test.ts — updated for real model inference

steps:

  1. Determine output interpretation — inspect the graph topology to resolve whether Identity:0 is pre-softmax logits or post-softmax probabilities:

    • The model graph contains a Softmax node at StatefulPartitionedCall/mnv2_pv_original_1/dense_1/Softmax
    • The output Identity:0 may be after Softmax (probabilities) or before (logits)
    • Test: run inference on a zero tensor — if output sums to ~1.0, it's already probabilities; if output has negative values or doesn't sum to 1.0, it's logits
    • Fix: if output is already probabilities, remove the softmaxFloat32() call in inference.ts and use the raw output directly
  2. Fix input tensor format — the model expects NHWC [1, 160, 160, 3] but our pipeline produces NCHW [3, 160, 160]:

    // Current code in model-loader.ts tryLoadTFJS():
    const inputTensor = tf
      .tensor4d(Array.from(tensor), [3, 160, 160])
      .transpose([1, 2, 0]) // [160, 160, 3]
      .expandDims(0); // [1, 160, 160, 3] NHWC
    
    • Verify this transpose is correct (NCHW → NHWC)
    • Verify the tensor values are in the expected range (ImageNet-normalized: roughly -2.5 to +2.5)
    • Alternative: reshape directly as [1, 160, 160, 3] if the identify endpoint produces NHWC data
  3. Install @tensorflow/tfjs-node for server-side native acceleration:

    npm install @tensorflow/tfjs-node
    
    • Browser tfjs works on server but is significantly slower (no native BLAS)
    • @tensorflow/tfjs-node uses libtensorflow C library for ~10-100x speedup
    • Verify native bindings install correctly (may need @tensorflow/tfjs-node-gpu for GPU, but CPU is fine for this use case)
    • Fallback chain remains: tfjs-node → tfjs (browser) → mock
  4. Verify model loads from filesystem:

    const model = await tf.loadGraphModel(`file://${MODEL_JSON_PATH}`);
    console.log("Model loaded:", model.inputs, model.outputs);
    // Expected:
    // inputs: [{ shape: [-1, 160, 160, 3], dtype: 'float32' }]
    // outputs: [{ shape: [-1, 38], dtype: 'float32' }]
    
    • Verify model.inputs[0].shape matches [null, 160, 160, 3]
    • Verify model.outputs[0].shape matches [null, 38]
    • Verify model has predict() method (GraphModel uses predict(), not execute())
  5. Run inference smoke test:

    // Create a test tensor (random normalized values)
    const testTensor = new Float32Array(3 * 160 * 160);
    for (let i = 0; i < testTensor.length; i++) {
      testTensor[i] = (Math.random() - 0.5) * 2;
    }
    // Reshape to NHWC for TF.js
    const input = tf.tensor4d(
      Array.from(testTensor),
      [1, 160, 160, 3], // NHWC
    );
    const output = model.predict(input);
    const data = await output.data();
    console.log("Output shape:", output.shape);
    console.log(
      "Output sum:",
      data.reduce((a, b) => a + b, 0),
    );
    console.log("Output max:", Math.max(...data));
    console.log("Output min:", Math.min(...data));
    
    • Output should be [1, 38] with 38 float values
    • If values are probabilities: sum ≈ 1.0, all values ≥ 0
    • If values are logits: sum ≠ 1.0, may have negative values
  6. Fix model-loader.ts getStatus() to report real class count:

    getStatus(): ModelStatus {
      return {
        loaded: true,
        backend: "tfjs",
        modelId: MODEL_ID,
        numClasses: 38,  // PlantVillage, not 95
      };
    }
    
  7. Add memory management — dispose tensors after use to prevent memory leaks:

    // In predict():
    tf.tidy(() => {
      const input = tf.tensor4d(...);
      const output = model.predict(input);
      return output.dataSync();
    });
    
    • Or manually dispose: inputTensor.dispose(), outputTensor.dispose()
    • Use tf.memory() to monitor tensor count during development
  8. Handle model load failures gracefully:

    • If model files are corrupted, log the specific error
    • If tfjs-node native bindings fail, fall back to browser tfjs with a warning
    • Never crash the server on model load failure — fall back to mock mode with clear logging

tests:

  • Integration: model loads from public/models/plant-disease-classifier/model.json without errors
  • Integration: model.inputs[0].shape is [-1, 160, 160, 3]
  • Integration: model.outputs[0].shape is [-1, 38]
  • Integration: inference on random tensor produces [38] float output
  • Integration: if output is probabilities, sum is within 0.991.01
  • Integration: getStatus() returns { loaded: true, backend: "tfjs", numClasses: 38 }
  • Unit: validateInput() correctly rejects tensors with wrong length
  • Unit: NCHW → NHWC transpose produces correct layout
  • Performance: inference completes in < 500ms on a typical server (with tfjs-node)

acceptance_criteria:

  • getModel() returns a model with loaded: true and backend: "tfjs"
  • model.predict() on a valid [1, 160, 160, 3] input returns [1, 38] output without errors
  • Output interpretation is correctly determined (logits vs probabilities) and handled
  • @tensorflow/tfjs-node is installed and used as primary backend
  • No memory leaks: tensor count stays stable after repeated inference calls
  • Fallback chain works: tfjs-node → tfjs → mock (each failure logs warning)
  • Model load time < 30 seconds on first request
  • Inference time < 500ms per image on server

validation:

  • npm install @tensorflow/tfjs-node — native bindings install successfully
  • npx vitest run src/lib/ml/model-loader.test.ts — all loading tests pass
  • npx vitest run src/lib/ml/inference.test.ts — all inference tests pass
  • Manual: curl -X POST http://localhost:3000/api/identify -H "Content-Type: application/json" -d '{"imageId":"<existing-id>"}' — returns real predictions (no demo_mode: true)
  • Check server logs for [model-loader] Loaded TF.js model (not mock fallback)

notes:

  • The model file best_mnv2_pv_original.keras is the original Keras file — the TF.js conversion is already done (model.json + 3 weight shards)
  • The .keras file can be deleted after confirming TF.js works, saving ~27MB
  • @tensorflow/tfjs-node requires libtensorflow — it downloads automatically during npm install
  • The file:// protocol for loadGraphModel works with @tensorflow/tfjs-node but may not work with browser tfjs (which uses fetch) — if using browser tfjs fallback, need to read file and use tf.io.loadGraphModel with a custom loader
  • ImageNet normalization in preprocessImageBuffer() uses mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225] — verify this matches what the PlantVillage model expects (it should, since MobileNetV2 is typically trained with ImageNet preprocessing)