This commit is contained in:
2026-06-06 15:09:46 -04:00
parent 78220d3568
commit 06295c83ca
56 changed files with 12018 additions and 440 deletions

View File

@@ -196,8 +196,8 @@ async function tryLoadTFJS(): Promise<PlantDiseaseModel | null> {
async predict(tensor: Float32Array): Promise<ModelOutput> {
const startTime = performance.now();
// Reshape to [1, 3, 224, 224] NCHW → [1, 224, 224, 3] NHWC for TF.js
const inputTensor = tf.tensor4d(Array.from(tensor), [3, 224, 224])
// Reshape to [1, 3, 160, 160] NCHW → [1, 160, 160, 3] NHWC for TF.js
const inputTensor = tf.tensor4d(Array.from(tensor), [3, 160, 160])
.transpose([1, 2, 0])
.expandDims(0);
@@ -220,7 +220,7 @@ async function tryLoadTFJS(): Promise<PlantDiseaseModel | null> {
loaded: true,
backend: "tfjs",
modelId: MODEL_ID,
numClasses: 95, // Will be updated after model loads
numClasses: 38, // Original PlantVillage model
};
},
};
@@ -256,8 +256,8 @@ async function tryLoadONNX(): Promise<PlantDiseaseModel | null> {
async predict(tensor: Float32Array): Promise<ModelOutput> {
const startTime = performance.now();
// ONNX expects NCHW format: [1, 3, 224, 224]
const inputTensor = new ort.Tensor("float32", tensor, [1, 3, 224, 224]);
// ONNX expects NCHW format: [1, 3, 160, 160]
const inputTensor = new ort.Tensor("float32", tensor, [1, 3, 160, 160]);
const feeds = { [session.inputNames[0]]: inputTensor };
const results = await session.run(feeds);
@@ -278,7 +278,7 @@ async function tryLoadONNX(): Promise<PlantDiseaseModel | null> {
loaded: true,
backend: "onnx",
modelId: MODEL_ID,
numClasses: 95,
numClasses: 38,
};
},
};
@@ -313,7 +313,7 @@ function createMockModel(): PlantDiseaseModel {
loaded: false,
backend: "mock",
modelId: MODEL_ID,
numClasses: 95,
numClasses: 38,
error: "Model files not found. Running in demo mode with mock predictions.",
};
},
@@ -326,7 +326,7 @@ function createMockModel(): PlantDiseaseModel {
* reproducible but varied predictions.
*/
function generateMockLogits(tensor: Float32Array): Float32Array {
const numClasses = 95;
const numClasses = 38;
const logits = new Float32Array(numClasses);
// Simple hash of input for deterministic output