security audit fix start
This commit is contained in:
@@ -1,216 +1,262 @@
|
||||
import { WebSocketServer, WebSocket } from "ws";
|
||||
import type { Server } from "ws";
|
||||
import { IncomingMessage } from "node:http";
|
||||
import { URL } from "node:url";
|
||||
import { verifyJWT } from "~/server/auth/jwt";
|
||||
|
||||
const WS_PORT = parseInt(process.env.WS_PORT ?? "3001", 10);
|
||||
const HEARTBEAT_INTERVAL = 30_000;
|
||||
const PONG_TIMEOUT = 10_000;
|
||||
const AUTH_TIMEOUT = 5_000;
|
||||
|
||||
interface AlertMessage {
|
||||
type: "alert";
|
||||
alert: {
|
||||
id: string;
|
||||
title: string;
|
||||
message: string;
|
||||
severity: string;
|
||||
source: string;
|
||||
category: string;
|
||||
createdAt: string;
|
||||
};
|
||||
type: "alert";
|
||||
alert: {
|
||||
id: string;
|
||||
title: string;
|
||||
message: string;
|
||||
severity: string;
|
||||
source: string;
|
||||
category: string;
|
||||
createdAt: string;
|
||||
};
|
||||
}
|
||||
|
||||
interface WsClient extends WebSocket {
|
||||
userId?: string;
|
||||
isAlive?: boolean;
|
||||
pongTimer?: ReturnType<typeof setTimeout>;
|
||||
userId?: string;
|
||||
isAlive?: boolean;
|
||||
isAuthed?: boolean;
|
||||
pongTimer?: ReturnType<typeof setTimeout>;
|
||||
authTimer?: ReturnType<typeof setTimeout>;
|
||||
}
|
||||
|
||||
const userSockets = new Map<string, Set<WsClient>>();
|
||||
let wss: Server | null = null;
|
||||
let heartbeatTimer: ReturnType<typeof setInterval> | null = null;
|
||||
|
||||
function getTokenFromRequest(req: IncomingMessage): string | null {
|
||||
const url = new URL(req.url ?? "/", "http://localhost");
|
||||
return url.searchParams.get("token");
|
||||
}
|
||||
|
||||
async function authenticateConnection(
|
||||
ws: WsClient,
|
||||
req: IncomingMessage,
|
||||
): Promise<string | null> {
|
||||
const token = getTokenFromRequest(req);
|
||||
if (!token) return null;
|
||||
|
||||
try {
|
||||
const payload = await verifyJWT<{ sub?: string; userId?: string }>(token);
|
||||
const userId = payload.sub ?? payload.userId;
|
||||
if (!userId) return null;
|
||||
return userId;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
async function authenticateToken(token: string): Promise<string | null> {
|
||||
try {
|
||||
const payload = await verifyJWT<{ sub?: string; userId?: string }>(token);
|
||||
const userId = payload.sub ?? payload.userId;
|
||||
if (!userId) return null;
|
||||
return userId;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function addSocket(userId: string, ws: WsClient) {
|
||||
let sockets = userSockets.get(userId);
|
||||
if (!sockets) {
|
||||
sockets = new Set();
|
||||
userSockets.set(userId, sockets);
|
||||
}
|
||||
sockets.add(ws);
|
||||
let sockets = userSockets.get(userId);
|
||||
if (!sockets) {
|
||||
sockets = new Set();
|
||||
userSockets.set(userId, sockets);
|
||||
}
|
||||
sockets.add(ws);
|
||||
}
|
||||
|
||||
function removeSocket(userId: string, ws: WsClient) {
|
||||
const sockets = userSockets.get(userId);
|
||||
if (!sockets) return;
|
||||
sockets.delete(ws);
|
||||
if (sockets.size === 0) {
|
||||
userSockets.delete(userId);
|
||||
}
|
||||
const sockets = userSockets.get(userId);
|
||||
if (!sockets) return;
|
||||
sockets.delete(ws);
|
||||
if (sockets.size === 0) {
|
||||
userSockets.delete(userId);
|
||||
}
|
||||
}
|
||||
|
||||
function heartbeat(ws: WsClient) {
|
||||
ws.isAlive = true;
|
||||
ws.isAlive = true;
|
||||
}
|
||||
|
||||
function startHeartbeat() {
|
||||
if (heartbeatTimer) clearInterval(heartbeatTimer);
|
||||
heartbeatTimer = setInterval(() => {
|
||||
if (!wss) return;
|
||||
wss.clients.forEach((client) => {
|
||||
const ws = client as WsClient;
|
||||
if (ws.isAlive === false) {
|
||||
ws.terminate();
|
||||
return;
|
||||
}
|
||||
ws.isAlive = false;
|
||||
ws.ping();
|
||||
if (heartbeatTimer) clearInterval(heartbeatTimer);
|
||||
heartbeatTimer = setInterval(() => {
|
||||
if (!wss) return;
|
||||
wss.clients.forEach((client) => {
|
||||
const ws = client as WsClient;
|
||||
if (ws.isAlive === false) {
|
||||
ws.terminate();
|
||||
return;
|
||||
}
|
||||
ws.isAlive = false;
|
||||
ws.ping();
|
||||
|
||||
ws.pongTimer = setTimeout(() => {
|
||||
ws.terminate();
|
||||
}, PONG_TIMEOUT);
|
||||
});
|
||||
}, HEARTBEAT_INTERVAL);
|
||||
ws.pongTimer = setTimeout(() => {
|
||||
ws.terminate();
|
||||
}, PONG_TIMEOUT);
|
||||
});
|
||||
}, HEARTBEAT_INTERVAL);
|
||||
|
||||
if (heartbeatTimer && typeof heartbeatTimer === "object") {
|
||||
heartbeatTimer.unref();
|
||||
}
|
||||
if (heartbeatTimer && typeof heartbeatTimer === "object") {
|
||||
heartbeatTimer.unref();
|
||||
}
|
||||
}
|
||||
|
||||
function stopHeartbeat() {
|
||||
if (heartbeatTimer) {
|
||||
clearInterval(heartbeatTimer);
|
||||
heartbeatTimer = null;
|
||||
}
|
||||
if (heartbeatTimer) {
|
||||
clearInterval(heartbeatTimer);
|
||||
heartbeatTimer = null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Enforces post-connection auth timeout.
|
||||
* If the client doesn't send an auth message within AUTH_TIMEOUT,
|
||||
* the connection is terminated.
|
||||
*/
|
||||
function enforceAuthTimeout(ws: WsClient): void {
|
||||
ws.authTimer = setTimeout(() => {
|
||||
if (!ws.isAuthed) {
|
||||
console.log(
|
||||
"[websocket] Auth timeout — closing unauthenticated connection",
|
||||
);
|
||||
ws.close(4001, "Authentication timeout");
|
||||
}
|
||||
}, AUTH_TIMEOUT);
|
||||
}
|
||||
|
||||
export function broadcastToUser(userId: string, data: AlertMessage) {
|
||||
const sockets = userSockets.get(userId);
|
||||
if (!sockets || sockets.size === 0) return false;
|
||||
const sockets = userSockets.get(userId);
|
||||
if (!sockets || sockets.size === 0) return false;
|
||||
|
||||
const message = JSON.stringify(data);
|
||||
let sent = false;
|
||||
for (const ws of sockets) {
|
||||
if (ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(message);
|
||||
sent = true;
|
||||
}
|
||||
}
|
||||
return sent;
|
||||
const message = JSON.stringify(data);
|
||||
let sent = false;
|
||||
for (const ws of sockets) {
|
||||
if (ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(message);
|
||||
sent = true;
|
||||
}
|
||||
}
|
||||
return sent;
|
||||
}
|
||||
|
||||
export function getConnectedUsers(): string[] {
|
||||
return Array.from(userSockets.keys());
|
||||
return Array.from(userSockets.keys());
|
||||
}
|
||||
|
||||
export function getConnectionCount(): number {
|
||||
let count = 0;
|
||||
for (const sockets of userSockets.values()) {
|
||||
count += sockets.size;
|
||||
}
|
||||
return count;
|
||||
let count = 0;
|
||||
for (const sockets of userSockets.values()) {
|
||||
count += sockets.size;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
export function start(): Promise<void> {
|
||||
return new Promise((resolve) => {
|
||||
if (wss) {
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
return new Promise((resolve) => {
|
||||
if (wss) {
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
|
||||
wss = new WebSocketServer({ port: WS_PORT }, () => {
|
||||
console.log(`[websocket] Server listening on port ${WS_PORT}`);
|
||||
resolve();
|
||||
});
|
||||
wss = new WebSocketServer({ port: WS_PORT }, () => {
|
||||
console.log(`[websocket] Server listening on port ${WS_PORT}`);
|
||||
resolve();
|
||||
});
|
||||
|
||||
wss.on("connection", async (ws: WsClient, req: IncomingMessage) => {
|
||||
const userId = await authenticateConnection(ws, req);
|
||||
wss.on("connection", async (ws: WsClient) => {
|
||||
// Mark as unauthenticated initially; client must authenticate within timeout
|
||||
ws.isAuthed = false;
|
||||
enforceAuthTimeout(ws);
|
||||
|
||||
if (!userId) {
|
||||
ws.close(4001, "Authentication failed");
|
||||
return;
|
||||
}
|
||||
ws.on("message", async (data) => {
|
||||
try {
|
||||
const msg = JSON.parse(data.toString());
|
||||
|
||||
ws.userId = userId;
|
||||
ws.isAlive = true;
|
||||
addSocket(userId, ws);
|
||||
// Handle auth messages (post-connection JWT authentication)
|
||||
if (
|
||||
msg.type === "auth" &&
|
||||
msg.token &&
|
||||
typeof msg.token === "string"
|
||||
) {
|
||||
const userId = await authenticateToken(msg.token);
|
||||
|
||||
ws.on("pong", () => {
|
||||
heartbeat(ws);
|
||||
if (ws.pongTimer) {
|
||||
clearTimeout(ws.pongTimer);
|
||||
ws.pongTimer = undefined;
|
||||
}
|
||||
});
|
||||
if (userId) {
|
||||
ws.isAuthed = true;
|
||||
ws.userId = userId;
|
||||
ws.isAlive = true;
|
||||
|
||||
ws.on("message", (data) => {
|
||||
try {
|
||||
const msg = JSON.parse(data.toString());
|
||||
if (msg.type === "ping") {
|
||||
ws.send(JSON.stringify({ type: "pong" }));
|
||||
}
|
||||
} catch {
|
||||
// ignore invalid messages
|
||||
}
|
||||
});
|
||||
// Clear the auth timeout — client is now authenticated
|
||||
if (ws.authTimer) {
|
||||
clearTimeout(ws.authTimer);
|
||||
ws.authTimer = undefined;
|
||||
}
|
||||
|
||||
ws.on("close", () => {
|
||||
if (ws.userId) {
|
||||
removeSocket(ws.userId, ws);
|
||||
}
|
||||
if (ws.pongTimer) {
|
||||
clearTimeout(ws.pongTimer);
|
||||
}
|
||||
});
|
||||
addSocket(userId, ws);
|
||||
ws.send(JSON.stringify({ type: "auth_success" }));
|
||||
} else {
|
||||
ws.send(
|
||||
JSON.stringify({
|
||||
type: "auth_error",
|
||||
message: "Invalid token",
|
||||
}),
|
||||
);
|
||||
ws.close(4001, "Authentication failed");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
ws.on("error", (err) => {
|
||||
console.error("[websocket] Client error:", err.message);
|
||||
});
|
||||
});
|
||||
// Only allow messages from authenticated connections
|
||||
if (!ws.isAuthed) {
|
||||
// Ignore ping messages from unauthenticated clients (they might not have sent auth yet)
|
||||
if (msg.type === "ping") {
|
||||
ws.send(JSON.stringify({ type: "pong" }));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
startHeartbeat();
|
||||
});
|
||||
// Handle normal messages from authenticated clients
|
||||
if (msg.type === "ping") {
|
||||
ws.send(JSON.stringify({ type: "pong" }));
|
||||
}
|
||||
} catch {
|
||||
// ignore invalid messages
|
||||
}
|
||||
});
|
||||
|
||||
ws.on("pong", () => {
|
||||
heartbeat(ws);
|
||||
if (ws.pongTimer) {
|
||||
clearTimeout(ws.pongTimer);
|
||||
ws.pongTimer = undefined;
|
||||
}
|
||||
});
|
||||
|
||||
ws.on("close", () => {
|
||||
if (ws.userId) {
|
||||
removeSocket(ws.userId, ws);
|
||||
}
|
||||
if (ws.pongTimer) {
|
||||
clearTimeout(ws.pongTimer);
|
||||
}
|
||||
if (ws.authTimer) {
|
||||
clearTimeout(ws.authTimer);
|
||||
}
|
||||
});
|
||||
|
||||
ws.on("error", (err) => {
|
||||
console.error("[websocket] Client error:", err.message);
|
||||
});
|
||||
});
|
||||
|
||||
startHeartbeat();
|
||||
});
|
||||
}
|
||||
|
||||
export function stop(): Promise<void> {
|
||||
return new Promise((resolve) => {
|
||||
stopHeartbeat();
|
||||
if (!wss) {
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
return new Promise((resolve) => {
|
||||
stopHeartbeat();
|
||||
if (!wss) {
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
|
||||
for (const ws of wss.clients) {
|
||||
ws.close(1001, "Server shutting down");
|
||||
}
|
||||
for (const ws of wss.clients) {
|
||||
ws.close(1001, "Server shutting down");
|
||||
}
|
||||
|
||||
wss.close(() => {
|
||||
wss = null;
|
||||
userSockets.clear();
|
||||
console.log("[websocket] Server stopped");
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
wss.close(() => {
|
||||
wss = null;
|
||||
userSockets.clear();
|
||||
console.log("[websocket] Server stopped");
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user