From c02457c66ad72489a2ced949acdbe8a0f314d2f9 Mon Sep 17 00:00:00 2001 From: Michael Freno Date: Mon, 25 May 2026 17:58:47 -0400 Subject: [PATCH] feat: real-time alerts via WebSocket push notifications MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add ws WebSocket server (port 3001) with JWT auth and user-socket mapping - Add WebSocket client with exponential backoff reconnection and heartbeat - Add useRealtimeAlerts hook with toast notifications and unread badge - Add alert.publisher service (WS → push → email fallback) - Integrate publisher into DarkWatch, VoicePrint, HomeTitle, SpamShield, RemoveBrokers - Update Navbar with connection status indicator and unread count - Add comprehensive tests (14 passing) for server, client, and publisher --- pnpm-lock.yaml | 6 + web/package.json | 6 +- web/src/components/layout/Navbar.tsx | 49 +++- web/src/hooks/useRealtimeAlerts.ts | 91 ++++++++ web/src/lib/websocket.test.ts | 174 ++++++++++++++ web/src/lib/websocket.ts | 221 ++++++++++++++++++ web/src/server/api/routers/spamshield.ts | 7 +- .../server/services/alert.publisher.test.ts | 111 +++++++++ web/src/server/services/alert.publisher.ts | 66 ++++++ .../services/darkwatch/alert.pipeline.ts | 28 ++- web/src/server/services/hometitle.service.ts | 13 +- .../server/services/removebrokers.service.ts | 35 ++- web/src/server/services/spamshield.service.ts | 70 +++++- web/src/server/services/voiceprint.service.ts | 35 ++- web/src/server/websocket.test.ts | 95 ++++++++ web/src/server/websocket.ts | 216 +++++++++++++++++ 16 files changed, 1197 insertions(+), 26 deletions(-) create mode 100644 web/src/hooks/useRealtimeAlerts.ts create mode 100644 web/src/lib/websocket.test.ts create mode 100644 web/src/lib/websocket.ts create mode 100644 web/src/server/services/alert.publisher.test.ts create mode 100644 web/src/server/services/alert.publisher.ts create mode 100644 web/src/server/websocket.test.ts create mode 100644 web/src/server/websocket.ts diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 4459802..de8d79b 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -112,6 +112,9 @@ importers: vite: specifier: ^7.0.0 version: 7.3.3(@types/node@25.9.1)(jiti@2.7.0)(lightningcss@1.32.0)(terser@5.48.0)(tsx@4.22.3) + ws: + specifier: ^8.21.0 + version: 8.21.0 devDependencies: '@types/node-cron': specifier: ^3.0.11 @@ -119,6 +122,9 @@ importers: '@types/pg': specifier: ^8.20.0 version: 8.20.0 + '@types/ws': + specifier: ^8.18.1 + version: 8.18.1 drizzle-kit: specifier: ^0.31.10 version: 0.31.10 diff --git a/web/package.json b/web/package.json index 5103854..17c3ff7 100644 --- a/web/package.json +++ b/web/package.json @@ -14,6 +14,7 @@ "db:seed": "tsx src/server/db/seed.ts" }, "dependencies": { + "@libsql/client": "^0.15.0", "@solidjs/meta": "^0.29.4", "@solidjs/router": "^0.15.0", "@solidjs/start": "2.0.0-alpha.2", @@ -26,7 +27,6 @@ "bcryptjs": "^3.0.3", "bullmq": "^5.77.3", "clerk-solidjs": "^2.0.10", - "@libsql/client": "^0.15.0", "drizzle-orm": "^0.45.2", "firebase-admin": "^13.10.0", "ioredis": "^5.10.1", @@ -41,7 +41,8 @@ "three": "^0.184.0", "twilio": "^6.0.2", "valibot": "^0.29.0", - "vite": "^7.0.0" + "vite": "^7.0.0", + "ws": "^8.21.0" }, "engines": { "node": ">=22" @@ -49,6 +50,7 @@ "devDependencies": { "@types/node-cron": "^3.0.11", "@types/pg": "^8.20.0", + "@types/ws": "^8.18.1", "drizzle-kit": "^0.31.10", "jsdom": "^29.1.1", "tsx": "^4.22.3", diff --git a/web/src/components/layout/Navbar.tsx b/web/src/components/layout/Navbar.tsx index 0c80abe..626cf7f 100644 --- a/web/src/components/layout/Navbar.tsx +++ b/web/src/components/layout/Navbar.tsx @@ -1,10 +1,11 @@ -import { createSignal, onMount, onCleanup, Show, Suspense } from "solid-js"; +import { createSignal, onMount, onCleanup, Show } from "solid-js"; import { A } from "@solidjs/router"; import { cn } from "~/lib/utils"; import { Button } from "~/components/ui"; import { Typewriter } from "~/components/ui/Typewriter"; import { useTheme } from "~/lib/theme"; import { SignedIn, SignedOut, UserButton } from "clerk-solidjs"; +import { useRealtimeAlerts } from "~/hooks/useRealtimeAlerts"; function ShieldLogo() { return ( @@ -125,6 +126,51 @@ const navLinks = [ { label: "Dashboard", href: "/dashboard" }, ]; +function RealtimeIndicator() { + const { connectionStatus, unreadCount, clearUnread } = useRealtimeAlerts(); + + return ( +
+ 0}> + + + + + + + + +
+ ) : ( +
+ + +
+ ) + } + > +
+ + + + +
+ + + ); +} + export default function Navbar() { const [mobileOpen, setMobileOpen] = createSignal(false); const [scrolled, setScrolled] = createSignal(false); @@ -169,6 +215,7 @@ export default function Navbar() { + diff --git a/web/src/hooks/useRealtimeAlerts.ts b/web/src/hooks/useRealtimeAlerts.ts new file mode 100644 index 0000000..13ffa82 --- /dev/null +++ b/web/src/hooks/useRealtimeAlerts.ts @@ -0,0 +1,91 @@ +import { createSignal, createEffect, onMount, onCleanup } from "solid-js"; +import { createWebSocketClient, type AlertPayload, type ConnectionStatus } from "~/lib/websocket"; +import { useToast } from "~/components/ui"; + +const UNREAD_STORAGE_KEY = "shieldai_unread_count"; + +function loadUnreadCount(): number { + try { + return parseInt(localStorage.getItem(UNREAD_STORAGE_KEY) ?? "0", 10); + } catch { + return 0; + } +} + +function saveUnreadCount(count: number) { + try { + localStorage.setItem(UNREAD_STORAGE_KEY, String(count)); + } catch { + // localStorage unavailable + } +} + +function prefersReducedMotion(): boolean { + if (typeof window === "undefined") return false; + return window.matchMedia("(prefers-reduced-motion: reduce)").matches; +} + +export function useRealtimeAlerts() { + const { showToast } = useToast(); + const [unreadCount, setUnreadCount] = createSignal(loadUnreadCount()); + const [connectionStatus, setConnectionStatus] = createSignal("disconnected"); + const client = createWebSocketClient(); + const reducedMotion = prefersReducedMotion(); + + function handleAlert(alert: AlertPayload) { + setUnreadCount((prev) => { + const next = prev + 1; + saveUnreadCount(next); + return next; + }); + + const severityMap: Record = { + LOW: "info", + INFO: "info", + MEDIUM: "warning", + WARNING: "warning", + HIGH: "error", + CRITICAL: "error", + }; + + showToast( + alert.title, + severityMap[alert.severity] ?? "info", + reducedMotion ? 6000 : 4000, + ); + } + + function handleStatusChange(status: ConnectionStatus) { + setConnectionStatus(status); + if (status === "reconnecting") { + showToast("Reconnecting to real-time alerts...", "warning", 3000); + } + } + + let removeAlertListener: (() => void) | null = null; + let removeStatusListener: (() => void) | null = null; + + onMount(() => { + removeAlertListener = client.onAlert(handleAlert); + removeStatusListener = client.onStatusChange(handleStatusChange); + client.connect(); + }); + + onCleanup(() => { + removeAlertListener?.(); + removeStatusListener?.(); + client.cleanup(); + }); + + function clearUnread() { + setUnreadCount(0); + saveUnreadCount(0); + } + + return { + unreadCount, + connectionStatus, + clearUnread, + lastAlert: client.lastAlert, + }; +} diff --git a/web/src/lib/websocket.test.ts b/web/src/lib/websocket.test.ts new file mode 100644 index 0000000..237a4b8 --- /dev/null +++ b/web/src/lib/websocket.test.ts @@ -0,0 +1,174 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { createRoot } from "solid-js"; + +function createMockWs() { + let onopen: (() => void) | null = null; + let onclose: ((event: { code: number }) => void) | null = null; + let onmessage: ((event: { data: string }) => void) | null = null; + + return { + readyState: 1, + send: vi.fn(), + close: vi.fn((code?: number) => { + onclose?.({ code: code ?? 1000 }); + }), + get onopen() { return onopen; }, + set onopen(fn) { onopen = fn; }, + get onclose() { return onclose; }, + set onclose(fn) { onclose = fn; }, + get onmessage() { return onmessage; }, + set onmessage(fn) { onmessage = fn; }, + CONNECTING: 0, + OPEN: 1, + CLOSING: 2, + CLOSED: 3, + }; +} + +Object.defineProperty(globalThis, "crypto", { + value: { randomUUID: () => "test-uuid" }, +}); + +describe("WebSocket client", () => { + let mockWs: ReturnType; + let originalWs: typeof globalThis.WebSocket; + let wsConstructorUrls: string[] = []; + + beforeEach(() => { + mockWs = createMockWs(); + originalWs = globalThis.WebSocket; + wsConstructorUrls = []; + + function MockWebSocket(this: any, url: string) { + wsConstructorUrls.push(url); + setTimeout(() => { + if (typeof mockWs.onopen === "function") mockWs.onopen(); + }, 1); + return mockWs; + } + MockWebSocket.OPEN = 1; + MockWebSocket.CONNECTING = 0; + MockWebSocket.CLOSING = 2; + MockWebSocket.CLOSED = 3; + + globalThis.WebSocket = MockWebSocket as unknown as typeof globalThis.WebSocket; + + Object.defineProperty(document, "cookie", { + value: "session_token=test-session-token", + configurable: true, + }); + }); + + afterEach(() => { + globalThis.WebSocket = originalWs; + }); + + function runWithRoot(fn: (dispose: () => void) => T): T { + let result!: T; + createRoot((dispose) => { + result = fn(dispose); + }); + return result; + } + + it("should call WebSocket constructor with token", () => { + const ws = new globalThis.WebSocket("ws://test/tok"); + expect(wsConstructorUrls).toHaveLength(1); + expect(wsConstructorUrls[0]).toContain("ws://test/tok"); + expect(ws).toBe(mockWs); + }); + + it("should connect and set connected status", async () => { + const { createWebSocketClient } = await import("./websocket"); + const client = runWithRoot(() => createWebSocketClient()); + client.connect(); + + expect(wsConstructorUrls).toHaveLength(1); + expect(wsConstructorUrls[0]).toContain("token=test-session-token"); + + mockWs.onopen?.(); + await new Promise((r) => setTimeout(r, 5)); + expect(client.connectionStatus()).toBe("connected"); + client.disconnect(); + }); + + it("should not connect without auth token", async () => { + Object.defineProperty(document, "cookie", { + value: "", + configurable: true, + }); + + const { createWebSocketClient } = await import("./websocket"); + const client = runWithRoot(() => createWebSocketClient()); + client.connect(); + + expect(wsConstructorUrls).toHaveLength(0); + expect(client.connectionStatus()).toBe("disconnected"); + }); + + it("should reconnect on unexpected disconnect", async () => { + const { createWebSocketClient } = await import("./websocket"); + const client = runWithRoot(() => createWebSocketClient()); + client.connect(); + mockWs.onopen?.(); + await new Promise((r) => setTimeout(r, 5)); + expect(client.connectionStatus()).toBe("connected"); + + mockWs.onclose?.({ code: 1006 }); + await new Promise((r) => setTimeout(r, 5)); + expect(client.connectionStatus()).toBe("reconnecting"); + client.disconnect(); + }); + + it("should fire onAlert callback", async () => { + const { createWebSocketClient } = await import("./websocket"); + const client = runWithRoot(() => createWebSocketClient()); + const alerts: Array<{ title: string }> = []; + client.onAlert((a) => alerts.push(a)); + client.connect(); + mockWs.onopen?.(); + await new Promise((r) => setTimeout(r, 5)); + + mockWs.onmessage?.({ + data: JSON.stringify({ + type: "alert", + alert: { + id: "a1", title: "Test Alert", message: "Test", + severity: "HIGH", source: "DARKWATCH", + category: "EXPOSURE_DETECTED", + createdAt: new Date().toISOString(), + }, + }), + }); + + expect(alerts).toHaveLength(1); + expect(alerts[0].title).toBe("Test Alert"); + client.disconnect(); + }); + + it("should disconnect and set disconnected", async () => { + const { createWebSocketClient } = await import("./websocket"); + const client = runWithRoot(() => createWebSocketClient()); + client.connect(); + mockWs.onopen?.(); + await new Promise((r) => setTimeout(r, 5)); + expect(client.connectionStatus()).toBe("connected"); + + client.disconnect(); + expect(client.connectionStatus()).toBe("disconnected"); + }); + + it("should fire pong handler without throwing", async () => { + const { createWebSocketClient } = await import("./websocket"); + const client = runWithRoot(() => createWebSocketClient()); + client.connect(); + mockWs.onopen?.(); + await new Promise((r) => setTimeout(r, 5)); + expect(client.connectionStatus()).toBe("connected"); + + expect(() => { + mockWs.onmessage?.({ data: JSON.stringify({ type: "pong" }) }); + }).not.toThrow(); + client.disconnect(); + }); +}); diff --git a/web/src/lib/websocket.ts b/web/src/lib/websocket.ts new file mode 100644 index 0000000..447484e --- /dev/null +++ b/web/src/lib/websocket.ts @@ -0,0 +1,221 @@ +import { createSignal, onCleanup } from "solid-js"; + +const WS_URL = import.meta.env.VITE_WS_URL ?? "ws://localhost:3001"; +const RECONNECT_DELAYS = [1000, 2000, 5000, 10_000, 30_000]; +const MAX_RECONNECT_ATTEMPTS = 10; +const HEARTBEAT_INTERVAL = 30_000; +const PONG_TIMEOUT = 10_000; + +export interface AlertPayload { + id: string; + title: string; + message: string; + severity: string; + source: string; + category: string; + createdAt: string; +} + +export type ConnectionStatus = "connecting" | "connected" | "disconnected" | "reconnecting"; + +function getAuthToken(): string | null { + if (typeof document === "undefined") return null; + const match = document.cookie.match(/(?:^|;\s*)session_token=([^;]*)/); + if (match) return match[1]; + try { + return localStorage.getItem("auth_token"); + } catch { + return null; + } +} + +export function createWebSocketClient() { + const [connectionStatus, setConnectionStatus] = createSignal("disconnected"); + const [lastAlert, setLastAlert] = createSignal(null); + + let ws: WebSocket | null = null; + let reconnectAttempts = 0; + let reconnectTimer: ReturnType | null = null; + let heartbeatTimer: ReturnType | null = null; + let pongTimer: ReturnType | null = null; + let intentionalClose = false; + let listeners: Array<(alert: AlertPayload) => void> = []; + let statusListeners: Array<(status: ConnectionStatus) => void> = []; + let mountedCleanups: Array<() => void> = []; + + function notifyAlert(alert: AlertPayload) { + setLastAlert(alert); + for (const listener of listeners) { + try { + listener(alert); + } catch { + // ignore + } + } + } + + function notifyStatus(status: ConnectionStatus) { + setConnectionStatus(status); + for (const listener of statusListeners) { + try { + listener(status); + } catch { + // ignore + } + } + } + + function startHeartbeat() { + stopHeartbeat(); + heartbeatTimer = setInterval(() => { + if (ws?.readyState === WebSocket.OPEN) { + ws.send(JSON.stringify({ type: "ping" })); + pongTimer = setTimeout(() => { + intentionalClose = false; + ws?.close(); + }, PONG_TIMEOUT); + } + }, HEARTBEAT_INTERVAL); + } + + function stopHeartbeat() { + if (heartbeatTimer) { + clearInterval(heartbeatTimer); + heartbeatTimer = null; + } + if (pongTimer) { + clearTimeout(pongTimer); + pongTimer = null; + } + } + + function scheduleReconnect() { + if (intentionalClose) return; + if (reconnectAttempts >= MAX_RECONNECT_ATTEMPTS) { + notifyStatus("disconnected"); + return; + } + + notifyStatus("reconnecting"); + const delay = + RECONNECT_DELAYS[reconnectAttempts] ?? + RECONNECT_DELAYS[RECONNECT_DELAYS.length - 1]; + reconnectAttempts++; + + reconnectTimer = setTimeout(() => { + connect(); + }, delay); + } + + function connect() { + if (ws?.readyState === WebSocket.OPEN || ws?.readyState === WebSocket.CONNECTING) { + return; + } + + const token = getAuthToken(); + if (!token) { + notifyStatus("disconnected"); + return; + } + + intentionalClose = false; + notifyStatus("connecting"); + + try { + ws = new WebSocket(`${WS_URL}?token=${encodeURIComponent(token)}`); + + ws.onopen = () => { + reconnectAttempts = 0; + notifyStatus("connected"); + startHeartbeat(); + }; + + ws.onmessage = (event) => { + try { + const data = JSON.parse(event.data); + + if (data.type === "pong") { + if (pongTimer) { + clearTimeout(pongTimer); + pongTimer = null; + } + return; + } + + if (data.type === "alert") { + notifyAlert(data.alert as AlertPayload); + } + } catch { + // ignore invalid messages + } + }; + + ws.onclose = () => { + stopHeartbeat(); + if (!intentionalClose) { + scheduleReconnect(); + } else { + notifyStatus("disconnected"); + } + }; + + ws.onerror = () => { + // onclose will handle reconnection + }; + } catch { + notifyStatus("disconnected"); + } + } + + function disconnect() { + intentionalClose = true; + stopHeartbeat(); + if (reconnectTimer) { + clearTimeout(reconnectTimer); + reconnectTimer = null; + } + if (ws) { + ws.close(1000, "Client disconnecting"); + ws = null; + } + notifyStatus("disconnected"); + } + + function onAlert(listener: (alert: AlertPayload) => void) { + listeners.push(listener); + return () => { + listeners = listeners.filter((l) => l !== listener); + }; + } + + function onStatusChange(listener: (status: ConnectionStatus) => void) { + statusListeners.push(listener); + return () => { + statusListeners = statusListeners.filter((l) => l !== listener); + }; + } + + function cleanup() { + disconnect(); + listeners = []; + statusListeners = []; + for (const c of mountedCleanups) { + c(); + } + mountedCleanups = []; + } + + onCleanup(() => cleanup()); + + return { + connect, + disconnect, + onAlert, + onStatusChange, + connectionStatus, + lastAlert, + cleanup, + }; +} + +export type WebSocketClient = ReturnType; diff --git a/web/src/server/api/routers/spamshield.ts b/web/src/server/api/routers/spamshield.ts index b2ceeb0..8eb5dd2 100644 --- a/web/src/server/api/routers/spamshield.ts +++ b/web/src/server/api/routers/spamshield.ts @@ -20,17 +20,18 @@ export const spamshieldRouter = createTRPCRouter({ classifySMS: publicProcedure .input(wrap(ClassifySMSSchema)) - .query(async ({ input }) => { - return spamshieldService.classifySMS(input.text); + .query(async ({ input, ctx }) => { + return spamshieldService.classifySMS(input.text, ctx.user?.id); }), classifyCall: publicProcedure .input(wrap(ClassifyCallSchema)) - .query(async ({ input }) => { + .query(async ({ input, ctx }) => { return spamshieldService.classifyCall( input.callerNumber, input.duration, input.timeOfDay, + ctx.user?.id, ); }), diff --git a/web/src/server/services/alert.publisher.test.ts b/web/src/server/services/alert.publisher.test.ts new file mode 100644 index 0000000..d1e0994 --- /dev/null +++ b/web/src/server/services/alert.publisher.test.ts @@ -0,0 +1,111 @@ +// @vitest-environment node +import { describe, it, expect, vi, beforeEach } from "vitest"; + +const mockBroadcastToUser = vi.fn(); + +vi.mock("~/server/websocket", () => ({ + broadcastToUser: mockBroadcastToUser, +})); + +const mockSendPush = vi.fn(); +const mockSendEmail = vi.fn(); + +vi.mock("~/server/services/notification.service", () => ({ + sendPush: mockSendPush, + sendEmail: mockSendEmail, +})); + +vi.mock("~/server/db", () => ({ + db: { + select: vi.fn(), + insert: vi.fn(), + update: vi.fn(), + }, +})); + +beforeEach(() => { + vi.clearAllMocks(); +}); + +describe("alert.publisher", () => { + it("should send alert via WebSocket when user is connected", async () => { + mockBroadcastToUser.mockReturnValue(true); + + const { publishAlert } = await import("./alert.publisher"); + await publishAlert("user-1", { + id: "alert-1", + title: "Test Alert", + message: "Test message", + severity: "HIGH", + source: "DARKWATCH", + category: "EXPOSURE_DETECTED", + createdAt: new Date(), + }); + + expect(mockBroadcastToUser).toHaveBeenCalledWith("user-1", { + type: "alert", + alert: { + id: "alert-1", + title: "Test Alert", + message: "Test message", + severity: "HIGH", + source: "DARKWATCH", + category: "EXPOSURE_DETECTED", + createdAt: expect.any(String), + }, + }); + expect(mockSendPush).not.toHaveBeenCalled(); + expect(mockSendEmail).not.toHaveBeenCalled(); + }); + + it("should fall back to push notification when user is not connected", async () => { + mockBroadcastToUser.mockReturnValue(false); + mockSendPush.mockResolvedValue({ successCount: 1 }); + + const { publishAlert } = await import("./alert.publisher"); + await publishAlert("user-1", { + id: "alert-2", + title: "Offline Alert", + message: "Offline message", + severity: "WARNING", + source: "VOICEPRINT", + category: "SYNTHETIC_VOICE", + createdAt: new Date(), + }); + + expect(mockBroadcastToUser).toHaveBeenCalled(); + expect(mockSendPush).toHaveBeenCalledWith( + "user-1", + "Offline Alert", + "Offline message", + { alertId: "alert-2", source: "VOICEPRINT", severity: "WARNING" }, + ); + }); + + it("should publish alert to multiple users", async () => { + mockBroadcastToUser.mockReturnValue(false); + mockSendPush.mockResolvedValue({ successCount: 0 }); + + const db = await import("~/server/db"); + (db.db.select as ReturnType).mockReturnValue({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue([]), + }), + }), + }); + + const { publishToGroup } = await import("./alert.publisher"); + await publishToGroup(["user-1", "user-2"], { + id: "alert-3", + title: "Group Alert", + message: "Group message", + severity: "INFO", + source: "HOME_TITLE", + category: "HOME_TITLE", + createdAt: new Date(), + }); + + expect(mockBroadcastToUser).toHaveBeenCalledTimes(2); + }); +}); diff --git a/web/src/server/services/alert.publisher.ts b/web/src/server/services/alert.publisher.ts new file mode 100644 index 0000000..7cdc829 --- /dev/null +++ b/web/src/server/services/alert.publisher.ts @@ -0,0 +1,66 @@ +import { broadcastToUser } from "~/server/websocket"; +import { sendPush, sendEmail } from "~/server/services/notification.service"; +import { db } from "~/server/db"; +import { users } from "~/server/db/schema/auth"; +import { eq } from "drizzle-orm"; + +export interface PublishableAlert { + id: string; + title: string; + message: string; + severity: string; + source: string; + category: string; + createdAt: Date; +} + +export async function publishAlert(userId: string, alert: PublishableAlert): Promise { + const message = { + type: "alert" as const, + alert: { + id: alert.id, + title: alert.title, + message: alert.message, + severity: alert.severity, + source: alert.source, + category: alert.category, + createdAt: alert.createdAt.toISOString(), + }, + }; + + const sent = broadcastToUser(userId, message); + + if (!sent) { + try { + const pushResult = await sendPush(userId, alert.title, alert.message, { + alertId: alert.id, + source: alert.source, + severity: alert.severity, + }); + + if (pushResult.successCount === 0) { + const [user] = await db + .select() + .from(users) + .where(eq(users.id, userId)) + .limit(1); + + if (user?.email) { + await sendEmail( + user.email, + `[ShieldAI] ${alert.title}`, + `

${alert.message}

`, + alert.message, + ); + } + } + } catch (err) { + console.error("[alert.publisher] Fallback notification failed:", err); + } + } +} + +export async function publishToGroup(userIds: string[], alert: PublishableAlert): Promise { + const promises = userIds.map((userId) => publishAlert(userId, alert)); + await Promise.allSettled(promises); +} diff --git a/web/src/server/services/darkwatch/alert.pipeline.ts b/web/src/server/services/darkwatch/alert.pipeline.ts index 766e142..32a23cd 100644 --- a/web/src/server/services/darkwatch/alert.pipeline.ts +++ b/web/src/server/services/darkwatch/alert.pipeline.ts @@ -1,6 +1,7 @@ import { eq, and } from "drizzle-orm"; import { db } from "~/server/db"; import { exposures, alerts, subscriptions } from "~/server/db/schema"; +import { publishAlert } from "~/server/services/alert.publisher"; export function severityScore(exposure: { source: string; @@ -105,14 +106,27 @@ async function createAlertForExposure( if (!sub) return; - await db.insert(alerts).values({ - subscriptionId: exposure.subscriptionId, - userId: sub.userId, - exposureId: exposure.id, - type: "exposure_detected", + const [alert] = await db + .insert(alerts) + .values({ + subscriptionId: exposure.subscriptionId, + userId: sub.userId, + exposureId: exposure.id, + type: "exposure_detected", + title, + message, + severity: alertSeverityMap[severity] ?? "info", + channel: ["email", "push"], + }) + .returning(); + + publishAlert(sub.userId, { + id: alert.id, title, message, severity: alertSeverityMap[severity] ?? "info", - channel: ["email", "push"], - }); + source: "DARKWATCH", + category: "EXPOSURE_DETECTED", + createdAt: alert.createdAt, + }).catch((err) => console.error("[darkwatch] Failed to publish alert:", err)); } diff --git a/web/src/server/services/hometitle.service.ts b/web/src/server/services/hometitle.service.ts index 4c3e407..f156ef2 100644 --- a/web/src/server/services/hometitle.service.ts +++ b/web/src/server/services/hometitle.service.ts @@ -14,6 +14,7 @@ import { type DetectedChange, type SnapshotData, } from "./hometitle/change.detector"; +import { publishAlert } from "~/server/services/alert.publisher"; import { geocodeAddress, fetchCountyRecords, @@ -433,7 +434,7 @@ async function generateAlert( critical: "CRITICAL", }; - await db + const [normalized] = await db .insert(normalizedAlerts) .values({ source: "HOME_TITLE", @@ -454,5 +455,15 @@ async function generateAlert( }) .returning(); + publishAlert(sub.userId, { + id: alert.id, + title, + message, + severity: change.severity, + source: "HOME_TITLE", + category: "HOME_TITLE", + createdAt: alert.createdAt, + }).catch((err) => console.error("[hometitle] Failed to publish alert:", err)); + return alert; } diff --git a/web/src/server/services/removebrokers.service.ts b/web/src/server/services/removebrokers.service.ts index 44d96de..02feca2 100644 --- a/web/src/server/services/removebrokers.service.ts +++ b/web/src/server/services/removebrokers.service.ts @@ -1,8 +1,9 @@ import { TRPCError } from "@trpc/server"; import { eq, and, desc, count, inArray, or, isNull, lt } from "drizzle-orm"; import { db } from "~/server/db"; -import { infoBrokers, removalRequests, brokerListings, subscriptions } from "~/server/db/schema"; +import { infoBrokers, removalRequests, brokerListings, subscriptions, normalizedAlerts } from "~/server/db/schema"; import { getActiveBrokers } from "./removebrokers/broker.registry"; +import { publishAlert } from "~/server/services/alert.publisher"; import { submitAutomatedRemoval, sendRemovalEmail } from "./removebrokers/removal.engine"; import type { PersonalInfo } from "./removebrokers/removal.engine"; import type { RemovalMethod } from "./removebrokers/broker.registry"; @@ -273,6 +274,38 @@ export async function scanForListings(userId: string, brokerId?: string) { createdListings.push({ id: listing.id, brokerId: listing.brokerId, url: listing.url }); } + if (createdListings.length > 0) { + const sourceAlertId = `removebrokers:scan:${sub.id}:${Date.now()}`; + try { + const [na] = await db + .insert(normalizedAlerts) + .values({ + source: "REMOVEBROKERS", + category: "BROKER_LISTING", + severity: "INFO", + userId: sub.userId, + title: "New Broker Listings Found", + description: `Found ${createdListings.length} new broker listing(s)`, + entities: { listingCount: createdListings.length, brokerIds: createdListings.map((l) => l.brokerId) }, + sourceAlertId, + createdAt: new Date(), + }) + .returning(); + + publishAlert(sub.userId, { + id: na.id, + title: "New Broker Listings Found", + message: `Found ${createdListings.length} new broker listing(s)`, + severity: "INFO", + source: "REMOVEBROKERS", + category: "BROKER_LISTING", + createdAt: na.createdAt, + }).catch(() => {}); + } catch { + // alert creation is best-effort + } + } + return { scanned: brokers.length, listingsFound: createdListings.length, diff --git a/web/src/server/services/spamshield.service.ts b/web/src/server/services/spamshield.service.ts index 1625009..d53a1ff 100644 --- a/web/src/server/services/spamshield.service.ts +++ b/web/src/server/services/spamshield.service.ts @@ -1,10 +1,11 @@ import { TRPCError } from "@trpc/server"; import { eq, and, desc, count, sql, gte } from "drizzle-orm"; import { db } from "~/server/db"; -import { spamFeedback, spamRules, auditLogs } from "~/server/db/schema"; +import { spamFeedback, spamRules, auditLogs, normalizedAlerts } from "~/server/db/schema"; import { classifyTextBERT, extractFeatures, ruleEngine } from "./spamshield/ml.engine"; import { checkReputation } from "./spamshield/reputation.api"; import type { ReputationResult } from "./spamshield/reputation.api"; +import { publishAlert } from "~/server/services/alert.publisher"; function normalizePhoneNumber(phone: string): string { let cleaned = phone.replace(/[^\d+]/g, ""); @@ -72,7 +73,7 @@ export async function checkNumberReputation(phoneNumber: string) { return response; } -export async function classifySMS(text: string) { +export async function classifySMS(text: string, userId?: string) { const classification = await classifyTextBERT(text); const response = { @@ -83,6 +84,38 @@ export async function classifySMS(text: string) { await logAudit(undefined, "classifySMS", { textLength: text.length }, response, response.confidence); + if (classification.isSpam && classification.confidence >= 0.8 && userId) { + const sourceAlertId = `spamshield:sms:${Date.now()}`; + try { + const [na] = await db + .insert(normalizedAlerts) + .values({ + source: "SPAMSHIELD", + category: "SPAM_SMS", + severity: "LOW", + userId, + title: "Spam SMS Detected", + description: "Suspected spam SMS detected", + entities: { textLength: text.length, confidence: classification.confidence }, + sourceAlertId, + createdAt: new Date(), + }) + .returning(); + + publishAlert(userId, { + id: na.id, + title: "Spam SMS Detected", + message: "Suspected spam SMS detected", + severity: "LOW", + source: "SPAMSHIELD", + category: "SPAM_SMS", + createdAt: na.createdAt, + }).catch(() => {}); + } catch { + // alert creation is best-effort + } + } + return response; } @@ -90,6 +123,7 @@ export async function classifyCall( callerNumber: string, duration?: number, timeOfDay?: number, + userId?: string, ) { const normalized = normalizePhoneNumber(callerNumber); const features = await extractFeatures({ callerNumber: normalized, duration, timeOfDay }); @@ -138,6 +172,38 @@ export async function classifyCall( await logAudit(undefined, "classifyCall", { callerNumber: normalized, duration, timeOfDay }, response, confidence); + if (isSpam && confidence >= 0.8 && userId) { + const sourceAlertId = `spamshield:call:${normalized}:${Date.now()}`; + try { + const [na] = await db + .insert(normalizedAlerts) + .values({ + source: "SPAMSHIELD", + category: "SPAM_CALL", + severity: "MEDIUM", + userId, + title: "Spam Call Blocked", + description: `Blocked spam call from ${normalized}`, + entities: { callerNumber: normalized, confidence, ruleMatch: !!ruleMatch }, + sourceAlertId, + createdAt: new Date(), + }) + .returning(); + + publishAlert(userId, { + id: na.id, + title: "Spam Call Blocked", + message: `Blocked spam call from ${normalized}`, + severity: "MEDIUM", + source: "SPAMSHIELD", + category: "SPAM_CALL", + createdAt: na.createdAt, + }).catch(() => {}); + } catch { + // alert creation is best-effort + } + } + return response; } diff --git a/web/src/server/services/voiceprint.service.ts b/web/src/server/services/voiceprint.service.ts index 14cc66f..80b0b62 100644 --- a/web/src/server/services/voiceprint.service.ts +++ b/web/src/server/services/voiceprint.service.ts @@ -10,6 +10,7 @@ import { normalizedAlerts, } from "~/server/db/schema"; import { saveAudio, getAudioUrl, deleteFile } from "./voiceprint/storage"; +import { publishAlert } from "~/server/services/alert.publisher"; import { preprocessAudio, detectSynthetic, @@ -100,17 +101,33 @@ async function createVoiceAlert( if (sub) { const category = verdict === "SYNTHETIC" ? "SYNTHETIC_VOICE" : "VOICE_MISMATCH"; - await db.insert(normalizedAlerts).values({ + const title = verdict === "SYNTHETIC" ? "Synthetic Voice Detected" : "Voice Mismatch Detected"; + const description = `Analysis ${analysisId} returned verdict ${verdict} with ${(confidence * 100).toFixed(1)}% confidence`; + + const [alert] = await db + .insert(normalizedAlerts) + .values({ + source: "VOICEPRINT", + category, + severity: verdict === "SYNTHETIC" ? "HIGH" : "MEDIUM", + userId, + title, + description, + entities: { analysisId, verdict, confidence }, + sourceAlertId: `voiceprint-${analysisId}`, + createdAt: new Date(), + }) + .returning(); + + publishAlert(userId, { + id: alert.id, + title, + message: description, + severity: verdict === "SYNTHETIC" ? "HIGH" : "MEDIUM", source: "VOICEPRINT", category, - severity: verdict === "SYNTHETIC" ? "HIGH" : "MEDIUM", - userId, - title: verdict === "SYNTHETIC" ? "Synthetic Voice Detected" : "Voice Mismatch Detected", - description: `Analysis ${analysisId} returned verdict ${verdict} with ${(confidence * 100).toFixed(1)}% confidence`, - entities: { analysisId, verdict, confidence }, - sourceAlertId: `voiceprint-${analysisId}`, - createdAt: new Date(), - }); + createdAt: alert.createdAt, + }).catch((err) => console.error("[voiceprint] Failed to publish alert:", err)); } } catch (err) { console.error("[voiceprint] Failed to create alert:", err); diff --git a/web/src/server/websocket.test.ts b/web/src/server/websocket.test.ts new file mode 100644 index 0000000..f3cf4c2 --- /dev/null +++ b/web/src/server/websocket.test.ts @@ -0,0 +1,95 @@ +// @vitest-environment node +import { describe, it, expect, vi, beforeAll, afterAll } from "vitest"; + +const mockVerifyJWT = vi.fn(); + +vi.mock("~/server/auth/jwt", () => ({ + verifyJWT: mockVerifyJWT, + signJWT: vi.fn(), +})); + +let mockServer: any; +let connectionHandler: ((ws: any, req: any) => void) | null = null; + +vi.mock("ws", () => { + mockServer = { + on: vi.fn((event: string, handler: any) => { + if (event === "connection") connectionHandler = handler; + }), + close: vi.fn((cb: () => void) => cb()), + clients: new Set(), + }; + + return { + WebSocketServer: vi.fn(function (_opts: any, cb?: () => void) { + if (cb) setTimeout(cb, 0); + return mockServer; + }), + WebSocket: { OPEN: 1, CONNECTING: 0, CLOSING: 2, CLOSED: 3 }, + }; +}); + +function makeWs() { + const handlers: Record void> = {}; + return { + close: vi.fn(), + send: vi.fn(), + ping: vi.fn(), + terminate: vi.fn(), + readyState: 1, + on: vi.fn((event: string, handler: any) => { + handlers[event] = handler; + }), + emit: (event: string, ...args: any[]) => { + handlers[event]?.(...args); + }, + }; +} + +describe("WebSocket server", () => { + beforeAll(async () => { + process.env.WS_PORT = "3099"; + const { start } = await import("~/server/websocket"); + await start(); + }, 15000); + + afterAll(async () => { + const { stop } = await import("~/server/websocket"); + await stop(); + }); + + it("should reject connection without JWT", async () => { + const ws = makeWs(); + await connectionHandler!(ws, { url: "/" }); + expect(ws.close).toHaveBeenCalledWith(4001, "Authentication failed"); + }); + + it("should reject connection with invalid JWT", async () => { + mockVerifyJWT.mockRejectedValue(new Error("Invalid token")); + + const ws = makeWs(); + await connectionHandler!(ws, { url: "/?token=bad" }); + expect(ws.close).toHaveBeenCalledWith(4001, "Authentication failed"); + }); + + it("should accept connection with valid JWT", async () => { + mockVerifyJWT.mockResolvedValue({ sub: "user-1" }); + + const ws = makeWs(); + await connectionHandler!(ws, { url: "/?token=good" }); + expect(ws.close).not.toHaveBeenCalled(); + }); + + it("should return false when broadcasting to non-existent user", async () => { + const { broadcastToUser } = await import("~/server/websocket"); + const result = broadcastToUser("nonexistent", { + type: "alert", + alert: { + id: "a1", title: "T", message: "M", + severity: "INFO", source: "TEST", category: "TEST", + createdAt: new Date().toISOString(), + }, + }); + expect(result).toBe(false); + }); +}); diff --git a/web/src/server/websocket.ts b/web/src/server/websocket.ts new file mode 100644 index 0000000..6bb137e --- /dev/null +++ b/web/src/server/websocket.ts @@ -0,0 +1,216 @@ +import { WebSocketServer, WebSocket } from "ws"; +import type { Server } from "ws"; +import { IncomingMessage } from "node:http"; +import { URL } from "node:url"; +import { verifyJWT } from "~/server/auth/jwt"; + +const WS_PORT = parseInt(process.env.WS_PORT ?? "3001", 10); +const HEARTBEAT_INTERVAL = 30_000; +const PONG_TIMEOUT = 10_000; + +interface AlertMessage { + type: "alert"; + alert: { + id: string; + title: string; + message: string; + severity: string; + source: string; + category: string; + createdAt: string; + }; +} + +interface WsClient extends WebSocket { + userId?: string; + isAlive?: boolean; + pongTimer?: ReturnType; +} + +const userSockets = new Map>(); +let wss: Server | null = null; +let heartbeatTimer: ReturnType | null = null; + +function getTokenFromRequest(req: IncomingMessage): string | null { + const url = new URL(req.url ?? "/", "http://localhost"); + return url.searchParams.get("token"); +} + +async function authenticateConnection( + ws: WsClient, + req: IncomingMessage, +): Promise { + const token = getTokenFromRequest(req); + if (!token) return null; + + try { + const payload = await verifyJWT<{ sub?: string; userId?: string }>(token); + const userId = payload.sub ?? payload.userId; + if (!userId) return null; + return userId; + } catch { + return null; + } +} + +function addSocket(userId: string, ws: WsClient) { + let sockets = userSockets.get(userId); + if (!sockets) { + sockets = new Set(); + userSockets.set(userId, sockets); + } + sockets.add(ws); +} + +function removeSocket(userId: string, ws: WsClient) { + const sockets = userSockets.get(userId); + if (!sockets) return; + sockets.delete(ws); + if (sockets.size === 0) { + userSockets.delete(userId); + } +} + +function heartbeat(ws: WsClient) { + ws.isAlive = true; +} + +function startHeartbeat() { + if (heartbeatTimer) clearInterval(heartbeatTimer); + heartbeatTimer = setInterval(() => { + if (!wss) return; + wss.clients.forEach((client) => { + const ws = client as WsClient; + if (ws.isAlive === false) { + ws.terminate(); + return; + } + ws.isAlive = false; + ws.ping(); + + ws.pongTimer = setTimeout(() => { + ws.terminate(); + }, PONG_TIMEOUT); + }); + }, HEARTBEAT_INTERVAL); + + if (heartbeatTimer && typeof heartbeatTimer === "object") { + heartbeatTimer.unref(); + } +} + +function stopHeartbeat() { + if (heartbeatTimer) { + clearInterval(heartbeatTimer); + heartbeatTimer = null; + } +} + +export function broadcastToUser(userId: string, data: AlertMessage) { + const sockets = userSockets.get(userId); + if (!sockets || sockets.size === 0) return false; + + const message = JSON.stringify(data); + let sent = false; + for (const ws of sockets) { + if (ws.readyState === WebSocket.OPEN) { + ws.send(message); + sent = true; + } + } + return sent; +} + +export function getConnectedUsers(): string[] { + return Array.from(userSockets.keys()); +} + +export function getConnectionCount(): number { + let count = 0; + for (const sockets of userSockets.values()) { + count += sockets.size; + } + return count; +} + +export function start(): Promise { + return new Promise((resolve) => { + if (wss) { + resolve(); + return; + } + + wss = new WebSocketServer({ port: WS_PORT }, () => { + console.log(`[websocket] Server listening on port ${WS_PORT}`); + resolve(); + }); + + wss.on("connection", async (ws: WsClient, req: IncomingMessage) => { + const userId = await authenticateConnection(ws, req); + + if (!userId) { + ws.close(4001, "Authentication failed"); + return; + } + + ws.userId = userId; + ws.isAlive = true; + addSocket(userId, ws); + + ws.on("pong", () => { + heartbeat(ws); + if (ws.pongTimer) { + clearTimeout(ws.pongTimer); + ws.pongTimer = undefined; + } + }); + + ws.on("message", (data) => { + try { + const msg = JSON.parse(data.toString()); + if (msg.type === "ping") { + ws.send(JSON.stringify({ type: "pong" })); + } + } catch { + // ignore invalid messages + } + }); + + ws.on("close", () => { + if (ws.userId) { + removeSocket(ws.userId, ws); + } + if (ws.pongTimer) { + clearTimeout(ws.pongTimer); + } + }); + + ws.on("error", (err) => { + console.error("[websocket] Client error:", err.message); + }); + }); + + startHeartbeat(); + }); +} + +export function stop(): Promise { + return new Promise((resolve) => { + stopHeartbeat(); + if (!wss) { + resolve(); + return; + } + + for (const ws of wss.clients) { + ws.close(1001, "Server shutting down"); + } + + wss.close(() => { + wss = null; + userSockets.clear(); + console.log("[websocket] Server stopped"); + resolve(); + }); + }); +}