Fix SMSClassifierService initialization race condition (FRE-4509)

Add promise-based lazy initialization via ensureInitialized() to deduplicate
concurrent initialize() calls. Concurrent callers now await the same
initialization promise instead of each triggering a separate load.

Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
2026-04-29 22:25:39 -04:00
parent 3aead0d7bb
commit 2b7ff938da
2 changed files with 166 additions and 7 deletions

View File

@@ -0,0 +1,144 @@
import { describe, it, expect, beforeEach, vi } from 'vitest';
import { SMSClassifierService } from '../services/spamshield/spamshield.service';
// Mock shared-db before anything else (Prisma client is not generated in test env)
vi.mock('@shieldsai/shared-db', () => ({
prisma: {},
SpamFeedback: {},
}));
// Mock the feature flags module to control enableMLClassifier
vi.mock('../services/spamshield/spamshield.config', () => ({
spamShieldEnv: {
SPAM_THRESHOLD_AUTO_BLOCK: 0.85,
SPAM_THRESHOLD_FLAG: 0.6,
},
spamFeatureFlags: {
enableMLClassifier: true,
},
SpamDecision: {
ALLOW: 'allow',
FLAG: 'flag',
BLOCK: 'block',
CHALLENGE: 'challenge',
},
SpamLayer: {
NUMBER_REPUTATION: 'number_reputation',
CONTENT_CLASSIFICATION: 'content_classification',
BEHAVIORAL_ANALYSIS: 'behavioral_analysis',
COMMUNITY_INTELLIGENCE: 'community_intelligence',
},
ConfidenceLevel: {
LOW: 'low',
MEDIUM: 'medium',
HIGH: 'high',
VERY_HIGH: 'very_high',
},
spamRateLimits: {},
}));
describe('SMSClassifierService', () => {
let classifier: SMSClassifierService;
let initializeCalls: number;
let initializeDelay: Promise<void>;
beforeEach(() => {
// Re-import after mock to get fresh module state
initializeCalls = 0;
initializeDelay = new Promise(resolve => setTimeout(resolve, 50));
classifier = new SMSClassifierService();
// Override initialize to track calls and add delay
classifier.initialize = async () => {
initializeCalls++;
await initializeDelay;
};
});
describe('initialization race condition', () => {
it('should call initialize only once under concurrent classify calls', async () => {
const promises = Array.from({ length: 10 }, () =>
classifier.classify('ACT NOW - Limited offer!'),
);
const results = await Promise.all(promises);
expect(initializeCalls).toBe(1);
expect(results).toHaveLength(10);
results.forEach(r => {
expect(r).toHaveProperty('isSpam');
expect(r).toHaveProperty('confidence');
expect(r).toHaveProperty('spamFeatures');
});
});
it('should handle interleaved calls after partial initialization', async () => {
const batch1 = Array.from({ length: 5 }, () =>
classifier.classify('First batch message'),
);
await Promise.all(batch1);
expect(initializeCalls).toBe(1);
const batch2 = Array.from({ length: 5 }, () =>
classifier.classify('Second batch message'),
);
await Promise.all(batch2);
// initialize should still only have been called once
expect(initializeCalls).toBe(1);
});
it('should return consistent results for same input under concurrency', async () => {
const text = 'URGENT: Click http://example.com now!';
const promises = Array.from({ length: 20 }, () =>
classifier.classify(text),
);
const results = await Promise.all(promises);
const firstResult = results[0];
results.forEach((r, i) => {
expect(r.isSpam).toBe(firstResult.isSpam);
expect(r.confidence).toBe(firstResult.confidence);
expect(r.spamFeatures).toEqual(firstResult.spamFeatures);
});
});
it('should handle rapid sequential calls without re-initializing', async () => {
for (let i = 0; i < 50; i++) {
await classifier.classify(`Message ${i}`);
}
expect(initializeCalls).toBe(1);
});
});
describe('feature extraction', () => {
it('should detect URL presence', async () => {
const result = await classifier.classify('Visit www.example.com');
expect(result.spamFeatures).toContain('url_present');
});
it('should detect urgency keywords', async () => {
const result = await classifier.classify('Act now! This offer is urgent.');
expect(result.spamFeatures).toContain('urgency_keyword');
});
it('should detect excessive capitalization', async () => {
const result = await classifier.classify('BUY THIS NOW!!!');
expect(result.spamFeatures).toContain('excessive_caps');
});
it('should detect multiple features', async () => {
const result = await classifier.classify(
'URGENT: Visit www.example.com NOW!!!',
);
expect(result.spamFeatures).toContain('url_present');
expect(result.spamFeatures).toContain('urgency_keyword');
expect(result.spamFeatures).toContain('excessive_caps');
});
});
});

View File

@@ -1,6 +1,5 @@
import { prisma, SpamRule, SpamFeedback, User } from '@shieldsai/shared-db';
import { spamShieldEnv, SpamDecision, ConfidenceLevel, spamFeatureFlags } from './spamshield.config';
import { checkFlag } from './feature-flags';
import { prisma, SpamFeedback } from '@shieldsai/shared-db';
import { spamShieldEnv, SpamDecision, spamFeatureFlags } from './spamshield.config';
import { createHash } from 'crypto';
// Number reputation service (Hiya API integration)
@@ -91,9 +90,10 @@ export class NumberReputationService {
// SMS content classifier (BERT-based)
export class SMSClassifierService {
private model: any = null; // BERT model placeholder
private _initPromise: Promise<void> | null = null;
/**
* Initialize the BERT model
* Initialize the BERT model (thread-safe via promise deduplication)
*/
async initialize(): Promise<void> {
// TODO: Load BERT model from path
@@ -101,6 +101,23 @@ export class SMSClassifierService {
console.log('SMS classifier initialized');
}
/**
* Ensures model is initialized before use. Concurrent callers
* await the same initialization promise to avoid race conditions.
*/
private async ensureInitialized(): Promise<void> {
if (this._initPromise) {
return this._initPromise;
}
this._initPromise = (async () => {
if (this.model) {
return;
}
await this.initialize();
})();
return this._initPromise;
}
/**
* Classify SMS text as spam or ham
*/
@@ -123,9 +140,7 @@ export class SMSClassifierService {
};
}
if (!this.model) {
await this.initialize();
}
await this.ensureInitialized();
// Extract features
const features = this.extractFeatures(smsText);