pounce/backend/app/api/subscription.py
Yves Gugger bb7ce97330
Some checks failed
CI / Frontend Lint & Type Check (push) Has been cancelled
CI / Frontend Build (push) Has been cancelled
CI / Backend Lint (push) Has been cancelled
CI / Backend Tests (push) Has been cancelled
CI / Docker Build (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
Deploy / Build & Push Images (push) Has been cancelled
Deploy / Deploy to Server (push) Has been cancelled
Deploy / Notify (push) Has been cancelled
Deploy: referral rewards antifraud + legal contact updates
2025-12-15 13:56:43 +01:00

347 lines
11 KiB
Python

"""
Subscription API endpoints with Stripe integration.
Endpoints:
- GET /subscription - Get current subscription
- GET /subscription/tiers - Get available tiers
- GET /subscription/features - Get current features
- POST /subscription/checkout - Create Stripe checkout session
- POST /subscription/portal - Create Stripe customer portal session
- POST /subscription/cancel - Cancel subscription
"""
import os
from fastapi import APIRouter, HTTPException, status, Request
from sqlalchemy import select, func
from pydantic import BaseModel
from typing import Optional
from app.api.deps import Database, CurrentUser
from app.models.domain import Domain
from app.models.user import User
from app.models.subscription import Subscription, SubscriptionTier, TIER_CONFIG
from app.schemas.subscription import SubscriptionResponse
from app.services.stripe_service import StripeService, TIER_FEATURES
from app.services.email_service import email_service
router = APIRouter()
# ============== Schemas ==============
class CheckoutRequest(BaseModel):
"""Request to create checkout session."""
plan: str # "trader" or "tycoon"
success_url: Optional[str] = None
cancel_url: Optional[str] = None
class CheckoutResponse(BaseModel):
"""Response with checkout URL."""
checkout_url: str
session_id: str
class PortalResponse(BaseModel):
"""Response with portal URL."""
portal_url: str
# ============== Endpoints ==============
@router.get("", response_model=SubscriptionResponse)
async def get_subscription(
current_user: CurrentUser,
db: Database,
):
"""Get current user's subscription details."""
result = await db.execute(
select(Subscription).where(Subscription.user_id == current_user.id)
)
subscription = result.scalar_one_or_none()
if not subscription:
# Create default Scout subscription
subscription = Subscription(
user_id=current_user.id,
tier=SubscriptionTier.SCOUT,
max_domains=5,
check_frequency="daily",
)
db.add(subscription)
await db.commit()
await db.refresh(subscription)
# Count domains used
domain_count = await db.execute(
select(func.count(Domain.id)).where(Domain.user_id == current_user.id)
)
domains_used = domain_count.scalar() or 0
config = subscription.config
return SubscriptionResponse(
id=subscription.id,
tier=subscription.tier.value,
tier_name=config["name"],
status=subscription.status.value,
domain_limit=subscription.domain_limit,
domains_used=domains_used,
portfolio_limit=config.get("portfolio_limit", 0),
check_frequency=config["check_frequency"],
history_days=config["history_days"],
features=config["features"],
started_at=subscription.started_at,
expires_at=subscription.expires_at,
)
@router.get("/tiers")
async def get_subscription_tiers():
"""Get available subscription tiers and their features."""
tiers = []
for tier_enum, config in TIER_CONFIG.items():
feature_list = []
feature_list.append(f"{config['domain_limit']} domains in watchlist")
if config.get("portfolio_limit"):
if config["portfolio_limit"] == -1:
feature_list.append("Unlimited portfolio domains")
elif config["portfolio_limit"] > 0:
feature_list.append(f"{config['portfolio_limit']} portfolio domains")
if config["check_frequency"] == "realtime":
feature_list.append("10-minute availability checks")
elif config["check_frequency"] == "hourly":
feature_list.append("Hourly availability checks")
else:
feature_list.append("Daily availability checks")
if config["features"].get("sms_alerts"):
feature_list.append("SMS & Telegram notifications")
elif config["features"].get("email_alerts"):
feature_list.append("Email notifications")
if config["features"].get("domain_valuation"):
feature_list.append("Domain valuation")
if config["features"].get("market_insights"):
feature_list.append("Full market insights")
if config["history_days"] == -1:
feature_list.append("Unlimited check history")
elif config["history_days"] > 0:
feature_list.append(f"{config['history_days']}-day check history")
if config["features"].get("api_access"):
feature_list.append("REST API access")
if config["features"].get("bulk_tools"):
feature_list.append("Bulk import/export tools")
if config["features"].get("seo_metrics"):
feature_list.append("SEO metrics (DA, backlinks)")
tiers.append({
"id": tier_enum.value,
"name": config["name"],
"domain_limit": config["domain_limit"],
"portfolio_limit": config.get("portfolio_limit", 0),
"price": config["price"],
"currency": config.get("currency", "USD"),
"check_frequency": config["check_frequency"],
"features": feature_list,
"feature_flags": config["features"],
})
return {"tiers": tiers}
@router.get("/features")
async def get_my_features(current_user: CurrentUser, db: Database):
"""Get current user's available features based on subscription."""
result = await db.execute(
select(Subscription).where(Subscription.user_id == current_user.id)
)
subscription = result.scalar_one_or_none()
if not subscription:
# Default to Scout
config = TIER_CONFIG[SubscriptionTier.SCOUT]
return {
"tier": "scout",
"tier_name": "Scout",
"domain_limit": config["domain_limit"],
"portfolio_limit": config.get("portfolio_limit", 0),
"check_frequency": config["check_frequency"],
"history_days": config["history_days"],
"features": config["features"],
}
config = subscription.config
return {
"tier": subscription.tier.value,
"tier_name": config["name"],
"domain_limit": config["domain_limit"],
"portfolio_limit": config.get("portfolio_limit", 0),
"check_frequency": config["check_frequency"],
"history_days": config["history_days"],
"features": config["features"],
}
@router.post("/checkout", response_model=CheckoutResponse)
async def create_checkout_session(
request: CheckoutRequest,
current_user: CurrentUser,
db: Database,
):
"""
Create a Stripe Checkout session for subscription upgrade.
Args:
plan: "trader" or "tycoon"
success_url: URL to redirect after successful payment
cancel_url: URL to redirect if user cancels
Returns:
checkout_url: Stripe Checkout page URL
session_id: Stripe session ID
"""
if not StripeService.is_configured():
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Payment system not configured. Please contact support.",
)
if request.plan not in ["trader", "tycoon"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid plan. Must be 'trader' or 'tycoon'",
)
# Get site URL from environment
site_url = os.getenv("SITE_URL", "http://localhost:3000")
success_url = request.success_url or f"{site_url}/command/welcome?plan={request.plan}"
cancel_url = request.cancel_url or f"{site_url}/pricing?cancelled=true"
try:
result = await StripeService.create_checkout_session(
user=current_user,
plan=request.plan,
success_url=success_url,
cancel_url=cancel_url,
)
# Save Stripe customer ID if new
if result.get("customer_id") and not current_user.stripe_customer_id:
current_user.stripe_customer_id = result["customer_id"]
await db.commit()
return CheckoutResponse(
checkout_url=result["checkout_url"],
session_id=result["session_id"],
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create checkout session: {str(e)}",
)
@router.post("/portal", response_model=PortalResponse)
async def create_portal_session(
current_user: CurrentUser,
db: Database,
):
"""
Create a Stripe Customer Portal session.
Users can:
- Update payment method
- View invoices
- Cancel subscription
- Update billing info
"""
if not StripeService.is_configured():
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Payment system not configured. Please contact support.",
)
if not current_user.stripe_customer_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No billing account found. Please subscribe to a plan first.",
)
site_url = os.getenv("SITE_URL", "http://localhost:3000")
return_url = f"{site_url}/command/settings"
try:
portal_url = await StripeService.create_portal_session(
customer_id=current_user.stripe_customer_id,
return_url=return_url,
)
return PortalResponse(portal_url=portal_url)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create portal session: {str(e)}",
)
@router.post("/cancel")
async def cancel_subscription(
current_user: CurrentUser,
db: Database,
):
"""
Cancel subscription and downgrade to Scout.
Note: For Stripe-managed subscriptions, use the Customer Portal instead.
This endpoint is for manual cancellation.
"""
result = await db.execute(
select(Subscription).where(Subscription.user_id == current_user.id)
)
subscription = result.scalar_one_or_none()
if not subscription:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No subscription found",
)
if subscription.tier == SubscriptionTier.SCOUT:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Already on free plan",
)
# Downgrade to Scout
old_tier = subscription.tier.value
subscription.tier = SubscriptionTier.SCOUT
subscription.max_domains = TIER_CONFIG[SubscriptionTier.SCOUT]["domain_limit"]
subscription.check_frequency = TIER_CONFIG[SubscriptionTier.SCOUT]["check_frequency"]
subscription.stripe_subscription_id = None
await db.commit()
return {
"status": "cancelled",
"message": f"Subscription cancelled. Downgraded from {old_tier} to Scout.",
"new_tier": "scout",
}