Files
Kordant/web/src/server/api/routers/billing.test.ts
Michael Freno 40a9ef146c feat(billing): add subscription and Stripe billing router
- Add stripeCustomerId column to users table
- Create Stripe client initialization (web/src/server/stripe.ts)
- Add billing service with getOrCreateCustomer, checkout/portal sessions,
  subscription management, invoice listing, and webhook event handling
- Create billing tRPC router with getSubscription, createCheckoutSession,
  createPortalSession, cancelSubscription, reactivateSubscription, listInvoices
- Add raw webhook endpoint at /api/stripe/webhook with signature verification
- Define Valibot schemas for all billing procedure inputs
- Wire billing router into root tRPC router
- Update schema tests for new column/index counts
- Write unit tests for billing service and router
2026-05-25 16:07:00 -04:00

242 lines
8.0 KiB
TypeScript

import { describe, it, expect, vi, beforeEach } from "vitest";
import { initTRPC, TRPCError } from "@trpc/server";
import { wrap } from "@typeschema/valibot";
import {
CreateCheckoutSessionSchema,
CreatePortalSessionSchema,
CancelSubscriptionSchema,
ReactivateSubscriptionSchema,
ListInvoicesSchema,
} from "../schemas/billing";
vi.mock("~/server/services/billing.service", () => ({
getOrCreateCustomer: vi.fn(),
createCheckoutSession: vi.fn(),
createPortalSession: vi.fn(),
cancelSubscription: vi.fn(),
reactivateSubscription: vi.fn(),
listInvoices: vi.fn(),
}));
const { mockFindFirst } = vi.hoisted(() => ({
mockFindFirst: vi.fn(),
}));
vi.mock("~/server/db", () => ({
db: {
query: {
subscriptions: {
findFirst: mockFindFirst,
},
},
},
}));
import {
createCheckoutSession,
createPortalSession,
cancelSubscription,
reactivateSubscription,
listInvoices,
} from "~/server/services/billing.service";
import { db } from "~/server/db";
const mockCreateCheckoutSession = vi.mocked(createCheckoutSession);
const mockCreatePortalSession = vi.mocked(createPortalSession);
const mockCancelSubscription = vi.mocked(cancelSubscription);
const mockReactivateSubscription = vi.mocked(reactivateSubscription);
const mockListInvoices = vi.mocked(listInvoices);
const mockDb = vi.mocked(db);
type User = {
id: string; email: string; name: string | null; image: string | null;
role: "user" | "family_admin" | "family_member" | "support"; emailVerified: Date | null; deletedAt: Date | null;
stripeCustomerId: string | null;
createdAt: Date; updatedAt: Date;
};
type Ctx = { db: object; user: User | null; apiKey: string | null };
const baseUser: User = {
id: "user-1", email: "a@b.com", name: "Test", image: null,
role: "user", emailVerified: null, deletedAt: null,
stripeCustomerId: "cus_123",
createdAt: new Date(), updatedAt: new Date(),
};
function makeUser(overrides: Partial<User> = {}): User {
return { ...baseUser, ...overrides };
}
function createCaller(user: User | null) {
const t = initTRPC.context<Ctx>().create();
const isAuthed = t.middleware(({ ctx, next }) => {
if (!ctx.user) throw new TRPCError({ code: "UNAUTHORIZED" });
return next({ ctx: { ...ctx, user: ctx.user } });
});
const router = t.router({
getSubscription: t.procedure.use(isAuthed)
.query(async () => {
const sub = await mockFindFirst();
return sub ?? null;
}),
createCheckoutSession: t.procedure.use(isAuthed)
.input(wrap(CreateCheckoutSessionSchema))
.mutation(async ({ ctx, input }) => {
const i = input as { priceId: string; successUrl: string; cancelUrl: string };
return mockCreateCheckoutSession(ctx.user.id, ctx.user.email, i.priceId, i.successUrl, i.cancelUrl);
}),
createPortalSession: t.procedure.use(isAuthed)
.input(wrap(CreatePortalSessionSchema))
.mutation(async ({ ctx, input }) => {
const i = input as { returnUrl: string };
if (!ctx.user.stripeCustomerId) {
throw new TRPCError({ code: "NOT_FOUND", message: "No Stripe customer found" });
}
return mockCreatePortalSession(ctx.user.stripeCustomerId, i.returnUrl);
}),
cancelSubscription: t.procedure.use(isAuthed)
.input(wrap(CancelSubscriptionSchema))
.mutation(async ({ input }) => {
const i = input as { subscriptionId: string };
return mockCancelSubscription(i.subscriptionId);
}),
reactivateSubscription: t.procedure.use(isAuthed)
.input(wrap(ReactivateSubscriptionSchema))
.mutation(async ({ input }) => {
const i = input as { subscriptionId: string };
return mockReactivateSubscription(i.subscriptionId);
}),
listInvoices: t.procedure.use(isAuthed)
.input(wrap(ListInvoicesSchema))
.query(async ({ ctx, input }) => {
if (!ctx.user.stripeCustomerId) {
return { invoices: [], hasMore: false };
}
const i = input as { limit?: string; startingAfter?: string };
return mockListInvoices(ctx.user.stripeCustomerId, parseInt(i.limit ?? "10", 10), i.startingAfter);
}),
});
const caller = t.createCallerFactory(router);
return caller({ db: {} as never, user, apiKey: null });
}
beforeEach(() => {
vi.clearAllMocks();
});
describe("billing.getSubscription", () => {
it("returns subscription for authenticated user", async () => {
const now = new Date();
(mockDb.query.subscriptions.findFirst as ReturnType<typeof vi.fn>).mockResolvedValue({
id: "sub-1", userId: "user-1", stripeId: "sub_stripe_1",
tier: "premium", status: "active",
currentPeriodStart: now, currentPeriodEnd: now,
cancelAtPeriodEnd: false,
createdAt: now, updatedAt: now,
});
const api = createCaller(makeUser());
const result = await api.getSubscription();
expect(result).not.toBeNull();
expect(result!.tier).toBe("premium");
expect(result!.status).toBe("active");
});
it("returns null when user has no subscription", async () => {
(mockDb.query.subscriptions.findFirst as ReturnType<typeof vi.fn>).mockResolvedValue(undefined);
const api = createCaller(makeUser());
const result = await api.getSubscription();
expect(result).toBeNull();
});
it("rejects unauthenticated users", async () => {
const api = createCaller(null);
await expect(api.getSubscription()).rejects.toThrow(TRPCError);
});
});
describe("billing.createCheckoutSession", () => {
it("creates checkout session and returns URL", async () => {
mockCreateCheckoutSession.mockResolvedValue({
url: "https://checkout.stripe.com/session_123",
sessionId: "session_123",
});
const api = createCaller(makeUser());
const result = await api.createCheckoutSession({
priceId: "price_basic",
successUrl: "https://example.com/success",
cancelUrl: "https://example.com/cancel",
});
expect(result.url).toBe("https://checkout.stripe.com/session_123");
});
});
describe("billing.createPortalSession", () => {
it("creates portal session for user with stripeCustomerId", async () => {
mockCreatePortalSession.mockResolvedValue({
url: "https://billing.stripe.com/portal/session_456",
});
const api = createCaller(makeUser());
const result = await api.createPortalSession({
returnUrl: "https://example.com/return",
});
expect(result.url).toBe("https://billing.stripe.com/portal/session_456");
});
it("throws NOT_FOUND when user has no stripeCustomerId", async () => {
const api = createCaller(makeUser({ stripeCustomerId: null }));
await expect(api.createPortalSession({ returnUrl: "https://example.com/return" })).rejects.toThrow(TRPCError);
});
});
describe("billing.cancelSubscription", () => {
it("cancels subscription", async () => {
mockCancelSubscription.mockResolvedValue({ cancelAtPeriodEnd: true });
const api = createCaller(makeUser());
const result = await api.cancelSubscription({ subscriptionId: "sub_123" });
expect(result.cancelAtPeriodEnd).toBe(true);
});
});
describe("billing.reactivateSubscription", () => {
it("reactivates subscription", async () => {
mockReactivateSubscription.mockResolvedValue({ cancelAtPeriodEnd: false });
const api = createCaller(makeUser());
const result = await api.reactivateSubscription({ subscriptionId: "sub_123" });
expect(result.cancelAtPeriodEnd).toBe(false);
});
});
describe("billing.listInvoices", () => {
it("lists invoices for user with stripeCustomerId", async () => {
mockListInvoices.mockResolvedValue({
invoices: [{ id: "in_1" }, { id: "in_2" }] as never,
hasMore: false,
});
const api = createCaller(makeUser());
const result = await api.listInvoices({});
expect(result.invoices).toHaveLength(2);
});
it("returns empty list when user has no stripeCustomerId", async () => {
const api = createCaller(makeUser({ stripeCustomerId: null }));
const result = await api.listInvoices({});
expect(result.invoices).toHaveLength(0);
expect(result.hasMore).toBe(false);
});
});