"""Prometheus metrics for FastAPI + optional DB query metrics.""" from __future__ import annotations import time from typing import Optional from fastapi import FastAPI, Request, Response try: from prometheus_client import Counter, Histogram, generate_latest, CONTENT_TYPE_LATEST except Exception: # pragma: no cover Counter = None # type: ignore Histogram = None # type: ignore generate_latest = None # type: ignore CONTENT_TYPE_LATEST = "text/plain; version=0.0.4" # type: ignore _instrumented = False _db_instrumented = False def _get_route_template(request: Request) -> str: route = request.scope.get("route") if route is not None and hasattr(route, "path"): return str(route.path) return request.url.path def instrument_app(app: FastAPI, *, metrics_path: str = "/metrics", enable_db_metrics: bool = False) -> None: """ Add Prometheus request metrics and a `/metrics` endpoint. - Low-cardinality path labels by using FastAPI route templates. - Optional SQLAlchemy query timing metrics (off by default). """ global _instrumented if _instrumented: return _instrumented = True if Counter is None or Histogram is None: # Dependency not installed; keep app working without metrics. return http_requests_total = Counter( "http_requests_total", "Total HTTP requests", ["method", "path", "status"], ) http_request_duration_seconds = Histogram( "http_request_duration_seconds", "HTTP request duration (seconds)", ["method", "path"], buckets=(0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10), ) @app.middleware("http") async def _metrics_middleware(request: Request, call_next): start = time.perf_counter() response: Optional[Response] = None try: response = await call_next(request) return response finally: duration = time.perf_counter() - start path = _get_route_template(request) method = request.method status = str(getattr(response, "status_code", 500)) http_requests_total.labels(method=method, path=path, status=status).inc() http_request_duration_seconds.labels(method=method, path=path).observe(duration) @app.get(metrics_path, include_in_schema=False) async def _metrics_endpoint(): return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST) if enable_db_metrics: _instrument_db_metrics() def _instrument_db_metrics() -> None: """Attach SQLAlchemy event listeners to track query latencies.""" global _db_instrumented if _db_instrumented: return _db_instrumented = True if Counter is None or Histogram is None: return from sqlalchemy import event from app.database import engine db_queries_total = Counter( "db_queries_total", "Total DB queries executed", ["dialect"], ) db_query_duration_seconds = Histogram( "db_query_duration_seconds", "DB query duration (seconds)", ["dialect"], buckets=(0.001, 0.0025, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5), ) dialect = engine.sync_engine.dialect.name @event.listens_for(engine.sync_engine, "before_cursor_execute") def _before_cursor_execute(conn, cursor, statement, parameters, context, executemany): # type: ignore[no-untyped-def] conn.info.setdefault("_query_start_time", []).append(time.perf_counter()) @event.listens_for(engine.sync_engine, "after_cursor_execute") def _after_cursor_execute(conn, cursor, statement, parameters, context, executemany): # type: ignore[no-untyped-def] start_list = conn.info.get("_query_start_time") or [] if not start_list: return start = start_list.pop() duration = time.perf_counter() - start db_queries_total.labels(dialect=dialect).inc() db_query_duration_seconds.labels(dialect=dialect).observe(duration)