""" Subscription API endpoints with Stripe integration. Endpoints: - GET /subscription - Get current subscription - GET /subscription/tiers - Get available tiers - GET /subscription/features - Get current features - POST /subscription/checkout - Create Stripe checkout session - POST /subscription/portal - Create Stripe customer portal session - POST /subscription/cancel - Cancel subscription """ import os from fastapi import APIRouter, HTTPException, status, Request from sqlalchemy import select, func from pydantic import BaseModel from typing import Optional from app.api.deps import Database, CurrentUser from app.models.domain import Domain from app.models.user import User from app.models.subscription import Subscription, SubscriptionTier, TIER_CONFIG from app.schemas.subscription import SubscriptionResponse from app.services.stripe_service import StripeService, TIER_FEATURES from app.services.email_service import email_service router = APIRouter() # ============== Schemas ============== class CheckoutRequest(BaseModel): """Request to create checkout session.""" plan: str # "trader" or "tycoon" success_url: Optional[str] = None cancel_url: Optional[str] = None class CheckoutResponse(BaseModel): """Response with checkout URL.""" checkout_url: str session_id: str class PortalResponse(BaseModel): """Response with portal URL.""" portal_url: str # ============== Endpoints ============== @router.get("", response_model=SubscriptionResponse) async def get_subscription( current_user: CurrentUser, db: Database, ): """Get current user's subscription details.""" result = await db.execute( select(Subscription).where(Subscription.user_id == current_user.id) ) subscription = result.scalar_one_or_none() if not subscription: # Create default Scout subscription subscription = Subscription( user_id=current_user.id, tier=SubscriptionTier.SCOUT, max_domains=5, check_frequency="daily", ) db.add(subscription) await db.commit() await db.refresh(subscription) # Count domains used domain_count = await db.execute( select(func.count(Domain.id)).where(Domain.user_id == current_user.id) ) domains_used = domain_count.scalar() or 0 config = subscription.config return SubscriptionResponse( id=subscription.id, tier=subscription.tier.value, tier_name=config["name"], status=subscription.status.value, domain_limit=subscription.max_domains, domains_used=domains_used, portfolio_limit=config.get("portfolio_limit", 0), check_frequency=config["check_frequency"], history_days=config["history_days"], features=config["features"], started_at=subscription.started_at, expires_at=subscription.expires_at, ) @router.get("/tiers") async def get_subscription_tiers(): """Get available subscription tiers and their features.""" tiers = [] for tier_enum, config in TIER_CONFIG.items(): feature_list = [] feature_list.append(f"{config['domain_limit']} domains in watchlist") if config.get("portfolio_limit"): if config["portfolio_limit"] == -1: feature_list.append("Unlimited portfolio domains") elif config["portfolio_limit"] > 0: feature_list.append(f"{config['portfolio_limit']} portfolio domains") if config["check_frequency"] == "realtime": feature_list.append("10-minute availability checks") elif config["check_frequency"] == "hourly": feature_list.append("Hourly availability checks") else: feature_list.append("Daily availability checks") if config["features"].get("sms_alerts"): feature_list.append("SMS & Telegram notifications") elif config["features"].get("email_alerts"): feature_list.append("Email notifications") if config["features"].get("domain_valuation"): feature_list.append("Domain valuation") if config["features"].get("market_insights"): feature_list.append("Full market insights") if config["history_days"] == -1: feature_list.append("Unlimited check history") elif config["history_days"] > 0: feature_list.append(f"{config['history_days']}-day check history") if config["features"].get("api_access"): feature_list.append("REST API access") if config["features"].get("bulk_tools"): feature_list.append("Bulk import/export tools") if config["features"].get("seo_metrics"): feature_list.append("SEO metrics (DA, backlinks)") tiers.append({ "id": tier_enum.value, "name": config["name"], "domain_limit": config["domain_limit"], "portfolio_limit": config.get("portfolio_limit", 0), "price": config["price"], "currency": config.get("currency", "USD"), "check_frequency": config["check_frequency"], "features": feature_list, "feature_flags": config["features"], }) return {"tiers": tiers} @router.get("/features") async def get_my_features(current_user: CurrentUser, db: Database): """Get current user's available features based on subscription.""" result = await db.execute( select(Subscription).where(Subscription.user_id == current_user.id) ) subscription = result.scalar_one_or_none() if not subscription: # Default to Scout config = TIER_CONFIG[SubscriptionTier.SCOUT] return { "tier": "scout", "tier_name": "Scout", "domain_limit": config["domain_limit"], "portfolio_limit": config.get("portfolio_limit", 0), "check_frequency": config["check_frequency"], "history_days": config["history_days"], "features": config["features"], } config = subscription.config return { "tier": subscription.tier.value, "tier_name": config["name"], "domain_limit": config["domain_limit"], "portfolio_limit": config.get("portfolio_limit", 0), "check_frequency": config["check_frequency"], "history_days": config["history_days"], "features": config["features"], } @router.post("/checkout", response_model=CheckoutResponse) async def create_checkout_session( request: CheckoutRequest, current_user: CurrentUser, db: Database, ): """ Create a Stripe Checkout session for subscription upgrade. Args: plan: "trader" or "tycoon" success_url: URL to redirect after successful payment cancel_url: URL to redirect if user cancels Returns: checkout_url: Stripe Checkout page URL session_id: Stripe session ID """ if not StripeService.is_configured(): raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Payment system not configured. Please contact support.", ) if request.plan not in ["trader", "tycoon"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid plan. Must be 'trader' or 'tycoon'", ) # Get site URL from environment site_url = os.getenv("SITE_URL", "http://localhost:3000") success_url = request.success_url or f"{site_url}/dashboard?upgraded=true" cancel_url = request.cancel_url or f"{site_url}/pricing?cancelled=true" try: result = await StripeService.create_checkout_session( user=current_user, plan=request.plan, success_url=success_url, cancel_url=cancel_url, ) # Save Stripe customer ID if new if result.get("customer_id") and not current_user.stripe_customer_id: current_user.stripe_customer_id = result["customer_id"] await db.commit() return CheckoutResponse( checkout_url=result["checkout_url"], session_id=result["session_id"], ) except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=str(e), ) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to create checkout session: {str(e)}", ) @router.post("/portal", response_model=PortalResponse) async def create_portal_session( current_user: CurrentUser, db: Database, ): """ Create a Stripe Customer Portal session. Users can: - Update payment method - View invoices - Cancel subscription - Update billing info """ if not StripeService.is_configured(): raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Payment system not configured. Please contact support.", ) if not current_user.stripe_customer_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="No billing account found. Please subscribe to a plan first.", ) site_url = os.getenv("SITE_URL", "http://localhost:3000") return_url = f"{site_url}/dashboard" try: portal_url = await StripeService.create_portal_session( customer_id=current_user.stripe_customer_id, return_url=return_url, ) return PortalResponse(portal_url=portal_url) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to create portal session: {str(e)}", ) @router.post("/cancel") async def cancel_subscription( current_user: CurrentUser, db: Database, ): """ Cancel subscription and downgrade to Scout. Note: For Stripe-managed subscriptions, use the Customer Portal instead. This endpoint is for manual cancellation. """ result = await db.execute( select(Subscription).where(Subscription.user_id == current_user.id) ) subscription = result.scalar_one_or_none() if not subscription: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="No subscription found", ) if subscription.tier == SubscriptionTier.SCOUT: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Already on free plan", ) # Downgrade to Scout old_tier = subscription.tier.value subscription.tier = SubscriptionTier.SCOUT subscription.max_domains = TIER_CONFIG[SubscriptionTier.SCOUT]["domain_limit"] subscription.check_frequency = TIER_CONFIG[SubscriptionTier.SCOUT]["check_frequency"] subscription.stripe_subscription_id = None await db.commit() return { "status": "cancelled", "message": f"Subscription cancelled. Downgraded from {old_tier} to Scout.", "new_tier": "scout", }