Spaces:
Build error
Build error
| """ | |
| Usage Tracker for ARF API – quotas, tiers, and audit logging. | |
| Thread‑safe, atomic quota consumption, idempotent, fail‑closed. | |
| Extended for multi‑tenancy: each API key is linked to a tenant ID. | |
| Tenant ID is stored in the `api_keys` table and used for resource isolation. | |
| """ | |
| import json | |
| import sqlite3 | |
| import threading | |
| import time | |
| from contextlib import contextmanager | |
| from datetime import datetime, timedelta | |
| from dataclasses import dataclass | |
| from typing import Dict, Any, Optional, List, Tuple, Callable | |
| from enum import Enum | |
| from fastapi import BackgroundTasks, HTTPException, Request | |
| # Optional Redis support | |
| try: | |
| import redis | |
| REDIS_AVAILABLE = True | |
| except ImportError: | |
| REDIS_AVAILABLE = False | |
| redis = None | |
| class Tier(str, Enum): | |
| """Pricing tiers with associated quota limits and audit retention.""" | |
| FREE = "free" | |
| PRO = "pro" | |
| PREMIUM = "premium" | |
| ENTERPRISE = "enterprise" | |
| def monthly_evaluation_limit(self) -> Optional[int]: | |
| """Monthly evaluation quota. None = unlimited.""" | |
| limits = { | |
| Tier.FREE: 1000, | |
| Tier.PRO: 10_000, | |
| Tier.PREMIUM: 50_000, | |
| Tier.ENTERPRISE: None, | |
| } | |
| return limits[self] | |
| def audit_log_retention_days(self) -> int: | |
| """How many days to keep usage and decision audit logs.""" | |
| retention = { | |
| Tier.FREE: 7, | |
| Tier.PRO: 30, | |
| Tier.PREMIUM: 90, | |
| Tier.ENTERPRISE: 365, | |
| } | |
| return retention[self] | |
| class UsageRecord: | |
| """Single API call usage record (for quota and debugging).""" | |
| api_key: str | |
| tier: Tier | |
| timestamp: float | |
| endpoint: str | |
| request_body: Optional[Dict[str, Any]] = None | |
| response: Optional[Dict[str, Any]] = None | |
| error: Optional[str] = None | |
| processing_ms: Optional[float] = None | |
| class UsageTracker: | |
| """ | |
| Thread‑safe usage tracker with atomic quota consumption and idempotency. | |
| Extended to support tenant isolation: each API key is linked to a tenant. | |
| """ | |
| def __init__(self, db_path: str = "arf_usage.db", | |
| redis_url: Optional[str] = None): | |
| self.db_path = db_path | |
| self._local = threading.local() | |
| self._init_db() | |
| self._redis_client = None | |
| if redis_url and REDIS_AVAILABLE: | |
| self._redis_client = redis.from_url(redis_url) | |
| elif redis_url: | |
| raise ImportError("Redis client not installed. Run: pip install redis") | |
| def _get_conn(self): | |
| """Get a thread‑local SQLite connection with WAL and immediate transactions.""" | |
| if not hasattr(self._local, "conn"): | |
| self._local.conn = sqlite3.connect( | |
| self.db_path, check_same_thread=False, isolation_level=None) | |
| self._local.conn.row_factory = sqlite3.Row | |
| self._local.conn.execute("PRAGMA journal_mode=WAL") | |
| yield self._local.conn | |
| def _init_db(self): | |
| """Initialise SQLite tables with tenant_id support.""" | |
| with self._get_conn() as conn: | |
| # Modified: api_keys now has tenant_id column | |
| conn.execute(""" | |
| CREATE TABLE IF NOT EXISTS api_keys ( | |
| key TEXT PRIMARY KEY, | |
| tenant_id TEXT NOT NULL, | |
| tier TEXT NOT NULL, | |
| created_at REAL NOT NULL, | |
| last_used_at REAL, | |
| is_active INTEGER DEFAULT 1 | |
| ) | |
| """) | |
| conn.execute(""" | |
| CREATE TABLE IF NOT EXISTS usage_log ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| api_key TEXT NOT NULL, | |
| tier TEXT NOT NULL, | |
| timestamp REAL NOT NULL, | |
| endpoint TEXT NOT NULL, | |
| request_body TEXT, | |
| response TEXT, | |
| error TEXT, | |
| processing_ms REAL, | |
| idempotency_key TEXT UNIQUE | |
| ) | |
| """) | |
| conn.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_api_key_timestamp | |
| ON usage_log (api_key, timestamp) | |
| """) | |
| conn.execute(""" | |
| CREATE TABLE IF NOT EXISTS monthly_counts ( | |
| api_key TEXT NOT NULL, | |
| year_month TEXT NOT NULL, | |
| count INTEGER DEFAULT 0, | |
| PRIMARY KEY (api_key, year_month) | |
| ) | |
| """) | |
| conn.execute(""" | |
| CREATE TABLE IF NOT EXISTS idempotency_keys ( | |
| key TEXT PRIMARY KEY, | |
| consumed_at REAL NOT NULL | |
| ) | |
| """) | |
| conn.commit() | |
| def _get_month_key(self) -> str: | |
| return datetime.now().strftime("%Y-%m") | |
| def get_or_create_api_key(self, key: str, tenant_id: str, tier: Tier = Tier.FREE) -> bool: | |
| """ | |
| Register a new API key for a given tenant. | |
| Args: | |
| key: The API key (plain text, will be hashed in production). | |
| tenant_id: UUID of the tenant (must already exist in main DB). | |
| tier: Initial tier for the key. | |
| Returns: | |
| True if key was created (or already exists for the same tenant). | |
| """ | |
| with self._get_conn() as conn: | |
| row = conn.execute( | |
| "SELECT key FROM api_keys WHERE key = ?", (key,)).fetchone() | |
| if row: | |
| # Key already exists – ensure it belongs to the same tenant | |
| existing_tenant = conn.execute( | |
| "SELECT tenant_id FROM api_keys WHERE key = ?", (key,)).fetchone() | |
| if existing_tenant["tenant_id"] != tenant_id: | |
| raise ValueError(f"Key {key[:8]}... already belongs to a different tenant.") | |
| return True | |
| conn.execute( | |
| "INSERT INTO api_keys (key, tenant_id, tier, created_at, is_active) VALUES (?, ?, ?, ?, ?)", | |
| (key, tenant_id, tier.value, time.time(), 1) | |
| ) | |
| conn.commit() | |
| return True | |
| def get_tier(self, api_key: str) -> Optional[Tier]: | |
| """Return the tier for a given API key, or None if key invalid/inactive.""" | |
| with self._get_conn() as conn: | |
| row = conn.execute( | |
| "SELECT tier FROM api_keys WHERE key = ? AND is_active = 1", | |
| (api_key,) | |
| ).fetchone() | |
| if not row: | |
| return None | |
| return Tier(row["tier"]) | |
| def get_tenant_id(self, api_key: str) -> Optional[str]: | |
| """Return the tenant ID associated with the API key, or None if key invalid.""" | |
| with self._get_conn() as conn: | |
| row = conn.execute( | |
| "SELECT tenant_id FROM api_keys WHERE key = ? AND is_active = 1", | |
| (api_key,) | |
| ).fetchone() | |
| if not row: | |
| return None | |
| return row["tenant_id"] | |
| def update_api_key_tier(self, api_key: str, new_tier: Tier) -> bool: | |
| """Update the tier of an existing API key. Returns True if successful.""" | |
| with self._get_conn() as conn: | |
| row = conn.execute( | |
| "SELECT key FROM api_keys WHERE key = ?", (api_key,)).fetchone() | |
| if not row: | |
| return False | |
| conn.execute( | |
| "UPDATE api_keys SET tier = ? WHERE key = ?", | |
| (new_tier.value, api_key)) | |
| conn.commit() | |
| return True | |
| # -------------------------------------------------------------------------- | |
| # Atomic quota consumption (unchanged, but uses api_key which links to tenant) | |
| # -------------------------------------------------------------------------- | |
| def _consume_quota_atomic_sqlite(self, api_key: str, tier: Tier, month: str) -> bool: | |
| limit = tier.monthly_evaluation_limit | |
| if limit is None: | |
| with self._get_conn() as conn: | |
| conn.execute( | |
| "INSERT INTO monthly_counts (api_key, year_month, count) VALUES (?, ?, 1) " | |
| "ON CONFLICT(api_key, year_month) DO UPDATE SET count = count + 1", | |
| (api_key, month) | |
| ) | |
| conn.commit() | |
| return True | |
| with self._get_conn() as conn: | |
| conn.execute("BEGIN IMMEDIATE") | |
| try: | |
| row = conn.execute( | |
| "SELECT count FROM monthly_counts WHERE api_key = ? AND year_month = ?", | |
| (api_key, month) | |
| ).fetchone() | |
| current = row["count"] if row else 0 | |
| if current >= limit: | |
| conn.rollback() | |
| return False | |
| conn.execute( | |
| "INSERT INTO monthly_counts (api_key, year_month, count) VALUES (?, ?, 1) " | |
| "ON CONFLICT(api_key, year_month) DO UPDATE SET count = count + 1", | |
| (api_key, month) | |
| ) | |
| conn.commit() | |
| return True | |
| except Exception: | |
| conn.rollback() | |
| raise | |
| def _consume_quota_atomic_redis(self, api_key: str, tier: Tier, month: str) -> bool: | |
| limit = tier.monthly_evaluation_limit | |
| if limit is None: | |
| redis_key = f"arf:quota:{api_key}:{month}" | |
| self._redis_client.incr(redis_key) | |
| self._redis_client.expire(redis_key, timedelta(days=31)) | |
| return True | |
| lua_script = """ | |
| local key = KEYS[1] | |
| local limit = tonumber(ARGV[1]) | |
| local current = redis.call('GET', key) | |
| if current and tonumber(current) >= limit then | |
| return 0 | |
| end | |
| local new = redis.call('INCR', key) | |
| redis.call('EXPIRE', key, 2678400) | |
| return 1 | |
| """ | |
| redis_key = f"arf:quota:{api_key}:{month}" | |
| result = self._redis_client.eval(lua_script, 1, redis_key, limit) | |
| return result == 1 | |
| # -------------------------------------------------------------------------- | |
| # Idempotency handling (unchanged) | |
| # -------------------------------------------------------------------------- | |
| def _is_idempotent_key_used(self, key: str) -> bool: | |
| with self._get_conn() as conn: | |
| row = conn.execute( | |
| "SELECT 1 FROM idempotency_keys WHERE key = ?", (key,)).fetchone() | |
| return row is not None | |
| def _mark_idempotent_key_used(self, key: str, ttl_seconds: int = 86400): | |
| with self._get_conn() as conn: | |
| conn.execute( | |
| "INSERT INTO idempotency_keys (key, consumed_at) VALUES (?, ?)", | |
| (key, time.time()) | |
| ) | |
| conn.commit() | |
| # -------------------------------------------------------------------------- | |
| # Core usage recording (atomic + idempotent) – unchanged | |
| # -------------------------------------------------------------------------- | |
| def consume_quota_and_log(self, record: UsageRecord, idempotency_key: Optional[str] = None | |
| ) -> Tuple[bool, Optional[Dict[str, Any]]]: | |
| if idempotency_key and self._is_idempotent_key_used(idempotency_key): | |
| return False, {"idempotent": True, "message": "Already processed"} | |
| month = self._get_month_key() | |
| if self._redis_client: | |
| quota_ok = self._consume_quota_atomic_redis(record.api_key, record.tier, month) | |
| else: | |
| quota_ok = self._consume_quota_atomic_sqlite(record.api_key, record.tier, month) | |
| if not quota_ok: | |
| return False, None | |
| try: | |
| with self._get_conn() as conn: | |
| conn.execute( | |
| """INSERT INTO usage_log | |
| (api_key, tier, timestamp, endpoint, request_body, response, error, processing_ms, idempotency_key) | |
| VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", | |
| (record.api_key, record.tier.value, record.timestamp, record.endpoint, | |
| json.dumps(record.request_body) if record.request_body else None, | |
| json.dumps(record.response) if record.response else None, | |
| record.error, record.processing_ms, idempotency_key) | |
| ) | |
| conn.commit() | |
| except sqlite3.IntegrityError as e: | |
| if "UNIQUE constraint failed: usage_log.idempotency_key" in str(e): | |
| return False, {"idempotent": True, "message": "Already processed"} | |
| raise | |
| if idempotency_key: | |
| self._mark_idempotent_key_used(idempotency_key) | |
| return True, None | |
| # -------------------------------------------------------------------------- | |
| # Legacy interface (kept for compatibility) | |
| # -------------------------------------------------------------------------- | |
| def increment_usage_sync(self, record: UsageRecord, idempotency_key: Optional[str] = None) -> bool: | |
| success, _ = self.consume_quota_and_log(record, idempotency_key) | |
| return success | |
| async def increment_usage_async(self, record: UsageRecord, background_tasks: BackgroundTasks, | |
| idempotency_key: Optional[str] = None) -> bool: | |
| return self.increment_usage_sync(record, idempotency_key) | |
| # -------------------------------------------------------------------------- | |
| # Quota inspection | |
| # -------------------------------------------------------------------------- | |
| def get_remaining_quota(self, api_key: str, tier: Tier) -> Optional[int]: | |
| limit = tier.monthly_evaluation_limit | |
| if limit is None: | |
| return None | |
| month = self._get_month_key() | |
| if self._redis_client: | |
| redis_key = f"arf:quota:{api_key}:{month}" | |
| count = int(self._redis_client.get(redis_key) or 0) | |
| return max(0, limit - count) | |
| with self._get_conn() as conn: | |
| row = conn.execute( | |
| "SELECT count FROM monthly_counts WHERE api_key = ? AND year_month = ?", | |
| (api_key, month) | |
| ).fetchone() | |
| count = row["count"] if row else 0 | |
| return max(0, limit - count) | |
| # -------------------------------------------------------------------------- | |
| # Audit and maintenance (kept for usage_log) | |
| # -------------------------------------------------------------------------- | |
| def get_audit_logs(self, api_key: str, start_date: Optional[datetime] = None, | |
| end_date: Optional[datetime] = None, limit: int = 100) -> List[Dict[str, Any]]: | |
| query = "SELECT * FROM usage_log WHERE api_key = ?" | |
| params = [api_key] | |
| if start_date: | |
| query += " AND timestamp >= ?" | |
| params.append(start_date.timestamp()) | |
| if end_date: | |
| query += " AND timestamp <= ?" | |
| params.append(end_date.timestamp()) | |
| query += " ORDER BY timestamp DESC LIMIT ?" | |
| params.append(limit) | |
| with self._get_conn() as conn: | |
| rows = conn.execute(query, params).fetchall() | |
| return [dict(row) for row in rows] | |
| def clean_old_logs(self): | |
| with self._get_conn() as conn: | |
| for tier in Tier: | |
| retention_days = tier.audit_log_retention_days | |
| cutoff = time.time() - retention_days * 86400 | |
| conn.execute( | |
| "DELETE FROM usage_log WHERE tier = ? AND timestamp < ?", | |
| (tier.value, cutoff) | |
| ) | |
| cutoff = time.time() - 7 * 86400 | |
| conn.execute("DELETE FROM idempotency_keys WHERE consumed_at < ?", (cutoff,)) | |
| conn.commit() | |
| # -------------------------------------------------------------------------- | |
| # Global instance and FastAPI dependency | |
| # -------------------------------------------------------------------------- | |
| tracker: Optional[UsageTracker] = None | |
| def init_tracker(db_path: str = "arf_usage.db", redis_url: Optional[str] = None): | |
| global tracker | |
| tracker = UsageTracker(db_path, redis_url) | |
| def update_key_tier(api_key: str, new_tier: Tier) -> bool: | |
| if tracker is None: | |
| return False | |
| return tracker.update_api_key_tier(api_key, new_tier) | |
| async def enforce_quota(request: Request, api_key: str = None): | |
| """ | |
| FastAPI dependency that enforces quota and attaches tenant_id to request state. | |
| """ | |
| if tracker is None: | |
| raise HTTPException(status_code=503, detail="Usage tracking service not initialised.") | |
| if api_key is None: | |
| auth_header = request.headers.get("Authorization") | |
| if auth_header and auth_header.startswith("Bearer "): | |
| api_key = auth_header[7:] | |
| else: | |
| api_key = request.query_params.get("api_key") | |
| if not api_key: | |
| raise HTTPException(status_code=401, detail="Missing API key") | |
| tier = tracker.get_tier(api_key) | |
| if tier is None: | |
| raise HTTPException(status_code=403, detail="Invalid or inactive API key") | |
| remaining = tracker.get_remaining_quota(api_key, tier) | |
| if remaining is not None and remaining <= 0: | |
| raise HTTPException(status_code=429, detail="Monthly evaluation quota exceeded") | |
| # Retrieve tenant_id | |
| tenant_id = tracker.get_tenant_id(api_key) | |
| if not tenant_id: | |
| raise HTTPException(status_code=403, detail="API key not associated with a tenant") | |
| request.state.api_key = api_key | |
| request.state.tier = tier | |
| request.state.tenant_id = tenant_id | |
| return {"api_key": api_key, "tier": tier, "tenant_id": tenant_id, "remaining": remaining} | |