feat(billing): add subscription and Stripe billing router

- Add stripeCustomerId column to users table
- Create Stripe client initialization (web/src/server/stripe.ts)
- Add billing service with getOrCreateCustomer, checkout/portal sessions,
  subscription management, invoice listing, and webhook event handling
- Create billing tRPC router with getSubscription, createCheckoutSession,
  createPortalSession, cancelSubscription, reactivateSubscription, listInvoices
- Add raw webhook endpoint at /api/stripe/webhook with signature verification
- Define Valibot schemas for all billing procedure inputs
- Wire billing router into root tRPC router
- Update schema tests for new column/index counts
- Write unit tests for billing service and router
This commit is contained in:
2026-05-25 16:07:00 -04:00
parent 28c33a930d
commit 40a9ef146c
14 changed files with 1006 additions and 4 deletions

16
pnpm-lock.yaml generated
View File

@@ -67,6 +67,9 @@ importers:
solid-js:
specifier: ^1.9.5
version: 1.9.13
stripe:
specifier: ^22.1.1
version: 22.1.1(@types/node@25.9.1)
tailwindcss:
specifier: ^4.0.0
version: 4.3.0
@@ -3130,6 +3133,15 @@ packages:
strip-literal@3.1.0:
resolution: {integrity: sha512-8r3mkIM/2+PpjHoOtiAW8Rg3jJLHaV7xPwG+YRGrv6FP0wwk/toTpATxWYOW0BKdWwl82VT2tFYi5DlROa0Mxg==}
stripe@22.1.1:
resolution: {integrity: sha512-cmodIYP27tBkJ8G7DuGgWw0PFuemlFZbuF3Wwr1TrjFjUa3T7NIgCe6TVwX8BO2ynu+xtTuDGfHafNDCPt9lXA==}
engines: {node: '>=18'}
peerDependencies:
'@types/node': '>=18'
peerDependenciesMeta:
'@types/node':
optional: true
supports-color@10.2.2:
resolution: {integrity: sha512-SS+jx45GF1QjgEXQx4NJZV9ImqmO2NPz5FNsIHrsDjh2YsHnawpan7SNQ1o8NuhrbHZy9AZhIoCUiCeaW/C80g==}
engines: {node: '>=18'}
@@ -6431,6 +6443,10 @@ snapshots:
dependencies:
js-tokens: 9.0.1
stripe@22.1.1(@types/node@25.9.1):
optionalDependencies:
'@types/node': 25.9.1
supports-color@10.2.2: {}
supports-color@7.2.0:

View File

@@ -1 +1,8 @@
DATABASE_URL="postgresql://postgres:postgres@localhost:5432/shieldai"
# Stripe (get test keys from https://dashboard.stripe.com/test/apikeys)
STRIPE_SECRET_KEY="sk_test_..."
STRIPE_WEBHOOK_SECRET="whsec_..."
STRIPE_PRICE_BASIC="price_basic"
STRIPE_PRICE_PLUS="price_plus"
STRIPE_PRICE_PREMIUM="price_premium"

View File

@@ -28,6 +28,7 @@
"jose": "^5",
"pg": "^8.21.0",
"solid-js": "^1.9.5",
"stripe": "^22.1.1",
"tailwindcss": "^4.0.0",
"three": "^0.184.0",
"valibot": "^0.29.0",

View File

@@ -0,0 +1,27 @@
import type { APIEvent } from "@solidjs/start/server";
import { stripe } from "~/server/stripe";
import { handleWebhookEvent } from "~/server/services/billing.service";
export async function POST(event: APIEvent) {
const body = await event.request.text();
const signature = event.request.headers.get("stripe-signature");
if (!signature) {
return new Response("Missing stripe-signature header", { status: 400 });
}
try {
const webhookEvent = stripe.webhooks.constructEvent(
body,
signature,
process.env.STRIPE_WEBHOOK_SECRET ?? "",
);
await handleWebhookEvent(webhookEvent);
return new Response("OK", { status: 200 });
} catch (err) {
const message = err instanceof Error ? err.message : "Webhook error";
return new Response(message, { status: 400 });
}
}

View File

@@ -1,10 +1,12 @@
import { exampleRouter } from "./routers/example";
import { userRouter } from "./routers/user";
import { billingRouter } from "./routers/billing";
import { createTRPCRouter } from "./utils";
export const appRouter = createTRPCRouter({
example: exampleRouter,
user: userRouter,
billing: billingRouter,
});
export type AppRouter = typeof appRouter;

View File

@@ -0,0 +1,241 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { initTRPC, TRPCError } from "@trpc/server";
import { wrap } from "@typeschema/valibot";
import {
CreateCheckoutSessionSchema,
CreatePortalSessionSchema,
CancelSubscriptionSchema,
ReactivateSubscriptionSchema,
ListInvoicesSchema,
} from "../schemas/billing";
vi.mock("~/server/services/billing.service", () => ({
getOrCreateCustomer: vi.fn(),
createCheckoutSession: vi.fn(),
createPortalSession: vi.fn(),
cancelSubscription: vi.fn(),
reactivateSubscription: vi.fn(),
listInvoices: vi.fn(),
}));
const { mockFindFirst } = vi.hoisted(() => ({
mockFindFirst: vi.fn(),
}));
vi.mock("~/server/db", () => ({
db: {
query: {
subscriptions: {
findFirst: mockFindFirst,
},
},
},
}));
import {
createCheckoutSession,
createPortalSession,
cancelSubscription,
reactivateSubscription,
listInvoices,
} from "~/server/services/billing.service";
import { db } from "~/server/db";
const mockCreateCheckoutSession = vi.mocked(createCheckoutSession);
const mockCreatePortalSession = vi.mocked(createPortalSession);
const mockCancelSubscription = vi.mocked(cancelSubscription);
const mockReactivateSubscription = vi.mocked(reactivateSubscription);
const mockListInvoices = vi.mocked(listInvoices);
const mockDb = vi.mocked(db);
type User = {
id: string; email: string; name: string | null; image: string | null;
role: "user" | "family_admin" | "family_member" | "support"; emailVerified: Date | null; deletedAt: Date | null;
stripeCustomerId: string | null;
createdAt: Date; updatedAt: Date;
};
type Ctx = { db: object; user: User | null; apiKey: string | null };
const baseUser: User = {
id: "user-1", email: "a@b.com", name: "Test", image: null,
role: "user", emailVerified: null, deletedAt: null,
stripeCustomerId: "cus_123",
createdAt: new Date(), updatedAt: new Date(),
};
function makeUser(overrides: Partial<User> = {}): User {
return { ...baseUser, ...overrides };
}
function createCaller(user: User | null) {
const t = initTRPC.context<Ctx>().create();
const isAuthed = t.middleware(({ ctx, next }) => {
if (!ctx.user) throw new TRPCError({ code: "UNAUTHORIZED" });
return next({ ctx: { ...ctx, user: ctx.user } });
});
const router = t.router({
getSubscription: t.procedure.use(isAuthed)
.query(async () => {
const sub = await mockFindFirst();
return sub ?? null;
}),
createCheckoutSession: t.procedure.use(isAuthed)
.input(wrap(CreateCheckoutSessionSchema))
.mutation(async ({ ctx, input }) => {
const i = input as { priceId: string; successUrl: string; cancelUrl: string };
return mockCreateCheckoutSession(ctx.user.id, ctx.user.email, i.priceId, i.successUrl, i.cancelUrl);
}),
createPortalSession: t.procedure.use(isAuthed)
.input(wrap(CreatePortalSessionSchema))
.mutation(async ({ ctx, input }) => {
const i = input as { returnUrl: string };
if (!ctx.user.stripeCustomerId) {
throw new TRPCError({ code: "NOT_FOUND", message: "No Stripe customer found" });
}
return mockCreatePortalSession(ctx.user.stripeCustomerId, i.returnUrl);
}),
cancelSubscription: t.procedure.use(isAuthed)
.input(wrap(CancelSubscriptionSchema))
.mutation(async ({ input }) => {
const i = input as { subscriptionId: string };
return mockCancelSubscription(i.subscriptionId);
}),
reactivateSubscription: t.procedure.use(isAuthed)
.input(wrap(ReactivateSubscriptionSchema))
.mutation(async ({ input }) => {
const i = input as { subscriptionId: string };
return mockReactivateSubscription(i.subscriptionId);
}),
listInvoices: t.procedure.use(isAuthed)
.input(wrap(ListInvoicesSchema))
.query(async ({ ctx, input }) => {
if (!ctx.user.stripeCustomerId) {
return { invoices: [], hasMore: false };
}
const i = input as { limit?: string; startingAfter?: string };
return mockListInvoices(ctx.user.stripeCustomerId, parseInt(i.limit ?? "10", 10), i.startingAfter);
}),
});
const caller = t.createCallerFactory(router);
return caller({ db: {} as never, user, apiKey: null });
}
beforeEach(() => {
vi.clearAllMocks();
});
describe("billing.getSubscription", () => {
it("returns subscription for authenticated user", async () => {
const now = new Date();
(mockDb.query.subscriptions.findFirst as ReturnType<typeof vi.fn>).mockResolvedValue({
id: "sub-1", userId: "user-1", stripeId: "sub_stripe_1",
tier: "premium", status: "active",
currentPeriodStart: now, currentPeriodEnd: now,
cancelAtPeriodEnd: false,
createdAt: now, updatedAt: now,
});
const api = createCaller(makeUser());
const result = await api.getSubscription();
expect(result).not.toBeNull();
expect(result!.tier).toBe("premium");
expect(result!.status).toBe("active");
});
it("returns null when user has no subscription", async () => {
(mockDb.query.subscriptions.findFirst as ReturnType<typeof vi.fn>).mockResolvedValue(undefined);
const api = createCaller(makeUser());
const result = await api.getSubscription();
expect(result).toBeNull();
});
it("rejects unauthenticated users", async () => {
const api = createCaller(null);
await expect(api.getSubscription()).rejects.toThrow(TRPCError);
});
});
describe("billing.createCheckoutSession", () => {
it("creates checkout session and returns URL", async () => {
mockCreateCheckoutSession.mockResolvedValue({
url: "https://checkout.stripe.com/session_123",
sessionId: "session_123",
});
const api = createCaller(makeUser());
const result = await api.createCheckoutSession({
priceId: "price_basic",
successUrl: "https://example.com/success",
cancelUrl: "https://example.com/cancel",
});
expect(result.url).toBe("https://checkout.stripe.com/session_123");
});
});
describe("billing.createPortalSession", () => {
it("creates portal session for user with stripeCustomerId", async () => {
mockCreatePortalSession.mockResolvedValue({
url: "https://billing.stripe.com/portal/session_456",
});
const api = createCaller(makeUser());
const result = await api.createPortalSession({
returnUrl: "https://example.com/return",
});
expect(result.url).toBe("https://billing.stripe.com/portal/session_456");
});
it("throws NOT_FOUND when user has no stripeCustomerId", async () => {
const api = createCaller(makeUser({ stripeCustomerId: null }));
await expect(api.createPortalSession({ returnUrl: "https://example.com/return" })).rejects.toThrow(TRPCError);
});
});
describe("billing.cancelSubscription", () => {
it("cancels subscription", async () => {
mockCancelSubscription.mockResolvedValue({ cancelAtPeriodEnd: true });
const api = createCaller(makeUser());
const result = await api.cancelSubscription({ subscriptionId: "sub_123" });
expect(result.cancelAtPeriodEnd).toBe(true);
});
});
describe("billing.reactivateSubscription", () => {
it("reactivates subscription", async () => {
mockReactivateSubscription.mockResolvedValue({ cancelAtPeriodEnd: false });
const api = createCaller(makeUser());
const result = await api.reactivateSubscription({ subscriptionId: "sub_123" });
expect(result.cancelAtPeriodEnd).toBe(false);
});
});
describe("billing.listInvoices", () => {
it("lists invoices for user with stripeCustomerId", async () => {
mockListInvoices.mockResolvedValue({
invoices: [{ id: "in_1" }, { id: "in_2" }] as never,
hasMore: false,
});
const api = createCaller(makeUser());
const result = await api.listInvoices({});
expect(result.invoices).toHaveLength(2);
});
it("returns empty list when user has no stripeCustomerId", async () => {
const api = createCaller(makeUser({ stripeCustomerId: null }));
const result = await api.listInvoices({});
expect(result.invoices).toHaveLength(0);
expect(result.hasMore).toBe(false);
});
});

View File

@@ -0,0 +1,100 @@
import { TRPCError } from "@trpc/server";
import { eq } from "drizzle-orm";
import { wrap } from "@typeschema/valibot";
import { createTRPCRouter, protectedProcedure } from "../utils";
import {
CreateCheckoutSessionSchema,
CreatePortalSessionSchema,
CancelSubscriptionSchema,
ReactivateSubscriptionSchema,
ListInvoicesSchema,
} from "../schemas/billing";
import {
getOrCreateCustomer,
createCheckoutSession,
createPortalSession,
cancelSubscription,
reactivateSubscription,
listInvoices,
} from "~/server/services/billing.service";
import { db } from "~/server/db";
import { subscriptions } from "~/server/db/schema/subscription";
export const billingRouter = createTRPCRouter({
getSubscription: protectedProcedure.query(async ({ ctx }) => {
const sub = await db.query.subscriptions.findFirst({
where: eq(subscriptions.userId, ctx.user.id),
});
return sub ?? null;
}),
createCheckoutSession: protectedProcedure
.input(wrap(CreateCheckoutSessionSchema))
.mutation(async ({ ctx, input }) => {
const allowedPrices = [
process.env.STRIPE_PRICE_BASIC,
process.env.STRIPE_PRICE_PLUS,
process.env.STRIPE_PRICE_PREMIUM,
].filter(Boolean);
if (!allowedPrices.includes(input.priceId)) {
throw new TRPCError({
code: "BAD_REQUEST",
message: "Invalid price ID",
});
}
return createCheckoutSession(
ctx.user.id,
ctx.user.email,
input.priceId,
input.successUrl,
input.cancelUrl,
);
}),
createPortalSession: protectedProcedure
.input(wrap(CreatePortalSessionSchema))
.mutation(async ({ ctx, input }) => {
const user = ctx.user;
const stripeCustomerId = user.stripeCustomerId;
if (!stripeCustomerId) {
throw new TRPCError({
code: "NOT_FOUND",
message: "No Stripe customer found",
});
}
return createPortalSession(stripeCustomerId, input.returnUrl);
}),
cancelSubscription: protectedProcedure
.input(wrap(CancelSubscriptionSchema))
.mutation(async ({ input }) => {
return cancelSubscription(input.subscriptionId);
}),
reactivateSubscription: protectedProcedure
.input(wrap(ReactivateSubscriptionSchema))
.mutation(async ({ input }) => {
return reactivateSubscription(input.subscriptionId);
}),
listInvoices: protectedProcedure
.input(wrap(ListInvoicesSchema))
.query(async ({ ctx, input }) => {
const user = ctx.user;
const stripeCustomerId = user.stripeCustomerId;
if (!stripeCustomerId) {
return { invoices: [], hasMore: false };
}
return listInvoices(
stripeCustomerId,
parseInt(input.limit ?? "10", 10),
input.startingAfter,
);
}),
});

View File

@@ -30,6 +30,7 @@ const mockUpdateMemberRole = vi.mocked(updateMemberRole);
type User = {
id: string; email: string; name: string | null; image: string | null;
role: "user" | "family_admin" | "family_member" | "support"; emailVerified: Date | null; deletedAt: Date | null;
stripeCustomerId: string | null;
createdAt: Date; updatedAt: Date;
};
type Ctx = { db: object; user: User | null; apiKey: string | null };
@@ -97,6 +98,7 @@ function createCaller(user: User | null) {
const baseUser: User = {
id: "user-1", email: "a@b.com", name: "Test", image: null,
role: "user", emailVerified: null, deletedAt: null,
stripeCustomerId: null,
createdAt: new Date(), updatedAt: new Date(),
};

View File

@@ -0,0 +1,24 @@
import { object, string, url, minLength, optional, picklist } from "valibot";
export const CreateCheckoutSessionSchema = object({
priceId: string([minLength(1)]),
successUrl: string([url()]),
cancelUrl: string([url()]),
});
export const CreatePortalSessionSchema = object({
returnUrl: string([url()]),
});
export const CancelSubscriptionSchema = object({
subscriptionId: string([minLength(1)]),
});
export const ReactivateSubscriptionSchema = object({
subscriptionId: string([minLength(1)]),
});
export const ListInvoicesSchema = object({
limit: optional(string(), "10"),
startingAfter: optional(string()),
});

View File

@@ -76,17 +76,18 @@ describe("users table", () => {
expect(colNames).toContain("name");
expect(colNames).toContain("image");
expect(colNames).toContain("role");
expect(colNames).toContain("stripe_customer_id");
expect(colNames).toContain("deleted_at");
expect(colNames).toContain("created_at");
expect(colNames).toContain("updated_at");
});
it("has 9 columns", () => {
expect(config.columns).toHaveLength(9);
it("has 10 columns", () => {
expect(config.columns).toHaveLength(10);
});
it("has 2 indexes", () => {
expect(config.indexes.length).toBe(2);
it("has 3 indexes", () => {
expect(config.indexes.length).toBe(3);
});
});

View File

@@ -8,12 +8,14 @@ export const users = pgTable("users", {
name: text("name"),
image: text("image"),
role: userRole("role").default("user").notNull(),
stripeCustomerId: text("stripe_customer_id"),
deletedAt: timestamp("deleted_at", { withTimezone: true, mode: "date" }),
createdAt: timestamp("created_at", { withTimezone: true, mode: "date" }).defaultNow().notNull(),
updatedAt: timestamp("updated_at", { withTimezone: true, mode: "date" }).defaultNow().notNull().$onUpdate(() => new Date()),
}, (table) => ({
emailIdx: index("users_email_idx").on(table.email),
roleIdx: index("users_role_idx").on(table.role),
stripeCustomerIdIdx: index("users_stripe_customer_id_idx").on(table.stripeCustomerId),
}));
export const accounts = pgTable("accounts", {

View File

@@ -0,0 +1,338 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
vi.mock("~/server/stripe", () => ({
stripe: {
customers: { create: vi.fn() },
checkout: { sessions: { create: vi.fn() } },
billingPortal: { sessions: { create: vi.fn() } },
subscriptions: { update: vi.fn(), retrieve: vi.fn() },
invoices: { list: vi.fn() },
webhooks: { constructEvent: vi.fn() },
},
}));
vi.mock("~/server/db", () => ({
db: {
select: vi.fn(),
insert: vi.fn(),
update: vi.fn(),
query: {
subscriptions: {
findFirst: vi.fn(),
},
},
},
}));
import { stripe } from "~/server/stripe";
import { db } from "~/server/db";
import {
getOrCreateCustomer,
createCheckoutSession,
createPortalSession,
cancelSubscription,
reactivateSubscription,
listInvoices,
handleWebhookEvent,
} from "./billing.service";
beforeEach(() => {
vi.clearAllMocks();
});
describe("getOrCreateCustomer", () => {
it("returns existing stripeCustomerId if present", async () => {
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([
{ id: "u1", email: "a@b.com", stripeCustomerId: "cus_existing" },
]),
}),
}),
});
const result = await getOrCreateCustomer("u1", "a@b.com");
expect(result).toBe("cus_existing");
expect(stripe.customers.create).not.toHaveBeenCalled();
});
it("creates a new Stripe customer when no stripeCustomerId", async () => {
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([
{ id: "u1", email: "a@b.com", stripeCustomerId: null },
]),
}),
}),
});
(stripe.customers.create as ReturnType<typeof vi.fn>).mockResolvedValue({ id: "cus_new" });
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockResolvedValue([{ id: "u1" }]),
}),
});
const result = await getOrCreateCustomer("u1", "a@b.com");
expect(result).toBe("cus_new");
expect(stripe.customers.create).toHaveBeenCalledWith({
email: "a@b.com",
metadata: { userId: "u1" },
});
});
it("throws NOT_FOUND for missing user", async () => {
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([]),
}),
}),
});
await expect(getOrCreateCustomer("u-missing", "x@y.com")).rejects.toThrow(
"User not found",
);
});
});
describe("createCheckoutSession", () => {
it("creates a Stripe checkout session and returns URL", async () => {
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([
{ id: "u1", email: "a@b.com", stripeCustomerId: "cus_123" },
]),
}),
}),
});
(stripe.checkout.sessions.create as ReturnType<typeof vi.fn>).mockResolvedValue({
url: "https://checkout.stripe.com/session_123",
id: "session_123",
});
const result = await createCheckoutSession(
"u1",
"a@b.com",
"price_basic",
"https://example.com/success",
"https://example.com/cancel",
);
expect(result.url).toBe("https://checkout.stripe.com/session_123");
expect(result.sessionId).toBe("session_123");
});
});
describe("createPortalSession", () => {
it("creates a Stripe billing portal session", async () => {
(stripe.billingPortal.sessions.create as ReturnType<typeof vi.fn>).mockResolvedValue({
url: "https://billing.stripe.com/portal/session_456",
});
const result = await createPortalSession(
"cus_123",
"https://example.com/return",
);
expect(result.url).toBe("https://billing.stripe.com/portal/session_456");
});
});
describe("cancelSubscription", () => {
it("sets cancel_at_period_end on Stripe subscription", async () => {
(stripe.subscriptions.update as ReturnType<typeof vi.fn>).mockResolvedValue({});
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockResolvedValue([]),
}),
});
const result = await cancelSubscription("sub_123");
expect(result.cancelAtPeriodEnd).toBe(true);
expect(stripe.subscriptions.update).toHaveBeenCalledWith("sub_123", {
cancel_at_period_end: true,
});
});
});
describe("reactivateSubscription", () => {
it("removes cancel_at_period_end on Stripe subscription", async () => {
(stripe.subscriptions.update as ReturnType<typeof vi.fn>).mockResolvedValue({});
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockResolvedValue([]),
}),
});
const result = await reactivateSubscription("sub_123");
expect(result.cancelAtPeriodEnd).toBe(false);
expect(stripe.subscriptions.update).toHaveBeenCalledWith("sub_123", {
cancel_at_period_end: false,
});
});
});
describe("listInvoices", () => {
it("returns invoices list from Stripe", async () => {
(stripe.invoices.list as ReturnType<typeof vi.fn>).mockResolvedValue({
data: [{ id: "in_1" }, { id: "in_2" }],
has_more: false,
});
const result = await listInvoices("cus_123", 10);
expect(result.invoices).toHaveLength(2);
expect(result.hasMore).toBe(false);
});
});
describe("handleWebhookEvent", () => {
it("handles checkout.session.completed", async () => {
(db.insert as ReturnType<typeof vi.fn>).mockReturnValue({
values: vi.fn().mockReturnValue({
onConflictDoNothing: vi.fn().mockResolvedValue(undefined),
}),
});
(stripe.subscriptions.retrieve as ReturnType<typeof vi.fn>).mockResolvedValue({
id: "sub_new",
items: { data: [{ price: { id: "price_premium" } }] },
current_period_start: 1700000000,
current_period_end: 1702592000,
status: "active",
cancel_at_period_end: false,
});
await handleWebhookEvent({
type: "checkout.session.completed",
data: {
object: {
metadata: { userId: "u1" },
subscription: "sub_new",
},
},
} as never);
expect(db.insert).toHaveBeenCalled();
});
it("handles invoice.paid", async () => {
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([{ id: "sub_db_1" }]),
}),
}),
});
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
returning: vi.fn().mockResolvedValue([{ id: "sub_db_1", status: "active" }]),
}),
}),
});
await handleWebhookEvent({
type: "invoice.paid",
data: {
object: {
subscription: "sub_123",
},
},
} as never);
});
it("handles invoice.payment_failed", async () => {
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([{ id: "sub_db_1" }]),
}),
}),
});
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
returning: vi.fn().mockResolvedValue([{ id: "sub_db_1", status: "past_due" }]),
}),
}),
});
await handleWebhookEvent({
type: "invoice.payment_failed",
data: {
object: {
subscription: "sub_123",
},
},
} as never);
});
it("handles customer.subscription.updated", async () => {
(db.query.subscriptions.findFirst as ReturnType<typeof vi.fn>).mockResolvedValue(null);
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([]),
}),
}),
});
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
returning: vi.fn().mockResolvedValue([{ id: "sub_db_1", status: "active" }]),
}),
}),
});
await handleWebhookEvent({
type: "customer.subscription.updated",
data: {
object: {
id: "sub_123",
metadata: { userId: "u1" },
items: { data: [{ price: { id: "price_plus" } }] },
current_period_start: 1700000000,
current_period_end: 1702592000,
status: "active",
cancel_at_period_end: false,
},
},
} as never);
});
it("handles customer.subscription.deleted", async () => {
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([{ id: "sub_db_1" }]),
}),
}),
});
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
returning: vi.fn().mockResolvedValue([{ id: "sub_db_1", status: "canceled" }]),
}),
}),
});
await handleWebhookEvent({
type: "customer.subscription.deleted",
data: {
object: {
id: "sub_123",
},
},
} as never);
});
});

View File

@@ -0,0 +1,235 @@
import { TRPCError } from "@trpc/server";
import { eq } from "drizzle-orm";
import { db } from "~/server/db";
import { stripe } from "~/server/stripe";
import { users } from "~/server/db/schema/auth";
import { subscriptions } from "~/server/db/schema/subscription";
import type Stripe from "stripe";
type Tier = "basic" | "plus" | "premium";
export async function getOrCreateCustomer(userId: string, email: string) {
const [user] = await db
.select()
.from(users)
.where(eq(users.id, userId))
.limit(1);
if (!user) {
throw new TRPCError({ code: "NOT_FOUND", message: "User not found" });
}
if (user.stripeCustomerId) {
return user.stripeCustomerId;
}
const customer = await stripe.customers.create({
email,
metadata: { userId },
});
await db
.update(users)
.set({ stripeCustomerId: customer.id })
.where(eq(users.id, userId));
return customer.id;
}
export async function createCheckoutSession(
userId: string,
email: string,
priceId: string,
successUrl: string,
cancelUrl: string,
) {
const customerId = await getOrCreateCustomer(userId, email);
const session = await stripe.checkout.sessions.create({
customer: customerId,
mode: "subscription",
line_items: [{ price: priceId, quantity: 1 }],
success_url: successUrl,
cancel_url: cancelUrl,
metadata: { userId },
});
return { url: session.url, sessionId: session.id };
}
export async function createPortalSession(customerId: string, returnUrl: string) {
const session = await stripe.billingPortal.sessions.create({
customer: customerId,
return_url: returnUrl,
});
return { url: session.url };
}
export async function cancelSubscription(stripeSubscriptionId: string) {
await stripe.subscriptions.update(stripeSubscriptionId, {
cancel_at_period_end: true,
});
await db
.update(subscriptions)
.set({ cancelAtPeriodEnd: true })
.where(eq(subscriptions.stripeId, stripeSubscriptionId));
return { cancelAtPeriodEnd: true };
}
export async function reactivateSubscription(stripeSubscriptionId: string) {
await stripe.subscriptions.update(stripeSubscriptionId, {
cancel_at_period_end: false,
});
await db
.update(subscriptions)
.set({ cancelAtPeriodEnd: false })
.where(eq(subscriptions.stripeId, stripeSubscriptionId));
return { cancelAtPeriodEnd: false };
}
export async function listInvoices(
customerId: string,
limit: number = 10,
startingAfter?: string,
) {
const params: Stripe.InvoiceListParams = {
customer: customerId,
limit,
};
if (startingAfter) {
params.starting_after = startingAfter;
}
const invoices = await stripe.invoices.list(params);
return {
invoices: invoices.data,
hasMore: invoices.has_more,
};
}
export async function updateSubscriptionInDB(
stripeId: string,
data: {
tier?: Tier;
status?: string;
currentPeriodStart?: Date;
currentPeriodEnd?: Date;
cancelAtPeriodEnd?: boolean;
},
) {
const [existing] = await db
.select()
.from(subscriptions)
.where(eq(subscriptions.stripeId, stripeId))
.limit(1);
if (existing) {
const [updated] = await db
.update(subscriptions)
.set(data as Record<string, unknown>)
.where(eq(subscriptions.stripeId, stripeId))
.returning();
return updated;
}
return null;
}
export async function handleWebhookEvent(event: Stripe.Event) {
const obj = event.data.object as unknown as Record<string, unknown>;
switch (event.type) {
case "checkout.session.completed": {
const session = obj as unknown as Stripe.Checkout.Session;
const userId = session.metadata?.userId;
if (!userId || !session.subscription) break;
const stripeSub = await stripe.subscriptions.retrieve(
session.subscription as string,
);
const sub = stripeSub as unknown as Record<string, unknown>;
await db.insert(subscriptions).values({
userId,
stripeId: stripeSub.id,
tier: mapStripeProductToTier(
stripeSub.items.data[0]?.price?.id ?? "",
),
status: sub.status as typeof subscriptions.$inferSelect.status,
currentPeriodStart: new Date((sub.current_period_start as number) * 1000),
currentPeriodEnd: new Date((sub.current_period_end as number) * 1000),
cancelAtPeriodEnd: sub.cancel_at_period_end as boolean,
}).onConflictDoNothing();
break;
}
case "invoice.paid": {
const invoice = obj;
if (!invoice.subscription) break;
await updateSubscriptionInDB(invoice.subscription as string, {
status: "active",
});
break;
}
case "invoice.payment_failed": {
const invoice = obj;
if (!invoice.subscription) break;
await updateSubscriptionInDB(invoice.subscription as string, {
status: "past_due",
});
break;
}
case "customer.subscription.updated": {
const stripeSub = obj as unknown as Stripe.Subscription;
const userId = stripeSub.metadata?.userId;
const sub = stripeSub as unknown as Record<string, unknown>;
if (!userId) {
const [existingSub] = await db
.select()
.from(subscriptions)
.where(eq(subscriptions.stripeId, stripeSub.id))
.limit(1);
if (!existingSub) break;
}
const tier = stripeSub.items.data[0]?.price?.id
? mapStripeProductToTier(stripeSub.items.data[0].price.id)
: undefined;
await updateSubscriptionInDB(stripeSub.id, {
tier,
status: sub.status as string,
currentPeriodStart: new Date((sub.current_period_start as number) * 1000),
currentPeriodEnd: new Date((sub.current_period_end as number) * 1000),
cancelAtPeriodEnd: sub.cancel_at_period_end as boolean,
});
break;
}
case "customer.subscription.deleted": {
const stripeSub = obj as unknown as Stripe.Subscription;
await updateSubscriptionInDB(stripeSub.id, {
status: "canceled",
});
break;
}
}
}
function mapStripeProductToTier(priceId: string): Tier {
if (priceId === process.env.STRIPE_PRICE_BASIC) return "basic";
if (priceId === process.env.STRIPE_PRICE_PLUS) return "plus";
if (priceId === process.env.STRIPE_PRICE_PREMIUM) return "premium";
return "basic";
}

6
web/src/server/stripe.ts Normal file
View File

@@ -0,0 +1,6 @@
import Stripe from "stripe";
export const stripe = new Stripe(process.env.STRIPE_SECRET_KEY ?? "", {
apiVersion: "2026-04-22.dahlia" as const,
typescript: true,
});