pounce/backend/app/services/llm_gateway.py
Yves Gugger bd3046b782
Some checks failed
CI / Frontend Lint & Type Check (push) Has been cancelled
CI / Frontend Build (push) Has been cancelled
CI / Backend Lint (push) Has been cancelled
CI / Backend Tests (push) Has been cancelled
CI / Docker Build (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
Deploy / Build & Push Images (push) Has been cancelled
Deploy / Deploy to Server (push) Has been cancelled
Deploy / Notify (push) Has been cancelled
Add LLM gateway proxy endpoint (Trader/Tycoon)
2025-12-17 13:12:45 +01:00

54 lines
1.7 KiB
Python

from __future__ import annotations
import json
from typing import Any, AsyncIterator, Optional
import httpx
from app.config import get_settings
settings = get_settings()
class LLMGatewayError(RuntimeError):
pass
def _auth_headers() -> dict[str, str]:
key = (settings.llm_gateway_api_key or "").strip()
if not key:
raise LLMGatewayError("LLM gateway not configured (missing llm_gateway_api_key)")
return {"Authorization": f"Bearer {key}"}
async def chat_completions(payload: dict[str, Any]) -> dict[str, Any]:
"""
Non-streaming call to the LLM gateway (OpenAI-ish format).
"""
url = settings.llm_gateway_url.rstrip("/") + "/v1/chat/completions"
async with httpx.AsyncClient(timeout=60) as client:
r = await client.post(url, headers=_auth_headers(), json=payload)
if r.status_code >= 400:
raise LLMGatewayError(f"LLM gateway error: {r.status_code} {r.text[:500]}")
return r.json()
async def chat_completions_stream(payload: dict[str, Any]) -> AsyncIterator[bytes]:
"""
Streaming call to the LLM gateway. The gateway returns SSE; we proxy bytes through.
"""
url = settings.llm_gateway_url.rstrip("/") + "/v1/chat/completions"
timeout = httpx.Timeout(connect=10, read=None, write=10, pool=10)
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream("POST", url, headers=_auth_headers(), json=payload) as r:
if r.status_code >= 400:
body = await r.aread()
raise LLMGatewayError(f"LLM gateway stream error: {r.status_code} {body[:500].decode('utf-8','ignore')}")
async for chunk in r.aiter_bytes():
if chunk:
yield chunk