web security audit fixes
This commit is contained in:
@@ -37,6 +37,7 @@
|
||||
"ioredis": "^5.10.1",
|
||||
"isomorphic-dompurify": "^3.15.0",
|
||||
"jose": "^5",
|
||||
"marked": "^18.0.4",
|
||||
"node-cron": "^4.2.1",
|
||||
"onnxruntime-node": "^1.26.0",
|
||||
"pino": "^10.3.1",
|
||||
|
||||
@@ -1,21 +1,25 @@
|
||||
import DOMPurify from "isomorphic-dompurify";
|
||||
import { marked } from "marked";
|
||||
|
||||
/**
|
||||
* Sanitizes HTML content by stripping all XSS vectors (script tags,
|
||||
* event handlers, javascript:/data: URIs) while preserving safe
|
||||
* formatting elements (headings, paragraphs, links, lists, code).
|
||||
*
|
||||
* Uses a strict ALLOWED_URI_REGEXP that only permits http, https, mailto,
|
||||
* tel, and relative URLs — explicitly blocking javascript:, data:, and vbscript:.
|
||||
*/
|
||||
export function sanitizeHtml(rawHtml: string): string {
|
||||
return DOMPurify.sanitize(rawHtml, {
|
||||
ALLOWED_TAGS: [
|
||||
"h1", "h2", "h3", "h4", "h5", "h6",
|
||||
"h7", "h8",
|
||||
"p", "a", "ul", "ol", "li",
|
||||
"strong", "em", "b", "i", "u",
|
||||
"code", "pre",
|
||||
"br", "hr",
|
||||
"blockquote",
|
||||
"img",
|
||||
"table", "thead", "tbody", "tr", "th", "td",
|
||||
],
|
||||
ALLOWED_ATTR: [
|
||||
"href",
|
||||
@@ -26,6 +30,19 @@ export function sanitizeHtml(rawHtml: string): string {
|
||||
"rel",
|
||||
"target",
|
||||
],
|
||||
ALLOWED_URI_REGEXP: /^(?:(?:(?:f|ht)tps?|mailto|tel|callto|cid|xmpp):|[^a-z]|[a-z+.\-]+(?:[^a-z+.\-:]|$))/i,
|
||||
// Only allow safe URI schemes. Explicitly blocks javascript:, data:, vbscript:
|
||||
ALLOWED_URI_REGEXP: /^(?:(?:https?|mailto|tel):|\/|#|\.\?\/)/i,
|
||||
// Explicitly forbid dangerous tags even if they slip through
|
||||
FORBID_TAGS: ["script", "style", "iframe", "object", "embed", "form", "input", "button", "select", "textarea"],
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts markdown content to sanitized HTML.
|
||||
* Uses marked for markdown parsing and DOMPurify for sanitization.
|
||||
* This is the safe replacement for the previous contentToHtml() function.
|
||||
*/
|
||||
export function markdownToHtml(markdown: string): string {
|
||||
const html = marked.parse(markdown, { async: false }) as string;
|
||||
return sanitizeHtml(html);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { object, string, minLength, custom } from "valibot";
|
||||
import { custom } from "valibot";
|
||||
|
||||
function getAllowlist(): string[] {
|
||||
const raw = process.env.ALLOWED_RETURN_DOMAINS ?? "app.kordant.com,admin.kordant.com";
|
||||
@@ -58,12 +58,8 @@ export function validateReturnUrl(url: string): boolean {
|
||||
*/
|
||||
export const returnUrlSchema = custom<string>(
|
||||
(value) => {
|
||||
if (typeof value !== "string" || !validateReturnUrl(value)) {
|
||||
return {
|
||||
message:
|
||||
"Return URL must point to a trusted domain. Only app.kordant.com and admin.kordant.com are allowed.",
|
||||
};
|
||||
}
|
||||
return value;
|
||||
if (typeof value !== "string") return false;
|
||||
return validateReturnUrl(value);
|
||||
},
|
||||
"Return URL must point to a trusted domain. Only app.kordant.com and admin.kordant.com are allowed.",
|
||||
);
|
||||
|
||||
@@ -5,10 +5,13 @@ function createMockWs() {
|
||||
let onopen: (() => void) | null = null;
|
||||
let onclose: ((event: { code: number }) => void) | null = null;
|
||||
let onmessage: ((event: { data: string }) => void) | null = null;
|
||||
const sentMessages: string[] = [];
|
||||
|
||||
return {
|
||||
readyState: 1,
|
||||
send: vi.fn(),
|
||||
send: vi.fn((data: string) => {
|
||||
sentMessages.push(data);
|
||||
}),
|
||||
close: vi.fn((code?: number) => {
|
||||
onclose?.({ code: code ?? 1000 });
|
||||
}),
|
||||
@@ -22,6 +25,7 @@ function createMockWs() {
|
||||
OPEN: 1,
|
||||
CLOSING: 2,
|
||||
CLOSED: 3,
|
||||
sentMessages,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -71,23 +75,37 @@ describe("WebSocket client", () => {
|
||||
return result;
|
||||
}
|
||||
|
||||
it("should call WebSocket constructor with token", () => {
|
||||
const ws = new globalThis.WebSocket("ws://test/tok");
|
||||
it("should connect WITHOUT token in URL (no JWT leakage)", () => {
|
||||
const ws = new globalThis.WebSocket("ws://test");
|
||||
expect(wsConstructorUrls).toHaveLength(1);
|
||||
expect(wsConstructorUrls[0]).toContain("ws://test/tok");
|
||||
expect(wsConstructorUrls[0]).toBe("ws://test");
|
||||
expect(wsConstructorUrls[0]).not.toContain("token=");
|
||||
expect(ws).toBe(mockWs);
|
||||
});
|
||||
|
||||
it("should connect and set connected status", async () => {
|
||||
it("should connect and send post-connection auth message", async () => {
|
||||
const { createWebSocketClient } = await import("./websocket");
|
||||
const client = runWithRoot(() => createWebSocketClient());
|
||||
client.connect();
|
||||
|
||||
// WebSocket should connect without token in URL
|
||||
expect(wsConstructorUrls).toHaveLength(1);
|
||||
expect(wsConstructorUrls[0]).toContain("token=test-session-token");
|
||||
expect(wsConstructorUrls[0]).toBe("ws://localhost:3001");
|
||||
expect(wsConstructorUrls[0]).not.toContain("token=");
|
||||
|
||||
// Trigger onopen to simulate connection established
|
||||
mockWs.onopen?.();
|
||||
await new Promise((r) => setTimeout(r, 5));
|
||||
|
||||
// Should have sent auth message with token
|
||||
expect(mockWs.send).toHaveBeenCalledWith(
|
||||
JSON.stringify({ type: "auth", token: "test-session-token" }),
|
||||
);
|
||||
|
||||
// Simulate auth_success from server
|
||||
mockWs.onmessage?.({ data: JSON.stringify({ type: "auth_success" }) });
|
||||
await new Promise((r) => setTimeout(r, 5));
|
||||
|
||||
expect(client.connectionStatus()).toBe("connected");
|
||||
client.disconnect();
|
||||
});
|
||||
@@ -106,12 +124,35 @@ describe("WebSocket client", () => {
|
||||
expect(client.connectionStatus()).toBe("disconnected");
|
||||
});
|
||||
|
||||
it("should handle auth_error and close connection", async () => {
|
||||
const { createWebSocketClient } = await import("./websocket");
|
||||
const client = runWithRoot(() => createWebSocketClient());
|
||||
client.connect();
|
||||
|
||||
mockWs.onopen?.();
|
||||
await new Promise((r) => setTimeout(r, 5));
|
||||
|
||||
// Simulate auth_error from server
|
||||
mockWs.onmessage?.({
|
||||
data: JSON.stringify({ type: "auth_error", message: "Invalid token" }),
|
||||
});
|
||||
await new Promise((r) => setTimeout(r, 5));
|
||||
|
||||
// Should have closed the connection
|
||||
expect(mockWs.close).toHaveBeenCalled();
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it("should reconnect on unexpected disconnect", async () => {
|
||||
const { createWebSocketClient } = await import("./websocket");
|
||||
const client = runWithRoot(() => createWebSocketClient());
|
||||
client.connect();
|
||||
mockWs.onopen?.();
|
||||
await new Promise((r) => setTimeout(r, 5));
|
||||
|
||||
// Auth success
|
||||
mockWs.onmessage?.({ data: JSON.stringify({ type: "auth_success" }) });
|
||||
await new Promise((r) => setTimeout(r, 5));
|
||||
expect(client.connectionStatus()).toBe("connected");
|
||||
|
||||
mockWs.onclose?.({ code: 1006 });
|
||||
@@ -129,6 +170,10 @@ describe("WebSocket client", () => {
|
||||
mockWs.onopen?.();
|
||||
await new Promise((r) => setTimeout(r, 5));
|
||||
|
||||
// Auth success first
|
||||
mockWs.onmessage?.({ data: JSON.stringify({ type: "auth_success" }) });
|
||||
await new Promise((r) => setTimeout(r, 5));
|
||||
|
||||
mockWs.onmessage?.({
|
||||
data: JSON.stringify({
|
||||
type: "alert",
|
||||
@@ -152,6 +197,9 @@ describe("WebSocket client", () => {
|
||||
client.connect();
|
||||
mockWs.onopen?.();
|
||||
await new Promise((r) => setTimeout(r, 5));
|
||||
|
||||
mockWs.onmessage?.({ data: JSON.stringify({ type: "auth_success" }) });
|
||||
await new Promise((r) => setTimeout(r, 5));
|
||||
expect(client.connectionStatus()).toBe("connected");
|
||||
|
||||
client.disconnect();
|
||||
@@ -164,7 +212,9 @@ describe("WebSocket client", () => {
|
||||
client.connect();
|
||||
mockWs.onopen?.();
|
||||
await new Promise((r) => setTimeout(r, 5));
|
||||
expect(client.connectionStatus()).toBe("connected");
|
||||
|
||||
mockWs.onmessage?.({ data: JSON.stringify({ type: "auth_success" }) });
|
||||
await new Promise((r) => setTimeout(r, 5));
|
||||
|
||||
expect(() => {
|
||||
mockWs.onmessage?.({ data: JSON.stringify({ type: "pong" }) });
|
||||
|
||||
@@ -5,6 +5,7 @@ const RECONNECT_DELAYS = [1000, 2000, 5000, 10_000, 30_000];
|
||||
const MAX_RECONNECT_ATTEMPTS = 10;
|
||||
const HEARTBEAT_INTERVAL = 30_000;
|
||||
const PONG_TIMEOUT = 10_000;
|
||||
const AUTH_TIMEOUT = 5_000;
|
||||
|
||||
export interface AlertPayload {
|
||||
id: string;
|
||||
@@ -39,6 +40,8 @@ export function createWebSocketClient() {
|
||||
let heartbeatTimer: ReturnType<typeof setInterval> | null = null;
|
||||
let pongTimer: ReturnType<typeof setTimeout> | null = null;
|
||||
let intentionalClose = false;
|
||||
let isAuthenticated = false;
|
||||
let authTimer: ReturnType<typeof setTimeout> | null = null;
|
||||
let listeners: Array<(alert: AlertPayload) => void> = [];
|
||||
let statusListeners: Array<(status: ConnectionStatus) => void> = [];
|
||||
let mountedCleanups: Array<() => void> = [];
|
||||
@@ -89,6 +92,30 @@ export function createWebSocketClient() {
|
||||
}
|
||||
}
|
||||
|
||||
function startAuthTimeout() {
|
||||
stopAuthTimeout();
|
||||
authTimer = setTimeout(() => {
|
||||
if (!isAuthenticated && ws) {
|
||||
intentionalClose = false;
|
||||
ws.close();
|
||||
}
|
||||
}, AUTH_TIMEOUT);
|
||||
}
|
||||
|
||||
function stopAuthTimeout() {
|
||||
if (authTimer) {
|
||||
clearTimeout(authTimer);
|
||||
authTimer = null;
|
||||
}
|
||||
}
|
||||
|
||||
function sendAuth(token: string) {
|
||||
if (ws?.readyState === WebSocket.OPEN) {
|
||||
ws.send(JSON.stringify({ type: "auth", token }));
|
||||
startAuthTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
function scheduleReconnect() {
|
||||
if (intentionalClose) return;
|
||||
if (reconnectAttempts >= MAX_RECONNECT_ATTEMPTS) {
|
||||
@@ -119,15 +146,17 @@ export function createWebSocketClient() {
|
||||
}
|
||||
|
||||
intentionalClose = false;
|
||||
isAuthenticated = false;
|
||||
notifyStatus("connecting");
|
||||
|
||||
try {
|
||||
ws = new WebSocket(`${WS_URL}?token=${encodeURIComponent(token)}`);
|
||||
// Connect WITHOUT token in URL to prevent JWT leakage in logs
|
||||
ws = new WebSocket(WS_URL);
|
||||
|
||||
ws.onopen = () => {
|
||||
reconnectAttempts = 0;
|
||||
notifyStatus("connected");
|
||||
startHeartbeat();
|
||||
// Send auth message after connection opens (post-connection auth)
|
||||
sendAuth(token);
|
||||
};
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
@@ -142,6 +171,21 @@ export function createWebSocketClient() {
|
||||
return;
|
||||
}
|
||||
|
||||
if (data.type === "auth_success") {
|
||||
isAuthenticated = true;
|
||||
stopAuthTimeout();
|
||||
notifyStatus("connected");
|
||||
startHeartbeat();
|
||||
return;
|
||||
}
|
||||
|
||||
if (data.type === "auth_error") {
|
||||
stopAuthTimeout();
|
||||
intentionalClose = true;
|
||||
ws?.close();
|
||||
return;
|
||||
}
|
||||
|
||||
if (data.type === "alert") {
|
||||
notifyAlert(data.alert as AlertPayload);
|
||||
}
|
||||
@@ -152,6 +196,8 @@ export function createWebSocketClient() {
|
||||
|
||||
ws.onclose = () => {
|
||||
stopHeartbeat();
|
||||
stopAuthTimeout();
|
||||
isAuthenticated = false;
|
||||
if (!intentionalClose) {
|
||||
scheduleReconnect();
|
||||
} else {
|
||||
@@ -169,7 +215,9 @@ export function createWebSocketClient() {
|
||||
|
||||
function disconnect() {
|
||||
intentionalClose = true;
|
||||
isAuthenticated = false;
|
||||
stopHeartbeat();
|
||||
stopAuthTimeout();
|
||||
if (reconnectTimer) {
|
||||
clearTimeout(reconnectTimer);
|
||||
reconnectTimer = null;
|
||||
|
||||
@@ -1,73 +1,112 @@
|
||||
import { describe, it, expect } from "vitest";
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import {
|
||||
validateCorsOrigin,
|
||||
parseCorsAllowlist,
|
||||
} from "~/server/lib/cors-validation";
|
||||
|
||||
/**
|
||||
* Mirrors the isValidCorsOrigin function from middleware.ts
|
||||
*/
|
||||
function isValidCorsOrigin(origin: string): boolean {
|
||||
if (!origin || !origin.trim()) return false;
|
||||
if (origin === "*") return false;
|
||||
|
||||
try {
|
||||
const parsed = new URL(origin);
|
||||
if (!parsed.protocol.match(/^https?:$/)) return false;
|
||||
if (!parsed.hostname) return false;
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
describe("isValidCorsOrigin", () => {
|
||||
describe("validateCorsOrigin", () => {
|
||||
describe("accepted origins", () => {
|
||||
it("accepts valid HTTPS origins", () => {
|
||||
expect(isValidCorsOrigin("https://app.kordant.com")).toBe(true);
|
||||
expect(isValidCorsOrigin("https://admin.kordant.com")).toBe(true);
|
||||
expect(isValidCorsOrigin("https://localhost:3000")).toBe(true);
|
||||
expect(validateCorsOrigin("https://app.kordant.com")).toBe(true);
|
||||
expect(validateCorsOrigin("https://admin.kordant.com")).toBe(true);
|
||||
expect(validateCorsOrigin("https://localhost:3000")).toBe(true);
|
||||
});
|
||||
|
||||
it("accepts valid HTTP origins", () => {
|
||||
expect(isValidCorsOrigin("http://localhost:3000")).toBe(true);
|
||||
expect(isValidCorsOrigin("http://localhost:3001")).toBe(true);
|
||||
expect(isValidCorsOrigin("http://127.0.0.1:8080")).toBe(true);
|
||||
expect(validateCorsOrigin("http://localhost:3000")).toBe(true);
|
||||
expect(validateCorsOrigin("http://localhost:3001")).toBe(true);
|
||||
expect(validateCorsOrigin("http://127.0.0.1:8080")).toBe(true);
|
||||
});
|
||||
|
||||
it("accepts origins with ports", () => {
|
||||
expect(isValidCorsOrigin("https://app.kordant.com:8443")).toBe(true);
|
||||
expect(isValidCorsOrigin("http://localhost:5173")).toBe(true);
|
||||
expect(validateCorsOrigin("https://app.kordant.com:8443")).toBe(true);
|
||||
expect(validateCorsOrigin("http://localhost:5173")).toBe(true);
|
||||
});
|
||||
|
||||
it("accepts origins with paths", () => {
|
||||
expect(isValidCorsOrigin("https://app.kordant.com/api")).toBe(true);
|
||||
expect(validateCorsOrigin("https://app.kordant.com/api")).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("rejected origins", () => {
|
||||
it("rejects wildcard", () => {
|
||||
expect(isValidCorsOrigin("*")).toBe(false);
|
||||
expect(validateCorsOrigin("*")).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects missing scheme", () => {
|
||||
expect(isValidCorsOrigin("evil.com")).toBe(false);
|
||||
expect(isValidCorsOrigin("localhost")).toBe(false);
|
||||
expect(isValidCorsOrigin("app.kordant.com")).toBe(false);
|
||||
expect(validateCorsOrigin("evil.com")).toBe(false);
|
||||
expect(validateCorsOrigin("localhost")).toBe(false);
|
||||
expect(validateCorsOrigin("app.kordant.com")).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects non-HTTP schemes", () => {
|
||||
expect(isValidCorsOrigin("ftp://evil.com")).toBe(false);
|
||||
expect(isValidCorsOrigin("file:///etc/passwd")).toBe(false);
|
||||
expect(isValidCorsOrigin("javascript:alert(1)")).toBe(false);
|
||||
expect(isValidCorsOrigin("data:text/html,test")).toBe(false);
|
||||
expect(validateCorsOrigin("ftp://evil.com")).toBe(false);
|
||||
expect(validateCorsOrigin("file:///etc/passwd")).toBe(false);
|
||||
expect(validateCorsOrigin("javascript:alert(1)")).toBe(false);
|
||||
expect(validateCorsOrigin("data:text/html,test")).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects empty and whitespace strings", () => {
|
||||
expect(isValidCorsOrigin("")).toBe(false);
|
||||
expect(isValidCorsOrigin(" ")).toBe(false);
|
||||
expect(isValidCorsOrigin("\t")).toBe(false);
|
||||
expect(validateCorsOrigin("")).toBe(false);
|
||||
expect(validateCorsOrigin(" ")).toBe(false);
|
||||
expect(validateCorsOrigin("\t")).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects malformed URLs", () => {
|
||||
expect(isValidCorsOrigin("not a url")).toBe(false);
|
||||
expect(isValidCorsOrigin("://missing-protocol")).toBe(false);
|
||||
expect(validateCorsOrigin("not a url")).toBe(false);
|
||||
expect(validateCorsOrigin("://missing-protocol")).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("parseCorsAllowlist", () => {
|
||||
beforeEach(() => {
|
||||
vi.spyOn(console, "warn").mockImplementation(() => {});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it("returns empty array for undefined/null/empty input", () => {
|
||||
expect(parseCorsAllowlist(undefined)).toEqual([]);
|
||||
expect(parseCorsAllowlist(null)).toEqual([]);
|
||||
expect(parseCorsAllowlist("")).toEqual([]);
|
||||
expect(parseCorsAllowlist(" ")).toEqual([]);
|
||||
});
|
||||
|
||||
it("parses and validates a comma-separated list of origins", () => {
|
||||
const result = parseCorsAllowlist(
|
||||
"https://app.kordant.com,https://admin.kordant.com",
|
||||
);
|
||||
expect(result).toEqual([
|
||||
"https://app.kordant.com",
|
||||
"https://admin.kordant.com",
|
||||
]);
|
||||
});
|
||||
|
||||
it("filters out invalid origins and warns", () => {
|
||||
const result = parseCorsAllowlist(
|
||||
"https://app.kordant.com,evil.com,*,ftp://bad.com",
|
||||
);
|
||||
expect(result).toEqual(["https://app.kordant.com"]);
|
||||
expect(console.warn).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
it("rejects http://localhost:9999 when not in the allowlist", () => {
|
||||
// This origin is not in the configured list
|
||||
const result = parseCorsAllowlist("https://app.kordant.com");
|
||||
expect(result).not.toContain("http://localhost:9999");
|
||||
expect(result).toEqual(["https://app.kordant.com"]);
|
||||
});
|
||||
|
||||
it("handles whitespace around commas", () => {
|
||||
const result = parseCorsAllowlist(
|
||||
" https://app.kordant.com , https://admin.kordant.com ",
|
||||
);
|
||||
expect(result).toEqual([
|
||||
"https://app.kordant.com",
|
||||
"https://admin.kordant.com",
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import { createMiddleware, type RequestMiddleware } from "@solidjs/start/middleware";
|
||||
import { clerkMiddleware } from "clerk-solidjs/start/server";
|
||||
import { requestLogger } from "~/server/lib/request-logger";
|
||||
import {
|
||||
validateCorsOrigin,
|
||||
parseCorsAllowlist,
|
||||
} from "~/server/lib/cors-validation";
|
||||
|
||||
const securityHeaders: RequestMiddleware = (event) => {
|
||||
const h = event.response.headers;
|
||||
@@ -18,26 +22,6 @@ const securityHeaders: RequestMiddleware = (event) => {
|
||||
h.set("X-Permitted-Cross-Domain-Policies", "none");
|
||||
};
|
||||
|
||||
/**
|
||||
* Validates that an origin string is a well-formed HTTP(S) origin.
|
||||
* Rejects wildcards, empty strings, non-HTTP schemes, and malformed URLs.
|
||||
*/
|
||||
function isValidCorsOrigin(origin: string): boolean {
|
||||
if (!origin || !origin.trim()) return false;
|
||||
if (origin === "*") return false;
|
||||
|
||||
try {
|
||||
const parsed = new URL(origin);
|
||||
// Only allow http and https schemes
|
||||
if (!parsed.protocol.match(/^https?:$/)) return false;
|
||||
// Hostname must not be empty
|
||||
if (!parsed.hostname) return false;
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const corsHeaders: RequestMiddleware = (event) => {
|
||||
const origin = event.request.headers.get("origin");
|
||||
const allowedOrigins = [
|
||||
@@ -48,13 +32,17 @@ const corsHeaders: RequestMiddleware = (event) => {
|
||||
// Validate APP_URL before trusting it as a CORS origin
|
||||
const appUrl = process.env.APP_URL;
|
||||
if (appUrl) {
|
||||
if (isValidCorsOrigin(appUrl)) {
|
||||
if (validateCorsOrigin(appUrl)) {
|
||||
allowedOrigins.push(appUrl);
|
||||
} else {
|
||||
console.warn(`[cors] APP_URL "${appUrl}" is not a valid HTTP(S) origin and will be excluded from CORS allowlist`);
|
||||
}
|
||||
}
|
||||
|
||||
// Parse and validate additional origins from VALID_CORS_ORIGINS env var
|
||||
const validOrigins = parseCorsAllowlist(process.env.VALID_CORS_ORIGINS, "VALID_CORS_ORIGINS");
|
||||
allowedOrigins.push(...validOrigins);
|
||||
|
||||
if (origin && allowedOrigins.includes(origin)) {
|
||||
event.response.headers.set("Access-Control-Allow-Origin", origin);
|
||||
event.response.headers.set("Access-Control-Allow-Credentials", "true");
|
||||
|
||||
@@ -224,7 +224,7 @@ describe("billing.createCheckoutSession", () => {
|
||||
const api = createCaller(makeUser());
|
||||
const result = await api.createCheckoutSession({
|
||||
priceId: "price_basic",
|
||||
returnUrl: "https://example.com/return",
|
||||
returnUrl: "https://app.kordant.com/return",
|
||||
}) as { clientSecret: string; sessionId: string };
|
||||
|
||||
expect(result.clientSecret).toBe("cs_123_secret");
|
||||
@@ -240,7 +240,7 @@ describe("billing.createCheckoutSession", () => {
|
||||
const api = createCaller(makeUser());
|
||||
await api.createCheckoutSession({
|
||||
priceId: "price_plus",
|
||||
returnUrl: "https://example.com/return",
|
||||
returnUrl: "https://app.kordant.com/return",
|
||||
});
|
||||
|
||||
expect(mockChangeSubscriptionTier).toHaveBeenCalledWith("sub_stripe_1", "price_plus");
|
||||
@@ -257,7 +257,7 @@ describe("billing.createTrialSubscription", () => {
|
||||
|
||||
const api = createCaller(makeUser());
|
||||
const result = await api.createTrialSubscription({
|
||||
returnUrl: "https://example.com/return",
|
||||
returnUrl: "https://app.kordant.com/return",
|
||||
});
|
||||
|
||||
expect(result.sessionId).toBe("session_trial");
|
||||
@@ -270,7 +270,7 @@ describe("billing.createTrialSubscription", () => {
|
||||
|
||||
const api = createCaller(makeUser());
|
||||
await expect(api.createTrialSubscription({
|
||||
returnUrl: "https://example.com/return",
|
||||
returnUrl: "https://app.kordant.com/return",
|
||||
})).rejects.toThrow(TRPCError);
|
||||
});
|
||||
});
|
||||
@@ -304,7 +304,7 @@ describe("billing.createPortalSession", () => {
|
||||
|
||||
const api = createCaller(makeUser());
|
||||
const result = await api.createPortalSession({
|
||||
returnUrl: "https://example.com/return",
|
||||
returnUrl: "https://app.kordant.com/return",
|
||||
});
|
||||
|
||||
expect(result.url).toBe("https://billing.stripe.com/portal/session_456");
|
||||
@@ -312,7 +312,7 @@ describe("billing.createPortalSession", () => {
|
||||
|
||||
it("throws NOT_FOUND when user has no stripeCustomerId", async () => {
|
||||
const api = createCaller(makeUser({ stripeCustomerId: null }));
|
||||
await expect(api.createPortalSession({ returnUrl: "https://example.com/return" })).rejects.toThrow(TRPCError);
|
||||
await expect(api.createPortalSession({ returnUrl: "https://app.kordant.com/return" })).rejects.toThrow(TRPCError);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -1,85 +1,273 @@
|
||||
import { describe, it, expect } from "vitest";
|
||||
|
||||
/**
|
||||
* Mirrors the SENSITIVE_PROCEDURES Set from utils.ts
|
||||
* Mirrors the PROCEDURE_TIERS mapping from utils.ts
|
||||
*
|
||||
* Categories:
|
||||
* sensitive (3/hr) — auth operations
|
||||
* expensive (5/hr) — external API calls / ML inference
|
||||
* memory (10/hr) — memory-heavy ML processing
|
||||
*/
|
||||
const SENSITIVE_PROCEDURES = new Set([
|
||||
"user.login",
|
||||
"user.signup",
|
||||
"user.forgotPassword",
|
||||
"user.resetPassword",
|
||||
"darkwatch.runScan",
|
||||
"darkwatch.runFullScan",
|
||||
"voiceprint.analyzeAudio",
|
||||
"voiceprint.createEnrollment",
|
||||
]);
|
||||
const PROCEDURE_TIERS: Record<string, string> = {
|
||||
// Auth operations — 3/hr
|
||||
"user.login": "sensitive",
|
||||
"user.signup": "sensitive",
|
||||
"user.forgotPassword": "sensitive",
|
||||
"user.resetPassword": "sensitive",
|
||||
|
||||
function getRateLimitTier(path: string, userRole: string | null, hasUser: boolean): "sensitive" | "authenticated" | "public" | "admin" {
|
||||
// Darkwatch — 5/hr (expensive external API calls: HIBP, SecurityTrails, Censys, Shodan)
|
||||
"darkwatch.runScan": "expensive",
|
||||
"darkwatch.runFullScan": "expensive",
|
||||
|
||||
// VoicePrint — 10/hr (ML analysis, 300MB+ memory per request)
|
||||
"voiceprint.analyzeAudio": "memory",
|
||||
"voiceprint.analyzeCallRecording": "memory",
|
||||
"voiceprint.createEnrollment": "memory",
|
||||
"voiceprint.enrollAdditionalSample": "memory",
|
||||
|
||||
// SpamShield — 5/hr (ML model inference)
|
||||
"spamshield.classifySMS": "expensive",
|
||||
"spamshield.classifyCall": "expensive",
|
||||
|
||||
// HomeTitle — 5/hr (county website scraping)
|
||||
"hometitle.runScan": "expensive",
|
||||
|
||||
// RemoveBrokers — 5/hr (broker website scraping)
|
||||
"removebrokers.scanForListings": "expensive",
|
||||
};
|
||||
|
||||
function getRateLimitTier(
|
||||
path: string,
|
||||
userRole: string | null,
|
||||
hasUser: boolean,
|
||||
): string {
|
||||
if (userRole === "admin") return "admin";
|
||||
if (SENSITIVE_PROCEDURES.has(path)) return "sensitive";
|
||||
return hasUser ? "authenticated" : "public";
|
||||
return PROCEDURE_TIERS[path] ?? (hasUser ? "authenticated" : "public");
|
||||
}
|
||||
|
||||
describe("Rate limiter exact matching", () => {
|
||||
describe("sensitive procedures", () => {
|
||||
it("matches auth procedures", () => {
|
||||
describe("Rate limiter tiered exact matching", () => {
|
||||
// -----------------------------------------------------------------------
|
||||
// sensitive tier (3/hr) — auth operations
|
||||
// -----------------------------------------------------------------------
|
||||
describe("sensitive tier — auth operations (3/hr)", () => {
|
||||
it("matches user.login", () => {
|
||||
expect(getRateLimitTier("user.login", null, true)).toBe("sensitive");
|
||||
});
|
||||
|
||||
it("matches user.signup", () => {
|
||||
expect(getRateLimitTier("user.signup", null, true)).toBe("sensitive");
|
||||
});
|
||||
|
||||
it("matches user.forgotPassword", () => {
|
||||
expect(getRateLimitTier("user.forgotPassword", null, true)).toBe("sensitive");
|
||||
});
|
||||
|
||||
it("matches user.resetPassword", () => {
|
||||
expect(getRateLimitTier("user.resetPassword", null, true)).toBe("sensitive");
|
||||
});
|
||||
|
||||
it("matches darkwatch procedures", () => {
|
||||
expect(getRateLimitTier("darkwatch.runScan", null, true)).toBe("sensitive");
|
||||
expect(getRateLimitTier("darkwatch.runFullScan", null, true)).toBe("sensitive");
|
||||
});
|
||||
|
||||
it("matches voiceprint procedures", () => {
|
||||
expect(getRateLimitTier("voiceprint.analyzeAudio", null, true)).toBe("sensitive");
|
||||
expect(getRateLimitTier("voiceprint.createEnrollment", null, true)).toBe("sensitive");
|
||||
});
|
||||
});
|
||||
|
||||
describe("non-sensitive procedures", () => {
|
||||
it("returns authenticated tier for normal procedures", () => {
|
||||
// -----------------------------------------------------------------------
|
||||
// expensive tier (5/hr) — external API calls / ML inference
|
||||
// -----------------------------------------------------------------------
|
||||
describe("expensive tier — external API operations (5/hr)", () => {
|
||||
it("matches darkwatch.runScan", () => {
|
||||
expect(getRateLimitTier("darkwatch.runScan", null, true)).toBe("expensive");
|
||||
});
|
||||
|
||||
it("matches darkwatch.runFullScan", () => {
|
||||
expect(getRateLimitTier("darkwatch.runFullScan", null, true)).toBe("expensive");
|
||||
});
|
||||
|
||||
it("matches spamshield.classifySMS", () => {
|
||||
expect(getRateLimitTier("spamshield.classifySMS", null, true)).toBe("expensive");
|
||||
});
|
||||
|
||||
it("matches spamshield.classifyCall", () => {
|
||||
expect(getRateLimitTier("spamshield.classifyCall", null, true)).toBe("expensive");
|
||||
});
|
||||
|
||||
it("matches hometitle.runScan", () => {
|
||||
expect(getRateLimitTier("hometitle.runScan", null, true)).toBe("expensive");
|
||||
});
|
||||
|
||||
it("matches removebrokers.scanForListings", () => {
|
||||
expect(getRateLimitTier("removebrokers.scanForListings", null, true)).toBe("expensive");
|
||||
});
|
||||
});
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// memory tier (10/hr) — memory-heavy ML processing
|
||||
// -----------------------------------------------------------------------
|
||||
describe("memory tier — ML analysis (10/hr)", () => {
|
||||
it("matches voiceprint.analyzeAudio", () => {
|
||||
expect(getRateLimitTier("voiceprint.analyzeAudio", null, true)).toBe("memory");
|
||||
});
|
||||
|
||||
it("matches voiceprint.analyzeCallRecording", () => {
|
||||
expect(getRateLimitTier("voiceprint.analyzeCallRecording", null, true)).toBe("memory");
|
||||
});
|
||||
|
||||
it("matches voiceprint.createEnrollment", () => {
|
||||
expect(getRateLimitTier("voiceprint.createEnrollment", null, true)).toBe("memory");
|
||||
});
|
||||
|
||||
it("matches voiceprint.enrollAdditionalSample", () => {
|
||||
expect(getRateLimitTier("voiceprint.enrollAdditionalSample", null, true)).toBe("memory");
|
||||
});
|
||||
});
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Non-sensitive procedures — default tiers
|
||||
// -----------------------------------------------------------------------
|
||||
describe("default tier fallback", () => {
|
||||
it("returns authenticated tier for normal procedures when user is logged in", () => {
|
||||
expect(getRateLimitTier("blog.bySlug", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("correlation.search", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("spamshield.analyze", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("spamshield.getRules", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("billing.getInvoices", null, true)).toBe("authenticated");
|
||||
});
|
||||
|
||||
it("returns public tier for unauthenticated users", () => {
|
||||
expect(getRateLimitTier("blog.bySlug", null, false)).toBe("public");
|
||||
expect(getRateLimitTier("spamshield.modelInfo", null, false)).toBe("public");
|
||||
});
|
||||
|
||||
it("returns admin tier for admin users regardless of procedure", () => {
|
||||
expect(getRateLimitTier("user.login", "admin", true)).toBe("admin");
|
||||
expect(getRateLimitTier("darkwatch.runScan", "admin", true)).toBe("admin");
|
||||
expect(getRateLimitTier("voiceprint.analyzeAudio", "admin", true)).toBe("admin");
|
||||
expect(getRateLimitTier("darkwatch.nonexistent", "admin", true)).toBe("admin");
|
||||
});
|
||||
});
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Substring bypass prevention
|
||||
// -----------------------------------------------------------------------
|
||||
describe("substring bypass prevention", () => {
|
||||
it("does not match substring attacks on auth procedures", () => {
|
||||
// These should NOT be sensitive (substring match would incorrectly flag them)
|
||||
expect(getRateLimitTier("user.loginLike", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("user.signupPage", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("user.loginResetPassword", null, true)).toBe("authenticated");
|
||||
describe("auth procedures", () => {
|
||||
it("does not match login variant suffixes", () => {
|
||||
expect(getRateLimitTier("user.loginLike", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("user.loginPage", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("user.logins", null, true)).toBe("authenticated");
|
||||
});
|
||||
|
||||
it("does not match signup variant suffixes", () => {
|
||||
expect(getRateLimitTier("user.signupPage", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("user.signups", null, true)).toBe("authenticated");
|
||||
});
|
||||
|
||||
it("does not match concatenated procedure names", () => {
|
||||
expect(getRateLimitTier("user.loginResetPassword", null, true)).toBe("authenticated");
|
||||
});
|
||||
});
|
||||
|
||||
it("does not match substring attacks on darkwatch", () => {
|
||||
expect(getRateLimitTier("darkwatch.runScanLike", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("darkwatch.runScanHistory", null, true)).toBe("authenticated");
|
||||
describe("darkwatch procedures", () => {
|
||||
it("does not match suffix attacks", () => {
|
||||
expect(getRateLimitTier("darkwatch.runScanLike", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("darkwatch.runScanHistory", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("darkwatch.runScanner", null, true)).toBe("authenticated");
|
||||
});
|
||||
|
||||
it("does not match prefix attacks", () => {
|
||||
expect(getRateLimitTier("notdarkwatch.runScan", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("predarkwatch.runScan", null, true)).toBe("authenticated");
|
||||
});
|
||||
|
||||
it("does not match different method on same namespace", () => {
|
||||
expect(getRateLimitTier("darkwatch.notrunScan", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("darkwatch.getScanStatus", null, true)).toBe("authenticated");
|
||||
});
|
||||
});
|
||||
|
||||
it("does not match substring attacks on voiceprint", () => {
|
||||
expect(getRateLimitTier("voiceprint.analyzeAudioPlayer", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("voiceprint.createEnrollmentPage", null, true)).toBe("authenticated");
|
||||
describe("voiceprint procedures", () => {
|
||||
it("does not match suffix attacks", () => {
|
||||
expect(getRateLimitTier("voiceprint.analyzeAudioPlayer", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("voiceprint.analyzeAudioFile", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("voiceprint.createEnrollmentPage", null, true)).toBe("authenticated");
|
||||
});
|
||||
|
||||
it("does not match partial path segments", () => {
|
||||
expect(getRateLimitTier("voiceprint.analyze", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("voiceprint.create", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("voiceprint.enroll", null, true)).toBe("authenticated");
|
||||
});
|
||||
});
|
||||
|
||||
it("does not match partial path segments", () => {
|
||||
expect(getRateLimitTier("notdarkwatch.runScan", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("darkwatch.notrunScan", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("voiceprint.analyze", null, true)).toBe("authenticated");
|
||||
describe("spamshield procedures", () => {
|
||||
it("does not match suffix attacks", () => {
|
||||
expect(getRateLimitTier("spamshield.classifySMSSpam", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("spamshield.classifyCallLog", null, true)).toBe("authenticated");
|
||||
});
|
||||
|
||||
it("does not match different method on same namespace", () => {
|
||||
expect(getRateLimitTier("spamshield.getRules", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("spamshield.createRule", null, true)).toBe("authenticated");
|
||||
});
|
||||
});
|
||||
|
||||
describe("hometitle procedures", () => {
|
||||
it("does not match suffix attacks", () => {
|
||||
expect(getRateLimitTier("hometitle.runScanNow", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("hometitle.runScanner", null, true)).toBe("authenticated");
|
||||
});
|
||||
|
||||
it("does not match different method on same namespace", () => {
|
||||
expect(getRateLimitTier("hometitle.getProperties", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("hometitle.addProperty", null, true)).toBe("authenticated");
|
||||
});
|
||||
});
|
||||
|
||||
describe("removebrokers procedures", () => {
|
||||
it("does not match suffix attacks", () => {
|
||||
expect(getRateLimitTier("removebrokers.scanForListingsNow", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("removebrokers.scanForListingsBatch", null, true)).toBe("authenticated");
|
||||
});
|
||||
|
||||
it("does not match different method on same namespace", () => {
|
||||
expect(getRateLimitTier("removebrokers.getBrokerRegistry", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("removebrokers.createRemovalRequest", null, true)).toBe("authenticated");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Edge cases
|
||||
// -----------------------------------------------------------------------
|
||||
describe("edge cases", () => {
|
||||
it("handles empty path gracefully", () => {
|
||||
expect(getRateLimitTier("", null, true)).toBe("authenticated");
|
||||
});
|
||||
|
||||
it("handles unknown paths gracefully", () => {
|
||||
expect(getRateLimitTier("completely.unknown.procedure", null, true)).toBe("authenticated");
|
||||
});
|
||||
|
||||
it("handles paths with dots in unexpected places", () => {
|
||||
expect(getRateLimitTier(".darkwatch.runScan", null, true)).toBe("authenticated");
|
||||
expect(getRateLimitTier("darkwatch..runScan", null, true)).toBe("authenticated");
|
||||
});
|
||||
});
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Tier configuration verification
|
||||
// -----------------------------------------------------------------------
|
||||
describe("tier configuration", () => {
|
||||
it("all mapped tiers are valid rate limit tier keys", () => {
|
||||
const validTiers = new Set(["sensitive", "expensive", "memory"]);
|
||||
for (const tier of Object.values(PROCEDURE_TIERS)) {
|
||||
expect(validTiers.has(tier)).toBe(true);
|
||||
}
|
||||
});
|
||||
|
||||
it("every sensitive procedure has a defined tier", () => {
|
||||
// If a procedure is listed, it must have a valid tier entry
|
||||
const allMappedProcedures = Object.keys(PROCEDURE_TIERS);
|
||||
expect(allMappedProcedures.length).toBeGreaterThan(0);
|
||||
for (const proc of allMappedProcedures) {
|
||||
expect(PROCEDURE_TIERS[proc]).toBeDefined();
|
||||
expect(["sensitive", "expensive", "memory"]).toContain(PROCEDURE_TIERS[proc]);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -32,23 +32,49 @@ const isAdmin = t.middleware(({ ctx, next }) => {
|
||||
|
||||
export const adminProcedure = t.procedure.use(isAdmin);
|
||||
|
||||
/**
|
||||
* Tiered procedure-to-rate-limit mapping.
|
||||
*
|
||||
* Exact procedure path matching (not substring) to prevent bypass.
|
||||
* Categories:
|
||||
* sensitive (3/hr) — auth operations
|
||||
* expensive (5/hr) — external API calls / ML inference
|
||||
* memory (10/hr) — memory-heavy ML processing
|
||||
*/
|
||||
const PROCEDURE_TIERS: Record<string, keyof typeof import("~/server/lib/ratelimit").rateLimitTiers> = {
|
||||
// Auth operations — 3/hr
|
||||
"user.login": "sensitive",
|
||||
"user.signup": "sensitive",
|
||||
"user.forgotPassword": "sensitive",
|
||||
"user.resetPassword": "sensitive",
|
||||
|
||||
// Darkwatch — 5/hr (expensive external API calls: HIBP, SecurityTrails, Censys, Shodan)
|
||||
"darkwatch.runScan": "expensive",
|
||||
"darkwatch.runFullScan": "expensive",
|
||||
|
||||
// VoicePrint — 10/hr (ML analysis, 300MB+ memory per request)
|
||||
"voiceprint.analyzeAudio": "memory",
|
||||
"voiceprint.analyzeCallRecording": "memory",
|
||||
"voiceprint.createEnrollment": "memory",
|
||||
"voiceprint.enrollAdditionalSample": "memory",
|
||||
|
||||
// SpamShield — 5/hr (ML model inference)
|
||||
"spamshield.classifySMS": "expensive",
|
||||
"spamshield.classifyCall": "expensive",
|
||||
|
||||
// HomeTitle — 5/hr (county website scraping)
|
||||
"hometitle.runScan": "expensive",
|
||||
|
||||
// RemoveBrokers — 5/hr (broker website scraping)
|
||||
"removebrokers.scanForListings": "expensive",
|
||||
};
|
||||
|
||||
const isRateLimited = t.middleware(async ({ ctx, next, path }) => {
|
||||
const identifier = ctx.user?.id ?? ctx.apiKey ?? "anonymous";
|
||||
const tier = ctx.user?.role === "admin" ? "admin" : ctx.user ? "authenticated" : "public";
|
||||
|
||||
// Sensitive operations get stricter limits (exact match to prevent bypass)
|
||||
const SENSITIVE_PROCEDURES = new Set([
|
||||
"user.login",
|
||||
"user.signup",
|
||||
"user.forgotPassword",
|
||||
"user.resetPassword",
|
||||
"darkwatch.runScan",
|
||||
"darkwatch.runFullScan",
|
||||
"voiceprint.analyzeAudio",
|
||||
"voiceprint.createEnrollment",
|
||||
]);
|
||||
|
||||
const effectiveTier = SENSITIVE_PROCEDURES.has(path) ? "sensitive" : tier;
|
||||
// Look up procedure-specific tier, falling back to the default for the user
|
||||
const effectiveTier = PROCEDURE_TIERS[path] ?? tier;
|
||||
|
||||
await checkRateLimitOrThrow(identifier, effectiveTier);
|
||||
return next();
|
||||
|
||||
49
web/src/server/lib/cors-validation.ts
Normal file
49
web/src/server/lib/cors-validation.ts
Normal file
@@ -0,0 +1,49 @@
|
||||
/**
|
||||
* Validates that an origin string is a well-formed HTTP(S) origin.
|
||||
* Rejects wildcards, empty strings, non-HTTP schemes, and malformed URLs.
|
||||
*/
|
||||
export function validateCorsOrigin(origin: string): boolean {
|
||||
if (!origin || !origin.trim()) return false;
|
||||
if (origin === "*") return false;
|
||||
|
||||
try {
|
||||
const parsed = new URL(origin);
|
||||
// Only allow http and https schemes
|
||||
if (!parsed.protocol.match(/^https?:$/)) return false;
|
||||
// Hostname must not be empty
|
||||
if (!parsed.hostname) return false;
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses a comma-separated origin allowlist string, validates each entry,
|
||||
* and returns the subset of valid origins. Logs warnings for invalid entries.
|
||||
*
|
||||
* Accepts undefined/null to support optional env vars.
|
||||
*/
|
||||
export function parseCorsAllowlist(
|
||||
raw: string | undefined | null,
|
||||
label: string = "CORS_ORIGINS",
|
||||
): string[] {
|
||||
if (!raw || !raw.trim()) return [];
|
||||
|
||||
const entries = raw
|
||||
.split(",")
|
||||
.map((s) => s.trim())
|
||||
.filter(Boolean);
|
||||
|
||||
const valid: string[] = [];
|
||||
for (const entry of entries) {
|
||||
if (validateCorsOrigin(entry)) {
|
||||
valid.push(entry);
|
||||
} else {
|
||||
console.warn(
|
||||
`[cors] ${label} entry "${entry}" is not a valid HTTP(S) origin and will be excluded`,
|
||||
);
|
||||
}
|
||||
}
|
||||
return valid;
|
||||
}
|
||||
@@ -22,15 +22,19 @@ export type RateLimitTier = {
|
||||
windowMs: number;
|
||||
};
|
||||
|
||||
export const rateLimitTiers: Record<string, RateLimitTier> = {
|
||||
export const rateLimitTiers = {
|
||||
public: { limit: 5, windowMs: 60_000 },
|
||||
authenticated: { limit: 100, windowMs: 60_000 },
|
||||
sensitive: { limit: 3, windowMs: 3_600_000 },
|
||||
/** Expensive external API operations: darkwatch scans, hometitle scans, spamshield ML */
|
||||
expensive: { limit: 5, windowMs: 3_600_000 },
|
||||
/** Memory-intensive ML operations: voiceprint analysis, enrollment */
|
||||
memory: { limit: 10, windowMs: 3_600_000 },
|
||||
admin: { limit: 50, windowMs: 60_000 },
|
||||
websocket: { limit: 1, windowMs: 60_000 },
|
||||
websocketReconnect: { limit: 5, windowMs: 60_000 },
|
||||
reputation: { limit: 100, windowMs: 60_000 },
|
||||
};
|
||||
} as const;
|
||||
|
||||
export async function checkRateLimit(
|
||||
identifier: string,
|
||||
|
||||
@@ -235,16 +235,15 @@ export async function changeSubscriptionTier(
|
||||
|
||||
// Update DB record
|
||||
const tier = mapStripeProductToTier(newPriceId);
|
||||
const subData = updatedSub as unknown as Record<string, unknown>;
|
||||
await updateSubscriptionInDB(stripeSubscriptionId, {
|
||||
tier,
|
||||
stripePriceId: newPriceId,
|
||||
status: (subData.status as SubscriptionStatus) ?? "active",
|
||||
currentPeriodStart: subData.current_period_start
|
||||
? new Date((subData.current_period_start as number) * 1000)
|
||||
status: (updatedSub.status as SubscriptionStatus) ?? "active",
|
||||
currentPeriodStart: updatedSub.current_period_start
|
||||
? new Date(updatedSub.current_period_start * 1000)
|
||||
: undefined,
|
||||
currentPeriodEnd: subData.current_period_end
|
||||
? new Date((subData.current_period_end as number) * 1000)
|
||||
currentPeriodEnd: updatedSub.current_period_end
|
||||
? new Date(updatedSub.current_period_end * 1000)
|
||||
: undefined,
|
||||
});
|
||||
|
||||
@@ -338,7 +337,6 @@ async function upsertSubscriptionFromStripe(
|
||||
userId: string,
|
||||
stripeSub: Stripe.Subscription,
|
||||
) {
|
||||
const subData = stripeSub as unknown as Record<string, unknown>;
|
||||
const priceItem = stripeSub.items.data[0]?.price;
|
||||
const priceId =
|
||||
typeof priceItem === "string"
|
||||
@@ -350,17 +348,17 @@ async function upsertSubscriptionFromStripe(
|
||||
stripeId: stripeSub.id,
|
||||
stripePriceId: priceId || undefined,
|
||||
tier: mapStripeProductToTier(priceId),
|
||||
status: (subData.status as SubscriptionStatus) ?? "active",
|
||||
currentPeriodStart: subData.current_period_start
|
||||
? new Date((subData.current_period_start as number) * 1000)
|
||||
status: (stripeSub.status as SubscriptionStatus) ?? "active",
|
||||
currentPeriodStart: stripeSub.current_period_start
|
||||
? new Date(stripeSub.current_period_start * 1000)
|
||||
: undefined,
|
||||
currentPeriodEnd: subData.current_period_end
|
||||
? new Date((subData.current_period_end as number) * 1000)
|
||||
currentPeriodEnd: stripeSub.current_period_end
|
||||
? new Date(stripeSub.current_period_end * 1000)
|
||||
: undefined,
|
||||
trialEnd: subData.trial_end
|
||||
? new Date((subData.trial_end as number) * 1000)
|
||||
trialEnd: stripeSub.trial_end
|
||||
? new Date(stripeSub.trial_end * 1000)
|
||||
: undefined,
|
||||
cancelAtPeriodEnd: Boolean(subData.cancel_at_period_end),
|
||||
cancelAtPeriodEnd: stripeSub.cancel_at_period_end,
|
||||
};
|
||||
|
||||
// Upsert: insert or update if stripeId already exists
|
||||
@@ -386,9 +384,7 @@ async function extractPaymentMethodLast4(
|
||||
): Promise<string | undefined> {
|
||||
const defaultSource = stripeSub.default_payment_method;
|
||||
if (!defaultSource || typeof defaultSource === "string") return undefined;
|
||||
const pm = defaultSource as Stripe.PaymentMethod;
|
||||
if (pm.card?.last4) return pm.card.last4;
|
||||
return undefined;
|
||||
return defaultSource.card?.last4 ?? undefined;
|
||||
}
|
||||
|
||||
export async function handleWebhookEvent(event: Stripe.Event) {
|
||||
|
||||
@@ -1,30 +1,9 @@
|
||||
// @vitest-environment node
|
||||
import { describe, it, expect } from "vitest";
|
||||
import { isBlockedUrl, generatePDF } from "./generator";
|
||||
|
||||
/**
|
||||
* URL blocking logic for SSRF protection.
|
||||
* Mirrors the isBlockedUrl function in generator.ts.
|
||||
*/
|
||||
function isBlockedUrl(url: string): boolean {
|
||||
if (url.startsWith("file:")) return true;
|
||||
if (url.startsWith("data:")) return true;
|
||||
if (/^https?:\/\/(169\.254\.169\.254|metadata\.google\.internal)/i.test(url)) return true;
|
||||
|
||||
const hostname = url.replace(/^https?:\/\//, "").split(/[/:?]/)[0];
|
||||
if (/^(\d+\.\d+\.\d+\.\d+)/.test(hostname)) {
|
||||
const [, ip] = hostname.match(/^(\d+\.\d+\.\d+\.\d+)/)!;
|
||||
const parts = ip.split(".").map(Number);
|
||||
if (parts[0] === 10) return true;
|
||||
if (parts[0] === 172 && parts[1] >= 16 && parts[1] <= 31) return true;
|
||||
if (parts[0] === 192 && parts[1] === 168) return true;
|
||||
if (parts[0] === 127) return true;
|
||||
if (parts[0] === 0) return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
describe("SSRF URL blocking", () => {
|
||||
describe("blocked URLs", () => {
|
||||
describe("SSRF URL blocking (isBlockedUrl)", () => {
|
||||
describe("blocked URL schemes", () => {
|
||||
it("blocks file:// URLs", () => {
|
||||
expect(isBlockedUrl("file:///etc/passwd")).toBe(true);
|
||||
expect(isBlockedUrl("file:///etc/shadow")).toBe(true);
|
||||
@@ -36,13 +15,66 @@ describe("SSRF URL blocking", () => {
|
||||
expect(isBlockedUrl("data:image/png;base64,abc")).toBe(true);
|
||||
});
|
||||
|
||||
it("blocks cloud metadata endpoints", () => {
|
||||
it("blocks chrome:// URLs", () => {
|
||||
expect(isBlockedUrl("chrome://settings")).toBe(true);
|
||||
expect(isBlockedUrl("chrome://version")).toBe(true);
|
||||
expect(isBlockedUrl("chrome://net-internals")).toBe(true);
|
||||
});
|
||||
|
||||
it("blocks about: URLs", () => {
|
||||
expect(isBlockedUrl("about:blank")).toBe(true);
|
||||
expect(isBlockedUrl("about:config")).toBe(true);
|
||||
expect(isBlockedUrl("about:debugging")).toBe(true);
|
||||
});
|
||||
|
||||
it("blocks ftp:// URLs", () => {
|
||||
expect(isBlockedUrl("ftp://internal-server.secrets.com/data")).toBe(true);
|
||||
});
|
||||
|
||||
it("blocks view-source: URLs", () => {
|
||||
expect(isBlockedUrl("view-source:https://example.com")).toBe(true);
|
||||
});
|
||||
|
||||
it("blocks javascript: URLs", () => {
|
||||
expect(isBlockedUrl("javascript:alert(1)")).toBe(true);
|
||||
});
|
||||
|
||||
it("blocks schemes case-insensitively", () => {
|
||||
expect(isBlockedUrl("FILE:///etc/passwd")).toBe(true);
|
||||
expect(isBlockedUrl("DATA:text/html,test")).toBe(true);
|
||||
expect(isBlockedUrl("Chrome://settings")).toBe(true);
|
||||
expect(isBlockedUrl("About:blank")).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("blocked cloud metadata endpoints", () => {
|
||||
it("blocks AWS metadata (169.254.169.254)", () => {
|
||||
expect(isBlockedUrl("http://169.254.169.254/latest/meta-data/")).toBe(true);
|
||||
expect(isBlockedUrl("https://169.254.169.254/computeMetadata/v1/")).toBe(true);
|
||||
expect(isBlockedUrl("http://169.254.169.254/latest/api/token")).toBe(true);
|
||||
});
|
||||
|
||||
it("blocks GCP metadata", () => {
|
||||
expect(isBlockedUrl("http://metadata.google.internal/computeMetadata/v1/")).toBe(true);
|
||||
expect(isBlockedUrl("https://metadata.google.internal/")).toBe(true);
|
||||
});
|
||||
|
||||
it("blocks DigitalOcean metadata", () => {
|
||||
expect(isBlockedUrl("http://169.254.170.2/v1.json")).toBe(true);
|
||||
expect(isBlockedUrl("http://metadata.digitalocean.com/meta.json")).toBe(true);
|
||||
});
|
||||
|
||||
it("blocks Oracle Cloud metadata", () => {
|
||||
expect(isBlockedUrl("http://192.168.56.1/latest/ocids/")).toBe(true);
|
||||
expect(isBlockedUrl("http://10.0.0.251/opc/")).toBe(true);
|
||||
});
|
||||
|
||||
it("blocks Alibaba Cloud metadata", () => {
|
||||
expect(isBlockedUrl("http://224.0.0.1/latest/meta-data/")).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("blocked private IPv4 ranges", () => {
|
||||
it("blocks 10.0.0.0/8", () => {
|
||||
expect(isBlockedUrl("http://10.0.0.1/admin")).toBe(true);
|
||||
expect(isBlockedUrl("http://10.255.255.255/")).toBe(true);
|
||||
@@ -66,13 +98,49 @@ describe("SSRF URL blocking", () => {
|
||||
expect(isBlockedUrl("http://192.168.255.255/")).toBe(true);
|
||||
});
|
||||
|
||||
it("blocks 127.0.0.0/8", () => {
|
||||
it("blocks 127.0.0.0/8 (loopback)", () => {
|
||||
expect(isBlockedUrl("http://127.0.0.1:8080/health")).toBe(true);
|
||||
expect(isBlockedUrl("http://127.0.0.2/")).toBe(true);
|
||||
expect(isBlockedUrl("http://127.255.255.255/")).toBe(true);
|
||||
});
|
||||
|
||||
it("blocks 0.0.0.0", () => {
|
||||
it("blocks 0.0.0.0/8", () => {
|
||||
expect(isBlockedUrl("http://0.0.0.0/")).toBe(true);
|
||||
expect(isBlockedUrl("http://0.1.2.3/")).toBe(true);
|
||||
});
|
||||
|
||||
it("blocks 169.254.0.0/16 (link-local)", () => {
|
||||
expect(isBlockedUrl("http://169.254.1.1/")).toBe(true);
|
||||
expect(isBlockedUrl("http://169.254.169.254/latest/meta-data/iam/security-credentials/")).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("blocked bypass techniques", () => {
|
||||
it("blocks integer IP notation", () => {
|
||||
// 2130706433 = 127.0.0.1
|
||||
expect(isBlockedUrl("http://2130706433/")).toBe(true);
|
||||
// 167772162 = 10.0.0.2
|
||||
expect(isBlockedUrl("http://167772162/")).toBe(true);
|
||||
});
|
||||
|
||||
it("blocks octal IP notation", () => {
|
||||
// 0177.0.0.1 = 127.0.0.1
|
||||
expect(isBlockedUrl("http://0177.0.0.1/")).toBe(true);
|
||||
});
|
||||
|
||||
it("blocks IPv6 loopback", () => {
|
||||
expect(isBlockedUrl("http://[::1]/")).toBe(true);
|
||||
expect(isBlockedUrl("http://[::]/")).toBe(true);
|
||||
});
|
||||
|
||||
it("blocks IPv6 private ranges", () => {
|
||||
expect(isBlockedUrl("http://[fd00::1]/")).toBe(true);
|
||||
expect(isBlockedUrl("http://[fe80::1]/")).toBe(true);
|
||||
});
|
||||
|
||||
it("does not block URLs with IP-like path segments", () => {
|
||||
expect(isBlockedUrl("https://example.com/192.168.1.1")).toBe(false);
|
||||
expect(isBlockedUrl("https://cdn.example.com/path/10.0.0.1/image")).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -82,10 +150,7 @@ describe("SSRF URL blocking", () => {
|
||||
expect(isBlockedUrl("http://cdn.example.com/font.woff2")).toBe(false);
|
||||
expect(isBlockedUrl("https://fonts.googleapis.com/css")).toBe(false);
|
||||
expect(isBlockedUrl("https://app.kordant.com/api")).toBe(false);
|
||||
});
|
||||
|
||||
it("does not block URLs with IP-like path segments", () => {
|
||||
expect(isBlockedUrl("https://example.com/192.168.1.1")).toBe(false);
|
||||
expect(isBlockedUrl("https://unpkg.com/lodash@4.17.21")).toBe(false);
|
||||
});
|
||||
|
||||
it("handles edge cases", () => {
|
||||
@@ -94,3 +159,74 @@ describe("SSRF URL blocking", () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("generatePDF SSRF integration", () => {
|
||||
describe("request interception configuration", () => {
|
||||
it("generatePDF returns a Buffer", async () => {
|
||||
expect(typeof generatePDF).toBe("function");
|
||||
const result = await generatePDF("<html><body><h1>Test</h1></body></html>");
|
||||
expect(Buffer.isBuffer(result)).toBe(true);
|
||||
});
|
||||
|
||||
it("generatePDF returns a Buffer for legitimate HTML", async () => {
|
||||
const html = `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>Test Report</title></head>
|
||||
<body><h1>Test Report</h1><p>This is a test.</p></body>
|
||||
</html>
|
||||
`;
|
||||
const result = await generatePDF(html);
|
||||
// In CI/test env without Chrome, falls back to HTML Buffer.
|
||||
// In production with Chrome, returns valid PDF (%PDF- header).
|
||||
expect(Buffer.isBuffer(result)).toBe(true);
|
||||
expect(result.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it("generatePDF handles HTML with embedded file:// URLs without crashing", async () => {
|
||||
const html = `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<body>
|
||||
<h1>Test Report</h1>
|
||||
<img src="file:///etc/passwd" alt="blocked" />
|
||||
<a href="file:///etc/shadow">blocked link</a>
|
||||
</body>
|
||||
</html>
|
||||
`;
|
||||
// Should not throw — blocked URLs are aborted, not fatal
|
||||
const result = await generatePDF(html);
|
||||
expect(Buffer.isBuffer(result)).toBe(true);
|
||||
});
|
||||
|
||||
it("generatePDF handles HTML with internal IP URLs without crashing", async () => {
|
||||
const html = `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<body>
|
||||
<h1>Test Report</h1>
|
||||
<img src="http://169.254.169.254/latest/meta-data/" alt="blocked" />
|
||||
<img src="http://10.0.0.1/admin" alt="blocked" />
|
||||
<img src="http://127.0.0.1:8080/health" alt="blocked" />
|
||||
</body>
|
||||
</html>
|
||||
`;
|
||||
const result = await generatePDF(html);
|
||||
expect(Buffer.isBuffer(result)).toBe(true);
|
||||
});
|
||||
|
||||
it("generatePDF handles HTML with data: URIs without crashing", async () => {
|
||||
const html = `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<body>
|
||||
<h1>Test Report</h1>
|
||||
<img src="data:text/html,<script>alert(1)</script>" alt="blocked" />
|
||||
</body>
|
||||
</html>
|
||||
`;
|
||||
const result = await generatePDF(html);
|
||||
expect(Buffer.isBuffer(result)).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -248,19 +248,85 @@ export function renderHTML(data: ReportData, reportType: string): string {
|
||||
|
||||
/**
|
||||
* Returns true if the URL should be blocked (SSRF/metadata/internal access).
|
||||
*
|
||||
* Compensating control for --no-sandbox deployment: blocks all network requests
|
||||
* to dangerous URL schemes, private IP ranges, and cloud metadata endpoints
|
||||
* at the Puppeteer request-interception layer.
|
||||
*/
|
||||
export function isBlockedUrl(url: string): boolean {
|
||||
// Block local file access
|
||||
if (url.startsWith("file:")) return true;
|
||||
const lower = url.toLowerCase();
|
||||
|
||||
// Block data URIs
|
||||
if (url.startsWith("data:")) return true;
|
||||
// Block dangerous URL schemes
|
||||
const blockedSchemes = ["file:", "data:", "chrome:", "about:", "ftp:", "view-source:", "javascript:"];
|
||||
for (const scheme of blockedSchemes) {
|
||||
if (lower.startsWith(scheme)) return true;
|
||||
}
|
||||
|
||||
// Block cloud metadata endpoints
|
||||
if (/^https?:\/\/(169\.254\.169\.254|metadata\.google\.internal)/i.test(url)) return true;
|
||||
// Block cloud metadata endpoints (AWS, GCP, Azure, DigitalOcean, Oracle, Alibaba)
|
||||
const metadataPatterns = [
|
||||
/^https?:\/\/(169\.254\.169\.254|metadata\.google\.internal)/i,
|
||||
/^https?:\/\/[a-z0-9.-]*\.(amazonaws\.com|amazontrust\.com)$/i,
|
||||
/^https?:\/\/[a-z0-9.-]*\.(azure\.com|cloudapp\.azure\.com|management\.azure\.com)$/i,
|
||||
/^https?:\/\/(169\.254\.169\.254|169\.254\.170\.2|metadata\.digitalocean\.com)/i,
|
||||
/^https?:\/\/(192\.168\.56\.1|10\.0\.0\.251|curl\.169\.254\.169\.254)/i,
|
||||
/^https?:\/\/(ueor\.com|aliyun\.internal)/i,
|
||||
];
|
||||
for (const pattern of metadataPatterns) {
|
||||
if (pattern.test(url)) return true;
|
||||
}
|
||||
|
||||
// Block internal/private IP ranges
|
||||
const hostname = url.replace(/^https?:\/\//, "").split(/[/:?]/)[0];
|
||||
// Extract hostname for IP-based checks
|
||||
// Handle IPv6 in brackets first: http://[::1]/path → [::1]
|
||||
let hostname: string;
|
||||
const ipv6Match = url.match(/^https?:\/\/\[([^\]]+)\]/i);
|
||||
if (ipv6Match) {
|
||||
hostname = `[${ipv6Match[1]}]`;
|
||||
} else {
|
||||
hostname = url.replace(/^https?:\/\//, "").split(/[/:?]/)[0];
|
||||
}
|
||||
|
||||
// Handle IPv6 addresses in brackets [::1], [fd00::1], etc.
|
||||
if (hostname.startsWith("[") && hostname.endsWith("]")) {
|
||||
const ipv6Addr = hostname.slice(1, -1).toLowerCase();
|
||||
// Block IPv6 loopback
|
||||
if (ipv6Addr === "::1" || ipv6Addr === "::") return true;
|
||||
// Block IPv6 private ranges (fdxx::/8, fe80::/10)
|
||||
if (/^(fd|fe80|fe8[0-9a-f]|fec|fed|fee|fef)/.test(ipv6Addr)) return true;
|
||||
// Block IPv6 mapped IPv4 loopback ::ffff:127.x.x.x
|
||||
if (ipv6Addr.startsWith("::ffff:")) {
|
||||
const mapped = ipv6Addr.replace("::ffff:", "");
|
||||
if (/^(127\.|10\.|192\.168\.|172\.(1[6-9]|2[0-9]|3[01])\.)/.test(mapped)) return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Block integer IP notation (e.g., 2130706433 = 127.0.0.1)
|
||||
if (/^\d{7,}$/.test(hostname)) {
|
||||
const intIp = parseInt(hostname, 10);
|
||||
if (!isNaN(intIp) && intIp > 0) {
|
||||
const parts = [
|
||||
(intIp >> 24) & 0xff,
|
||||
(intIp >> 16) & 0xff,
|
||||
(intIp >> 8) & 0xff,
|
||||
intIp & 0xff,
|
||||
];
|
||||
// Apply same private IP checks to integer-decoded octets
|
||||
if (parts[0] === 10) return true;
|
||||
if (parts[0] === 172 && parts[1] >= 16 && parts[1] <= 31) return true;
|
||||
if (parts[0] === 192 && parts[1] === 168) return true;
|
||||
if (parts[0] === 127) return true;
|
||||
if (parts[0] === 0) return true;
|
||||
if (parts[0] === 169 && parts[1] === 254) return true;
|
||||
}
|
||||
return true; // Block any large integer that looks like an IP
|
||||
}
|
||||
|
||||
// Block octal IP notation (e.g., 0177.0.0.1 = 127.0.0.1)
|
||||
// Match IPs where any octet starts with 0 (possible octal)
|
||||
if (/^0\d+\./.test(hostname)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Block internal/private IPv4 ranges
|
||||
if (/^(\d+\.\d+\.\d+\.\d+)/.test(hostname)) {
|
||||
const [, ip] = hostname.match(/^(\d+\.\d+\.\d+\.\d+)/)!;
|
||||
const parts = ip.split(".").map(Number);
|
||||
@@ -272,8 +338,12 @@ export function isBlockedUrl(url: string): boolean {
|
||||
if (parts[0] === 192 && parts[1] === 168) return true;
|
||||
// 127.0.0.0/8 (loopback)
|
||||
if (parts[0] === 127) return true;
|
||||
// 0.0.0.0
|
||||
// 0.0.0.0/8
|
||||
if (parts[0] === 0) return true;
|
||||
// 169.254.0.0/16 (link-local / cloud metadata)
|
||||
if (parts[0] === 169 && parts[1] === 254) return true;
|
||||
// 224.0.0.0/4 (multicast, used by Alibaba Cloud metadata)
|
||||
if (parts[0] >= 224 && parts[0] <= 239) return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
@@ -282,9 +352,29 @@ export function isBlockedUrl(url: string): boolean {
|
||||
export async function generatePDF(html: string): Promise<Buffer> {
|
||||
try {
|
||||
const puppeteer = await import("puppeteer");
|
||||
const browser = await puppeteer.launch({ headless: true, args: ["--no-sandbox"] });
|
||||
|
||||
/*
|
||||
* SECURITY: --no-sandbox is required in Docker containers where Chrome
|
||||
* cannot run as root with proper sandboxing. Compensating controls:
|
||||
* 1. Request interception blocks all dangerous URL schemes and IPs
|
||||
* 2. --disable-dev-shm-usage avoids /dev/shm race conditions
|
||||
* 3. --disable-features=TrustTokens prevents token-based attacks
|
||||
* 4. Container runs as non-root user (appuser) per Dockerfile
|
||||
* 5. Network namespace isolation via Docker default networking
|
||||
*/
|
||||
const browser = await puppeteer.launch({
|
||||
headless: true,
|
||||
args: [
|
||||
"--no-sandbox",
|
||||
"--disable-dev-shm-usage",
|
||||
"--disable-features=TrustTokens",
|
||||
],
|
||||
});
|
||||
const page = await browser.newPage();
|
||||
|
||||
// Enable request interception (required in Puppeteer v22+)
|
||||
await page.setRequestInterception(true);
|
||||
|
||||
// Block dangerous network requests to prevent SSRF
|
||||
page.on("request", (request) => {
|
||||
const url = request.url();
|
||||
|
||||
@@ -31,6 +31,8 @@ vi.mock("./voiceprint/storage", () => ({
|
||||
deleteFile: vi.fn(),
|
||||
computeHash: vi.fn(),
|
||||
deleteAudio: vi.fn(),
|
||||
getUserStorageUsage: vi.fn(),
|
||||
checkStorageQuota: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("./voiceprint/ml.engine", () => ({
|
||||
@@ -65,10 +67,15 @@ vi.mock("~/server/services/alert.publisher", () => ({
|
||||
publishAlert: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("~/server/lib/ratelimit", () => ({
|
||||
checkRateLimitOrThrow: vi.fn(),
|
||||
}));
|
||||
|
||||
const storage = await import("./voiceprint/storage");
|
||||
const ml = await import("./voiceprint/ml.engine");
|
||||
const azure = await import("./voiceprint/azure.client");
|
||||
const tier = await import("~/server/lib/tier");
|
||||
const ratelimit = await import("~/server/lib/ratelimit");
|
||||
|
||||
const mockEnrollment = {
|
||||
id: "enr-1",
|
||||
@@ -158,10 +165,11 @@ const mockSub = {
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Default: user has plus tier with active subscription
|
||||
// Completely reset the mock to avoid cross-test pollution
|
||||
mockQueryResult.mockReset();
|
||||
mockQueryResult.mockResolvedValue([mockSub]);
|
||||
vi.mocked(tier.hasFeatureAccess).mockReturnValue(true);
|
||||
vi.mocked(ratelimit.checkRateLimitOrThrow).mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -189,9 +197,11 @@ describe("getEnrollments", () => {
|
||||
describe("createEnrollment", () => {
|
||||
it("saves audio, creates Azure profile, and stores enrollment", async () => {
|
||||
mockQueryResult.mockResolvedValueOnce([mockSub]); // subscription check
|
||||
mockQueryResult.mockResolvedValueOnce([{ count: 0 }]); // enrollment count check
|
||||
vi.mocked(storage.saveAudio).mockResolvedValue({
|
||||
hash: "audio-hash",
|
||||
filePath: "/path/file.wav",
|
||||
isNew: true,
|
||||
});
|
||||
vi.mocked(ml.preprocessAudio).mockResolvedValue({
|
||||
duration: 2.5,
|
||||
@@ -260,6 +270,7 @@ describe("enrollAdditionalSample", () => {
|
||||
vi.mocked(storage.saveAudio).mockResolvedValue({
|
||||
hash: "audio-hash-2",
|
||||
filePath: "/path/file2.wav",
|
||||
isNew: true,
|
||||
});
|
||||
vi.mocked(ml.preprocessAudio).mockResolvedValue({
|
||||
duration: 2.0,
|
||||
@@ -346,6 +357,7 @@ describe("analyzeAudio", () => {
|
||||
vi.mocked(storage.saveAudio).mockResolvedValue({
|
||||
hash: "audio-hash",
|
||||
filePath: "/path/file.wav",
|
||||
isNew: true,
|
||||
});
|
||||
vi.mocked(storage.getAudioUrl).mockReturnValue(
|
||||
"/uploads/voiceprint/user-1/audio-hash.wav",
|
||||
@@ -395,6 +407,7 @@ describe("analyzeAudio", () => {
|
||||
vi.mocked(storage.saveAudio).mockResolvedValue({
|
||||
hash: "audio-hash",
|
||||
filePath: "/path/file.wav",
|
||||
isNew: true,
|
||||
});
|
||||
vi.mocked(storage.getAudioUrl).mockReturnValue(
|
||||
"/uploads/voiceprint/user-1/audio-hash.wav",
|
||||
@@ -442,6 +455,7 @@ describe("analyzeAudio", () => {
|
||||
vi.mocked(storage.saveAudio).mockResolvedValue({
|
||||
hash: "audio-hash-synth",
|
||||
filePath: "/path/synth.wav",
|
||||
isNew: true,
|
||||
});
|
||||
vi.mocked(storage.getAudioUrl).mockReturnValue(
|
||||
"/uploads/voiceprint/user-1/audio-hash-synth.wav",
|
||||
@@ -634,9 +648,11 @@ describe("VoicePrint size limits", () => {
|
||||
|
||||
it("accepts createEnrollment with valid-sized payload", async () => {
|
||||
mockQueryResult.mockResolvedValueOnce([mockSub]); // subscription check
|
||||
mockQueryResult.mockResolvedValueOnce([{ count: 0 }]); // enrollment count check
|
||||
vi.mocked(storage.saveAudio).mockResolvedValue({
|
||||
hash: "audio-hash",
|
||||
filePath: "/path/file.wav",
|
||||
isNew: true,
|
||||
});
|
||||
vi.mocked(ml.preprocessAudio).mockResolvedValue({
|
||||
duration: 2.5,
|
||||
@@ -688,6 +704,7 @@ describe("VoicePrint size limits", () => {
|
||||
vi.mocked(storage.saveAudio).mockResolvedValue({
|
||||
hash: "audio-hash",
|
||||
filePath: "/path/file.wav",
|
||||
isNew: true,
|
||||
});
|
||||
vi.mocked(storage.getAudioUrl).mockReturnValue(
|
||||
"/uploads/voiceprint/user-1/audio-hash.wav",
|
||||
@@ -744,3 +761,187 @@ describe("VoicePrint tier enforcement", () => {
|
||||
await expect(getEnrollments("user-1")).rejects.toThrow(TRPCError);
|
||||
});
|
||||
});
|
||||
|
||||
describe("VoicePrint rate limiting", () => {
|
||||
it("createEnrollment checks rate limit", async () => {
|
||||
mockQueryResult.mockResolvedValueOnce([mockSub]); // subscription check
|
||||
mockQueryResult.mockResolvedValueOnce([{ count: 0 }]); // enrollment count check
|
||||
|
||||
const { createEnrollment } = await import("./voiceprint.service");
|
||||
await createEnrollment("user-1", "My Voice", "dGVzdA==");
|
||||
|
||||
expect(ratelimit.checkRateLimitOrThrow).toHaveBeenCalledWith("user-1", "memory");
|
||||
});
|
||||
|
||||
it("createEnrollment throws TOO_MANY_REQUESTS when rate limited", async () => {
|
||||
mockQueryResult.mockResolvedValueOnce([mockSub]); // subscription check
|
||||
mockQueryResult.mockResolvedValueOnce([{ count: 0 }]); // enrollment count check
|
||||
vi.mocked(ratelimit.checkRateLimitOrThrow).mockRejectedValue(
|
||||
new TRPCError({ code: "TOO_MANY_REQUESTS", message: "Rate limit exceeded" }),
|
||||
);
|
||||
|
||||
const { createEnrollment } = await import("./voiceprint.service");
|
||||
await expect(createEnrollment("user-1", "My Voice", "dGVzdA==")).rejects.toThrow(
|
||||
"Rate limit exceeded",
|
||||
);
|
||||
});
|
||||
|
||||
it("analyzeAudio checks rate limit", async () => {
|
||||
mockQueryResult.mockResolvedValueOnce([mockSub]); // subscription check
|
||||
mockQueryResult.mockResolvedValueOnce([{ count: 0 }]); // analysis limit check
|
||||
|
||||
vi.mocked(storage.saveAudio).mockResolvedValue({
|
||||
hash: "audio-hash",
|
||||
filePath: "/path/file.wav",
|
||||
isNew: true,
|
||||
});
|
||||
vi.mocked(storage.getAudioUrl).mockReturnValue("/uploads/voiceprint/user-1/audio-hash.wav");
|
||||
vi.mocked(ml.preprocessAudio).mockResolvedValue({
|
||||
duration: 2.5,
|
||||
sampleRate: 16000,
|
||||
channels: 1,
|
||||
rawPcm: Buffer.from("test"),
|
||||
snrEstimate: 30,
|
||||
rmsEnergy: 0.3,
|
||||
peakAmplitude: 0.8,
|
||||
});
|
||||
vi.mocked(ml.detectSynthetic).mockResolvedValue({
|
||||
isSynthetic: false,
|
||||
confidence: 0.95,
|
||||
score: 0.05,
|
||||
});
|
||||
vi.mocked(ml.deriveVerdict).mockReturnValue({ verdict: "NATURAL", isSynthetic: false });
|
||||
mockQueryResult.mockResolvedValueOnce([mockAnalysis]);
|
||||
|
||||
const { analyzeAudio } = await import("./voiceprint.service");
|
||||
await analyzeAudio("user-1", "dGVzdA==");
|
||||
|
||||
expect(ratelimit.checkRateLimitOrThrow).toHaveBeenCalledWith("user-1", "memory");
|
||||
});
|
||||
|
||||
it("enrollAdditionalSample checks rate limit", async () => {
|
||||
// Enrollment: subscription → find enrollment → update enrollment
|
||||
mockQueryResult
|
||||
.mockResolvedValueOnce([mockSub])
|
||||
.mockResolvedValueOnce([mockEnrollment])
|
||||
.mockResolvedValueOnce([mockEnrollment])
|
||||
.mockResolvedValue([mockSub]); // fallback
|
||||
|
||||
vi.mocked(storage.saveAudio).mockResolvedValue({
|
||||
hash: "audio-hash-2",
|
||||
filePath: "/path/file2.wav",
|
||||
isNew: true,
|
||||
});
|
||||
vi.mocked(ml.preprocessAudio).mockResolvedValue({
|
||||
duration: 2.0,
|
||||
sampleRate: 16000,
|
||||
channels: 1,
|
||||
rawPcm: Buffer.from("more-audio"),
|
||||
snrEstimate: 28,
|
||||
rmsEnergy: 0.25,
|
||||
peakAmplitude: 0.7,
|
||||
});
|
||||
|
||||
const mockAzureClient = {
|
||||
enrollProfile: vi.fn().mockResolvedValue({
|
||||
enrollmentStatus: "Enrolled",
|
||||
enrollmentsCount: 3,
|
||||
remainingEnrollments: 0,
|
||||
phrase: "My voice is my passport",
|
||||
audioLengthInSeconds: 2.0,
|
||||
}),
|
||||
};
|
||||
vi.mocked(azure.getAzureClient).mockReturnValue(mockAzureClient as any);
|
||||
|
||||
const { enrollAdditionalSample } = await import("./voiceprint.service");
|
||||
await enrollAdditionalSample("user-1", "enr-1", "bW9yZS1hdWRpbw==");
|
||||
|
||||
expect(ratelimit.checkRateLimitOrThrow).toHaveBeenCalledWith("user-1", "memory");
|
||||
});
|
||||
});
|
||||
|
||||
describe("VoicePrint enrollment limits", () => {
|
||||
it("rejects createEnrollment when max enrollments reached", async () => {
|
||||
// Enrollment: subscription → enrollment count
|
||||
mockQueryResult
|
||||
.mockResolvedValueOnce([mockSub])
|
||||
.mockResolvedValueOnce([{ count: 5 }])
|
||||
.mockResolvedValue([mockSub]); // fallback
|
||||
|
||||
const { createEnrollment } = await import("./voiceprint.service");
|
||||
await expect(
|
||||
createEnrollment("user-1", "My Voice", "dGVzdA=="),
|
||||
).rejects.toThrow("Maximum 5 active enrollments reached");
|
||||
});
|
||||
|
||||
it("allows createEnrollment when under enrollment limit", async () => {
|
||||
// Enrollment: subscription → enrollment count → insert enrollment
|
||||
mockQueryResult
|
||||
.mockResolvedValueOnce([mockSub])
|
||||
.mockResolvedValueOnce([{ count: 2 }])
|
||||
.mockResolvedValueOnce([mockEnrollment])
|
||||
.mockResolvedValue([mockSub]); // fallback
|
||||
|
||||
vi.mocked(storage.saveAudio).mockResolvedValue({
|
||||
hash: "audio-hash",
|
||||
filePath: "/path/file.wav",
|
||||
isNew: true,
|
||||
});
|
||||
vi.mocked(ml.preprocessAudio).mockResolvedValue({
|
||||
duration: 2.5,
|
||||
sampleRate: 16000,
|
||||
channels: 1,
|
||||
rawPcm: Buffer.from("test"),
|
||||
snrEstimate: 30,
|
||||
rmsEnergy: 0.3,
|
||||
peakAmplitude: 0.8,
|
||||
});
|
||||
vi.mocked(ml.generateEmbedding).mockResolvedValue({
|
||||
vector: new Float64Array(256),
|
||||
hash: "embed-hash",
|
||||
});
|
||||
|
||||
const mockAzureClient = {
|
||||
createProfile: vi.fn().mockResolvedValue({
|
||||
profileId: "azure-profile-123",
|
||||
locale: "en-US",
|
||||
enrollmentStatus: "Enrolling",
|
||||
remainingEnrollments: 2,
|
||||
createdDate: new Date().toISOString(),
|
||||
}),
|
||||
enrollProfile: vi.fn().mockResolvedValue({
|
||||
enrollmentStatus: "Enrolling",
|
||||
enrollmentsCount: 1,
|
||||
remainingEnrollments: 2,
|
||||
phrase: "My voice is my passport",
|
||||
audioLengthInSeconds: 2.5,
|
||||
}),
|
||||
};
|
||||
vi.mocked(azure.getAzureClient).mockReturnValue(mockAzureClient as any);
|
||||
|
||||
const { createEnrollment } = await import("./voiceprint.service");
|
||||
const result = await createEnrollment("user-1", "My Voice", "dGVzdA==");
|
||||
expect(result).toEqual(mockEnrollment);
|
||||
});
|
||||
});
|
||||
|
||||
describe("VoicePrint storage usage in stats", () => {
|
||||
it("includes storage usage in getUsageStats", async () => {
|
||||
// Stats: analyses count → enrollments count → call recordings count
|
||||
mockQueryResult
|
||||
.mockResolvedValueOnce([{ count: 5 }])
|
||||
.mockResolvedValueOnce([{ count: 2 }])
|
||||
.mockResolvedValueOnce([{ count: 3 }])
|
||||
.mockResolvedValue([mockSub]); // fallback
|
||||
|
||||
vi.mocked(storage.getUserStorageUsage).mockResolvedValue(1048576); // 1MB
|
||||
|
||||
const { getUsageStats } = await import("./voiceprint.service");
|
||||
const result = await getUsageStats("user-1");
|
||||
|
||||
expect(result.analysesThisMonth).toBe(5);
|
||||
expect(result.activeEnrollments).toBe(2);
|
||||
expect(result.storageUsedBytes).toBe(1048576);
|
||||
expect(result.storageUsedMB).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -11,7 +11,7 @@ import {
|
||||
subscriptions,
|
||||
normalizedAlerts,
|
||||
} from "~/server/db/schema";
|
||||
import { saveAudio, getAudioUrl, deleteFile } from "./voiceprint/storage";
|
||||
import { saveAudio, getAudioUrl, deleteFile, getUserStorageUsage } from "./voiceprint/storage";
|
||||
import { publishAlert } from "~/server/services/alert.publisher";
|
||||
import {
|
||||
preprocessAudio,
|
||||
@@ -29,6 +29,7 @@ import {
|
||||
TIER_ORDER,
|
||||
type SubWithEffectiveTier,
|
||||
} from "~/server/lib/tier";
|
||||
import { checkRateLimitOrThrow } from "~/server/lib/ratelimit";
|
||||
|
||||
type DetectionVerdict = "NATURAL" | "SYNTHETIC" | "UNCERTAIN";
|
||||
|
||||
@@ -38,6 +39,12 @@ const MAX_DECODED_SIZE = parseInt(
|
||||
10,
|
||||
);
|
||||
|
||||
/** Maximum number of active voice enrollments per user (default 5). */
|
||||
const MAX_ENROLLMENTS = parseInt(
|
||||
process.env.VOICEPRINT_MAX_ENROLLMENTS ?? "5",
|
||||
10,
|
||||
);
|
||||
|
||||
/** US states requiring two-party consent for call recording. */
|
||||
const TWO_PARTY_CONSENT_STATES = new Set([
|
||||
"CA", "CT", "FL", "HI", "IL", "MD", "MA", "MI", "MT", "NH", "OR", "PA", "WA",
|
||||
@@ -177,9 +184,38 @@ export async function createEnrollment(
|
||||
audioBase64: string,
|
||||
) {
|
||||
await checkVoicePrintAccess(userId);
|
||||
|
||||
// Rate limit: memory-intensive enrollment operations
|
||||
await checkRateLimitOrThrow(userId, "memory");
|
||||
|
||||
// Enforce enrollment count limit
|
||||
const [enrollmentCountResult] = await db
|
||||
.select({ count: count() })
|
||||
.from(voiceEnrollments)
|
||||
.where(
|
||||
and(
|
||||
eq(voiceEnrollments.userId, userId),
|
||||
eq(voiceEnrollments.isActive, true),
|
||||
),
|
||||
);
|
||||
|
||||
if (enrollmentCountResult.count >= MAX_ENROLLMENTS) {
|
||||
throw new TRPCError({
|
||||
code: "BAD_REQUEST",
|
||||
message: `Maximum ${MAX_ENROLLMENTS} active enrollments reached. Delete an existing enrollment to create a new one.`,
|
||||
});
|
||||
}
|
||||
|
||||
validateDecodedSize(audioBase64);
|
||||
const audioBuffer = Buffer.from(audioBase64, "base64");
|
||||
const { hash: _hash, filePath } = await saveAudio(userId, audioBuffer);
|
||||
|
||||
// Check for duplicate audio before expensive processing
|
||||
const saved = await saveAudio(userId, audioBuffer);
|
||||
|
||||
if (!saved.isNew) {
|
||||
// Duplicate audio — still allow enrollment creation but skip re-saving
|
||||
// (deduplication saves disk; enrollment can still use the same audio)
|
||||
}
|
||||
|
||||
// Preprocess audio to Azure-compatible format
|
||||
const features = await preprocessAudio(audioBuffer);
|
||||
@@ -220,7 +256,7 @@ export async function createEnrollment(
|
||||
azureEnrollmentStatus,
|
||||
enrollmentSampleCount,
|
||||
audioMetadata: {
|
||||
filePath,
|
||||
filePath: saved.filePath,
|
||||
duration: features.duration,
|
||||
sampleRate: features.sampleRate,
|
||||
azureProfileId,
|
||||
@@ -244,6 +280,10 @@ export async function enrollAdditionalSample(
|
||||
audioBase64: string,
|
||||
) {
|
||||
await checkVoicePrintAccess(userId);
|
||||
|
||||
// Rate limit: memory-intensive enrollment operations
|
||||
await checkRateLimitOrThrow(userId, "memory");
|
||||
|
||||
validateDecodedSize(audioBase64);
|
||||
|
||||
const [enrollment] = await db
|
||||
@@ -425,10 +465,15 @@ export async function analyzeAudio(
|
||||
) {
|
||||
await checkVoicePrintAccess(userId);
|
||||
await checkAnalysisLimit(userId);
|
||||
|
||||
// Rate limit: memory-intensive analysis operations
|
||||
await checkRateLimitOrThrow(userId, "memory");
|
||||
|
||||
validateDecodedSize(audioBase64);
|
||||
|
||||
const audioBuffer = Buffer.from(audioBase64, "base64");
|
||||
const { hash: audioHash } = await saveAudio(userId, audioBuffer);
|
||||
const saved = await saveAudio(userId, audioBuffer);
|
||||
const audioHash = saved.hash;
|
||||
|
||||
// Preprocess audio for Azure
|
||||
const features = await preprocessAudio(audioBuffer);
|
||||
@@ -677,10 +722,15 @@ export async function getUsageStats(userId: string) {
|
||||
),
|
||||
);
|
||||
|
||||
// Get disk storage usage
|
||||
const storageUsage = await getUserStorageUsage(userId);
|
||||
|
||||
return {
|
||||
analysesThisMonth: analysisCount.count,
|
||||
activeEnrollments: enrollmentCount.count,
|
||||
callRecordingsThisMonth: callRecordingsCount.count,
|
||||
storageUsedBytes: storageUsage,
|
||||
storageUsedMB: Math.round(storageUsage / 1024 / 1024 * 100) / 100,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -800,6 +850,9 @@ export async function analyzeCallRecording(
|
||||
) {
|
||||
await checkVoicePrintAccess(userId);
|
||||
|
||||
// Rate limit: memory-intensive call recording analysis
|
||||
await checkRateLimitOrThrow(userId, "memory");
|
||||
|
||||
let audioHash: string | undefined;
|
||||
let audioFilePath: string | undefined;
|
||||
|
||||
|
||||
@@ -221,5 +221,30 @@ describe("Audio Processor", () => {
|
||||
expect(result.pcmBuffer.length).toBe(result.samples.length * 2);
|
||||
expect(result.samples.BYTES_PER_ELEMENT).toBe(2);
|
||||
});
|
||||
|
||||
it("rejects oversized input files to prevent memory exhaustion", async () => {
|
||||
// Create a valid WAV that is larger than the default 5MB limit
|
||||
// 44100Hz * 2ch * 2bytes * duration = bytes. For 5MB: duration > 5*1024*1024 / (44100*2*2) ≈ 29.1s
|
||||
const oversizedWav = createTestWav(44100, 2, 16, 30);
|
||||
await expect(preprocessAudio(oversizedWav)).rejects.toThrow(
|
||||
"Audio file too large",
|
||||
);
|
||||
});
|
||||
|
||||
it("rejects extremely long audio from header before full decode", async () => {
|
||||
// Create a WAV with a 60-second duration (exceeds 30s + 30s grace = 60s)
|
||||
const longWav = createTestWav(16000, 1, 16, 65);
|
||||
await expect(preprocessAudio(longWav)).rejects.toThrow(
|
||||
"Audio too long",
|
||||
);
|
||||
});
|
||||
|
||||
it("allows audio just under the duration limit", async () => {
|
||||
// 35 seconds should be fine (30s max + 30s grace = 60s total header check)
|
||||
const wavBuffer = createTestWav(16000, 1, 16, 35);
|
||||
const result = await preprocessAudio(wavBuffer);
|
||||
// Output is capped at 30s by the post-processing limiter
|
||||
expect(result.duration).toBeLessThanOrEqual(30);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -41,6 +41,11 @@ interface WavHeader {
|
||||
|
||||
/** Maximum allowed audio duration in seconds */
|
||||
const MAX_DURATION_SEC = 30;
|
||||
/** Maximum raw WAV file size before processing (default 5MB). Prevents memory exhaustion. */
|
||||
const MAX_INPUT_BYTES = parseInt(
|
||||
process.env.VOICEPRINT_MAX_INPUT_BYTES ?? "5242880",
|
||||
10,
|
||||
);
|
||||
/** Target normalization level in dBFS */
|
||||
const TARGET_DBFS = -3;
|
||||
/** Frame size for VAD in milliseconds */
|
||||
@@ -359,16 +364,27 @@ function computeQualityMetrics(samples: Float64Array): {
|
||||
|
||||
/**
|
||||
* Main audio preprocessing pipeline:
|
||||
* 1. Parse WAV header
|
||||
* 2. Read PCM samples
|
||||
* 3. Convert to mono
|
||||
* 4. Resample to 16kHz
|
||||
* 5. Normalize to -3 dBFS
|
||||
* 6. VAD silence trimming
|
||||
* 7. Limit to 30 seconds
|
||||
* 8. Convert to 16-bit PCM
|
||||
* 1. Validate input size
|
||||
* 2. Parse WAV header
|
||||
* 3. Validate duration from header (reject too-long audio before decoding)
|
||||
* 4. Read PCM samples
|
||||
* 5. Convert to mono
|
||||
* 6. Resample to 16kHz
|
||||
* 7. Normalize to -3 dBFS
|
||||
* 8. VAD silence trimming
|
||||
* 9. Limit to 30 seconds
|
||||
* 10. Convert to 16-bit PCM
|
||||
*/
|
||||
export async function preprocessAudio(inputBuffer: Buffer): Promise<ProcessedAudio> {
|
||||
// Reject oversized input early to prevent memory exhaustion
|
||||
if (inputBuffer.length > MAX_INPUT_BYTES) {
|
||||
throw new Error(
|
||||
`Audio file too large: ${(inputBuffer.length / 1024 / 1024).toFixed(1)}MB. ` +
|
||||
`Maximum ${(MAX_INPUT_BYTES / 1024 / 1024).toFixed(0)}MB. ` +
|
||||
`Please upload a shorter audio clip (max ${MAX_DURATION_SEC} seconds).`,
|
||||
);
|
||||
}
|
||||
|
||||
// Detect if it's a WAV by checking RIFF header
|
||||
const isWav =
|
||||
inputBuffer.length >= 4 &&
|
||||
@@ -382,6 +398,17 @@ export async function preprocessAudio(inputBuffer: Buffer): Promise<ProcessedAud
|
||||
}
|
||||
|
||||
const { header, dataOffset } = parseWavHeader(inputBuffer);
|
||||
|
||||
// Validate duration from header BEFORE allocating sample buffers.
|
||||
// This prevents loading multi-hour WAV files into memory.
|
||||
const totalSamples = Math.floor(header.dataSize / (header.bitsPerSample / 8) / header.numChannels);
|
||||
const durationSec = totalSamples / header.sampleRate;
|
||||
if (durationSec > MAX_DURATION_SEC + 30) {
|
||||
throw new Error(
|
||||
`Audio too long: ${durationSec.toFixed(1)}s. Maximum ${MAX_DURATION_SEC}s for analysis. ` +
|
||||
`Please trim your audio before uploading.`,
|
||||
);
|
||||
}
|
||||
let samples = readPcmSamples(inputBuffer, header, dataOffset);
|
||||
|
||||
// Convert to mono
|
||||
|
||||
@@ -12,10 +12,13 @@ describe("voiceprint storage", () => {
|
||||
testDir = mkdtempSync(join(tmpdir(), "vp-storage-test-"));
|
||||
userId = "test-user-123";
|
||||
vi.spyOn(process, "cwd").mockReturnValue(testDir);
|
||||
// Set a small quota for testing
|
||||
vi.stubEnv("VOICEPRINT_MAX_USER_STORAGE_BYTES", "1024"); // 1KB for tests
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
vi.unstubAllEnvs();
|
||||
try {
|
||||
rmSync(testDir, { recursive: true, force: true });
|
||||
} catch {
|
||||
@@ -42,6 +45,7 @@ describe("voiceprint storage", () => {
|
||||
expect(result.hash.length).toBe(64);
|
||||
expect(result.filePath).toContain(userId);
|
||||
expect(existsSync(result.filePath)).toBe(true);
|
||||
expect(result.isNew).toBe(true);
|
||||
});
|
||||
|
||||
it("reuses existing directory", async () => {
|
||||
@@ -52,6 +56,69 @@ describe("voiceprint storage", () => {
|
||||
const dir = join(testDir, "uploads", "voiceprint", userId);
|
||||
expect(existsSync(dir)).toBe(true);
|
||||
});
|
||||
|
||||
it("deduplicates identical audio (returns isNew: false)", async () => {
|
||||
const { saveAudio } = await import("./storage");
|
||||
const audioBuffer = Buffer.from("same-audio-content");
|
||||
|
||||
const first = await saveAudio(userId, audioBuffer);
|
||||
expect(first.isNew).toBe(true);
|
||||
|
||||
const second = await saveAudio(userId, audioBuffer);
|
||||
expect(second.isNew).toBe(false);
|
||||
expect(second.hash).toBe(first.hash);
|
||||
expect(second.filePath).toBe(first.filePath);
|
||||
});
|
||||
|
||||
it("throws when storage quota would be exceeded", async () => {
|
||||
const { saveAudio } = await import("./storage");
|
||||
// Quota is 1KB, upload 2KB
|
||||
const largeBuffer = Buffer.alloc(2048, "A");
|
||||
await expect(saveAudio(userId, largeBuffer)).rejects.toThrow("Storage quota exceeded");
|
||||
});
|
||||
|
||||
it("allows upload under quota", async () => {
|
||||
const { saveAudio } = await import("./storage");
|
||||
// Quota is 1KB, upload 100 bytes
|
||||
const smallBuffer = Buffer.alloc(100, "B");
|
||||
const result = await saveAudio(userId, smallBuffer);
|
||||
expect(result.isNew).toBe(true);
|
||||
expect(existsSync(result.filePath)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getUserStorageUsage", () => {
|
||||
it("returns 0 for non-existent user directory", async () => {
|
||||
const { getUserStorageUsage } = await import("./storage");
|
||||
const usage = await getUserStorageUsage("nonexistent-user");
|
||||
expect(usage).toBe(0);
|
||||
});
|
||||
|
||||
it("returns total bytes of all audio files", async () => {
|
||||
const { saveAudio, getUserStorageUsage } = await import("./storage");
|
||||
|
||||
// Save two files
|
||||
await saveAudio(userId, Buffer.from("file-one-content"));
|
||||
await saveAudio(userId, Buffer.from("file-two-content"));
|
||||
|
||||
const usage = await getUserStorageUsage(userId);
|
||||
expect(usage).toBeGreaterThan(0);
|
||||
// Should be exactly the sum of the two file sizes
|
||||
expect(usage).toBe("file-one-content".length + "file-two-content".length);
|
||||
});
|
||||
|
||||
it("deduplication means second identical upload doesn\'t increase usage", async () => {
|
||||
const { saveAudio, getUserStorageUsage } = await import("./storage");
|
||||
const audioBuffer = Buffer.from("identical-content-for-usage-test");
|
||||
|
||||
await saveAudio(userId, audioBuffer);
|
||||
const usageAfterFirst = await getUserStorageUsage(userId);
|
||||
|
||||
await saveAudio(userId, audioBuffer);
|
||||
const usageAfterSecond = await getUserStorageUsage(userId);
|
||||
|
||||
expect(usageAfterFirst).toBe(usageAfterSecond);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getAudioUrl", () => {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { createHash } from "node:crypto";
|
||||
import { writeFile, unlink, mkdir } from "node:fs/promises";
|
||||
import { writeFile, unlink, mkdir, stat, readdir } from "node:fs/promises";
|
||||
import { existsSync } from "node:fs";
|
||||
import { join } from "node:path";
|
||||
|
||||
@@ -11,18 +11,89 @@ function getUserDir(userId: string): string {
|
||||
return join(process.cwd(), "uploads", "voiceprint", userId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Maximum total disk storage per user (default 50MB).
|
||||
* Prevents disk exhaustion via unlimited audio uploads.
|
||||
*/
|
||||
const MAX_USER_STORAGE_BYTES = parseInt(
|
||||
process.env.VOICEPRINT_MAX_USER_STORAGE_BYTES ?? "52428800",
|
||||
10,
|
||||
);
|
||||
|
||||
/**
|
||||
* Calculate total disk usage for a user's voiceprint audio files.
|
||||
*/
|
||||
export async function getUserStorageUsage(userId: string): Promise<number> {
|
||||
const userDir = getUserDir(userId);
|
||||
if (!existsSync(userDir)) return 0;
|
||||
|
||||
const files = await readdir(userDir);
|
||||
let totalBytes = 0;
|
||||
|
||||
for (const file of files) {
|
||||
const filePath = join(userDir, file);
|
||||
try {
|
||||
const stats = await stat(filePath);
|
||||
if (stats.isFile()) {
|
||||
totalBytes += stats.size;
|
||||
}
|
||||
} catch {
|
||||
// Skip files we can't stat
|
||||
}
|
||||
}
|
||||
|
||||
return totalBytes;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if saving a file of the given size would exceed the user's storage quota.
|
||||
* Throws if the quota would be exceeded.
|
||||
*/
|
||||
export async function checkStorageQuota(
|
||||
userId: string,
|
||||
fileSizeBytes: number,
|
||||
): Promise<void> {
|
||||
const currentUsage = await getUserStorageUsage(userId);
|
||||
const projectedUsage = currentUsage + fileSizeBytes;
|
||||
|
||||
if (projectedUsage > MAX_USER_STORAGE_BYTES) {
|
||||
throw new Error(
|
||||
`Storage quota exceeded. User has ${(currentUsage / 1024 / 1024).toFixed(1)}MB ` +
|
||||
`of ${(MAX_USER_STORAGE_BYTES / 1024 / 1024).toFixed(0)}MB allocated. ` +
|
||||
`Upload is ${(fileSizeBytes / 1024).toFixed(0)}KB. Delete old audio files to free space.`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Save audio file with deduplication. If a file with the same hash already exists,
|
||||
* skip writing and return the existing path. This prevents redundant storage and
|
||||
* avoids re-processing identical audio.
|
||||
*
|
||||
* @returns { hash, filePath, isNew } — isNew is false if the file already existed.
|
||||
*/
|
||||
export async function saveAudio(
|
||||
userId: string,
|
||||
audioBuffer: Buffer,
|
||||
): Promise<{ hash: string; filePath: string }> {
|
||||
): Promise<{ hash: string; filePath: string; isNew: boolean }> {
|
||||
const hash = computeHash(audioBuffer);
|
||||
const userDir = getUserDir(userId);
|
||||
const filePath = join(userDir, `${hash}.wav`);
|
||||
|
||||
// Deduplication: if the file already exists, skip writing
|
||||
if (existsSync(filePath)) {
|
||||
return { hash, filePath, isNew: false };
|
||||
}
|
||||
|
||||
// Check storage quota before writing
|
||||
await checkStorageQuota(userId, audioBuffer.length);
|
||||
|
||||
if (!existsSync(userDir)) {
|
||||
await mkdir(userDir, { recursive: true });
|
||||
}
|
||||
const filePath = join(userDir, `${hash}.wav`);
|
||||
|
||||
await writeFile(filePath, audioBuffer);
|
||||
return { hash, filePath };
|
||||
return { hash, filePath, isNew: true };
|
||||
}
|
||||
|
||||
export function getAudioUrl(userId: string, audioHash: string): string {
|
||||
|
||||
Reference in New Issue
Block a user