Fix SMSClassifierService initialization race condition (FRE-4509)
Add promise-based lazy initialization via ensureInitialized() to deduplicate concurrent initialize() calls. Concurrent callers now await the same initialization promise instead of each triggering a separate load. Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
144
apps/api/src/__tests__/sms-classifier-race-condition.test.ts
Normal file
144
apps/api/src/__tests__/sms-classifier-race-condition.test.ts
Normal file
@@ -0,0 +1,144 @@
|
||||
import { describe, it, expect, beforeEach, vi } from 'vitest';
|
||||
import { SMSClassifierService } from '../services/spamshield/spamshield.service';
|
||||
|
||||
// Mock shared-db before anything else (Prisma client is not generated in test env)
|
||||
vi.mock('@shieldsai/shared-db', () => ({
|
||||
prisma: {},
|
||||
SpamFeedback: {},
|
||||
}));
|
||||
|
||||
// Mock the feature flags module to control enableMLClassifier
|
||||
vi.mock('../services/spamshield/spamshield.config', () => ({
|
||||
spamShieldEnv: {
|
||||
SPAM_THRESHOLD_AUTO_BLOCK: 0.85,
|
||||
SPAM_THRESHOLD_FLAG: 0.6,
|
||||
},
|
||||
spamFeatureFlags: {
|
||||
enableMLClassifier: true,
|
||||
},
|
||||
SpamDecision: {
|
||||
ALLOW: 'allow',
|
||||
FLAG: 'flag',
|
||||
BLOCK: 'block',
|
||||
CHALLENGE: 'challenge',
|
||||
},
|
||||
SpamLayer: {
|
||||
NUMBER_REPUTATION: 'number_reputation',
|
||||
CONTENT_CLASSIFICATION: 'content_classification',
|
||||
BEHAVIORAL_ANALYSIS: 'behavioral_analysis',
|
||||
COMMUNITY_INTELLIGENCE: 'community_intelligence',
|
||||
},
|
||||
ConfidenceLevel: {
|
||||
LOW: 'low',
|
||||
MEDIUM: 'medium',
|
||||
HIGH: 'high',
|
||||
VERY_HIGH: 'very_high',
|
||||
},
|
||||
spamRateLimits: {},
|
||||
}));
|
||||
|
||||
describe('SMSClassifierService', () => {
|
||||
let classifier: SMSClassifierService;
|
||||
let initializeCalls: number;
|
||||
let initializeDelay: Promise<void>;
|
||||
|
||||
beforeEach(() => {
|
||||
// Re-import after mock to get fresh module state
|
||||
initializeCalls = 0;
|
||||
initializeDelay = new Promise(resolve => setTimeout(resolve, 50));
|
||||
|
||||
classifier = new SMSClassifierService();
|
||||
// Override initialize to track calls and add delay
|
||||
classifier.initialize = async () => {
|
||||
initializeCalls++;
|
||||
await initializeDelay;
|
||||
};
|
||||
});
|
||||
|
||||
describe('initialization race condition', () => {
|
||||
it('should call initialize only once under concurrent classify calls', async () => {
|
||||
const promises = Array.from({ length: 10 }, () =>
|
||||
classifier.classify('ACT NOW - Limited offer!'),
|
||||
);
|
||||
|
||||
const results = await Promise.all(promises);
|
||||
|
||||
expect(initializeCalls).toBe(1);
|
||||
expect(results).toHaveLength(10);
|
||||
results.forEach(r => {
|
||||
expect(r).toHaveProperty('isSpam');
|
||||
expect(r).toHaveProperty('confidence');
|
||||
expect(r).toHaveProperty('spamFeatures');
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle interleaved calls after partial initialization', async () => {
|
||||
const batch1 = Array.from({ length: 5 }, () =>
|
||||
classifier.classify('First batch message'),
|
||||
);
|
||||
|
||||
await Promise.all(batch1);
|
||||
|
||||
expect(initializeCalls).toBe(1);
|
||||
|
||||
const batch2 = Array.from({ length: 5 }, () =>
|
||||
classifier.classify('Second batch message'),
|
||||
);
|
||||
|
||||
await Promise.all(batch2);
|
||||
|
||||
// initialize should still only have been called once
|
||||
expect(initializeCalls).toBe(1);
|
||||
});
|
||||
|
||||
it('should return consistent results for same input under concurrency', async () => {
|
||||
const text = 'URGENT: Click http://example.com now!';
|
||||
const promises = Array.from({ length: 20 }, () =>
|
||||
classifier.classify(text),
|
||||
);
|
||||
|
||||
const results = await Promise.all(promises);
|
||||
|
||||
const firstResult = results[0];
|
||||
results.forEach((r, i) => {
|
||||
expect(r.isSpam).toBe(firstResult.isSpam);
|
||||
expect(r.confidence).toBe(firstResult.confidence);
|
||||
expect(r.spamFeatures).toEqual(firstResult.spamFeatures);
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle rapid sequential calls without re-initializing', async () => {
|
||||
for (let i = 0; i < 50; i++) {
|
||||
await classifier.classify(`Message ${i}`);
|
||||
}
|
||||
|
||||
expect(initializeCalls).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('feature extraction', () => {
|
||||
it('should detect URL presence', async () => {
|
||||
const result = await classifier.classify('Visit www.example.com');
|
||||
expect(result.spamFeatures).toContain('url_present');
|
||||
});
|
||||
|
||||
it('should detect urgency keywords', async () => {
|
||||
const result = await classifier.classify('Act now! This offer is urgent.');
|
||||
expect(result.spamFeatures).toContain('urgency_keyword');
|
||||
});
|
||||
|
||||
it('should detect excessive capitalization', async () => {
|
||||
const result = await classifier.classify('BUY THIS NOW!!!');
|
||||
expect(result.spamFeatures).toContain('excessive_caps');
|
||||
});
|
||||
|
||||
it('should detect multiple features', async () => {
|
||||
const result = await classifier.classify(
|
||||
'URGENT: Visit www.example.com NOW!!!',
|
||||
);
|
||||
expect(result.spamFeatures).toContain('url_present');
|
||||
expect(result.spamFeatures).toContain('urgency_keyword');
|
||||
expect(result.spamFeatures).toContain('excessive_caps');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,6 +1,5 @@
|
||||
import { prisma, SpamRule, SpamFeedback, User } from '@shieldsai/shared-db';
|
||||
import { spamShieldEnv, SpamDecision, ConfidenceLevel, spamFeatureFlags } from './spamshield.config';
|
||||
import { checkFlag } from './feature-flags';
|
||||
import { prisma, SpamFeedback } from '@shieldsai/shared-db';
|
||||
import { spamShieldEnv, SpamDecision, spamFeatureFlags } from './spamshield.config';
|
||||
import { createHash } from 'crypto';
|
||||
|
||||
// Number reputation service (Hiya API integration)
|
||||
@@ -91,9 +90,10 @@ export class NumberReputationService {
|
||||
// SMS content classifier (BERT-based)
|
||||
export class SMSClassifierService {
|
||||
private model: any = null; // BERT model placeholder
|
||||
private _initPromise: Promise<void> | null = null;
|
||||
|
||||
/**
|
||||
* Initialize the BERT model
|
||||
* Initialize the BERT model (thread-safe via promise deduplication)
|
||||
*/
|
||||
async initialize(): Promise<void> {
|
||||
// TODO: Load BERT model from path
|
||||
@@ -101,6 +101,23 @@ export class SMSClassifierService {
|
||||
console.log('SMS classifier initialized');
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures model is initialized before use. Concurrent callers
|
||||
* await the same initialization promise to avoid race conditions.
|
||||
*/
|
||||
private async ensureInitialized(): Promise<void> {
|
||||
if (this._initPromise) {
|
||||
return this._initPromise;
|
||||
}
|
||||
this._initPromise = (async () => {
|
||||
if (this.model) {
|
||||
return;
|
||||
}
|
||||
await this.initialize();
|
||||
})();
|
||||
return this._initPromise;
|
||||
}
|
||||
|
||||
/**
|
||||
* Classify SMS text as spam or ham
|
||||
*/
|
||||
@@ -123,9 +140,7 @@ export class SMSClassifierService {
|
||||
};
|
||||
}
|
||||
|
||||
if (!this.model) {
|
||||
await this.initialize();
|
||||
}
|
||||
await this.ensureInitialized();
|
||||
|
||||
// Extract features
|
||||
const features = this.extractFeatures(smsText);
|
||||
|
||||
Reference in New Issue
Block a user