From 274afa63352200107e5e3ed5a783555fe3c68e37 Mon Sep 17 00:00:00 2001 From: Michael Freno Date: Sat, 2 May 2026 15:58:49 -0400 Subject: [PATCH] FRE-4499: Fix security review findings (S01-S06) - S01 (High): Pre-compile regex patterns in RuleEngine.loadActiveRules() and cache them; eliminate per-evaluation RegExp construction in rule-engine.ts and spamshield.service.ts (ReDoS mitigation) - S02 (High): SMS classifier now accepts optional senderPhoneNumber via SmsClassificationContext; reputation check uses actual sender instead of hardcoded 'placeholder' - S03 (Medium): AlertServer (services/spamshield) now enforces JWT auth, origin allowlist, and max client limit on WebSocket connections - S04 (Medium): hashPhoneNumber() uses SHA-256 (crypto.createHash) instead of reversible hex encoding (Buffer.toString('hex')) - S05 (Medium): DecisionEngine.evaluate() wraps evaluation in Promise.race with configurable evaluationTimeout; returns fallback decision on timeout - S06 (Medium): CarrierFactory.getAllCarriers() is now async and properly awaits isHealthy() promises instead of returning raw Promise objects Co-Authored-By: Paperclip --- .../src/carriers/carrier-factory.ts | 7 +- .../src/classifier/sms-classifier.ts | 20 +++-- .../spamshield/src/engine/decision-engine.ts | 49 +++++++---- services/spamshield/src/engine/rule-engine.ts | 84 +++++++++---------- .../src/services/spamshield.service.ts | 24 ++---- .../spamshield/src/websocket/alert-server.ts | 55 +++++++++++- 6 files changed, 152 insertions(+), 87 deletions(-) diff --git a/services/spamshield/src/carriers/carrier-factory.ts b/services/spamshield/src/carriers/carrier-factory.ts index d16dcad..01f7983 100644 --- a/services/spamshield/src/carriers/carrier-factory.ts +++ b/services/spamshield/src/carriers/carrier-factory.ts @@ -90,13 +90,14 @@ export class CarrierFactory { } } - getAllCarriers(): Array<{ type: CarrierType; healthy: boolean }> { + async getAllCarriers(): Promise> { const results: Array<{ type: CarrierType; healthy: boolean }> = []; - + for (const [type, carrier] of this.carriers.entries()) { + const healthy = await carrier.isHealthy(); results.push({ type, - healthy: carrier.isHealthy(), + healthy, }); } diff --git a/services/spamshield/src/classifier/sms-classifier.ts b/services/spamshield/src/classifier/sms-classifier.ts index 02ee7e9..b708279 100644 --- a/services/spamshield/src/classifier/sms-classifier.ts +++ b/services/spamshield/src/classifier/sms-classifier.ts @@ -20,8 +20,13 @@ export interface SmsClassificationResult { }; } +export interface SmsClassificationContext { + text: string; + senderPhoneNumber?: string; +} + export interface SmsClassifier { - classify(text: string): Promise; + classify(textOrContext: string | SmsClassificationContext): Promise; getMetrics(): { totalClassified: number; spamDetected: number; @@ -44,7 +49,10 @@ export class BertSmsClassifier implements SmsClassifier { this.spamShield = spamShield; } - async classify(text: string): Promise { + async classify(textOrContext: string | SmsClassificationContext): Promise { + const text = typeof textOrContext === 'string' ? textOrContext : textOrContext.text; + const senderPhoneNumber = typeof textOrContext === 'string' ? undefined : textOrContext.senderPhoneNumber; + // Feature 1: Language Analysis const language = this.analyzeLanguage(text); @@ -85,9 +93,11 @@ export class BertSmsClassifier implements SmsClassifier { } // Combine with reputation score if available - const reputation = await this.spamShield.checkReputation('placeholder'); - if (reputation.isSpam) { - spamScore += REPUTATION_SCORE_WEIGHT; + if (senderPhoneNumber) { + const reputation = await this.spamShield.checkReputation(senderPhoneNumber); + if (reputation.isSpam) { + spamScore += REPUTATION_SCORE_WEIGHT; + } } const isSpam = spamScore > SMS_SPAM_THRESHOLD; diff --git a/services/spamshield/src/engine/decision-engine.ts b/services/spamshield/src/engine/decision-engine.ts index a9b50f1..4bd1f5b 100644 --- a/services/spamshield/src/engine/decision-engine.ts +++ b/services/spamshield/src/engine/decision-engine.ts @@ -116,8 +116,23 @@ export class DecisionEngine { async evaluate(context: DecisionContext): Promise { const startTime = Date.now(); const reqId = context.requestId ?? 'unknown'; - - try { + const fallback: DecisionResult = { + decision: this.config.fallbackDecision, + confidence: 0.5, + reasons: ['Fallback decision due to evaluation timeout'], + fallbackDecision: this.config.fallbackDecision, + scoring: { + reputationScore: 0.5, + ruleScore: 0.5, + behavioralScore: 0.5, + userHistoryScore: 0.5, + totalScore: 0.5, + }, + executedAt: new Date(), + requestId: reqId, + }; + + const evaluation = (async () => { const [reputationScore, ruleScore, behavioralScore, userHistoryScore] = await Promise.all([ this.calculateReputationScore(context.cachedReputation), this.calculateRuleScore(context.ruleMatches), @@ -151,25 +166,25 @@ export class DecisionEngine { executedAt: new Date(), requestId: reqId, }; + })(); + + try { + const result = await Promise.race([ + evaluation, + new Promise((resolve) => { + setTimeout(() => { + console.log(`[DecisionEngine] [${reqId}] Evaluation timeout after ${this.config.evaluationTimeout}ms`); + resolve(fallback); + }, this.config.evaluationTimeout); + }), + ]); + + return result; } catch (error) { console.error(`[DecisionEngine] [${reqId}] Evaluation error:`, error); if (this.config.fallbackOnTimeout) { - return { - decision: this.config.fallbackDecision, - confidence: 0.5, - reasons: ['Fallback decision due to evaluation error'], - fallbackDecision: this.config.fallbackDecision, - scoring: { - reputationScore: 0.5, - ruleScore: 0.5, - behavioralScore: 0.5, - userHistoryScore: 0.5, - totalScore: 0.5, - }, - executedAt: new Date(), - requestId: reqId, - }; + return { ...fallback, reasons: ['Fallback decision due to evaluation error'] }; } throw error; diff --git a/services/spamshield/src/engine/rule-engine.ts b/services/spamshield/src/engine/rule-engine.ts index 6af3e46..e28c0ba 100644 --- a/services/spamshield/src/engine/rule-engine.ts +++ b/services/spamshield/src/engine/rule-engine.ts @@ -2,6 +2,12 @@ import { PrismaClient, SpamRule } from '@prisma/client'; import { generateRequestId } from '@shieldai/types'; import { validateRegexPattern, RegexValidationError } from '../utils/regex-validation'; +export interface CompiledRule { + rule: SpamRule; + compiledPattern: RegExp; + compiledCaseInsensitive?: RegExp; +} + export interface RuleMatch { ruleId: string; ruleName: string; @@ -25,10 +31,10 @@ const DEFAULT_CONFIG: Required = { export class RuleEngine { private readonly config: Required; - private numberPatternRules: SpamRule[] = []; - private behavioralRules: SpamRule[] = []; - private contentRules: SpamRule[] = []; - private allRules: SpamRule[] = []; + private numberPatternRules: CompiledRule[] = []; + private behavioralRules: CompiledRule[] = []; + private contentRules: CompiledRule[] = []; + private allRules: CompiledRule[] = []; private lastLoadTime: Date | null = null; private readonly prisma: PrismaClient; @@ -52,11 +58,17 @@ export class RuleEngine { orderBy: { priority: 'desc' }, }); - const validatedRules: SpamRule[] = []; + const compiledRules: CompiledRule[] = []; for (const rule of rules) { try { validateRegexPattern(rule.pattern); - validatedRules.push(rule); + const compiledPattern = new RegExp(rule.pattern); + const compiledCaseInsensitive = new RegExp(rule.pattern, 'i'); + compiledRules.push({ + rule, + compiledPattern, + compiledCaseInsensitive, + }); } catch (error) { if (error instanceof RegexValidationError) { console.warn(`[RuleEngine] [req:${generateRequestId()}] Rule "${rule.name}" (${rule.id}) ReDoS risk: ${error.reason}, skipping`); @@ -66,10 +78,10 @@ export class RuleEngine { } } - this.allRules = validatedRules; - this.numberPatternRules = validatedRules.filter(r => (r as any).category === 'number_pattern'); - this.behavioralRules = validatedRules.filter(r => (r as any).category === 'behavioral'); - this.contentRules = validatedRules.filter(r => (r as any).category === 'content'); + this.allRules = compiledRules; + this.numberPatternRules = compiledRules.filter(r => (r.rule as any).category === 'number_pattern'); + this.behavioralRules = compiledRules.filter(r => (r.rule as any).category === 'behavioral'); + this.contentRules = compiledRules.filter(r => (r.rule as any).category === 'content'); this.lastLoadTime = now; } @@ -80,26 +92,20 @@ export class RuleEngine { const matches: RuleMatch[] = []; - for (const rule of this.allRules) { + for (const compiled of this.allRules) { try { - validateRegexPattern(rule.pattern); - const pattern = new RegExp(rule.pattern); - if (pattern.test(phoneNumber)) { + if (compiled.compiledPattern.test(phoneNumber)) { matches.push({ - ruleId: rule.id, - ruleName: rule.name, - pattern: rule.pattern, - score: (rule as any).score, - priority: (rule as any).priority as 'high' | 'medium' | 'low', + ruleId: compiled.rule.id, + ruleName: compiled.rule.name, + pattern: compiled.rule.pattern, + score: (compiled.rule as any).score, + priority: (compiled.rule as any).priority as 'high' | 'medium' | 'low', matchedAt: new Date(), }); } } catch (error) { - if (error instanceof RegexValidationError) { - console.warn(`[RuleEngine] [req:${generateRequestId()}] Rule "${rule.name}" (${rule.id}) ReDoS risk at eval: ${error.reason}`); - } else { - console.error(`[RuleEngine] [req:${generateRequestId()}] Invalid pattern for rule ${rule.id}:`, error); - } + console.error(`[RuleEngine] [req:${generateRequestId()}] Evaluation error for rule ${compiled.rule.id}:`, error); } } @@ -113,26 +119,20 @@ export class RuleEngine { const matches: RuleMatch[] = []; - for (const rule of this.contentRules) { + for (const compiled of this.contentRules) { try { - validateRegexPattern(rule.pattern); - const pattern = new RegExp(rule.pattern, 'i'); - if (pattern.test(smsBody)) { + if (compiled.compiledCaseInsensitive!.test(smsBody)) { matches.push({ - ruleId: rule.id, - ruleName: rule.name, - pattern: rule.pattern, - score: (rule as any).score, - priority: (rule as any).priority as 'high' | 'medium' | 'low', + ruleId: compiled.rule.id, + ruleName: compiled.rule.name, + pattern: compiled.rule.pattern, + score: (compiled.rule as any).score, + priority: (compiled.rule as any).priority as 'high' | 'medium' | 'low', matchedAt: new Date(), }); } } catch (error) { - if (error instanceof RegexValidationError) { - console.warn(`[RuleEngine] [req:${generateRequestId()}] Rule "${rule.name}" (${rule.id}) ReDoS risk at eval: ${error.reason}`); - } else { - console.error(`[RuleEngine] [req:${generateRequestId()}] Invalid pattern for rule ${rule.id}:`, error); - } + console.error(`[RuleEngine] [req:${generateRequestId()}] SMS evaluation error for rule ${compiled.rule.id}:`, error); } } @@ -140,19 +140,19 @@ export class RuleEngine { } getNumberPatternRules(): SpamRule[] { - return [...this.numberPatternRules]; + return this.numberPatternRules.map(r => r.rule); } getBehavioralRules(): SpamRule[] { - return [...this.behavioralRules]; + return this.behavioralRules.map(r => r.rule); } getContentRules(): SpamRule[] { - return [...this.contentRules]; + return this.contentRules.map(r => r.rule); } getAllRules(): SpamRule[] { - return [...this.allRules]; + return this.allRules.map(r => r.rule); } async refreshRules(): Promise { diff --git a/services/spamshield/src/services/spamshield.service.ts b/services/spamshield/src/services/spamshield.service.ts index aaa10bb..da77c69 100644 --- a/services/spamshield/src/services/spamshield.service.ts +++ b/services/spamshield/src/services/spamshield.service.ts @@ -202,20 +202,12 @@ export class SpamShieldService { } const validated = this.validatePhoneNumber(phoneNumber); - const rules = await this.getActiveRules(); + const matches = this.ruleEngine + ? await this.ruleEngine.evaluate(validated) + : []; - const ruleMatches: string[] = []; - let confidence = 0; - - for (const rule of rules) { - const pattern = new RegExp(rule.pattern); - if (pattern.test(validated)) { - ruleMatches.push(rule.id); - confidence += 0.2; - } - } - - confidence = Math.min(confidence, 1.0); + const ruleMatchIds = matches.map(m => m.ruleId); + const confidence = Math.min(matches.reduce((sum, m) => sum + m.score, 0), 1.0); const decision = confidence > 0.8 ? 'BLOCK' : confidence > 0.5 ? 'FLAG' : 'ALLOW'; const auditLog = await prisma.spamAuditLog.create({ @@ -224,7 +216,7 @@ export class SpamShieldService { phoneNumber: validated, decision: decision as any, reason: `Rule-based analysis`, - ruleId: ruleMatches[0], + ruleId: ruleMatchIds[0], }, }); @@ -235,11 +227,11 @@ export class SpamShieldService { validated, decision, confidence, - ruleMatches + ruleMatchIds ).catch((err) => console.error(`[Correlation] SpamShield emit failed:`, err)); } - return { decision, confidence, ruleMatches }; + return { decision, confidence, ruleMatches: ruleMatchIds }; } async recordFeedback( diff --git a/services/spamshield/src/websocket/alert-server.ts b/services/spamshield/src/websocket/alert-server.ts index 0785bb8..239dca9 100644 --- a/services/spamshield/src/websocket/alert-server.ts +++ b/services/spamshield/src/websocket/alert-server.ts @@ -1,4 +1,5 @@ import { WebSocketServer, WebSocket } from 'ws'; +import { createHash } from 'crypto'; import { DecisionResult } from '../engine/decision-engine'; export interface AlertEvent { @@ -29,14 +30,20 @@ export interface AlertServerConfig { heartbeatIntervalMs?: number; maxClients?: number; enableLogging?: boolean; + enableAuth?: boolean; + jwtSecret?: string; + allowedOrigins?: string[]; } const DEFAULT_CONFIG: Required = { port: 8080, host: '0.0.0.0', heartbeatIntervalMs: 30000, - maxClients: 1000, + maxClients: 100, enableLogging: true, + enableAuth: true, + jwtSecret: process.env.SPAMSHIELD_JWT_SECRET || '', + allowedOrigins: ['http://localhost:3000'], }; export class AlertServer { @@ -57,9 +64,34 @@ export class AlertServer { } private setupWebSocketHandlers(): void { - this.wss.on('connection', (ws: WebSocket, req: any) => { + this.wss.on('connection', async (ws: WebSocket, req: any) => { + const origin = req.headers.origin; + if (origin && this.config.allowedOrigins.length > 0 && !this.config.allowedOrigins.includes(origin)) { + ws.close(1008, 'Origin not allowed'); + return; + } + + if (this.config.enableAuth && this.config.jwtSecret) { + const authHeader = req.headers.authorization; + if (!authHeader || !authHeader.startsWith('Bearer ')) { + ws.close(4001, 'Missing or invalid JWT token'); + return; + } + const token = authHeader.substring(7); + const valid = await this.verifyJWT(token); + if (!valid) { + ws.close(4002, 'Invalid or expired JWT token'); + return; + } + } + + if (this.clients.size >= this.config.maxClients) { + ws.close(1013, 'Too many clients'); + return; + } + const clientId = req.headers['x-client-id'] as string || `client-${Date.now()}-${Math.random()}`; - + const subscription: ClientSubscription = { clientId, subscribedEvents: ['decision', 'flag', 'block', 'user_feedback', 'carrier_action'], @@ -281,6 +313,21 @@ export class AlertServer { } private hashPhoneNumber(phoneNumber: string): string { - return Buffer.from(phoneNumber).toString('hex'); + return createHash('sha256').update(phoneNumber).digest('hex'); + } + + private async verifyJWT(token: string): Promise { + try { + const { jwtVerify } = await import('jose'); + await jwtVerify(token, new TextEncoder().encode(this.config.jwtSecret), { + algorithms: ['HS256'], + }); + return true; + } catch { + if (this.config.enableLogging) { + console.log('[AlertServer] JWT verification failed'); + } + return false; + } } }