feat: real-time alerts via WebSocket push notifications

- Add ws WebSocket server (port 3001) with JWT auth and user-socket mapping
- Add WebSocket client with exponential backoff reconnection and heartbeat
- Add useRealtimeAlerts hook with toast notifications and unread badge
- Add alert.publisher service (WS → push → email fallback)
- Integrate publisher into DarkWatch, VoicePrint, HomeTitle, SpamShield, RemoveBrokers
- Update Navbar with connection status indicator and unread count
- Add comprehensive tests (14 passing) for server, client, and publisher
This commit is contained in:
2026-05-25 17:58:47 -04:00
parent 3a8e329f02
commit c02457c66a
16 changed files with 1197 additions and 26 deletions

View File

@@ -20,17 +20,18 @@ export const spamshieldRouter = createTRPCRouter({
classifySMS: publicProcedure
.input(wrap(ClassifySMSSchema))
.query(async ({ input }) => {
return spamshieldService.classifySMS(input.text);
.query(async ({ input, ctx }) => {
return spamshieldService.classifySMS(input.text, ctx.user?.id);
}),
classifyCall: publicProcedure
.input(wrap(ClassifyCallSchema))
.query(async ({ input }) => {
.query(async ({ input, ctx }) => {
return spamshieldService.classifyCall(
input.callerNumber,
input.duration,
input.timeOfDay,
ctx.user?.id,
);
}),

View File

@@ -0,0 +1,111 @@
// @vitest-environment node
import { describe, it, expect, vi, beforeEach } from "vitest";
const mockBroadcastToUser = vi.fn();
vi.mock("~/server/websocket", () => ({
broadcastToUser: mockBroadcastToUser,
}));
const mockSendPush = vi.fn();
const mockSendEmail = vi.fn();
vi.mock("~/server/services/notification.service", () => ({
sendPush: mockSendPush,
sendEmail: mockSendEmail,
}));
vi.mock("~/server/db", () => ({
db: {
select: vi.fn(),
insert: vi.fn(),
update: vi.fn(),
},
}));
beforeEach(() => {
vi.clearAllMocks();
});
describe("alert.publisher", () => {
it("should send alert via WebSocket when user is connected", async () => {
mockBroadcastToUser.mockReturnValue(true);
const { publishAlert } = await import("./alert.publisher");
await publishAlert("user-1", {
id: "alert-1",
title: "Test Alert",
message: "Test message",
severity: "HIGH",
source: "DARKWATCH",
category: "EXPOSURE_DETECTED",
createdAt: new Date(),
});
expect(mockBroadcastToUser).toHaveBeenCalledWith("user-1", {
type: "alert",
alert: {
id: "alert-1",
title: "Test Alert",
message: "Test message",
severity: "HIGH",
source: "DARKWATCH",
category: "EXPOSURE_DETECTED",
createdAt: expect.any(String),
},
});
expect(mockSendPush).not.toHaveBeenCalled();
expect(mockSendEmail).not.toHaveBeenCalled();
});
it("should fall back to push notification when user is not connected", async () => {
mockBroadcastToUser.mockReturnValue(false);
mockSendPush.mockResolvedValue({ successCount: 1 });
const { publishAlert } = await import("./alert.publisher");
await publishAlert("user-1", {
id: "alert-2",
title: "Offline Alert",
message: "Offline message",
severity: "WARNING",
source: "VOICEPRINT",
category: "SYNTHETIC_VOICE",
createdAt: new Date(),
});
expect(mockBroadcastToUser).toHaveBeenCalled();
expect(mockSendPush).toHaveBeenCalledWith(
"user-1",
"Offline Alert",
"Offline message",
{ alertId: "alert-2", source: "VOICEPRINT", severity: "WARNING" },
);
});
it("should publish alert to multiple users", async () => {
mockBroadcastToUser.mockReturnValue(false);
mockSendPush.mockResolvedValue({ successCount: 0 });
const db = await import("~/server/db");
(db.db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([]),
}),
}),
});
const { publishToGroup } = await import("./alert.publisher");
await publishToGroup(["user-1", "user-2"], {
id: "alert-3",
title: "Group Alert",
message: "Group message",
severity: "INFO",
source: "HOME_TITLE",
category: "HOME_TITLE",
createdAt: new Date(),
});
expect(mockBroadcastToUser).toHaveBeenCalledTimes(2);
});
});

View File

@@ -0,0 +1,66 @@
import { broadcastToUser } from "~/server/websocket";
import { sendPush, sendEmail } from "~/server/services/notification.service";
import { db } from "~/server/db";
import { users } from "~/server/db/schema/auth";
import { eq } from "drizzle-orm";
export interface PublishableAlert {
id: string;
title: string;
message: string;
severity: string;
source: string;
category: string;
createdAt: Date;
}
export async function publishAlert(userId: string, alert: PublishableAlert): Promise<void> {
const message = {
type: "alert" as const,
alert: {
id: alert.id,
title: alert.title,
message: alert.message,
severity: alert.severity,
source: alert.source,
category: alert.category,
createdAt: alert.createdAt.toISOString(),
},
};
const sent = broadcastToUser(userId, message);
if (!sent) {
try {
const pushResult = await sendPush(userId, alert.title, alert.message, {
alertId: alert.id,
source: alert.source,
severity: alert.severity,
});
if (pushResult.successCount === 0) {
const [user] = await db
.select()
.from(users)
.where(eq(users.id, userId))
.limit(1);
if (user?.email) {
await sendEmail(
user.email,
`[ShieldAI] ${alert.title}`,
`<p>${alert.message}</p>`,
alert.message,
);
}
}
} catch (err) {
console.error("[alert.publisher] Fallback notification failed:", err);
}
}
}
export async function publishToGroup(userIds: string[], alert: PublishableAlert): Promise<void> {
const promises = userIds.map((userId) => publishAlert(userId, alert));
await Promise.allSettled(promises);
}

View File

@@ -1,6 +1,7 @@
import { eq, and } from "drizzle-orm";
import { db } from "~/server/db";
import { exposures, alerts, subscriptions } from "~/server/db/schema";
import { publishAlert } from "~/server/services/alert.publisher";
export function severityScore(exposure: {
source: string;
@@ -105,14 +106,27 @@ async function createAlertForExposure(
if (!sub) return;
await db.insert(alerts).values({
subscriptionId: exposure.subscriptionId,
userId: sub.userId,
exposureId: exposure.id,
type: "exposure_detected",
const [alert] = await db
.insert(alerts)
.values({
subscriptionId: exposure.subscriptionId,
userId: sub.userId,
exposureId: exposure.id,
type: "exposure_detected",
title,
message,
severity: alertSeverityMap[severity] ?? "info",
channel: ["email", "push"],
})
.returning();
publishAlert(sub.userId, {
id: alert.id,
title,
message,
severity: alertSeverityMap[severity] ?? "info",
channel: ["email", "push"],
});
source: "DARKWATCH",
category: "EXPOSURE_DETECTED",
createdAt: alert.createdAt,
}).catch((err) => console.error("[darkwatch] Failed to publish alert:", err));
}

View File

@@ -14,6 +14,7 @@ import {
type DetectedChange,
type SnapshotData,
} from "./hometitle/change.detector";
import { publishAlert } from "~/server/services/alert.publisher";
import {
geocodeAddress,
fetchCountyRecords,
@@ -433,7 +434,7 @@ async function generateAlert(
critical: "CRITICAL",
};
await db
const [normalized] = await db
.insert(normalizedAlerts)
.values({
source: "HOME_TITLE",
@@ -454,5 +455,15 @@ async function generateAlert(
})
.returning();
publishAlert(sub.userId, {
id: alert.id,
title,
message,
severity: change.severity,
source: "HOME_TITLE",
category: "HOME_TITLE",
createdAt: alert.createdAt,
}).catch((err) => console.error("[hometitle] Failed to publish alert:", err));
return alert;
}

View File

@@ -1,8 +1,9 @@
import { TRPCError } from "@trpc/server";
import { eq, and, desc, count, inArray, or, isNull, lt } from "drizzle-orm";
import { db } from "~/server/db";
import { infoBrokers, removalRequests, brokerListings, subscriptions } from "~/server/db/schema";
import { infoBrokers, removalRequests, brokerListings, subscriptions, normalizedAlerts } from "~/server/db/schema";
import { getActiveBrokers } from "./removebrokers/broker.registry";
import { publishAlert } from "~/server/services/alert.publisher";
import { submitAutomatedRemoval, sendRemovalEmail } from "./removebrokers/removal.engine";
import type { PersonalInfo } from "./removebrokers/removal.engine";
import type { RemovalMethod } from "./removebrokers/broker.registry";
@@ -273,6 +274,38 @@ export async function scanForListings(userId: string, brokerId?: string) {
createdListings.push({ id: listing.id, brokerId: listing.brokerId, url: listing.url });
}
if (createdListings.length > 0) {
const sourceAlertId = `removebrokers:scan:${sub.id}:${Date.now()}`;
try {
const [na] = await db
.insert(normalizedAlerts)
.values({
source: "REMOVEBROKERS",
category: "BROKER_LISTING",
severity: "INFO",
userId: sub.userId,
title: "New Broker Listings Found",
description: `Found ${createdListings.length} new broker listing(s)`,
entities: { listingCount: createdListings.length, brokerIds: createdListings.map((l) => l.brokerId) },
sourceAlertId,
createdAt: new Date(),
})
.returning();
publishAlert(sub.userId, {
id: na.id,
title: "New Broker Listings Found",
message: `Found ${createdListings.length} new broker listing(s)`,
severity: "INFO",
source: "REMOVEBROKERS",
category: "BROKER_LISTING",
createdAt: na.createdAt,
}).catch(() => {});
} catch {
// alert creation is best-effort
}
}
return {
scanned: brokers.length,
listingsFound: createdListings.length,

View File

@@ -1,10 +1,11 @@
import { TRPCError } from "@trpc/server";
import { eq, and, desc, count, sql, gte } from "drizzle-orm";
import { db } from "~/server/db";
import { spamFeedback, spamRules, auditLogs } from "~/server/db/schema";
import { spamFeedback, spamRules, auditLogs, normalizedAlerts } from "~/server/db/schema";
import { classifyTextBERT, extractFeatures, ruleEngine } from "./spamshield/ml.engine";
import { checkReputation } from "./spamshield/reputation.api";
import type { ReputationResult } from "./spamshield/reputation.api";
import { publishAlert } from "~/server/services/alert.publisher";
function normalizePhoneNumber(phone: string): string {
let cleaned = phone.replace(/[^\d+]/g, "");
@@ -72,7 +73,7 @@ export async function checkNumberReputation(phoneNumber: string) {
return response;
}
export async function classifySMS(text: string) {
export async function classifySMS(text: string, userId?: string) {
const classification = await classifyTextBERT(text);
const response = {
@@ -83,6 +84,38 @@ export async function classifySMS(text: string) {
await logAudit(undefined, "classifySMS", { textLength: text.length }, response, response.confidence);
if (classification.isSpam && classification.confidence >= 0.8 && userId) {
const sourceAlertId = `spamshield:sms:${Date.now()}`;
try {
const [na] = await db
.insert(normalizedAlerts)
.values({
source: "SPAMSHIELD",
category: "SPAM_SMS",
severity: "LOW",
userId,
title: "Spam SMS Detected",
description: "Suspected spam SMS detected",
entities: { textLength: text.length, confidence: classification.confidence },
sourceAlertId,
createdAt: new Date(),
})
.returning();
publishAlert(userId, {
id: na.id,
title: "Spam SMS Detected",
message: "Suspected spam SMS detected",
severity: "LOW",
source: "SPAMSHIELD",
category: "SPAM_SMS",
createdAt: na.createdAt,
}).catch(() => {});
} catch {
// alert creation is best-effort
}
}
return response;
}
@@ -90,6 +123,7 @@ export async function classifyCall(
callerNumber: string,
duration?: number,
timeOfDay?: number,
userId?: string,
) {
const normalized = normalizePhoneNumber(callerNumber);
const features = await extractFeatures({ callerNumber: normalized, duration, timeOfDay });
@@ -138,6 +172,38 @@ export async function classifyCall(
await logAudit(undefined, "classifyCall", { callerNumber: normalized, duration, timeOfDay }, response, confidence);
if (isSpam && confidence >= 0.8 && userId) {
const sourceAlertId = `spamshield:call:${normalized}:${Date.now()}`;
try {
const [na] = await db
.insert(normalizedAlerts)
.values({
source: "SPAMSHIELD",
category: "SPAM_CALL",
severity: "MEDIUM",
userId,
title: "Spam Call Blocked",
description: `Blocked spam call from ${normalized}`,
entities: { callerNumber: normalized, confidence, ruleMatch: !!ruleMatch },
sourceAlertId,
createdAt: new Date(),
})
.returning();
publishAlert(userId, {
id: na.id,
title: "Spam Call Blocked",
message: `Blocked spam call from ${normalized}`,
severity: "MEDIUM",
source: "SPAMSHIELD",
category: "SPAM_CALL",
createdAt: na.createdAt,
}).catch(() => {});
} catch {
// alert creation is best-effort
}
}
return response;
}

View File

@@ -10,6 +10,7 @@ import {
normalizedAlerts,
} from "~/server/db/schema";
import { saveAudio, getAudioUrl, deleteFile } from "./voiceprint/storage";
import { publishAlert } from "~/server/services/alert.publisher";
import {
preprocessAudio,
detectSynthetic,
@@ -100,17 +101,33 @@ async function createVoiceAlert(
if (sub) {
const category = verdict === "SYNTHETIC" ? "SYNTHETIC_VOICE" : "VOICE_MISMATCH";
await db.insert(normalizedAlerts).values({
const title = verdict === "SYNTHETIC" ? "Synthetic Voice Detected" : "Voice Mismatch Detected";
const description = `Analysis ${analysisId} returned verdict ${verdict} with ${(confidence * 100).toFixed(1)}% confidence`;
const [alert] = await db
.insert(normalizedAlerts)
.values({
source: "VOICEPRINT",
category,
severity: verdict === "SYNTHETIC" ? "HIGH" : "MEDIUM",
userId,
title,
description,
entities: { analysisId, verdict, confidence },
sourceAlertId: `voiceprint-${analysisId}`,
createdAt: new Date(),
})
.returning();
publishAlert(userId, {
id: alert.id,
title,
message: description,
severity: verdict === "SYNTHETIC" ? "HIGH" : "MEDIUM",
source: "VOICEPRINT",
category,
severity: verdict === "SYNTHETIC" ? "HIGH" : "MEDIUM",
userId,
title: verdict === "SYNTHETIC" ? "Synthetic Voice Detected" : "Voice Mismatch Detected",
description: `Analysis ${analysisId} returned verdict ${verdict} with ${(confidence * 100).toFixed(1)}% confidence`,
entities: { analysisId, verdict, confidence },
sourceAlertId: `voiceprint-${analysisId}`,
createdAt: new Date(),
});
createdAt: alert.createdAt,
}).catch((err) => console.error("[voiceprint] Failed to publish alert:", err));
}
} catch (err) {
console.error("[voiceprint] Failed to create alert:", err);

View File

@@ -0,0 +1,95 @@
// @vitest-environment node
import { describe, it, expect, vi, beforeAll, afterAll } from "vitest";
const mockVerifyJWT = vi.fn();
vi.mock("~/server/auth/jwt", () => ({
verifyJWT: mockVerifyJWT,
signJWT: vi.fn(),
}));
let mockServer: any;
let connectionHandler: ((ws: any, req: any) => void) | null = null;
vi.mock("ws", () => {
mockServer = {
on: vi.fn((event: string, handler: any) => {
if (event === "connection") connectionHandler = handler;
}),
close: vi.fn((cb: () => void) => cb()),
clients: new Set(),
};
return {
WebSocketServer: vi.fn(function (_opts: any, cb?: () => void) {
if (cb) setTimeout(cb, 0);
return mockServer;
}),
WebSocket: { OPEN: 1, CONNECTING: 0, CLOSING: 2, CLOSED: 3 },
};
});
function makeWs() {
const handlers: Record<string, (...args: any[]) => void> = {};
return {
close: vi.fn(),
send: vi.fn(),
ping: vi.fn(),
terminate: vi.fn(),
readyState: 1,
on: vi.fn((event: string, handler: any) => {
handlers[event] = handler;
}),
emit: (event: string, ...args: any[]) => {
handlers[event]?.(...args);
},
};
}
describe("WebSocket server", () => {
beforeAll(async () => {
process.env.WS_PORT = "3099";
const { start } = await import("~/server/websocket");
await start();
}, 15000);
afterAll(async () => {
const { stop } = await import("~/server/websocket");
await stop();
});
it("should reject connection without JWT", async () => {
const ws = makeWs();
await connectionHandler!(ws, { url: "/" });
expect(ws.close).toHaveBeenCalledWith(4001, "Authentication failed");
});
it("should reject connection with invalid JWT", async () => {
mockVerifyJWT.mockRejectedValue(new Error("Invalid token"));
const ws = makeWs();
await connectionHandler!(ws, { url: "/?token=bad" });
expect(ws.close).toHaveBeenCalledWith(4001, "Authentication failed");
});
it("should accept connection with valid JWT", async () => {
mockVerifyJWT.mockResolvedValue({ sub: "user-1" });
const ws = makeWs();
await connectionHandler!(ws, { url: "/?token=good" });
expect(ws.close).not.toHaveBeenCalled();
});
it("should return false when broadcasting to non-existent user", async () => {
const { broadcastToUser } = await import("~/server/websocket");
const result = broadcastToUser("nonexistent", {
type: "alert",
alert: {
id: "a1", title: "T", message: "M",
severity: "INFO", source: "TEST", category: "TEST",
createdAt: new Date().toISOString(),
},
});
expect(result).toBe(false);
});
});

216
web/src/server/websocket.ts Normal file
View File

@@ -0,0 +1,216 @@
import { WebSocketServer, WebSocket } from "ws";
import type { Server } from "ws";
import { IncomingMessage } from "node:http";
import { URL } from "node:url";
import { verifyJWT } from "~/server/auth/jwt";
const WS_PORT = parseInt(process.env.WS_PORT ?? "3001", 10);
const HEARTBEAT_INTERVAL = 30_000;
const PONG_TIMEOUT = 10_000;
interface AlertMessage {
type: "alert";
alert: {
id: string;
title: string;
message: string;
severity: string;
source: string;
category: string;
createdAt: string;
};
}
interface WsClient extends WebSocket {
userId?: string;
isAlive?: boolean;
pongTimer?: ReturnType<typeof setTimeout>;
}
const userSockets = new Map<string, Set<WsClient>>();
let wss: Server | null = null;
let heartbeatTimer: ReturnType<typeof setInterval> | null = null;
function getTokenFromRequest(req: IncomingMessage): string | null {
const url = new URL(req.url ?? "/", "http://localhost");
return url.searchParams.get("token");
}
async function authenticateConnection(
ws: WsClient,
req: IncomingMessage,
): Promise<string | null> {
const token = getTokenFromRequest(req);
if (!token) return null;
try {
const payload = await verifyJWT<{ sub?: string; userId?: string }>(token);
const userId = payload.sub ?? payload.userId;
if (!userId) return null;
return userId;
} catch {
return null;
}
}
function addSocket(userId: string, ws: WsClient) {
let sockets = userSockets.get(userId);
if (!sockets) {
sockets = new Set();
userSockets.set(userId, sockets);
}
sockets.add(ws);
}
function removeSocket(userId: string, ws: WsClient) {
const sockets = userSockets.get(userId);
if (!sockets) return;
sockets.delete(ws);
if (sockets.size === 0) {
userSockets.delete(userId);
}
}
function heartbeat(ws: WsClient) {
ws.isAlive = true;
}
function startHeartbeat() {
if (heartbeatTimer) clearInterval(heartbeatTimer);
heartbeatTimer = setInterval(() => {
if (!wss) return;
wss.clients.forEach((client) => {
const ws = client as WsClient;
if (ws.isAlive === false) {
ws.terminate();
return;
}
ws.isAlive = false;
ws.ping();
ws.pongTimer = setTimeout(() => {
ws.terminate();
}, PONG_TIMEOUT);
});
}, HEARTBEAT_INTERVAL);
if (heartbeatTimer && typeof heartbeatTimer === "object") {
heartbeatTimer.unref();
}
}
function stopHeartbeat() {
if (heartbeatTimer) {
clearInterval(heartbeatTimer);
heartbeatTimer = null;
}
}
export function broadcastToUser(userId: string, data: AlertMessage) {
const sockets = userSockets.get(userId);
if (!sockets || sockets.size === 0) return false;
const message = JSON.stringify(data);
let sent = false;
for (const ws of sockets) {
if (ws.readyState === WebSocket.OPEN) {
ws.send(message);
sent = true;
}
}
return sent;
}
export function getConnectedUsers(): string[] {
return Array.from(userSockets.keys());
}
export function getConnectionCount(): number {
let count = 0;
for (const sockets of userSockets.values()) {
count += sockets.size;
}
return count;
}
export function start(): Promise<void> {
return new Promise((resolve) => {
if (wss) {
resolve();
return;
}
wss = new WebSocketServer({ port: WS_PORT }, () => {
console.log(`[websocket] Server listening on port ${WS_PORT}`);
resolve();
});
wss.on("connection", async (ws: WsClient, req: IncomingMessage) => {
const userId = await authenticateConnection(ws, req);
if (!userId) {
ws.close(4001, "Authentication failed");
return;
}
ws.userId = userId;
ws.isAlive = true;
addSocket(userId, ws);
ws.on("pong", () => {
heartbeat(ws);
if (ws.pongTimer) {
clearTimeout(ws.pongTimer);
ws.pongTimer = undefined;
}
});
ws.on("message", (data) => {
try {
const msg = JSON.parse(data.toString());
if (msg.type === "ping") {
ws.send(JSON.stringify({ type: "pong" }));
}
} catch {
// ignore invalid messages
}
});
ws.on("close", () => {
if (ws.userId) {
removeSocket(ws.userId, ws);
}
if (ws.pongTimer) {
clearTimeout(ws.pongTimer);
}
});
ws.on("error", (err) => {
console.error("[websocket] Client error:", err.message);
});
});
startHeartbeat();
});
}
export function stop(): Promise<void> {
return new Promise((resolve) => {
stopHeartbeat();
if (!wss) {
resolve();
return;
}
for (const ws of wss.clients) {
ws.close(1001, "Server shutting down");
}
wss.close(() => {
wss = null;
userSockets.clear();
console.log("[websocket] Server stopped");
resolve();
});
});
}