feat: real-time alerts via WebSocket push notifications
- Add ws WebSocket server (port 3001) with JWT auth and user-socket mapping - Add WebSocket client with exponential backoff reconnection and heartbeat - Add useRealtimeAlerts hook with toast notifications and unread badge - Add alert.publisher service (WS → push → email fallback) - Integrate publisher into DarkWatch, VoicePrint, HomeTitle, SpamShield, RemoveBrokers - Update Navbar with connection status indicator and unread count - Add comprehensive tests (14 passing) for server, client, and publisher
This commit is contained in:
6
pnpm-lock.yaml
generated
6
pnpm-lock.yaml
generated
@@ -112,6 +112,9 @@ importers:
|
||||
vite:
|
||||
specifier: ^7.0.0
|
||||
version: 7.3.3(@types/node@25.9.1)(jiti@2.7.0)(lightningcss@1.32.0)(terser@5.48.0)(tsx@4.22.3)
|
||||
ws:
|
||||
specifier: ^8.21.0
|
||||
version: 8.21.0
|
||||
devDependencies:
|
||||
'@types/node-cron':
|
||||
specifier: ^3.0.11
|
||||
@@ -119,6 +122,9 @@ importers:
|
||||
'@types/pg':
|
||||
specifier: ^8.20.0
|
||||
version: 8.20.0
|
||||
'@types/ws':
|
||||
specifier: ^8.18.1
|
||||
version: 8.18.1
|
||||
drizzle-kit:
|
||||
specifier: ^0.31.10
|
||||
version: 0.31.10
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
"db:seed": "tsx src/server/db/seed.ts"
|
||||
},
|
||||
"dependencies": {
|
||||
"@libsql/client": "^0.15.0",
|
||||
"@solidjs/meta": "^0.29.4",
|
||||
"@solidjs/router": "^0.15.0",
|
||||
"@solidjs/start": "2.0.0-alpha.2",
|
||||
@@ -26,7 +27,6 @@
|
||||
"bcryptjs": "^3.0.3",
|
||||
"bullmq": "^5.77.3",
|
||||
"clerk-solidjs": "^2.0.10",
|
||||
"@libsql/client": "^0.15.0",
|
||||
"drizzle-orm": "^0.45.2",
|
||||
"firebase-admin": "^13.10.0",
|
||||
"ioredis": "^5.10.1",
|
||||
@@ -41,7 +41,8 @@
|
||||
"three": "^0.184.0",
|
||||
"twilio": "^6.0.2",
|
||||
"valibot": "^0.29.0",
|
||||
"vite": "^7.0.0"
|
||||
"vite": "^7.0.0",
|
||||
"ws": "^8.21.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=22"
|
||||
@@ -49,6 +50,7 @@
|
||||
"devDependencies": {
|
||||
"@types/node-cron": "^3.0.11",
|
||||
"@types/pg": "^8.20.0",
|
||||
"@types/ws": "^8.18.1",
|
||||
"drizzle-kit": "^0.31.10",
|
||||
"jsdom": "^29.1.1",
|
||||
"tsx": "^4.22.3",
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import { createSignal, onMount, onCleanup, Show, Suspense } from "solid-js";
|
||||
import { createSignal, onMount, onCleanup, Show } from "solid-js";
|
||||
import { A } from "@solidjs/router";
|
||||
import { cn } from "~/lib/utils";
|
||||
import { Button } from "~/components/ui";
|
||||
import { Typewriter } from "~/components/ui/Typewriter";
|
||||
import { useTheme } from "~/lib/theme";
|
||||
import { SignedIn, SignedOut, UserButton } from "clerk-solidjs";
|
||||
import { useRealtimeAlerts } from "~/hooks/useRealtimeAlerts";
|
||||
|
||||
function ShieldLogo() {
|
||||
return (
|
||||
@@ -125,6 +126,51 @@ const navLinks = [
|
||||
{ label: "Dashboard", href: "/dashboard" },
|
||||
];
|
||||
|
||||
function RealtimeIndicator() {
|
||||
const { connectionStatus, unreadCount, clearUnread } = useRealtimeAlerts();
|
||||
|
||||
return (
|
||||
<div class="flex items-center gap-2">
|
||||
<Show when={unreadCount() > 0}>
|
||||
<button
|
||||
type="button"
|
||||
onClick={clearUnread}
|
||||
class="relative flex items-center justify-center w-6 h-6 rounded-full bg-[var(--color-error)] text-white text-[10px] font-bold leading-none transition-transform hover:scale-110"
|
||||
aria-label={`${unreadCount()} unread alerts`}
|
||||
>
|
||||
{unreadCount() > 99 ? "99+" : unreadCount()}
|
||||
</button>
|
||||
</Show>
|
||||
<Show
|
||||
when={connectionStatus() === "connected"}
|
||||
fallback={
|
||||
connectionStatus() === "reconnecting" || connectionStatus() === "connecting" ? (
|
||||
<div class="flex items-center gap-1 text-[10px] text-[var(--color-warning)]" aria-label="Reconnecting">
|
||||
<span class="relative flex h-2 w-2">
|
||||
<span class="animate-ping absolute inline-flex h-full w-full rounded-full bg-[var(--color-warning)] opacity-75" />
|
||||
<span class="relative inline-flex rounded-full h-2 w-2 bg-[var(--color-warning)]" />
|
||||
</span>
|
||||
<span class="hidden sm:inline">Reconnecting</span>
|
||||
</div>
|
||||
) : (
|
||||
<div class="flex items-center gap-1 text-[10px] text-[var(--color-text-muted)]" aria-label="Offline">
|
||||
<span class="inline-flex rounded-full h-2 w-2 bg-[var(--color-text-muted)]" />
|
||||
<span class="hidden sm:inline">Offline</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
>
|
||||
<div class="flex items-center gap-1" aria-label="Connected">
|
||||
<span class="relative flex h-2 w-2">
|
||||
<span class="animate-ping absolute inline-flex h-full w-full rounded-full bg-[var(--color-success)] opacity-75" />
|
||||
<span class="relative inline-flex rounded-full h-2 w-2 bg-[var(--color-success)]" />
|
||||
</span>
|
||||
</div>
|
||||
</Show>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default function Navbar() {
|
||||
const [mobileOpen, setMobileOpen] = createSignal(false);
|
||||
const [scrolled, setScrolled] = createSignal(false);
|
||||
@@ -169,6 +215,7 @@ export default function Navbar() {
|
||||
<ThemeToggle />
|
||||
<SignedIn>
|
||||
<UserButton showName />
|
||||
<RealtimeIndicator />
|
||||
<Button variant="secondary" size="sm">
|
||||
<A href="/dashboard">Dashboard</A>
|
||||
</Button>
|
||||
|
||||
91
web/src/hooks/useRealtimeAlerts.ts
Normal file
91
web/src/hooks/useRealtimeAlerts.ts
Normal file
@@ -0,0 +1,91 @@
|
||||
import { createSignal, createEffect, onMount, onCleanup } from "solid-js";
|
||||
import { createWebSocketClient, type AlertPayload, type ConnectionStatus } from "~/lib/websocket";
|
||||
import { useToast } from "~/components/ui";
|
||||
|
||||
const UNREAD_STORAGE_KEY = "shieldai_unread_count";
|
||||
|
||||
function loadUnreadCount(): number {
|
||||
try {
|
||||
return parseInt(localStorage.getItem(UNREAD_STORAGE_KEY) ?? "0", 10);
|
||||
} catch {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
function saveUnreadCount(count: number) {
|
||||
try {
|
||||
localStorage.setItem(UNREAD_STORAGE_KEY, String(count));
|
||||
} catch {
|
||||
// localStorage unavailable
|
||||
}
|
||||
}
|
||||
|
||||
function prefersReducedMotion(): boolean {
|
||||
if (typeof window === "undefined") return false;
|
||||
return window.matchMedia("(prefers-reduced-motion: reduce)").matches;
|
||||
}
|
||||
|
||||
export function useRealtimeAlerts() {
|
||||
const { showToast } = useToast();
|
||||
const [unreadCount, setUnreadCount] = createSignal(loadUnreadCount());
|
||||
const [connectionStatus, setConnectionStatus] = createSignal<ConnectionStatus>("disconnected");
|
||||
const client = createWebSocketClient();
|
||||
const reducedMotion = prefersReducedMotion();
|
||||
|
||||
function handleAlert(alert: AlertPayload) {
|
||||
setUnreadCount((prev) => {
|
||||
const next = prev + 1;
|
||||
saveUnreadCount(next);
|
||||
return next;
|
||||
});
|
||||
|
||||
const severityMap: Record<string, "success" | "error" | "warning" | "info"> = {
|
||||
LOW: "info",
|
||||
INFO: "info",
|
||||
MEDIUM: "warning",
|
||||
WARNING: "warning",
|
||||
HIGH: "error",
|
||||
CRITICAL: "error",
|
||||
};
|
||||
|
||||
showToast(
|
||||
alert.title,
|
||||
severityMap[alert.severity] ?? "info",
|
||||
reducedMotion ? 6000 : 4000,
|
||||
);
|
||||
}
|
||||
|
||||
function handleStatusChange(status: ConnectionStatus) {
|
||||
setConnectionStatus(status);
|
||||
if (status === "reconnecting") {
|
||||
showToast("Reconnecting to real-time alerts...", "warning", 3000);
|
||||
}
|
||||
}
|
||||
|
||||
let removeAlertListener: (() => void) | null = null;
|
||||
let removeStatusListener: (() => void) | null = null;
|
||||
|
||||
onMount(() => {
|
||||
removeAlertListener = client.onAlert(handleAlert);
|
||||
removeStatusListener = client.onStatusChange(handleStatusChange);
|
||||
client.connect();
|
||||
});
|
||||
|
||||
onCleanup(() => {
|
||||
removeAlertListener?.();
|
||||
removeStatusListener?.();
|
||||
client.cleanup();
|
||||
});
|
||||
|
||||
function clearUnread() {
|
||||
setUnreadCount(0);
|
||||
saveUnreadCount(0);
|
||||
}
|
||||
|
||||
return {
|
||||
unreadCount,
|
||||
connectionStatus,
|
||||
clearUnread,
|
||||
lastAlert: client.lastAlert,
|
||||
};
|
||||
}
|
||||
174
web/src/lib/websocket.test.ts
Normal file
174
web/src/lib/websocket.test.ts
Normal file
@@ -0,0 +1,174 @@
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import { createRoot } from "solid-js";
|
||||
|
||||
function createMockWs() {
|
||||
let onopen: (() => void) | null = null;
|
||||
let onclose: ((event: { code: number }) => void) | null = null;
|
||||
let onmessage: ((event: { data: string }) => void) | null = null;
|
||||
|
||||
return {
|
||||
readyState: 1,
|
||||
send: vi.fn(),
|
||||
close: vi.fn((code?: number) => {
|
||||
onclose?.({ code: code ?? 1000 });
|
||||
}),
|
||||
get onopen() { return onopen; },
|
||||
set onopen(fn) { onopen = fn; },
|
||||
get onclose() { return onclose; },
|
||||
set onclose(fn) { onclose = fn; },
|
||||
get onmessage() { return onmessage; },
|
||||
set onmessage(fn) { onmessage = fn; },
|
||||
CONNECTING: 0,
|
||||
OPEN: 1,
|
||||
CLOSING: 2,
|
||||
CLOSED: 3,
|
||||
};
|
||||
}
|
||||
|
||||
Object.defineProperty(globalThis, "crypto", {
|
||||
value: { randomUUID: () => "test-uuid" },
|
||||
});
|
||||
|
||||
describe("WebSocket client", () => {
|
||||
let mockWs: ReturnType<typeof createMockWs>;
|
||||
let originalWs: typeof globalThis.WebSocket;
|
||||
let wsConstructorUrls: string[] = [];
|
||||
|
||||
beforeEach(() => {
|
||||
mockWs = createMockWs();
|
||||
originalWs = globalThis.WebSocket;
|
||||
wsConstructorUrls = [];
|
||||
|
||||
function MockWebSocket(this: any, url: string) {
|
||||
wsConstructorUrls.push(url);
|
||||
setTimeout(() => {
|
||||
if (typeof mockWs.onopen === "function") mockWs.onopen();
|
||||
}, 1);
|
||||
return mockWs;
|
||||
}
|
||||
MockWebSocket.OPEN = 1;
|
||||
MockWebSocket.CONNECTING = 0;
|
||||
MockWebSocket.CLOSING = 2;
|
||||
MockWebSocket.CLOSED = 3;
|
||||
|
||||
globalThis.WebSocket = MockWebSocket as unknown as typeof globalThis.WebSocket;
|
||||
|
||||
Object.defineProperty(document, "cookie", {
|
||||
value: "session_token=test-session-token",
|
||||
configurable: true,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
globalThis.WebSocket = originalWs;
|
||||
});
|
||||
|
||||
function runWithRoot<T>(fn: (dispose: () => void) => T): T {
|
||||
let result!: T;
|
||||
createRoot((dispose) => {
|
||||
result = fn(dispose);
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
it("should call WebSocket constructor with token", () => {
|
||||
const ws = new globalThis.WebSocket("ws://test/tok");
|
||||
expect(wsConstructorUrls).toHaveLength(1);
|
||||
expect(wsConstructorUrls[0]).toContain("ws://test/tok");
|
||||
expect(ws).toBe(mockWs);
|
||||
});
|
||||
|
||||
it("should connect and set connected status", async () => {
|
||||
const { createWebSocketClient } = await import("./websocket");
|
||||
const client = runWithRoot(() => createWebSocketClient());
|
||||
client.connect();
|
||||
|
||||
expect(wsConstructorUrls).toHaveLength(1);
|
||||
expect(wsConstructorUrls[0]).toContain("token=test-session-token");
|
||||
|
||||
mockWs.onopen?.();
|
||||
await new Promise((r) => setTimeout(r, 5));
|
||||
expect(client.connectionStatus()).toBe("connected");
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it("should not connect without auth token", async () => {
|
||||
Object.defineProperty(document, "cookie", {
|
||||
value: "",
|
||||
configurable: true,
|
||||
});
|
||||
|
||||
const { createWebSocketClient } = await import("./websocket");
|
||||
const client = runWithRoot(() => createWebSocketClient());
|
||||
client.connect();
|
||||
|
||||
expect(wsConstructorUrls).toHaveLength(0);
|
||||
expect(client.connectionStatus()).toBe("disconnected");
|
||||
});
|
||||
|
||||
it("should reconnect on unexpected disconnect", async () => {
|
||||
const { createWebSocketClient } = await import("./websocket");
|
||||
const client = runWithRoot(() => createWebSocketClient());
|
||||
client.connect();
|
||||
mockWs.onopen?.();
|
||||
await new Promise((r) => setTimeout(r, 5));
|
||||
expect(client.connectionStatus()).toBe("connected");
|
||||
|
||||
mockWs.onclose?.({ code: 1006 });
|
||||
await new Promise((r) => setTimeout(r, 5));
|
||||
expect(client.connectionStatus()).toBe("reconnecting");
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it("should fire onAlert callback", async () => {
|
||||
const { createWebSocketClient } = await import("./websocket");
|
||||
const client = runWithRoot(() => createWebSocketClient());
|
||||
const alerts: Array<{ title: string }> = [];
|
||||
client.onAlert((a) => alerts.push(a));
|
||||
client.connect();
|
||||
mockWs.onopen?.();
|
||||
await new Promise((r) => setTimeout(r, 5));
|
||||
|
||||
mockWs.onmessage?.({
|
||||
data: JSON.stringify({
|
||||
type: "alert",
|
||||
alert: {
|
||||
id: "a1", title: "Test Alert", message: "Test",
|
||||
severity: "HIGH", source: "DARKWATCH",
|
||||
category: "EXPOSURE_DETECTED",
|
||||
createdAt: new Date().toISOString(),
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
expect(alerts).toHaveLength(1);
|
||||
expect(alerts[0].title).toBe("Test Alert");
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it("should disconnect and set disconnected", async () => {
|
||||
const { createWebSocketClient } = await import("./websocket");
|
||||
const client = runWithRoot(() => createWebSocketClient());
|
||||
client.connect();
|
||||
mockWs.onopen?.();
|
||||
await new Promise((r) => setTimeout(r, 5));
|
||||
expect(client.connectionStatus()).toBe("connected");
|
||||
|
||||
client.disconnect();
|
||||
expect(client.connectionStatus()).toBe("disconnected");
|
||||
});
|
||||
|
||||
it("should fire pong handler without throwing", async () => {
|
||||
const { createWebSocketClient } = await import("./websocket");
|
||||
const client = runWithRoot(() => createWebSocketClient());
|
||||
client.connect();
|
||||
mockWs.onopen?.();
|
||||
await new Promise((r) => setTimeout(r, 5));
|
||||
expect(client.connectionStatus()).toBe("connected");
|
||||
|
||||
expect(() => {
|
||||
mockWs.onmessage?.({ data: JSON.stringify({ type: "pong" }) });
|
||||
}).not.toThrow();
|
||||
client.disconnect();
|
||||
});
|
||||
});
|
||||
221
web/src/lib/websocket.ts
Normal file
221
web/src/lib/websocket.ts
Normal file
@@ -0,0 +1,221 @@
|
||||
import { createSignal, onCleanup } from "solid-js";
|
||||
|
||||
const WS_URL = import.meta.env.VITE_WS_URL ?? "ws://localhost:3001";
|
||||
const RECONNECT_DELAYS = [1000, 2000, 5000, 10_000, 30_000];
|
||||
const MAX_RECONNECT_ATTEMPTS = 10;
|
||||
const HEARTBEAT_INTERVAL = 30_000;
|
||||
const PONG_TIMEOUT = 10_000;
|
||||
|
||||
export interface AlertPayload {
|
||||
id: string;
|
||||
title: string;
|
||||
message: string;
|
||||
severity: string;
|
||||
source: string;
|
||||
category: string;
|
||||
createdAt: string;
|
||||
}
|
||||
|
||||
export type ConnectionStatus = "connecting" | "connected" | "disconnected" | "reconnecting";
|
||||
|
||||
function getAuthToken(): string | null {
|
||||
if (typeof document === "undefined") return null;
|
||||
const match = document.cookie.match(/(?:^|;\s*)session_token=([^;]*)/);
|
||||
if (match) return match[1];
|
||||
try {
|
||||
return localStorage.getItem("auth_token");
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export function createWebSocketClient() {
|
||||
const [connectionStatus, setConnectionStatus] = createSignal<ConnectionStatus>("disconnected");
|
||||
const [lastAlert, setLastAlert] = createSignal<AlertPayload | null>(null);
|
||||
|
||||
let ws: WebSocket | null = null;
|
||||
let reconnectAttempts = 0;
|
||||
let reconnectTimer: ReturnType<typeof setTimeout> | null = null;
|
||||
let heartbeatTimer: ReturnType<typeof setInterval> | null = null;
|
||||
let pongTimer: ReturnType<typeof setTimeout> | null = null;
|
||||
let intentionalClose = false;
|
||||
let listeners: Array<(alert: AlertPayload) => void> = [];
|
||||
let statusListeners: Array<(status: ConnectionStatus) => void> = [];
|
||||
let mountedCleanups: Array<() => void> = [];
|
||||
|
||||
function notifyAlert(alert: AlertPayload) {
|
||||
setLastAlert(alert);
|
||||
for (const listener of listeners) {
|
||||
try {
|
||||
listener(alert);
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function notifyStatus(status: ConnectionStatus) {
|
||||
setConnectionStatus(status);
|
||||
for (const listener of statusListeners) {
|
||||
try {
|
||||
listener(status);
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function startHeartbeat() {
|
||||
stopHeartbeat();
|
||||
heartbeatTimer = setInterval(() => {
|
||||
if (ws?.readyState === WebSocket.OPEN) {
|
||||
ws.send(JSON.stringify({ type: "ping" }));
|
||||
pongTimer = setTimeout(() => {
|
||||
intentionalClose = false;
|
||||
ws?.close();
|
||||
}, PONG_TIMEOUT);
|
||||
}
|
||||
}, HEARTBEAT_INTERVAL);
|
||||
}
|
||||
|
||||
function stopHeartbeat() {
|
||||
if (heartbeatTimer) {
|
||||
clearInterval(heartbeatTimer);
|
||||
heartbeatTimer = null;
|
||||
}
|
||||
if (pongTimer) {
|
||||
clearTimeout(pongTimer);
|
||||
pongTimer = null;
|
||||
}
|
||||
}
|
||||
|
||||
function scheduleReconnect() {
|
||||
if (intentionalClose) return;
|
||||
if (reconnectAttempts >= MAX_RECONNECT_ATTEMPTS) {
|
||||
notifyStatus("disconnected");
|
||||
return;
|
||||
}
|
||||
|
||||
notifyStatus("reconnecting");
|
||||
const delay =
|
||||
RECONNECT_DELAYS[reconnectAttempts] ??
|
||||
RECONNECT_DELAYS[RECONNECT_DELAYS.length - 1];
|
||||
reconnectAttempts++;
|
||||
|
||||
reconnectTimer = setTimeout(() => {
|
||||
connect();
|
||||
}, delay);
|
||||
}
|
||||
|
||||
function connect() {
|
||||
if (ws?.readyState === WebSocket.OPEN || ws?.readyState === WebSocket.CONNECTING) {
|
||||
return;
|
||||
}
|
||||
|
||||
const token = getAuthToken();
|
||||
if (!token) {
|
||||
notifyStatus("disconnected");
|
||||
return;
|
||||
}
|
||||
|
||||
intentionalClose = false;
|
||||
notifyStatus("connecting");
|
||||
|
||||
try {
|
||||
ws = new WebSocket(`${WS_URL}?token=${encodeURIComponent(token)}`);
|
||||
|
||||
ws.onopen = () => {
|
||||
reconnectAttempts = 0;
|
||||
notifyStatus("connected");
|
||||
startHeartbeat();
|
||||
};
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
try {
|
||||
const data = JSON.parse(event.data);
|
||||
|
||||
if (data.type === "pong") {
|
||||
if (pongTimer) {
|
||||
clearTimeout(pongTimer);
|
||||
pongTimer = null;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (data.type === "alert") {
|
||||
notifyAlert(data.alert as AlertPayload);
|
||||
}
|
||||
} catch {
|
||||
// ignore invalid messages
|
||||
}
|
||||
};
|
||||
|
||||
ws.onclose = () => {
|
||||
stopHeartbeat();
|
||||
if (!intentionalClose) {
|
||||
scheduleReconnect();
|
||||
} else {
|
||||
notifyStatus("disconnected");
|
||||
}
|
||||
};
|
||||
|
||||
ws.onerror = () => {
|
||||
// onclose will handle reconnection
|
||||
};
|
||||
} catch {
|
||||
notifyStatus("disconnected");
|
||||
}
|
||||
}
|
||||
|
||||
function disconnect() {
|
||||
intentionalClose = true;
|
||||
stopHeartbeat();
|
||||
if (reconnectTimer) {
|
||||
clearTimeout(reconnectTimer);
|
||||
reconnectTimer = null;
|
||||
}
|
||||
if (ws) {
|
||||
ws.close(1000, "Client disconnecting");
|
||||
ws = null;
|
||||
}
|
||||
notifyStatus("disconnected");
|
||||
}
|
||||
|
||||
function onAlert(listener: (alert: AlertPayload) => void) {
|
||||
listeners.push(listener);
|
||||
return () => {
|
||||
listeners = listeners.filter((l) => l !== listener);
|
||||
};
|
||||
}
|
||||
|
||||
function onStatusChange(listener: (status: ConnectionStatus) => void) {
|
||||
statusListeners.push(listener);
|
||||
return () => {
|
||||
statusListeners = statusListeners.filter((l) => l !== listener);
|
||||
};
|
||||
}
|
||||
|
||||
function cleanup() {
|
||||
disconnect();
|
||||
listeners = [];
|
||||
statusListeners = [];
|
||||
for (const c of mountedCleanups) {
|
||||
c();
|
||||
}
|
||||
mountedCleanups = [];
|
||||
}
|
||||
|
||||
onCleanup(() => cleanup());
|
||||
|
||||
return {
|
||||
connect,
|
||||
disconnect,
|
||||
onAlert,
|
||||
onStatusChange,
|
||||
connectionStatus,
|
||||
lastAlert,
|
||||
cleanup,
|
||||
};
|
||||
}
|
||||
|
||||
export type WebSocketClient = ReturnType<typeof createWebSocketClient>;
|
||||
@@ -20,17 +20,18 @@ export const spamshieldRouter = createTRPCRouter({
|
||||
|
||||
classifySMS: publicProcedure
|
||||
.input(wrap(ClassifySMSSchema))
|
||||
.query(async ({ input }) => {
|
||||
return spamshieldService.classifySMS(input.text);
|
||||
.query(async ({ input, ctx }) => {
|
||||
return spamshieldService.classifySMS(input.text, ctx.user?.id);
|
||||
}),
|
||||
|
||||
classifyCall: publicProcedure
|
||||
.input(wrap(ClassifyCallSchema))
|
||||
.query(async ({ input }) => {
|
||||
.query(async ({ input, ctx }) => {
|
||||
return spamshieldService.classifyCall(
|
||||
input.callerNumber,
|
||||
input.duration,
|
||||
input.timeOfDay,
|
||||
ctx.user?.id,
|
||||
);
|
||||
}),
|
||||
|
||||
|
||||
111
web/src/server/services/alert.publisher.test.ts
Normal file
111
web/src/server/services/alert.publisher.test.ts
Normal file
@@ -0,0 +1,111 @@
|
||||
// @vitest-environment node
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||
|
||||
const mockBroadcastToUser = vi.fn();
|
||||
|
||||
vi.mock("~/server/websocket", () => ({
|
||||
broadcastToUser: mockBroadcastToUser,
|
||||
}));
|
||||
|
||||
const mockSendPush = vi.fn();
|
||||
const mockSendEmail = vi.fn();
|
||||
|
||||
vi.mock("~/server/services/notification.service", () => ({
|
||||
sendPush: mockSendPush,
|
||||
sendEmail: mockSendEmail,
|
||||
}));
|
||||
|
||||
vi.mock("~/server/db", () => ({
|
||||
db: {
|
||||
select: vi.fn(),
|
||||
insert: vi.fn(),
|
||||
update: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("alert.publisher", () => {
|
||||
it("should send alert via WebSocket when user is connected", async () => {
|
||||
mockBroadcastToUser.mockReturnValue(true);
|
||||
|
||||
const { publishAlert } = await import("./alert.publisher");
|
||||
await publishAlert("user-1", {
|
||||
id: "alert-1",
|
||||
title: "Test Alert",
|
||||
message: "Test message",
|
||||
severity: "HIGH",
|
||||
source: "DARKWATCH",
|
||||
category: "EXPOSURE_DETECTED",
|
||||
createdAt: new Date(),
|
||||
});
|
||||
|
||||
expect(mockBroadcastToUser).toHaveBeenCalledWith("user-1", {
|
||||
type: "alert",
|
||||
alert: {
|
||||
id: "alert-1",
|
||||
title: "Test Alert",
|
||||
message: "Test message",
|
||||
severity: "HIGH",
|
||||
source: "DARKWATCH",
|
||||
category: "EXPOSURE_DETECTED",
|
||||
createdAt: expect.any(String),
|
||||
},
|
||||
});
|
||||
expect(mockSendPush).not.toHaveBeenCalled();
|
||||
expect(mockSendEmail).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should fall back to push notification when user is not connected", async () => {
|
||||
mockBroadcastToUser.mockReturnValue(false);
|
||||
mockSendPush.mockResolvedValue({ successCount: 1 });
|
||||
|
||||
const { publishAlert } = await import("./alert.publisher");
|
||||
await publishAlert("user-1", {
|
||||
id: "alert-2",
|
||||
title: "Offline Alert",
|
||||
message: "Offline message",
|
||||
severity: "WARNING",
|
||||
source: "VOICEPRINT",
|
||||
category: "SYNTHETIC_VOICE",
|
||||
createdAt: new Date(),
|
||||
});
|
||||
|
||||
expect(mockBroadcastToUser).toHaveBeenCalled();
|
||||
expect(mockSendPush).toHaveBeenCalledWith(
|
||||
"user-1",
|
||||
"Offline Alert",
|
||||
"Offline message",
|
||||
{ alertId: "alert-2", source: "VOICEPRINT", severity: "WARNING" },
|
||||
);
|
||||
});
|
||||
|
||||
it("should publish alert to multiple users", async () => {
|
||||
mockBroadcastToUser.mockReturnValue(false);
|
||||
mockSendPush.mockResolvedValue({ successCount: 0 });
|
||||
|
||||
const db = await import("~/server/db");
|
||||
(db.db.select as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([]),
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
const { publishToGroup } = await import("./alert.publisher");
|
||||
await publishToGroup(["user-1", "user-2"], {
|
||||
id: "alert-3",
|
||||
title: "Group Alert",
|
||||
message: "Group message",
|
||||
severity: "INFO",
|
||||
source: "HOME_TITLE",
|
||||
category: "HOME_TITLE",
|
||||
createdAt: new Date(),
|
||||
});
|
||||
|
||||
expect(mockBroadcastToUser).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
66
web/src/server/services/alert.publisher.ts
Normal file
66
web/src/server/services/alert.publisher.ts
Normal file
@@ -0,0 +1,66 @@
|
||||
import { broadcastToUser } from "~/server/websocket";
|
||||
import { sendPush, sendEmail } from "~/server/services/notification.service";
|
||||
import { db } from "~/server/db";
|
||||
import { users } from "~/server/db/schema/auth";
|
||||
import { eq } from "drizzle-orm";
|
||||
|
||||
export interface PublishableAlert {
|
||||
id: string;
|
||||
title: string;
|
||||
message: string;
|
||||
severity: string;
|
||||
source: string;
|
||||
category: string;
|
||||
createdAt: Date;
|
||||
}
|
||||
|
||||
export async function publishAlert(userId: string, alert: PublishableAlert): Promise<void> {
|
||||
const message = {
|
||||
type: "alert" as const,
|
||||
alert: {
|
||||
id: alert.id,
|
||||
title: alert.title,
|
||||
message: alert.message,
|
||||
severity: alert.severity,
|
||||
source: alert.source,
|
||||
category: alert.category,
|
||||
createdAt: alert.createdAt.toISOString(),
|
||||
},
|
||||
};
|
||||
|
||||
const sent = broadcastToUser(userId, message);
|
||||
|
||||
if (!sent) {
|
||||
try {
|
||||
const pushResult = await sendPush(userId, alert.title, alert.message, {
|
||||
alertId: alert.id,
|
||||
source: alert.source,
|
||||
severity: alert.severity,
|
||||
});
|
||||
|
||||
if (pushResult.successCount === 0) {
|
||||
const [user] = await db
|
||||
.select()
|
||||
.from(users)
|
||||
.where(eq(users.id, userId))
|
||||
.limit(1);
|
||||
|
||||
if (user?.email) {
|
||||
await sendEmail(
|
||||
user.email,
|
||||
`[ShieldAI] ${alert.title}`,
|
||||
`<p>${alert.message}</p>`,
|
||||
alert.message,
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("[alert.publisher] Fallback notification failed:", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export async function publishToGroup(userIds: string[], alert: PublishableAlert): Promise<void> {
|
||||
const promises = userIds.map((userId) => publishAlert(userId, alert));
|
||||
await Promise.allSettled(promises);
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import { db } from "~/server/db";
|
||||
import { exposures, alerts, subscriptions } from "~/server/db/schema";
|
||||
import { publishAlert } from "~/server/services/alert.publisher";
|
||||
|
||||
export function severityScore(exposure: {
|
||||
source: string;
|
||||
@@ -105,14 +106,27 @@ async function createAlertForExposure(
|
||||
|
||||
if (!sub) return;
|
||||
|
||||
await db.insert(alerts).values({
|
||||
subscriptionId: exposure.subscriptionId,
|
||||
userId: sub.userId,
|
||||
exposureId: exposure.id,
|
||||
type: "exposure_detected",
|
||||
const [alert] = await db
|
||||
.insert(alerts)
|
||||
.values({
|
||||
subscriptionId: exposure.subscriptionId,
|
||||
userId: sub.userId,
|
||||
exposureId: exposure.id,
|
||||
type: "exposure_detected",
|
||||
title,
|
||||
message,
|
||||
severity: alertSeverityMap[severity] ?? "info",
|
||||
channel: ["email", "push"],
|
||||
})
|
||||
.returning();
|
||||
|
||||
publishAlert(sub.userId, {
|
||||
id: alert.id,
|
||||
title,
|
||||
message,
|
||||
severity: alertSeverityMap[severity] ?? "info",
|
||||
channel: ["email", "push"],
|
||||
});
|
||||
source: "DARKWATCH",
|
||||
category: "EXPOSURE_DETECTED",
|
||||
createdAt: alert.createdAt,
|
||||
}).catch((err) => console.error("[darkwatch] Failed to publish alert:", err));
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import {
|
||||
type DetectedChange,
|
||||
type SnapshotData,
|
||||
} from "./hometitle/change.detector";
|
||||
import { publishAlert } from "~/server/services/alert.publisher";
|
||||
import {
|
||||
geocodeAddress,
|
||||
fetchCountyRecords,
|
||||
@@ -433,7 +434,7 @@ async function generateAlert(
|
||||
critical: "CRITICAL",
|
||||
};
|
||||
|
||||
await db
|
||||
const [normalized] = await db
|
||||
.insert(normalizedAlerts)
|
||||
.values({
|
||||
source: "HOME_TITLE",
|
||||
@@ -454,5 +455,15 @@ async function generateAlert(
|
||||
})
|
||||
.returning();
|
||||
|
||||
publishAlert(sub.userId, {
|
||||
id: alert.id,
|
||||
title,
|
||||
message,
|
||||
severity: change.severity,
|
||||
source: "HOME_TITLE",
|
||||
category: "HOME_TITLE",
|
||||
createdAt: alert.createdAt,
|
||||
}).catch((err) => console.error("[hometitle] Failed to publish alert:", err));
|
||||
|
||||
return alert;
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { TRPCError } from "@trpc/server";
|
||||
import { eq, and, desc, count, inArray, or, isNull, lt } from "drizzle-orm";
|
||||
import { db } from "~/server/db";
|
||||
import { infoBrokers, removalRequests, brokerListings, subscriptions } from "~/server/db/schema";
|
||||
import { infoBrokers, removalRequests, brokerListings, subscriptions, normalizedAlerts } from "~/server/db/schema";
|
||||
import { getActiveBrokers } from "./removebrokers/broker.registry";
|
||||
import { publishAlert } from "~/server/services/alert.publisher";
|
||||
import { submitAutomatedRemoval, sendRemovalEmail } from "./removebrokers/removal.engine";
|
||||
import type { PersonalInfo } from "./removebrokers/removal.engine";
|
||||
import type { RemovalMethod } from "./removebrokers/broker.registry";
|
||||
@@ -273,6 +274,38 @@ export async function scanForListings(userId: string, brokerId?: string) {
|
||||
createdListings.push({ id: listing.id, brokerId: listing.brokerId, url: listing.url });
|
||||
}
|
||||
|
||||
if (createdListings.length > 0) {
|
||||
const sourceAlertId = `removebrokers:scan:${sub.id}:${Date.now()}`;
|
||||
try {
|
||||
const [na] = await db
|
||||
.insert(normalizedAlerts)
|
||||
.values({
|
||||
source: "REMOVEBROKERS",
|
||||
category: "BROKER_LISTING",
|
||||
severity: "INFO",
|
||||
userId: sub.userId,
|
||||
title: "New Broker Listings Found",
|
||||
description: `Found ${createdListings.length} new broker listing(s)`,
|
||||
entities: { listingCount: createdListings.length, brokerIds: createdListings.map((l) => l.brokerId) },
|
||||
sourceAlertId,
|
||||
createdAt: new Date(),
|
||||
})
|
||||
.returning();
|
||||
|
||||
publishAlert(sub.userId, {
|
||||
id: na.id,
|
||||
title: "New Broker Listings Found",
|
||||
message: `Found ${createdListings.length} new broker listing(s)`,
|
||||
severity: "INFO",
|
||||
source: "REMOVEBROKERS",
|
||||
category: "BROKER_LISTING",
|
||||
createdAt: na.createdAt,
|
||||
}).catch(() => {});
|
||||
} catch {
|
||||
// alert creation is best-effort
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
scanned: brokers.length,
|
||||
listingsFound: createdListings.length,
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import { TRPCError } from "@trpc/server";
|
||||
import { eq, and, desc, count, sql, gte } from "drizzle-orm";
|
||||
import { db } from "~/server/db";
|
||||
import { spamFeedback, spamRules, auditLogs } from "~/server/db/schema";
|
||||
import { spamFeedback, spamRules, auditLogs, normalizedAlerts } from "~/server/db/schema";
|
||||
import { classifyTextBERT, extractFeatures, ruleEngine } from "./spamshield/ml.engine";
|
||||
import { checkReputation } from "./spamshield/reputation.api";
|
||||
import type { ReputationResult } from "./spamshield/reputation.api";
|
||||
import { publishAlert } from "~/server/services/alert.publisher";
|
||||
|
||||
function normalizePhoneNumber(phone: string): string {
|
||||
let cleaned = phone.replace(/[^\d+]/g, "");
|
||||
@@ -72,7 +73,7 @@ export async function checkNumberReputation(phoneNumber: string) {
|
||||
return response;
|
||||
}
|
||||
|
||||
export async function classifySMS(text: string) {
|
||||
export async function classifySMS(text: string, userId?: string) {
|
||||
const classification = await classifyTextBERT(text);
|
||||
|
||||
const response = {
|
||||
@@ -83,6 +84,38 @@ export async function classifySMS(text: string) {
|
||||
|
||||
await logAudit(undefined, "classifySMS", { textLength: text.length }, response, response.confidence);
|
||||
|
||||
if (classification.isSpam && classification.confidence >= 0.8 && userId) {
|
||||
const sourceAlertId = `spamshield:sms:${Date.now()}`;
|
||||
try {
|
||||
const [na] = await db
|
||||
.insert(normalizedAlerts)
|
||||
.values({
|
||||
source: "SPAMSHIELD",
|
||||
category: "SPAM_SMS",
|
||||
severity: "LOW",
|
||||
userId,
|
||||
title: "Spam SMS Detected",
|
||||
description: "Suspected spam SMS detected",
|
||||
entities: { textLength: text.length, confidence: classification.confidence },
|
||||
sourceAlertId,
|
||||
createdAt: new Date(),
|
||||
})
|
||||
.returning();
|
||||
|
||||
publishAlert(userId, {
|
||||
id: na.id,
|
||||
title: "Spam SMS Detected",
|
||||
message: "Suspected spam SMS detected",
|
||||
severity: "LOW",
|
||||
source: "SPAMSHIELD",
|
||||
category: "SPAM_SMS",
|
||||
createdAt: na.createdAt,
|
||||
}).catch(() => {});
|
||||
} catch {
|
||||
// alert creation is best-effort
|
||||
}
|
||||
}
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
@@ -90,6 +123,7 @@ export async function classifyCall(
|
||||
callerNumber: string,
|
||||
duration?: number,
|
||||
timeOfDay?: number,
|
||||
userId?: string,
|
||||
) {
|
||||
const normalized = normalizePhoneNumber(callerNumber);
|
||||
const features = await extractFeatures({ callerNumber: normalized, duration, timeOfDay });
|
||||
@@ -138,6 +172,38 @@ export async function classifyCall(
|
||||
|
||||
await logAudit(undefined, "classifyCall", { callerNumber: normalized, duration, timeOfDay }, response, confidence);
|
||||
|
||||
if (isSpam && confidence >= 0.8 && userId) {
|
||||
const sourceAlertId = `spamshield:call:${normalized}:${Date.now()}`;
|
||||
try {
|
||||
const [na] = await db
|
||||
.insert(normalizedAlerts)
|
||||
.values({
|
||||
source: "SPAMSHIELD",
|
||||
category: "SPAM_CALL",
|
||||
severity: "MEDIUM",
|
||||
userId,
|
||||
title: "Spam Call Blocked",
|
||||
description: `Blocked spam call from ${normalized}`,
|
||||
entities: { callerNumber: normalized, confidence, ruleMatch: !!ruleMatch },
|
||||
sourceAlertId,
|
||||
createdAt: new Date(),
|
||||
})
|
||||
.returning();
|
||||
|
||||
publishAlert(userId, {
|
||||
id: na.id,
|
||||
title: "Spam Call Blocked",
|
||||
message: `Blocked spam call from ${normalized}`,
|
||||
severity: "MEDIUM",
|
||||
source: "SPAMSHIELD",
|
||||
category: "SPAM_CALL",
|
||||
createdAt: na.createdAt,
|
||||
}).catch(() => {});
|
||||
} catch {
|
||||
// alert creation is best-effort
|
||||
}
|
||||
}
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
normalizedAlerts,
|
||||
} from "~/server/db/schema";
|
||||
import { saveAudio, getAudioUrl, deleteFile } from "./voiceprint/storage";
|
||||
import { publishAlert } from "~/server/services/alert.publisher";
|
||||
import {
|
||||
preprocessAudio,
|
||||
detectSynthetic,
|
||||
@@ -100,17 +101,33 @@ async function createVoiceAlert(
|
||||
|
||||
if (sub) {
|
||||
const category = verdict === "SYNTHETIC" ? "SYNTHETIC_VOICE" : "VOICE_MISMATCH";
|
||||
await db.insert(normalizedAlerts).values({
|
||||
const title = verdict === "SYNTHETIC" ? "Synthetic Voice Detected" : "Voice Mismatch Detected";
|
||||
const description = `Analysis ${analysisId} returned verdict ${verdict} with ${(confidence * 100).toFixed(1)}% confidence`;
|
||||
|
||||
const [alert] = await db
|
||||
.insert(normalizedAlerts)
|
||||
.values({
|
||||
source: "VOICEPRINT",
|
||||
category,
|
||||
severity: verdict === "SYNTHETIC" ? "HIGH" : "MEDIUM",
|
||||
userId,
|
||||
title,
|
||||
description,
|
||||
entities: { analysisId, verdict, confidence },
|
||||
sourceAlertId: `voiceprint-${analysisId}`,
|
||||
createdAt: new Date(),
|
||||
})
|
||||
.returning();
|
||||
|
||||
publishAlert(userId, {
|
||||
id: alert.id,
|
||||
title,
|
||||
message: description,
|
||||
severity: verdict === "SYNTHETIC" ? "HIGH" : "MEDIUM",
|
||||
source: "VOICEPRINT",
|
||||
category,
|
||||
severity: verdict === "SYNTHETIC" ? "HIGH" : "MEDIUM",
|
||||
userId,
|
||||
title: verdict === "SYNTHETIC" ? "Synthetic Voice Detected" : "Voice Mismatch Detected",
|
||||
description: `Analysis ${analysisId} returned verdict ${verdict} with ${(confidence * 100).toFixed(1)}% confidence`,
|
||||
entities: { analysisId, verdict, confidence },
|
||||
sourceAlertId: `voiceprint-${analysisId}`,
|
||||
createdAt: new Date(),
|
||||
});
|
||||
createdAt: alert.createdAt,
|
||||
}).catch((err) => console.error("[voiceprint] Failed to publish alert:", err));
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("[voiceprint] Failed to create alert:", err);
|
||||
|
||||
95
web/src/server/websocket.test.ts
Normal file
95
web/src/server/websocket.test.ts
Normal file
@@ -0,0 +1,95 @@
|
||||
// @vitest-environment node
|
||||
import { describe, it, expect, vi, beforeAll, afterAll } from "vitest";
|
||||
|
||||
const mockVerifyJWT = vi.fn();
|
||||
|
||||
vi.mock("~/server/auth/jwt", () => ({
|
||||
verifyJWT: mockVerifyJWT,
|
||||
signJWT: vi.fn(),
|
||||
}));
|
||||
|
||||
let mockServer: any;
|
||||
let connectionHandler: ((ws: any, req: any) => void) | null = null;
|
||||
|
||||
vi.mock("ws", () => {
|
||||
mockServer = {
|
||||
on: vi.fn((event: string, handler: any) => {
|
||||
if (event === "connection") connectionHandler = handler;
|
||||
}),
|
||||
close: vi.fn((cb: () => void) => cb()),
|
||||
clients: new Set(),
|
||||
};
|
||||
|
||||
return {
|
||||
WebSocketServer: vi.fn(function (_opts: any, cb?: () => void) {
|
||||
if (cb) setTimeout(cb, 0);
|
||||
return mockServer;
|
||||
}),
|
||||
WebSocket: { OPEN: 1, CONNECTING: 0, CLOSING: 2, CLOSED: 3 },
|
||||
};
|
||||
});
|
||||
|
||||
function makeWs() {
|
||||
const handlers: Record<string, (...args: any[]) => void> = {};
|
||||
return {
|
||||
close: vi.fn(),
|
||||
send: vi.fn(),
|
||||
ping: vi.fn(),
|
||||
terminate: vi.fn(),
|
||||
readyState: 1,
|
||||
on: vi.fn((event: string, handler: any) => {
|
||||
handlers[event] = handler;
|
||||
}),
|
||||
emit: (event: string, ...args: any[]) => {
|
||||
handlers[event]?.(...args);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
describe("WebSocket server", () => {
|
||||
beforeAll(async () => {
|
||||
process.env.WS_PORT = "3099";
|
||||
const { start } = await import("~/server/websocket");
|
||||
await start();
|
||||
}, 15000);
|
||||
|
||||
afterAll(async () => {
|
||||
const { stop } = await import("~/server/websocket");
|
||||
await stop();
|
||||
});
|
||||
|
||||
it("should reject connection without JWT", async () => {
|
||||
const ws = makeWs();
|
||||
await connectionHandler!(ws, { url: "/" });
|
||||
expect(ws.close).toHaveBeenCalledWith(4001, "Authentication failed");
|
||||
});
|
||||
|
||||
it("should reject connection with invalid JWT", async () => {
|
||||
mockVerifyJWT.mockRejectedValue(new Error("Invalid token"));
|
||||
|
||||
const ws = makeWs();
|
||||
await connectionHandler!(ws, { url: "/?token=bad" });
|
||||
expect(ws.close).toHaveBeenCalledWith(4001, "Authentication failed");
|
||||
});
|
||||
|
||||
it("should accept connection with valid JWT", async () => {
|
||||
mockVerifyJWT.mockResolvedValue({ sub: "user-1" });
|
||||
|
||||
const ws = makeWs();
|
||||
await connectionHandler!(ws, { url: "/?token=good" });
|
||||
expect(ws.close).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should return false when broadcasting to non-existent user", async () => {
|
||||
const { broadcastToUser } = await import("~/server/websocket");
|
||||
const result = broadcastToUser("nonexistent", {
|
||||
type: "alert",
|
||||
alert: {
|
||||
id: "a1", title: "T", message: "M",
|
||||
severity: "INFO", source: "TEST", category: "TEST",
|
||||
createdAt: new Date().toISOString(),
|
||||
},
|
||||
});
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
||||
216
web/src/server/websocket.ts
Normal file
216
web/src/server/websocket.ts
Normal file
@@ -0,0 +1,216 @@
|
||||
import { WebSocketServer, WebSocket } from "ws";
|
||||
import type { Server } from "ws";
|
||||
import { IncomingMessage } from "node:http";
|
||||
import { URL } from "node:url";
|
||||
import { verifyJWT } from "~/server/auth/jwt";
|
||||
|
||||
const WS_PORT = parseInt(process.env.WS_PORT ?? "3001", 10);
|
||||
const HEARTBEAT_INTERVAL = 30_000;
|
||||
const PONG_TIMEOUT = 10_000;
|
||||
|
||||
interface AlertMessage {
|
||||
type: "alert";
|
||||
alert: {
|
||||
id: string;
|
||||
title: string;
|
||||
message: string;
|
||||
severity: string;
|
||||
source: string;
|
||||
category: string;
|
||||
createdAt: string;
|
||||
};
|
||||
}
|
||||
|
||||
interface WsClient extends WebSocket {
|
||||
userId?: string;
|
||||
isAlive?: boolean;
|
||||
pongTimer?: ReturnType<typeof setTimeout>;
|
||||
}
|
||||
|
||||
const userSockets = new Map<string, Set<WsClient>>();
|
||||
let wss: Server | null = null;
|
||||
let heartbeatTimer: ReturnType<typeof setInterval> | null = null;
|
||||
|
||||
function getTokenFromRequest(req: IncomingMessage): string | null {
|
||||
const url = new URL(req.url ?? "/", "http://localhost");
|
||||
return url.searchParams.get("token");
|
||||
}
|
||||
|
||||
async function authenticateConnection(
|
||||
ws: WsClient,
|
||||
req: IncomingMessage,
|
||||
): Promise<string | null> {
|
||||
const token = getTokenFromRequest(req);
|
||||
if (!token) return null;
|
||||
|
||||
try {
|
||||
const payload = await verifyJWT<{ sub?: string; userId?: string }>(token);
|
||||
const userId = payload.sub ?? payload.userId;
|
||||
if (!userId) return null;
|
||||
return userId;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function addSocket(userId: string, ws: WsClient) {
|
||||
let sockets = userSockets.get(userId);
|
||||
if (!sockets) {
|
||||
sockets = new Set();
|
||||
userSockets.set(userId, sockets);
|
||||
}
|
||||
sockets.add(ws);
|
||||
}
|
||||
|
||||
function removeSocket(userId: string, ws: WsClient) {
|
||||
const sockets = userSockets.get(userId);
|
||||
if (!sockets) return;
|
||||
sockets.delete(ws);
|
||||
if (sockets.size === 0) {
|
||||
userSockets.delete(userId);
|
||||
}
|
||||
}
|
||||
|
||||
function heartbeat(ws: WsClient) {
|
||||
ws.isAlive = true;
|
||||
}
|
||||
|
||||
function startHeartbeat() {
|
||||
if (heartbeatTimer) clearInterval(heartbeatTimer);
|
||||
heartbeatTimer = setInterval(() => {
|
||||
if (!wss) return;
|
||||
wss.clients.forEach((client) => {
|
||||
const ws = client as WsClient;
|
||||
if (ws.isAlive === false) {
|
||||
ws.terminate();
|
||||
return;
|
||||
}
|
||||
ws.isAlive = false;
|
||||
ws.ping();
|
||||
|
||||
ws.pongTimer = setTimeout(() => {
|
||||
ws.terminate();
|
||||
}, PONG_TIMEOUT);
|
||||
});
|
||||
}, HEARTBEAT_INTERVAL);
|
||||
|
||||
if (heartbeatTimer && typeof heartbeatTimer === "object") {
|
||||
heartbeatTimer.unref();
|
||||
}
|
||||
}
|
||||
|
||||
function stopHeartbeat() {
|
||||
if (heartbeatTimer) {
|
||||
clearInterval(heartbeatTimer);
|
||||
heartbeatTimer = null;
|
||||
}
|
||||
}
|
||||
|
||||
export function broadcastToUser(userId: string, data: AlertMessage) {
|
||||
const sockets = userSockets.get(userId);
|
||||
if (!sockets || sockets.size === 0) return false;
|
||||
|
||||
const message = JSON.stringify(data);
|
||||
let sent = false;
|
||||
for (const ws of sockets) {
|
||||
if (ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(message);
|
||||
sent = true;
|
||||
}
|
||||
}
|
||||
return sent;
|
||||
}
|
||||
|
||||
export function getConnectedUsers(): string[] {
|
||||
return Array.from(userSockets.keys());
|
||||
}
|
||||
|
||||
export function getConnectionCount(): number {
|
||||
let count = 0;
|
||||
for (const sockets of userSockets.values()) {
|
||||
count += sockets.size;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
export function start(): Promise<void> {
|
||||
return new Promise((resolve) => {
|
||||
if (wss) {
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
|
||||
wss = new WebSocketServer({ port: WS_PORT }, () => {
|
||||
console.log(`[websocket] Server listening on port ${WS_PORT}`);
|
||||
resolve();
|
||||
});
|
||||
|
||||
wss.on("connection", async (ws: WsClient, req: IncomingMessage) => {
|
||||
const userId = await authenticateConnection(ws, req);
|
||||
|
||||
if (!userId) {
|
||||
ws.close(4001, "Authentication failed");
|
||||
return;
|
||||
}
|
||||
|
||||
ws.userId = userId;
|
||||
ws.isAlive = true;
|
||||
addSocket(userId, ws);
|
||||
|
||||
ws.on("pong", () => {
|
||||
heartbeat(ws);
|
||||
if (ws.pongTimer) {
|
||||
clearTimeout(ws.pongTimer);
|
||||
ws.pongTimer = undefined;
|
||||
}
|
||||
});
|
||||
|
||||
ws.on("message", (data) => {
|
||||
try {
|
||||
const msg = JSON.parse(data.toString());
|
||||
if (msg.type === "ping") {
|
||||
ws.send(JSON.stringify({ type: "pong" }));
|
||||
}
|
||||
} catch {
|
||||
// ignore invalid messages
|
||||
}
|
||||
});
|
||||
|
||||
ws.on("close", () => {
|
||||
if (ws.userId) {
|
||||
removeSocket(ws.userId, ws);
|
||||
}
|
||||
if (ws.pongTimer) {
|
||||
clearTimeout(ws.pongTimer);
|
||||
}
|
||||
});
|
||||
|
||||
ws.on("error", (err) => {
|
||||
console.error("[websocket] Client error:", err.message);
|
||||
});
|
||||
});
|
||||
|
||||
startHeartbeat();
|
||||
});
|
||||
}
|
||||
|
||||
export function stop(): Promise<void> {
|
||||
return new Promise((resolve) => {
|
||||
stopHeartbeat();
|
||||
if (!wss) {
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
|
||||
for (const ws of wss.clients) {
|
||||
ws.close(1001, "Server shutting down");
|
||||
}
|
||||
|
||||
wss.close(() => {
|
||||
wss = null;
|
||||
userSockets.clear();
|
||||
console.log("[websocket] Server stopped");
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
}
|
||||
Reference in New Issue
Block a user