From 2b7ff938da9b1164da3174427d9962fcdd78f70f Mon Sep 17 00:00:00 2001 From: Michael Freno Date: Wed, 29 Apr 2026 22:25:39 -0400 Subject: [PATCH] Fix SMSClassifierService initialization race condition (FRE-4509) Add promise-based lazy initialization via ensureInitialized() to deduplicate concurrent initialize() calls. Concurrent callers now await the same initialization promise instead of each triggering a separate load. Co-Authored-By: Paperclip --- .../sms-classifier-race-condition.test.ts | 144 ++++++++++++++++++ .../services/spamshield/spamshield.service.ts | 29 +++- 2 files changed, 166 insertions(+), 7 deletions(-) create mode 100644 apps/api/src/__tests__/sms-classifier-race-condition.test.ts diff --git a/apps/api/src/__tests__/sms-classifier-race-condition.test.ts b/apps/api/src/__tests__/sms-classifier-race-condition.test.ts new file mode 100644 index 000000000..675c75010 --- /dev/null +++ b/apps/api/src/__tests__/sms-classifier-race-condition.test.ts @@ -0,0 +1,144 @@ +import { describe, it, expect, beforeEach, vi } from 'vitest'; +import { SMSClassifierService } from '../services/spamshield/spamshield.service'; + +// Mock shared-db before anything else (Prisma client is not generated in test env) +vi.mock('@shieldsai/shared-db', () => ({ + prisma: {}, + SpamFeedback: {}, +})); + +// Mock the feature flags module to control enableMLClassifier +vi.mock('../services/spamshield/spamshield.config', () => ({ + spamShieldEnv: { + SPAM_THRESHOLD_AUTO_BLOCK: 0.85, + SPAM_THRESHOLD_FLAG: 0.6, + }, + spamFeatureFlags: { + enableMLClassifier: true, + }, + SpamDecision: { + ALLOW: 'allow', + FLAG: 'flag', + BLOCK: 'block', + CHALLENGE: 'challenge', + }, + SpamLayer: { + NUMBER_REPUTATION: 'number_reputation', + CONTENT_CLASSIFICATION: 'content_classification', + BEHAVIORAL_ANALYSIS: 'behavioral_analysis', + COMMUNITY_INTELLIGENCE: 'community_intelligence', + }, + ConfidenceLevel: { + LOW: 'low', + MEDIUM: 'medium', + HIGH: 'high', + VERY_HIGH: 'very_high', + }, + spamRateLimits: {}, +})); + +describe('SMSClassifierService', () => { + let classifier: SMSClassifierService; + let initializeCalls: number; + let initializeDelay: Promise; + + beforeEach(() => { + // Re-import after mock to get fresh module state + initializeCalls = 0; + initializeDelay = new Promise(resolve => setTimeout(resolve, 50)); + + classifier = new SMSClassifierService(); + // Override initialize to track calls and add delay + classifier.initialize = async () => { + initializeCalls++; + await initializeDelay; + }; + }); + + describe('initialization race condition', () => { + it('should call initialize only once under concurrent classify calls', async () => { + const promises = Array.from({ length: 10 }, () => + classifier.classify('ACT NOW - Limited offer!'), + ); + + const results = await Promise.all(promises); + + expect(initializeCalls).toBe(1); + expect(results).toHaveLength(10); + results.forEach(r => { + expect(r).toHaveProperty('isSpam'); + expect(r).toHaveProperty('confidence'); + expect(r).toHaveProperty('spamFeatures'); + }); + }); + + it('should handle interleaved calls after partial initialization', async () => { + const batch1 = Array.from({ length: 5 }, () => + classifier.classify('First batch message'), + ); + + await Promise.all(batch1); + + expect(initializeCalls).toBe(1); + + const batch2 = Array.from({ length: 5 }, () => + classifier.classify('Second batch message'), + ); + + await Promise.all(batch2); + + // initialize should still only have been called once + expect(initializeCalls).toBe(1); + }); + + it('should return consistent results for same input under concurrency', async () => { + const text = 'URGENT: Click http://example.com now!'; + const promises = Array.from({ length: 20 }, () => + classifier.classify(text), + ); + + const results = await Promise.all(promises); + + const firstResult = results[0]; + results.forEach((r, i) => { + expect(r.isSpam).toBe(firstResult.isSpam); + expect(r.confidence).toBe(firstResult.confidence); + expect(r.spamFeatures).toEqual(firstResult.spamFeatures); + }); + }); + + it('should handle rapid sequential calls without re-initializing', async () => { + for (let i = 0; i < 50; i++) { + await classifier.classify(`Message ${i}`); + } + + expect(initializeCalls).toBe(1); + }); + }); + + describe('feature extraction', () => { + it('should detect URL presence', async () => { + const result = await classifier.classify('Visit www.example.com'); + expect(result.spamFeatures).toContain('url_present'); + }); + + it('should detect urgency keywords', async () => { + const result = await classifier.classify('Act now! This offer is urgent.'); + expect(result.spamFeatures).toContain('urgency_keyword'); + }); + + it('should detect excessive capitalization', async () => { + const result = await classifier.classify('BUY THIS NOW!!!'); + expect(result.spamFeatures).toContain('excessive_caps'); + }); + + it('should detect multiple features', async () => { + const result = await classifier.classify( + 'URGENT: Visit www.example.com NOW!!!', + ); + expect(result.spamFeatures).toContain('url_present'); + expect(result.spamFeatures).toContain('urgency_keyword'); + expect(result.spamFeatures).toContain('excessive_caps'); + }); + }); +}); diff --git a/apps/api/src/services/spamshield/spamshield.service.ts b/apps/api/src/services/spamshield/spamshield.service.ts index f31000ba4..2cbaed271 100644 --- a/apps/api/src/services/spamshield/spamshield.service.ts +++ b/apps/api/src/services/spamshield/spamshield.service.ts @@ -1,6 +1,5 @@ -import { prisma, SpamRule, SpamFeedback, User } from '@shieldsai/shared-db'; -import { spamShieldEnv, SpamDecision, ConfidenceLevel, spamFeatureFlags } from './spamshield.config'; -import { checkFlag } from './feature-flags'; +import { prisma, SpamFeedback } from '@shieldsai/shared-db'; +import { spamShieldEnv, SpamDecision, spamFeatureFlags } from './spamshield.config'; import { createHash } from 'crypto'; // Number reputation service (Hiya API integration) @@ -91,9 +90,10 @@ export class NumberReputationService { // SMS content classifier (BERT-based) export class SMSClassifierService { private model: any = null; // BERT model placeholder + private _initPromise: Promise | null = null; /** - * Initialize the BERT model + * Initialize the BERT model (thread-safe via promise deduplication) */ async initialize(): Promise { // TODO: Load BERT model from path @@ -101,6 +101,23 @@ export class SMSClassifierService { console.log('SMS classifier initialized'); } + /** + * Ensures model is initialized before use. Concurrent callers + * await the same initialization promise to avoid race conditions. + */ + private async ensureInitialized(): Promise { + if (this._initPromise) { + return this._initPromise; + } + this._initPromise = (async () => { + if (this.model) { + return; + } + await this.initialize(); + })(); + return this._initPromise; + } + /** * Classify SMS text as spam or ham */ @@ -123,9 +140,7 @@ export class SMSClassifierService { }; } - if (!this.model) { - await this.initialize(); - } + await this.ensureInitialized(); // Extract features const features = this.extractFeatures(smsText);