import { prisma } from "@shieldai/db"; import { AlertSource, AlertCategory, Severity, EntityType, CorrelationStatus, NormalizedAlertInput, CorrelationGroupOutput, CorrelatedAlertOutput, CorrelationQuery, } from "@shieldai/types"; import { alertNormalizer, AlertNormalizer } from "./normalizer"; const SEVERITY_RANK: Record = { LOW: 0, INFO: 1, MEDIUM: 2, WARNING: 3, HIGH: 4, CRITICAL: 5, }; function higherSeverity(a: string, b: string): string { return SEVERITY_RANK[a] >= SEVERITY_RANK[b] ? a : b; } function entitiesOverlap( a: Array<{ type: string; value: string }>, b: Array<{ type: string; value: string }> ): boolean { for (const ea of a) { for (const eb of b) { if (ea.type === eb.type && ea.value.toLowerCase() === eb.value.toLowerCase()) { return true; } } } return false; } type AlertRow = { id: string; source: string; category: string; severity: string; userId: string; title: string; description: string; entities: unknown; sourceAlertId: string; groupId: string | null; payload: unknown; createdAt: Date; }; type GroupRow = { id: string; userId: string; entities: unknown; highestSeverity: string; status: string; alertCount: number; summary: string | null; resolvedAt: Date | null; createdAt: Date; updatedAt: Date; }; export class CorrelationEngine { private readonly timeWindowMinutes: number; constructor(timeWindowMinutes: number = 30) { this.timeWindowMinutes = timeWindowMinutes; } public async ingestAlert(input: NormalizedAlertInput): Promise { const alert = await (prisma as any).normalizedAlert.create({ data: { source: input.source, category: input.category, severity: input.severity, userId: input.userId, title: input.title, description: input.description, entities: input.entities, sourceAlertId: input.sourceAlertId, payload: input.payload, createdAt: input.timestamp || new Date(), }, }); const correlation = await this.findOrCreateCorrelation(alert as AlertRow); if (correlation) { await (prisma as any).normalizedAlert.update({ where: { id: alert.id }, data: { groupId: correlation.id }, }); const updated = await (prisma as any).normalizedAlert.findUnique({ where: { id: alert.id }, }); return this.toOutput(updated as AlertRow); } return this.toOutput(alert as AlertRow); } private async findOrCreateCorrelation( alert: AlertRow ): Promise { const cutoff = new Date(Date.now() - this.timeWindowMinutes * 60 * 1000); const existingGroups = await (prisma as any).correlationGroup.findMany({ where: { userId: alert.userId, status: CorrelationStatus.ACTIVE, createdAt: { gte: cutoff }, }, include: { alerts: { where: { createdAt: { gte: cutoff } }, }, }, }); const alertEntities = alert.entities as Array<{ type: string; value: string }>; for (const group of existingGroups) { const groupEntities = group.entities as Array<{ type: string; value: string }>; if (entitiesOverlap(groupEntities, alertEntities)) { const newSeverity = higherSeverity( group.highestSeverity, alert.severity ); const updatedGroup = await (prisma as any).correlationGroup.update({ where: { id: group.id }, data: { highestSeverity: newSeverity, alertCount: group.alertCount + 1, entities: this.mergeEntities(groupEntities, alertEntities), }, }); return updatedGroup; } } const uniqueSources = new Set(); uniqueSources.add(alert.source); const uniqueCategories = new Set(); uniqueCategories.add(alert.category); if (uniqueSources.size > 1 || uniqueCategories.size > 1) { const newGroup = await (prisma as any).correlationGroup.create({ data: { userId: alert.userId, entities: alert.entities, highestSeverity: alert.severity, status: CorrelationStatus.ACTIVE, alertCount: 1, summary: this.generateSummary( alert.source, alert.category, alert.title ), }, }); return newGroup; } return null; } private mergeEntities( a: Array<{ type: string; value: string }>, b: Array<{ type: string; value: string }> ): Array<{ type: string; value: string }> { const seen = new Map(); for (const e of [...a, ...b]) { const key = `${e.type}:${e.value.toLowerCase()}`; if (!seen.has(key)) { seen.set(key, e.value); } } return Array.from(seen.entries()).map(([key, value]) => { const [type] = key.split(":"); return { type, value }; }); } private generateSummary( source: string, category: string, title: string ): string { return `${source} - ${category}: ${title}`; } public async getCorrelatedAlerts( query: CorrelationQuery ): Promise<{ alerts: CorrelatedAlertOutput[]; total: number }> { const where: Record = {}; if (query.userId) where.userId = query.userId; if (query.groupId) where.groupId = query.groupId; if (query.source) where.source = query.source; if (query.category) where.category = query.category; if (query.severity) where.severity = query.severity; if (query.timeWindowMinutes) { where.createdAt = { gte: new Date(Date.now() - query.timeWindowMinutes * 60 * 1000), }; } if (query.entityType && query.entityId) { where.entities = { path: [], contains: JSON.stringify({ type: query.entityType, value: query.entityId }), }; } const [alerts, total] = await Promise.all([ (prisma as any).normalizedAlert.findMany({ where, orderBy: { createdAt: "desc" }, take: query.limit || 50, skip: query.offset || 0, }), (prisma as any).normalizedAlert.count({ where }), ]); return { alerts: alerts.map((a: AlertRow) => this.toOutput(a)), total, }; } public async getCorrelationGroups( query: CorrelationQuery ): Promise<{ groups: CorrelationGroupOutput[]; total: number }> { const where: Record = {}; if (query.userId) where.userId = query.userId; if (query.status) where.status = query.status; if (query.timeWindowMinutes) { where.createdAt = { gte: new Date(Date.now() - query.timeWindowMinutes * 60 * 1000), }; } const [groups, total] = await Promise.all([ (prisma as any).correlationGroup.findMany({ where, orderBy: { createdAt: "desc" }, take: query.limit || 50, skip: query.offset || 0, include: { alerts: { orderBy: { createdAt: "desc" }, take: 100, }, }, }), (prisma as any).correlationGroup.count({ where }), ]); return { groups: groups.map((g: GroupRow & { alerts: AlertRow[] }) => this.toGroupOutput(g) ), total, }; } public async getGroupById( groupId: string ): Promise { const group = await (prisma as any).correlationGroup.findUnique({ where: { id: groupId }, include: { alerts: { orderBy: { createdAt: "asc" }, }, }, }); return group ? this.toGroupOutput(group as GroupRow & { alerts: AlertRow[] }) : null; } public async resolveGroup( groupId: string, status: string = CorrelationStatus.RESOLVED ): Promise { const group = await (prisma as any).correlationGroup.update({ where: { id: groupId }, data: { status, resolvedAt: new Date(), }, include: { alerts: { orderBy: { createdAt: "asc" }, }, }, }); return this.toGroupOutput(group as GroupRow & { alerts: AlertRow[] }); } public async getDashboardData( userId: string, timeWindowMinutes: number = 60 ): Promise<{ totalAlerts: number; activeCorrelations: number; alertsBySource: Record; alertsBySeverity: Record; recentGroups: CorrelationGroupOutput[]; }> { const cutoff = new Date(Date.now() - timeWindowMinutes * 60 * 1000); const [totalAlerts, activeCorrelations, recentGroups] = await Promise.all([ (prisma as any).normalizedAlert.count({ where: { userId, createdAt: { gte: cutoff } }, }), (prisma as any).correlationGroup.count({ where: { userId, status: CorrelationStatus.ACTIVE, createdAt: { gte: cutoff }, }, }), (prisma as any).correlationGroup.findMany({ where: { userId, status: CorrelationStatus.ACTIVE, createdAt: { gte: cutoff }, }, orderBy: { createdAt: "desc" }, take: 10, include: { alerts: { orderBy: { createdAt: "desc" }, take: 100 } }, }), ]); const alertsBySource: Record = {}; const alertsBySeverity: Record = {}; const recentAlerts = await (prisma as any).normalizedAlert.findMany({ where: { userId, createdAt: { gte: cutoff } }, select: { source: true, severity: true }, }); for (const alert of recentAlerts) { alertsBySource[alert.source] = (alertsBySource[alert.source] || 0) + 1; alertsBySeverity[alert.severity] = (alertsBySeverity[alert.severity] || 0) + 1; } return { totalAlerts, activeCorrelations, alertsBySource, alertsBySeverity, recentGroups: recentGroups.map( (g: GroupRow & { alerts: AlertRow[] }) => this.toGroupOutput(g) ), }; } private toOutput(alert: AlertRow): CorrelatedAlertOutput { return { id: alert.id, source: alert.source as AlertSource, category: alert.category as AlertCategory, severity: alert.severity as Severity, userId: alert.userId, title: alert.title, description: alert.description, entities: alert.entities as Array<{ type: EntityType; value: string }>, sourceAlertId: alert.sourceAlertId, groupId: alert.groupId || "", payload: alert.payload as Record, createdAt: alert.createdAt, }; } private toGroupOutput( group: GroupRow & { alerts: AlertRow[] } ): CorrelationGroupOutput { const sources = new Set(); const categories = new Set(); const entities = group.entities as Array<{ type: EntityType; value: string }>; for (const alert of group.alerts) { sources.add(alert.source); categories.add(alert.category); } return { id: group.id, groupId: group.id, alertCount: group.alertCount, highestSeverity: group.highestSeverity as Severity, status: group.status as CorrelationStatus, entities, sources: Array.from(sources) as AlertSource[], categories: Array.from(categories) as AlertCategory[], createdAt: group.createdAt, updatedAt: group.updatedAt, }; } } export const correlationEngine = new CorrelationEngine();