diff --git a/packages/api/src/routes/subscription.routes.ts b/packages/api/src/routes/subscription.routes.ts
index 193b302..b387f22 100644
--- a/packages/api/src/routes/subscription.routes.ts
+++ b/packages/api/src/routes/subscription.routes.ts
@@ -1,7 +1,7 @@
import { FastifyInstance } from 'fastify';
import { BillingService } from '@shieldai/shared-billing/src/services/billing.service';
import { SubscriptionService, customerService, webhookService } from '@shieldai/shared-billing/src/services/billing.services';
-import { SubscriptionTier } from '@shieldai/shared-billing/src/config/billing.config';
+import { SubscriptionTier, isValidReturnUrl } from '@shieldai/shared-billing/src/config/billing.config';
import { AuthRequest } from './auth.middleware';
const billingService = BillingService.getInstance();
@@ -218,6 +218,13 @@ export async function subscriptionRoutes(fastify: FastifyInstance) {
});
}
+ if (!isValidReturnUrl(returnUrl)) {
+ return reply.status(400).send({
+ error: 'Invalid return URL',
+ message: 'returnUrl must be from an allowed origin',
+ });
+ }
+
try {
const portalSession = await billingService.createCustomerPortalSession(
customerId,
diff --git a/packages/shared-billing/src/__tests__/billing.config.test.ts b/packages/shared-billing/src/__tests__/billing.config.test.ts
new file mode 100644
index 0000000..09da1a0
--- /dev/null
+++ b/packages/shared-billing/src/__tests__/billing.config.test.ts
@@ -0,0 +1,59 @@
+import { describe, it, expect, beforeEach } from 'vitest';
+import { isValidReturnUrl } from '../config/billing.config';
+
+describe('isValidReturnUrl', () => {
+ const originalEnv = process.env;
+
+ beforeEach(() => {
+ process.env = { ...originalEnv, ALLOWED_RETURN_URL_ORIGINS: 'https://app.shieldai.com,shieldai://' };
+ });
+
+ it('accepts allowed web origin', () => {
+ expect(isValidReturnUrl('https://app.shieldai.com/account')).toBe(true);
+ });
+
+ it('accepts allowed web origin with path and query params', () => {
+ expect(isValidReturnUrl('https://app.shieldai.com/settings?tab=billing')).toBe(true);
+ });
+
+ it('accepts custom app scheme', () => {
+ expect(isValidReturnUrl('shieldai://portal-return')).toBe(true);
+ });
+
+ it('rejects attacker-controlled domain', () => {
+ expect(isValidReturnUrl('https://evil.com/phish')).toBe(false);
+ });
+
+ it('rejects subdomain of allowed origin', () => {
+ expect(isValidReturnUrl('https://evil-app.shieldai.com/hack')).toBe(false);
+ });
+
+ it('rejects domain that contains allowed origin as substring', () => {
+ expect(isValidReturnUrl('https://notapp.shieldai.com/hack')).toBe(false);
+ });
+
+ it('rejects domain that has allowed origin as prefix (substring attack)', () => {
+ expect(isValidReturnUrl('https://app.shieldai.com.evil.com/hack')).toBe(false);
+ });
+
+ it('rejects malformed URLs', () => {
+ expect(isValidReturnUrl('not-a-url')).toBe(false);
+ expect(isValidReturnUrl('')).toBe(false);
+ });
+
+ it('rejects javascript protocol', () => {
+ expect(isValidReturnUrl('javascript:alert(1)')).toBe(false);
+ });
+
+ it('rejects data protocol', () => {
+ expect(isValidReturnUrl('data:text/html,')).toBe(false);
+ });
+
+ it('rejects protocol-relative URLs', () => {
+ expect(isValidReturnUrl('//evil.com/steal')).toBe(false);
+ });
+
+ it('rejects URL with allowed origin in query string', () => {
+ expect(isValidReturnUrl('https://evil.com/?redirect=https://app.shieldai.com')).toBe(false);
+ });
+});
diff --git a/packages/shared-billing/src/config/billing.config.ts b/packages/shared-billing/src/config/billing.config.ts
index 52cb4ea..85d809b 100644
--- a/packages/shared-billing/src/config/billing.config.ts
+++ b/packages/shared-billing/src/config/billing.config.ts
@@ -52,6 +52,32 @@ export const BillingConfigSchema = z.object({
export type BillingConfig = z.infer;
+// Allowed return URL origins for Stripe customer portal (open redirect prevention)
+const allowedReturnUrlOrigins: string[] = (process.env.ALLOWED_RETURN_URL_ORIGINS || 'https://app.shieldai.com,shieldai://').split(',').map(s => s.trim()).filter(Boolean);
+
+export function isValidReturnUrl(url: string): boolean {
+ try {
+ const parsed = new URL(url);
+ // Custom app schemes (e.g., shieldai://) have origin === 'null'
+ // Use prefix matching for these since they don't have a standard origin
+ if (parsed.origin === 'null') {
+ return allowedReturnUrlOrigins.some(origin => url.startsWith(origin));
+ }
+ // Standard web protocols: compare origins to prevent substring attacks
+ return allowedReturnUrlOrigins.some(origin => {
+ try {
+ const originUrl = new URL(origin);
+ return parsed.origin === originUrl.origin;
+ } catch {
+ // If the allowed origin can't be parsed as a URL, it's a custom scheme
+ return false;
+ }
+ });
+ } catch {
+ return false;
+ }
+}
+
export const loadBillingConfig = (): BillingConfig => {
const rawConfig = {
stripe: {
diff --git a/packages/shared-billing/src/index.ts b/packages/shared-billing/src/index.ts
index 96be282..006a447 100644
--- a/packages/shared-billing/src/index.ts
+++ b/packages/shared-billing/src/index.ts
@@ -1,5 +1,5 @@
export { BillingService } from './services/billing.service';
-export { loadBillingConfig, SubscriptionTier } from './config/billing.config';
+export { loadBillingConfig, SubscriptionTier, isValidReturnUrl } from './config/billing.config';
export {
requireTier,
checkUsageLimit,
diff --git a/packages/shared-billing/src/services/billing.service.ts b/packages/shared-billing/src/services/billing.service.ts
index 5d218ec..71c0e1e 100644
--- a/packages/shared-billing/src/services/billing.service.ts
+++ b/packages/shared-billing/src/services/billing.service.ts
@@ -1,5 +1,5 @@
import Stripe from 'stripe';
-import { loadBillingConfig, SubscriptionTier } from '../config/billing.config';
+import { loadBillingConfig, SubscriptionTier, isValidReturnUrl } from '../config/billing.config';
import { RedisService } from '@shieldsai/shared-notifications';
import type { Subscription, SubscriptionCreateSchema, SubscriptionUpdateSchema } from '../models/subscription.model';
@@ -168,6 +168,9 @@ export class BillingService {
customerId: string,
returnUrl: string
): Promise {
+ if (!isValidReturnUrl(returnUrl)) {
+ throw new Error(`Invalid return URL: ${returnUrl}`);
+ }
return await stripe.billingPortal.sessions.create({
customer: customerId,
return_url: returnUrl,