330 lines
7.4 KiB
TypeScript
330 lines
7.4 KiB
TypeScript
import { WebSocketServer, WebSocket } from "ws";
|
|
import type { Server } from "ws";
|
|
import type { IncomingMessage } from "node:http";
|
|
import { verifyJWT } from "~/server/auth/jwt";
|
|
|
|
/**
|
|
* Builds the trusted WebSocket origins allowlist.
|
|
* Includes localhost dev origins and APP_URL if valid.
|
|
*/
|
|
function getTrustedOrigins(): string[] {
|
|
const origins = [
|
|
"http://localhost:3000",
|
|
"http://localhost:3001",
|
|
"http://127.0.0.1:3000",
|
|
"http://127.0.0.1:3001",
|
|
];
|
|
|
|
// Validate APP_URL before trusting it as a WebSocket origin
|
|
const appUrl = process.env.APP_URL;
|
|
if (appUrl) {
|
|
try {
|
|
const parsed = new URL(appUrl);
|
|
if (/^https?:$/.test(parsed.protocol) && parsed.hostname) {
|
|
origins.push(appUrl);
|
|
}
|
|
} catch {
|
|
// Invalid URL — skip
|
|
}
|
|
}
|
|
|
|
// Allow explicit override via VALID_WEBSOCKET_ORIGINS (comma-separated)
|
|
const explicit = process.env.VALID_WEBSOCKET_ORIGINS;
|
|
if (explicit) {
|
|
for (const origin of explicit.split(",").map((o) => o.trim())) {
|
|
if (origin) origins.push(origin);
|
|
}
|
|
}
|
|
|
|
return origins;
|
|
}
|
|
|
|
/**
|
|
* Validates the Origin header against the trusted origins allowlist.
|
|
* Rejects missing, empty, or untrusted origins.
|
|
*/
|
|
function isTrustedOrigin(
|
|
origin: string | undefined,
|
|
trustedOrigins: string[],
|
|
): boolean {
|
|
if (!origin || !origin.trim()) return false;
|
|
return trustedOrigins.includes(origin);
|
|
}
|
|
|
|
// Pre-compute trusted origins at startup
|
|
const TRUSTED_ORIGINS = getTrustedOrigins();
|
|
|
|
const WS_PORT = parseInt(process.env.WS_PORT ?? "3001", 10);
|
|
const HEARTBEAT_INTERVAL = 30_000;
|
|
const PONG_TIMEOUT = 10_000;
|
|
const AUTH_TIMEOUT = 5_000;
|
|
|
|
interface AlertMessage {
|
|
type: "alert";
|
|
alert: {
|
|
id: string;
|
|
title: string;
|
|
message: string;
|
|
severity: string;
|
|
source: string;
|
|
category: string;
|
|
createdAt: string;
|
|
};
|
|
}
|
|
|
|
interface WsClient extends WebSocket {
|
|
userId?: string;
|
|
isAlive?: boolean;
|
|
isAuthed?: boolean;
|
|
pongTimer?: ReturnType<typeof setTimeout>;
|
|
authTimer?: ReturnType<typeof setTimeout>;
|
|
}
|
|
|
|
const userSockets = new Map<string, Set<WsClient>>();
|
|
let wss: Server | null = null;
|
|
let heartbeatTimer: ReturnType<typeof setInterval> | null = null;
|
|
|
|
async function authenticateToken(token: string): Promise<string | null> {
|
|
try {
|
|
const payload = await verifyJWT<{ sub?: string; userId?: string }>(token);
|
|
const userId = payload.sub ?? payload.userId;
|
|
if (!userId) return null;
|
|
return userId;
|
|
} catch {
|
|
return null;
|
|
}
|
|
}
|
|
|
|
function addSocket(userId: string, ws: WsClient) {
|
|
let sockets = userSockets.get(userId);
|
|
if (!sockets) {
|
|
sockets = new Set();
|
|
userSockets.set(userId, sockets);
|
|
}
|
|
sockets.add(ws);
|
|
}
|
|
|
|
function removeSocket(userId: string, ws: WsClient) {
|
|
const sockets = userSockets.get(userId);
|
|
if (!sockets) return;
|
|
sockets.delete(ws);
|
|
if (sockets.size === 0) {
|
|
userSockets.delete(userId);
|
|
}
|
|
}
|
|
|
|
function heartbeat(ws: WsClient) {
|
|
ws.isAlive = true;
|
|
}
|
|
|
|
function startHeartbeat() {
|
|
if (heartbeatTimer) clearInterval(heartbeatTimer);
|
|
heartbeatTimer = setInterval(() => {
|
|
if (!wss) return;
|
|
wss.clients.forEach((client) => {
|
|
const ws = client as WsClient;
|
|
if (ws.isAlive === false) {
|
|
ws.terminate();
|
|
return;
|
|
}
|
|
ws.isAlive = false;
|
|
ws.ping();
|
|
|
|
ws.pongTimer = setTimeout(() => {
|
|
ws.terminate();
|
|
}, PONG_TIMEOUT);
|
|
});
|
|
}, HEARTBEAT_INTERVAL);
|
|
|
|
if (heartbeatTimer && typeof heartbeatTimer === "object") {
|
|
heartbeatTimer.unref();
|
|
}
|
|
}
|
|
|
|
function stopHeartbeat() {
|
|
if (heartbeatTimer) {
|
|
clearInterval(heartbeatTimer);
|
|
heartbeatTimer = null;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Enforces post-connection auth timeout.
|
|
* If the client doesn't send an auth message within AUTH_TIMEOUT,
|
|
* the connection is terminated.
|
|
*/
|
|
function enforceAuthTimeout(ws: WsClient): void {
|
|
ws.authTimer = setTimeout(() => {
|
|
if (!ws.isAuthed) {
|
|
console.log(
|
|
"[websocket] Auth timeout — closing unauthenticated connection",
|
|
);
|
|
ws.close(4001, "Authentication timeout");
|
|
}
|
|
}, AUTH_TIMEOUT);
|
|
}
|
|
|
|
export function broadcastToUser(userId: string, data: AlertMessage) {
|
|
const sockets = userSockets.get(userId);
|
|
if (!sockets || sockets.size === 0) return false;
|
|
|
|
const message = JSON.stringify(data);
|
|
let sent = false;
|
|
for (const ws of sockets) {
|
|
if (ws.readyState === WebSocket.OPEN) {
|
|
ws.send(message);
|
|
sent = true;
|
|
}
|
|
}
|
|
return sent;
|
|
}
|
|
|
|
export function getConnectedUsers(): string[] {
|
|
return Array.from(userSockets.keys());
|
|
}
|
|
|
|
export function getConnectionCount(): number {
|
|
let count = 0;
|
|
for (const sockets of userSockets.values()) {
|
|
count += sockets.size;
|
|
}
|
|
return count;
|
|
}
|
|
|
|
export function start(): Promise<void> {
|
|
return new Promise((resolve) => {
|
|
if (wss) {
|
|
resolve();
|
|
return;
|
|
}
|
|
|
|
wss = new WebSocketServer(
|
|
{
|
|
port: WS_PORT,
|
|
verifyClient: (info: { origin: string; req: IncomingMessage }) => {
|
|
const origin = info.req.headers.origin ?? info.origin;
|
|
if (!isTrustedOrigin(origin, TRUSTED_ORIGINS)) {
|
|
console.warn(
|
|
`[websocket] Rejected untrusted origin: ${origin ?? "(none)"}`,
|
|
);
|
|
return false;
|
|
}
|
|
return true;
|
|
},
|
|
},
|
|
() => {
|
|
console.log(`[websocket] Server listening on port ${WS_PORT}`);
|
|
resolve();
|
|
},
|
|
);
|
|
|
|
wss.on("connection", async (ws: WsClient) => {
|
|
// Mark as unauthenticated initially; client must authenticate within timeout
|
|
ws.isAuthed = false;
|
|
enforceAuthTimeout(ws);
|
|
|
|
ws.on("message", async (data) => {
|
|
try {
|
|
const msg = JSON.parse(data.toString());
|
|
|
|
// Handle auth messages (post-connection JWT authentication)
|
|
if (
|
|
msg.type === "auth" &&
|
|
msg.token &&
|
|
typeof msg.token === "string"
|
|
) {
|
|
const userId = await authenticateToken(msg.token);
|
|
|
|
if (userId) {
|
|
ws.isAuthed = true;
|
|
ws.userId = userId;
|
|
ws.isAlive = true;
|
|
|
|
// Clear the auth timeout — client is now authenticated
|
|
if (ws.authTimer) {
|
|
clearTimeout(ws.authTimer);
|
|
ws.authTimer = undefined;
|
|
}
|
|
|
|
addSocket(userId, ws);
|
|
ws.send(JSON.stringify({ type: "auth_success" }));
|
|
} else {
|
|
ws.send(
|
|
JSON.stringify({
|
|
type: "auth_error",
|
|
message: "Invalid token",
|
|
}),
|
|
);
|
|
ws.close(4001, "Authentication failed");
|
|
}
|
|
return;
|
|
}
|
|
|
|
// Only allow messages from authenticated connections
|
|
if (!ws.isAuthed) {
|
|
// Ignore ping messages from unauthenticated clients (they might not have sent auth yet)
|
|
if (msg.type === "ping") {
|
|
ws.send(JSON.stringify({ type: "pong" }));
|
|
}
|
|
return;
|
|
}
|
|
|
|
// Handle normal messages from authenticated clients
|
|
if (msg.type === "ping") {
|
|
ws.send(JSON.stringify({ type: "pong" }));
|
|
}
|
|
} catch {
|
|
// ignore invalid messages
|
|
}
|
|
});
|
|
|
|
ws.on("pong", () => {
|
|
heartbeat(ws);
|
|
if (ws.pongTimer) {
|
|
clearTimeout(ws.pongTimer);
|
|
ws.pongTimer = undefined;
|
|
}
|
|
});
|
|
|
|
ws.on("close", () => {
|
|
if (ws.userId) {
|
|
removeSocket(ws.userId, ws);
|
|
}
|
|
if (ws.pongTimer) {
|
|
clearTimeout(ws.pongTimer);
|
|
}
|
|
if (ws.authTimer) {
|
|
clearTimeout(ws.authTimer);
|
|
}
|
|
});
|
|
|
|
ws.on("error", (err) => {
|
|
console.error("[websocket] Client error:", err.message);
|
|
});
|
|
});
|
|
|
|
startHeartbeat();
|
|
});
|
|
}
|
|
|
|
export function stop(): Promise<void> {
|
|
return new Promise((resolve) => {
|
|
stopHeartbeat();
|
|
if (!wss) {
|
|
resolve();
|
|
return;
|
|
}
|
|
|
|
for (const ws of wss.clients) {
|
|
ws.close(1001, "Server shutting down");
|
|
}
|
|
|
|
wss.close(() => {
|
|
wss = null;
|
|
userSockets.clear();
|
|
console.log("[websocket] Server stopped");
|
|
resolve();
|
|
});
|
|
});
|
|
}
|