diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index c046a2c..fb61688 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -67,6 +67,9 @@ importers: solid-js: specifier: ^1.9.5 version: 1.9.13 + stripe: + specifier: ^22.1.1 + version: 22.1.1(@types/node@25.9.1) tailwindcss: specifier: ^4.0.0 version: 4.3.0 @@ -3130,6 +3133,15 @@ packages: strip-literal@3.1.0: resolution: {integrity: sha512-8r3mkIM/2+PpjHoOtiAW8Rg3jJLHaV7xPwG+YRGrv6FP0wwk/toTpATxWYOW0BKdWwl82VT2tFYi5DlROa0Mxg==} + stripe@22.1.1: + resolution: {integrity: sha512-cmodIYP27tBkJ8G7DuGgWw0PFuemlFZbuF3Wwr1TrjFjUa3T7NIgCe6TVwX8BO2ynu+xtTuDGfHafNDCPt9lXA==} + engines: {node: '>=18'} + peerDependencies: + '@types/node': '>=18' + peerDependenciesMeta: + '@types/node': + optional: true + supports-color@10.2.2: resolution: {integrity: sha512-SS+jx45GF1QjgEXQx4NJZV9ImqmO2NPz5FNsIHrsDjh2YsHnawpan7SNQ1o8NuhrbHZy9AZhIoCUiCeaW/C80g==} engines: {node: '>=18'} @@ -6431,6 +6443,10 @@ snapshots: dependencies: js-tokens: 9.0.1 + stripe@22.1.1(@types/node@25.9.1): + optionalDependencies: + '@types/node': 25.9.1 + supports-color@10.2.2: {} supports-color@7.2.0: diff --git a/web/.env.example b/web/.env.example index 1baeffc..e3e51e8 100644 --- a/web/.env.example +++ b/web/.env.example @@ -1 +1,8 @@ DATABASE_URL="postgresql://postgres:postgres@localhost:5432/shieldai" + +# Stripe (get test keys from https://dashboard.stripe.com/test/apikeys) +STRIPE_SECRET_KEY="sk_test_..." +STRIPE_WEBHOOK_SECRET="whsec_..." +STRIPE_PRICE_BASIC="price_basic" +STRIPE_PRICE_PLUS="price_plus" +STRIPE_PRICE_PREMIUM="price_premium" diff --git a/web/package.json b/web/package.json index c51806a..f63a5fe 100644 --- a/web/package.json +++ b/web/package.json @@ -28,6 +28,7 @@ "jose": "^5", "pg": "^8.21.0", "solid-js": "^1.9.5", + "stripe": "^22.1.1", "tailwindcss": "^4.0.0", "three": "^0.184.0", "valibot": "^0.29.0", diff --git a/web/src/routes/api/stripe/webhook.ts b/web/src/routes/api/stripe/webhook.ts new file mode 100644 index 0000000..9773a65 --- /dev/null +++ b/web/src/routes/api/stripe/webhook.ts @@ -0,0 +1,27 @@ +import type { APIEvent } from "@solidjs/start/server"; +import { stripe } from "~/server/stripe"; +import { handleWebhookEvent } from "~/server/services/billing.service"; + +export async function POST(event: APIEvent) { + const body = await event.request.text(); + const signature = event.request.headers.get("stripe-signature"); + + if (!signature) { + return new Response("Missing stripe-signature header", { status: 400 }); + } + + try { + const webhookEvent = stripe.webhooks.constructEvent( + body, + signature, + process.env.STRIPE_WEBHOOK_SECRET ?? "", + ); + + await handleWebhookEvent(webhookEvent); + + return new Response("OK", { status: 200 }); + } catch (err) { + const message = err instanceof Error ? err.message : "Webhook error"; + return new Response(message, { status: 400 }); + } +} diff --git a/web/src/server/api/root.ts b/web/src/server/api/root.ts index adeb0f5..980f2b9 100644 --- a/web/src/server/api/root.ts +++ b/web/src/server/api/root.ts @@ -1,10 +1,12 @@ import { exampleRouter } from "./routers/example"; import { userRouter } from "./routers/user"; +import { billingRouter } from "./routers/billing"; import { createTRPCRouter } from "./utils"; export const appRouter = createTRPCRouter({ example: exampleRouter, user: userRouter, + billing: billingRouter, }); export type AppRouter = typeof appRouter; diff --git a/web/src/server/api/routers/billing.test.ts b/web/src/server/api/routers/billing.test.ts new file mode 100644 index 0000000..4a5319e --- /dev/null +++ b/web/src/server/api/routers/billing.test.ts @@ -0,0 +1,241 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { initTRPC, TRPCError } from "@trpc/server"; +import { wrap } from "@typeschema/valibot"; +import { + CreateCheckoutSessionSchema, + CreatePortalSessionSchema, + CancelSubscriptionSchema, + ReactivateSubscriptionSchema, + ListInvoicesSchema, +} from "../schemas/billing"; + +vi.mock("~/server/services/billing.service", () => ({ + getOrCreateCustomer: vi.fn(), + createCheckoutSession: vi.fn(), + createPortalSession: vi.fn(), + cancelSubscription: vi.fn(), + reactivateSubscription: vi.fn(), + listInvoices: vi.fn(), +})); + +const { mockFindFirst } = vi.hoisted(() => ({ + mockFindFirst: vi.fn(), +})); + +vi.mock("~/server/db", () => ({ + db: { + query: { + subscriptions: { + findFirst: mockFindFirst, + }, + }, + }, +})); + +import { + createCheckoutSession, + createPortalSession, + cancelSubscription, + reactivateSubscription, + listInvoices, +} from "~/server/services/billing.service"; +import { db } from "~/server/db"; + +const mockCreateCheckoutSession = vi.mocked(createCheckoutSession); +const mockCreatePortalSession = vi.mocked(createPortalSession); +const mockCancelSubscription = vi.mocked(cancelSubscription); +const mockReactivateSubscription = vi.mocked(reactivateSubscription); +const mockListInvoices = vi.mocked(listInvoices); +const mockDb = vi.mocked(db); + +type User = { + id: string; email: string; name: string | null; image: string | null; + role: "user" | "family_admin" | "family_member" | "support"; emailVerified: Date | null; deletedAt: Date | null; + stripeCustomerId: string | null; + createdAt: Date; updatedAt: Date; +}; + +type Ctx = { db: object; user: User | null; apiKey: string | null }; + +const baseUser: User = { + id: "user-1", email: "a@b.com", name: "Test", image: null, + role: "user", emailVerified: null, deletedAt: null, + stripeCustomerId: "cus_123", + createdAt: new Date(), updatedAt: new Date(), +}; + +function makeUser(overrides: Partial = {}): User { + return { ...baseUser, ...overrides }; +} + +function createCaller(user: User | null) { + const t = initTRPC.context().create(); + const isAuthed = t.middleware(({ ctx, next }) => { + if (!ctx.user) throw new TRPCError({ code: "UNAUTHORIZED" }); + return next({ ctx: { ...ctx, user: ctx.user } }); + }); + + const router = t.router({ + getSubscription: t.procedure.use(isAuthed) + .query(async () => { + const sub = await mockFindFirst(); + return sub ?? null; + }), + createCheckoutSession: t.procedure.use(isAuthed) + .input(wrap(CreateCheckoutSessionSchema)) + .mutation(async ({ ctx, input }) => { + const i = input as { priceId: string; successUrl: string; cancelUrl: string }; + return mockCreateCheckoutSession(ctx.user.id, ctx.user.email, i.priceId, i.successUrl, i.cancelUrl); + }), + createPortalSession: t.procedure.use(isAuthed) + .input(wrap(CreatePortalSessionSchema)) + .mutation(async ({ ctx, input }) => { + const i = input as { returnUrl: string }; + if (!ctx.user.stripeCustomerId) { + throw new TRPCError({ code: "NOT_FOUND", message: "No Stripe customer found" }); + } + return mockCreatePortalSession(ctx.user.stripeCustomerId, i.returnUrl); + }), + cancelSubscription: t.procedure.use(isAuthed) + .input(wrap(CancelSubscriptionSchema)) + .mutation(async ({ input }) => { + const i = input as { subscriptionId: string }; + return mockCancelSubscription(i.subscriptionId); + }), + reactivateSubscription: t.procedure.use(isAuthed) + .input(wrap(ReactivateSubscriptionSchema)) + .mutation(async ({ input }) => { + const i = input as { subscriptionId: string }; + return mockReactivateSubscription(i.subscriptionId); + }), + listInvoices: t.procedure.use(isAuthed) + .input(wrap(ListInvoicesSchema)) + .query(async ({ ctx, input }) => { + if (!ctx.user.stripeCustomerId) { + return { invoices: [], hasMore: false }; + } + const i = input as { limit?: string; startingAfter?: string }; + return mockListInvoices(ctx.user.stripeCustomerId, parseInt(i.limit ?? "10", 10), i.startingAfter); + }), + }); + + const caller = t.createCallerFactory(router); + return caller({ db: {} as never, user, apiKey: null }); +} + +beforeEach(() => { + vi.clearAllMocks(); +}); + +describe("billing.getSubscription", () => { + it("returns subscription for authenticated user", async () => { + const now = new Date(); + (mockDb.query.subscriptions.findFirst as ReturnType).mockResolvedValue({ + id: "sub-1", userId: "user-1", stripeId: "sub_stripe_1", + tier: "premium", status: "active", + currentPeriodStart: now, currentPeriodEnd: now, + cancelAtPeriodEnd: false, + createdAt: now, updatedAt: now, + }); + + const api = createCaller(makeUser()); + const result = await api.getSubscription(); + expect(result).not.toBeNull(); + expect(result!.tier).toBe("premium"); + expect(result!.status).toBe("active"); + }); + + it("returns null when user has no subscription", async () => { + (mockDb.query.subscriptions.findFirst as ReturnType).mockResolvedValue(undefined); + const api = createCaller(makeUser()); + const result = await api.getSubscription(); + expect(result).toBeNull(); + }); + + it("rejects unauthenticated users", async () => { + const api = createCaller(null); + await expect(api.getSubscription()).rejects.toThrow(TRPCError); + }); +}); + +describe("billing.createCheckoutSession", () => { + it("creates checkout session and returns URL", async () => { + mockCreateCheckoutSession.mockResolvedValue({ + url: "https://checkout.stripe.com/session_123", + sessionId: "session_123", + }); + + const api = createCaller(makeUser()); + const result = await api.createCheckoutSession({ + priceId: "price_basic", + successUrl: "https://example.com/success", + cancelUrl: "https://example.com/cancel", + }); + + expect(result.url).toBe("https://checkout.stripe.com/session_123"); + }); +}); + +describe("billing.createPortalSession", () => { + it("creates portal session for user with stripeCustomerId", async () => { + mockCreatePortalSession.mockResolvedValue({ + url: "https://billing.stripe.com/portal/session_456", + }); + + const api = createCaller(makeUser()); + const result = await api.createPortalSession({ + returnUrl: "https://example.com/return", + }); + + expect(result.url).toBe("https://billing.stripe.com/portal/session_456"); + }); + + it("throws NOT_FOUND when user has no stripeCustomerId", async () => { + const api = createCaller(makeUser({ stripeCustomerId: null })); + await expect(api.createPortalSession({ returnUrl: "https://example.com/return" })).rejects.toThrow(TRPCError); + }); +}); + +describe("billing.cancelSubscription", () => { + it("cancels subscription", async () => { + mockCancelSubscription.mockResolvedValue({ cancelAtPeriodEnd: true }); + + const api = createCaller(makeUser()); + const result = await api.cancelSubscription({ subscriptionId: "sub_123" }); + + expect(result.cancelAtPeriodEnd).toBe(true); + }); +}); + +describe("billing.reactivateSubscription", () => { + it("reactivates subscription", async () => { + mockReactivateSubscription.mockResolvedValue({ cancelAtPeriodEnd: false }); + + const api = createCaller(makeUser()); + const result = await api.reactivateSubscription({ subscriptionId: "sub_123" }); + + expect(result.cancelAtPeriodEnd).toBe(false); + }); +}); + +describe("billing.listInvoices", () => { + it("lists invoices for user with stripeCustomerId", async () => { + mockListInvoices.mockResolvedValue({ + invoices: [{ id: "in_1" }, { id: "in_2" }] as never, + hasMore: false, + }); + + const api = createCaller(makeUser()); + const result = await api.listInvoices({}); + + expect(result.invoices).toHaveLength(2); + }); + + it("returns empty list when user has no stripeCustomerId", async () => { + const api = createCaller(makeUser({ stripeCustomerId: null })); + const result = await api.listInvoices({}); + + expect(result.invoices).toHaveLength(0); + expect(result.hasMore).toBe(false); + }); +}); diff --git a/web/src/server/api/routers/billing.ts b/web/src/server/api/routers/billing.ts new file mode 100644 index 0000000..75e76b5 --- /dev/null +++ b/web/src/server/api/routers/billing.ts @@ -0,0 +1,100 @@ +import { TRPCError } from "@trpc/server"; +import { eq } from "drizzle-orm"; +import { wrap } from "@typeschema/valibot"; +import { createTRPCRouter, protectedProcedure } from "../utils"; +import { + CreateCheckoutSessionSchema, + CreatePortalSessionSchema, + CancelSubscriptionSchema, + ReactivateSubscriptionSchema, + ListInvoicesSchema, +} from "../schemas/billing"; +import { + getOrCreateCustomer, + createCheckoutSession, + createPortalSession, + cancelSubscription, + reactivateSubscription, + listInvoices, +} from "~/server/services/billing.service"; +import { db } from "~/server/db"; +import { subscriptions } from "~/server/db/schema/subscription"; + +export const billingRouter = createTRPCRouter({ + getSubscription: protectedProcedure.query(async ({ ctx }) => { + const sub = await db.query.subscriptions.findFirst({ + where: eq(subscriptions.userId, ctx.user.id), + }); + return sub ?? null; + }), + + createCheckoutSession: protectedProcedure + .input(wrap(CreateCheckoutSessionSchema)) + .mutation(async ({ ctx, input }) => { + const allowedPrices = [ + process.env.STRIPE_PRICE_BASIC, + process.env.STRIPE_PRICE_PLUS, + process.env.STRIPE_PRICE_PREMIUM, + ].filter(Boolean); + + if (!allowedPrices.includes(input.priceId)) { + throw new TRPCError({ + code: "BAD_REQUEST", + message: "Invalid price ID", + }); + } + + return createCheckoutSession( + ctx.user.id, + ctx.user.email, + input.priceId, + input.successUrl, + input.cancelUrl, + ); + }), + + createPortalSession: protectedProcedure + .input(wrap(CreatePortalSessionSchema)) + .mutation(async ({ ctx, input }) => { + const user = ctx.user; + const stripeCustomerId = user.stripeCustomerId; + + if (!stripeCustomerId) { + throw new TRPCError({ + code: "NOT_FOUND", + message: "No Stripe customer found", + }); + } + + return createPortalSession(stripeCustomerId, input.returnUrl); + }), + + cancelSubscription: protectedProcedure + .input(wrap(CancelSubscriptionSchema)) + .mutation(async ({ input }) => { + return cancelSubscription(input.subscriptionId); + }), + + reactivateSubscription: protectedProcedure + .input(wrap(ReactivateSubscriptionSchema)) + .mutation(async ({ input }) => { + return reactivateSubscription(input.subscriptionId); + }), + + listInvoices: protectedProcedure + .input(wrap(ListInvoicesSchema)) + .query(async ({ ctx, input }) => { + const user = ctx.user; + const stripeCustomerId = user.stripeCustomerId; + + if (!stripeCustomerId) { + return { invoices: [], hasMore: false }; + } + + return listInvoices( + stripeCustomerId, + parseInt(input.limit ?? "10", 10), + input.startingAfter, + ); + }), +}); diff --git a/web/src/server/api/routers/user.test.ts b/web/src/server/api/routers/user.test.ts index 5b3d229..d38afda 100644 --- a/web/src/server/api/routers/user.test.ts +++ b/web/src/server/api/routers/user.test.ts @@ -30,6 +30,7 @@ const mockUpdateMemberRole = vi.mocked(updateMemberRole); type User = { id: string; email: string; name: string | null; image: string | null; role: "user" | "family_admin" | "family_member" | "support"; emailVerified: Date | null; deletedAt: Date | null; + stripeCustomerId: string | null; createdAt: Date; updatedAt: Date; }; type Ctx = { db: object; user: User | null; apiKey: string | null }; @@ -97,6 +98,7 @@ function createCaller(user: User | null) { const baseUser: User = { id: "user-1", email: "a@b.com", name: "Test", image: null, role: "user", emailVerified: null, deletedAt: null, + stripeCustomerId: null, createdAt: new Date(), updatedAt: new Date(), }; diff --git a/web/src/server/api/schemas/billing.ts b/web/src/server/api/schemas/billing.ts new file mode 100644 index 0000000..ed1bf19 --- /dev/null +++ b/web/src/server/api/schemas/billing.ts @@ -0,0 +1,24 @@ +import { object, string, url, minLength, optional, picklist } from "valibot"; + +export const CreateCheckoutSessionSchema = object({ + priceId: string([minLength(1)]), + successUrl: string([url()]), + cancelUrl: string([url()]), +}); + +export const CreatePortalSessionSchema = object({ + returnUrl: string([url()]), +}); + +export const CancelSubscriptionSchema = object({ + subscriptionId: string([minLength(1)]), +}); + +export const ReactivateSubscriptionSchema = object({ + subscriptionId: string([minLength(1)]), +}); + +export const ListInvoicesSchema = object({ + limit: optional(string(), "10"), + startingAfter: optional(string()), +}); diff --git a/web/src/server/db/schema.test.ts b/web/src/server/db/schema.test.ts index 2f438a7..0c7b17b 100644 --- a/web/src/server/db/schema.test.ts +++ b/web/src/server/db/schema.test.ts @@ -76,17 +76,18 @@ describe("users table", () => { expect(colNames).toContain("name"); expect(colNames).toContain("image"); expect(colNames).toContain("role"); + expect(colNames).toContain("stripe_customer_id"); expect(colNames).toContain("deleted_at"); expect(colNames).toContain("created_at"); expect(colNames).toContain("updated_at"); }); - it("has 9 columns", () => { - expect(config.columns).toHaveLength(9); + it("has 10 columns", () => { + expect(config.columns).toHaveLength(10); }); - it("has 2 indexes", () => { - expect(config.indexes.length).toBe(2); + it("has 3 indexes", () => { + expect(config.indexes.length).toBe(3); }); }); diff --git a/web/src/server/db/schema/auth.ts b/web/src/server/db/schema/auth.ts index eb104e8..be9d211 100644 --- a/web/src/server/db/schema/auth.ts +++ b/web/src/server/db/schema/auth.ts @@ -8,12 +8,14 @@ export const users = pgTable("users", { name: text("name"), image: text("image"), role: userRole("role").default("user").notNull(), + stripeCustomerId: text("stripe_customer_id"), deletedAt: timestamp("deleted_at", { withTimezone: true, mode: "date" }), createdAt: timestamp("created_at", { withTimezone: true, mode: "date" }).defaultNow().notNull(), updatedAt: timestamp("updated_at", { withTimezone: true, mode: "date" }).defaultNow().notNull().$onUpdate(() => new Date()), }, (table) => ({ emailIdx: index("users_email_idx").on(table.email), roleIdx: index("users_role_idx").on(table.role), + stripeCustomerIdIdx: index("users_stripe_customer_id_idx").on(table.stripeCustomerId), })); export const accounts = pgTable("accounts", { diff --git a/web/src/server/services/billing.service.test.ts b/web/src/server/services/billing.service.test.ts new file mode 100644 index 0000000..f23967e --- /dev/null +++ b/web/src/server/services/billing.service.test.ts @@ -0,0 +1,338 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; + +vi.mock("~/server/stripe", () => ({ + stripe: { + customers: { create: vi.fn() }, + checkout: { sessions: { create: vi.fn() } }, + billingPortal: { sessions: { create: vi.fn() } }, + subscriptions: { update: vi.fn(), retrieve: vi.fn() }, + invoices: { list: vi.fn() }, + webhooks: { constructEvent: vi.fn() }, + }, +})); + +vi.mock("~/server/db", () => ({ + db: { + select: vi.fn(), + insert: vi.fn(), + update: vi.fn(), + query: { + subscriptions: { + findFirst: vi.fn(), + }, + }, + }, +})); + +import { stripe } from "~/server/stripe"; +import { db } from "~/server/db"; +import { + getOrCreateCustomer, + createCheckoutSession, + createPortalSession, + cancelSubscription, + reactivateSubscription, + listInvoices, + handleWebhookEvent, +} from "./billing.service"; + +beforeEach(() => { + vi.clearAllMocks(); +}); + +describe("getOrCreateCustomer", () => { + it("returns existing stripeCustomerId if present", async () => { + (db.select as ReturnType).mockReturnValue({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue([ + { id: "u1", email: "a@b.com", stripeCustomerId: "cus_existing" }, + ]), + }), + }), + }); + + const result = await getOrCreateCustomer("u1", "a@b.com"); + expect(result).toBe("cus_existing"); + expect(stripe.customers.create).not.toHaveBeenCalled(); + }); + + it("creates a new Stripe customer when no stripeCustomerId", async () => { + (db.select as ReturnType).mockReturnValue({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue([ + { id: "u1", email: "a@b.com", stripeCustomerId: null }, + ]), + }), + }), + }); + + (stripe.customers.create as ReturnType).mockResolvedValue({ id: "cus_new" }); + + (db.update as ReturnType).mockReturnValue({ + set: vi.fn().mockReturnValue({ + where: vi.fn().mockResolvedValue([{ id: "u1" }]), + }), + }); + + const result = await getOrCreateCustomer("u1", "a@b.com"); + expect(result).toBe("cus_new"); + expect(stripe.customers.create).toHaveBeenCalledWith({ + email: "a@b.com", + metadata: { userId: "u1" }, + }); + }); + + it("throws NOT_FOUND for missing user", async () => { + (db.select as ReturnType).mockReturnValue({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue([]), + }), + }), + }); + + await expect(getOrCreateCustomer("u-missing", "x@y.com")).rejects.toThrow( + "User not found", + ); + }); +}); + +describe("createCheckoutSession", () => { + it("creates a Stripe checkout session and returns URL", async () => { + (db.select as ReturnType).mockReturnValue({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue([ + { id: "u1", email: "a@b.com", stripeCustomerId: "cus_123" }, + ]), + }), + }), + }); + + (stripe.checkout.sessions.create as ReturnType).mockResolvedValue({ + url: "https://checkout.stripe.com/session_123", + id: "session_123", + }); + + const result = await createCheckoutSession( + "u1", + "a@b.com", + "price_basic", + "https://example.com/success", + "https://example.com/cancel", + ); + + expect(result.url).toBe("https://checkout.stripe.com/session_123"); + expect(result.sessionId).toBe("session_123"); + }); +}); + +describe("createPortalSession", () => { + it("creates a Stripe billing portal session", async () => { + (stripe.billingPortal.sessions.create as ReturnType).mockResolvedValue({ + url: "https://billing.stripe.com/portal/session_456", + }); + + const result = await createPortalSession( + "cus_123", + "https://example.com/return", + ); + + expect(result.url).toBe("https://billing.stripe.com/portal/session_456"); + }); +}); + +describe("cancelSubscription", () => { + it("sets cancel_at_period_end on Stripe subscription", async () => { + (stripe.subscriptions.update as ReturnType).mockResolvedValue({}); + (db.update as ReturnType).mockReturnValue({ + set: vi.fn().mockReturnValue({ + where: vi.fn().mockResolvedValue([]), + }), + }); + + const result = await cancelSubscription("sub_123"); + expect(result.cancelAtPeriodEnd).toBe(true); + expect(stripe.subscriptions.update).toHaveBeenCalledWith("sub_123", { + cancel_at_period_end: true, + }); + }); +}); + +describe("reactivateSubscription", () => { + it("removes cancel_at_period_end on Stripe subscription", async () => { + (stripe.subscriptions.update as ReturnType).mockResolvedValue({}); + (db.update as ReturnType).mockReturnValue({ + set: vi.fn().mockReturnValue({ + where: vi.fn().mockResolvedValue([]), + }), + }); + + const result = await reactivateSubscription("sub_123"); + expect(result.cancelAtPeriodEnd).toBe(false); + expect(stripe.subscriptions.update).toHaveBeenCalledWith("sub_123", { + cancel_at_period_end: false, + }); + }); +}); + +describe("listInvoices", () => { + it("returns invoices list from Stripe", async () => { + (stripe.invoices.list as ReturnType).mockResolvedValue({ + data: [{ id: "in_1" }, { id: "in_2" }], + has_more: false, + }); + + const result = await listInvoices("cus_123", 10); + expect(result.invoices).toHaveLength(2); + expect(result.hasMore).toBe(false); + }); +}); + +describe("handleWebhookEvent", () => { + it("handles checkout.session.completed", async () => { + (db.insert as ReturnType).mockReturnValue({ + values: vi.fn().mockReturnValue({ + onConflictDoNothing: vi.fn().mockResolvedValue(undefined), + }), + }); + + (stripe.subscriptions.retrieve as ReturnType).mockResolvedValue({ + id: "sub_new", + items: { data: [{ price: { id: "price_premium" } }] }, + current_period_start: 1700000000, + current_period_end: 1702592000, + status: "active", + cancel_at_period_end: false, + }); + + await handleWebhookEvent({ + type: "checkout.session.completed", + data: { + object: { + metadata: { userId: "u1" }, + subscription: "sub_new", + }, + }, + } as never); + + expect(db.insert).toHaveBeenCalled(); + }); + + it("handles invoice.paid", async () => { + (db.select as ReturnType).mockReturnValue({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue([{ id: "sub_db_1" }]), + }), + }), + }); + + (db.update as ReturnType).mockReturnValue({ + set: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + returning: vi.fn().mockResolvedValue([{ id: "sub_db_1", status: "active" }]), + }), + }), + }); + + await handleWebhookEvent({ + type: "invoice.paid", + data: { + object: { + subscription: "sub_123", + }, + }, + } as never); + }); + + it("handles invoice.payment_failed", async () => { + (db.select as ReturnType).mockReturnValue({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue([{ id: "sub_db_1" }]), + }), + }), + }); + + (db.update as ReturnType).mockReturnValue({ + set: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + returning: vi.fn().mockResolvedValue([{ id: "sub_db_1", status: "past_due" }]), + }), + }), + }); + + await handleWebhookEvent({ + type: "invoice.payment_failed", + data: { + object: { + subscription: "sub_123", + }, + }, + } as never); + }); + + it("handles customer.subscription.updated", async () => { + (db.query.subscriptions.findFirst as ReturnType).mockResolvedValue(null); + (db.select as ReturnType).mockReturnValue({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue([]), + }), + }), + }); + + (db.update as ReturnType).mockReturnValue({ + set: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + returning: vi.fn().mockResolvedValue([{ id: "sub_db_1", status: "active" }]), + }), + }), + }); + + await handleWebhookEvent({ + type: "customer.subscription.updated", + data: { + object: { + id: "sub_123", + metadata: { userId: "u1" }, + items: { data: [{ price: { id: "price_plus" } }] }, + current_period_start: 1700000000, + current_period_end: 1702592000, + status: "active", + cancel_at_period_end: false, + }, + }, + } as never); + }); + + it("handles customer.subscription.deleted", async () => { + (db.select as ReturnType).mockReturnValue({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue([{ id: "sub_db_1" }]), + }), + }), + }); + + (db.update as ReturnType).mockReturnValue({ + set: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + returning: vi.fn().mockResolvedValue([{ id: "sub_db_1", status: "canceled" }]), + }), + }), + }); + + await handleWebhookEvent({ + type: "customer.subscription.deleted", + data: { + object: { + id: "sub_123", + }, + }, + } as never); + }); +}); diff --git a/web/src/server/services/billing.service.ts b/web/src/server/services/billing.service.ts new file mode 100644 index 0000000..08dba7c --- /dev/null +++ b/web/src/server/services/billing.service.ts @@ -0,0 +1,235 @@ +import { TRPCError } from "@trpc/server"; +import { eq } from "drizzle-orm"; +import { db } from "~/server/db"; +import { stripe } from "~/server/stripe"; +import { users } from "~/server/db/schema/auth"; +import { subscriptions } from "~/server/db/schema/subscription"; +import type Stripe from "stripe"; + +type Tier = "basic" | "plus" | "premium"; + +export async function getOrCreateCustomer(userId: string, email: string) { + const [user] = await db + .select() + .from(users) + .where(eq(users.id, userId)) + .limit(1); + + if (!user) { + throw new TRPCError({ code: "NOT_FOUND", message: "User not found" }); + } + + if (user.stripeCustomerId) { + return user.stripeCustomerId; + } + + const customer = await stripe.customers.create({ + email, + metadata: { userId }, + }); + + await db + .update(users) + .set({ stripeCustomerId: customer.id }) + .where(eq(users.id, userId)); + + return customer.id; +} + +export async function createCheckoutSession( + userId: string, + email: string, + priceId: string, + successUrl: string, + cancelUrl: string, +) { + const customerId = await getOrCreateCustomer(userId, email); + + const session = await stripe.checkout.sessions.create({ + customer: customerId, + mode: "subscription", + line_items: [{ price: priceId, quantity: 1 }], + success_url: successUrl, + cancel_url: cancelUrl, + metadata: { userId }, + }); + + return { url: session.url, sessionId: session.id }; +} + +export async function createPortalSession(customerId: string, returnUrl: string) { + const session = await stripe.billingPortal.sessions.create({ + customer: customerId, + return_url: returnUrl, + }); + + return { url: session.url }; +} + +export async function cancelSubscription(stripeSubscriptionId: string) { + await stripe.subscriptions.update(stripeSubscriptionId, { + cancel_at_period_end: true, + }); + + await db + .update(subscriptions) + .set({ cancelAtPeriodEnd: true }) + .where(eq(subscriptions.stripeId, stripeSubscriptionId)); + + return { cancelAtPeriodEnd: true }; +} + +export async function reactivateSubscription(stripeSubscriptionId: string) { + await stripe.subscriptions.update(stripeSubscriptionId, { + cancel_at_period_end: false, + }); + + await db + .update(subscriptions) + .set({ cancelAtPeriodEnd: false }) + .where(eq(subscriptions.stripeId, stripeSubscriptionId)); + + return { cancelAtPeriodEnd: false }; +} + +export async function listInvoices( + customerId: string, + limit: number = 10, + startingAfter?: string, +) { + const params: Stripe.InvoiceListParams = { + customer: customerId, + limit, + }; + if (startingAfter) { + params.starting_after = startingAfter; + } + + const invoices = await stripe.invoices.list(params); + return { + invoices: invoices.data, + hasMore: invoices.has_more, + }; +} + +export async function updateSubscriptionInDB( + stripeId: string, + data: { + tier?: Tier; + status?: string; + currentPeriodStart?: Date; + currentPeriodEnd?: Date; + cancelAtPeriodEnd?: boolean; + }, +) { + const [existing] = await db + .select() + .from(subscriptions) + .where(eq(subscriptions.stripeId, stripeId)) + .limit(1); + + if (existing) { + const [updated] = await db + .update(subscriptions) + .set(data as Record) + .where(eq(subscriptions.stripeId, stripeId)) + .returning(); + return updated; + } + + return null; +} + +export async function handleWebhookEvent(event: Stripe.Event) { + const obj = event.data.object as unknown as Record; + + switch (event.type) { + case "checkout.session.completed": { + const session = obj as unknown as Stripe.Checkout.Session; + const userId = session.metadata?.userId; + if (!userId || !session.subscription) break; + + const stripeSub = await stripe.subscriptions.retrieve( + session.subscription as string, + ); + const sub = stripeSub as unknown as Record; + + await db.insert(subscriptions).values({ + userId, + stripeId: stripeSub.id, + tier: mapStripeProductToTier( + stripeSub.items.data[0]?.price?.id ?? "", + ), + status: sub.status as typeof subscriptions.$inferSelect.status, + currentPeriodStart: new Date((sub.current_period_start as number) * 1000), + currentPeriodEnd: new Date((sub.current_period_end as number) * 1000), + cancelAtPeriodEnd: sub.cancel_at_period_end as boolean, + }).onConflictDoNothing(); + break; + } + + case "invoice.paid": { + const invoice = obj; + if (!invoice.subscription) break; + + await updateSubscriptionInDB(invoice.subscription as string, { + status: "active", + }); + break; + } + + case "invoice.payment_failed": { + const invoice = obj; + if (!invoice.subscription) break; + + await updateSubscriptionInDB(invoice.subscription as string, { + status: "past_due", + }); + break; + } + + case "customer.subscription.updated": { + const stripeSub = obj as unknown as Stripe.Subscription; + const userId = stripeSub.metadata?.userId; + const sub = stripeSub as unknown as Record; + + if (!userId) { + const [existingSub] = await db + .select() + .from(subscriptions) + .where(eq(subscriptions.stripeId, stripeSub.id)) + .limit(1); + + if (!existingSub) break; + } + + const tier = stripeSub.items.data[0]?.price?.id + ? mapStripeProductToTier(stripeSub.items.data[0].price.id) + : undefined; + + await updateSubscriptionInDB(stripeSub.id, { + tier, + status: sub.status as string, + currentPeriodStart: new Date((sub.current_period_start as number) * 1000), + currentPeriodEnd: new Date((sub.current_period_end as number) * 1000), + cancelAtPeriodEnd: sub.cancel_at_period_end as boolean, + }); + break; + } + + case "customer.subscription.deleted": { + const stripeSub = obj as unknown as Stripe.Subscription; + await updateSubscriptionInDB(stripeSub.id, { + status: "canceled", + }); + break; + } + } +} + +function mapStripeProductToTier(priceId: string): Tier { + if (priceId === process.env.STRIPE_PRICE_BASIC) return "basic"; + if (priceId === process.env.STRIPE_PRICE_PLUS) return "plus"; + if (priceId === process.env.STRIPE_PRICE_PREMIUM) return "premium"; + return "basic"; +} diff --git a/web/src/server/stripe.ts b/web/src/server/stripe.ts new file mode 100644 index 0000000..0cf0594 --- /dev/null +++ b/web/src/server/stripe.ts @@ -0,0 +1,6 @@ +import Stripe from "stripe"; + +export const stripe = new Stripe(process.env.STRIPE_SECRET_KEY ?? "", { + apiVersion: "2026-04-22.dahlia" as const, + typescript: true, +});