from __future__ import annotations import json from dataclasses import dataclass from typing import Any, Awaitable, Callable, Optional from sqlalchemy import and_, func, select from sqlalchemy.ext.asyncio import AsyncSession from app.api.tld_prices import get_trending_tlds from app.models.auction import DomainAuction from app.models.domain import Domain from app.models.listing import DomainListing, ListingStatus from app.models.subscription import Subscription, SubscriptionTier from app.models.user import User from app.services.analyze.service import get_domain_analysis from app.services.zone_file import get_dropped_domains ToolHandler = Callable[[AsyncSession, User, dict[str, Any]], Awaitable[dict[str, Any]]] @dataclass(frozen=True) class ToolDef: name: str description: str json_schema: dict[str, Any] min_tier: SubscriptionTier = SubscriptionTier.TRADER handler: ToolHandler | None = None def _tier_level(tier: SubscriptionTier) -> int: if tier == SubscriptionTier.TYCOON: return 3 if tier == SubscriptionTier.TRADER: return 2 return 1 async def _get_subscription(db: AsyncSession, user: User) -> Subscription | None: res = await db.execute(select(Subscription).where(Subscription.user_id == user.id)) return res.scalar_one_or_none() def _require_tier(user_tier: SubscriptionTier, tool_tier: SubscriptionTier) -> None: if _tier_level(user_tier) < _tier_level(tool_tier): raise PermissionError(f"Tool requires {tool_tier.value} tier.") def _clamp_int(value: Any, *, lo: int, hi: int, default: int) -> int: try: v = int(value) except Exception: return default return max(lo, min(hi, v)) def _clamp_float(value: Any, *, lo: float, hi: float, default: float) -> float: try: v = float(value) except Exception: return default return max(lo, min(hi, v)) # ============================================================================ # TOOL IMPLEMENTATIONS (READ-ONLY) # ============================================================================ async def tool_get_subscription(db: AsyncSession, user: User, args: dict[str, Any]) -> dict[str, Any]: sub = await _get_subscription(db, user) if not sub: return {"tier": "scout", "features": {}, "limits": {}} cfg = sub.config return { "tier": sub.tier.value, "tier_name": cfg.get("name"), "features": cfg.get("features", {}), "limits": { "watchlist": sub.domain_limit, "portfolio": cfg.get("portfolio_limit"), "listings": cfg.get("listing_limit"), "sniper": cfg.get("sniper_limit"), "history_days": cfg.get("history_days"), "check_frequency": cfg.get("check_frequency"), }, } async def tool_get_dashboard_summary(db: AsyncSession, user: User, args: dict[str, Any]) -> dict[str, Any]: # Similar to /dashboard/summary but kept lightweight and tool-friendly. now = args.get("now") # ignored (server time used) _ = now from datetime import datetime, timedelta t = datetime.utcnow() active = and_(DomainAuction.is_active == True, DomainAuction.end_time > t) total_auctions = (await db.execute(select(func.count(DomainAuction.id)).where(active))).scalar() or 0 cutoff = t + timedelta(hours=24) ending = and_(DomainAuction.is_active == True, DomainAuction.end_time > t, DomainAuction.end_time <= cutoff) ending_soon_count = (await db.execute(select(func.count(DomainAuction.id)).where(ending))).scalar() or 0 ending_rows = ( await db.execute(select(DomainAuction).where(ending).order_by(DomainAuction.end_time.asc()).limit(10)) ).scalars().all() # Listings counts listing_counts = ( await db.execute( select(DomainListing.status, func.count(DomainListing.id)) .where(DomainListing.user_id == user.id) .group_by(DomainListing.status) ) ).all() by_status = {str(status): int(count) for status, count in listing_counts} trending = await get_trending_tlds(db) return { "market": { "total_auctions": total_auctions, "ending_soon_24h": ending_soon_count, "ending_soon_preview": [ { "domain": a.domain, "current_bid": a.current_bid, "platform": a.platform, "end_time": a.end_time.isoformat() if a.end_time else None, "auction_url": a.auction_url, } for a in ending_rows ], }, "listings": { "active": by_status.get(ListingStatus.ACTIVE.value, 0), "sold": by_status.get(ListingStatus.SOLD.value, 0), "draft": by_status.get(ListingStatus.DRAFT.value, 0), "total": sum(by_status.values()), }, "tlds": trending, "timestamp": t.isoformat(), } async def tool_list_watchlist(db: AsyncSession, user: User, args: dict[str, Any]) -> dict[str, Any]: page = _clamp_int(args.get("page"), lo=1, hi=50, default=1) per_page = _clamp_int(args.get("per_page"), lo=1, hi=50, default=20) offset = (page - 1) * per_page total = (await db.execute(select(func.count(Domain.id)).where(Domain.user_id == user.id))).scalar() or 0 rows = ( await db.execute( select(Domain) .where(Domain.user_id == user.id) .order_by(Domain.created_at.desc()) .offset(offset) .limit(per_page) ) ).scalars().all() return { "page": page, "per_page": per_page, "total": int(total), "domains": [ { "id": d.id, "name": d.name, "status": getattr(d.status, "value", d.status), "is_available": bool(d.is_available), "registrar": d.registrar, "created_at": d.created_at.isoformat() if d.created_at else None, "updated_at": d.updated_at.isoformat() if d.updated_at else None, } for d in rows ], } async def tool_analyze_domain(db: AsyncSession, user: User, args: dict[str, Any]) -> dict[str, Any]: domain = (args.get("domain") or "").strip() if not domain: return {"error": "Missing domain"} fast = bool(args.get("fast", False)) refresh = bool(args.get("refresh", False)) res = await get_domain_analysis(db, domain, fast=fast, refresh=refresh) return res.model_dump(mode="json") async def tool_market_feed(db: AsyncSession, user: User, args: dict[str, Any]) -> dict[str, Any]: # Read-only query against DomainAuction similar to /auctions/feed; keep it capped. from datetime import datetime, timedelta limit = _clamp_int(args.get("limit"), lo=1, hi=50, default=20) source = (args.get("source") or "all").lower() keyword = (args.get("keyword") or "").strip().lower() or None tld = (args.get("tld") or "").strip().lower().lstrip(".") or None sort_by = (args.get("sort_by") or "time").lower() ending_within = args.get("ending_within_hours") ending_within_h = _clamp_int(ending_within, lo=1, hi=168, default=0) if ending_within is not None else None now = datetime.utcnow() q = select(DomainAuction).where(DomainAuction.is_active == True, DomainAuction.end_time > now) if source in ("pounce", "external"): q = q.where(DomainAuction.source == source) if keyword: q = q.where(DomainAuction.domain.ilike(f"%{keyword}%")) if tld: q = q.where(DomainAuction.domain.ilike(f"%.{tld}")) if ending_within_h: q = q.where(DomainAuction.end_time <= (now + timedelta(hours=ending_within_h))) if sort_by == "score": q = q.order_by(DomainAuction.score.desc().nullslast(), DomainAuction.end_time.asc()) else: q = q.order_by(DomainAuction.end_time.asc()) auctions = (await db.execute(q.limit(limit))).scalars().all() return { "items": [ { "domain": a.domain, "current_bid": a.current_bid, "platform": a.platform, "end_time": a.end_time.isoformat() if a.end_time else None, "bids": a.bids, "score": a.score, "auction_url": a.auction_url, "source": a.source, } for a in auctions ], "count": len(auctions), "timestamp": now.isoformat(), } async def tool_get_drops(db: AsyncSession, user: User, args: dict[str, Any]) -> dict[str, Any]: tld = (args.get("tld") or None) hours = _clamp_int(args.get("hours"), lo=1, hi=48, default=24) limit = _clamp_int(args.get("limit"), lo=1, hi=100, default=50) offset = _clamp_int(args.get("offset"), lo=0, hi=10000, default=0) keyword = (args.get("keyword") or None) min_length = args.get("min_length") max_length = args.get("max_length") exclude_numeric = bool(args.get("exclude_numeric", False)) exclude_hyphen = bool(args.get("exclude_hyphen", False)) result = await get_dropped_domains( db=db, tld=(tld.lower().lstrip(".") if isinstance(tld, str) and tld.strip() else None), hours=hours, min_length=int(min_length) if min_length is not None else None, max_length=int(max_length) if max_length is not None else None, exclude_numeric=exclude_numeric, exclude_hyphen=exclude_hyphen, keyword=(str(keyword).strip() if keyword else None), limit=limit, offset=offset, ) return result def get_tool_defs() -> list[ToolDef]: return [ ToolDef( name="get_subscription", description="Get current user's subscription tier, features and limits.", json_schema={"type": "object", "properties": {}, "additionalProperties": False}, min_tier=SubscriptionTier.TRADER, handler=tool_get_subscription, ), ToolDef( name="get_dashboard_summary", description="Get a compact snapshot: ending auctions, listing stats, trending TLDs.", json_schema={"type": "object", "properties": {}, "additionalProperties": False}, min_tier=SubscriptionTier.TRADER, handler=tool_get_dashboard_summary, ), ToolDef( name="list_watchlist", description="List user's watchlist domains (monitored domains).", json_schema={ "type": "object", "properties": { "page": {"type": "integer", "minimum": 1, "maximum": 50}, "per_page": {"type": "integer", "minimum": 1, "maximum": 50}, }, "additionalProperties": False, }, min_tier=SubscriptionTier.TRADER, handler=tool_list_watchlist, ), ToolDef( name="analyze_domain", description="Run Pounce domain analysis (Authority/Market/Risk/Value) for a given domain.", json_schema={ "type": "object", "properties": { "domain": {"type": "string"}, "fast": {"type": "boolean"}, "refresh": {"type": "boolean"}, }, "required": ["domain"], "additionalProperties": False, }, min_tier=SubscriptionTier.TRADER, handler=tool_analyze_domain, ), ToolDef( name="market_feed", description="Get current auction feed (filters: source, keyword, tld, ending window).", json_schema={ "type": "object", "properties": { "source": {"type": "string", "enum": ["all", "pounce", "external"]}, "keyword": {"type": "string"}, "tld": {"type": "string"}, "sort_by": {"type": "string", "enum": ["time", "score"]}, "ending_within_hours": {"type": "integer", "minimum": 1, "maximum": 168}, "limit": {"type": "integer", "minimum": 1, "maximum": 50}, }, "additionalProperties": False, }, min_tier=SubscriptionTier.TRADER, handler=tool_market_feed, ), ToolDef( name="get_drops", description="Get recently dropped domains from zone files (auth required).", json_schema={ "type": "object", "properties": { "tld": {"type": "string"}, "hours": {"type": "integer", "minimum": 1, "maximum": 48}, "min_length": {"type": "integer", "minimum": 1, "maximum": 63}, "max_length": {"type": "integer", "minimum": 1, "maximum": 63}, "exclude_numeric": {"type": "boolean"}, "exclude_hyphen": {"type": "boolean"}, "keyword": {"type": "string"}, "limit": {"type": "integer", "minimum": 1, "maximum": 100}, "offset": {"type": "integer", "minimum": 0, "maximum": 10000}, }, "additionalProperties": False, }, min_tier=SubscriptionTier.TRADER, handler=tool_get_drops, ), ] def tools_for_path(path: str) -> list[str]: """ Limit the visible tool list depending on the current Terminal page. This keeps prompts smaller and makes the model more decisive. """ p = (path or "").split("?")[0] if p.startswith("/terminal/hunt"): return ["get_subscription", "get_dashboard_summary", "market_feed", "get_drops", "analyze_domain", "list_watchlist"] if p.startswith("/terminal/market"): return ["get_subscription", "market_feed", "analyze_domain"] if p.startswith("/terminal/watchlist"): return ["get_subscription", "list_watchlist", "analyze_domain"] # default: allow a safe minimal set return ["get_subscription", "get_dashboard_summary", "analyze_domain"] async def execute_tool(db: AsyncSession, user: User, name: str, args: dict[str, Any], *, path: str) -> dict[str, Any]: defs = {t.name: t for t in get_tool_defs()} tool = defs.get(name) if tool is None or tool.handler is None: return {"error": f"Unknown tool: {name}"} # Enforce tool allowed on this page allowed = set(tools_for_path(path)) if name not in allowed: return {"error": f"Tool not allowed for path: {name}"} sub = await _get_subscription(db, user) user_tier = sub.tier if sub else SubscriptionTier.SCOUT try: _require_tier(user_tier, tool.min_tier) except PermissionError as e: return {"error": str(e)} try: return await tool.handler(db, user, args or {}) except Exception as e: return {"error": f"{type(e).__name__}: {e}"} def tool_catalog_for_prompt(path: str) -> list[dict[str, Any]]: allowed = set(tools_for_path(path)) out: list[dict[str, Any]] = [] for t in get_tool_defs(): if t.name in allowed: out.append( { "name": t.name, "description": t.description, "schema": t.json_schema, } ) return out