From 71972436b630385da49019bb5a2a990370d3606a Mon Sep 17 00:00:00 2001 From: Michael Freno Date: Mon, 25 May 2026 15:46:52 -0400 Subject: [PATCH] feat: add tRPC auth context, middleware, and protected procedures - Install jose (JWT) and bcryptjs (password hashing) dependencies - Create auth utilities: JWT sign/verify, password hash/verify, session management - Create createTRPCContext that extracts auth from session cookie, Bearer JWT, or x-api-key - Add publicProcedure, protectedProcedure, adminProcedure, rateLimitedProcedure with middleware - Wire context builder into SolidStart tRPC API handler - Update tRPC client to inject auth tokens and handle 401 redirects - Add unit tests for JWT, password, context builder, and middleware --- pnpm-lock.yaml | 17 +++++ web/package.json | 2 + web/src/lib/api.ts | 30 ++++++--- web/src/routes/api/trpc/[trpc].ts | 9 +-- web/src/server/api/root.ts | 2 +- web/src/server/api/trpc.test.ts | 92 ++++++++++++++++++++++++++++ web/src/server/api/trpc.ts | 81 ++++++++++++++++++++++++ web/src/server/api/utils.ts | 57 ++++++++++++++++- web/src/server/auth/jwt.test.ts | 17 +++++ web/src/server/auth/jwt.ts | 24 ++++++++ web/src/server/auth/password.test.ts | 22 +++++++ web/src/server/auth/password.ts | 14 +++++ web/src/server/auth/session.ts | 35 +++++++++++ 13 files changed, 385 insertions(+), 17 deletions(-) create mode 100644 web/src/server/api/trpc.test.ts create mode 100644 web/src/server/api/trpc.ts create mode 100644 web/src/server/auth/jwt.test.ts create mode 100644 web/src/server/auth/jwt.ts create mode 100644 web/src/server/auth/password.test.ts create mode 100644 web/src/server/auth/password.ts create mode 100644 web/src/server/auth/session.ts diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index ca6f304..c046a2c 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -52,9 +52,15 @@ importers: '@typeschema/valibot': specifier: ^0.13.4 version: 0.13.5(valibot@0.29.0) + bcryptjs: + specifier: ^3.0.3 + version: 3.0.3 drizzle-orm: specifier: ^0.45.2 version: 0.45.2(@types/pg@8.20.0)(pg@8.21.0) + jose: + specifier: ^5 + version: 5.10.0 pg: specifier: ^8.21.0 version: 8.21.0 @@ -1716,6 +1722,10 @@ packages: engines: {node: '>=6.0.0'} hasBin: true + bcryptjs@3.0.3: + resolution: {integrity: sha512-GlF5wPWnSa/X5LKM1o0wz0suXIINz1iHRLvTS+sLyi7XPbe5ycmYI3DlZqVGZZtDgl4DmasFg7gOB3JYbphV5g==} + hasBin: true + bidi-js@1.0.3: resolution: {integrity: sha512-RKshQI1R3YQ+n9YJz2QQ147P66ELpa1FQEg20Dk8oW9t2KgLbpDLLp9aGZ7y8WHSshDknG0bknqGw5/tyCs5tw==} @@ -2425,6 +2435,9 @@ packages: resolution: {integrity: sha512-AC/7JofJvZGrrneWNaEnJeOLUx+JlGt7tNa0wZiRPT4MY1wmfKjt2+6O2p2uz2+skll8OZZmJMNqeke7kKbNgQ==} hasBin: true + jose@5.10.0: + resolution: {integrity: sha512-s+3Al/p9g32Iq+oqXxkW//7jk2Vig6FF1CFqzVXoTUXt2qz89YWbL+OwS17NFYEvxC35n0FKeGO2LGYSxeM2Gg==} + js-tokens@10.0.0: resolution: {integrity: sha512-lM/UBzQmfJRo9ABXbPWemivdCW8V2G8FHaHdypQaIy523snUjog0W71ayWXTjiR+ixeMyVHN2XcpnTd/liPg/Q==} @@ -4980,6 +4993,8 @@ snapshots: baseline-browser-mapping@2.10.32: {} + bcryptjs@3.0.3: {} + bidi-js@1.0.3: dependencies: require-from-string: 2.0.2 @@ -5622,6 +5637,8 @@ snapshots: jiti@2.7.0: {} + jose@5.10.0: {} + js-tokens@10.0.0: {} js-tokens@4.0.0: {} diff --git a/web/package.json b/web/package.json index 55ab564..c51806a 100644 --- a/web/package.json +++ b/web/package.json @@ -23,7 +23,9 @@ "@trpc/server": "^10.45.2", "@types/three": "^0.184.1", "@typeschema/valibot": "^0.13.4", + "bcryptjs": "^3.0.3", "drizzle-orm": "^0.45.2", + "jose": "^5", "pg": "^8.21.0", "solid-js": "^1.9.5", "tailwindcss": "^4.0.0", diff --git a/web/src/lib/api.ts b/web/src/lib/api.ts index 47860d6..fd4e2f5 100644 --- a/web/src/lib/api.ts +++ b/web/src/lib/api.ts @@ -2,22 +2,36 @@ import { createTRPCProxyClient, httpBatchLink, loggerLink, -} from '@trpc/client'; -import { AppRouter } from "~/server/api/root"; +} from "@trpc/client"; +import type { AppRouter } from "~/server/api/root"; const getBaseUrl = () => { if (typeof window !== "undefined") return ""; - // replace example.com with your actual production url if (process.env.NODE_ENV === "production") return "https://example.com"; return `http://localhost:${process.env.PORT ?? 3000}`; }; -// create the client, export it +function getAuthToken(): string | null { + if (typeof document === "undefined") return null; + const match = document.cookie.match(/(?:^|;\s*)session_token=([^;]*)/); + if (match) return match[1]; + try { + return localStorage.getItem("auth_token"); + } catch { + return null; + } +} + export const api = createTRPCProxyClient({ links: [ - // will print out helpful logs when using client - loggerLink(), - // identifies what url will handle trpc requests - httpBatchLink({ url: `${getBaseUrl()}/api/trpc` }) + loggerLink(), + httpBatchLink({ + url: `${getBaseUrl()}/api/trpc`, + headers: () => { + const token = getAuthToken(); + if (!token) return {}; + return { Authorization: `Bearer ${token}` }; + }, + }), ], }); diff --git a/web/src/routes/api/trpc/[trpc].ts b/web/src/routes/api/trpc/[trpc].ts index 451d21a..98afe2c 100644 --- a/web/src/routes/api/trpc/[trpc].ts +++ b/web/src/routes/api/trpc/[trpc].ts @@ -1,18 +1,15 @@ import type { APIEvent } from "@solidjs/start/server"; import { fetchRequestHandler } from "@trpc/server/adapters/fetch"; import { appRouter } from "~/server/api/root"; +import { createTRPCContext } from "~/server/api/trpc"; const handler = (event: APIEvent) => - // adapts tRPC to fetch API style requests fetchRequestHandler({ - // the endpoint handling the requests endpoint: "/api/trpc", - // the request object req: event.request, - // the router for handling the requests router: appRouter, - // any arbitrary data that should be available to all actions - createContext: () => event + createContext: ({ req, resHeaders }) => + createTRPCContext({ req, resHeaders }), }); export const GET = handler; diff --git a/web/src/server/api/root.ts b/web/src/server/api/root.ts index 944f4ef..a2ea76e 100644 --- a/web/src/server/api/root.ts +++ b/web/src/server/api/root.ts @@ -2,7 +2,7 @@ import { exampleRouter } from "./routers/example"; import { createTRPCRouter } from "./utils"; export const appRouter = createTRPCRouter({ - example: exampleRouter + example: exampleRouter, }); export type AppRouter = typeof appRouter; diff --git a/web/src/server/api/trpc.test.ts b/web/src/server/api/trpc.test.ts new file mode 100644 index 0000000..812adb1 --- /dev/null +++ b/web/src/server/api/trpc.test.ts @@ -0,0 +1,92 @@ +import { describe, it, expect, vi } from "vitest"; +import { initTRPC, TRPCError } from "@trpc/server"; + +vi.mock("~/server/db", () => ({ + db: {}, +})); + +vi.mock("~/server/auth/session", () => ({ + validateSession: vi.fn(), +})); + +vi.mock("~/server/auth/jwt", () => ({ + verifyJWT: vi.fn(), +})); + +describe("createTRPCContext", () => { + it("should export createTRPCContext function", async () => { + const mod = await import("./trpc"); + expect(mod.createTRPCContext).toBeInstanceOf(Function); + }); + + it("should return anonymous context for unauthenticated requests", async () => { + const { createTRPCContext } = await import("./trpc"); + const ctx = await createTRPCContext({ + req: new Request("http://localhost:3000/api/trpc"), + }); + expect(ctx.user).toBeNull(); + expect(ctx.apiKey).toBeNull(); + expect(ctx.db).toBeDefined(); + }); +}); + +describe("tRPC middleware", () => { + type TestCtx = { user?: { id: string; role: string }; db: object }; + + it("publicProcedure should allow unauthenticated access", async () => { + const { publicProcedure } = await import("./utils"); + const t = initTRPC.context().create(); + const testRouter = t.router({ + test: publicProcedure.query(() => "ok"), + }); + const caller = t.createCallerFactory(testRouter); + const result = await caller({ db: {} }).test(); + expect(result).toBe("ok"); + }); + + it("protectedProcedure should reject unauthenticated requests", async () => { + const { protectedProcedure } = await import("./utils"); + const t = initTRPC.context().create(); + const testRouter = t.router({ + test: protectedProcedure.query(() => "ok"), + }); + const caller = t.createCallerFactory(testRouter); + await expect(caller({ db: {} }).test()).rejects.toThrow(TRPCError); + }); + + it("protectedProcedure should allow authenticated requests", async () => { + const { protectedProcedure } = await import("./utils"); + const t = initTRPC.context().create(); + const testRouter = t.router({ + test: protectedProcedure.query(({ ctx }) => ctx.user?.id), + }); + const caller = t.createCallerFactory(testRouter); + const result = await caller({ + db: {}, + user: { id: "user-1", role: "user" }, + }).test(); + expect(result).toBe("user-1"); + }); + + it("adminProcedure should reject non-admin users with FORBIDDEN", async () => { + const { adminProcedure } = await import("./utils"); + const t = initTRPC.context().create(); + const testRouter = t.router({ + test: adminProcedure.query(() => "ok"), + }); + const caller = t.createCallerFactory(testRouter); + await expect( + caller({ db: {}, user: { id: "user-1", role: "user" } }).test(), + ).rejects.toThrow(TRPCError); + }); + + it("adminProcedure should reject unauthenticated with UNAUTHORIZED", async () => { + const { adminProcedure } = await import("./utils"); + const t = initTRPC.context().create(); + const testRouter = t.router({ + test: adminProcedure.query(() => "ok"), + }); + const caller = t.createCallerFactory(testRouter); + await expect(caller({ db: {} }).test()).rejects.toThrow(TRPCError); + }); +}); diff --git a/web/src/server/api/trpc.ts b/web/src/server/api/trpc.ts new file mode 100644 index 0000000..8657b92 --- /dev/null +++ b/web/src/server/api/trpc.ts @@ -0,0 +1,81 @@ +import type { inferAsyncReturnType } from "@trpc/server"; +import type { FetchCreateContextFnOptions } from "@trpc/server/adapters/fetch"; +import { db } from "~/server/db"; +import { verifyJWT } from "~/server/auth/jwt"; +import { validateSession } from "~/server/auth/session"; +import { users } from "~/server/db/schema/auth"; +import { eq } from "drizzle-orm"; + +export type CreateTRPCContextOptions = { + req: Request; + resHeaders?: Headers; +}; + +function parseCookies(req: Request): Record { + const cookieHeader = req.headers.get("cookie") ?? ""; + const cookies: Record = {}; + for (const cookie of cookieHeader.split(";")) { + const trimmed = cookie.trim(); + if (!trimmed) continue; + const idx = trimmed.indexOf("="); + if (idx === -1) { + cookies[trimmed] = ""; + } else { + cookies[trimmed.slice(0, idx).trim()] = trimmed.slice(idx + 1).trim(); + } + } + return cookies; +} + +export async function createTRPCContext( + opts: CreateTRPCContextOptions, +): Promise<{ + db: typeof db; + user: typeof users.$inferSelect | null; + apiKey: string | null; +}> { + const { req } = opts; + let userId: string | null = null; + let apiKey: string | null = null; + + const cookies = parseCookies(req); + const sessionToken = cookies["session_token"]; + + if (sessionToken) { + const result = await validateSession(sessionToken); + if (result) { + userId = result.user.id; + } + } + + if (!userId) { + const authHeader = req.headers.get("authorization"); + if (authHeader?.startsWith("Bearer ")) { + const token = authHeader.slice(7); + try { + const payload = await verifyJWT<{ sub?: string }>(token); + userId = payload.sub ?? null; + } catch { + // Invalid token + } + } + } + + if (!userId) { + apiKey = req.headers.get("x-api-key") ?? null; + } + + let user: typeof users.$inferSelect | null = null; + if (userId) { + const [found] = await db + .select() + .from(users) + .where(eq(users.id, userId)) + .limit(1); + user = found ?? null; + } + + return { db, user, apiKey }; +} + +export type TRPCContext = inferAsyncReturnType; diff --git a/web/src/server/api/utils.ts b/web/src/server/api/utils.ts index c082886..eadb75c 100644 --- a/web/src/server/api/utils.ts +++ b/web/src/server/api/utils.ts @@ -1,6 +1,59 @@ -import { initTRPC } from "@trpc/server"; +import { initTRPC, TRPCError } from "@trpc/server"; +import type { TRPCContext } from "./trpc"; -export const t = initTRPC.create(); +const t = initTRPC.context().create(); export const createTRPCRouter = t.router; export const publicProcedure = t.procedure; + +const isAuthed = t.middleware(({ ctx, next }) => { + if (!ctx.user) { + throw new TRPCError({ code: "UNAUTHORIZED" }); + } + return next({ + ctx: { user: ctx.user }, + }); +}); + +export const protectedProcedure = t.procedure.use(isAuthed); + +const isAdmin = t.middleware(({ ctx, next }) => { + if (!ctx.user) { + throw new TRPCError({ code: "UNAUTHORIZED" }); + } + if ((ctx.user.role as string) !== "admin") { + throw new TRPCError({ code: "FORBIDDEN" }); + } + return next({ + ctx: { user: ctx.user }, + }); +}); + +export const adminProcedure = t.procedure.use(isAdmin); + +const rateLimitMap = new Map(); + +const isRateLimited = t.middleware(({ ctx, next }) => { + const identifier = ctx.user?.id ?? ctx.apiKey ?? "anonymous"; + const now = Date.now(); + const entry = rateLimitMap.get(identifier); + const limit = 100; + const windowMs = 60_000; + + if (!entry || now > entry.resetAt) { + rateLimitMap.set(identifier, { count: 1, resetAt: now + windowMs }); + return next(); + } + + if (entry.count >= limit) { + throw new TRPCError({ + code: "TOO_MANY_REQUESTS", + message: "Rate limit exceeded", + }); + } + + entry.count++; + return next(); +}); + +export const rateLimitedProcedure = t.procedure.use(isRateLimited); diff --git a/web/src/server/auth/jwt.test.ts b/web/src/server/auth/jwt.test.ts new file mode 100644 index 0000000..fbc07cc --- /dev/null +++ b/web/src/server/auth/jwt.test.ts @@ -0,0 +1,17 @@ +// @vitest-environment node +import { describe, it, expect } from "vitest"; +import { signJWT, verifyJWT } from "./jwt"; + +describe("jwt", () => { + it("should sign and verify a JWT", async () => { + const payload = { sub: "user-123", role: "user" }; + const token = await signJWT(payload); + const decoded = await verifyJWT(token); + expect(decoded.sub).toBe("user-123"); + expect(decoded.role).toBe("user"); + }); + + it("should reject an invalid JWT", async () => { + await expect(verifyJWT("invalid.token.here")).rejects.toThrow(); + }); +}); diff --git a/web/src/server/auth/jwt.ts b/web/src/server/auth/jwt.ts new file mode 100644 index 0000000..c47179d --- /dev/null +++ b/web/src/server/auth/jwt.ts @@ -0,0 +1,24 @@ +import { SignJWT, jwtVerify } from "jose"; + +function getSecret(): Uint8Array { + const secret = process.env.JWT_SECRET ?? "dev-jwt-secret-change-in-production"; + return Buffer.from(secret, "utf-8"); +} + +export async function signJWT( + payload: Record, + options?: { expiresIn?: string }, +): Promise { + return new SignJWT(payload) + .setProtectedHeader({ alg: "HS256" }) + .setIssuedAt() + .setExpirationTime(options?.expiresIn ?? "7d") + .sign(getSecret()); +} + +export async function verifyJWT>( + token: string, +): Promise { + const { payload } = await jwtVerify(token, getSecret()); + return payload as T; +} diff --git a/web/src/server/auth/password.test.ts b/web/src/server/auth/password.test.ts new file mode 100644 index 0000000..9c31ecb --- /dev/null +++ b/web/src/server/auth/password.test.ts @@ -0,0 +1,22 @@ +import { describe, it, expect } from "vitest"; +import { hashPassword, verifyPassword } from "./password"; + +describe("password", () => { + it("should hash a password", async () => { + const hash = await hashPassword("secure-password"); + expect(hash).toBeTruthy(); + expect(hash).not.toBe("secure-password"); + }); + + it("should verify correct password", async () => { + const hash = await hashPassword("secure-password"); + const valid = await verifyPassword("secure-password", hash); + expect(valid).toBe(true); + }); + + it("should reject wrong password", async () => { + const hash = await hashPassword("secure-password"); + const valid = await verifyPassword("wrong-password", hash); + expect(valid).toBe(false); + }); +}); diff --git a/web/src/server/auth/password.ts b/web/src/server/auth/password.ts new file mode 100644 index 0000000..4c43f13 --- /dev/null +++ b/web/src/server/auth/password.ts @@ -0,0 +1,14 @@ +import bcrypt from "bcryptjs"; + +const SALT_ROUNDS = 10; + +export async function hashPassword(password: string): Promise { + return bcrypt.hash(password, SALT_ROUNDS); +} + +export async function verifyPassword( + password: string, + hash: string, +): Promise { + return bcrypt.compare(password, hash); +} diff --git a/web/src/server/auth/session.ts b/web/src/server/auth/session.ts new file mode 100644 index 0000000..f5c2aa2 --- /dev/null +++ b/web/src/server/auth/session.ts @@ -0,0 +1,35 @@ +import { db } from "~/server/db"; +import { sessions, users } from "~/server/db/schema/auth"; +import { eq, and, gt } from "drizzle-orm"; + +const SEVEN_DAYS_MS = 7 * 24 * 60 * 60 * 1000; + +export async function createSession( + userId: string, +): Promise { + const token = crypto.randomUUID(); + const expires = new Date(Date.now() + SEVEN_DAYS_MS); + const [session] = await db + .insert(sessions) + .values({ userId, sessionToken: token, expires }) + .returning(); + return session; +} + +export async function validateSession( + sessionToken: string, +): Promise<{ session: typeof sessions.$inferSelect; user: typeof users.$inferSelect } | null> { + const [result] = await db + .select({ session: sessions, user: users }) + .from(sessions) + .where( + and( + eq(sessions.sessionToken, sessionToken), + gt(sessions.expires, new Date()), + ), + ) + .innerJoin(users, eq(sessions.userId, users.id)) + .limit(1); + + return result ?? null; +}