feat(billing): add subscription and Stripe billing router
- Add stripeCustomerId column to users table - Create Stripe client initialization (web/src/server/stripe.ts) - Add billing service with getOrCreateCustomer, checkout/portal sessions, subscription management, invoice listing, and webhook event handling - Create billing tRPC router with getSubscription, createCheckoutSession, createPortalSession, cancelSubscription, reactivateSubscription, listInvoices - Add raw webhook endpoint at /api/stripe/webhook with signature verification - Define Valibot schemas for all billing procedure inputs - Wire billing router into root tRPC router - Update schema tests for new column/index counts - Write unit tests for billing service and router
This commit is contained in:
16
pnpm-lock.yaml
generated
16
pnpm-lock.yaml
generated
@@ -67,6 +67,9 @@ importers:
|
||||
solid-js:
|
||||
specifier: ^1.9.5
|
||||
version: 1.9.13
|
||||
stripe:
|
||||
specifier: ^22.1.1
|
||||
version: 22.1.1(@types/node@25.9.1)
|
||||
tailwindcss:
|
||||
specifier: ^4.0.0
|
||||
version: 4.3.0
|
||||
@@ -3130,6 +3133,15 @@ packages:
|
||||
strip-literal@3.1.0:
|
||||
resolution: {integrity: sha512-8r3mkIM/2+PpjHoOtiAW8Rg3jJLHaV7xPwG+YRGrv6FP0wwk/toTpATxWYOW0BKdWwl82VT2tFYi5DlROa0Mxg==}
|
||||
|
||||
stripe@22.1.1:
|
||||
resolution: {integrity: sha512-cmodIYP27tBkJ8G7DuGgWw0PFuemlFZbuF3Wwr1TrjFjUa3T7NIgCe6TVwX8BO2ynu+xtTuDGfHafNDCPt9lXA==}
|
||||
engines: {node: '>=18'}
|
||||
peerDependencies:
|
||||
'@types/node': '>=18'
|
||||
peerDependenciesMeta:
|
||||
'@types/node':
|
||||
optional: true
|
||||
|
||||
supports-color@10.2.2:
|
||||
resolution: {integrity: sha512-SS+jx45GF1QjgEXQx4NJZV9ImqmO2NPz5FNsIHrsDjh2YsHnawpan7SNQ1o8NuhrbHZy9AZhIoCUiCeaW/C80g==}
|
||||
engines: {node: '>=18'}
|
||||
@@ -6431,6 +6443,10 @@ snapshots:
|
||||
dependencies:
|
||||
js-tokens: 9.0.1
|
||||
|
||||
stripe@22.1.1(@types/node@25.9.1):
|
||||
optionalDependencies:
|
||||
'@types/node': 25.9.1
|
||||
|
||||
supports-color@10.2.2: {}
|
||||
|
||||
supports-color@7.2.0:
|
||||
|
||||
@@ -1 +1,8 @@
|
||||
DATABASE_URL="postgresql://postgres:postgres@localhost:5432/shieldai"
|
||||
|
||||
# Stripe (get test keys from https://dashboard.stripe.com/test/apikeys)
|
||||
STRIPE_SECRET_KEY="sk_test_..."
|
||||
STRIPE_WEBHOOK_SECRET="whsec_..."
|
||||
STRIPE_PRICE_BASIC="price_basic"
|
||||
STRIPE_PRICE_PLUS="price_plus"
|
||||
STRIPE_PRICE_PREMIUM="price_premium"
|
||||
|
||||
@@ -28,6 +28,7 @@
|
||||
"jose": "^5",
|
||||
"pg": "^8.21.0",
|
||||
"solid-js": "^1.9.5",
|
||||
"stripe": "^22.1.1",
|
||||
"tailwindcss": "^4.0.0",
|
||||
"three": "^0.184.0",
|
||||
"valibot": "^0.29.0",
|
||||
|
||||
27
web/src/routes/api/stripe/webhook.ts
Normal file
27
web/src/routes/api/stripe/webhook.ts
Normal file
@@ -0,0 +1,27 @@
|
||||
import type { APIEvent } from "@solidjs/start/server";
|
||||
import { stripe } from "~/server/stripe";
|
||||
import { handleWebhookEvent } from "~/server/services/billing.service";
|
||||
|
||||
export async function POST(event: APIEvent) {
|
||||
const body = await event.request.text();
|
||||
const signature = event.request.headers.get("stripe-signature");
|
||||
|
||||
if (!signature) {
|
||||
return new Response("Missing stripe-signature header", { status: 400 });
|
||||
}
|
||||
|
||||
try {
|
||||
const webhookEvent = stripe.webhooks.constructEvent(
|
||||
body,
|
||||
signature,
|
||||
process.env.STRIPE_WEBHOOK_SECRET ?? "",
|
||||
);
|
||||
|
||||
await handleWebhookEvent(webhookEvent);
|
||||
|
||||
return new Response("OK", { status: 200 });
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : "Webhook error";
|
||||
return new Response(message, { status: 400 });
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,12 @@
|
||||
import { exampleRouter } from "./routers/example";
|
||||
import { userRouter } from "./routers/user";
|
||||
import { billingRouter } from "./routers/billing";
|
||||
import { createTRPCRouter } from "./utils";
|
||||
|
||||
export const appRouter = createTRPCRouter({
|
||||
example: exampleRouter,
|
||||
user: userRouter,
|
||||
billing: billingRouter,
|
||||
});
|
||||
|
||||
export type AppRouter = typeof appRouter;
|
||||
|
||||
241
web/src/server/api/routers/billing.test.ts
Normal file
241
web/src/server/api/routers/billing.test.ts
Normal file
@@ -0,0 +1,241 @@
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||
import { initTRPC, TRPCError } from "@trpc/server";
|
||||
import { wrap } from "@typeschema/valibot";
|
||||
import {
|
||||
CreateCheckoutSessionSchema,
|
||||
CreatePortalSessionSchema,
|
||||
CancelSubscriptionSchema,
|
||||
ReactivateSubscriptionSchema,
|
||||
ListInvoicesSchema,
|
||||
} from "../schemas/billing";
|
||||
|
||||
vi.mock("~/server/services/billing.service", () => ({
|
||||
getOrCreateCustomer: vi.fn(),
|
||||
createCheckoutSession: vi.fn(),
|
||||
createPortalSession: vi.fn(),
|
||||
cancelSubscription: vi.fn(),
|
||||
reactivateSubscription: vi.fn(),
|
||||
listInvoices: vi.fn(),
|
||||
}));
|
||||
|
||||
const { mockFindFirst } = vi.hoisted(() => ({
|
||||
mockFindFirst: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("~/server/db", () => ({
|
||||
db: {
|
||||
query: {
|
||||
subscriptions: {
|
||||
findFirst: mockFindFirst,
|
||||
},
|
||||
},
|
||||
},
|
||||
}));
|
||||
|
||||
import {
|
||||
createCheckoutSession,
|
||||
createPortalSession,
|
||||
cancelSubscription,
|
||||
reactivateSubscription,
|
||||
listInvoices,
|
||||
} from "~/server/services/billing.service";
|
||||
import { db } from "~/server/db";
|
||||
|
||||
const mockCreateCheckoutSession = vi.mocked(createCheckoutSession);
|
||||
const mockCreatePortalSession = vi.mocked(createPortalSession);
|
||||
const mockCancelSubscription = vi.mocked(cancelSubscription);
|
||||
const mockReactivateSubscription = vi.mocked(reactivateSubscription);
|
||||
const mockListInvoices = vi.mocked(listInvoices);
|
||||
const mockDb = vi.mocked(db);
|
||||
|
||||
type User = {
|
||||
id: string; email: string; name: string | null; image: string | null;
|
||||
role: "user" | "family_admin" | "family_member" | "support"; emailVerified: Date | null; deletedAt: Date | null;
|
||||
stripeCustomerId: string | null;
|
||||
createdAt: Date; updatedAt: Date;
|
||||
};
|
||||
|
||||
type Ctx = { db: object; user: User | null; apiKey: string | null };
|
||||
|
||||
const baseUser: User = {
|
||||
id: "user-1", email: "a@b.com", name: "Test", image: null,
|
||||
role: "user", emailVerified: null, deletedAt: null,
|
||||
stripeCustomerId: "cus_123",
|
||||
createdAt: new Date(), updatedAt: new Date(),
|
||||
};
|
||||
|
||||
function makeUser(overrides: Partial<User> = {}): User {
|
||||
return { ...baseUser, ...overrides };
|
||||
}
|
||||
|
||||
function createCaller(user: User | null) {
|
||||
const t = initTRPC.context<Ctx>().create();
|
||||
const isAuthed = t.middleware(({ ctx, next }) => {
|
||||
if (!ctx.user) throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||
return next({ ctx: { ...ctx, user: ctx.user } });
|
||||
});
|
||||
|
||||
const router = t.router({
|
||||
getSubscription: t.procedure.use(isAuthed)
|
||||
.query(async () => {
|
||||
const sub = await mockFindFirst();
|
||||
return sub ?? null;
|
||||
}),
|
||||
createCheckoutSession: t.procedure.use(isAuthed)
|
||||
.input(wrap(CreateCheckoutSessionSchema))
|
||||
.mutation(async ({ ctx, input }) => {
|
||||
const i = input as { priceId: string; successUrl: string; cancelUrl: string };
|
||||
return mockCreateCheckoutSession(ctx.user.id, ctx.user.email, i.priceId, i.successUrl, i.cancelUrl);
|
||||
}),
|
||||
createPortalSession: t.procedure.use(isAuthed)
|
||||
.input(wrap(CreatePortalSessionSchema))
|
||||
.mutation(async ({ ctx, input }) => {
|
||||
const i = input as { returnUrl: string };
|
||||
if (!ctx.user.stripeCustomerId) {
|
||||
throw new TRPCError({ code: "NOT_FOUND", message: "No Stripe customer found" });
|
||||
}
|
||||
return mockCreatePortalSession(ctx.user.stripeCustomerId, i.returnUrl);
|
||||
}),
|
||||
cancelSubscription: t.procedure.use(isAuthed)
|
||||
.input(wrap(CancelSubscriptionSchema))
|
||||
.mutation(async ({ input }) => {
|
||||
const i = input as { subscriptionId: string };
|
||||
return mockCancelSubscription(i.subscriptionId);
|
||||
}),
|
||||
reactivateSubscription: t.procedure.use(isAuthed)
|
||||
.input(wrap(ReactivateSubscriptionSchema))
|
||||
.mutation(async ({ input }) => {
|
||||
const i = input as { subscriptionId: string };
|
||||
return mockReactivateSubscription(i.subscriptionId);
|
||||
}),
|
||||
listInvoices: t.procedure.use(isAuthed)
|
||||
.input(wrap(ListInvoicesSchema))
|
||||
.query(async ({ ctx, input }) => {
|
||||
if (!ctx.user.stripeCustomerId) {
|
||||
return { invoices: [], hasMore: false };
|
||||
}
|
||||
const i = input as { limit?: string; startingAfter?: string };
|
||||
return mockListInvoices(ctx.user.stripeCustomerId, parseInt(i.limit ?? "10", 10), i.startingAfter);
|
||||
}),
|
||||
});
|
||||
|
||||
const caller = t.createCallerFactory(router);
|
||||
return caller({ db: {} as never, user, apiKey: null });
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("billing.getSubscription", () => {
|
||||
it("returns subscription for authenticated user", async () => {
|
||||
const now = new Date();
|
||||
(mockDb.query.subscriptions.findFirst as ReturnType<typeof vi.fn>).mockResolvedValue({
|
||||
id: "sub-1", userId: "user-1", stripeId: "sub_stripe_1",
|
||||
tier: "premium", status: "active",
|
||||
currentPeriodStart: now, currentPeriodEnd: now,
|
||||
cancelAtPeriodEnd: false,
|
||||
createdAt: now, updatedAt: now,
|
||||
});
|
||||
|
||||
const api = createCaller(makeUser());
|
||||
const result = await api.getSubscription();
|
||||
expect(result).not.toBeNull();
|
||||
expect(result!.tier).toBe("premium");
|
||||
expect(result!.status).toBe("active");
|
||||
});
|
||||
|
||||
it("returns null when user has no subscription", async () => {
|
||||
(mockDb.query.subscriptions.findFirst as ReturnType<typeof vi.fn>).mockResolvedValue(undefined);
|
||||
const api = createCaller(makeUser());
|
||||
const result = await api.getSubscription();
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("rejects unauthenticated users", async () => {
|
||||
const api = createCaller(null);
|
||||
await expect(api.getSubscription()).rejects.toThrow(TRPCError);
|
||||
});
|
||||
});
|
||||
|
||||
describe("billing.createCheckoutSession", () => {
|
||||
it("creates checkout session and returns URL", async () => {
|
||||
mockCreateCheckoutSession.mockResolvedValue({
|
||||
url: "https://checkout.stripe.com/session_123",
|
||||
sessionId: "session_123",
|
||||
});
|
||||
|
||||
const api = createCaller(makeUser());
|
||||
const result = await api.createCheckoutSession({
|
||||
priceId: "price_basic",
|
||||
successUrl: "https://example.com/success",
|
||||
cancelUrl: "https://example.com/cancel",
|
||||
});
|
||||
|
||||
expect(result.url).toBe("https://checkout.stripe.com/session_123");
|
||||
});
|
||||
});
|
||||
|
||||
describe("billing.createPortalSession", () => {
|
||||
it("creates portal session for user with stripeCustomerId", async () => {
|
||||
mockCreatePortalSession.mockResolvedValue({
|
||||
url: "https://billing.stripe.com/portal/session_456",
|
||||
});
|
||||
|
||||
const api = createCaller(makeUser());
|
||||
const result = await api.createPortalSession({
|
||||
returnUrl: "https://example.com/return",
|
||||
});
|
||||
|
||||
expect(result.url).toBe("https://billing.stripe.com/portal/session_456");
|
||||
});
|
||||
|
||||
it("throws NOT_FOUND when user has no stripeCustomerId", async () => {
|
||||
const api = createCaller(makeUser({ stripeCustomerId: null }));
|
||||
await expect(api.createPortalSession({ returnUrl: "https://example.com/return" })).rejects.toThrow(TRPCError);
|
||||
});
|
||||
});
|
||||
|
||||
describe("billing.cancelSubscription", () => {
|
||||
it("cancels subscription", async () => {
|
||||
mockCancelSubscription.mockResolvedValue({ cancelAtPeriodEnd: true });
|
||||
|
||||
const api = createCaller(makeUser());
|
||||
const result = await api.cancelSubscription({ subscriptionId: "sub_123" });
|
||||
|
||||
expect(result.cancelAtPeriodEnd).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("billing.reactivateSubscription", () => {
|
||||
it("reactivates subscription", async () => {
|
||||
mockReactivateSubscription.mockResolvedValue({ cancelAtPeriodEnd: false });
|
||||
|
||||
const api = createCaller(makeUser());
|
||||
const result = await api.reactivateSubscription({ subscriptionId: "sub_123" });
|
||||
|
||||
expect(result.cancelAtPeriodEnd).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("billing.listInvoices", () => {
|
||||
it("lists invoices for user with stripeCustomerId", async () => {
|
||||
mockListInvoices.mockResolvedValue({
|
||||
invoices: [{ id: "in_1" }, { id: "in_2" }] as never,
|
||||
hasMore: false,
|
||||
});
|
||||
|
||||
const api = createCaller(makeUser());
|
||||
const result = await api.listInvoices({});
|
||||
|
||||
expect(result.invoices).toHaveLength(2);
|
||||
});
|
||||
|
||||
it("returns empty list when user has no stripeCustomerId", async () => {
|
||||
const api = createCaller(makeUser({ stripeCustomerId: null }));
|
||||
const result = await api.listInvoices({});
|
||||
|
||||
expect(result.invoices).toHaveLength(0);
|
||||
expect(result.hasMore).toBe(false);
|
||||
});
|
||||
});
|
||||
100
web/src/server/api/routers/billing.ts
Normal file
100
web/src/server/api/routers/billing.ts
Normal file
@@ -0,0 +1,100 @@
|
||||
import { TRPCError } from "@trpc/server";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { wrap } from "@typeschema/valibot";
|
||||
import { createTRPCRouter, protectedProcedure } from "../utils";
|
||||
import {
|
||||
CreateCheckoutSessionSchema,
|
||||
CreatePortalSessionSchema,
|
||||
CancelSubscriptionSchema,
|
||||
ReactivateSubscriptionSchema,
|
||||
ListInvoicesSchema,
|
||||
} from "../schemas/billing";
|
||||
import {
|
||||
getOrCreateCustomer,
|
||||
createCheckoutSession,
|
||||
createPortalSession,
|
||||
cancelSubscription,
|
||||
reactivateSubscription,
|
||||
listInvoices,
|
||||
} from "~/server/services/billing.service";
|
||||
import { db } from "~/server/db";
|
||||
import { subscriptions } from "~/server/db/schema/subscription";
|
||||
|
||||
export const billingRouter = createTRPCRouter({
|
||||
getSubscription: protectedProcedure.query(async ({ ctx }) => {
|
||||
const sub = await db.query.subscriptions.findFirst({
|
||||
where: eq(subscriptions.userId, ctx.user.id),
|
||||
});
|
||||
return sub ?? null;
|
||||
}),
|
||||
|
||||
createCheckoutSession: protectedProcedure
|
||||
.input(wrap(CreateCheckoutSessionSchema))
|
||||
.mutation(async ({ ctx, input }) => {
|
||||
const allowedPrices = [
|
||||
process.env.STRIPE_PRICE_BASIC,
|
||||
process.env.STRIPE_PRICE_PLUS,
|
||||
process.env.STRIPE_PRICE_PREMIUM,
|
||||
].filter(Boolean);
|
||||
|
||||
if (!allowedPrices.includes(input.priceId)) {
|
||||
throw new TRPCError({
|
||||
code: "BAD_REQUEST",
|
||||
message: "Invalid price ID",
|
||||
});
|
||||
}
|
||||
|
||||
return createCheckoutSession(
|
||||
ctx.user.id,
|
||||
ctx.user.email,
|
||||
input.priceId,
|
||||
input.successUrl,
|
||||
input.cancelUrl,
|
||||
);
|
||||
}),
|
||||
|
||||
createPortalSession: protectedProcedure
|
||||
.input(wrap(CreatePortalSessionSchema))
|
||||
.mutation(async ({ ctx, input }) => {
|
||||
const user = ctx.user;
|
||||
const stripeCustomerId = user.stripeCustomerId;
|
||||
|
||||
if (!stripeCustomerId) {
|
||||
throw new TRPCError({
|
||||
code: "NOT_FOUND",
|
||||
message: "No Stripe customer found",
|
||||
});
|
||||
}
|
||||
|
||||
return createPortalSession(stripeCustomerId, input.returnUrl);
|
||||
}),
|
||||
|
||||
cancelSubscription: protectedProcedure
|
||||
.input(wrap(CancelSubscriptionSchema))
|
||||
.mutation(async ({ input }) => {
|
||||
return cancelSubscription(input.subscriptionId);
|
||||
}),
|
||||
|
||||
reactivateSubscription: protectedProcedure
|
||||
.input(wrap(ReactivateSubscriptionSchema))
|
||||
.mutation(async ({ input }) => {
|
||||
return reactivateSubscription(input.subscriptionId);
|
||||
}),
|
||||
|
||||
listInvoices: protectedProcedure
|
||||
.input(wrap(ListInvoicesSchema))
|
||||
.query(async ({ ctx, input }) => {
|
||||
const user = ctx.user;
|
||||
const stripeCustomerId = user.stripeCustomerId;
|
||||
|
||||
if (!stripeCustomerId) {
|
||||
return { invoices: [], hasMore: false };
|
||||
}
|
||||
|
||||
return listInvoices(
|
||||
stripeCustomerId,
|
||||
parseInt(input.limit ?? "10", 10),
|
||||
input.startingAfter,
|
||||
);
|
||||
}),
|
||||
});
|
||||
@@ -30,6 +30,7 @@ const mockUpdateMemberRole = vi.mocked(updateMemberRole);
|
||||
type User = {
|
||||
id: string; email: string; name: string | null; image: string | null;
|
||||
role: "user" | "family_admin" | "family_member" | "support"; emailVerified: Date | null; deletedAt: Date | null;
|
||||
stripeCustomerId: string | null;
|
||||
createdAt: Date; updatedAt: Date;
|
||||
};
|
||||
type Ctx = { db: object; user: User | null; apiKey: string | null };
|
||||
@@ -97,6 +98,7 @@ function createCaller(user: User | null) {
|
||||
const baseUser: User = {
|
||||
id: "user-1", email: "a@b.com", name: "Test", image: null,
|
||||
role: "user", emailVerified: null, deletedAt: null,
|
||||
stripeCustomerId: null,
|
||||
createdAt: new Date(), updatedAt: new Date(),
|
||||
};
|
||||
|
||||
|
||||
24
web/src/server/api/schemas/billing.ts
Normal file
24
web/src/server/api/schemas/billing.ts
Normal file
@@ -0,0 +1,24 @@
|
||||
import { object, string, url, minLength, optional, picklist } from "valibot";
|
||||
|
||||
export const CreateCheckoutSessionSchema = object({
|
||||
priceId: string([minLength(1)]),
|
||||
successUrl: string([url()]),
|
||||
cancelUrl: string([url()]),
|
||||
});
|
||||
|
||||
export const CreatePortalSessionSchema = object({
|
||||
returnUrl: string([url()]),
|
||||
});
|
||||
|
||||
export const CancelSubscriptionSchema = object({
|
||||
subscriptionId: string([minLength(1)]),
|
||||
});
|
||||
|
||||
export const ReactivateSubscriptionSchema = object({
|
||||
subscriptionId: string([minLength(1)]),
|
||||
});
|
||||
|
||||
export const ListInvoicesSchema = object({
|
||||
limit: optional(string(), "10"),
|
||||
startingAfter: optional(string()),
|
||||
});
|
||||
@@ -76,17 +76,18 @@ describe("users table", () => {
|
||||
expect(colNames).toContain("name");
|
||||
expect(colNames).toContain("image");
|
||||
expect(colNames).toContain("role");
|
||||
expect(colNames).toContain("stripe_customer_id");
|
||||
expect(colNames).toContain("deleted_at");
|
||||
expect(colNames).toContain("created_at");
|
||||
expect(colNames).toContain("updated_at");
|
||||
});
|
||||
|
||||
it("has 9 columns", () => {
|
||||
expect(config.columns).toHaveLength(9);
|
||||
it("has 10 columns", () => {
|
||||
expect(config.columns).toHaveLength(10);
|
||||
});
|
||||
|
||||
it("has 2 indexes", () => {
|
||||
expect(config.indexes.length).toBe(2);
|
||||
it("has 3 indexes", () => {
|
||||
expect(config.indexes.length).toBe(3);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -8,12 +8,14 @@ export const users = pgTable("users", {
|
||||
name: text("name"),
|
||||
image: text("image"),
|
||||
role: userRole("role").default("user").notNull(),
|
||||
stripeCustomerId: text("stripe_customer_id"),
|
||||
deletedAt: timestamp("deleted_at", { withTimezone: true, mode: "date" }),
|
||||
createdAt: timestamp("created_at", { withTimezone: true, mode: "date" }).defaultNow().notNull(),
|
||||
updatedAt: timestamp("updated_at", { withTimezone: true, mode: "date" }).defaultNow().notNull().$onUpdate(() => new Date()),
|
||||
}, (table) => ({
|
||||
emailIdx: index("users_email_idx").on(table.email),
|
||||
roleIdx: index("users_role_idx").on(table.role),
|
||||
stripeCustomerIdIdx: index("users_stripe_customer_id_idx").on(table.stripeCustomerId),
|
||||
}));
|
||||
|
||||
export const accounts = pgTable("accounts", {
|
||||
|
||||
338
web/src/server/services/billing.service.test.ts
Normal file
338
web/src/server/services/billing.service.test.ts
Normal file
@@ -0,0 +1,338 @@
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||
|
||||
vi.mock("~/server/stripe", () => ({
|
||||
stripe: {
|
||||
customers: { create: vi.fn() },
|
||||
checkout: { sessions: { create: vi.fn() } },
|
||||
billingPortal: { sessions: { create: vi.fn() } },
|
||||
subscriptions: { update: vi.fn(), retrieve: vi.fn() },
|
||||
invoices: { list: vi.fn() },
|
||||
webhooks: { constructEvent: vi.fn() },
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock("~/server/db", () => ({
|
||||
db: {
|
||||
select: vi.fn(),
|
||||
insert: vi.fn(),
|
||||
update: vi.fn(),
|
||||
query: {
|
||||
subscriptions: {
|
||||
findFirst: vi.fn(),
|
||||
},
|
||||
},
|
||||
},
|
||||
}));
|
||||
|
||||
import { stripe } from "~/server/stripe";
|
||||
import { db } from "~/server/db";
|
||||
import {
|
||||
getOrCreateCustomer,
|
||||
createCheckoutSession,
|
||||
createPortalSession,
|
||||
cancelSubscription,
|
||||
reactivateSubscription,
|
||||
listInvoices,
|
||||
handleWebhookEvent,
|
||||
} from "./billing.service";
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("getOrCreateCustomer", () => {
|
||||
it("returns existing stripeCustomerId if present", async () => {
|
||||
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([
|
||||
{ id: "u1", email: "a@b.com", stripeCustomerId: "cus_existing" },
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
const result = await getOrCreateCustomer("u1", "a@b.com");
|
||||
expect(result).toBe("cus_existing");
|
||||
expect(stripe.customers.create).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("creates a new Stripe customer when no stripeCustomerId", async () => {
|
||||
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([
|
||||
{ id: "u1", email: "a@b.com", stripeCustomerId: null },
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
(stripe.customers.create as ReturnType<typeof vi.fn>).mockResolvedValue({ id: "cus_new" });
|
||||
|
||||
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
set: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockResolvedValue([{ id: "u1" }]),
|
||||
}),
|
||||
});
|
||||
|
||||
const result = await getOrCreateCustomer("u1", "a@b.com");
|
||||
expect(result).toBe("cus_new");
|
||||
expect(stripe.customers.create).toHaveBeenCalledWith({
|
||||
email: "a@b.com",
|
||||
metadata: { userId: "u1" },
|
||||
});
|
||||
});
|
||||
|
||||
it("throws NOT_FOUND for missing user", async () => {
|
||||
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([]),
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
await expect(getOrCreateCustomer("u-missing", "x@y.com")).rejects.toThrow(
|
||||
"User not found",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("createCheckoutSession", () => {
|
||||
it("creates a Stripe checkout session and returns URL", async () => {
|
||||
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([
|
||||
{ id: "u1", email: "a@b.com", stripeCustomerId: "cus_123" },
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
(stripe.checkout.sessions.create as ReturnType<typeof vi.fn>).mockResolvedValue({
|
||||
url: "https://checkout.stripe.com/session_123",
|
||||
id: "session_123",
|
||||
});
|
||||
|
||||
const result = await createCheckoutSession(
|
||||
"u1",
|
||||
"a@b.com",
|
||||
"price_basic",
|
||||
"https://example.com/success",
|
||||
"https://example.com/cancel",
|
||||
);
|
||||
|
||||
expect(result.url).toBe("https://checkout.stripe.com/session_123");
|
||||
expect(result.sessionId).toBe("session_123");
|
||||
});
|
||||
});
|
||||
|
||||
describe("createPortalSession", () => {
|
||||
it("creates a Stripe billing portal session", async () => {
|
||||
(stripe.billingPortal.sessions.create as ReturnType<typeof vi.fn>).mockResolvedValue({
|
||||
url: "https://billing.stripe.com/portal/session_456",
|
||||
});
|
||||
|
||||
const result = await createPortalSession(
|
||||
"cus_123",
|
||||
"https://example.com/return",
|
||||
);
|
||||
|
||||
expect(result.url).toBe("https://billing.stripe.com/portal/session_456");
|
||||
});
|
||||
});
|
||||
|
||||
describe("cancelSubscription", () => {
|
||||
it("sets cancel_at_period_end on Stripe subscription", async () => {
|
||||
(stripe.subscriptions.update as ReturnType<typeof vi.fn>).mockResolvedValue({});
|
||||
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
set: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockResolvedValue([]),
|
||||
}),
|
||||
});
|
||||
|
||||
const result = await cancelSubscription("sub_123");
|
||||
expect(result.cancelAtPeriodEnd).toBe(true);
|
||||
expect(stripe.subscriptions.update).toHaveBeenCalledWith("sub_123", {
|
||||
cancel_at_period_end: true,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("reactivateSubscription", () => {
|
||||
it("removes cancel_at_period_end on Stripe subscription", async () => {
|
||||
(stripe.subscriptions.update as ReturnType<typeof vi.fn>).mockResolvedValue({});
|
||||
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
set: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockResolvedValue([]),
|
||||
}),
|
||||
});
|
||||
|
||||
const result = await reactivateSubscription("sub_123");
|
||||
expect(result.cancelAtPeriodEnd).toBe(false);
|
||||
expect(stripe.subscriptions.update).toHaveBeenCalledWith("sub_123", {
|
||||
cancel_at_period_end: false,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("listInvoices", () => {
|
||||
it("returns invoices list from Stripe", async () => {
|
||||
(stripe.invoices.list as ReturnType<typeof vi.fn>).mockResolvedValue({
|
||||
data: [{ id: "in_1" }, { id: "in_2" }],
|
||||
has_more: false,
|
||||
});
|
||||
|
||||
const result = await listInvoices("cus_123", 10);
|
||||
expect(result.invoices).toHaveLength(2);
|
||||
expect(result.hasMore).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("handleWebhookEvent", () => {
|
||||
it("handles checkout.session.completed", async () => {
|
||||
(db.insert as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
values: vi.fn().mockReturnValue({
|
||||
onConflictDoNothing: vi.fn().mockResolvedValue(undefined),
|
||||
}),
|
||||
});
|
||||
|
||||
(stripe.subscriptions.retrieve as ReturnType<typeof vi.fn>).mockResolvedValue({
|
||||
id: "sub_new",
|
||||
items: { data: [{ price: { id: "price_premium" } }] },
|
||||
current_period_start: 1700000000,
|
||||
current_period_end: 1702592000,
|
||||
status: "active",
|
||||
cancel_at_period_end: false,
|
||||
});
|
||||
|
||||
await handleWebhookEvent({
|
||||
type: "checkout.session.completed",
|
||||
data: {
|
||||
object: {
|
||||
metadata: { userId: "u1" },
|
||||
subscription: "sub_new",
|
||||
},
|
||||
},
|
||||
} as never);
|
||||
|
||||
expect(db.insert).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("handles invoice.paid", async () => {
|
||||
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([{ id: "sub_db_1" }]),
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
set: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
returning: vi.fn().mockResolvedValue([{ id: "sub_db_1", status: "active" }]),
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
await handleWebhookEvent({
|
||||
type: "invoice.paid",
|
||||
data: {
|
||||
object: {
|
||||
subscription: "sub_123",
|
||||
},
|
||||
},
|
||||
} as never);
|
||||
});
|
||||
|
||||
it("handles invoice.payment_failed", async () => {
|
||||
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([{ id: "sub_db_1" }]),
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
set: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
returning: vi.fn().mockResolvedValue([{ id: "sub_db_1", status: "past_due" }]),
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
await handleWebhookEvent({
|
||||
type: "invoice.payment_failed",
|
||||
data: {
|
||||
object: {
|
||||
subscription: "sub_123",
|
||||
},
|
||||
},
|
||||
} as never);
|
||||
});
|
||||
|
||||
it("handles customer.subscription.updated", async () => {
|
||||
(db.query.subscriptions.findFirst as ReturnType<typeof vi.fn>).mockResolvedValue(null);
|
||||
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([]),
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
set: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
returning: vi.fn().mockResolvedValue([{ id: "sub_db_1", status: "active" }]),
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
await handleWebhookEvent({
|
||||
type: "customer.subscription.updated",
|
||||
data: {
|
||||
object: {
|
||||
id: "sub_123",
|
||||
metadata: { userId: "u1" },
|
||||
items: { data: [{ price: { id: "price_plus" } }] },
|
||||
current_period_start: 1700000000,
|
||||
current_period_end: 1702592000,
|
||||
status: "active",
|
||||
cancel_at_period_end: false,
|
||||
},
|
||||
},
|
||||
} as never);
|
||||
});
|
||||
|
||||
it("handles customer.subscription.deleted", async () => {
|
||||
(db.select as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([{ id: "sub_db_1" }]),
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
(db.update as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
set: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
returning: vi.fn().mockResolvedValue([{ id: "sub_db_1", status: "canceled" }]),
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
await handleWebhookEvent({
|
||||
type: "customer.subscription.deleted",
|
||||
data: {
|
||||
object: {
|
||||
id: "sub_123",
|
||||
},
|
||||
},
|
||||
} as never);
|
||||
});
|
||||
});
|
||||
235
web/src/server/services/billing.service.ts
Normal file
235
web/src/server/services/billing.service.ts
Normal file
@@ -0,0 +1,235 @@
|
||||
import { TRPCError } from "@trpc/server";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { db } from "~/server/db";
|
||||
import { stripe } from "~/server/stripe";
|
||||
import { users } from "~/server/db/schema/auth";
|
||||
import { subscriptions } from "~/server/db/schema/subscription";
|
||||
import type Stripe from "stripe";
|
||||
|
||||
type Tier = "basic" | "plus" | "premium";
|
||||
|
||||
export async function getOrCreateCustomer(userId: string, email: string) {
|
||||
const [user] = await db
|
||||
.select()
|
||||
.from(users)
|
||||
.where(eq(users.id, userId))
|
||||
.limit(1);
|
||||
|
||||
if (!user) {
|
||||
throw new TRPCError({ code: "NOT_FOUND", message: "User not found" });
|
||||
}
|
||||
|
||||
if (user.stripeCustomerId) {
|
||||
return user.stripeCustomerId;
|
||||
}
|
||||
|
||||
const customer = await stripe.customers.create({
|
||||
email,
|
||||
metadata: { userId },
|
||||
});
|
||||
|
||||
await db
|
||||
.update(users)
|
||||
.set({ stripeCustomerId: customer.id })
|
||||
.where(eq(users.id, userId));
|
||||
|
||||
return customer.id;
|
||||
}
|
||||
|
||||
export async function createCheckoutSession(
|
||||
userId: string,
|
||||
email: string,
|
||||
priceId: string,
|
||||
successUrl: string,
|
||||
cancelUrl: string,
|
||||
) {
|
||||
const customerId = await getOrCreateCustomer(userId, email);
|
||||
|
||||
const session = await stripe.checkout.sessions.create({
|
||||
customer: customerId,
|
||||
mode: "subscription",
|
||||
line_items: [{ price: priceId, quantity: 1 }],
|
||||
success_url: successUrl,
|
||||
cancel_url: cancelUrl,
|
||||
metadata: { userId },
|
||||
});
|
||||
|
||||
return { url: session.url, sessionId: session.id };
|
||||
}
|
||||
|
||||
export async function createPortalSession(customerId: string, returnUrl: string) {
|
||||
const session = await stripe.billingPortal.sessions.create({
|
||||
customer: customerId,
|
||||
return_url: returnUrl,
|
||||
});
|
||||
|
||||
return { url: session.url };
|
||||
}
|
||||
|
||||
export async function cancelSubscription(stripeSubscriptionId: string) {
|
||||
await stripe.subscriptions.update(stripeSubscriptionId, {
|
||||
cancel_at_period_end: true,
|
||||
});
|
||||
|
||||
await db
|
||||
.update(subscriptions)
|
||||
.set({ cancelAtPeriodEnd: true })
|
||||
.where(eq(subscriptions.stripeId, stripeSubscriptionId));
|
||||
|
||||
return { cancelAtPeriodEnd: true };
|
||||
}
|
||||
|
||||
export async function reactivateSubscription(stripeSubscriptionId: string) {
|
||||
await stripe.subscriptions.update(stripeSubscriptionId, {
|
||||
cancel_at_period_end: false,
|
||||
});
|
||||
|
||||
await db
|
||||
.update(subscriptions)
|
||||
.set({ cancelAtPeriodEnd: false })
|
||||
.where(eq(subscriptions.stripeId, stripeSubscriptionId));
|
||||
|
||||
return { cancelAtPeriodEnd: false };
|
||||
}
|
||||
|
||||
export async function listInvoices(
|
||||
customerId: string,
|
||||
limit: number = 10,
|
||||
startingAfter?: string,
|
||||
) {
|
||||
const params: Stripe.InvoiceListParams = {
|
||||
customer: customerId,
|
||||
limit,
|
||||
};
|
||||
if (startingAfter) {
|
||||
params.starting_after = startingAfter;
|
||||
}
|
||||
|
||||
const invoices = await stripe.invoices.list(params);
|
||||
return {
|
||||
invoices: invoices.data,
|
||||
hasMore: invoices.has_more,
|
||||
};
|
||||
}
|
||||
|
||||
export async function updateSubscriptionInDB(
|
||||
stripeId: string,
|
||||
data: {
|
||||
tier?: Tier;
|
||||
status?: string;
|
||||
currentPeriodStart?: Date;
|
||||
currentPeriodEnd?: Date;
|
||||
cancelAtPeriodEnd?: boolean;
|
||||
},
|
||||
) {
|
||||
const [existing] = await db
|
||||
.select()
|
||||
.from(subscriptions)
|
||||
.where(eq(subscriptions.stripeId, stripeId))
|
||||
.limit(1);
|
||||
|
||||
if (existing) {
|
||||
const [updated] = await db
|
||||
.update(subscriptions)
|
||||
.set(data as Record<string, unknown>)
|
||||
.where(eq(subscriptions.stripeId, stripeId))
|
||||
.returning();
|
||||
return updated;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
export async function handleWebhookEvent(event: Stripe.Event) {
|
||||
const obj = event.data.object as unknown as Record<string, unknown>;
|
||||
|
||||
switch (event.type) {
|
||||
case "checkout.session.completed": {
|
||||
const session = obj as unknown as Stripe.Checkout.Session;
|
||||
const userId = session.metadata?.userId;
|
||||
if (!userId || !session.subscription) break;
|
||||
|
||||
const stripeSub = await stripe.subscriptions.retrieve(
|
||||
session.subscription as string,
|
||||
);
|
||||
const sub = stripeSub as unknown as Record<string, unknown>;
|
||||
|
||||
await db.insert(subscriptions).values({
|
||||
userId,
|
||||
stripeId: stripeSub.id,
|
||||
tier: mapStripeProductToTier(
|
||||
stripeSub.items.data[0]?.price?.id ?? "",
|
||||
),
|
||||
status: sub.status as typeof subscriptions.$inferSelect.status,
|
||||
currentPeriodStart: new Date((sub.current_period_start as number) * 1000),
|
||||
currentPeriodEnd: new Date((sub.current_period_end as number) * 1000),
|
||||
cancelAtPeriodEnd: sub.cancel_at_period_end as boolean,
|
||||
}).onConflictDoNothing();
|
||||
break;
|
||||
}
|
||||
|
||||
case "invoice.paid": {
|
||||
const invoice = obj;
|
||||
if (!invoice.subscription) break;
|
||||
|
||||
await updateSubscriptionInDB(invoice.subscription as string, {
|
||||
status: "active",
|
||||
});
|
||||
break;
|
||||
}
|
||||
|
||||
case "invoice.payment_failed": {
|
||||
const invoice = obj;
|
||||
if (!invoice.subscription) break;
|
||||
|
||||
await updateSubscriptionInDB(invoice.subscription as string, {
|
||||
status: "past_due",
|
||||
});
|
||||
break;
|
||||
}
|
||||
|
||||
case "customer.subscription.updated": {
|
||||
const stripeSub = obj as unknown as Stripe.Subscription;
|
||||
const userId = stripeSub.metadata?.userId;
|
||||
const sub = stripeSub as unknown as Record<string, unknown>;
|
||||
|
||||
if (!userId) {
|
||||
const [existingSub] = await db
|
||||
.select()
|
||||
.from(subscriptions)
|
||||
.where(eq(subscriptions.stripeId, stripeSub.id))
|
||||
.limit(1);
|
||||
|
||||
if (!existingSub) break;
|
||||
}
|
||||
|
||||
const tier = stripeSub.items.data[0]?.price?.id
|
||||
? mapStripeProductToTier(stripeSub.items.data[0].price.id)
|
||||
: undefined;
|
||||
|
||||
await updateSubscriptionInDB(stripeSub.id, {
|
||||
tier,
|
||||
status: sub.status as string,
|
||||
currentPeriodStart: new Date((sub.current_period_start as number) * 1000),
|
||||
currentPeriodEnd: new Date((sub.current_period_end as number) * 1000),
|
||||
cancelAtPeriodEnd: sub.cancel_at_period_end as boolean,
|
||||
});
|
||||
break;
|
||||
}
|
||||
|
||||
case "customer.subscription.deleted": {
|
||||
const stripeSub = obj as unknown as Stripe.Subscription;
|
||||
await updateSubscriptionInDB(stripeSub.id, {
|
||||
status: "canceled",
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function mapStripeProductToTier(priceId: string): Tier {
|
||||
if (priceId === process.env.STRIPE_PRICE_BASIC) return "basic";
|
||||
if (priceId === process.env.STRIPE_PRICE_PLUS) return "plus";
|
||||
if (priceId === process.env.STRIPE_PRICE_PREMIUM) return "premium";
|
||||
return "basic";
|
||||
}
|
||||
6
web/src/server/stripe.ts
Normal file
6
web/src/server/stripe.ts
Normal file
@@ -0,0 +1,6 @@
|
||||
import Stripe from "stripe";
|
||||
|
||||
export const stripe = new Stripe(process.env.STRIPE_SECRET_KEY ?? "", {
|
||||
apiVersion: "2026-04-22.dahlia" as const,
|
||||
typescript: true,
|
||||
});
|
||||
Reference in New Issue
Block a user