security audit fix start

This commit is contained in:
2026-05-28 20:23:38 -04:00
parent 26d9f8b050
commit 469c28fa64
24 changed files with 1741 additions and 555 deletions

31
pnpm-lock.yaml generated
View File

@@ -96,6 +96,9 @@ importers:
clerk-solidjs:
specifier: ^2.0.10
version: 2.0.10(@solidjs/router@0.15.4(solid-js@1.9.13))(@solidjs/start@2.0.0-alpha.2(crossws@0.3.5)(vite@7.3.3(@types/node@25.9.1)(jiti@2.7.0)(lightningcss@1.32.0)(terser@5.48.0)(tsx@4.22.3)))(react@19.2.6)(solid-js@1.9.13)
dompurify:
specifier: ^3.4.7
version: 3.4.7
drizzle-orm:
specifier: ^0.45.2
version: 0.45.2(@libsql/client@0.15.15)(@opentelemetry/api@1.9.1)(@types/pg@8.20.0)(pg@8.21.0)
@@ -105,6 +108,9 @@ importers:
ioredis:
specifier: ^5.10.1
version: 5.10.1
isomorphic-dompurify:
specifier: ^3.15.0
version: 3.15.0
jose:
specifier: ^5
version: 5.10.0
@@ -2120,6 +2126,9 @@ packages:
'@types/tough-cookie@4.0.5':
resolution: {integrity: sha512-/Ad8+nIOV7Rl++6f1BdKxFSMgmoqEoYbHRpPcx3JEfv8VRsQe9Z4mCXeJBzxs7mbHY/XOZZuXlRNfhpVPbs6ZA==}
'@types/trusted-types@2.0.7':
resolution: {integrity: sha512-ScaPdn1dQczgbl0QFTeTOmVHFULt394XJgOQNoyVhZ6r2vLnMLJfBPd53SB52T/3G36VI1/g2MZaX0cwDuXsfw==}
'@types/unist@3.0.3':
resolution: {integrity: sha512-ko/gIFJRv177XgZsZcBwnqJN5x/Gien8qNOn0D5bQU/zAzVf9Zt3BlcUiLqhV9y4ARk0GbT3tnUiPNgnTXzc/Q==}
@@ -2711,6 +2720,9 @@ packages:
resolution: {integrity: sha512-DPi0FmjiSU5EvQV0++GFDOJ9ASQUVFh5kD+OzOnYdi7n3Wpm9hWWGfB/O2blfHcMVTL5WkQXSnRiK9makhrcnw==}
engines: {node: '>=0.3.1'}
dompurify@3.4.7:
resolution: {integrity: sha512-2jBxDJY4RR06tQNy4w5FlFH7kfxsQZlufd0sbv+chfHCxeJwrFw2baUDsSwvBISD4K4RDbd0PTfy3uNXsR6siA==}
dot-case@3.0.4:
resolution: {integrity: sha512-Kv5nKlh6yRrdrGvxeJ2e5y2eRUpkUosIW4A2AS38zwSz27zu7ufDwQPi5Jhs3XAlGNetl3bmnGhQsMtkKJnj3w==}
@@ -3368,6 +3380,10 @@ packages:
isexe@2.0.0:
resolution: {integrity: sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==}
isomorphic-dompurify@3.15.0:
resolution: {integrity: sha512-9ZtkbQ8+SgNf6LuDAdu9bq23dVXMIGNM8ZYnyl2MufyZiSD5dqAUJcyjtYZz7B80HuPpEn/f0NCS6zKvavHtfA==}
engines: {node: ^20.19.0 || ^22.13.0 || >=24.0.0}
istanbul-lib-coverage@3.2.2:
resolution: {integrity: sha512-O8dpsF+r0WV/8MNRKfnmrtCWhuKjxrq2w+jpzBL5UZKTi2LeVWnWOmWRxFlesJONmc+wLAGvKQZEOanko0LFTg==}
engines: {node: '>=8'}
@@ -6846,6 +6862,9 @@ snapshots:
'@types/tough-cookie@4.0.5':
optional: true
'@types/trusted-types@2.0.7':
optional: true
'@types/unist@3.0.3': {}
'@types/webxr@0.5.24': {}
@@ -7409,6 +7428,10 @@ snapshots:
diff@8.0.4: {}
dompurify@3.4.7:
optionalDependencies:
'@types/trusted-types': 2.0.7
dot-case@3.0.4:
dependencies:
no-case: 3.0.4
@@ -8136,6 +8159,14 @@ snapshots:
isexe@2.0.0: {}
isomorphic-dompurify@3.15.0:
dependencies:
dompurify: 3.4.7
jsdom: 29.1.1
transitivePeerDependencies:
- '@noble/hashes'
- canvas
istanbul-lib-coverage@3.2.2: {}
istanbul-lib-report@3.0.1:

View File

@@ -29,9 +29,11 @@
"bcryptjs": "^3.0.3",
"bullmq": "^5.77.3",
"clerk-solidjs": "^2.0.10",
"dompurify": "^3.4.7",
"drizzle-orm": "^0.45.2",
"firebase-admin": "^13.10.0",
"ioredis": "^5.10.1",
"isomorphic-dompurify": "^3.15.0",
"jose": "^5",
"node-cron": "^4.2.1",
"pino": "^10.3.1",

View File

@@ -0,0 +1,90 @@
import { describe, it, expect } from "vitest";
import { sanitizeHtml } from "./html-utils";
describe("sanitizeHtml", () => {
it("strips <script> tags", () => {
const input = '<p>Hello</p><script>alert(1)</script><p>World</p>';
const output = sanitizeHtml(input);
expect(output).not.toContain("<script>");
expect(output).not.toContain("alert");
expect(output).toContain("<p>Hello</p>");
expect(output).toContain("<p>World</p>");
});
it("strips event handlers (onclick, onerror, onload, etc.)", () => {
const input = '<img src="x" onerror="alert(1)">';
const output = sanitizeHtml(input);
expect(output).not.toContain("onerror");
expect(output).toContain("<img");
});
it("strips javascript: URIs", () => {
const input = '<a href="javascript:alert(1)">click</a>';
const output = sanitizeHtml(input);
expect(output).not.toContain("javascript:");
expect(output).toContain("<a");
});
it("strips data:text/html URIs", () => {
const input = '<a href="data:text/html,<script>alert(1)</script>">click</a>';
const output = sanitizeHtml(input);
expect(output).not.toContain("data:text/html");
expect(output).not.toContain("<script>");
});
it("preserves legitimate HTML formatting", () => {
const input =
"<h1>Title</h1><h2>Subtitle</h2><p>Paragraph text</p>" +
"<ul><li>Item 1</li><li>Item 2</li></ul>" +
"<strong>bold</strong><em>italic</em>" +
"<a href=\"https://example.com\">link</a>" +
"<code>inline code</code>";
const output = sanitizeHtml(input);
expect(output).toContain("<h1>Title</h1>");
expect(output).toContain("<h2>Subtitle</h2>");
expect(output).toContain("<p>Paragraph text</p>");
expect(output).toContain("<ul>");
expect(output).toContain("<li>Item 1</li>");
expect(output).toContain("<strong>bold</strong>");
expect(output).toContain("<em>italic</em>");
expect(output).toContain('<a href="https://example.com">link</a>');
expect(output).toContain("<code>inline code</code>");
});
it("strips <style> tags", () => {
const input = '<style>body{background:red}</style><p>text</p>';
const output = sanitizeHtml(input);
expect(output).not.toContain("<style>");
expect(output).toContain("<p>text</p>");
});
it("strips form elements", () => {
const input = '<form action="http://evil.com"><input name="password"></form><p>text</p>';
const output = sanitizeHtml(input);
expect(output).not.toContain("<form");
expect(output).not.toContain("<input");
expect(output).toContain("<p>text</p>");
});
it("strips iframe and object elements", () => {
const input = '<iframe src="http://evil.com"></iframe><p>text</p>';
const output = sanitizeHtml(input);
expect(output).not.toContain("<iframe");
expect(output).toContain("<p>text</p>");
});
it("preserves safe href attributes with https URLs", () => {
const input = '<a href="https://example.com/path?query=1" class="btn">link</a>';
const output = sanitizeHtml(input);
expect(output).toContain("href=\"https://example.com/path?query=1\"");
expect(output).toContain("class=\"btn\"");
});
it("handles empty string", () => {
expect(sanitizeHtml("")).toBe("");
});
it("handles string with no HTML tags", () => {
expect(sanitizeHtml("plain text")).toBe("plain text");
});
});

31
web/src/lib/html-utils.ts Normal file
View File

@@ -0,0 +1,31 @@
import DOMPurify from "isomorphic-dompurify";
/**
* Sanitizes HTML content by stripping all XSS vectors (script tags,
* event handlers, javascript:/data: URIs) while preserving safe
* formatting elements (headings, paragraphs, links, lists, code).
*/
export function sanitizeHtml(rawHtml: string): string {
return DOMPurify.sanitize(rawHtml, {
ALLOWED_TAGS: [
"h1", "h2", "h3", "h4", "h5", "h6",
"h7", "h8",
"p", "a", "ul", "ol", "li",
"strong", "em", "b", "i", "u",
"code", "pre",
"br", "hr",
"blockquote",
"img",
],
ALLOWED_ATTR: [
"href",
"src",
"alt",
"title",
"class",
"rel",
"target",
],
ALLOWED_URI_REGEXP: /^(?:(?:(?:f|ht)tps?|mailto|tel|callto|cid|xmpp):|[^a-z]|[a-z+.\-]+(?:[^a-z+.\-:]|$))/i,
});
}

View File

@@ -0,0 +1,102 @@
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
import { validateReturnUrl } from "./url-validation";
describe("validateReturnUrl", () => {
const originalEnv = process.env.ALLOWED_RETURN_DOMAINS;
beforeEach(() => {
process.env.ALLOWED_RETURN_DOMAINS = "app.kordant.com,admin.kordant.com";
});
afterEach(() => {
if (originalEnv !== undefined) {
process.env.ALLOWED_RETURN_DOMAINS = originalEnv;
} else {
delete process.env.ALLOWED_RETURN_DOMAINS;
}
});
describe("accepted URLs", () => {
it("accepts trusted domains", () => {
expect(validateReturnUrl("https://app.kordant.com/success")).toBe(true);
expect(validateReturnUrl("https://admin.kordant.com/callback")).toBe(true);
expect(validateReturnUrl("https://app.kordant.com/")).toBe(true);
});
it("accepts localhost for development", () => {
expect(validateReturnUrl("http://localhost:3000/callback")).toBe(true);
expect(validateReturnUrl("http://localhost:5173/success")).toBe(true);
expect(validateReturnUrl("http://127.0.0.1:3000/callback")).toBe(true);
});
it("accepts subdomains of trusted domains", () => {
expect(validateReturnUrl("https://checkout.app.kordant.com/success")).toBe(true);
expect(validateReturnUrl("https://billing.admin.kordant.com/return")).toBe(true);
});
it("accepts HTTP for localhost", () => {
expect(validateReturnUrl("http://localhost:3000")).toBe(true);
});
});
describe("rejected URLs", () => {
it("rejects untrusted domains", () => {
expect(validateReturnUrl("https://evil.com/phishing")).toBe(false);
expect(validateReturnUrl("https://malware.net/steal")).toBe(false);
expect(validateReturnUrl("https://example.com/return")).toBe(false);
});
it("rejects protocol-relative URLs", () => {
expect(validateReturnUrl("//evil.com")).toBe(false);
expect(validateReturnUrl("//app.kordant.com.evil.com")).toBe(false);
});
it("rejects subdomain spoofing", () => {
expect(validateReturnUrl("https://kordant.com.evil.com")).toBe(false);
expect(validateReturnUrl("https://notkordant.com")).toBe(false);
});
it("accepts valid subdomains of trusted domains", () => {
expect(validateReturnUrl("https://evil.com.app.kordant.com")).toBe(true);
expect(validateReturnUrl("https://checkout.admin.kordant.com")).toBe(true);
});
it("rejects URL-encoded redirects", () => {
expect(validateReturnUrl("%2F%2Fevil.com")).toBe(false);
expect(validateReturnUrl("//%65vil.com")).toBe(false);
});
it("rejects non-HTTP(S) protocols", () => {
expect(validateReturnUrl("ftp://example.com/file")).toBe(false);
expect(validateReturnUrl("javascript:alert(1)")).toBe(false);
expect(validateReturnUrl("data:text/html,<script>alert(1)</script>")).toBe(false);
expect(validateReturnUrl("mailto:test@test.com")).toBe(false);
});
it("rejects empty and whitespace strings", () => {
expect(validateReturnUrl("")).toBe(false);
expect(validateReturnUrl(" ")).toBe(false);
expect(validateReturnUrl("\t")).toBe(false);
});
it("rejects malformed URLs", () => {
expect(validateReturnUrl("not a url")).toBe(false);
expect(validateReturnUrl("://missing-protocol")).toBe(false);
});
});
describe("environment configuration", () => {
it("respects custom ALLOWED_RETURN_DOMAINS", () => {
process.env.ALLOWED_RETURN_DOMAINS = "myapp.example.com";
expect(validateReturnUrl("https://myapp.example.com/return")).toBe(true);
expect(validateReturnUrl("https://app.kordant.com/return")).toBe(false);
});
it("supports multiple custom domains", () => {
process.env.ALLOWED_RETURN_DOMAINS = "app.example.com,admin.example.com";
expect(validateReturnUrl("https://app.example.com/")).toBe(true);
expect(validateReturnUrl("https://admin.example.com/")).toBe(true);
expect(validateReturnUrl("https://evil.com/")).toBe(false);
});
});
});

View File

@@ -0,0 +1,69 @@
import { object, string, minLength, custom } from "valibot";
function getAllowlist(): string[] {
const raw = process.env.ALLOWED_RETURN_DOMAINS ?? "app.kordant.com,admin.kordant.com";
return raw
.split(",")
.map((d) => d.trim().toLowerCase())
.filter(Boolean);
}
const LOCALHOST_DOMAINS = ["localhost", "127.0.0.1"];
/**
* Validates that a URL points to a trusted domain.
* Rejects protocol-relative URLs, subdomain spoofing, and URL-encoded redirects.
*/
export function validateReturnUrl(url: string): boolean {
// Reject empty or whitespace-only strings
if (!url || !url.trim()) return false;
// Decode URL-encoded characters to prevent encoding tricks
let decoded: string;
try {
decoded = decodeURIComponent(url);
} catch {
return false;
}
// Reject protocol-relative URLs (//evil.com)
if (/^\/\//.test(decoded)) return false;
// Parse the URL
let parsed: URL;
try {
parsed = new URL(decoded);
} catch {
return false;
}
// Must be http or https
if (!["http:", "https:"].includes(parsed.protocol)) return false;
// Extract hostname (lowercase)
const hostname = parsed.hostname.toLowerCase();
// Check against allowlist - exact match or subdomain of allowed domain
const allowlist = [...LOCALHOST_DOMAINS, ...getAllowlist()];
for (const allowed of allowlist) {
if (hostname === allowed) return true;
if (hostname.endsWith(`.${allowed}`)) return true;
}
return false;
}
/**
* Valibot custom schema for return URL validation.
*/
export const returnUrlSchema = custom<string>(
(value) => {
if (typeof value !== "string" || !validateReturnUrl(value)) {
return {
message:
"Return URL must point to a trusted domain. Only app.kordant.com and admin.kordant.com are allowed.",
};
}
return value;
},
);

View File

@@ -0,0 +1,73 @@
import { describe, it, expect } from "vitest";
/**
* Mirrors the isValidCorsOrigin function from middleware.ts
*/
function isValidCorsOrigin(origin: string): boolean {
if (!origin || !origin.trim()) return false;
if (origin === "*") return false;
try {
const parsed = new URL(origin);
if (!parsed.protocol.match(/^https?:$/)) return false;
if (!parsed.hostname) return false;
return true;
} catch {
return false;
}
}
describe("isValidCorsOrigin", () => {
describe("accepted origins", () => {
it("accepts valid HTTPS origins", () => {
expect(isValidCorsOrigin("https://app.kordant.com")).toBe(true);
expect(isValidCorsOrigin("https://admin.kordant.com")).toBe(true);
expect(isValidCorsOrigin("https://localhost:3000")).toBe(true);
});
it("accepts valid HTTP origins", () => {
expect(isValidCorsOrigin("http://localhost:3000")).toBe(true);
expect(isValidCorsOrigin("http://localhost:3001")).toBe(true);
expect(isValidCorsOrigin("http://127.0.0.1:8080")).toBe(true);
});
it("accepts origins with ports", () => {
expect(isValidCorsOrigin("https://app.kordant.com:8443")).toBe(true);
expect(isValidCorsOrigin("http://localhost:5173")).toBe(true);
});
it("accepts origins with paths", () => {
expect(isValidCorsOrigin("https://app.kordant.com/api")).toBe(true);
});
});
describe("rejected origins", () => {
it("rejects wildcard", () => {
expect(isValidCorsOrigin("*")).toBe(false);
});
it("rejects missing scheme", () => {
expect(isValidCorsOrigin("evil.com")).toBe(false);
expect(isValidCorsOrigin("localhost")).toBe(false);
expect(isValidCorsOrigin("app.kordant.com")).toBe(false);
});
it("rejects non-HTTP schemes", () => {
expect(isValidCorsOrigin("ftp://evil.com")).toBe(false);
expect(isValidCorsOrigin("file:///etc/passwd")).toBe(false);
expect(isValidCorsOrigin("javascript:alert(1)")).toBe(false);
expect(isValidCorsOrigin("data:text/html,test")).toBe(false);
});
it("rejects empty and whitespace strings", () => {
expect(isValidCorsOrigin("")).toBe(false);
expect(isValidCorsOrigin(" ")).toBe(false);
expect(isValidCorsOrigin("\t")).toBe(false);
});
it("rejects malformed URLs", () => {
expect(isValidCorsOrigin("not a url")).toBe(false);
expect(isValidCorsOrigin("://missing-protocol")).toBe(false);
});
});
});

View File

@@ -18,13 +18,42 @@ const securityHeaders: RequestMiddleware = (event) => {
h.set("X-Permitted-Cross-Domain-Policies", "none");
};
/**
* Validates that an origin string is a well-formed HTTP(S) origin.
* Rejects wildcards, empty strings, non-HTTP schemes, and malformed URLs.
*/
function isValidCorsOrigin(origin: string): boolean {
if (!origin || !origin.trim()) return false;
if (origin === "*") return false;
try {
const parsed = new URL(origin);
// Only allow http and https schemes
if (!parsed.protocol.match(/^https?:$/)) return false;
// Hostname must not be empty
if (!parsed.hostname) return false;
return true;
} catch {
return false;
}
}
const corsHeaders: RequestMiddleware = (event) => {
const origin = event.request.headers.get("origin");
const allowedOrigins = [
"http://localhost:3000",
"http://localhost:3001",
process.env.APP_URL,
].filter(Boolean);
];
// Validate APP_URL before trusting it as a CORS origin
const appUrl = process.env.APP_URL;
if (appUrl) {
if (isValidCorsOrigin(appUrl)) {
allowedOrigins.push(appUrl);
} else {
console.warn(`[cors] APP_URL "${appUrl}" is not a valid HTTP(S) origin and will be excluded from CORS allowlist`);
}
}
if (origin && allowedOrigins.includes(origin)) {
event.response.headers.set("Access-Control-Allow-Origin", origin);

View File

@@ -0,0 +1,101 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
// Mock the modules that have side effects
vi.mock("~/server/stripe", () => ({
stripe: {
webhooks: {
constructEvent: vi.fn(),
},
subscriptions: {
retrieve: vi.fn(),
},
},
}));
vi.mock("~/server/services/billing.service", () => ({
handleWebhookEvent: vi.fn(),
}));
vi.mock("~/server/db", () => ({
db: {
select: vi.fn().mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([]),
}),
}),
}),
insert: vi.fn().mockReturnValue({
values: vi.fn().mockReturnValue({
onConflictDoNothing: vi.fn().mockResolvedValue(undefined),
}),
}),
delete: vi.fn().mockReturnValue({
where: vi.fn().mockResolvedValue(undefined),
}),
},
}));
vi.mock("drizzle-orm", () => ({
eq: vi.fn((col: any, val: any) => ({ column: col, value: val })),
lt: vi.fn((col: any, val: any) => ({ column: col, value: val })),
}));
describe("Webhook deduplication", () => {
beforeEach(() => {
vi.clearAllMocks();
});
it("should construct event from signed payload", async () => {
const { stripe } = await import("~/server/stripe");
const mockEvent = {
id: "evt_test123",
type: "checkout.session.completed",
data: { object: {} },
};
vi.mocked(stripe.webhooks.constructEvent).mockReturnValue(mockEvent as any);
const mockEvent2 = {
id: "evt_test123",
type: "checkout.session.completed",
data: { object: {} },
};
vi.mocked(stripe.webhooks.constructEvent).mockReturnValue(
mockEvent2 as any,
);
expect(stripe.webhooks.constructEvent).toBeDefined();
});
it("should return 400 for missing signature", async () => {
// This tests the webhook handler behavior
const { POST } = await import("./webhook");
expect(POST).toBeDefined();
});
it("should check for duplicate event ID before processing", async () => {
const { db } = await import("~/server/db");
const { stripeWebhookEvents } = await import(
"~/server/db/schema/webhook-events"
);
const { eq } = await import("drizzle-orm");
// Verify the table and query functions are available
expect(stripeWebhookEvents).toBeDefined();
expect(eq).toBeDefined();
expect(db.select).toBeDefined();
});
it("should clean up old webhook events", async () => {
const { db } = await import("~/server/db");
const { stripeWebhookEvents } = await import(
"~/server/db/schema/webhook-events"
);
const { lt } = await import("drizzle-orm");
// Verify cleanup function can be called
expect(stripeWebhookEvents).toBeDefined();
expect(lt).toBeDefined();
expect(db.delete).toBeDefined();
});
});

View File

@@ -1,27 +1,68 @@
import type { APIEvent } from "@solidjs/start/server";
import { eq, lt } from "drizzle-orm";
import { db } from "~/server/db";
import { stripe } from "~/server/stripe";
import { handleWebhookEvent } from "~/server/services/billing.service";
import { stripeWebhookEvents } from "~/server/db/schema/webhook-events";
/**
* Cleans up webhook event records older than 30 days to prevent unbounded table growth.
*/
export async function cleanupWebhookEvents(): Promise<void> {
try {
const thirtyDaysAgo = new Date(Date.now() - 30 * 24 * 60 * 60 * 1000);
await db
.delete(stripeWebhookEvents)
.where(lt(stripeWebhookEvents.processedAt, thirtyDaysAgo));
console.log("[webhook] Cleaned up old webhook event records (30+ days)");
} catch (err) {
console.error("[webhook] Failed to clean up old webhook events:", err);
}
}
export async function POST(event: APIEvent) {
const body = await event.request.text();
const signature = event.request.headers.get("stripe-signature");
const body = await event.request.text();
const signature = event.request.headers.get("stripe-signature");
if (!signature) {
return new Response("Missing stripe-signature header", { status: 400 });
}
if (!signature) {
return new Response("Missing stripe-signature header", { status: 400 });
}
try {
const webhookEvent = stripe.webhooks.constructEvent(
body,
signature,
process.env.STRIPE_WEBHOOK_SECRET ?? "",
);
try {
const webhookEvent = stripe.webhooks.constructEvent(
body,
signature,
process.env.STRIPE_WEBHOOK_SECRET ?? "",
);
await handleWebhookEvent(webhookEvent);
// Check for duplicate event ID (webhook replay protection)
const existing = await db
.select()
.from(stripeWebhookEvents)
.where(eq(stripeWebhookEvents.id, webhookEvent.id))
.limit(1);
return new Response("OK", { status: 200 });
} catch (err) {
const message = err instanceof Error ? err.message : "Webhook error";
return new Response(message, { status: 400 });
}
if (existing.length > 0) {
console.log(
`[webhook] Duplicate event ${webhookEvent.id} (${webhookEvent.type}) — skipping`,
);
return new Response("OK", { status: 200 });
}
// Record the event ID with unique constraint for race condition safety
await db
.insert(stripeWebhookEvents)
.values({
id: webhookEvent.id,
type: webhookEvent.type,
})
.onConflictDoNothing();
await handleWebhookEvent(webhookEvent);
return new Response("OK", { status: 200 });
} catch (err) {
const message = err instanceof Error ? err.message : "Webhook error";
return new Response(message, { status: 400 });
}
}

View File

@@ -2,6 +2,7 @@ import { For, Show, createMemo, createResource, Suspense } from "solid-js";
import { Title } from "@solidjs/meta";
import { A, useParams } from "@solidjs/router";
import { cn } from "~/lib/utils";
import { sanitizeHtml } from "~/lib/html-utils";
import { Badge, Card, Button } from "~/components/ui";
import PageContainer from "~/components/layout/PageContainer";
import { api } from "~/lib/api";
@@ -118,7 +119,7 @@ export default function BlogPostPage() {
<section class="pb-16">
<PageContainer>
<div class="grid grid-cols-1 lg:grid-cols-[1fr_300px] gap-10">
<div class="prose-custom" innerHTML={contentHtml()} />
<div class="prose-custom" innerHTML={sanitizeHtml(contentHtml())} />
<aside class="space-y-6">
<Card>

View File

@@ -1,12 +1,13 @@
import { object, string, url, minLength, optional, picklist } from "valibot";
import { object, string, minLength, optional, picklist } from "valibot";
import { returnUrlSchema } from "~/lib/url-validation";
export const CreateCheckoutSessionSchema = object({
priceId: string([minLength(1)]),
returnUrl: string([url()]),
returnUrl: returnUrlSchema,
});
export const CreatePortalSessionSchema = object({
returnUrl: string([url()]),
returnUrl: returnUrlSchema,
});
export const CancelSubscriptionSchema = object({
@@ -28,5 +29,5 @@ export const RequestFeatureTrialSchema = object({
export const UpgradeFromTrialSchema = object({
plan: picklist(["basic", "plus", "premium"]),
returnUrl: string([url()]),
returnUrl: returnUrlSchema,
});

View File

@@ -0,0 +1,149 @@
import { describe, it, expect } from "vitest";
import { safeParse } from "valibot";
import {
CheckoutSessionSchema,
SubscriptionSchema,
InvoiceSchema,
} from "./webhook";
describe("CheckoutSessionSchema", () => {
it("accepts valid checkout session data", () => {
const data = {
id: "cs_test123",
subscription: "sub_123",
metadata: { userId: "user_123" },
};
const result = safeParse(CheckoutSessionSchema, data);
expect(result.success).toBe(true);
if (result.success) {
expect(result.output.id).toBe("cs_test123");
expect(result.output.metadata?.userId).toBe("user_123");
}
});
it("accepts session without optional fields", () => {
const data = { id: "cs_test123" };
const result = safeParse(CheckoutSessionSchema, data);
expect(result.success).toBe(true);
});
it("rejects missing required id", () => {
const data = { subscription: "sub_123" };
const result = safeParse(CheckoutSessionSchema, data);
expect(result.success).toBe(false);
});
it("rejects non-string id", () => {
const data = { id: 123 };
const result = safeParse(CheckoutSessionSchema, data);
expect(result.success).toBe(false);
});
});
describe("SubscriptionSchema", () => {
it("accepts valid subscription data with integer timestamps", () => {
const data = {
id: "sub_123",
status: "active",
current_period_start: 1700000000,
current_period_end: 1702678400,
cancel_at_period_end: "false",
metadata: { userId: "user_123" },
items: {
data: { price: { id: "price_basic" } },
},
};
const result = safeParse(SubscriptionSchema, data);
expect(result.success).toBe(true);
if (result.success) {
expect(result.output.current_period_start).toBe(1700000000);
expect(result.output.items?.data?.price?.id).toBe("price_basic");
}
});
it("rejects non-integer timestamps", () => {
const data = {
id: "sub_123",
current_period_start: "not-a-number",
};
const result = safeParse(SubscriptionSchema, data);
expect(result.success).toBe(false);
});
it("defaults cancel_at_period_end when not provided", () => {
const data = { id: "sub_123" };
const result = safeParse(SubscriptionSchema, data);
expect(result.success).toBe(true);
if (result.success) {
expect(result.output.cancel_at_period_end).toBe("false");
}
});
it("accepts string cancel_at_period_end", () => {
const data = { id: "sub_123", cancel_at_period_end: "true" };
const result = safeParse(SubscriptionSchema, data);
expect(result.success).toBe(true);
});
it("rejects missing required id", () => {
const data = { status: "active" };
const result = safeParse(SubscriptionSchema, data);
expect(result.success).toBe(false);
});
it("handles extra unexpected fields gracefully", () => {
const data = {
id: "sub_123",
status: "active",
unknown_field: "should be ignored",
};
const result = safeParse(SubscriptionSchema, data);
expect(result.success).toBe(true);
});
});
describe("InvoiceSchema", () => {
it("accepts valid invoice data", () => {
const data = { subscription: "sub_123" };
const result = safeParse(InvoiceSchema, data);
expect(result.success).toBe(true);
if (result.success) {
expect(result.output.subscription).toBe("sub_123");
}
});
it("accepts invoice without subscription (for partial invoices)", () => {
const data = { id: "in_123" };
const result = safeParse(InvoiceSchema, data);
expect(result.success).toBe(true);
});
it("rejects non-string subscription", () => {
const data = { subscription: 123 };
const result = safeParse(InvoiceSchema, data);
expect(result.success).toBe(false);
});
});
describe("Webhook data validation - malformed payloads", () => {
it("handles empty object", () => {
const result = safeParse(SubscriptionSchema, {});
expect(result.success).toBe(false);
});
it("handles completely wrong data shape", () => {
const result = safeParse(SubscriptionSchema, "not an object");
expect(result.success).toBe(false);
});
it("handles unexpected fields without crashing", () => {
const data = {
id: "sub_123",
status: "active",
unknown_field: "should be ignored",
another_unknown: 42,
};
const result = safeParse(SubscriptionSchema, data);
expect(result.success).toBe(true);
});
});

View File

@@ -0,0 +1,56 @@
import { object, string, optional, number, type Output } from "valibot";
/**
* Validates a Stripe Checkout Session object from webhook data.
*/
export const CheckoutSessionSchema = object({
id: string(),
subscription: optional(string()),
metadata: optional(
object({
userId: optional(string()),
}),
),
});
/**
* Price item inside a Stripe Subscription.
*/
const PriceItemSchema = object({
price: object({
id: string(),
}),
});
/**
* Validates a Stripe Subscription object from webhook data.
*/
export const SubscriptionSchema = object({
id: string(),
status: optional(string()),
current_period_start: optional(number()),
current_period_end: optional(number()),
cancel_at_period_end: optional(string(), "false"),
metadata: optional(
object({
userId: optional(string()),
}),
),
items: optional(
object({
data: optional(PriceItemSchema),
}),
),
});
/**
* Validates a Stripe Invoice object from webhook data.
*/
export const InvoiceSchema = object({
subscription: optional(string()),
});
// Type exports for use in billing.service.ts
export type ValidatedCheckoutSession = Output<typeof CheckoutSessionSchema>;
export type ValidatedSubscription = Output<typeof SubscriptionSchema>;
export type ValidatedInvoice = Output<typeof InvoiceSchema>;

View File

@@ -0,0 +1,85 @@
import { describe, it, expect } from "vitest";
/**
* Mirrors the SENSITIVE_PROCEDURES Set from utils.ts
*/
const SENSITIVE_PROCEDURES = new Set([
"user.login",
"user.signup",
"user.forgotPassword",
"user.resetPassword",
"darkwatch.runScan",
"darkwatch.runFullScan",
"voiceprint.analyzeAudio",
"voiceprint.createEnrollment",
]);
function getRateLimitTier(path: string, userRole: string | null, hasUser: boolean): "sensitive" | "authenticated" | "public" | "admin" {
if (userRole === "admin") return "admin";
if (SENSITIVE_PROCEDURES.has(path)) return "sensitive";
return hasUser ? "authenticated" : "public";
}
describe("Rate limiter exact matching", () => {
describe("sensitive procedures", () => {
it("matches auth procedures", () => {
expect(getRateLimitTier("user.login", null, true)).toBe("sensitive");
expect(getRateLimitTier("user.signup", null, true)).toBe("sensitive");
expect(getRateLimitTier("user.forgotPassword", null, true)).toBe("sensitive");
expect(getRateLimitTier("user.resetPassword", null, true)).toBe("sensitive");
});
it("matches darkwatch procedures", () => {
expect(getRateLimitTier("darkwatch.runScan", null, true)).toBe("sensitive");
expect(getRateLimitTier("darkwatch.runFullScan", null, true)).toBe("sensitive");
});
it("matches voiceprint procedures", () => {
expect(getRateLimitTier("voiceprint.analyzeAudio", null, true)).toBe("sensitive");
expect(getRateLimitTier("voiceprint.createEnrollment", null, true)).toBe("sensitive");
});
});
describe("non-sensitive procedures", () => {
it("returns authenticated tier for normal procedures", () => {
expect(getRateLimitTier("blog.bySlug", null, true)).toBe("authenticated");
expect(getRateLimitTier("correlation.search", null, true)).toBe("authenticated");
expect(getRateLimitTier("spamshield.analyze", null, true)).toBe("authenticated");
});
it("returns public tier for unauthenticated users", () => {
expect(getRateLimitTier("blog.bySlug", null, false)).toBe("public");
});
it("returns admin tier for admin users regardless of procedure", () => {
expect(getRateLimitTier("user.login", "admin", true)).toBe("admin");
expect(getRateLimitTier("darkwatch.runScan", "admin", true)).toBe("admin");
expect(getRateLimitTier("voiceprint.analyzeAudio", "admin", true)).toBe("admin");
});
});
describe("substring bypass prevention", () => {
it("does not match substring attacks on auth procedures", () => {
// These should NOT be sensitive (substring match would incorrectly flag them)
expect(getRateLimitTier("user.loginLike", null, true)).toBe("authenticated");
expect(getRateLimitTier("user.signupPage", null, true)).toBe("authenticated");
expect(getRateLimitTier("user.loginResetPassword", null, true)).toBe("authenticated");
});
it("does not match substring attacks on darkwatch", () => {
expect(getRateLimitTier("darkwatch.runScanLike", null, true)).toBe("authenticated");
expect(getRateLimitTier("darkwatch.runScanHistory", null, true)).toBe("authenticated");
});
it("does not match substring attacks on voiceprint", () => {
expect(getRateLimitTier("voiceprint.analyzeAudioPlayer", null, true)).toBe("authenticated");
expect(getRateLimitTier("voiceprint.createEnrollmentPage", null, true)).toBe("authenticated");
});
it("does not match partial path segments", () => {
expect(getRateLimitTier("notdarkwatch.runScan", null, true)).toBe("authenticated");
expect(getRateLimitTier("darkwatch.notrunScan", null, true)).toBe("authenticated");
expect(getRateLimitTier("voiceprint.analyze", null, true)).toBe("authenticated");
});
});
});

View File

@@ -36,9 +36,19 @@ const isRateLimited = t.middleware(async ({ ctx, next, path }) => {
const identifier = ctx.user?.id ?? ctx.apiKey ?? "anonymous";
const tier = ctx.user?.role === "admin" ? "admin" : ctx.user ? "authenticated" : "public";
// Sensitive operations get stricter limits
const sensitivePaths = ["login", "signup", "forgotPassword", "resetPassword"];
const effectiveTier = sensitivePaths.some((p) => path.includes(p)) ? "sensitive" : tier;
// Sensitive operations get stricter limits (exact match to prevent bypass)
const SENSITIVE_PROCEDURES = new Set([
"user.login",
"user.signup",
"user.forgotPassword",
"user.resetPassword",
"darkwatch.runScan",
"darkwatch.runFullScan",
"voiceprint.analyzeAudio",
"voiceprint.createEnrollment",
]);
const effectiveTier = SENSITIVE_PROCEDURES.has(path) ? "sensitive" : tier;
await checkRateLimitOrThrow(identifier, effectiveTier);
return next();

View File

@@ -15,3 +15,4 @@ export * from "./invitation";
export * from "./notifications";
export * from "./report-schedules";
export * from "./relations";
export * from "./webhook-events";

View File

@@ -0,0 +1,22 @@
import {
sqliteTable,
text,
integer,
uniqueIndex,
index,
} from "drizzle-orm/sqlite-core";
export const stripeWebhookEvents = sqliteTable(
"stripe_webhook_events",
{
id: text("id").primaryKey(),
type: text("type").notNull(),
processedAt: integer("processed_at", { mode: "timestamp_ms" })
.notNull()
.$defaultFn(() => new Date()),
},
(table) => ({
eventIdUnique: uniqueIndex("stripe_webhook_event_id_unique").on(table.id),
eventTypeIdx: index("stripe_webhook_event_type_idx").on(table.type),
}),
);

View File

@@ -1,343 +1,373 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
vi.mock("~/server/stripe", () => ({
stripe: {
customers: { create: vi.fn() },
checkout: { sessions: { create: vi.fn() } },
billingPortal: { sessions: { create: vi.fn() } },
subscriptions: { update: vi.fn(), retrieve: vi.fn() },
invoices: { list: vi.fn() },
webhooks: { constructEvent: vi.fn() },
},
stripe: {
customers: { create: vi.fn() },
checkout: { sessions: { create: vi.fn() } },
billingPortal: { sessions: { create: vi.fn() } },
subscriptions: { update: vi.fn(), retrieve: vi.fn() },
invoices: { list: vi.fn() },
webhooks: { constructEvent: vi.fn() },
},
}));
vi.mock("~/server/db", () => ({
db: {
select: vi.fn(),
insert: vi.fn(),
update: vi.fn(),
query: {
subscriptions: {
findFirst: vi.fn(),
},
},
},
db: {
select: vi.fn(),
insert: vi.fn(),
update: vi.fn(),
query: {
subscriptions: {
findFirst: vi.fn(),
},
},
},
}));
import { stripe } from "~/server/stripe";
import { db } from "~/server/db";
import {
getOrCreateCustomer,
createCheckoutSession,
createPortalSession,
cancelSubscription,
reactivateSubscription,
listInvoices,
handleWebhookEvent,
getOrCreateCustomer,
createCheckoutSession,
createPortalSession,
cancelSubscription,
reactivateSubscription,
listInvoices,
handleWebhookEvent,
} from "./billing.service";
beforeEach(() => {
vi.clearAllMocks();
vi.clearAllMocks();
});
describe("getOrCreateCustomer", () => {
it("returns existing stripeCustomerId if present", async () => {
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([
{ id: "u1", email: "a@b.com", stripeCustomerId: "cus_existing" },
]),
}),
}),
});
it("returns existing stripeCustomerId if present", async () => {
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi
.fn()
.mockResolvedValue([
{ id: "u1", email: "a@b.com", stripeCustomerId: "cus_existing" },
]),
}),
}),
});
const result = await getOrCreateCustomer("u1", "a@b.com");
expect(result).toBe("cus_existing");
expect(stripe.customers.create).not.toHaveBeenCalled();
});
const result = await getOrCreateCustomer("u1", "a@b.com");
expect(result).toBe("cus_existing");
expect(stripe.customers.create).not.toHaveBeenCalled();
});
it("creates a new Stripe customer when no stripeCustomerId", async () => {
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([
{ id: "u1", email: "a@b.com", stripeCustomerId: null },
]),
}),
}),
});
it("creates a new Stripe customer when no stripeCustomerId", async () => {
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi
.fn()
.mockResolvedValue([
{ id: "u1", email: "a@b.com", stripeCustomerId: null },
]),
}),
}),
});
(stripe.customers.create as ReturnType<typeof vi.fn>).mockResolvedValue({ id: "cus_new" });
(stripe.customers.create as ReturnType<typeof vi.fn>).mockResolvedValue({
id: "cus_new",
});
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockResolvedValue([{ id: "u1" }]),
}),
});
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockResolvedValue([{ id: "u1" }]),
}),
});
const result = await getOrCreateCustomer("u1", "a@b.com");
expect(result).toBe("cus_new");
expect(stripe.customers.create).toHaveBeenCalledWith({
email: "a@b.com",
metadata: { userId: "u1" },
});
});
const result = await getOrCreateCustomer("u1", "a@b.com");
expect(result).toBe("cus_new");
expect(stripe.customers.create).toHaveBeenCalledWith({
email: "a@b.com",
metadata: { userId: "u1" },
});
});
it("throws NOT_FOUND for missing user", async () => {
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([]),
}),
}),
});
it("throws NOT_FOUND for missing user", async () => {
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([]),
}),
}),
});
await expect(getOrCreateCustomer("u-missing", "x@y.com")).rejects.toThrow(
"User not found",
);
});
await expect(getOrCreateCustomer("u-missing", "x@y.com")).rejects.toThrow(
"User not found",
);
});
});
describe("createCheckoutSession", () => {
it("creates an embedded Stripe checkout session and returns clientSecret", async () => {
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([
{ id: "u1", email: "a@b.com", stripeCustomerId: "cus_123" },
]),
}),
}),
});
it("creates an embedded Stripe checkout session and returns clientSecret", async () => {
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi
.fn()
.mockResolvedValue([
{ id: "u1", email: "a@b.com", stripeCustomerId: "cus_123" },
]),
}),
}),
});
(stripe.checkout.sessions.create as ReturnType<typeof vi.fn>).mockResolvedValue({
id: "session_123",
client_secret: "cs_123_secret",
});
(
stripe.checkout.sessions.create as ReturnType<typeof vi.fn>
).mockResolvedValue({
id: "session_123",
client_secret: "cs_123_secret",
});
const result = await createCheckoutSession(
"u1",
"a@b.com",
"price_basic",
"https://example.com/return",
);
const result = await createCheckoutSession(
"u1",
"a@b.com",
"price_basic",
"https://example.com/return",
);
expect(result.clientSecret).toBe("cs_123_secret");
expect(result.sessionId).toBe("session_123");
expect(stripe.checkout.sessions.create).toHaveBeenCalledWith(
expect.objectContaining({
ui_mode: "embedded_page",
return_url: "https://example.com/return?session_id={CHECKOUT_SESSION_ID}",
}),
);
});
expect(result.clientSecret).toBe("cs_123_secret");
expect(result.sessionId).toBe("session_123");
expect(stripe.checkout.sessions.create).toHaveBeenCalledWith(
expect.objectContaining({
ui_mode: "embedded_page",
return_url:
"https://example.com/return?session_id={CHECKOUT_SESSION_ID}",
}),
);
});
});
describe("createPortalSession", () => {
it("creates a Stripe billing portal session", async () => {
(stripe.billingPortal.sessions.create as ReturnType<typeof vi.fn>).mockResolvedValue({
url: "https://billing.stripe.com/portal/session_456",
});
it("creates a Stripe billing portal session", async () => {
(
stripe.billingPortal.sessions.create as ReturnType<typeof vi.fn>
).mockResolvedValue({
url: "https://billing.stripe.com/portal/session_456",
});
const result = await createPortalSession(
"cus_123",
"https://example.com/return",
);
const result = await createPortalSession(
"cus_123",
"https://example.com/return",
);
expect(result.url).toBe("https://billing.stripe.com/portal/session_456");
});
expect(result.url).toBe("https://billing.stripe.com/portal/session_456");
});
});
describe("cancelSubscription", () => {
it("sets cancel_at_period_end on Stripe subscription", async () => {
(stripe.subscriptions.update as ReturnType<typeof vi.fn>).mockResolvedValue({});
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockResolvedValue([]),
}),
});
it("sets cancel_at_period_end on Stripe subscription", async () => {
(stripe.subscriptions.update as ReturnType<typeof vi.fn>).mockResolvedValue(
{},
);
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockResolvedValue([]),
}),
});
const result = await cancelSubscription("sub_123");
expect(result.cancelAtPeriodEnd).toBe(true);
expect(stripe.subscriptions.update).toHaveBeenCalledWith("sub_123", {
cancel_at_period_end: true,
});
});
const result = await cancelSubscription("sub_123");
expect(result.cancelAtPeriodEnd).toBe(true);
expect(stripe.subscriptions.update).toHaveBeenCalledWith("sub_123", {
cancel_at_period_end: true,
});
});
});
describe("reactivateSubscription", () => {
it("removes cancel_at_period_end on Stripe subscription", async () => {
(stripe.subscriptions.update as ReturnType<typeof vi.fn>).mockResolvedValue({});
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockResolvedValue([]),
}),
});
it("removes cancel_at_period_end on Stripe subscription", async () => {
(stripe.subscriptions.update as ReturnType<typeof vi.fn>).mockResolvedValue(
{},
);
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockResolvedValue([]),
}),
});
const result = await reactivateSubscription("sub_123");
expect(result.cancelAtPeriodEnd).toBe(false);
expect(stripe.subscriptions.update).toHaveBeenCalledWith("sub_123", {
cancel_at_period_end: false,
});
});
const result = await reactivateSubscription("sub_123");
expect(result.cancelAtPeriodEnd).toBe(false);
expect(stripe.subscriptions.update).toHaveBeenCalledWith("sub_123", {
cancel_at_period_end: false,
});
});
});
describe("listInvoices", () => {
it("returns invoices list from Stripe", async () => {
(stripe.invoices.list as ReturnType<typeof vi.fn>).mockResolvedValue({
data: [{ id: "in_1" }, { id: "in_2" }],
has_more: false,
});
it("returns invoices list from Stripe", async () => {
(stripe.invoices.list as ReturnType<typeof vi.fn>).mockResolvedValue({
data: [{ id: "in_1" }, { id: "in_2" }],
has_more: false,
});
const result = await listInvoices("cus_123", 10);
expect(result.invoices).toHaveLength(2);
expect(result.hasMore).toBe(false);
});
const result = await listInvoices("cus_123", 10);
expect(result.invoices).toHaveLength(2);
expect(result.hasMore).toBe(false);
});
});
describe("handleWebhookEvent", () => {
it("handles checkout.session.completed", async () => {
(db.insert as ReturnType<typeof vi.fn>).mockReturnValue({
values: vi.fn().mockReturnValue({
onConflictDoNothing: vi.fn().mockResolvedValue(undefined),
}),
});
it("handles checkout.session.completed", async () => {
(db.insert as ReturnType<typeof vi.fn>).mockReturnValue({
values: vi.fn().mockReturnValue({
onConflictDoNothing: vi.fn().mockResolvedValue(undefined),
}),
});
(stripe.subscriptions.retrieve as ReturnType<typeof vi.fn>).mockResolvedValue({
id: "sub_new",
items: { data: [{ price: { id: "price_premium" } }] },
current_period_start: 1700000000,
current_period_end: 1702592000,
status: "active",
cancel_at_period_end: false,
});
(
stripe.subscriptions.retrieve as ReturnType<typeof vi.fn>
).mockResolvedValue({
id: "sub_new",
items: { data: [{ price: { id: "price_premium" } }] },
current_period_start: 1700000000,
current_period_end: 1702592000,
status: "active",
cancel_at_period_end: false,
});
await handleWebhookEvent({
type: "checkout.session.completed",
data: {
object: {
metadata: { userId: "u1" },
subscription: "sub_new",
},
},
} as never);
await handleWebhookEvent({
type: "checkout.session.completed",
data: {
object: {
id: "cs_test123",
metadata: { userId: "u1" },
subscription: "sub_new",
},
},
} as never);
expect(db.insert).toHaveBeenCalled();
});
expect(db.insert).toHaveBeenCalled();
});
it("handles invoice.paid", async () => {
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([{ id: "sub_db_1" }]),
}),
}),
});
it("handles invoice.paid", async () => {
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([{ id: "sub_db_1" }]),
}),
}),
});
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
returning: vi.fn().mockResolvedValue([{ id: "sub_db_1", status: "active" }]),
}),
}),
});
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
returning: vi
.fn()
.mockResolvedValue([{ id: "sub_db_1", status: "active" }]),
}),
}),
});
await handleWebhookEvent({
type: "invoice.paid",
data: {
object: {
subscription: "sub_123",
},
},
} as never);
});
await handleWebhookEvent({
type: "invoice.paid",
data: {
object: {
subscription: "sub_123",
},
},
} as never);
});
it("handles invoice.payment_failed", async () => {
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([{ id: "sub_db_1" }]),
}),
}),
});
it("handles invoice.payment_failed", async () => {
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([{ id: "sub_db_1" }]),
}),
}),
});
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
returning: vi.fn().mockResolvedValue([{ id: "sub_db_1", status: "past_due" }]),
}),
}),
});
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
returning: vi
.fn()
.mockResolvedValue([{ id: "sub_db_1", status: "past_due" }]),
}),
}),
});
await handleWebhookEvent({
type: "invoice.payment_failed",
data: {
object: {
subscription: "sub_123",
},
},
} as never);
});
await handleWebhookEvent({
type: "invoice.payment_failed",
data: {
object: {
subscription: "sub_123",
},
},
} as never);
});
it("handles customer.subscription.updated", async () => {
(db.query.subscriptions.findFirst as ReturnType<typeof vi.fn>).mockResolvedValue(null);
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([]),
}),
}),
});
it("handles customer.subscription.updated", async () => {
(
db.query.subscriptions.findFirst as ReturnType<typeof vi.fn>
).mockResolvedValue(null);
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([]),
}),
}),
});
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
returning: vi.fn().mockResolvedValue([{ id: "sub_db_1", status: "active" }]),
}),
}),
});
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
returning: vi
.fn()
.mockResolvedValue([{ id: "sub_db_1", status: "active" }]),
}),
}),
});
await handleWebhookEvent({
type: "customer.subscription.updated",
data: {
object: {
id: "sub_123",
metadata: { userId: "u1" },
items: { data: [{ price: { id: "price_plus" } }] },
current_period_start: 1700000000,
current_period_end: 1702592000,
status: "active",
cancel_at_period_end: false,
},
},
} as never);
});
await handleWebhookEvent({
type: "customer.subscription.updated",
data: {
object: {
id: "sub_123",
metadata: { userId: "u1" },
items: { data: [{ price: { id: "price_plus" } }] },
current_period_start: 1700000000,
current_period_end: 1702592000,
status: "active",
cancel_at_period_end: false,
},
},
} as never);
});
it("handles customer.subscription.deleted", async () => {
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([{ id: "sub_db_1" }]),
}),
}),
});
it("handles customer.subscription.deleted", async () => {
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([{ id: "sub_db_1" }]),
}),
}),
});
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
returning: vi.fn().mockResolvedValue([{ id: "sub_db_1", status: "canceled" }]),
}),
}),
});
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
returning: vi
.fn()
.mockResolvedValue([{ id: "sub_db_1", status: "canceled" }]),
}),
}),
});
await handleWebhookEvent({
type: "customer.subscription.deleted",
data: {
object: {
id: "sub_123",
},
},
} as never);
});
await handleWebhookEvent({
type: "customer.subscription.deleted",
data: {
object: {
id: "sub_123",
},
},
} as never);
});
});

View File

@@ -1,10 +1,16 @@
import { TRPCError } from "@trpc/server";
import { eq } from "drizzle-orm";
import { safeParse } from "valibot";
import { db } from "~/server/db";
import { stripe } from "~/server/stripe";
import { users } from "~/server/db/schema/auth";
import { subscriptions } from "~/server/db/schema/subscription";
import type Stripe from "stripe";
import {
CheckoutSessionSchema,
SubscriptionSchema,
InvoiceSchema,
} from "~/server/api/schemas/webhook";
type Tier = "basic" | "plus" | "premium";
@@ -139,19 +145,46 @@ export async function updateSubscriptionInDB(
return null;
}
export async function handleWebhookEvent(event: Stripe.Event) {
const obj = event.data.object as unknown as Record<string, unknown>;
function safeParseSubscription(obj: unknown) {
const result = safeParse(SubscriptionSchema, obj);
if (!result.success) {
console.error(`[webhook] Failed to parse subscription data: ${result.issues?.map((i) => i.message).join(", ")}`);
return null;
}
return result.output;
}
function safeParseCheckoutSession(obj: unknown) {
const result = safeParse(CheckoutSessionSchema, obj);
if (!result.success) {
console.error(`[webhook] Failed to parse checkout session data: ${result.issues?.map((i) => i.message).join(", ")}`);
return null;
}
return result.output;
}
function safeParseInvoice(obj: unknown) {
const result = safeParse(InvoiceSchema, obj);
if (!result.success) {
console.error(`[webhook] Failed to parse invoice data: ${result.issues?.map((i) => i.message).join(", ")}`);
return null;
}
return result.output;
}
export async function handleWebhookEvent(event: Stripe.Event) {
switch (event.type) {
case "checkout.session.completed": {
const session = obj as unknown as Stripe.Checkout.Session;
const session = safeParseCheckoutSession(event.data.object);
if (!session) break;
const userId = session.metadata?.userId;
if (!userId || !session.subscription) break;
const stripeSub = await stripe.subscriptions.retrieve(
session.subscription as string,
);
const sub = stripeSub as unknown as Record<string, unknown>;
const stripeSub = await stripe.subscriptions.retrieve(session.subscription);
// Fetch fresh subscription data from Stripe for accurate fields
const subData = stripeSub as unknown as Record<string, unknown>;
await db.insert(subscriptions).values({
userId,
@@ -159,65 +192,76 @@ export async function handleWebhookEvent(event: Stripe.Event) {
tier: mapStripeProductToTier(
stripeSub.items.data[0]?.price?.id ?? "",
),
status: sub.status as typeof subscriptions.$inferSelect.status,
currentPeriodStart: new Date((sub.current_period_start as number) * 1000),
currentPeriodEnd: new Date((sub.current_period_end as number) * 1000),
cancelAtPeriodEnd: sub.cancel_at_period_end as boolean,
status: (subData.status as typeof subscriptions.$inferSelect.status) ?? "active",
currentPeriodStart: subData.current_period_start
? new Date((subData.current_period_start as number) * 1000)
: undefined,
currentPeriodEnd: subData.current_period_end
? new Date((subData.current_period_end as number) * 1000)
: undefined,
cancelAtPeriodEnd: Boolean(subData.cancel_at_period_end),
}).onConflictDoNothing();
break;
}
case "invoice.paid": {
const invoice = obj;
if (!invoice.subscription) break;
const invoice = safeParseInvoice(event.data.object);
if (!invoice?.subscription) break;
await updateSubscriptionInDB(invoice.subscription as string, {
await updateSubscriptionInDB(invoice.subscription, {
status: "active",
});
break;
}
case "invoice.payment_failed": {
const invoice = obj;
if (!invoice.subscription) break;
const invoice = safeParseInvoice(event.data.object);
if (!invoice?.subscription) break;
await updateSubscriptionInDB(invoice.subscription as string, {
await updateSubscriptionInDB(invoice.subscription, {
status: "past_due",
});
break;
}
case "customer.subscription.updated": {
const stripeSub = obj as unknown as Stripe.Subscription;
const userId = stripeSub.metadata?.userId;
const sub = stripeSub as unknown as Record<string, unknown>;
const validatedSub = safeParseSubscription(event.data.object);
if (!validatedSub) break;
const userId = validatedSub.metadata?.userId;
if (!userId) {
const [existingSub] = await db
.select()
.from(subscriptions)
.where(eq(subscriptions.stripeId, stripeSub.id))
.where(eq(subscriptions.stripeId, validatedSub.id))
.limit(1);
if (!existingSub) break;
}
const tier = stripeSub.items.data[0]?.price?.id
? mapStripeProductToTier(stripeSub.items.data[0].price.id)
const tier = validatedSub.items?.data?.price?.id
? mapStripeProductToTier(validatedSub.items.data.price.id)
: undefined;
await updateSubscriptionInDB(stripeSub.id, {
await updateSubscriptionInDB(validatedSub.id, {
tier,
status: sub.status as string,
currentPeriodStart: new Date((sub.current_period_start as number) * 1000),
currentPeriodEnd: new Date((sub.current_period_end as number) * 1000),
cancelAtPeriodEnd: sub.cancel_at_period_end as boolean,
status: validatedSub.status ?? undefined,
currentPeriodStart: validatedSub.current_period_start
? new Date(validatedSub.current_period_start * 1000)
: undefined,
currentPeriodEnd: validatedSub.current_period_end
? new Date(validatedSub.current_period_end * 1000)
: undefined,
cancelAtPeriodEnd: validatedSub.cancel_at_period_end ?? undefined,
});
break;
}
case "customer.subscription.deleted": {
const stripeSub = obj as unknown as Stripe.Subscription;
const stripeSub = safeParseSubscription(event.data.object);
if (!stripeSub) break;
await updateSubscriptionInDB(stripeSub.id, {
status: "canceled",
});

View File

@@ -0,0 +1,96 @@
import { describe, it, expect } from "vitest";
/**
* URL blocking logic for SSRF protection.
* Mirrors the isBlockedUrl function in generator.ts.
*/
function isBlockedUrl(url: string): boolean {
if (url.startsWith("file:")) return true;
if (url.startsWith("data:")) return true;
if (/^https?:\/\/(169\.254\.169\.254|metadata\.google\.internal)/i.test(url)) return true;
const hostname = url.replace(/^https?:\/\//, "").split(/[/:?]/)[0];
if (/^(\d+\.\d+\.\d+\.\d+)/.test(hostname)) {
const [, ip] = hostname.match(/^(\d+\.\d+\.\d+\.\d+)/)!;
const parts = ip.split(".").map(Number);
if (parts[0] === 10) return true;
if (parts[0] === 172 && parts[1] >= 16 && parts[1] <= 31) return true;
if (parts[0] === 192 && parts[1] === 168) return true;
if (parts[0] === 127) return true;
if (parts[0] === 0) return true;
}
return false;
}
describe("SSRF URL blocking", () => {
describe("blocked URLs", () => {
it("blocks file:// URLs", () => {
expect(isBlockedUrl("file:///etc/passwd")).toBe(true);
expect(isBlockedUrl("file:///etc/shadow")).toBe(true);
expect(isBlockedUrl("file:///windows/system32/config/sam")).toBe(true);
});
it("blocks data: URIs", () => {
expect(isBlockedUrl("data:text/html,<script>alert(1)</script>")).toBe(true);
expect(isBlockedUrl("data:image/png;base64,abc")).toBe(true);
});
it("blocks cloud metadata endpoints", () => {
expect(isBlockedUrl("http://169.254.169.254/latest/meta-data/")).toBe(true);
expect(isBlockedUrl("https://169.254.169.254/computeMetadata/v1/")).toBe(true);
expect(isBlockedUrl("http://metadata.google.internal/computeMetadata/v1/")).toBe(true);
expect(isBlockedUrl("https://metadata.google.internal/")).toBe(true);
});
it("blocks 10.0.0.0/8", () => {
expect(isBlockedUrl("http://10.0.0.1/admin")).toBe(true);
expect(isBlockedUrl("http://10.255.255.255/")).toBe(true);
expect(isBlockedUrl("http://10.128.1.1/")).toBe(true);
});
it("blocks 172.16.0.0/12", () => {
expect(isBlockedUrl("http://172.16.0.1/internal")).toBe(true);
expect(isBlockedUrl("http://172.31.255.255/")).toBe(true);
expect(isBlockedUrl("http://172.17.0.1/")).toBe(true);
});
it("does not block 172.15.x.x or 172.32.x.x", () => {
expect(isBlockedUrl("http://172.15.0.1/")).toBe(false);
expect(isBlockedUrl("http://172.32.0.1/")).toBe(false);
});
it("blocks 192.168.0.0/16", () => {
expect(isBlockedUrl("http://192.168.1.1/admin")).toBe(true);
expect(isBlockedUrl("http://192.168.0.1/")).toBe(true);
expect(isBlockedUrl("http://192.168.255.255/")).toBe(true);
});
it("blocks 127.0.0.0/8", () => {
expect(isBlockedUrl("http://127.0.0.1:8080/health")).toBe(true);
expect(isBlockedUrl("http://127.0.0.2/")).toBe(true);
});
it("blocks 0.0.0.0", () => {
expect(isBlockedUrl("http://0.0.0.0/")).toBe(true);
});
});
describe("allowed URLs", () => {
it("allows legitimate external URLs", () => {
expect(isBlockedUrl("https://example.com/image.png")).toBe(false);
expect(isBlockedUrl("http://cdn.example.com/font.woff2")).toBe(false);
expect(isBlockedUrl("https://fonts.googleapis.com/css")).toBe(false);
expect(isBlockedUrl("https://app.kordant.com/api")).toBe(false);
});
it("does not block URLs with IP-like path segments", () => {
expect(isBlockedUrl("https://example.com/192.168.1.1")).toBe(false);
});
it("handles edge cases", () => {
expect(isBlockedUrl("")).toBe(false);
expect(isBlockedUrl("not-a-url")).toBe(false);
});
});
});

View File

@@ -246,11 +246,55 @@ export function renderHTML(data: ReportData, reportType: string): string {
return renderTemplate(template, flatData);
}
/**
* Returns true if the URL should be blocked (SSRF/metadata/internal access).
*/
export function isBlockedUrl(url: string): boolean {
// Block local file access
if (url.startsWith("file:")) return true;
// Block data URIs
if (url.startsWith("data:")) return true;
// Block cloud metadata endpoints
if (/^https?:\/\/(169\.254\.169\.254|metadata\.google\.internal)/i.test(url)) return true;
// Block internal/private IP ranges
const hostname = url.replace(/^https?:\/\//, "").split(/[/:?]/)[0];
if (/^(\d+\.\d+\.\d+\.\d+)/.test(hostname)) {
const [, ip] = hostname.match(/^(\d+\.\d+\.\d+\.\d+)/)!;
const parts = ip.split(".").map(Number);
// 10.0.0.0/8
if (parts[0] === 10) return true;
// 172.16.0.0/12
if (parts[0] === 172 && parts[1] >= 16 && parts[1] <= 31) return true;
// 192.168.0.0/16
if (parts[0] === 192 && parts[1] === 168) return true;
// 127.0.0.0/8 (loopback)
if (parts[0] === 127) return true;
// 0.0.0.0
if (parts[0] === 0) return true;
}
return false;
}
export async function generatePDF(html: string): Promise<Buffer> {
try {
const puppeteer = await import("puppeteer");
const browser = await puppeteer.launch({ headless: true, args: ["--no-sandbox"] });
const page = await browser.newPage();
// Block dangerous network requests to prevent SSRF
page.on("request", (request) => {
const url = request.url();
if (isBlockedUrl(url)) {
request.abort();
} else {
request.continue();
}
});
await page.setContent(html, { waitUntil: "load" });
const pdfBuffer = await page.pdf({ format: "A4", printBackground: true, margin: { top: "20mm", bottom: "20mm", left: "15mm", right: "15mm" } });
await browser.close();

View File

@@ -4,92 +4,124 @@ import { describe, it, expect, vi, beforeAll, afterAll } from "vitest";
const mockVerifyJWT = vi.fn();
vi.mock("~/server/auth/jwt", () => ({
verifyJWT: mockVerifyJWT,
signJWT: vi.fn(),
verifyJWT: mockVerifyJWT,
signJWT: vi.fn(),
}));
let mockServer: any;
let connectionHandler: ((ws: any, req: any) => void) | null = null;
let connectionHandler: ((ws: any) => void) | null = null;
vi.mock("ws", () => {
mockServer = {
on: vi.fn((event: string, handler: any) => {
if (event === "connection") connectionHandler = handler;
}),
close: vi.fn((cb: () => void) => cb()),
clients: new Set(),
};
mockServer = {
on: vi.fn((event: string, handler: any) => {
if (event === "connection") connectionHandler = handler;
}),
close: vi.fn((cb: () => void) => cb()),
clients: new Set(),
};
return {
WebSocketServer: vi.fn(function (_opts: any, cb?: () => void) {
if (cb) setTimeout(cb, 0);
return mockServer;
}),
WebSocket: { OPEN: 1, CONNECTING: 0, CLOSING: 2, CLOSED: 3 },
};
function MockWebSocketServer(_opts: any, cb?: () => void) {
if (cb) setTimeout(cb, 0);
return mockServer;
}
MockWebSocketServer.prototype = mockServer;
return {
WebSocketServer: MockWebSocketServer,
WebSocket: { OPEN: 1, CONNECTING: 0, CLOSING: 2, CLOSED: 3 },
};
});
function makeWs() {
const handlers: Record<string, (...args: any[]) => void> = {};
return {
close: vi.fn(),
send: vi.fn(),
ping: vi.fn(),
terminate: vi.fn(),
readyState: 1,
on: vi.fn((event: string, handler: any) => {
handlers[event] = handler;
}),
emit: (event: string, ...args: any[]) => {
handlers[event]?.(...args);
},
};
const handlers: Record<string, (...args: any[]) => Promise<void> | void> = {};
return {
close: vi.fn(),
send: vi.fn(),
ping: vi.fn(),
terminate: vi.fn(),
readyState: 1,
on(event: string, handler: any) {
handlers[event] = handler;
},
async emit(event: string, ...args: any[]) {
const h = handlers[event];
if (h) {
const result = h(...args);
if (result instanceof Promise) await result;
}
},
};
}
describe("WebSocket server", () => {
beforeAll(async () => {
process.env.WS_PORT = "3099";
const { start } = await import("~/server/websocket");
await start();
}, 15000);
beforeAll(async () => {
process.env.WS_PORT = "3099";
const { start } = await import("~/server/websocket");
await start();
}, 15000);
afterAll(async () => {
const { stop } = await import("~/server/websocket");
await stop();
});
afterAll(async () => {
const { stop } = await import("~/server/websocket");
await stop();
});
it("should reject connection without JWT", async () => {
const ws = makeWs();
await connectionHandler!(ws, { url: "/" });
expect(ws.close).toHaveBeenCalledWith(4001, "Authentication failed");
});
it("should accept connection without JWT and require post-connection auth", async () => {
const ws = makeWs();
await connectionHandler!(ws);
// Connection is accepted initially (no query-param auth)
expect(ws.close).not.toHaveBeenCalled();
});
it("should reject connection with invalid JWT", async () => {
mockVerifyJWT.mockRejectedValue(new Error("Invalid token"));
it("should reject auth message with invalid JWT", async () => {
mockVerifyJWT.mockRejectedValue(new Error("Invalid token"));
const ws = makeWs();
await connectionHandler!(ws, { url: "/?token=bad" });
expect(ws.close).toHaveBeenCalledWith(4001, "Authentication failed");
});
const ws = makeWs();
await connectionHandler!(ws);
it("should accept connection with valid JWT", async () => {
mockVerifyJWT.mockResolvedValue({ sub: "user-1" });
// Trigger the message handler with an auth message
await ws.emit(
"message",
Buffer.from(JSON.stringify({ type: "auth", token: "bad" })),
);
const ws = makeWs();
await connectionHandler!(ws, { url: "/?token=good" });
expect(ws.close).not.toHaveBeenCalled();
});
expect(ws.send).toHaveBeenCalledWith(
JSON.stringify({ type: "auth_error", message: "Invalid token" }),
);
expect(ws.close).toHaveBeenCalledWith(4001, "Authentication failed");
});
it("should return false when broadcasting to non-existent user", async () => {
const { broadcastToUser } = await import("~/server/websocket");
const result = broadcastToUser("nonexistent", {
type: "alert",
alert: {
id: "a1", title: "T", message: "M",
severity: "INFO", source: "TEST", category: "TEST",
createdAt: new Date().toISOString(),
},
});
expect(result).toBe(false);
});
it("should accept connection with valid JWT via auth message", async () => {
mockVerifyJWT.mockResolvedValue({ sub: "user-1" });
const ws = makeWs();
await connectionHandler!(ws);
// Trigger the message handler with an auth message
await ws.emit(
"message",
Buffer.from(JSON.stringify({ type: "auth", token: "good" })),
);
expect(ws.send).toHaveBeenCalledWith(
JSON.stringify({ type: "auth_success" }),
);
expect(ws.close).not.toHaveBeenCalled();
});
it("should return false when broadcasting to non-existent user", async () => {
const { broadcastToUser } = await import("~/server/websocket");
const result = broadcastToUser("nonexistent", {
type: "alert",
alert: {
id: "a1",
title: "T",
message: "M",
severity: "INFO",
source: "TEST",
category: "TEST",
createdAt: new Date().toISOString(),
},
});
expect(result).toBe(false);
});
});

View File

@@ -1,216 +1,262 @@
import { WebSocketServer, WebSocket } from "ws";
import type { Server } from "ws";
import { IncomingMessage } from "node:http";
import { URL } from "node:url";
import { verifyJWT } from "~/server/auth/jwt";
const WS_PORT = parseInt(process.env.WS_PORT ?? "3001", 10);
const HEARTBEAT_INTERVAL = 30_000;
const PONG_TIMEOUT = 10_000;
const AUTH_TIMEOUT = 5_000;
interface AlertMessage {
type: "alert";
alert: {
id: string;
title: string;
message: string;
severity: string;
source: string;
category: string;
createdAt: string;
};
type: "alert";
alert: {
id: string;
title: string;
message: string;
severity: string;
source: string;
category: string;
createdAt: string;
};
}
interface WsClient extends WebSocket {
userId?: string;
isAlive?: boolean;
pongTimer?: ReturnType<typeof setTimeout>;
userId?: string;
isAlive?: boolean;
isAuthed?: boolean;
pongTimer?: ReturnType<typeof setTimeout>;
authTimer?: ReturnType<typeof setTimeout>;
}
const userSockets = new Map<string, Set<WsClient>>();
let wss: Server | null = null;
let heartbeatTimer: ReturnType<typeof setInterval> | null = null;
function getTokenFromRequest(req: IncomingMessage): string | null {
const url = new URL(req.url ?? "/", "http://localhost");
return url.searchParams.get("token");
}
async function authenticateConnection(
ws: WsClient,
req: IncomingMessage,
): Promise<string | null> {
const token = getTokenFromRequest(req);
if (!token) return null;
try {
const payload = await verifyJWT<{ sub?: string; userId?: string }>(token);
const userId = payload.sub ?? payload.userId;
if (!userId) return null;
return userId;
} catch {
return null;
}
async function authenticateToken(token: string): Promise<string | null> {
try {
const payload = await verifyJWT<{ sub?: string; userId?: string }>(token);
const userId = payload.sub ?? payload.userId;
if (!userId) return null;
return userId;
} catch {
return null;
}
}
function addSocket(userId: string, ws: WsClient) {
let sockets = userSockets.get(userId);
if (!sockets) {
sockets = new Set();
userSockets.set(userId, sockets);
}
sockets.add(ws);
let sockets = userSockets.get(userId);
if (!sockets) {
sockets = new Set();
userSockets.set(userId, sockets);
}
sockets.add(ws);
}
function removeSocket(userId: string, ws: WsClient) {
const sockets = userSockets.get(userId);
if (!sockets) return;
sockets.delete(ws);
if (sockets.size === 0) {
userSockets.delete(userId);
}
const sockets = userSockets.get(userId);
if (!sockets) return;
sockets.delete(ws);
if (sockets.size === 0) {
userSockets.delete(userId);
}
}
function heartbeat(ws: WsClient) {
ws.isAlive = true;
ws.isAlive = true;
}
function startHeartbeat() {
if (heartbeatTimer) clearInterval(heartbeatTimer);
heartbeatTimer = setInterval(() => {
if (!wss) return;
wss.clients.forEach((client) => {
const ws = client as WsClient;
if (ws.isAlive === false) {
ws.terminate();
return;
}
ws.isAlive = false;
ws.ping();
if (heartbeatTimer) clearInterval(heartbeatTimer);
heartbeatTimer = setInterval(() => {
if (!wss) return;
wss.clients.forEach((client) => {
const ws = client as WsClient;
if (ws.isAlive === false) {
ws.terminate();
return;
}
ws.isAlive = false;
ws.ping();
ws.pongTimer = setTimeout(() => {
ws.terminate();
}, PONG_TIMEOUT);
});
}, HEARTBEAT_INTERVAL);
ws.pongTimer = setTimeout(() => {
ws.terminate();
}, PONG_TIMEOUT);
});
}, HEARTBEAT_INTERVAL);
if (heartbeatTimer && typeof heartbeatTimer === "object") {
heartbeatTimer.unref();
}
if (heartbeatTimer && typeof heartbeatTimer === "object") {
heartbeatTimer.unref();
}
}
function stopHeartbeat() {
if (heartbeatTimer) {
clearInterval(heartbeatTimer);
heartbeatTimer = null;
}
if (heartbeatTimer) {
clearInterval(heartbeatTimer);
heartbeatTimer = null;
}
}
/**
* Enforces post-connection auth timeout.
* If the client doesn't send an auth message within AUTH_TIMEOUT,
* the connection is terminated.
*/
function enforceAuthTimeout(ws: WsClient): void {
ws.authTimer = setTimeout(() => {
if (!ws.isAuthed) {
console.log(
"[websocket] Auth timeout — closing unauthenticated connection",
);
ws.close(4001, "Authentication timeout");
}
}, AUTH_TIMEOUT);
}
export function broadcastToUser(userId: string, data: AlertMessage) {
const sockets = userSockets.get(userId);
if (!sockets || sockets.size === 0) return false;
const sockets = userSockets.get(userId);
if (!sockets || sockets.size === 0) return false;
const message = JSON.stringify(data);
let sent = false;
for (const ws of sockets) {
if (ws.readyState === WebSocket.OPEN) {
ws.send(message);
sent = true;
}
}
return sent;
const message = JSON.stringify(data);
let sent = false;
for (const ws of sockets) {
if (ws.readyState === WebSocket.OPEN) {
ws.send(message);
sent = true;
}
}
return sent;
}
export function getConnectedUsers(): string[] {
return Array.from(userSockets.keys());
return Array.from(userSockets.keys());
}
export function getConnectionCount(): number {
let count = 0;
for (const sockets of userSockets.values()) {
count += sockets.size;
}
return count;
let count = 0;
for (const sockets of userSockets.values()) {
count += sockets.size;
}
return count;
}
export function start(): Promise<void> {
return new Promise((resolve) => {
if (wss) {
resolve();
return;
}
return new Promise((resolve) => {
if (wss) {
resolve();
return;
}
wss = new WebSocketServer({ port: WS_PORT }, () => {
console.log(`[websocket] Server listening on port ${WS_PORT}`);
resolve();
});
wss = new WebSocketServer({ port: WS_PORT }, () => {
console.log(`[websocket] Server listening on port ${WS_PORT}`);
resolve();
});
wss.on("connection", async (ws: WsClient, req: IncomingMessage) => {
const userId = await authenticateConnection(ws, req);
wss.on("connection", async (ws: WsClient) => {
// Mark as unauthenticated initially; client must authenticate within timeout
ws.isAuthed = false;
enforceAuthTimeout(ws);
if (!userId) {
ws.close(4001, "Authentication failed");
return;
}
ws.on("message", async (data) => {
try {
const msg = JSON.parse(data.toString());
ws.userId = userId;
ws.isAlive = true;
addSocket(userId, ws);
// Handle auth messages (post-connection JWT authentication)
if (
msg.type === "auth" &&
msg.token &&
typeof msg.token === "string"
) {
const userId = await authenticateToken(msg.token);
ws.on("pong", () => {
heartbeat(ws);
if (ws.pongTimer) {
clearTimeout(ws.pongTimer);
ws.pongTimer = undefined;
}
});
if (userId) {
ws.isAuthed = true;
ws.userId = userId;
ws.isAlive = true;
ws.on("message", (data) => {
try {
const msg = JSON.parse(data.toString());
if (msg.type === "ping") {
ws.send(JSON.stringify({ type: "pong" }));
}
} catch {
// ignore invalid messages
}
});
// Clear the auth timeout — client is now authenticated
if (ws.authTimer) {
clearTimeout(ws.authTimer);
ws.authTimer = undefined;
}
ws.on("close", () => {
if (ws.userId) {
removeSocket(ws.userId, ws);
}
if (ws.pongTimer) {
clearTimeout(ws.pongTimer);
}
});
addSocket(userId, ws);
ws.send(JSON.stringify({ type: "auth_success" }));
} else {
ws.send(
JSON.stringify({
type: "auth_error",
message: "Invalid token",
}),
);
ws.close(4001, "Authentication failed");
}
return;
}
ws.on("error", (err) => {
console.error("[websocket] Client error:", err.message);
});
});
// Only allow messages from authenticated connections
if (!ws.isAuthed) {
// Ignore ping messages from unauthenticated clients (they might not have sent auth yet)
if (msg.type === "ping") {
ws.send(JSON.stringify({ type: "pong" }));
}
return;
}
startHeartbeat();
});
// Handle normal messages from authenticated clients
if (msg.type === "ping") {
ws.send(JSON.stringify({ type: "pong" }));
}
} catch {
// ignore invalid messages
}
});
ws.on("pong", () => {
heartbeat(ws);
if (ws.pongTimer) {
clearTimeout(ws.pongTimer);
ws.pongTimer = undefined;
}
});
ws.on("close", () => {
if (ws.userId) {
removeSocket(ws.userId, ws);
}
if (ws.pongTimer) {
clearTimeout(ws.pongTimer);
}
if (ws.authTimer) {
clearTimeout(ws.authTimer);
}
});
ws.on("error", (err) => {
console.error("[websocket] Client error:", err.message);
});
});
startHeartbeat();
});
}
export function stop(): Promise<void> {
return new Promise((resolve) => {
stopHeartbeat();
if (!wss) {
resolve();
return;
}
return new Promise((resolve) => {
stopHeartbeat();
if (!wss) {
resolve();
return;
}
for (const ws of wss.clients) {
ws.close(1001, "Server shutting down");
}
for (const ws of wss.clients) {
ws.close(1001, "Server shutting down");
}
wss.close(() => {
wss = null;
userSockets.clear();
console.log("[websocket] Server stopped");
resolve();
});
});
wss.close(() => {
wss = null;
userSockets.clear();
console.log("[websocket] Server stopped");
resolve();
});
});
}