diff --git a/packages/api/src/routes/subscription.routes.ts b/packages/api/src/routes/subscription.routes.ts index 193b302..b387f22 100644 --- a/packages/api/src/routes/subscription.routes.ts +++ b/packages/api/src/routes/subscription.routes.ts @@ -1,7 +1,7 @@ import { FastifyInstance } from 'fastify'; import { BillingService } from '@shieldai/shared-billing/src/services/billing.service'; import { SubscriptionService, customerService, webhookService } from '@shieldai/shared-billing/src/services/billing.services'; -import { SubscriptionTier } from '@shieldai/shared-billing/src/config/billing.config'; +import { SubscriptionTier, isValidReturnUrl } from '@shieldai/shared-billing/src/config/billing.config'; import { AuthRequest } from './auth.middleware'; const billingService = BillingService.getInstance(); @@ -218,6 +218,13 @@ export async function subscriptionRoutes(fastify: FastifyInstance) { }); } + if (!isValidReturnUrl(returnUrl)) { + return reply.status(400).send({ + error: 'Invalid return URL', + message: 'returnUrl must be from an allowed origin', + }); + } + try { const portalSession = await billingService.createCustomerPortalSession( customerId, diff --git a/packages/shared-billing/src/__tests__/billing.config.test.ts b/packages/shared-billing/src/__tests__/billing.config.test.ts new file mode 100644 index 0000000..09da1a0 --- /dev/null +++ b/packages/shared-billing/src/__tests__/billing.config.test.ts @@ -0,0 +1,59 @@ +import { describe, it, expect, beforeEach } from 'vitest'; +import { isValidReturnUrl } from '../config/billing.config'; + +describe('isValidReturnUrl', () => { + const originalEnv = process.env; + + beforeEach(() => { + process.env = { ...originalEnv, ALLOWED_RETURN_URL_ORIGINS: 'https://app.shieldai.com,shieldai://' }; + }); + + it('accepts allowed web origin', () => { + expect(isValidReturnUrl('https://app.shieldai.com/account')).toBe(true); + }); + + it('accepts allowed web origin with path and query params', () => { + expect(isValidReturnUrl('https://app.shieldai.com/settings?tab=billing')).toBe(true); + }); + + it('accepts custom app scheme', () => { + expect(isValidReturnUrl('shieldai://portal-return')).toBe(true); + }); + + it('rejects attacker-controlled domain', () => { + expect(isValidReturnUrl('https://evil.com/phish')).toBe(false); + }); + + it('rejects subdomain of allowed origin', () => { + expect(isValidReturnUrl('https://evil-app.shieldai.com/hack')).toBe(false); + }); + + it('rejects domain that contains allowed origin as substring', () => { + expect(isValidReturnUrl('https://notapp.shieldai.com/hack')).toBe(false); + }); + + it('rejects domain that has allowed origin as prefix (substring attack)', () => { + expect(isValidReturnUrl('https://app.shieldai.com.evil.com/hack')).toBe(false); + }); + + it('rejects malformed URLs', () => { + expect(isValidReturnUrl('not-a-url')).toBe(false); + expect(isValidReturnUrl('')).toBe(false); + }); + + it('rejects javascript protocol', () => { + expect(isValidReturnUrl('javascript:alert(1)')).toBe(false); + }); + + it('rejects data protocol', () => { + expect(isValidReturnUrl('data:text/html,')).toBe(false); + }); + + it('rejects protocol-relative URLs', () => { + expect(isValidReturnUrl('//evil.com/steal')).toBe(false); + }); + + it('rejects URL with allowed origin in query string', () => { + expect(isValidReturnUrl('https://evil.com/?redirect=https://app.shieldai.com')).toBe(false); + }); +}); diff --git a/packages/shared-billing/src/config/billing.config.ts b/packages/shared-billing/src/config/billing.config.ts index 52cb4ea..85d809b 100644 --- a/packages/shared-billing/src/config/billing.config.ts +++ b/packages/shared-billing/src/config/billing.config.ts @@ -52,6 +52,32 @@ export const BillingConfigSchema = z.object({ export type BillingConfig = z.infer; +// Allowed return URL origins for Stripe customer portal (open redirect prevention) +const allowedReturnUrlOrigins: string[] = (process.env.ALLOWED_RETURN_URL_ORIGINS || 'https://app.shieldai.com,shieldai://').split(',').map(s => s.trim()).filter(Boolean); + +export function isValidReturnUrl(url: string): boolean { + try { + const parsed = new URL(url); + // Custom app schemes (e.g., shieldai://) have origin === 'null' + // Use prefix matching for these since they don't have a standard origin + if (parsed.origin === 'null') { + return allowedReturnUrlOrigins.some(origin => url.startsWith(origin)); + } + // Standard web protocols: compare origins to prevent substring attacks + return allowedReturnUrlOrigins.some(origin => { + try { + const originUrl = new URL(origin); + return parsed.origin === originUrl.origin; + } catch { + // If the allowed origin can't be parsed as a URL, it's a custom scheme + return false; + } + }); + } catch { + return false; + } +} + export const loadBillingConfig = (): BillingConfig => { const rawConfig = { stripe: { diff --git a/packages/shared-billing/src/index.ts b/packages/shared-billing/src/index.ts index 96be282..006a447 100644 --- a/packages/shared-billing/src/index.ts +++ b/packages/shared-billing/src/index.ts @@ -1,5 +1,5 @@ export { BillingService } from './services/billing.service'; -export { loadBillingConfig, SubscriptionTier } from './config/billing.config'; +export { loadBillingConfig, SubscriptionTier, isValidReturnUrl } from './config/billing.config'; export { requireTier, checkUsageLimit, diff --git a/packages/shared-billing/src/services/billing.service.ts b/packages/shared-billing/src/services/billing.service.ts index 5d218ec..71c0e1e 100644 --- a/packages/shared-billing/src/services/billing.service.ts +++ b/packages/shared-billing/src/services/billing.service.ts @@ -1,5 +1,5 @@ import Stripe from 'stripe'; -import { loadBillingConfig, SubscriptionTier } from '../config/billing.config'; +import { loadBillingConfig, SubscriptionTier, isValidReturnUrl } from '../config/billing.config'; import { RedisService } from '@shieldsai/shared-notifications'; import type { Subscription, SubscriptionCreateSchema, SubscriptionUpdateSchema } from '../models/subscription.model'; @@ -168,6 +168,9 @@ export class BillingService { customerId: string, returnUrl: string ): Promise { + if (!isValidReturnUrl(returnUrl)) { + throw new Error(`Invalid return URL: ${returnUrl}`); + } return await stripe.billingPortal.sessions.create({ customer: customerId, return_url: returnUrl,