| """ |
| Usage Tracker for ARF API – quotas, tiers, and audit logging. |
| Thread‑safe, atomic quota consumption, idempotent, fail‑closed. |
| """ |
|
|
| import json |
| import sqlite3 |
| import threading |
| import time |
| from contextlib import contextmanager |
| from datetime import datetime, timedelta |
| from dataclasses import dataclass |
| from typing import Dict, Any, Optional, List, Tuple |
| from enum import Enum |
| from fastapi import BackgroundTasks, HTTPException, Request |
|
|
| |
| try: |
| import redis |
| REDIS_AVAILABLE = True |
| except ImportError: |
| REDIS_AVAILABLE = False |
| redis = None |
|
|
|
|
| class Tier(str, Enum): |
| FREE = "free" |
| PRO = "pro" |
| PREMIUM = "premium" |
| ENTERPRISE = "enterprise" |
|
|
| @property |
| def monthly_evaluation_limit(self) -> Optional[int]: |
| limits = { |
| Tier.FREE: 1000, |
| Tier.PRO: 10_000, |
| Tier.PREMIUM: 50_000, |
| Tier.ENTERPRISE: None, |
| } |
| return limits[self] |
|
|
| @property |
| def audit_log_retention_days(self) -> int: |
| retention = { |
| Tier.FREE: 7, |
| Tier.PRO: 30, |
| Tier.PREMIUM: 90, |
| Tier.ENTERPRISE: 365, |
| } |
| return retention[self] |
|
|
|
|
| @dataclass |
| class UsageRecord: |
| """Single evaluation usage record.""" |
| api_key: str |
| tier: Tier |
| timestamp: float |
| endpoint: str |
| request_body: Optional[Dict[str, Any]] = None |
| response: Optional[Dict[str, Any]] = None |
| error: Optional[str] = None |
| processing_ms: Optional[float] = None |
|
|
|
|
| class UsageTracker: |
| """ |
| Thread‑safe usage tracker with atomic quota consumption and idempotency. |
| """ |
|
|
| def __init__(self, db_path: str = "arf_usage.db", |
| redis_url: Optional[str] = None): |
| self.db_path = db_path |
| self._local = threading.local() |
| self._init_db() |
|
|
| self._redis_client = None |
| if redis_url and REDIS_AVAILABLE: |
| self._redis_client = redis.from_url(redis_url) |
| elif redis_url: |
| raise ImportError( |
| "Redis client not installed. Run: pip install redis") |
|
|
| @contextmanager |
| def _get_conn(self): |
| """Get a thread‑local SQLite connection with write‑ahead logging and immediate transactions.""" |
| if not hasattr(self._local, "conn"): |
| self._local.conn = sqlite3.connect( |
| self.db_path, check_same_thread=False, isolation_level=None) |
| self._local.conn.row_factory = sqlite3.Row |
| self._local.conn.execute("PRAGMA journal_mode=WAL") |
| yield self._local.conn |
|
|
| def _init_db(self): |
| with self._get_conn() as conn: |
| conn.execute(""" |
| CREATE TABLE IF NOT EXISTS api_keys ( |
| key TEXT PRIMARY KEY, |
| tier TEXT NOT NULL, |
| created_at REAL NOT NULL, |
| last_used_at REAL, |
| is_active INTEGER DEFAULT 1 |
| ) |
| """) |
| conn.execute(""" |
| CREATE TABLE IF NOT EXISTS usage_log ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| api_key TEXT NOT NULL, |
| tier TEXT NOT NULL, |
| timestamp REAL NOT NULL, |
| endpoint TEXT NOT NULL, |
| request_body TEXT, |
| response TEXT, |
| error TEXT, |
| processing_ms REAL, |
| idempotency_key TEXT UNIQUE |
| ) |
| """) |
| conn.execute(""" |
| CREATE INDEX IF NOT EXISTS idx_api_key_timestamp |
| ON usage_log (api_key, timestamp) |
| """) |
| conn.execute(""" |
| CREATE TABLE IF NOT EXISTS monthly_counts ( |
| api_key TEXT NOT NULL, |
| year_month TEXT NOT NULL, |
| count INTEGER DEFAULT 0, |
| PRIMARY KEY (api_key, year_month) |
| ) |
| """) |
| conn.execute(""" |
| CREATE TABLE IF NOT EXISTS idempotency_keys ( |
| key TEXT PRIMARY KEY, |
| consumed_at REAL NOT NULL |
| ) |
| """) |
| conn.commit() |
|
|
| def _get_month_key(self) -> str: |
| return datetime.now().strftime("%Y-%m") |
|
|
| def get_or_create_api_key(self, key: str, tier: Tier = Tier.FREE) -> bool: |
| """Register a new API key. Returns True if key exists or was created.""" |
| with self._get_conn() as conn: |
| row = conn.execute( |
| "SELECT key FROM api_keys WHERE key = ?", (key,)).fetchone() |
| if row: |
| return True |
| conn.execute( |
| "INSERT INTO api_keys (key, tier, created_at, is_active) VALUES (?, ?, ?, ?)", |
| (key, tier.value, time.time(), 1) |
| ) |
| conn.commit() |
| return True |
|
|
| def get_tier(self, api_key: str) -> Optional[Tier]: |
| """Return the tier for a given API key, or None if key invalid/inactive.""" |
| with self._get_conn() as conn: |
| row = conn.execute( |
| "SELECT tier FROM api_keys WHERE key = ? AND is_active = 1", |
| (api_key,) |
| ).fetchone() |
| if not row: |
| return None |
| return Tier(row["tier"]) |
|
|
| def update_api_key_tier(self, api_key: str, new_tier: Tier) -> bool: |
| """Update the tier of an existing API key. Returns True if successful.""" |
| with self._get_conn() as conn: |
| row = conn.execute( |
| "SELECT key FROM api_keys WHERE key = ?", (api_key,)).fetchone() |
| if not row: |
| return False |
| conn.execute( |
| "UPDATE api_keys SET tier = ? WHERE key = ?", |
| (new_tier.value, |
| api_key)) |
| conn.commit() |
| return True |
|
|
| |
| |
| |
| def _consume_quota_atomic_sqlite( |
| self, |
| api_key: str, |
| tier: Tier, |
| month: str) -> bool: |
| """ |
| Atomically increment counter only if under limit. |
| Returns True if quota was consumed, False if limit reached. |
| """ |
| limit = tier.monthly_evaluation_limit |
| if limit is None: |
| |
| with self._get_conn() as conn: |
| conn.execute( |
| """INSERT INTO monthly_counts (api_key, year_month, count) |
| VALUES (?, ?, 1) |
| ON CONFLICT(api_key, year_month) DO UPDATE SET count = count + 1""", |
| (api_key, month) |
| ) |
| conn.commit() |
| return True |
|
|
| |
| with self._get_conn() as conn: |
| conn.execute("BEGIN IMMEDIATE") |
| try: |
| |
| row = conn.execute( |
| "SELECT count FROM monthly_counts WHERE api_key = ? AND year_month = ?", |
| (api_key, month) |
| ).fetchone() |
| current = row["count"] if row else 0 |
| if current >= limit: |
| conn.rollback() |
| return False |
| |
| conn.execute( |
| """INSERT INTO monthly_counts (api_key, year_month, count) |
| VALUES (?, ?, 1) |
| ON CONFLICT(api_key, year_month) DO UPDATE SET count = count + 1""", |
| (api_key, month) |
| ) |
| conn.commit() |
| return True |
| except Exception: |
| conn.rollback() |
| raise |
|
|
| def _consume_quota_atomic_redis( |
| self, |
| api_key: str, |
| tier: Tier, |
| month: str) -> bool: |
| """Atomic Lua script for Redis: INCR only if below limit.""" |
| limit = tier.monthly_evaluation_limit |
| if limit is None: |
| |
| redis_key = f"arf:quota:{api_key}:{month}" |
| self._redis_client.incr(redis_key) |
| self._redis_client.expire(redis_key, timedelta(days=31)) |
| return True |
|
|
| lua_script = """ |
| local key = KEYS[1] |
| local limit = tonumber(ARGV[1]) |
| local current = redis.call('GET', key) |
| if current and tonumber(current) >= limit then |
| return 0 |
| end |
| local new = redis.call('INCR', key) |
| redis.call('EXPIRE', key, 2678400) -- 31 days |
| return 1 |
| """ |
| redis_key = f"arf:quota:{api_key}:{month}" |
| result = self._redis_client.eval(lua_script, 1, redis_key, limit) |
| return result == 1 |
|
|
| |
| |
| |
| def _is_idempotent_key_used(self, key: str) -> bool: |
| """Check if idempotency key already processed.""" |
| with self._get_conn() as conn: |
| row = conn.execute( |
| "SELECT 1 FROM idempotency_keys WHERE key = ?", (key,)).fetchone() |
| return row is not None |
|
|
| def _mark_idempotent_key_used(self, key: str, ttl_seconds: int = 86400): |
| """Store idempotency key with expiration (cleanup later).""" |
| with self._get_conn() as conn: |
| conn.execute( |
| "INSERT INTO idempotency_keys (key, consumed_at) VALUES (?, ?)", |
| (key, time.time()) |
| ) |
| conn.commit() |
| |
| |
|
|
| |
| |
| |
| def consume_quota_and_log( |
| self, |
| record: UsageRecord, |
| idempotency_key: Optional[str] = None, |
| ) -> Tuple[bool, Optional[Dict[str, Any]]]: |
| """ |
| Atomically consume quota and insert audit log. |
| Returns (success, existing_response) where existing_response is not None |
| only when idempotency_key matched a previous successful call. |
| """ |
| |
| if idempotency_key: |
| if self._is_idempotent_key_used(idempotency_key): |
| |
| |
| |
| |
| return False, {"idempotent": True, |
| "message": "Already processed"} |
|
|
| month = self._get_month_key() |
| |
| if self._redis_client: |
| quota_ok = self._consume_quota_atomic_redis( |
| record.api_key, record.tier, month) |
| else: |
| quota_ok = self._consume_quota_atomic_sqlite( |
| record.api_key, record.tier, month) |
|
|
| if not quota_ok: |
| return False, None |
|
|
| |
| try: |
| with self._get_conn() as conn: |
| conn.execute( |
| """INSERT INTO usage_log |
| (api_key, tier, timestamp, endpoint, |
| request_body, response, error, processing_ms, |
| idempotency_key) |
| VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", |
| (record.api_key, |
| record.tier.value, |
| record.timestamp, |
| record.endpoint, |
| json.dumps( |
| record.request_body) if record.request_body else None, |
| json.dumps( |
| record.response) if record.response else None, |
| record.error, |
| record.processing_ms, |
| idempotency_key, |
| )) |
| conn.commit() |
| except sqlite3.IntegrityError as e: |
| |
| |
| if "UNIQUE constraint failed: usage_log.idempotency_key" in str(e): |
| return False, {"idempotent": True, |
| "message": "Already processed"} |
| raise |
|
|
| if idempotency_key: |
| self._mark_idempotent_key_used(idempotency_key) |
| |
| return True, None |
|
|
| |
| |
| |
| def increment_usage_sync( |
| self, |
| record: UsageRecord, |
| idempotency_key: Optional[str] = None) -> bool: |
| """ |
| Synchronously record usage and increment counter. |
| Returns True if within quota and recorded, False otherwise. |
| This method now uses the atomic implementation. |
| """ |
| success, _ = self.consume_quota_and_log(record, idempotency_key) |
| return success |
|
|
| async def increment_usage_async( |
| self, |
| record: UsageRecord, |
| background_tasks: BackgroundTasks, |
| idempotency_key: Optional[str] = None |
| ) -> bool: |
| """ |
| Asynchronously record usage using FastAPI BackgroundTasks. |
| Still does the atomic check synchronously, then schedules the insert. |
| """ |
| |
| |
| |
| |
| |
| |
| return self.increment_usage_sync(record, idempotency_key) |
|
|
| |
| |
| |
| def get_remaining_quota(self, api_key: str, tier: Tier) -> Optional[int]: |
| """Return remaining evaluations for the month (non‑atomic, for info only).""" |
| limit = tier.monthly_evaluation_limit |
| if limit is None: |
| return None |
|
|
| month = self._get_month_key() |
| if self._redis_client: |
| redis_key = f"arf:quota:{api_key}:{month}" |
| count = int(self._redis_client.get(redis_key) or 0) |
| return max(0, limit - count) |
|
|
| with self._get_conn() as conn: |
| row = conn.execute( |
| "SELECT count FROM monthly_counts WHERE api_key = ? AND year_month = ?", |
| (api_key, month) |
| ).fetchone() |
| count = row["count"] if row else 0 |
| return max(0, limit - count) |
|
|
| |
| |
| |
| def get_audit_logs( |
| self, |
| api_key: str, |
| start_date: Optional[datetime] = None, |
| end_date: Optional[datetime] = None, |
| limit: int = 100, |
| ) -> List[Dict[str, Any]]: |
| """Retrieve audit logs for a given API key.""" |
| query = "SELECT * FROM usage_log WHERE api_key = ?" |
| params = [api_key] |
| if start_date: |
| query += " AND timestamp >= ?" |
| params.append(start_date.timestamp()) |
| if end_date: |
| query += " AND timestamp <= ?" |
| params.append(end_date.timestamp()) |
| query += " ORDER BY timestamp DESC LIMIT ?" |
| params.append(limit) |
|
|
| with self._get_conn() as conn: |
| rows = conn.execute(query, params).fetchall() |
| return [dict(row) for row in rows] |
|
|
| def clean_old_logs(self): |
| """Delete logs older than retention period for each tier, and old idempotency keys.""" |
| with self._get_conn() as conn: |
| |
| for tier in Tier: |
| retention_days = tier.audit_log_retention_days |
| if retention_days is None: |
| continue |
| cutoff = time.time() - retention_days * 86400 |
| conn.execute( |
| "DELETE FROM usage_log WHERE tier = ? AND timestamp < ?", |
| (tier.value, cutoff) |
| ) |
| |
| cutoff = time.time() - 7 * 86400 |
| conn.execute( |
| "DELETE FROM idempotency_keys WHERE consumed_at < ?", (cutoff,)) |
| conn.commit() |
|
|
|
|
| |
| |
| |
| tracker: Optional[UsageTracker] = None |
|
|
|
|
| def init_tracker( |
| db_path: str = "arf_usage.db", |
| redis_url: Optional[str] = None): |
| """Initialize the global tracker. Must be called before enforce_quota.""" |
| global tracker |
| tracker = UsageTracker(db_path, redis_url) |
|
|
|
|
| def update_key_tier(api_key: str, new_tier: Tier) -> bool: |
| """Globally accessible helper to update API key tier.""" |
| if tracker is None: |
| return False |
| return tracker.update_api_key_tier(api_key, new_tier) |
|
|
|
|
| async def enforce_quota(request: Request, api_key: str = None): |
| """ |
| Dependency that checks API key and remaining quota. |
| FAILS CLOSED: if tracker not initialised, raises HTTP 503. |
| """ |
| |
| if tracker is None: |
| raise HTTPException( |
| status_code=503, |
| detail="Usage tracking service not initialised. Please contact administrator.") |
|
|
| |
| if api_key is None: |
| auth_header = request.headers.get("Authorization") |
| if auth_header and auth_header.startswith("Bearer "): |
| api_key = auth_header[7:] |
| else: |
| api_key = request.query_params.get("api_key") |
|
|
| if not api_key: |
| raise HTTPException(status_code=401, detail="Missing API key") |
|
|
| tier = tracker.get_tier(api_key) |
| if tier is None: |
| raise HTTPException( |
| status_code=403, |
| detail="Invalid or inactive API key") |
|
|
| remaining = tracker.get_remaining_quota(api_key, tier) |
| if remaining is not None and remaining <= 0: |
| raise HTTPException(status_code=429, |
| detail="Monthly evaluation quota exceeded") |
|
|
| |
| request.state.api_key = api_key |
| request.state.tier = tier |
| return {"api_key": api_key, "tier": tier, "remaining": remaining} |
|
|