import { WebSocketServer, WebSocket } from "ws"; import type { Server } from "ws"; import type { IncomingMessage } from "node:http"; import { verifyJWT } from "~/server/auth/jwt"; /** * Builds the trusted WebSocket origins allowlist. * Includes localhost dev origins and APP_URL if valid. */ function getTrustedOrigins(): string[] { const origins = [ "http://localhost:3000", "http://localhost:3001", "http://127.0.0.1:3000", "http://127.0.0.1:3001", ]; // Validate APP_URL before trusting it as a WebSocket origin const appUrl = process.env.APP_URL; if (appUrl) { try { const parsed = new URL(appUrl); if (/^https?:$/.test(parsed.protocol) && parsed.hostname) { origins.push(appUrl); } } catch { // Invalid URL — skip } } // Allow explicit override via VALID_WEBSOCKET_ORIGINS (comma-separated) const explicit = process.env.VALID_WEBSOCKET_ORIGINS; if (explicit) { for (const origin of explicit.split(",").map((o) => o.trim())) { if (origin) origins.push(origin); } } return origins; } /** * Validates the Origin header against the trusted origins allowlist. * Rejects missing, empty, or untrusted origins. */ function isTrustedOrigin( origin: string | undefined, trustedOrigins: string[], ): boolean { if (!origin || !origin.trim()) return false; return trustedOrigins.includes(origin); } // Pre-compute trusted origins at startup const TRUSTED_ORIGINS = getTrustedOrigins(); const WS_PORT = parseInt(process.env.WS_PORT ?? "3001", 10); const HEARTBEAT_INTERVAL = 30_000; const PONG_TIMEOUT = 10_000; const AUTH_TIMEOUT = 5_000; interface AlertMessage { type: "alert"; alert: { id: string; title: string; message: string; severity: string; source: string; category: string; createdAt: string; }; } interface WsClient extends WebSocket { userId?: string; isAlive?: boolean; isAuthed?: boolean; pongTimer?: ReturnType; authTimer?: ReturnType; } const userSockets = new Map>(); let wss: Server | null = null; let heartbeatTimer: ReturnType | null = null; async function authenticateToken(token: string): Promise { try { const payload = await verifyJWT<{ sub?: string; userId?: string }>(token); const userId = payload.sub ?? payload.userId; if (!userId) return null; return userId; } catch { return null; } } function addSocket(userId: string, ws: WsClient) { let sockets = userSockets.get(userId); if (!sockets) { sockets = new Set(); userSockets.set(userId, sockets); } sockets.add(ws); } function removeSocket(userId: string, ws: WsClient) { const sockets = userSockets.get(userId); if (!sockets) return; sockets.delete(ws); if (sockets.size === 0) { userSockets.delete(userId); } } function heartbeat(ws: WsClient) { ws.isAlive = true; } function startHeartbeat() { if (heartbeatTimer) clearInterval(heartbeatTimer); heartbeatTimer = setInterval(() => { if (!wss) return; wss.clients.forEach((client) => { const ws = client as WsClient; if (ws.isAlive === false) { ws.terminate(); return; } ws.isAlive = false; ws.ping(); ws.pongTimer = setTimeout(() => { ws.terminate(); }, PONG_TIMEOUT); }); }, HEARTBEAT_INTERVAL); if (heartbeatTimer && typeof heartbeatTimer === "object") { heartbeatTimer.unref(); } } function stopHeartbeat() { if (heartbeatTimer) { clearInterval(heartbeatTimer); heartbeatTimer = null; } } /** * Enforces post-connection auth timeout. * If the client doesn't send an auth message within AUTH_TIMEOUT, * the connection is terminated. */ function enforceAuthTimeout(ws: WsClient): void { ws.authTimer = setTimeout(() => { if (!ws.isAuthed) { console.log( "[websocket] Auth timeout — closing unauthenticated connection", ); ws.close(4001, "Authentication timeout"); } }, AUTH_TIMEOUT); } export function broadcastToUser(userId: string, data: AlertMessage) { const sockets = userSockets.get(userId); if (!sockets || sockets.size === 0) return false; const message = JSON.stringify(data); let sent = false; for (const ws of sockets) { if (ws.readyState === WebSocket.OPEN) { ws.send(message); sent = true; } } return sent; } export function getConnectedUsers(): string[] { return Array.from(userSockets.keys()); } export function getConnectionCount(): number { let count = 0; for (const sockets of userSockets.values()) { count += sockets.size; } return count; } export function start(): Promise { return new Promise((resolve) => { if (wss) { resolve(); return; } wss = new WebSocketServer( { port: WS_PORT, verifyClient: (info: { origin: string; req: IncomingMessage }) => { const origin = info.req.headers.origin ?? info.origin; if (!isTrustedOrigin(origin, TRUSTED_ORIGINS)) { console.warn( `[websocket] Rejected untrusted origin: ${origin ?? "(none)"}`, ); return false; } return true; }, }, () => { console.log(`[websocket] Server listening on port ${WS_PORT}`); resolve(); }, ); wss.on("connection", async (ws: WsClient) => { // Mark as unauthenticated initially; client must authenticate within timeout ws.isAuthed = false; enforceAuthTimeout(ws); ws.on("message", async (data) => { try { const msg = JSON.parse(data.toString()); // Handle auth messages (post-connection JWT authentication) if ( msg.type === "auth" && msg.token && typeof msg.token === "string" ) { const userId = await authenticateToken(msg.token); if (userId) { ws.isAuthed = true; ws.userId = userId; ws.isAlive = true; // Clear the auth timeout — client is now authenticated if (ws.authTimer) { clearTimeout(ws.authTimer); ws.authTimer = undefined; } addSocket(userId, ws); ws.send(JSON.stringify({ type: "auth_success" })); } else { ws.send( JSON.stringify({ type: "auth_error", message: "Invalid token", }), ); ws.close(4001, "Authentication failed"); } return; } // Only allow messages from authenticated connections if (!ws.isAuthed) { // Ignore ping messages from unauthenticated clients (they might not have sent auth yet) if (msg.type === "ping") { ws.send(JSON.stringify({ type: "pong" })); } return; } // Handle normal messages from authenticated clients if (msg.type === "ping") { ws.send(JSON.stringify({ type: "pong" })); } } catch { // ignore invalid messages } }); ws.on("pong", () => { heartbeat(ws); if (ws.pongTimer) { clearTimeout(ws.pongTimer); ws.pongTimer = undefined; } }); ws.on("close", () => { if (ws.userId) { removeSocket(ws.userId, ws); } if (ws.pongTimer) { clearTimeout(ws.pongTimer); } if (ws.authTimer) { clearTimeout(ws.authTimer); } }); ws.on("error", (err) => { console.error("[websocket] Client error:", err.message); }); }); startHeartbeat(); }); } export function stop(): Promise { return new Promise((resolve) => { stopHeartbeat(); if (!wss) { resolve(); return; } for (const ws of wss.clients) { ws.close(1001, "Server shutting down"); } wss.close(() => { wss = null; userSockets.clear(); console.log("[websocket] Server stopped"); resolve(); }); }); }