Files
Kordant/web/src/server/services/billing.service.ts

651 lines
20 KiB
TypeScript

import { TRPCError } from "@trpc/server";
import { eq, and } from "drizzle-orm";
import { safeParse } from "valibot";
import { db } from "~/server/db";
import { stripe } from "~/server/stripe";
import { resend } from "~/server/lib/resend";
import { users } from "~/server/db/schema/auth";
import { subscriptions } from "~/server/db/schema/subscription";
import type Stripe from "stripe";
import {
CheckoutSessionSchema,
SubscriptionSchema,
InvoiceSchema,
} from "~/server/api/schemas/webhook";
import { paymentFailedEmail, subscriptionActivatedEmail } from "./email.templates";
export type Tier = "basic" | "plus" | "premium" | "family_guard" | "family_fortress";
export type SubscriptionStatus =
| "active"
| "past_due"
| "canceled"
| "unpaid"
| "trialing"
| "paused"
| "incomplete"
| "incomplete_expired";
const TRIAL_DAYS = 14;
/* ------------------------------------------------------------------ */
/* Stripe customer lifecycle */
/* ------------------------------------------------------------------ */
export async function getOrCreateCustomer(userId: string, email: string) {
const [user] = await db
.select()
.from(users)
.where(eq(users.id, userId))
.limit(1);
if (!user) {
throw new TRPCError({ code: "NOT_FOUND", message: "User not found" });
}
if (user.stripeCustomerId) {
return user.stripeCustomerId;
}
const customer = await stripe.customers.create({
email,
metadata: { userId },
});
await db
.update(users)
.set({ stripeCustomerId: customer.id })
.where(eq(users.id, userId));
return customer.id;
}
/* ------------------------------------------------------------------ */
/* Checkout sessions */
/* ------------------------------------------------------------------ */
export async function createCheckoutSession(
userId: string,
email: string,
priceId: string,
returnUrl: string,
options: { trial?: boolean; isUpgrade?: boolean; isDowngrade?: boolean } = {},
) {
const customerId = await getOrCreateCustomer(userId, email);
const subscriptionData: Record<string, unknown> =
{
metadata: { userId },
trial_period_days: options.trial ? TRIAL_DAYS : undefined,
};
// For upgrades / downgrades, set proration behavior
if (options.isUpgrade || options.isDowngrade) {
subscriptionData.proration_behavior = "create_prorations";
}
const session = await stripe.checkout.sessions.create({
customer: customerId,
mode: "subscription",
ui_mode: "embedded_page",
line_items: [{ price: priceId, quantity: 1 }],
return_url: `${returnUrl}?session_id={CHECKOUT_SESSION_ID}`,
metadata: { userId },
subscription_data: subscriptionData,
});
return { clientSecret: session.client_secret ?? "", sessionId: session.id };
}
/* ------------------------------------------------------------------ */
/* Customer portal */
/* ------------------------------------------------------------------ */
export async function createPortalSession(
customerId: string,
returnUrl: string,
) {
const session = await stripe.billingPortal.sessions.create({
customer: customerId,
return_url: returnUrl,
});
return { url: session.url };
}
/* ------------------------------------------------------------------ */
/* Subscription management */
/* ------------------------------------------------------------------ */
export async function cancelSubscription(stripeSubscriptionId: string) {
await stripe.subscriptions.update(stripeSubscriptionId, {
cancel_at_period_end: true,
});
await db
.update(subscriptions)
.set({ cancelAtPeriodEnd: true })
.where(eq(subscriptions.stripeId, stripeSubscriptionId));
return { cancelAtPeriodEnd: true };
}
export async function reactivateSubscription(stripeSubscriptionId: string) {
await stripe.subscriptions.update(stripeSubscriptionId, {
cancel_at_period_end: false,
});
await db
.update(subscriptions)
.set({ cancelAtPeriodEnd: false })
.where(eq(subscriptions.stripeId, stripeSubscriptionId));
return { cancelAtPeriodEnd: false };
}
/* ------------------------------------------------------------------ */
/* Invoices */
/* ------------------------------------------------------------------ */
export async function listInvoices(
customerId: string,
limit: number = 10,
startingAfter?: string,
) {
const params: Stripe.InvoiceListParams = {
customer: customerId,
limit,
};
if (startingAfter) {
params.starting_after = startingAfter;
}
const invoices = await stripe.invoices.list(params);
return {
invoices: invoices.data,
hasMore: invoices.has_more,
};
}
/* ------------------------------------------------------------------ */
/* Trial creation */
/* ------------------------------------------------------------------ */
export async function createTrialSubscription(
userId: string,
email: string,
returnUrl: string,
) {
const customerId = await getOrCreateCustomer(userId, email);
// Use the basic plan price for trial subscriptions
const trialPriceId = process.env.STRIPE_PRICE_BASIC;
if (!trialPriceId) {
throw new TRPCError({
code: "INTERNAL_SERVER_ERROR",
message: "Trial price ID not configured",
});
}
const session = await stripe.checkout.sessions.create({
customer: customerId,
mode: "subscription",
line_items: [{ price: trialPriceId, quantity: 1 }],
allow_promotion_codes: true,
subscription_data: {
trial_period_days: TRIAL_DAYS,
metadata: { userId },
},
success_url: `${returnUrl}?session_id={CHECKOUT_SESSION_ID}`,
cancel_url: `${returnUrl}/pricing`,
metadata: { userId },
});
return { sessionId: session.id, url: session.url };
}
/* ------------------------------------------------------------------ */
/* Tier change (upgrade / downgrade with proration) */
/* ------------------------------------------------------------------ */
export async function changeSubscriptionTier(
stripeSubscriptionId: string,
newPriceId: string,
) {
const subscription = await stripe.subscriptions.retrieve(
stripeSubscriptionId,
{ expand: ["items.data.price"] },
);
// Update the subscription item with proration
const item = subscription.items.data[0];
if (!item) {
throw new TRPCError({
code: "NOT_FOUND",
message: "No subscription items found",
});
}
const updatedSub = await stripe.subscriptions.update(
stripeSubscriptionId,
{
items: [{ id: item.id, price: newPriceId }],
proration_behavior: "create_prorations",
},
);
// Update DB record
const tier = mapStripeProductToTier(newPriceId);
const subData = updatedSub as unknown as Record<string, unknown>;
await updateSubscriptionInDB(stripeSubscriptionId, {
tier,
stripePriceId: newPriceId,
status: (subData.status as SubscriptionStatus) ?? "active",
currentPeriodStart: subData.current_period_start
? new Date((subData.current_period_start as number) * 1000)
: undefined,
currentPeriodEnd: subData.current_period_end
? new Date((subData.current_period_end as number) * 1000)
: undefined,
});
return { subscription: updatedSub };
}
/* ------------------------------------------------------------------ */
/* Database helpers */
/* ------------------------------------------------------------------ */
export async function updateSubscriptionInDB(
stripeId: string,
data: {
tier?: Tier;
stripePriceId?: string;
status?: SubscriptionStatus;
currentPeriodStart?: Date;
currentPeriodEnd?: Date;
trialEnd?: Date;
cancelAtPeriodEnd?: boolean;
defaultPaymentMethodLast4?: string;
},
) {
const [existing] = await db
.select()
.from(subscriptions)
.where(eq(subscriptions.stripeId, stripeId))
.limit(1);
if (existing) {
const updateData: Record<string, unknown> = {};
for (const [key, value] of Object.entries(data)) {
if (value !== undefined) {
updateData[key] = value;
}
}
const [updated] = await db
.update(subscriptions)
.set(updateData)
.where(eq(subscriptions.stripeId, stripeId))
.returning();
return updated;
}
return null;
}
/* ------------------------------------------------------------------ */
/* Valibot parsers */
/* ------------------------------------------------------------------ */
function safeParseSubscription(obj: unknown) {
const result = safeParse(SubscriptionSchema, obj);
if (!result.success) {
console.error(
`[billing:webhook] Failed to parse subscription data: ${result.issues?.map((i) => i.message).join(", ")}`,
);
return null;
}
return result.output;
}
function safeParseCheckoutSession(obj: unknown) {
const result = safeParse(CheckoutSessionSchema, obj);
if (!result.success) {
console.error(
`[billing:webhook] Failed to parse checkout session data: ${result.issues?.map((i) => i.message).join(", ")}`,
);
return null;
}
return result.output;
}
function safeParseInvoice(obj: unknown) {
const result = safeParse(InvoiceSchema, obj);
if (!result.success) {
console.error(
`[billing:webhook] Failed to parse invoice data: ${result.issues?.map((i) => i.message).join(", ")}`,
);
return null;
}
return result.output;
}
/* ------------------------------------------------------------------ */
/* Webhook event handler */
/* ------------------------------------------------------------------ */
async function upsertSubscriptionFromStripe(
userId: string,
stripeSub: Stripe.Subscription,
) {
const subData = stripeSub as unknown as Record<string, unknown>;
const priceItem = stripeSub.items.data[0]?.price;
const priceId =
typeof priceItem === "string"
? priceItem
: (priceItem as Stripe.Price | undefined)?.id ?? "";
const insertData = {
userId,
stripeId: stripeSub.id,
stripePriceId: priceId || undefined,
tier: mapStripeProductToTier(priceId),
status: (subData.status as SubscriptionStatus) ?? "active",
currentPeriodStart: subData.current_period_start
? new Date((subData.current_period_start as number) * 1000)
: undefined,
currentPeriodEnd: subData.current_period_end
? new Date((subData.current_period_end as number) * 1000)
: undefined,
trialEnd: subData.trial_end
? new Date((subData.trial_end as number) * 1000)
: undefined,
cancelAtPeriodEnd: Boolean(subData.cancel_at_period_end),
};
// Upsert: insert or update if stripeId already exists
await db
.insert(subscriptions)
.values(insertData)
.onConflictDoUpdate({
target: subscriptions.stripeId,
set: {
tier: insertData.tier,
status: insertData.status,
currentPeriodStart: insertData.currentPeriodStart,
currentPeriodEnd: insertData.currentPeriodEnd,
trialEnd: insertData.trialEnd,
cancelAtPeriodEnd: insertData.cancelAtPeriodEnd,
stripePriceId: insertData.stripePriceId,
},
});
}
async function extractPaymentMethodLast4(
stripeSub: Stripe.Subscription,
): Promise<string | undefined> {
const defaultSource = stripeSub.default_payment_method;
if (!defaultSource || typeof defaultSource === "string") return undefined;
const pm = defaultSource as Stripe.PaymentMethod;
if (pm.card?.last4) return pm.card.last4;
return undefined;
}
export async function handleWebhookEvent(event: Stripe.Event) {
const eventType = event.type;
console.log(`[billing:webhook] Processing event: ${eventType} (${event.id})`);
switch (eventType) {
case "checkout.session.completed": {
const session = safeParseCheckoutSession(event.data.object);
if (!session) break;
const userId = session.metadata?.userId;
if (!userId || !session.subscription) {
console.warn(
`[billing:webhook] checkout.session.completed missing userId or subscription`,
);
break;
}
const stripeSub = await stripe.subscriptions.retrieve(
session.subscription as string,
{ expand: ["default_payment_method"] },
);
await upsertSubscriptionFromStripe(userId, stripeSub);
// Update payment method last4
const last4 = await extractPaymentMethodLast4(stripeSub);
if (last4) {
await updateSubscriptionInDB(stripeSub.id, {
defaultPaymentMethodLast4: last4,
});
}
// If this is a trial subscription, send activation email
if (stripeSub.status === "trialing") {
try {
const [user] = await db
.select()
.from(users)
.where(eq(users.id, userId))
.limit(1);
if (user?.email) {
await resend.emails.send({
from: "Kordant <noreply@kordant.com>",
to: user.email,
...subscriptionActivatedEmail(
user.name ?? "there",
"Basic",
TRIAL_DAYS,
),
});
}
} catch (emailErr) {
console.error(
`[billing:webhook] Failed to send trial activation email:`,
emailErr,
);
}
}
break;
}
case "invoice.payment_succeeded":
case "invoice.paid": {
const invoice = safeParseInvoice(event.data.object);
if (!invoice?.subscription) break;
const stripeSub = await stripe.subscriptions.retrieve(
invoice.subscription as string,
{ expand: ["default_payment_method"] },
);
// Find the user from the subscription record
const [existingSub] = await db
.select()
.from(subscriptions)
.where(eq(subscriptions.stripeId, invoice.subscription as string))
.limit(1);
if (existingSub) {
await upsertSubscriptionFromStripe(existingSub.userId, stripeSub);
const last4 = await extractPaymentMethodLast4(stripeSub);
if (last4) {
await updateSubscriptionInDB(stripeSub.id, {
defaultPaymentMethodLast4: last4,
});
}
}
// If this was a trial-to-paid transition, send activation email
if (stripeSub.trial_end && stripeSub.status === "active") {
try {
const userId = existingSub?.userId;
if (userId) {
const [user] = await db
.select()
.from(users)
.where(eq(users.id, userId))
.limit(1);
if (user?.email) {
const tier = mapStripeProductToTier(
(stripeSub.items.data[0]?.price as Stripe.Price)?.id ?? "",
);
await resend.emails.send({
from: "Kordant <noreply@kordant.com>",
to: user.email,
...subscriptionActivatedEmail(user.name ?? "there", tier, 0),
});
}
}
} catch (emailErr) {
console.error(
`[billing:webhook] Failed to send subscription activation email:`,
emailErr,
);
}
}
break;
}
case "invoice.payment_failed": {
const invoice = safeParseInvoice(event.data.object);
if (!invoice?.subscription) break;
await updateSubscriptionInDB(invoice.subscription as string, {
status: "past_due",
});
// Send payment failure / retry email
try {
const [existingSub] = await db
.select()
.from(subscriptions)
.where(eq(subscriptions.stripeId, invoice.subscription as string))
.limit(1);
if (existingSub) {
const [user] = await db
.select()
.from(users)
.where(eq(users.id, existingSub.userId))
.limit(1);
if (user?.email) {
const portalSession = await stripe.billingPortal.sessions.create({
customer: user.stripeCustomerId!,
return_url: `${process.env.APP_URL ?? "https://kordant.com"}/settings`,
});
await resend.emails.send({
from: "Kordant <noreply@kordant.com>",
to: user.email,
...paymentFailedEmail(user.name ?? "there", portalSession.url),
});
}
}
} catch (emailErr) {
console.error(
`[billing:webhook] Failed to send payment failure email:`,
emailErr,
);
}
break;
}
case "customer.subscription.updated": {
const validatedSub = safeParseSubscription(event.data.object);
if (!validatedSub) break;
// Find existing subscription to get userId
const [existingSub] = await db
.select()
.from(subscriptions)
.where(eq(subscriptions.stripeId, validatedSub.id))
.limit(1);
if (!existingSub) {
// Subscription doesn't exist in DB yet — might be from metadata
const userId = validatedSub.metadata?.userId;
if (!userId) break;
const stripeSub = await stripe.subscriptions.retrieve(
validatedSub.id,
{ expand: ["default_payment_method"] },
);
await upsertSubscriptionFromStripe(userId, stripeSub);
break;
}
// Retrieve full subscription from Stripe for accurate data
const stripeSub = await stripe.subscriptions.retrieve(validatedSub.id, {
expand: ["default_payment_method"],
});
await upsertSubscriptionFromStripe(existingSub.userId, stripeSub);
const last4 = await extractPaymentMethodLast4(stripeSub);
if (last4) {
await updateSubscriptionInDB(stripeSub.id, {
defaultPaymentMethodLast4: last4,
});
}
break;
}
case "customer.subscription.deleted": {
const stripeSub = safeParseSubscription(event.data.object);
if (!stripeSub) break;
await updateSubscriptionInDB(stripeSub.id, {
status: "canceled",
});
break;
}
default: {
console.log(
`[billing:webhook] Unhandled event type: ${eventType}`,
);
}
}
}
/* ------------------------------------------------------------------ */
/* Tier mapping */
/* ------------------------------------------------------------------ */
export function mapStripeProductToTier(priceId: string): Tier {
if (!priceId) return "basic";
const envBasic = process.env.STRIPE_PRICE_BASIC ?? "";
const envPlus = process.env.STRIPE_PRICE_PLUS ?? "";
const envPremium = process.env.STRIPE_PRICE_PREMIUM ?? "";
const envFamilyGuard = process.env.STRIPE_PRICE_FAMILY_GUARD ?? "";
const envFamilyFortress = process.env.STRIPE_PRICE_FAMILY_FORTRESS ?? "";
if (priceId === envBasic) return "basic";
if (priceId === envPlus) return "plus";
if (priceId === envPremium) return "premium";
if (priceId === envFamilyGuard) return "family_guard";
if (priceId === envFamilyFortress) return "family_fortress";
// Also check for product ID prefixes or metadata patterns
// Check family plans FIRST to avoid mis-matching "family_guard" as "plus"
if (priceId.includes("family_fortress")) return "family_fortress";
if (priceId.includes("family_guard")) return "family_guard";
if (priceId.includes("basic") || priceId.includes("shield")) return "basic";
if (priceId.includes("plus") || priceId.includes("guard")) return "plus";
if (priceId.includes("premium") || priceId.includes("fortress")) return "premium";
return "basic";
}