FRE-4499: Fix security review findings (S01-S06)

- S01 (High): Pre-compile regex patterns in RuleEngine.loadActiveRules() and
  cache them; eliminate per-evaluation RegExp construction in rule-engine.ts
  and spamshield.service.ts (ReDoS mitigation)
- S02 (High): SMS classifier now accepts optional senderPhoneNumber via
  SmsClassificationContext; reputation check uses actual sender instead of
  hardcoded 'placeholder'
- S03 (Medium): AlertServer (services/spamshield) now enforces JWT auth,
  origin allowlist, and max client limit on WebSocket connections
- S04 (Medium): hashPhoneNumber() uses SHA-256 (crypto.createHash) instead
  of reversible hex encoding (Buffer.toString('hex'))
- S05 (Medium): DecisionEngine.evaluate() wraps evaluation in Promise.race
  with configurable evaluationTimeout; returns fallback decision on timeout
- S06 (Medium): CarrierFactory.getAllCarriers() is now async and properly
  awaits isHealthy() promises instead of returning raw Promise objects

Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
2026-05-02 15:58:49 -04:00
parent 24bc9c235f
commit 274afa6335
6 changed files with 152 additions and 87 deletions

View File

@@ -90,13 +90,14 @@ export class CarrierFactory {
} }
} }
getAllCarriers(): Array<{ type: CarrierType; healthy: boolean }> { async getAllCarriers(): Promise<Array<{ type: CarrierType; healthy: boolean }>> {
const results: Array<{ type: CarrierType; healthy: boolean }> = []; const results: Array<{ type: CarrierType; healthy: boolean }> = [];
for (const [type, carrier] of this.carriers.entries()) { for (const [type, carrier] of this.carriers.entries()) {
const healthy = await carrier.isHealthy();
results.push({ results.push({
type, type,
healthy: carrier.isHealthy(), healthy,
}); });
} }

View File

@@ -20,8 +20,13 @@ export interface SmsClassificationResult {
}; };
} }
export interface SmsClassificationContext {
text: string;
senderPhoneNumber?: string;
}
export interface SmsClassifier { export interface SmsClassifier {
classify(text: string): Promise<SmsClassificationResult>; classify(textOrContext: string | SmsClassificationContext): Promise<SmsClassificationResult>;
getMetrics(): { getMetrics(): {
totalClassified: number; totalClassified: number;
spamDetected: number; spamDetected: number;
@@ -44,7 +49,10 @@ export class BertSmsClassifier implements SmsClassifier {
this.spamShield = spamShield; this.spamShield = spamShield;
} }
async classify(text: string): Promise<SmsClassificationResult> { async classify(textOrContext: string | SmsClassificationContext): Promise<SmsClassificationResult> {
const text = typeof textOrContext === 'string' ? textOrContext : textOrContext.text;
const senderPhoneNumber = typeof textOrContext === 'string' ? undefined : textOrContext.senderPhoneNumber;
// Feature 1: Language Analysis // Feature 1: Language Analysis
const language = this.analyzeLanguage(text); const language = this.analyzeLanguage(text);
@@ -85,9 +93,11 @@ export class BertSmsClassifier implements SmsClassifier {
} }
// Combine with reputation score if available // Combine with reputation score if available
const reputation = await this.spamShield.checkReputation('placeholder'); if (senderPhoneNumber) {
if (reputation.isSpam) { const reputation = await this.spamShield.checkReputation(senderPhoneNumber);
spamScore += REPUTATION_SCORE_WEIGHT; if (reputation.isSpam) {
spamScore += REPUTATION_SCORE_WEIGHT;
}
} }
const isSpam = spamScore > SMS_SPAM_THRESHOLD; const isSpam = spamScore > SMS_SPAM_THRESHOLD;

View File

@@ -116,8 +116,23 @@ export class DecisionEngine {
async evaluate(context: DecisionContext): Promise<DecisionResult> { async evaluate(context: DecisionContext): Promise<DecisionResult> {
const startTime = Date.now(); const startTime = Date.now();
const reqId = context.requestId ?? 'unknown'; const reqId = context.requestId ?? 'unknown';
const fallback: DecisionResult = {
decision: this.config.fallbackDecision,
confidence: 0.5,
reasons: ['Fallback decision due to evaluation timeout'],
fallbackDecision: this.config.fallbackDecision,
scoring: {
reputationScore: 0.5,
ruleScore: 0.5,
behavioralScore: 0.5,
userHistoryScore: 0.5,
totalScore: 0.5,
},
executedAt: new Date(),
requestId: reqId,
};
try { const evaluation = (async () => {
const [reputationScore, ruleScore, behavioralScore, userHistoryScore] = await Promise.all([ const [reputationScore, ruleScore, behavioralScore, userHistoryScore] = await Promise.all([
this.calculateReputationScore(context.cachedReputation), this.calculateReputationScore(context.cachedReputation),
this.calculateRuleScore(context.ruleMatches), this.calculateRuleScore(context.ruleMatches),
@@ -151,25 +166,25 @@ export class DecisionEngine {
executedAt: new Date(), executedAt: new Date(),
requestId: reqId, requestId: reqId,
}; };
})();
try {
const result = await Promise.race([
evaluation,
new Promise<DecisionResult>((resolve) => {
setTimeout(() => {
console.log(`[DecisionEngine] [${reqId}] Evaluation timeout after ${this.config.evaluationTimeout}ms`);
resolve(fallback);
}, this.config.evaluationTimeout);
}),
]);
return result;
} catch (error) { } catch (error) {
console.error(`[DecisionEngine] [${reqId}] Evaluation error:`, error); console.error(`[DecisionEngine] [${reqId}] Evaluation error:`, error);
if (this.config.fallbackOnTimeout) { if (this.config.fallbackOnTimeout) {
return { return { ...fallback, reasons: ['Fallback decision due to evaluation error'] };
decision: this.config.fallbackDecision,
confidence: 0.5,
reasons: ['Fallback decision due to evaluation error'],
fallbackDecision: this.config.fallbackDecision,
scoring: {
reputationScore: 0.5,
ruleScore: 0.5,
behavioralScore: 0.5,
userHistoryScore: 0.5,
totalScore: 0.5,
},
executedAt: new Date(),
requestId: reqId,
};
} }
throw error; throw error;

View File

@@ -2,6 +2,12 @@ import { PrismaClient, SpamRule } from '@prisma/client';
import { generateRequestId } from '@shieldai/types'; import { generateRequestId } from '@shieldai/types';
import { validateRegexPattern, RegexValidationError } from '../utils/regex-validation'; import { validateRegexPattern, RegexValidationError } from '../utils/regex-validation';
export interface CompiledRule {
rule: SpamRule;
compiledPattern: RegExp;
compiledCaseInsensitive?: RegExp;
}
export interface RuleMatch { export interface RuleMatch {
ruleId: string; ruleId: string;
ruleName: string; ruleName: string;
@@ -25,10 +31,10 @@ const DEFAULT_CONFIG: Required<RuleEngineConfig> = {
export class RuleEngine { export class RuleEngine {
private readonly config: Required<RuleEngineConfig>; private readonly config: Required<RuleEngineConfig>;
private numberPatternRules: SpamRule[] = []; private numberPatternRules: CompiledRule[] = [];
private behavioralRules: SpamRule[] = []; private behavioralRules: CompiledRule[] = [];
private contentRules: SpamRule[] = []; private contentRules: CompiledRule[] = [];
private allRules: SpamRule[] = []; private allRules: CompiledRule[] = [];
private lastLoadTime: Date | null = null; private lastLoadTime: Date | null = null;
private readonly prisma: PrismaClient; private readonly prisma: PrismaClient;
@@ -52,11 +58,17 @@ export class RuleEngine {
orderBy: { priority: 'desc' }, orderBy: { priority: 'desc' },
}); });
const validatedRules: SpamRule[] = []; const compiledRules: CompiledRule[] = [];
for (const rule of rules) { for (const rule of rules) {
try { try {
validateRegexPattern(rule.pattern); validateRegexPattern(rule.pattern);
validatedRules.push(rule); const compiledPattern = new RegExp(rule.pattern);
const compiledCaseInsensitive = new RegExp(rule.pattern, 'i');
compiledRules.push({
rule,
compiledPattern,
compiledCaseInsensitive,
});
} catch (error) { } catch (error) {
if (error instanceof RegexValidationError) { if (error instanceof RegexValidationError) {
console.warn(`[RuleEngine] [req:${generateRequestId()}] Rule "${rule.name}" (${rule.id}) ReDoS risk: ${error.reason}, skipping`); console.warn(`[RuleEngine] [req:${generateRequestId()}] Rule "${rule.name}" (${rule.id}) ReDoS risk: ${error.reason}, skipping`);
@@ -66,10 +78,10 @@ export class RuleEngine {
} }
} }
this.allRules = validatedRules; this.allRules = compiledRules;
this.numberPatternRules = validatedRules.filter(r => (r as any).category === 'number_pattern'); this.numberPatternRules = compiledRules.filter(r => (r.rule as any).category === 'number_pattern');
this.behavioralRules = validatedRules.filter(r => (r as any).category === 'behavioral'); this.behavioralRules = compiledRules.filter(r => (r.rule as any).category === 'behavioral');
this.contentRules = validatedRules.filter(r => (r as any).category === 'content'); this.contentRules = compiledRules.filter(r => (r.rule as any).category === 'content');
this.lastLoadTime = now; this.lastLoadTime = now;
} }
@@ -80,26 +92,20 @@ export class RuleEngine {
const matches: RuleMatch[] = []; const matches: RuleMatch[] = [];
for (const rule of this.allRules) { for (const compiled of this.allRules) {
try { try {
validateRegexPattern(rule.pattern); if (compiled.compiledPattern.test(phoneNumber)) {
const pattern = new RegExp(rule.pattern);
if (pattern.test(phoneNumber)) {
matches.push({ matches.push({
ruleId: rule.id, ruleId: compiled.rule.id,
ruleName: rule.name, ruleName: compiled.rule.name,
pattern: rule.pattern, pattern: compiled.rule.pattern,
score: (rule as any).score, score: (compiled.rule as any).score,
priority: (rule as any).priority as 'high' | 'medium' | 'low', priority: (compiled.rule as any).priority as 'high' | 'medium' | 'low',
matchedAt: new Date(), matchedAt: new Date(),
}); });
} }
} catch (error) { } catch (error) {
if (error instanceof RegexValidationError) { console.error(`[RuleEngine] [req:${generateRequestId()}] Evaluation error for rule ${compiled.rule.id}:`, error);
console.warn(`[RuleEngine] [req:${generateRequestId()}] Rule "${rule.name}" (${rule.id}) ReDoS risk at eval: ${error.reason}`);
} else {
console.error(`[RuleEngine] [req:${generateRequestId()}] Invalid pattern for rule ${rule.id}:`, error);
}
} }
} }
@@ -113,26 +119,20 @@ export class RuleEngine {
const matches: RuleMatch[] = []; const matches: RuleMatch[] = [];
for (const rule of this.contentRules) { for (const compiled of this.contentRules) {
try { try {
validateRegexPattern(rule.pattern); if (compiled.compiledCaseInsensitive!.test(smsBody)) {
const pattern = new RegExp(rule.pattern, 'i');
if (pattern.test(smsBody)) {
matches.push({ matches.push({
ruleId: rule.id, ruleId: compiled.rule.id,
ruleName: rule.name, ruleName: compiled.rule.name,
pattern: rule.pattern, pattern: compiled.rule.pattern,
score: (rule as any).score, score: (compiled.rule as any).score,
priority: (rule as any).priority as 'high' | 'medium' | 'low', priority: (compiled.rule as any).priority as 'high' | 'medium' | 'low',
matchedAt: new Date(), matchedAt: new Date(),
}); });
} }
} catch (error) { } catch (error) {
if (error instanceof RegexValidationError) { console.error(`[RuleEngine] [req:${generateRequestId()}] SMS evaluation error for rule ${compiled.rule.id}:`, error);
console.warn(`[RuleEngine] [req:${generateRequestId()}] Rule "${rule.name}" (${rule.id}) ReDoS risk at eval: ${error.reason}`);
} else {
console.error(`[RuleEngine] [req:${generateRequestId()}] Invalid pattern for rule ${rule.id}:`, error);
}
} }
} }
@@ -140,19 +140,19 @@ export class RuleEngine {
} }
getNumberPatternRules(): SpamRule[] { getNumberPatternRules(): SpamRule[] {
return [...this.numberPatternRules]; return this.numberPatternRules.map(r => r.rule);
} }
getBehavioralRules(): SpamRule[] { getBehavioralRules(): SpamRule[] {
return [...this.behavioralRules]; return this.behavioralRules.map(r => r.rule);
} }
getContentRules(): SpamRule[] { getContentRules(): SpamRule[] {
return [...this.contentRules]; return this.contentRules.map(r => r.rule);
} }
getAllRules(): SpamRule[] { getAllRules(): SpamRule[] {
return [...this.allRules]; return this.allRules.map(r => r.rule);
} }
async refreshRules(): Promise<void> { async refreshRules(): Promise<void> {

View File

@@ -202,20 +202,12 @@ export class SpamShieldService {
} }
const validated = this.validatePhoneNumber(phoneNumber); const validated = this.validatePhoneNumber(phoneNumber);
const rules = await this.getActiveRules(); const matches = this.ruleEngine
? await this.ruleEngine.evaluate(validated)
: [];
const ruleMatches: string[] = []; const ruleMatchIds = matches.map(m => m.ruleId);
let confidence = 0; const confidence = Math.min(matches.reduce((sum, m) => sum + m.score, 0), 1.0);
for (const rule of rules) {
const pattern = new RegExp(rule.pattern);
if (pattern.test(validated)) {
ruleMatches.push(rule.id);
confidence += 0.2;
}
}
confidence = Math.min(confidence, 1.0);
const decision = confidence > 0.8 ? 'BLOCK' : confidence > 0.5 ? 'FLAG' : 'ALLOW'; const decision = confidence > 0.8 ? 'BLOCK' : confidence > 0.5 ? 'FLAG' : 'ALLOW';
const auditLog = await prisma.spamAuditLog.create({ const auditLog = await prisma.spamAuditLog.create({
@@ -224,7 +216,7 @@ export class SpamShieldService {
phoneNumber: validated, phoneNumber: validated,
decision: decision as any, decision: decision as any,
reason: `Rule-based analysis`, reason: `Rule-based analysis`,
ruleId: ruleMatches[0], ruleId: ruleMatchIds[0],
}, },
}); });
@@ -235,11 +227,11 @@ export class SpamShieldService {
validated, validated,
decision, decision,
confidence, confidence,
ruleMatches ruleMatchIds
).catch((err) => console.error(`[Correlation] SpamShield emit failed:`, err)); ).catch((err) => console.error(`[Correlation] SpamShield emit failed:`, err));
} }
return { decision, confidence, ruleMatches }; return { decision, confidence, ruleMatches: ruleMatchIds };
} }
async recordFeedback( async recordFeedback(

View File

@@ -1,4 +1,5 @@
import { WebSocketServer, WebSocket } from 'ws'; import { WebSocketServer, WebSocket } from 'ws';
import { createHash } from 'crypto';
import { DecisionResult } from '../engine/decision-engine'; import { DecisionResult } from '../engine/decision-engine';
export interface AlertEvent { export interface AlertEvent {
@@ -29,14 +30,20 @@ export interface AlertServerConfig {
heartbeatIntervalMs?: number; heartbeatIntervalMs?: number;
maxClients?: number; maxClients?: number;
enableLogging?: boolean; enableLogging?: boolean;
enableAuth?: boolean;
jwtSecret?: string;
allowedOrigins?: string[];
} }
const DEFAULT_CONFIG: Required<AlertServerConfig> = { const DEFAULT_CONFIG: Required<AlertServerConfig> = {
port: 8080, port: 8080,
host: '0.0.0.0', host: '0.0.0.0',
heartbeatIntervalMs: 30000, heartbeatIntervalMs: 30000,
maxClients: 1000, maxClients: 100,
enableLogging: true, enableLogging: true,
enableAuth: true,
jwtSecret: process.env.SPAMSHIELD_JWT_SECRET || '',
allowedOrigins: ['http://localhost:3000'],
}; };
export class AlertServer { export class AlertServer {
@@ -57,7 +64,32 @@ export class AlertServer {
} }
private setupWebSocketHandlers(): void { private setupWebSocketHandlers(): void {
this.wss.on('connection', (ws: WebSocket, req: any) => { this.wss.on('connection', async (ws: WebSocket, req: any) => {
const origin = req.headers.origin;
if (origin && this.config.allowedOrigins.length > 0 && !this.config.allowedOrigins.includes(origin)) {
ws.close(1008, 'Origin not allowed');
return;
}
if (this.config.enableAuth && this.config.jwtSecret) {
const authHeader = req.headers.authorization;
if (!authHeader || !authHeader.startsWith('Bearer ')) {
ws.close(4001, 'Missing or invalid JWT token');
return;
}
const token = authHeader.substring(7);
const valid = await this.verifyJWT(token);
if (!valid) {
ws.close(4002, 'Invalid or expired JWT token');
return;
}
}
if (this.clients.size >= this.config.maxClients) {
ws.close(1013, 'Too many clients');
return;
}
const clientId = req.headers['x-client-id'] as string || `client-${Date.now()}-${Math.random()}`; const clientId = req.headers['x-client-id'] as string || `client-${Date.now()}-${Math.random()}`;
const subscription: ClientSubscription = { const subscription: ClientSubscription = {
@@ -281,6 +313,21 @@ export class AlertServer {
} }
private hashPhoneNumber(phoneNumber: string): string { private hashPhoneNumber(phoneNumber: string): string {
return Buffer.from(phoneNumber).toString('hex'); return createHash('sha256').update(phoneNumber).digest('hex');
}
private async verifyJWT(token: string): Promise<boolean> {
try {
const { jwtVerify } = await import('jose');
await jwtVerify(token, new TextEncoder().encode(this.config.jwtSecret), {
algorithms: ['HS256'],
});
return true;
} catch {
if (this.config.enableLogging) {
console.log('[AlertServer] JWT verification failed');
}
return false;
}
} }
} }