""" Structured logging middleware — JSON logs with request tracing, timing, AI provider health, cache hit ratios, and WebSocket events. """ import json import logging import time import uuid from collections import defaultdict, deque from datetime import datetime from typing import Callable from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware # ─── Structured JSON logger ─────────────────────────────────────────────────── class JSONFormatter(logging.Formatter): def format(self, record: logging.LogRecord) -> str: log = { "ts": datetime.utcnow().isoformat() + "Z", "level": record.levelname, "logger": record.name, "msg": record.getMessage(), } if hasattr(record, "extra"): log.update(record.extra) if record.exc_info: log["exc"] = self.formatException(record.exc_info) return json.dumps(log) def get_logger(name: str) -> logging.Logger: logger = logging.getLogger(name) if not logger.handlers: handler = logging.StreamHandler() handler.setFormatter(JSONFormatter()) logger.addHandler(handler) logger.setLevel(logging.INFO) logger.propagate = False return logger api_logger = get_logger("bankbot.api") ai_logger = get_logger("bankbot.ai") ws_logger = get_logger("bankbot.ws") db_logger = get_logger("bankbot.db") # ─── In-process metrics store ───────────────────────────────────────────────── class MetricsStore: """Thread-safe in-memory metrics — no external dependency.""" def __init__(self): self.request_count: int = 0 self.error_count: int = 0 self.auth_failures: int = 0 self.ws_connects: int = 0 self.ws_reconnects: int = 0 self.ai_calls: dict = defaultdict(int) # provider → count self.ai_errors: dict = defaultdict(int) # provider → errors self.ai_latencies: dict = defaultdict(list) # provider → [ms] self.ai_fallbacks: int = 0 self.cache_hits: int = 0 self.cache_misses: int = 0 self.route_timings: dict = defaultdict(list) # path → [ms] self._recent_errors: deque = deque(maxlen=50) # last 50 errors self.start_time: float = time.time() # ── AI tracking ────────────────────────────────────────────────────────── def record_ai_call(self, provider: str, latency_ms: float, success: bool): self.ai_calls[provider] += 1 self.ai_latencies[provider].append(latency_ms) if len(self.ai_latencies[provider]) > 200: self.ai_latencies[provider] = self.ai_latencies[provider][-200:] if not success: self.ai_errors[provider] += 1 def record_ai_fallback(self): self.ai_fallbacks += 1 # ── Cache tracking ──────────────────────────────────────────────────────── def record_cache_hit(self): self.cache_hits += 1 def record_cache_miss(self): self.cache_misses += 1 # ── Error tracking ──────────────────────────────────────────────────────── def record_error(self, path: str, status: int, detail: str): self._recent_errors.append({ "ts": datetime.utcnow().isoformat() + "Z", "path": path, "status": status, "detail": detail[:200], }) self.error_count += 1 if status == 401: self.auth_failures += 1 # ── Summary ─────────────────────────────────────────────────────────────── def summary(self) -> dict: uptime = time.time() - self.start_time cache_total = self.cache_hits + self.cache_misses cache_ratio = round(self.cache_hits / cache_total * 100, 1) if cache_total else 0 ai_summary = {} for provider in set(list(self.ai_calls.keys()) + list(self.ai_errors.keys())): lats = self.ai_latencies.get(provider, []) ai_summary[provider] = { "calls": self.ai_calls[provider], "errors": self.ai_errors[provider], "avg_latency_ms": round(sum(lats) / len(lats), 1) if lats else 0, "p95_latency_ms": round(sorted(lats)[int(len(lats) * 0.95)], 1) if len(lats) >= 20 else None, } slow_routes = {} for path, times in self.route_timings.items(): if times: slow_routes[path] = { "calls": len(times), "avg_ms": round(sum(times) / len(times), 1), "max_ms": round(max(times), 1), } return { "uptime_seconds": round(uptime, 0), "requests": { "total": self.request_count, "errors": self.error_count, "auth_failures": self.auth_failures, "error_rate_pct": round(self.error_count / max(self.request_count, 1) * 100, 2), }, "websocket": { "total_connects": self.ws_connects, "reconnects": self.ws_reconnects, }, "ai": { "fallbacks": self.ai_fallbacks, "by_provider": ai_summary, }, "cache": { "hits": self.cache_hits, "misses": self.cache_misses, "hit_ratio_pct": cache_ratio, }, "route_timings": dict(sorted(slow_routes.items(), key=lambda x: -x[1]["avg_ms"])[:10]), "recent_errors": list(self._recent_errors)[-10:], } metrics = MetricsStore() # ─── Request logging middleware ─────────────────────────────────────────────── class RequestLoggingMiddleware(BaseHTTPMiddleware): SKIP_PATHS = {"/health", "/openapi.json", "/docs", "/redoc", "/docs/oauth2-redirect"} async def dispatch(self, request: Request, call_next: Callable) -> Response: if request.url.path in self.SKIP_PATHS: return await call_next(request) request_id = str(uuid.uuid4())[:8] start = time.time() metrics.request_count += 1 response = await call_next(request) elapsed_ms = (time.time() - start) * 1000 path = request.url.path metrics.route_timings[path].append(elapsed_ms) if len(metrics.route_timings[path]) > 500: metrics.route_timings[path] = metrics.route_timings[path][-500:] level = logging.WARNING if elapsed_ms > 2000 else logging.INFO if response.status_code >= 400: metrics.record_error(path, response.status_code, "") level = logging.WARNING if response.status_code < 500 else logging.ERROR api_logger.log(level, f"{request.method} {path}", extra={ "request_id": request_id, "method": request.method, "path": path, "status": response.status_code, "duration_ms": round(elapsed_ms, 1), "ip": request.client.host if request.client else "unknown", }) response.headers["X-Request-ID"] = request_id return response