BankBot-AI / backend /app /middleware /logging.py
mohsin-devs's picture
Deploy to HF
a282d4b
Raw
History Blame Contribute Delete
7.9 kB
"""
Structured logging middleware β€” JSON logs with request tracing,
timing, AI provider health, cache hit ratios, and WebSocket events.
"""
import json
import logging
import time
import uuid
from collections import defaultdict, deque
from datetime import datetime
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
# ─── Structured JSON logger ───────────────────────────────────────────────────
class JSONFormatter(logging.Formatter):
def format(self, record: logging.LogRecord) -> str:
log = {
"ts": datetime.utcnow().isoformat() + "Z",
"level": record.levelname,
"logger": record.name,
"msg": record.getMessage(),
}
if hasattr(record, "extra"):
log.update(record.extra)
if record.exc_info:
log["exc"] = self.formatException(record.exc_info)
return json.dumps(log)
def get_logger(name: str) -> logging.Logger:
logger = logging.getLogger(name)
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(JSONFormatter())
logger.addHandler(handler)
logger.setLevel(logging.INFO)
logger.propagate = False
return logger
api_logger = get_logger("bankbot.api")
ai_logger = get_logger("bankbot.ai")
ws_logger = get_logger("bankbot.ws")
db_logger = get_logger("bankbot.db")
# ─── In-process metrics store ─────────────────────────────────────────────────
class MetricsStore:
"""Thread-safe in-memory metrics β€” no external dependency."""
def __init__(self):
self.request_count: int = 0
self.error_count: int = 0
self.auth_failures: int = 0
self.ws_connects: int = 0
self.ws_reconnects: int = 0
self.ai_calls: dict = defaultdict(int) # provider β†’ count
self.ai_errors: dict = defaultdict(int) # provider β†’ errors
self.ai_latencies: dict = defaultdict(list) # provider β†’ [ms]
self.ai_fallbacks: int = 0
self.cache_hits: int = 0
self.cache_misses: int = 0
self.route_timings: dict = defaultdict(list) # path β†’ [ms]
self._recent_errors: deque = deque(maxlen=50) # last 50 errors
self.start_time: float = time.time()
# ── AI tracking ──────────────────────────────────────────────────────────
def record_ai_call(self, provider: str, latency_ms: float, success: bool):
self.ai_calls[provider] += 1
self.ai_latencies[provider].append(latency_ms)
if len(self.ai_latencies[provider]) > 200:
self.ai_latencies[provider] = self.ai_latencies[provider][-200:]
if not success:
self.ai_errors[provider] += 1
def record_ai_fallback(self):
self.ai_fallbacks += 1
# ── Cache tracking ────────────────────────────────────────────────────────
def record_cache_hit(self):
self.cache_hits += 1
def record_cache_miss(self):
self.cache_misses += 1
# ── Error tracking ────────────────────────────────────────────────────────
def record_error(self, path: str, status: int, detail: str):
self._recent_errors.append({
"ts": datetime.utcnow().isoformat() + "Z",
"path": path,
"status": status,
"detail": detail[:200],
})
self.error_count += 1
if status == 401:
self.auth_failures += 1
# ── Summary ───────────────────────────────────────────────────────────────
def summary(self) -> dict:
uptime = time.time() - self.start_time
cache_total = self.cache_hits + self.cache_misses
cache_ratio = round(self.cache_hits / cache_total * 100, 1) if cache_total else 0
ai_summary = {}
for provider in set(list(self.ai_calls.keys()) + list(self.ai_errors.keys())):
lats = self.ai_latencies.get(provider, [])
ai_summary[provider] = {
"calls": self.ai_calls[provider],
"errors": self.ai_errors[provider],
"avg_latency_ms": round(sum(lats) / len(lats), 1) if lats else 0,
"p95_latency_ms": round(sorted(lats)[int(len(lats) * 0.95)], 1) if len(lats) >= 20 else None,
}
slow_routes = {}
for path, times in self.route_timings.items():
if times:
slow_routes[path] = {
"calls": len(times),
"avg_ms": round(sum(times) / len(times), 1),
"max_ms": round(max(times), 1),
}
return {
"uptime_seconds": round(uptime, 0),
"requests": {
"total": self.request_count,
"errors": self.error_count,
"auth_failures": self.auth_failures,
"error_rate_pct": round(self.error_count / max(self.request_count, 1) * 100, 2),
},
"websocket": {
"total_connects": self.ws_connects,
"reconnects": self.ws_reconnects,
},
"ai": {
"fallbacks": self.ai_fallbacks,
"by_provider": ai_summary,
},
"cache": {
"hits": self.cache_hits,
"misses": self.cache_misses,
"hit_ratio_pct": cache_ratio,
},
"route_timings": dict(sorted(slow_routes.items(), key=lambda x: -x[1]["avg_ms"])[:10]),
"recent_errors": list(self._recent_errors)[-10:],
}
metrics = MetricsStore()
# ─── Request logging middleware ───────────────────────────────────────────────
class RequestLoggingMiddleware(BaseHTTPMiddleware):
SKIP_PATHS = {"/health", "/openapi.json", "/docs", "/redoc", "/docs/oauth2-redirect"}
async def dispatch(self, request: Request, call_next: Callable) -> Response:
if request.url.path in self.SKIP_PATHS:
return await call_next(request)
request_id = str(uuid.uuid4())[:8]
start = time.time()
metrics.request_count += 1
response = await call_next(request)
elapsed_ms = (time.time() - start) * 1000
path = request.url.path
metrics.route_timings[path].append(elapsed_ms)
if len(metrics.route_timings[path]) > 500:
metrics.route_timings[path] = metrics.route_timings[path][-500:]
level = logging.WARNING if elapsed_ms > 2000 else logging.INFO
if response.status_code >= 400:
metrics.record_error(path, response.status_code, "")
level = logging.WARNING if response.status_code < 500 else logging.ERROR
api_logger.log(level, f"{request.method} {path}", extra={
"request_id": request_id,
"method": request.method,
"path": path,
"status": response.status_code,
"duration_ms": round(elapsed_ms, 1),
"ip": request.client.host if request.client else "unknown",
})
response.headers["X-Request-ID"] = request_id
return response