Files
Kordant/web/src/server/websocket.ts
2026-05-29 09:03:47 -04:00

330 lines
7.4 KiB
TypeScript

import { WebSocketServer, WebSocket } from "ws";
import type { Server } from "ws";
import type { IncomingMessage } from "node:http";
import { verifyJWT } from "~/server/auth/jwt";
/**
* Builds the trusted WebSocket origins allowlist.
* Includes localhost dev origins and APP_URL if valid.
*/
function getTrustedOrigins(): string[] {
const origins = [
"http://localhost:3000",
"http://localhost:3001",
"http://127.0.0.1:3000",
"http://127.0.0.1:3001",
];
// Validate APP_URL before trusting it as a WebSocket origin
const appUrl = process.env.APP_URL;
if (appUrl) {
try {
const parsed = new URL(appUrl);
if (/^https?:$/.test(parsed.protocol) && parsed.hostname) {
origins.push(appUrl);
}
} catch {
// Invalid URL — skip
}
}
// Allow explicit override via VALID_WEBSOCKET_ORIGINS (comma-separated)
const explicit = process.env.VALID_WEBSOCKET_ORIGINS;
if (explicit) {
for (const origin of explicit.split(",").map((o) => o.trim())) {
if (origin) origins.push(origin);
}
}
return origins;
}
/**
* Validates the Origin header against the trusted origins allowlist.
* Rejects missing, empty, or untrusted origins.
*/
function isTrustedOrigin(
origin: string | undefined,
trustedOrigins: string[],
): boolean {
if (!origin || !origin.trim()) return false;
return trustedOrigins.includes(origin);
}
// Pre-compute trusted origins at startup
const TRUSTED_ORIGINS = getTrustedOrigins();
const WS_PORT = parseInt(process.env.WS_PORT ?? "3001", 10);
const HEARTBEAT_INTERVAL = 30_000;
const PONG_TIMEOUT = 10_000;
const AUTH_TIMEOUT = 5_000;
interface AlertMessage {
type: "alert";
alert: {
id: string;
title: string;
message: string;
severity: string;
source: string;
category: string;
createdAt: string;
};
}
interface WsClient extends WebSocket {
userId?: string;
isAlive?: boolean;
isAuthed?: boolean;
pongTimer?: ReturnType<typeof setTimeout>;
authTimer?: ReturnType<typeof setTimeout>;
}
const userSockets = new Map<string, Set<WsClient>>();
let wss: Server | null = null;
let heartbeatTimer: ReturnType<typeof setInterval> | null = null;
async function authenticateToken(token: string): Promise<string | null> {
try {
const payload = await verifyJWT<{ sub?: string; userId?: string }>(token);
const userId = payload.sub ?? payload.userId;
if (!userId) return null;
return userId;
} catch {
return null;
}
}
function addSocket(userId: string, ws: WsClient) {
let sockets = userSockets.get(userId);
if (!sockets) {
sockets = new Set();
userSockets.set(userId, sockets);
}
sockets.add(ws);
}
function removeSocket(userId: string, ws: WsClient) {
const sockets = userSockets.get(userId);
if (!sockets) return;
sockets.delete(ws);
if (sockets.size === 0) {
userSockets.delete(userId);
}
}
function heartbeat(ws: WsClient) {
ws.isAlive = true;
}
function startHeartbeat() {
if (heartbeatTimer) clearInterval(heartbeatTimer);
heartbeatTimer = setInterval(() => {
if (!wss) return;
wss.clients.forEach((client) => {
const ws = client as WsClient;
if (ws.isAlive === false) {
ws.terminate();
return;
}
ws.isAlive = false;
ws.ping();
ws.pongTimer = setTimeout(() => {
ws.terminate();
}, PONG_TIMEOUT);
});
}, HEARTBEAT_INTERVAL);
if (heartbeatTimer && typeof heartbeatTimer === "object") {
heartbeatTimer.unref();
}
}
function stopHeartbeat() {
if (heartbeatTimer) {
clearInterval(heartbeatTimer);
heartbeatTimer = null;
}
}
/**
* Enforces post-connection auth timeout.
* If the client doesn't send an auth message within AUTH_TIMEOUT,
* the connection is terminated.
*/
function enforceAuthTimeout(ws: WsClient): void {
ws.authTimer = setTimeout(() => {
if (!ws.isAuthed) {
console.log(
"[websocket] Auth timeout — closing unauthenticated connection",
);
ws.close(4001, "Authentication timeout");
}
}, AUTH_TIMEOUT);
}
export function broadcastToUser(userId: string, data: AlertMessage) {
const sockets = userSockets.get(userId);
if (!sockets || sockets.size === 0) return false;
const message = JSON.stringify(data);
let sent = false;
for (const ws of sockets) {
if (ws.readyState === WebSocket.OPEN) {
ws.send(message);
sent = true;
}
}
return sent;
}
export function getConnectedUsers(): string[] {
return Array.from(userSockets.keys());
}
export function getConnectionCount(): number {
let count = 0;
for (const sockets of userSockets.values()) {
count += sockets.size;
}
return count;
}
export function start(): Promise<void> {
return new Promise((resolve) => {
if (wss) {
resolve();
return;
}
wss = new WebSocketServer(
{
port: WS_PORT,
verifyClient: (info: { origin: string; req: IncomingMessage }) => {
const origin = info.req.headers.origin ?? info.origin;
if (!isTrustedOrigin(origin, TRUSTED_ORIGINS)) {
console.warn(
`[websocket] Rejected untrusted origin: ${origin ?? "(none)"}`,
);
return false;
}
return true;
},
},
() => {
console.log(`[websocket] Server listening on port ${WS_PORT}`);
resolve();
},
);
wss.on("connection", async (ws: WsClient) => {
// Mark as unauthenticated initially; client must authenticate within timeout
ws.isAuthed = false;
enforceAuthTimeout(ws);
ws.on("message", async (data) => {
try {
const msg = JSON.parse(data.toString());
// Handle auth messages (post-connection JWT authentication)
if (
msg.type === "auth" &&
msg.token &&
typeof msg.token === "string"
) {
const userId = await authenticateToken(msg.token);
if (userId) {
ws.isAuthed = true;
ws.userId = userId;
ws.isAlive = true;
// Clear the auth timeout — client is now authenticated
if (ws.authTimer) {
clearTimeout(ws.authTimer);
ws.authTimer = undefined;
}
addSocket(userId, ws);
ws.send(JSON.stringify({ type: "auth_success" }));
} else {
ws.send(
JSON.stringify({
type: "auth_error",
message: "Invalid token",
}),
);
ws.close(4001, "Authentication failed");
}
return;
}
// Only allow messages from authenticated connections
if (!ws.isAuthed) {
// Ignore ping messages from unauthenticated clients (they might not have sent auth yet)
if (msg.type === "ping") {
ws.send(JSON.stringify({ type: "pong" }));
}
return;
}
// Handle normal messages from authenticated clients
if (msg.type === "ping") {
ws.send(JSON.stringify({ type: "pong" }));
}
} catch {
// ignore invalid messages
}
});
ws.on("pong", () => {
heartbeat(ws);
if (ws.pongTimer) {
clearTimeout(ws.pongTimer);
ws.pongTimer = undefined;
}
});
ws.on("close", () => {
if (ws.userId) {
removeSocket(ws.userId, ws);
}
if (ws.pongTimer) {
clearTimeout(ws.pongTimer);
}
if (ws.authTimer) {
clearTimeout(ws.authTimer);
}
});
ws.on("error", (err) => {
console.error("[websocket] Client error:", err.message);
});
});
startHeartbeat();
});
}
export function stop(): Promise<void> {
return new Promise((resolve) => {
stopHeartbeat();
if (!wss) {
resolve();
return;
}
for (const ws of wss.clients) {
ws.close(1001, "Server shutting down");
}
wss.close(() => {
wss = null;
userSockets.clear();
console.log("[websocket] Server stopped");
resolve();
});
});
}