"""Authentication service.""" from datetime import datetime, timedelta from typing import Optional import bcrypt import secrets from jose import JWTError, jwt from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.config import get_settings from app.models.user import User from app.models.subscription import Subscription, SubscriptionTier, SubscriptionStatus, TIER_CONFIG settings = get_settings() class AuthService: """Service for authentication operations.""" @staticmethod def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a password against its hash.""" return bcrypt.checkpw( plain_password.encode('utf-8'), hashed_password.encode('utf-8') ) @staticmethod def hash_password(password: str) -> str: """Hash a password.""" salt = bcrypt.gensalt(rounds=12) hashed = bcrypt.hashpw(password.encode('utf-8'), salt) return hashed.decode('utf-8') @staticmethod def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: """Create a JWT access token.""" to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=settings.access_token_expire_minutes) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm) return encoded_jwt @staticmethod def decode_token(token: str) -> Optional[dict]: """Decode and validate a JWT token.""" try: payload = jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm]) return payload except JWTError: return None @staticmethod async def get_user_by_email(db: AsyncSession, email: str) -> Optional[User]: """Get user by email (case-insensitive).""" from sqlalchemy import func result = await db.execute( select(User).where(func.lower(User.email) == email.lower()) ) return result.scalar_one_or_none() @staticmethod async def get_user_by_id(db: AsyncSession, user_id: int) -> Optional[User]: """Get user by ID.""" result = await db.execute(select(User).where(User.id == user_id)) return result.scalar_one_or_none() @staticmethod async def authenticate_user(db: AsyncSession, email: str, password: str) -> Optional[User]: """Authenticate user with email and password.""" user = await AuthService.get_user_by_email(db, email) if not user: return None if not AuthService.verify_password(password, user.hashed_password): return None return user @staticmethod async def create_user( db: AsyncSession, email: str, password: str, name: Optional[str] = None ) -> User: """Create a new user with default subscription.""" async def _generate_unique_invite_code() -> str: # 12 hex chars; easy to validate + share + embed in URLs. for _ in range(12): code = secrets.token_hex(6) exists = await db.execute(select(User.id).where(User.invite_code == code)) if exists.scalar_one_or_none() is None: return code raise RuntimeError("Failed to generate unique invite code") # Create user (normalize email to lowercase) user = User( email=email.lower().strip(), hashed_password=AuthService.hash_password(password), name=name, invite_code=await _generate_unique_invite_code(), ) db.add(user) await db.flush() # Create default Scout (free) subscription subscription = Subscription( user_id=user.id, tier=SubscriptionTier.SCOUT, status=SubscriptionStatus.ACTIVE, ) db.add(subscription) await db.commit() await db.refresh(user) return user # Singleton instance auth_service = AuthService()