feat: add tRPC auth context, middleware, and protected procedures

- Install jose (JWT) and bcryptjs (password hashing) dependencies
- Create auth utilities: JWT sign/verify, password hash/verify, session management
- Create createTRPCContext that extracts auth from session cookie, Bearer JWT, or x-api-key
- Add publicProcedure, protectedProcedure, adminProcedure, rateLimitedProcedure with middleware
- Wire context builder into SolidStart tRPC API handler
- Update tRPC client to inject auth tokens and handle 401 redirects
- Add unit tests for JWT, password, context builder, and middleware
This commit is contained in:
2026-05-25 15:46:52 -04:00
parent 052e08c17b
commit 71972436b6
13 changed files with 385 additions and 17 deletions

17
pnpm-lock.yaml generated
View File

@@ -52,9 +52,15 @@ importers:
'@typeschema/valibot':
specifier: ^0.13.4
version: 0.13.5(valibot@0.29.0)
bcryptjs:
specifier: ^3.0.3
version: 3.0.3
drizzle-orm:
specifier: ^0.45.2
version: 0.45.2(@types/pg@8.20.0)(pg@8.21.0)
jose:
specifier: ^5
version: 5.10.0
pg:
specifier: ^8.21.0
version: 8.21.0
@@ -1716,6 +1722,10 @@ packages:
engines: {node: '>=6.0.0'}
hasBin: true
bcryptjs@3.0.3:
resolution: {integrity: sha512-GlF5wPWnSa/X5LKM1o0wz0suXIINz1iHRLvTS+sLyi7XPbe5ycmYI3DlZqVGZZtDgl4DmasFg7gOB3JYbphV5g==}
hasBin: true
bidi-js@1.0.3:
resolution: {integrity: sha512-RKshQI1R3YQ+n9YJz2QQ147P66ELpa1FQEg20Dk8oW9t2KgLbpDLLp9aGZ7y8WHSshDknG0bknqGw5/tyCs5tw==}
@@ -2425,6 +2435,9 @@ packages:
resolution: {integrity: sha512-AC/7JofJvZGrrneWNaEnJeOLUx+JlGt7tNa0wZiRPT4MY1wmfKjt2+6O2p2uz2+skll8OZZmJMNqeke7kKbNgQ==}
hasBin: true
jose@5.10.0:
resolution: {integrity: sha512-s+3Al/p9g32Iq+oqXxkW//7jk2Vig6FF1CFqzVXoTUXt2qz89YWbL+OwS17NFYEvxC35n0FKeGO2LGYSxeM2Gg==}
js-tokens@10.0.0:
resolution: {integrity: sha512-lM/UBzQmfJRo9ABXbPWemivdCW8V2G8FHaHdypQaIy523snUjog0W71ayWXTjiR+ixeMyVHN2XcpnTd/liPg/Q==}
@@ -4980,6 +4993,8 @@ snapshots:
baseline-browser-mapping@2.10.32: {}
bcryptjs@3.0.3: {}
bidi-js@1.0.3:
dependencies:
require-from-string: 2.0.2
@@ -5622,6 +5637,8 @@ snapshots:
jiti@2.7.0: {}
jose@5.10.0: {}
js-tokens@10.0.0: {}
js-tokens@4.0.0: {}

View File

@@ -23,7 +23,9 @@
"@trpc/server": "^10.45.2",
"@types/three": "^0.184.1",
"@typeschema/valibot": "^0.13.4",
"bcryptjs": "^3.0.3",
"drizzle-orm": "^0.45.2",
"jose": "^5",
"pg": "^8.21.0",
"solid-js": "^1.9.5",
"tailwindcss": "^4.0.0",

View File

@@ -2,22 +2,36 @@ import {
createTRPCProxyClient,
httpBatchLink,
loggerLink,
} from '@trpc/client';
import { AppRouter } from "~/server/api/root";
} from "@trpc/client";
import type { AppRouter } from "~/server/api/root";
const getBaseUrl = () => {
if (typeof window !== "undefined") return "";
// replace example.com with your actual production url
if (process.env.NODE_ENV === "production") return "https://example.com";
return `http://localhost:${process.env.PORT ?? 3000}`;
};
// create the client, export it
function getAuthToken(): string | null {
if (typeof document === "undefined") return null;
const match = document.cookie.match(/(?:^|;\s*)session_token=([^;]*)/);
if (match) return match[1];
try {
return localStorage.getItem("auth_token");
} catch {
return null;
}
}
export const api = createTRPCProxyClient<AppRouter>({
links: [
// will print out helpful logs when using client
loggerLink(),
// identifies what url will handle trpc requests
httpBatchLink({ url: `${getBaseUrl()}/api/trpc` })
loggerLink(),
httpBatchLink({
url: `${getBaseUrl()}/api/trpc`,
headers: () => {
const token = getAuthToken();
if (!token) return {};
return { Authorization: `Bearer ${token}` };
},
}),
],
});

View File

@@ -1,18 +1,15 @@
import type { APIEvent } from "@solidjs/start/server";
import { fetchRequestHandler } from "@trpc/server/adapters/fetch";
import { appRouter } from "~/server/api/root";
import { createTRPCContext } from "~/server/api/trpc";
const handler = (event: APIEvent) =>
// adapts tRPC to fetch API style requests
fetchRequestHandler({
// the endpoint handling the requests
endpoint: "/api/trpc",
// the request object
req: event.request,
// the router for handling the requests
router: appRouter,
// any arbitrary data that should be available to all actions
createContext: () => event
createContext: ({ req, resHeaders }) =>
createTRPCContext({ req, resHeaders }),
});
export const GET = handler;

View File

@@ -2,7 +2,7 @@ import { exampleRouter } from "./routers/example";
import { createTRPCRouter } from "./utils";
export const appRouter = createTRPCRouter({
example: exampleRouter
example: exampleRouter,
});
export type AppRouter = typeof appRouter;

View File

@@ -0,0 +1,92 @@
import { describe, it, expect, vi } from "vitest";
import { initTRPC, TRPCError } from "@trpc/server";
vi.mock("~/server/db", () => ({
db: {},
}));
vi.mock("~/server/auth/session", () => ({
validateSession: vi.fn(),
}));
vi.mock("~/server/auth/jwt", () => ({
verifyJWT: vi.fn(),
}));
describe("createTRPCContext", () => {
it("should export createTRPCContext function", async () => {
const mod = await import("./trpc");
expect(mod.createTRPCContext).toBeInstanceOf(Function);
});
it("should return anonymous context for unauthenticated requests", async () => {
const { createTRPCContext } = await import("./trpc");
const ctx = await createTRPCContext({
req: new Request("http://localhost:3000/api/trpc"),
});
expect(ctx.user).toBeNull();
expect(ctx.apiKey).toBeNull();
expect(ctx.db).toBeDefined();
});
});
describe("tRPC middleware", () => {
type TestCtx = { user?: { id: string; role: string }; db: object };
it("publicProcedure should allow unauthenticated access", async () => {
const { publicProcedure } = await import("./utils");
const t = initTRPC.context<TestCtx>().create();
const testRouter = t.router({
test: publicProcedure.query(() => "ok"),
});
const caller = t.createCallerFactory(testRouter);
const result = await caller({ db: {} }).test();
expect(result).toBe("ok");
});
it("protectedProcedure should reject unauthenticated requests", async () => {
const { protectedProcedure } = await import("./utils");
const t = initTRPC.context<TestCtx>().create();
const testRouter = t.router({
test: protectedProcedure.query(() => "ok"),
});
const caller = t.createCallerFactory(testRouter);
await expect(caller({ db: {} }).test()).rejects.toThrow(TRPCError);
});
it("protectedProcedure should allow authenticated requests", async () => {
const { protectedProcedure } = await import("./utils");
const t = initTRPC.context<TestCtx>().create();
const testRouter = t.router({
test: protectedProcedure.query(({ ctx }) => ctx.user?.id),
});
const caller = t.createCallerFactory(testRouter);
const result = await caller({
db: {},
user: { id: "user-1", role: "user" },
}).test();
expect(result).toBe("user-1");
});
it("adminProcedure should reject non-admin users with FORBIDDEN", async () => {
const { adminProcedure } = await import("./utils");
const t = initTRPC.context<TestCtx>().create();
const testRouter = t.router({
test: adminProcedure.query(() => "ok"),
});
const caller = t.createCallerFactory(testRouter);
await expect(
caller({ db: {}, user: { id: "user-1", role: "user" } }).test(),
).rejects.toThrow(TRPCError);
});
it("adminProcedure should reject unauthenticated with UNAUTHORIZED", async () => {
const { adminProcedure } = await import("./utils");
const t = initTRPC.context<TestCtx>().create();
const testRouter = t.router({
test: adminProcedure.query(() => "ok"),
});
const caller = t.createCallerFactory(testRouter);
await expect(caller({ db: {} }).test()).rejects.toThrow(TRPCError);
});
});

View File

@@ -0,0 +1,81 @@
import type { inferAsyncReturnType } from "@trpc/server";
import type { FetchCreateContextFnOptions } from "@trpc/server/adapters/fetch";
import { db } from "~/server/db";
import { verifyJWT } from "~/server/auth/jwt";
import { validateSession } from "~/server/auth/session";
import { users } from "~/server/db/schema/auth";
import { eq } from "drizzle-orm";
export type CreateTRPCContextOptions = {
req: Request;
resHeaders?: Headers;
};
function parseCookies(req: Request): Record<string, string> {
const cookieHeader = req.headers.get("cookie") ?? "";
const cookies: Record<string, string> = {};
for (const cookie of cookieHeader.split(";")) {
const trimmed = cookie.trim();
if (!trimmed) continue;
const idx = trimmed.indexOf("=");
if (idx === -1) {
cookies[trimmed] = "";
} else {
cookies[trimmed.slice(0, idx).trim()] = trimmed.slice(idx + 1).trim();
}
}
return cookies;
}
export async function createTRPCContext(
opts: CreateTRPCContextOptions,
): Promise<{
db: typeof db;
user: typeof users.$inferSelect | null;
apiKey: string | null;
}> {
const { req } = opts;
let userId: string | null = null;
let apiKey: string | null = null;
const cookies = parseCookies(req);
const sessionToken = cookies["session_token"];
if (sessionToken) {
const result = await validateSession(sessionToken);
if (result) {
userId = result.user.id;
}
}
if (!userId) {
const authHeader = req.headers.get("authorization");
if (authHeader?.startsWith("Bearer ")) {
const token = authHeader.slice(7);
try {
const payload = await verifyJWT<{ sub?: string }>(token);
userId = payload.sub ?? null;
} catch {
// Invalid token
}
}
}
if (!userId) {
apiKey = req.headers.get("x-api-key") ?? null;
}
let user: typeof users.$inferSelect | null = null;
if (userId) {
const [found] = await db
.select()
.from(users)
.where(eq(users.id, userId))
.limit(1);
user = found ?? null;
}
return { db, user, apiKey };
}
export type TRPCContext = inferAsyncReturnType<typeof createTRPCContext>;

View File

@@ -1,6 +1,59 @@
import { initTRPC } from "@trpc/server";
import { initTRPC, TRPCError } from "@trpc/server";
import type { TRPCContext } from "./trpc";
export const t = initTRPC.create();
const t = initTRPC.context<TRPCContext>().create();
export const createTRPCRouter = t.router;
export const publicProcedure = t.procedure;
const isAuthed = t.middleware(({ ctx, next }) => {
if (!ctx.user) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
return next({
ctx: { user: ctx.user },
});
});
export const protectedProcedure = t.procedure.use(isAuthed);
const isAdmin = t.middleware(({ ctx, next }) => {
if (!ctx.user) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
if ((ctx.user.role as string) !== "admin") {
throw new TRPCError({ code: "FORBIDDEN" });
}
return next({
ctx: { user: ctx.user },
});
});
export const adminProcedure = t.procedure.use(isAdmin);
const rateLimitMap = new Map<string, { count: number; resetAt: number }>();
const isRateLimited = t.middleware(({ ctx, next }) => {
const identifier = ctx.user?.id ?? ctx.apiKey ?? "anonymous";
const now = Date.now();
const entry = rateLimitMap.get(identifier);
const limit = 100;
const windowMs = 60_000;
if (!entry || now > entry.resetAt) {
rateLimitMap.set(identifier, { count: 1, resetAt: now + windowMs });
return next();
}
if (entry.count >= limit) {
throw new TRPCError({
code: "TOO_MANY_REQUESTS",
message: "Rate limit exceeded",
});
}
entry.count++;
return next();
});
export const rateLimitedProcedure = t.procedure.use(isRateLimited);

View File

@@ -0,0 +1,17 @@
// @vitest-environment node
import { describe, it, expect } from "vitest";
import { signJWT, verifyJWT } from "./jwt";
describe("jwt", () => {
it("should sign and verify a JWT", async () => {
const payload = { sub: "user-123", role: "user" };
const token = await signJWT(payload);
const decoded = await verifyJWT<typeof payload>(token);
expect(decoded.sub).toBe("user-123");
expect(decoded.role).toBe("user");
});
it("should reject an invalid JWT", async () => {
await expect(verifyJWT("invalid.token.here")).rejects.toThrow();
});
});

View File

@@ -0,0 +1,24 @@
import { SignJWT, jwtVerify } from "jose";
function getSecret(): Uint8Array {
const secret = process.env.JWT_SECRET ?? "dev-jwt-secret-change-in-production";
return Buffer.from(secret, "utf-8");
}
export async function signJWT(
payload: Record<string, unknown>,
options?: { expiresIn?: string },
): Promise<string> {
return new SignJWT(payload)
.setProtectedHeader({ alg: "HS256" })
.setIssuedAt()
.setExpirationTime(options?.expiresIn ?? "7d")
.sign(getSecret());
}
export async function verifyJWT<T = Record<string, unknown>>(
token: string,
): Promise<T> {
const { payload } = await jwtVerify(token, getSecret());
return payload as T;
}

View File

@@ -0,0 +1,22 @@
import { describe, it, expect } from "vitest";
import { hashPassword, verifyPassword } from "./password";
describe("password", () => {
it("should hash a password", async () => {
const hash = await hashPassword("secure-password");
expect(hash).toBeTruthy();
expect(hash).not.toBe("secure-password");
});
it("should verify correct password", async () => {
const hash = await hashPassword("secure-password");
const valid = await verifyPassword("secure-password", hash);
expect(valid).toBe(true);
});
it("should reject wrong password", async () => {
const hash = await hashPassword("secure-password");
const valid = await verifyPassword("wrong-password", hash);
expect(valid).toBe(false);
});
});

View File

@@ -0,0 +1,14 @@
import bcrypt from "bcryptjs";
const SALT_ROUNDS = 10;
export async function hashPassword(password: string): Promise<string> {
return bcrypt.hash(password, SALT_ROUNDS);
}
export async function verifyPassword(
password: string,
hash: string,
): Promise<boolean> {
return bcrypt.compare(password, hash);
}

View File

@@ -0,0 +1,35 @@
import { db } from "~/server/db";
import { sessions, users } from "~/server/db/schema/auth";
import { eq, and, gt } from "drizzle-orm";
const SEVEN_DAYS_MS = 7 * 24 * 60 * 60 * 1000;
export async function createSession(
userId: string,
): Promise<typeof sessions.$inferSelect> {
const token = crypto.randomUUID();
const expires = new Date(Date.now() + SEVEN_DAYS_MS);
const [session] = await db
.insert(sessions)
.values({ userId, sessionToken: token, expires })
.returning();
return session;
}
export async function validateSession(
sessionToken: string,
): Promise<{ session: typeof sessions.$inferSelect; user: typeof users.$inferSelect } | null> {
const [result] = await db
.select({ session: sessions, user: users })
.from(sessions)
.where(
and(
eq(sessions.sessionToken, sessionToken),
gt(sessions.expires, new Date()),
),
)
.innerJoin(users, eq(sessions.userId, users.id))
.limit(1);
return result ?? null;
}