Some checks failed
CI / Frontend Lint & Type Check (push) Has been cancelled
CI / Frontend Build (push) Has been cancelled
CI / Backend Lint (push) Has been cancelled
CI / Backend Tests (push) Has been cancelled
CI / Docker Build (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
Deploy / Build & Push Images (push) Has been cancelled
Deploy / Deploy to Server (push) Has been cancelled
Deploy / Notify (push) Has been cancelled
347 lines
11 KiB
Python
347 lines
11 KiB
Python
"""
|
|
Subscription API endpoints with Stripe integration.
|
|
|
|
Endpoints:
|
|
- GET /subscription - Get current subscription
|
|
- GET /subscription/tiers - Get available tiers
|
|
- GET /subscription/features - Get current features
|
|
- POST /subscription/checkout - Create Stripe checkout session
|
|
- POST /subscription/portal - Create Stripe customer portal session
|
|
- POST /subscription/cancel - Cancel subscription
|
|
"""
|
|
import os
|
|
from fastapi import APIRouter, HTTPException, status, Request
|
|
from sqlalchemy import select, func
|
|
from pydantic import BaseModel
|
|
from typing import Optional
|
|
|
|
from app.api.deps import Database, CurrentUser
|
|
from app.models.domain import Domain
|
|
from app.models.user import User
|
|
from app.models.subscription import Subscription, SubscriptionTier, TIER_CONFIG
|
|
from app.schemas.subscription import SubscriptionResponse
|
|
from app.services.stripe_service import StripeService, TIER_FEATURES
|
|
from app.services.email_service import email_service
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
# ============== Schemas ==============
|
|
|
|
class CheckoutRequest(BaseModel):
|
|
"""Request to create checkout session."""
|
|
plan: str # "trader" or "tycoon"
|
|
success_url: Optional[str] = None
|
|
cancel_url: Optional[str] = None
|
|
|
|
|
|
class CheckoutResponse(BaseModel):
|
|
"""Response with checkout URL."""
|
|
checkout_url: str
|
|
session_id: str
|
|
|
|
|
|
class PortalResponse(BaseModel):
|
|
"""Response with portal URL."""
|
|
portal_url: str
|
|
|
|
|
|
# ============== Endpoints ==============
|
|
|
|
@router.get("", response_model=SubscriptionResponse)
|
|
async def get_subscription(
|
|
current_user: CurrentUser,
|
|
db: Database,
|
|
):
|
|
"""Get current user's subscription details."""
|
|
result = await db.execute(
|
|
select(Subscription).where(Subscription.user_id == current_user.id)
|
|
)
|
|
subscription = result.scalar_one_or_none()
|
|
|
|
if not subscription:
|
|
# Create default Scout subscription
|
|
subscription = Subscription(
|
|
user_id=current_user.id,
|
|
tier=SubscriptionTier.SCOUT,
|
|
max_domains=5,
|
|
check_frequency="daily",
|
|
)
|
|
db.add(subscription)
|
|
await db.commit()
|
|
await db.refresh(subscription)
|
|
|
|
# Count domains used
|
|
domain_count = await db.execute(
|
|
select(func.count(Domain.id)).where(Domain.user_id == current_user.id)
|
|
)
|
|
domains_used = domain_count.scalar() or 0
|
|
|
|
config = subscription.config
|
|
|
|
return SubscriptionResponse(
|
|
id=subscription.id,
|
|
tier=subscription.tier.value,
|
|
tier_name=config["name"],
|
|
status=subscription.status.value,
|
|
domain_limit=subscription.domain_limit,
|
|
domains_used=domains_used,
|
|
portfolio_limit=config.get("portfolio_limit", 0),
|
|
check_frequency=config["check_frequency"],
|
|
history_days=config["history_days"],
|
|
features=config["features"],
|
|
started_at=subscription.started_at,
|
|
expires_at=subscription.expires_at,
|
|
)
|
|
|
|
|
|
@router.get("/tiers")
|
|
async def get_subscription_tiers():
|
|
"""Get available subscription tiers and their features."""
|
|
tiers = []
|
|
|
|
for tier_enum, config in TIER_CONFIG.items():
|
|
feature_list = []
|
|
|
|
feature_list.append(f"{config['domain_limit']} domains in watchlist")
|
|
|
|
if config.get("portfolio_limit"):
|
|
if config["portfolio_limit"] == -1:
|
|
feature_list.append("Unlimited portfolio domains")
|
|
elif config["portfolio_limit"] > 0:
|
|
feature_list.append(f"{config['portfolio_limit']} portfolio domains")
|
|
|
|
if config["check_frequency"] == "realtime":
|
|
feature_list.append("10-minute availability checks")
|
|
elif config["check_frequency"] == "hourly":
|
|
feature_list.append("Hourly availability checks")
|
|
else:
|
|
feature_list.append("Daily availability checks")
|
|
|
|
if config["features"].get("sms_alerts"):
|
|
feature_list.append("SMS & Telegram notifications")
|
|
elif config["features"].get("email_alerts"):
|
|
feature_list.append("Email notifications")
|
|
|
|
if config["features"].get("domain_valuation"):
|
|
feature_list.append("Domain valuation")
|
|
|
|
if config["features"].get("market_insights"):
|
|
feature_list.append("Full market insights")
|
|
|
|
if config["history_days"] == -1:
|
|
feature_list.append("Unlimited check history")
|
|
elif config["history_days"] > 0:
|
|
feature_list.append(f"{config['history_days']}-day check history")
|
|
|
|
if config["features"].get("api_access"):
|
|
feature_list.append("REST API access")
|
|
|
|
if config["features"].get("bulk_tools"):
|
|
feature_list.append("Bulk import/export tools")
|
|
|
|
if config["features"].get("seo_metrics"):
|
|
feature_list.append("SEO metrics (DA, backlinks)")
|
|
|
|
tiers.append({
|
|
"id": tier_enum.value,
|
|
"name": config["name"],
|
|
"domain_limit": config["domain_limit"],
|
|
"portfolio_limit": config.get("portfolio_limit", 0),
|
|
"price": config["price"],
|
|
"currency": config.get("currency", "USD"),
|
|
"check_frequency": config["check_frequency"],
|
|
"features": feature_list,
|
|
"feature_flags": config["features"],
|
|
})
|
|
|
|
return {"tiers": tiers}
|
|
|
|
|
|
@router.get("/features")
|
|
async def get_my_features(current_user: CurrentUser, db: Database):
|
|
"""Get current user's available features based on subscription."""
|
|
result = await db.execute(
|
|
select(Subscription).where(Subscription.user_id == current_user.id)
|
|
)
|
|
subscription = result.scalar_one_or_none()
|
|
|
|
if not subscription:
|
|
# Default to Scout
|
|
config = TIER_CONFIG[SubscriptionTier.SCOUT]
|
|
return {
|
|
"tier": "scout",
|
|
"tier_name": "Scout",
|
|
"domain_limit": config["domain_limit"],
|
|
"portfolio_limit": config.get("portfolio_limit", 0),
|
|
"check_frequency": config["check_frequency"],
|
|
"history_days": config["history_days"],
|
|
"features": config["features"],
|
|
}
|
|
|
|
config = subscription.config
|
|
|
|
return {
|
|
"tier": subscription.tier.value,
|
|
"tier_name": config["name"],
|
|
"domain_limit": config["domain_limit"],
|
|
"portfolio_limit": config.get("portfolio_limit", 0),
|
|
"check_frequency": config["check_frequency"],
|
|
"history_days": config["history_days"],
|
|
"features": config["features"],
|
|
}
|
|
|
|
|
|
@router.post("/checkout", response_model=CheckoutResponse)
|
|
async def create_checkout_session(
|
|
request: CheckoutRequest,
|
|
current_user: CurrentUser,
|
|
db: Database,
|
|
):
|
|
"""
|
|
Create a Stripe Checkout session for subscription upgrade.
|
|
|
|
Args:
|
|
plan: "trader" or "tycoon"
|
|
success_url: URL to redirect after successful payment
|
|
cancel_url: URL to redirect if user cancels
|
|
|
|
Returns:
|
|
checkout_url: Stripe Checkout page URL
|
|
session_id: Stripe session ID
|
|
"""
|
|
if not StripeService.is_configured():
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Payment system not configured. Please contact support.",
|
|
)
|
|
|
|
if request.plan not in ["trader", "tycoon"]:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid plan. Must be 'trader' or 'tycoon'",
|
|
)
|
|
|
|
# Get site URL from environment
|
|
site_url = os.getenv("SITE_URL", "http://localhost:3000")
|
|
|
|
success_url = request.success_url or f"{site_url}/command/welcome?plan={request.plan}"
|
|
cancel_url = request.cancel_url or f"{site_url}/pricing?cancelled=true"
|
|
|
|
try:
|
|
result = await StripeService.create_checkout_session(
|
|
user=current_user,
|
|
plan=request.plan,
|
|
success_url=success_url,
|
|
cancel_url=cancel_url,
|
|
)
|
|
|
|
# Save Stripe customer ID if new
|
|
if result.get("customer_id") and not current_user.stripe_customer_id:
|
|
current_user.stripe_customer_id = result["customer_id"]
|
|
await db.commit()
|
|
|
|
return CheckoutResponse(
|
|
checkout_url=result["checkout_url"],
|
|
session_id=result["session_id"],
|
|
)
|
|
|
|
except ValueError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e),
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to create checkout session: {str(e)}",
|
|
)
|
|
|
|
|
|
@router.post("/portal", response_model=PortalResponse)
|
|
async def create_portal_session(
|
|
current_user: CurrentUser,
|
|
db: Database,
|
|
):
|
|
"""
|
|
Create a Stripe Customer Portal session.
|
|
|
|
Users can:
|
|
- Update payment method
|
|
- View invoices
|
|
- Cancel subscription
|
|
- Update billing info
|
|
"""
|
|
if not StripeService.is_configured():
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Payment system not configured. Please contact support.",
|
|
)
|
|
|
|
if not current_user.stripe_customer_id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="No billing account found. Please subscribe to a plan first.",
|
|
)
|
|
|
|
site_url = os.getenv("SITE_URL", "http://localhost:3000")
|
|
return_url = f"{site_url}/command/settings"
|
|
|
|
try:
|
|
portal_url = await StripeService.create_portal_session(
|
|
customer_id=current_user.stripe_customer_id,
|
|
return_url=return_url,
|
|
)
|
|
|
|
return PortalResponse(portal_url=portal_url)
|
|
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to create portal session: {str(e)}",
|
|
)
|
|
|
|
|
|
@router.post("/cancel")
|
|
async def cancel_subscription(
|
|
current_user: CurrentUser,
|
|
db: Database,
|
|
):
|
|
"""
|
|
Cancel subscription and downgrade to Scout.
|
|
|
|
Note: For Stripe-managed subscriptions, use the Customer Portal instead.
|
|
This endpoint is for manual cancellation.
|
|
"""
|
|
result = await db.execute(
|
|
select(Subscription).where(Subscription.user_id == current_user.id)
|
|
)
|
|
subscription = result.scalar_one_or_none()
|
|
|
|
if not subscription:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="No subscription found",
|
|
)
|
|
|
|
if subscription.tier == SubscriptionTier.SCOUT:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Already on free plan",
|
|
)
|
|
|
|
# Downgrade to Scout
|
|
old_tier = subscription.tier.value
|
|
subscription.tier = SubscriptionTier.SCOUT
|
|
subscription.max_domains = TIER_CONFIG[SubscriptionTier.SCOUT]["domain_limit"]
|
|
subscription.check_frequency = TIER_CONFIG[SubscriptionTier.SCOUT]["check_frequency"]
|
|
subscription.stripe_subscription_id = None
|
|
|
|
await db.commit()
|
|
|
|
return {
|
|
"status": "cancelled",
|
|
"message": f"Subscription cancelled. Downgraded from {old_tier} to Scout.",
|
|
"new_tier": "scout",
|
|
}
|