feat(billing): add subscription and Stripe billing router
- Add stripeCustomerId column to users table - Create Stripe client initialization (web/src/server/stripe.ts) - Add billing service with getOrCreateCustomer, checkout/portal sessions, subscription management, invoice listing, and webhook event handling - Create billing tRPC router with getSubscription, createCheckoutSession, createPortalSession, cancelSubscription, reactivateSubscription, listInvoices - Add raw webhook endpoint at /api/stripe/webhook with signature verification - Define Valibot schemas for all billing procedure inputs - Wire billing router into root tRPC router - Update schema tests for new column/index counts - Write unit tests for billing service and router
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
import { exampleRouter } from "./routers/example";
|
||||
import { userRouter } from "./routers/user";
|
||||
import { billingRouter } from "./routers/billing";
|
||||
import { createTRPCRouter } from "./utils";
|
||||
|
||||
export const appRouter = createTRPCRouter({
|
||||
example: exampleRouter,
|
||||
user: userRouter,
|
||||
billing: billingRouter,
|
||||
});
|
||||
|
||||
export type AppRouter = typeof appRouter;
|
||||
|
||||
241
web/src/server/api/routers/billing.test.ts
Normal file
241
web/src/server/api/routers/billing.test.ts
Normal file
@@ -0,0 +1,241 @@
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||
import { initTRPC, TRPCError } from "@trpc/server";
|
||||
import { wrap } from "@typeschema/valibot";
|
||||
import {
|
||||
CreateCheckoutSessionSchema,
|
||||
CreatePortalSessionSchema,
|
||||
CancelSubscriptionSchema,
|
||||
ReactivateSubscriptionSchema,
|
||||
ListInvoicesSchema,
|
||||
} from "../schemas/billing";
|
||||
|
||||
vi.mock("~/server/services/billing.service", () => ({
|
||||
getOrCreateCustomer: vi.fn(),
|
||||
createCheckoutSession: vi.fn(),
|
||||
createPortalSession: vi.fn(),
|
||||
cancelSubscription: vi.fn(),
|
||||
reactivateSubscription: vi.fn(),
|
||||
listInvoices: vi.fn(),
|
||||
}));
|
||||
|
||||
const { mockFindFirst } = vi.hoisted(() => ({
|
||||
mockFindFirst: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("~/server/db", () => ({
|
||||
db: {
|
||||
query: {
|
||||
subscriptions: {
|
||||
findFirst: mockFindFirst,
|
||||
},
|
||||
},
|
||||
},
|
||||
}));
|
||||
|
||||
import {
|
||||
createCheckoutSession,
|
||||
createPortalSession,
|
||||
cancelSubscription,
|
||||
reactivateSubscription,
|
||||
listInvoices,
|
||||
} from "~/server/services/billing.service";
|
||||
import { db } from "~/server/db";
|
||||
|
||||
const mockCreateCheckoutSession = vi.mocked(createCheckoutSession);
|
||||
const mockCreatePortalSession = vi.mocked(createPortalSession);
|
||||
const mockCancelSubscription = vi.mocked(cancelSubscription);
|
||||
const mockReactivateSubscription = vi.mocked(reactivateSubscription);
|
||||
const mockListInvoices = vi.mocked(listInvoices);
|
||||
const mockDb = vi.mocked(db);
|
||||
|
||||
type User = {
|
||||
id: string; email: string; name: string | null; image: string | null;
|
||||
role: "user" | "family_admin" | "family_member" | "support"; emailVerified: Date | null; deletedAt: Date | null;
|
||||
stripeCustomerId: string | null;
|
||||
createdAt: Date; updatedAt: Date;
|
||||
};
|
||||
|
||||
type Ctx = { db: object; user: User | null; apiKey: string | null };
|
||||
|
||||
const baseUser: User = {
|
||||
id: "user-1", email: "a@b.com", name: "Test", image: null,
|
||||
role: "user", emailVerified: null, deletedAt: null,
|
||||
stripeCustomerId: "cus_123",
|
||||
createdAt: new Date(), updatedAt: new Date(),
|
||||
};
|
||||
|
||||
function makeUser(overrides: Partial<User> = {}): User {
|
||||
return { ...baseUser, ...overrides };
|
||||
}
|
||||
|
||||
function createCaller(user: User | null) {
|
||||
const t = initTRPC.context<Ctx>().create();
|
||||
const isAuthed = t.middleware(({ ctx, next }) => {
|
||||
if (!ctx.user) throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||
return next({ ctx: { ...ctx, user: ctx.user } });
|
||||
});
|
||||
|
||||
const router = t.router({
|
||||
getSubscription: t.procedure.use(isAuthed)
|
||||
.query(async () => {
|
||||
const sub = await mockFindFirst();
|
||||
return sub ?? null;
|
||||
}),
|
||||
createCheckoutSession: t.procedure.use(isAuthed)
|
||||
.input(wrap(CreateCheckoutSessionSchema))
|
||||
.mutation(async ({ ctx, input }) => {
|
||||
const i = input as { priceId: string; successUrl: string; cancelUrl: string };
|
||||
return mockCreateCheckoutSession(ctx.user.id, ctx.user.email, i.priceId, i.successUrl, i.cancelUrl);
|
||||
}),
|
||||
createPortalSession: t.procedure.use(isAuthed)
|
||||
.input(wrap(CreatePortalSessionSchema))
|
||||
.mutation(async ({ ctx, input }) => {
|
||||
const i = input as { returnUrl: string };
|
||||
if (!ctx.user.stripeCustomerId) {
|
||||
throw new TRPCError({ code: "NOT_FOUND", message: "No Stripe customer found" });
|
||||
}
|
||||
return mockCreatePortalSession(ctx.user.stripeCustomerId, i.returnUrl);
|
||||
}),
|
||||
cancelSubscription: t.procedure.use(isAuthed)
|
||||
.input(wrap(CancelSubscriptionSchema))
|
||||
.mutation(async ({ input }) => {
|
||||
const i = input as { subscriptionId: string };
|
||||
return mockCancelSubscription(i.subscriptionId);
|
||||
}),
|
||||
reactivateSubscription: t.procedure.use(isAuthed)
|
||||
.input(wrap(ReactivateSubscriptionSchema))
|
||||
.mutation(async ({ input }) => {
|
||||
const i = input as { subscriptionId: string };
|
||||
return mockReactivateSubscription(i.subscriptionId);
|
||||
}),
|
||||
listInvoices: t.procedure.use(isAuthed)
|
||||
.input(wrap(ListInvoicesSchema))
|
||||
.query(async ({ ctx, input }) => {
|
||||
if (!ctx.user.stripeCustomerId) {
|
||||
return { invoices: [], hasMore: false };
|
||||
}
|
||||
const i = input as { limit?: string; startingAfter?: string };
|
||||
return mockListInvoices(ctx.user.stripeCustomerId, parseInt(i.limit ?? "10", 10), i.startingAfter);
|
||||
}),
|
||||
});
|
||||
|
||||
const caller = t.createCallerFactory(router);
|
||||
return caller({ db: {} as never, user, apiKey: null });
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("billing.getSubscription", () => {
|
||||
it("returns subscription for authenticated user", async () => {
|
||||
const now = new Date();
|
||||
(mockDb.query.subscriptions.findFirst as ReturnType<typeof vi.fn>).mockResolvedValue({
|
||||
id: "sub-1", userId: "user-1", stripeId: "sub_stripe_1",
|
||||
tier: "premium", status: "active",
|
||||
currentPeriodStart: now, currentPeriodEnd: now,
|
||||
cancelAtPeriodEnd: false,
|
||||
createdAt: now, updatedAt: now,
|
||||
});
|
||||
|
||||
const api = createCaller(makeUser());
|
||||
const result = await api.getSubscription();
|
||||
expect(result).not.toBeNull();
|
||||
expect(result!.tier).toBe("premium");
|
||||
expect(result!.status).toBe("active");
|
||||
});
|
||||
|
||||
it("returns null when user has no subscription", async () => {
|
||||
(mockDb.query.subscriptions.findFirst as ReturnType<typeof vi.fn>).mockResolvedValue(undefined);
|
||||
const api = createCaller(makeUser());
|
||||
const result = await api.getSubscription();
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("rejects unauthenticated users", async () => {
|
||||
const api = createCaller(null);
|
||||
await expect(api.getSubscription()).rejects.toThrow(TRPCError);
|
||||
});
|
||||
});
|
||||
|
||||
describe("billing.createCheckoutSession", () => {
|
||||
it("creates checkout session and returns URL", async () => {
|
||||
mockCreateCheckoutSession.mockResolvedValue({
|
||||
url: "https://checkout.stripe.com/session_123",
|
||||
sessionId: "session_123",
|
||||
});
|
||||
|
||||
const api = createCaller(makeUser());
|
||||
const result = await api.createCheckoutSession({
|
||||
priceId: "price_basic",
|
||||
successUrl: "https://example.com/success",
|
||||
cancelUrl: "https://example.com/cancel",
|
||||
});
|
||||
|
||||
expect(result.url).toBe("https://checkout.stripe.com/session_123");
|
||||
});
|
||||
});
|
||||
|
||||
describe("billing.createPortalSession", () => {
|
||||
it("creates portal session for user with stripeCustomerId", async () => {
|
||||
mockCreatePortalSession.mockResolvedValue({
|
||||
url: "https://billing.stripe.com/portal/session_456",
|
||||
});
|
||||
|
||||
const api = createCaller(makeUser());
|
||||
const result = await api.createPortalSession({
|
||||
returnUrl: "https://example.com/return",
|
||||
});
|
||||
|
||||
expect(result.url).toBe("https://billing.stripe.com/portal/session_456");
|
||||
});
|
||||
|
||||
it("throws NOT_FOUND when user has no stripeCustomerId", async () => {
|
||||
const api = createCaller(makeUser({ stripeCustomerId: null }));
|
||||
await expect(api.createPortalSession({ returnUrl: "https://example.com/return" })).rejects.toThrow(TRPCError);
|
||||
});
|
||||
});
|
||||
|
||||
describe("billing.cancelSubscription", () => {
|
||||
it("cancels subscription", async () => {
|
||||
mockCancelSubscription.mockResolvedValue({ cancelAtPeriodEnd: true });
|
||||
|
||||
const api = createCaller(makeUser());
|
||||
const result = await api.cancelSubscription({ subscriptionId: "sub_123" });
|
||||
|
||||
expect(result.cancelAtPeriodEnd).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("billing.reactivateSubscription", () => {
|
||||
it("reactivates subscription", async () => {
|
||||
mockReactivateSubscription.mockResolvedValue({ cancelAtPeriodEnd: false });
|
||||
|
||||
const api = createCaller(makeUser());
|
||||
const result = await api.reactivateSubscription({ subscriptionId: "sub_123" });
|
||||
|
||||
expect(result.cancelAtPeriodEnd).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("billing.listInvoices", () => {
|
||||
it("lists invoices for user with stripeCustomerId", async () => {
|
||||
mockListInvoices.mockResolvedValue({
|
||||
invoices: [{ id: "in_1" }, { id: "in_2" }] as never,
|
||||
hasMore: false,
|
||||
});
|
||||
|
||||
const api = createCaller(makeUser());
|
||||
const result = await api.listInvoices({});
|
||||
|
||||
expect(result.invoices).toHaveLength(2);
|
||||
});
|
||||
|
||||
it("returns empty list when user has no stripeCustomerId", async () => {
|
||||
const api = createCaller(makeUser({ stripeCustomerId: null }));
|
||||
const result = await api.listInvoices({});
|
||||
|
||||
expect(result.invoices).toHaveLength(0);
|
||||
expect(result.hasMore).toBe(false);
|
||||
});
|
||||
});
|
||||
100
web/src/server/api/routers/billing.ts
Normal file
100
web/src/server/api/routers/billing.ts
Normal file
@@ -0,0 +1,100 @@
|
||||
import { TRPCError } from "@trpc/server";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { wrap } from "@typeschema/valibot";
|
||||
import { createTRPCRouter, protectedProcedure } from "../utils";
|
||||
import {
|
||||
CreateCheckoutSessionSchema,
|
||||
CreatePortalSessionSchema,
|
||||
CancelSubscriptionSchema,
|
||||
ReactivateSubscriptionSchema,
|
||||
ListInvoicesSchema,
|
||||
} from "../schemas/billing";
|
||||
import {
|
||||
getOrCreateCustomer,
|
||||
createCheckoutSession,
|
||||
createPortalSession,
|
||||
cancelSubscription,
|
||||
reactivateSubscription,
|
||||
listInvoices,
|
||||
} from "~/server/services/billing.service";
|
||||
import { db } from "~/server/db";
|
||||
import { subscriptions } from "~/server/db/schema/subscription";
|
||||
|
||||
export const billingRouter = createTRPCRouter({
|
||||
getSubscription: protectedProcedure.query(async ({ ctx }) => {
|
||||
const sub = await db.query.subscriptions.findFirst({
|
||||
where: eq(subscriptions.userId, ctx.user.id),
|
||||
});
|
||||
return sub ?? null;
|
||||
}),
|
||||
|
||||
createCheckoutSession: protectedProcedure
|
||||
.input(wrap(CreateCheckoutSessionSchema))
|
||||
.mutation(async ({ ctx, input }) => {
|
||||
const allowedPrices = [
|
||||
process.env.STRIPE_PRICE_BASIC,
|
||||
process.env.STRIPE_PRICE_PLUS,
|
||||
process.env.STRIPE_PRICE_PREMIUM,
|
||||
].filter(Boolean);
|
||||
|
||||
if (!allowedPrices.includes(input.priceId)) {
|
||||
throw new TRPCError({
|
||||
code: "BAD_REQUEST",
|
||||
message: "Invalid price ID",
|
||||
});
|
||||
}
|
||||
|
||||
return createCheckoutSession(
|
||||
ctx.user.id,
|
||||
ctx.user.email,
|
||||
input.priceId,
|
||||
input.successUrl,
|
||||
input.cancelUrl,
|
||||
);
|
||||
}),
|
||||
|
||||
createPortalSession: protectedProcedure
|
||||
.input(wrap(CreatePortalSessionSchema))
|
||||
.mutation(async ({ ctx, input }) => {
|
||||
const user = ctx.user;
|
||||
const stripeCustomerId = user.stripeCustomerId;
|
||||
|
||||
if (!stripeCustomerId) {
|
||||
throw new TRPCError({
|
||||
code: "NOT_FOUND",
|
||||
message: "No Stripe customer found",
|
||||
});
|
||||
}
|
||||
|
||||
return createPortalSession(stripeCustomerId, input.returnUrl);
|
||||
}),
|
||||
|
||||
cancelSubscription: protectedProcedure
|
||||
.input(wrap(CancelSubscriptionSchema))
|
||||
.mutation(async ({ input }) => {
|
||||
return cancelSubscription(input.subscriptionId);
|
||||
}),
|
||||
|
||||
reactivateSubscription: protectedProcedure
|
||||
.input(wrap(ReactivateSubscriptionSchema))
|
||||
.mutation(async ({ input }) => {
|
||||
return reactivateSubscription(input.subscriptionId);
|
||||
}),
|
||||
|
||||
listInvoices: protectedProcedure
|
||||
.input(wrap(ListInvoicesSchema))
|
||||
.query(async ({ ctx, input }) => {
|
||||
const user = ctx.user;
|
||||
const stripeCustomerId = user.stripeCustomerId;
|
||||
|
||||
if (!stripeCustomerId) {
|
||||
return { invoices: [], hasMore: false };
|
||||
}
|
||||
|
||||
return listInvoices(
|
||||
stripeCustomerId,
|
||||
parseInt(input.limit ?? "10", 10),
|
||||
input.startingAfter,
|
||||
);
|
||||
}),
|
||||
});
|
||||
@@ -30,6 +30,7 @@ const mockUpdateMemberRole = vi.mocked(updateMemberRole);
|
||||
type User = {
|
||||
id: string; email: string; name: string | null; image: string | null;
|
||||
role: "user" | "family_admin" | "family_member" | "support"; emailVerified: Date | null; deletedAt: Date | null;
|
||||
stripeCustomerId: string | null;
|
||||
createdAt: Date; updatedAt: Date;
|
||||
};
|
||||
type Ctx = { db: object; user: User | null; apiKey: string | null };
|
||||
@@ -97,6 +98,7 @@ function createCaller(user: User | null) {
|
||||
const baseUser: User = {
|
||||
id: "user-1", email: "a@b.com", name: "Test", image: null,
|
||||
role: "user", emailVerified: null, deletedAt: null,
|
||||
stripeCustomerId: null,
|
||||
createdAt: new Date(), updatedAt: new Date(),
|
||||
};
|
||||
|
||||
|
||||
24
web/src/server/api/schemas/billing.ts
Normal file
24
web/src/server/api/schemas/billing.ts
Normal file
@@ -0,0 +1,24 @@
|
||||
import { object, string, url, minLength, optional, picklist } from "valibot";
|
||||
|
||||
export const CreateCheckoutSessionSchema = object({
|
||||
priceId: string([minLength(1)]),
|
||||
successUrl: string([url()]),
|
||||
cancelUrl: string([url()]),
|
||||
});
|
||||
|
||||
export const CreatePortalSessionSchema = object({
|
||||
returnUrl: string([url()]),
|
||||
});
|
||||
|
||||
export const CancelSubscriptionSchema = object({
|
||||
subscriptionId: string([minLength(1)]),
|
||||
});
|
||||
|
||||
export const ReactivateSubscriptionSchema = object({
|
||||
subscriptionId: string([minLength(1)]),
|
||||
});
|
||||
|
||||
export const ListInvoicesSchema = object({
|
||||
limit: optional(string(), "10"),
|
||||
startingAfter: optional(string()),
|
||||
});
|
||||
Reference in New Issue
Block a user