| | """
|
| | Security Middleware & Utilities
|
| | ===============================
|
| | Provides Rate Limiting and Security Headers for the API.
|
| | Supports differentiated rate limits per endpoint category.
|
| | """
|
| |
|
| | import time
|
| | from typing import Optional, Callable
|
| | from fastapi import Request, HTTPException, status, Response
|
| | from fastapi.responses import JSONResponse
|
| | from starlette.middleware.base import BaseHTTPMiddleware
|
| | from loguru import logger
|
| |
|
| | from mnemocore.core.config import get_config
|
| |
|
| |
|
| |
|
| | RATE_LIMIT_CONFIGS = {
|
| | "store": {
|
| | "requests": 100,
|
| | "window": 60,
|
| | "description": "Memory storage operations"
|
| | },
|
| | "query": {
|
| | "requests": 500,
|
| | "window": 60,
|
| | "description": "Query operations"
|
| | },
|
| | "concept": {
|
| | "requests": 100,
|
| | "window": 60,
|
| | "description": "Concept operations"
|
| | },
|
| | "analogy": {
|
| | "requests": 100,
|
| | "window": 60,
|
| | "description": "Analogy operations"
|
| | },
|
| | "default": {
|
| | "requests": 100,
|
| | "window": 60,
|
| | "description": "Default rate limit"
|
| | }
|
| | }
|
| |
|
| |
|
| | def get_endpoint_category(path: str) -> str:
|
| | """Determine the rate limit category based on the endpoint path."""
|
| | if "/store" in path:
|
| | return "store"
|
| | elif "/query" in path:
|
| | return "query"
|
| | elif "/concept" in path:
|
| | return "concept"
|
| | elif "/analogy" in path:
|
| | return "analogy"
|
| | return "default"
|
| |
|
| |
|
| | class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
| | """
|
| | Middleware to add security headers to every response.
|
| | """
|
| | async def dispatch(self, request: Request, call_next):
|
| | response = await call_next(request)
|
| | response.headers["X-Content-Type-Options"] = "nosniff"
|
| | response.headers["X-Frame-Options"] = "DENY"
|
| | response.headers["X-XSS-Protection"] = "1; mode=block"
|
| | response.headers["Content-Security-Policy"] = "default-src 'self'"
|
| | response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
| |
|
| |
|
| |
|
| | return response
|
| |
|
| |
|
| | class RateLimiter:
|
| | """
|
| | Dependency for rate limiting based on IP address using Redis.
|
| | Supports differentiated rate limits per endpoint category.
|
| |
|
| | Usage:
|
| | # Default rate limit (from config)
|
| | @app.post("/endpoint", dependencies=[Depends(RateLimiter())])
|
| |
|
| | # Custom rate limit
|
| | @app.post("/endpoint", dependencies=[Depends(RateLimiter(requests=100, window=60))])
|
| |
|
| | # Category-based rate limit
|
| | @app.post("/store", dependencies=[Depends(RateLimiter(category="store"))])
|
| | """
|
| | def __init__(
|
| | self,
|
| | requests: Optional[int] = None,
|
| | window: Optional[int] = None,
|
| | category: Optional[str] = None
|
| | ):
|
| | self.requests = requests
|
| | self.window = window
|
| | self.category = category
|
| |
|
| | async def __call__(self, request: Request) -> None:
|
| | config = get_config()
|
| | if not config.security.rate_limit_enabled:
|
| | return
|
| |
|
| |
|
| | container = getattr(request.app.state, 'container', None)
|
| | if not container or not container.redis_storage:
|
| | logger.warning("Redis storage not available, skipping rate limit.")
|
| | return
|
| |
|
| |
|
| | if self.category and self.category in RATE_LIMIT_CONFIGS:
|
| | category_config = RATE_LIMIT_CONFIGS[self.category]
|
| | limit = self.requests or category_config["requests"]
|
| | window = self.window or category_config["window"]
|
| | elif self.category:
|
| |
|
| | category = get_endpoint_category(request.url.path)
|
| | category_config = RATE_LIMIT_CONFIGS[category]
|
| | limit = self.requests or category_config["requests"]
|
| | window = self.window or category_config["window"]
|
| | else:
|
| |
|
| | limit = self.requests or config.security.rate_limit_requests
|
| | window = self.window or config.security.rate_limit_window
|
| |
|
| |
|
| | forwarded = request.headers.get("X-Forwarded-For")
|
| | if forwarded:
|
| | client_ip = forwarded.split(",")[0].strip()
|
| | else:
|
| | client_ip = request.client.host if request.client else "unknown"
|
| |
|
| |
|
| | effective_category = self.category or get_endpoint_category(request.url.path)
|
| |
|
| |
|
| |
|
| | current_time = int(time.time())
|
| | window_index = current_time // window
|
| | key = f"rate_limit:{effective_category}:{client_ip}:{window_index}"
|
| |
|
| | try:
|
| | redis = container.redis_storage.redis_client
|
| |
|
| |
|
| | async with redis.pipeline() as pipe:
|
| | pipe.incr(key)
|
| | pipe.expire(key, window + 10)
|
| | results = await pipe.execute()
|
| |
|
| | count = results[0]
|
| |
|
| | if count > limit:
|
| |
|
| | window_end = (window_index + 1) * window
|
| | retry_after = window_end - current_time
|
| |
|
| | logger.warning(f"Rate limit exceeded for {client_ip} on {effective_category}: {count}/{limit}")
|
| |
|
| | raise HTTPException(
|
| | status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
| | detail=f"Rate limit exceeded. Try again in {retry_after} seconds.",
|
| | headers={"Retry-After": str(retry_after)}
|
| | )
|
| |
|
| | except HTTPException:
|
| | raise
|
| | except Exception as e:
|
| |
|
| |
|
| | logger.error(f"Rate limiter error (failing open): {e}")
|
| | pass
|
| |
|
| |
|
| | class StoreRateLimiter(RateLimiter):
|
| | """Rate limiter for store endpoints: 100/minute."""
|
| | def __init__(self):
|
| | super().__init__(category="store")
|
| |
|
| |
|
| | class QueryRateLimiter(RateLimiter):
|
| | """Rate limiter for query endpoints: 500/minute."""
|
| | def __init__(self):
|
| | super().__init__(category="query")
|
| |
|
| |
|
| | class ConceptRateLimiter(RateLimiter):
|
| | """Rate limiter for concept endpoints: 100/minute."""
|
| | def __init__(self):
|
| | super().__init__(category="concept")
|
| |
|
| |
|
| | class AnalogyRateLimiter(RateLimiter):
|
| | """Rate limiter for analogy endpoints: 100/minute."""
|
| | def __init__(self):
|
| | super().__init__(category="analogy")
|
| |
|
| |
|
| | async def rate_limit_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
| | """
|
| | Custom exception handler for rate limit errors.
|
| | Ensures HTTP 429 with Retry-After header.
|
| | """
|
| | retry_after = exc.headers.get("Retry-After", "60") if exc.headers else "60"
|
| |
|
| | return JSONResponse(
|
| | status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
| | content={
|
| | "detail": exc.detail,
|
| | "error_type": "rate_limit_exceeded",
|
| | "retry_after": int(retry_after)
|
| | },
|
| | headers={"Retry-After": retry_after}
|
| | )
|
| |
|