beepboop
This commit is contained in:
@@ -196,8 +196,8 @@ async function tryLoadTFJS(): Promise<PlantDiseaseModel | null> {
|
||||
async predict(tensor: Float32Array): Promise<ModelOutput> {
|
||||
const startTime = performance.now();
|
||||
|
||||
// Reshape to [1, 3, 224, 224] NCHW → [1, 224, 224, 3] NHWC for TF.js
|
||||
const inputTensor = tf.tensor4d(Array.from(tensor), [3, 224, 224])
|
||||
// Reshape to [1, 3, 160, 160] NCHW → [1, 160, 160, 3] NHWC for TF.js
|
||||
const inputTensor = tf.tensor4d(Array.from(tensor), [3, 160, 160])
|
||||
.transpose([1, 2, 0])
|
||||
.expandDims(0);
|
||||
|
||||
@@ -220,7 +220,7 @@ async function tryLoadTFJS(): Promise<PlantDiseaseModel | null> {
|
||||
loaded: true,
|
||||
backend: "tfjs",
|
||||
modelId: MODEL_ID,
|
||||
numClasses: 95, // Will be updated after model loads
|
||||
numClasses: 38, // Original PlantVillage model
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -256,8 +256,8 @@ async function tryLoadONNX(): Promise<PlantDiseaseModel | null> {
|
||||
async predict(tensor: Float32Array): Promise<ModelOutput> {
|
||||
const startTime = performance.now();
|
||||
|
||||
// ONNX expects NCHW format: [1, 3, 224, 224]
|
||||
const inputTensor = new ort.Tensor("float32", tensor, [1, 3, 224, 224]);
|
||||
// ONNX expects NCHW format: [1, 3, 160, 160]
|
||||
const inputTensor = new ort.Tensor("float32", tensor, [1, 3, 160, 160]);
|
||||
const feeds = { [session.inputNames[0]]: inputTensor };
|
||||
const results = await session.run(feeds);
|
||||
|
||||
@@ -278,7 +278,7 @@ async function tryLoadONNX(): Promise<PlantDiseaseModel | null> {
|
||||
loaded: true,
|
||||
backend: "onnx",
|
||||
modelId: MODEL_ID,
|
||||
numClasses: 95,
|
||||
numClasses: 38,
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -313,7 +313,7 @@ function createMockModel(): PlantDiseaseModel {
|
||||
loaded: false,
|
||||
backend: "mock",
|
||||
modelId: MODEL_ID,
|
||||
numClasses: 95,
|
||||
numClasses: 38,
|
||||
error: "Model files not found. Running in demo mode with mock predictions.",
|
||||
};
|
||||
},
|
||||
@@ -326,7 +326,7 @@ function createMockModel(): PlantDiseaseModel {
|
||||
* reproducible but varied predictions.
|
||||
*/
|
||||
function generateMockLogits(tensor: Float32Array): Float32Array {
|
||||
const numClasses = 95;
|
||||
const numClasses = 38;
|
||||
const logits = new Float32Array(numClasses);
|
||||
|
||||
// Simple hash of input for deterministic output
|
||||
|
||||
Reference in New Issue
Block a user