pounce/backend/app/api/subscription.py
yves.gugger c1316d8b38 feat: Perfect onboarding journey after Stripe payment
NEW WELCOME PAGE (/command/welcome):
- Celebratory confetti animation on arrival
- Plan-specific welcome message (Trader/Tycoon)
- Features unlocked section with icons
- Next steps with quick links to key features
- Link to documentation and support

UPDATED USER JOURNEY:

1. Pricing Page (/pricing)
   ↓ Click plan button
2. (If not logged in) → Register → Back to Pricing
   ↓ Click plan button
3. Stripe Checkout (external)
   ↓ Payment successful
4. Welcome Page (/command/welcome?plan=trader)
   - Shows unlocked features
   - Guided next steps
   ↓ 'Go to Dashboard'
5. Dashboard (/command/dashboard)

CANCEL FLOW:
- Stripe Cancel → /pricing?cancelled=true
- Shows friendly banner: 'No worries! Card not charged.'
- Dismissible with X button
- URL cleaned up automatically

BACKEND UPDATES:
- Default success URL: /command/welcome?plan={plan}
- Default cancel URL: /pricing?cancelled=true
- Portal return URL: /command/settings (not /dashboard)

This creates a complete, professional onboarding experience
that celebrates the upgrade and guides users to get started.
2025-12-10 16:17:29 +01:00

347 lines
11 KiB
Python

"""
Subscription API endpoints with Stripe integration.
Endpoints:
- GET /subscription - Get current subscription
- GET /subscription/tiers - Get available tiers
- GET /subscription/features - Get current features
- POST /subscription/checkout - Create Stripe checkout session
- POST /subscription/portal - Create Stripe customer portal session
- POST /subscription/cancel - Cancel subscription
"""
import os
from fastapi import APIRouter, HTTPException, status, Request
from sqlalchemy import select, func
from pydantic import BaseModel
from typing import Optional
from app.api.deps import Database, CurrentUser
from app.models.domain import Domain
from app.models.user import User
from app.models.subscription import Subscription, SubscriptionTier, TIER_CONFIG
from app.schemas.subscription import SubscriptionResponse
from app.services.stripe_service import StripeService, TIER_FEATURES
from app.services.email_service import email_service
router = APIRouter()
# ============== Schemas ==============
class CheckoutRequest(BaseModel):
"""Request to create checkout session."""
plan: str # "trader" or "tycoon"
success_url: Optional[str] = None
cancel_url: Optional[str] = None
class CheckoutResponse(BaseModel):
"""Response with checkout URL."""
checkout_url: str
session_id: str
class PortalResponse(BaseModel):
"""Response with portal URL."""
portal_url: str
# ============== Endpoints ==============
@router.get("", response_model=SubscriptionResponse)
async def get_subscription(
current_user: CurrentUser,
db: Database,
):
"""Get current user's subscription details."""
result = await db.execute(
select(Subscription).where(Subscription.user_id == current_user.id)
)
subscription = result.scalar_one_or_none()
if not subscription:
# Create default Scout subscription
subscription = Subscription(
user_id=current_user.id,
tier=SubscriptionTier.SCOUT,
max_domains=5,
check_frequency="daily",
)
db.add(subscription)
await db.commit()
await db.refresh(subscription)
# Count domains used
domain_count = await db.execute(
select(func.count(Domain.id)).where(Domain.user_id == current_user.id)
)
domains_used = domain_count.scalar() or 0
config = subscription.config
return SubscriptionResponse(
id=subscription.id,
tier=subscription.tier.value,
tier_name=config["name"],
status=subscription.status.value,
domain_limit=subscription.max_domains,
domains_used=domains_used,
portfolio_limit=config.get("portfolio_limit", 0),
check_frequency=config["check_frequency"],
history_days=config["history_days"],
features=config["features"],
started_at=subscription.started_at,
expires_at=subscription.expires_at,
)
@router.get("/tiers")
async def get_subscription_tiers():
"""Get available subscription tiers and their features."""
tiers = []
for tier_enum, config in TIER_CONFIG.items():
feature_list = []
feature_list.append(f"{config['domain_limit']} domains in watchlist")
if config.get("portfolio_limit"):
if config["portfolio_limit"] == -1:
feature_list.append("Unlimited portfolio domains")
elif config["portfolio_limit"] > 0:
feature_list.append(f"{config['portfolio_limit']} portfolio domains")
if config["check_frequency"] == "realtime":
feature_list.append("10-minute availability checks")
elif config["check_frequency"] == "hourly":
feature_list.append("Hourly availability checks")
else:
feature_list.append("Daily availability checks")
if config["features"].get("sms_alerts"):
feature_list.append("SMS & Telegram notifications")
elif config["features"].get("email_alerts"):
feature_list.append("Email notifications")
if config["features"].get("domain_valuation"):
feature_list.append("Domain valuation")
if config["features"].get("market_insights"):
feature_list.append("Full market insights")
if config["history_days"] == -1:
feature_list.append("Unlimited check history")
elif config["history_days"] > 0:
feature_list.append(f"{config['history_days']}-day check history")
if config["features"].get("api_access"):
feature_list.append("REST API access")
if config["features"].get("bulk_tools"):
feature_list.append("Bulk import/export tools")
if config["features"].get("seo_metrics"):
feature_list.append("SEO metrics (DA, backlinks)")
tiers.append({
"id": tier_enum.value,
"name": config["name"],
"domain_limit": config["domain_limit"],
"portfolio_limit": config.get("portfolio_limit", 0),
"price": config["price"],
"currency": config.get("currency", "USD"),
"check_frequency": config["check_frequency"],
"features": feature_list,
"feature_flags": config["features"],
})
return {"tiers": tiers}
@router.get("/features")
async def get_my_features(current_user: CurrentUser, db: Database):
"""Get current user's available features based on subscription."""
result = await db.execute(
select(Subscription).where(Subscription.user_id == current_user.id)
)
subscription = result.scalar_one_or_none()
if not subscription:
# Default to Scout
config = TIER_CONFIG[SubscriptionTier.SCOUT]
return {
"tier": "scout",
"tier_name": "Scout",
"domain_limit": config["domain_limit"],
"portfolio_limit": config.get("portfolio_limit", 0),
"check_frequency": config["check_frequency"],
"history_days": config["history_days"],
"features": config["features"],
}
config = subscription.config
return {
"tier": subscription.tier.value,
"tier_name": config["name"],
"domain_limit": config["domain_limit"],
"portfolio_limit": config.get("portfolio_limit", 0),
"check_frequency": config["check_frequency"],
"history_days": config["history_days"],
"features": config["features"],
}
@router.post("/checkout", response_model=CheckoutResponse)
async def create_checkout_session(
request: CheckoutRequest,
current_user: CurrentUser,
db: Database,
):
"""
Create a Stripe Checkout session for subscription upgrade.
Args:
plan: "trader" or "tycoon"
success_url: URL to redirect after successful payment
cancel_url: URL to redirect if user cancels
Returns:
checkout_url: Stripe Checkout page URL
session_id: Stripe session ID
"""
if not StripeService.is_configured():
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Payment system not configured. Please contact support.",
)
if request.plan not in ["trader", "tycoon"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid plan. Must be 'trader' or 'tycoon'",
)
# Get site URL from environment
site_url = os.getenv("SITE_URL", "http://localhost:3000")
success_url = request.success_url or f"{site_url}/command/welcome?plan={request.plan}"
cancel_url = request.cancel_url or f"{site_url}/pricing?cancelled=true"
try:
result = await StripeService.create_checkout_session(
user=current_user,
plan=request.plan,
success_url=success_url,
cancel_url=cancel_url,
)
# Save Stripe customer ID if new
if result.get("customer_id") and not current_user.stripe_customer_id:
current_user.stripe_customer_id = result["customer_id"]
await db.commit()
return CheckoutResponse(
checkout_url=result["checkout_url"],
session_id=result["session_id"],
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create checkout session: {str(e)}",
)
@router.post("/portal", response_model=PortalResponse)
async def create_portal_session(
current_user: CurrentUser,
db: Database,
):
"""
Create a Stripe Customer Portal session.
Users can:
- Update payment method
- View invoices
- Cancel subscription
- Update billing info
"""
if not StripeService.is_configured():
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Payment system not configured. Please contact support.",
)
if not current_user.stripe_customer_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No billing account found. Please subscribe to a plan first.",
)
site_url = os.getenv("SITE_URL", "http://localhost:3000")
return_url = f"{site_url}/command/settings"
try:
portal_url = await StripeService.create_portal_session(
customer_id=current_user.stripe_customer_id,
return_url=return_url,
)
return PortalResponse(portal_url=portal_url)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create portal session: {str(e)}",
)
@router.post("/cancel")
async def cancel_subscription(
current_user: CurrentUser,
db: Database,
):
"""
Cancel subscription and downgrade to Scout.
Note: For Stripe-managed subscriptions, use the Customer Portal instead.
This endpoint is for manual cancellation.
"""
result = await db.execute(
select(Subscription).where(Subscription.user_id == current_user.id)
)
subscription = result.scalar_one_or_none()
if not subscription:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No subscription found",
)
if subscription.tier == SubscriptionTier.SCOUT:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Already on free plan",
)
# Downgrade to Scout
old_tier = subscription.tier.value
subscription.tier = SubscriptionTier.SCOUT
subscription.max_domains = TIER_CONFIG[SubscriptionTier.SCOUT]["domain_limit"]
subscription.check_frequency = TIER_CONFIG[SubscriptionTier.SCOUT]["check_frequency"]
subscription.stripe_subscription_id = None
await db.commit()
return {
"status": "cancelled",
"message": f"Subscription cancelled. Downgraded from {old_tier} to Scout.",
"new_tier": "scout",
}