pounce/backend/app/services/llm_tools.py
Yves Gugger 8f6e13ffcf
Some checks failed
CI / Frontend Lint & Type Check (push) Has been cancelled
CI / Frontend Build (push) Has been cancelled
CI / Backend Lint (push) Has been cancelled
CI / Backend Tests (push) Has been cancelled
CI / Docker Build (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
Deploy / Build & Push Images (push) Has been cancelled
Deploy / Deploy to Server (push) Has been cancelled
Deploy / Notify (push) Has been cancelled
LLM Agent: tool-calling endpoint + HunterCompanion uses /llm/agent
2025-12-17 14:30:25 +01:00

408 lines
15 KiB
Python

from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Optional
from sqlalchemy import and_, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.tld_prices import get_trending_tlds
from app.models.auction import DomainAuction
from app.models.domain import Domain
from app.models.listing import DomainListing, ListingStatus
from app.models.subscription import Subscription, SubscriptionTier
from app.models.user import User
from app.services.analyze.service import get_domain_analysis
from app.services.zone_file import get_dropped_domains
ToolHandler = Callable[[AsyncSession, User, dict[str, Any]], Awaitable[dict[str, Any]]]
@dataclass(frozen=True)
class ToolDef:
name: str
description: str
json_schema: dict[str, Any]
min_tier: SubscriptionTier = SubscriptionTier.TRADER
handler: ToolHandler | None = None
def _tier_level(tier: SubscriptionTier) -> int:
if tier == SubscriptionTier.TYCOON:
return 3
if tier == SubscriptionTier.TRADER:
return 2
return 1
async def _get_subscription(db: AsyncSession, user: User) -> Subscription | None:
res = await db.execute(select(Subscription).where(Subscription.user_id == user.id))
return res.scalar_one_or_none()
def _require_tier(user_tier: SubscriptionTier, tool_tier: SubscriptionTier) -> None:
if _tier_level(user_tier) < _tier_level(tool_tier):
raise PermissionError(f"Tool requires {tool_tier.value} tier.")
def _clamp_int(value: Any, *, lo: int, hi: int, default: int) -> int:
try:
v = int(value)
except Exception:
return default
return max(lo, min(hi, v))
def _clamp_float(value: Any, *, lo: float, hi: float, default: float) -> float:
try:
v = float(value)
except Exception:
return default
return max(lo, min(hi, v))
# ============================================================================
# TOOL IMPLEMENTATIONS (READ-ONLY)
# ============================================================================
async def tool_get_subscription(db: AsyncSession, user: User, args: dict[str, Any]) -> dict[str, Any]:
sub = await _get_subscription(db, user)
if not sub:
return {"tier": "scout", "features": {}, "limits": {}}
cfg = sub.config
return {
"tier": sub.tier.value,
"tier_name": cfg.get("name"),
"features": cfg.get("features", {}),
"limits": {
"watchlist": sub.domain_limit,
"portfolio": cfg.get("portfolio_limit"),
"listings": cfg.get("listing_limit"),
"sniper": cfg.get("sniper_limit"),
"history_days": cfg.get("history_days"),
"check_frequency": cfg.get("check_frequency"),
},
}
async def tool_get_dashboard_summary(db: AsyncSession, user: User, args: dict[str, Any]) -> dict[str, Any]:
# Similar to /dashboard/summary but kept lightweight and tool-friendly.
now = args.get("now") # ignored (server time used)
_ = now
from datetime import datetime, timedelta
t = datetime.utcnow()
active = and_(DomainAuction.is_active == True, DomainAuction.end_time > t)
total_auctions = (await db.execute(select(func.count(DomainAuction.id)).where(active))).scalar() or 0
cutoff = t + timedelta(hours=24)
ending = and_(DomainAuction.is_active == True, DomainAuction.end_time > t, DomainAuction.end_time <= cutoff)
ending_soon_count = (await db.execute(select(func.count(DomainAuction.id)).where(ending))).scalar() or 0
ending_rows = (
await db.execute(select(DomainAuction).where(ending).order_by(DomainAuction.end_time.asc()).limit(10))
).scalars().all()
# Listings counts
listing_counts = (
await db.execute(
select(DomainListing.status, func.count(DomainListing.id))
.where(DomainListing.user_id == user.id)
.group_by(DomainListing.status)
)
).all()
by_status = {str(status): int(count) for status, count in listing_counts}
trending = await get_trending_tlds(db)
return {
"market": {
"total_auctions": total_auctions,
"ending_soon_24h": ending_soon_count,
"ending_soon_preview": [
{
"domain": a.domain,
"current_bid": a.current_bid,
"platform": a.platform,
"end_time": a.end_time.isoformat() if a.end_time else None,
"auction_url": a.auction_url,
}
for a in ending_rows
],
},
"listings": {
"active": by_status.get(ListingStatus.ACTIVE.value, 0),
"sold": by_status.get(ListingStatus.SOLD.value, 0),
"draft": by_status.get(ListingStatus.DRAFT.value, 0),
"total": sum(by_status.values()),
},
"tlds": trending,
"timestamp": t.isoformat(),
}
async def tool_list_watchlist(db: AsyncSession, user: User, args: dict[str, Any]) -> dict[str, Any]:
page = _clamp_int(args.get("page"), lo=1, hi=50, default=1)
per_page = _clamp_int(args.get("per_page"), lo=1, hi=50, default=20)
offset = (page - 1) * per_page
total = (await db.execute(select(func.count(Domain.id)).where(Domain.user_id == user.id))).scalar() or 0
rows = (
await db.execute(
select(Domain)
.where(Domain.user_id == user.id)
.order_by(Domain.created_at.desc())
.offset(offset)
.limit(per_page)
)
).scalars().all()
return {
"page": page,
"per_page": per_page,
"total": int(total),
"domains": [
{
"id": d.id,
"name": d.name,
"status": getattr(d.status, "value", d.status),
"is_available": bool(d.is_available),
"registrar": d.registrar,
"created_at": d.created_at.isoformat() if d.created_at else None,
"updated_at": d.updated_at.isoformat() if d.updated_at else None,
}
for d in rows
],
}
async def tool_analyze_domain(db: AsyncSession, user: User, args: dict[str, Any]) -> dict[str, Any]:
domain = (args.get("domain") or "").strip()
if not domain:
return {"error": "Missing domain"}
fast = bool(args.get("fast", False))
refresh = bool(args.get("refresh", False))
res = await get_domain_analysis(db, domain, fast=fast, refresh=refresh)
return res.model_dump(mode="json")
async def tool_market_feed(db: AsyncSession, user: User, args: dict[str, Any]) -> dict[str, Any]:
# Read-only query against DomainAuction similar to /auctions/feed; keep it capped.
from datetime import datetime, timedelta
limit = _clamp_int(args.get("limit"), lo=1, hi=50, default=20)
source = (args.get("source") or "all").lower()
keyword = (args.get("keyword") or "").strip().lower() or None
tld = (args.get("tld") or "").strip().lower().lstrip(".") or None
sort_by = (args.get("sort_by") or "time").lower()
ending_within = args.get("ending_within_hours")
ending_within_h = _clamp_int(ending_within, lo=1, hi=168, default=0) if ending_within is not None else None
now = datetime.utcnow()
q = select(DomainAuction).where(DomainAuction.is_active == True, DomainAuction.end_time > now)
if source in ("pounce", "external"):
q = q.where(DomainAuction.source == source)
if keyword:
q = q.where(DomainAuction.domain.ilike(f"%{keyword}%"))
if tld:
q = q.where(DomainAuction.domain.ilike(f"%.{tld}"))
if ending_within_h:
q = q.where(DomainAuction.end_time <= (now + timedelta(hours=ending_within_h)))
if sort_by == "score":
q = q.order_by(DomainAuction.score.desc().nullslast(), DomainAuction.end_time.asc())
else:
q = q.order_by(DomainAuction.end_time.asc())
auctions = (await db.execute(q.limit(limit))).scalars().all()
return {
"items": [
{
"domain": a.domain,
"current_bid": a.current_bid,
"platform": a.platform,
"end_time": a.end_time.isoformat() if a.end_time else None,
"bids": a.bids,
"score": a.score,
"auction_url": a.auction_url,
"source": a.source,
}
for a in auctions
],
"count": len(auctions),
"timestamp": now.isoformat(),
}
async def tool_get_drops(db: AsyncSession, user: User, args: dict[str, Any]) -> dict[str, Any]:
tld = (args.get("tld") or None)
hours = _clamp_int(args.get("hours"), lo=1, hi=48, default=24)
limit = _clamp_int(args.get("limit"), lo=1, hi=100, default=50)
offset = _clamp_int(args.get("offset"), lo=0, hi=10000, default=0)
keyword = (args.get("keyword") or None)
min_length = args.get("min_length")
max_length = args.get("max_length")
exclude_numeric = bool(args.get("exclude_numeric", False))
exclude_hyphen = bool(args.get("exclude_hyphen", False))
result = await get_dropped_domains(
db=db,
tld=(tld.lower().lstrip(".") if isinstance(tld, str) and tld.strip() else None),
hours=hours,
min_length=int(min_length) if min_length is not None else None,
max_length=int(max_length) if max_length is not None else None,
exclude_numeric=exclude_numeric,
exclude_hyphen=exclude_hyphen,
keyword=(str(keyword).strip() if keyword else None),
limit=limit,
offset=offset,
)
return result
def get_tool_defs() -> list[ToolDef]:
return [
ToolDef(
name="get_subscription",
description="Get current user's subscription tier, features and limits.",
json_schema={"type": "object", "properties": {}, "additionalProperties": False},
min_tier=SubscriptionTier.TRADER,
handler=tool_get_subscription,
),
ToolDef(
name="get_dashboard_summary",
description="Get a compact snapshot: ending auctions, listing stats, trending TLDs.",
json_schema={"type": "object", "properties": {}, "additionalProperties": False},
min_tier=SubscriptionTier.TRADER,
handler=tool_get_dashboard_summary,
),
ToolDef(
name="list_watchlist",
description="List user's watchlist domains (monitored domains).",
json_schema={
"type": "object",
"properties": {
"page": {"type": "integer", "minimum": 1, "maximum": 50},
"per_page": {"type": "integer", "minimum": 1, "maximum": 50},
},
"additionalProperties": False,
},
min_tier=SubscriptionTier.TRADER,
handler=tool_list_watchlist,
),
ToolDef(
name="analyze_domain",
description="Run Pounce domain analysis (Authority/Market/Risk/Value) for a given domain.",
json_schema={
"type": "object",
"properties": {
"domain": {"type": "string"},
"fast": {"type": "boolean"},
"refresh": {"type": "boolean"},
},
"required": ["domain"],
"additionalProperties": False,
},
min_tier=SubscriptionTier.TRADER,
handler=tool_analyze_domain,
),
ToolDef(
name="market_feed",
description="Get current auction feed (filters: source, keyword, tld, ending window).",
json_schema={
"type": "object",
"properties": {
"source": {"type": "string", "enum": ["all", "pounce", "external"]},
"keyword": {"type": "string"},
"tld": {"type": "string"},
"sort_by": {"type": "string", "enum": ["time", "score"]},
"ending_within_hours": {"type": "integer", "minimum": 1, "maximum": 168},
"limit": {"type": "integer", "minimum": 1, "maximum": 50},
},
"additionalProperties": False,
},
min_tier=SubscriptionTier.TRADER,
handler=tool_market_feed,
),
ToolDef(
name="get_drops",
description="Get recently dropped domains from zone files (auth required).",
json_schema={
"type": "object",
"properties": {
"tld": {"type": "string"},
"hours": {"type": "integer", "minimum": 1, "maximum": 48},
"min_length": {"type": "integer", "minimum": 1, "maximum": 63},
"max_length": {"type": "integer", "minimum": 1, "maximum": 63},
"exclude_numeric": {"type": "boolean"},
"exclude_hyphen": {"type": "boolean"},
"keyword": {"type": "string"},
"limit": {"type": "integer", "minimum": 1, "maximum": 100},
"offset": {"type": "integer", "minimum": 0, "maximum": 10000},
},
"additionalProperties": False,
},
min_tier=SubscriptionTier.TRADER,
handler=tool_get_drops,
),
]
def tools_for_path(path: str) -> list[str]:
"""
Limit the visible tool list depending on the current Terminal page.
This keeps prompts smaller and makes the model more decisive.
"""
p = (path or "").split("?")[0]
if p.startswith("/terminal/hunt"):
return ["get_subscription", "get_dashboard_summary", "market_feed", "get_drops", "analyze_domain", "list_watchlist"]
if p.startswith("/terminal/market"):
return ["get_subscription", "market_feed", "analyze_domain"]
if p.startswith("/terminal/watchlist"):
return ["get_subscription", "list_watchlist", "analyze_domain"]
# default: allow a safe minimal set
return ["get_subscription", "get_dashboard_summary", "analyze_domain"]
async def execute_tool(db: AsyncSession, user: User, name: str, args: dict[str, Any], *, path: str) -> dict[str, Any]:
defs = {t.name: t for t in get_tool_defs()}
tool = defs.get(name)
if tool is None or tool.handler is None:
return {"error": f"Unknown tool: {name}"}
# Enforce tool allowed on this page
allowed = set(tools_for_path(path))
if name not in allowed:
return {"error": f"Tool not allowed for path: {name}"}
sub = await _get_subscription(db, user)
user_tier = sub.tier if sub else SubscriptionTier.SCOUT
try:
_require_tier(user_tier, tool.min_tier)
except PermissionError as e:
return {"error": str(e)}
try:
return await tool.handler(db, user, args or {})
except Exception as e:
return {"error": f"{type(e).__name__}: {e}"}
def tool_catalog_for_prompt(path: str) -> list[dict[str, Any]]:
allowed = set(tools_for_path(path))
out: list[dict[str, Any]] = []
for t in get_tool_defs():
if t.name in allowed:
out.append(
{
"name": t.name,
"description": t.description,
"schema": t.json_schema,
}
)
return out