Spaces:
Build error
Build error
| """ | |
| Admin API endpoints for API key management and audit logs. | |
| These endpoints should be protected (e.g., by an admin API key) in production. | |
| """ | |
| from fastapi import APIRouter, Depends, HTTPException, Query, Path, Body | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| from datetime import datetime | |
| import uuid | |
| from app.core.usage_tracker import tracker, Tier | |
| router = APIRouter(prefix="/admin", tags=["admin"]) | |
| # Simple in‑memory admin key (replace with proper auth in production) | |
| ADMIN_API_KEY = "admin_secret_change_me" | |
| def verify_admin(admin_key: str = Query(..., alias="admin_key")): | |
| if admin_key != ADMIN_API_KEY: | |
| raise HTTPException(status_code=403, detail="Invalid admin key") | |
| return True | |
| class CreateKeyRequest(BaseModel): | |
| tier: str | |
| class UpdateTierRequest(BaseModel): | |
| tier: str | |
| async def create_api_key(req: CreateKeyRequest): | |
| if req.tier not in [t.value for t in Tier]: | |
| raise HTTPException( | |
| status_code=400, detail=f"Invalid tier. Must be one of {[t.value for t in Tier]}") | |
| new_key = f"sk_live_{uuid.uuid4().hex[:24]}" | |
| tier_enum = Tier(req.tier) | |
| tracker.get_or_create_api_key(new_key, tier_enum) | |
| return {"api_key": new_key, "tier": req.tier} | |
| async def list_api_keys(limit: int = 100, offset: int = 0): | |
| with tracker._get_conn() as conn: | |
| rows = conn.execute( | |
| "SELECT key, tier, created_at, last_used_at, is_active FROM api_keys ORDER BY created_at DESC LIMIT ? OFFSET ?", # noqa: E501 | |
| (limit, offset) | |
| ).fetchall() # noqa: E501 | |
| keys = [] | |
| for row in rows: | |
| month = tracker._get_month_key() | |
| usage_row = conn.execute( | |
| "SELECT count FROM monthly_counts WHERE api_key = ? AND year_month = ?", | |
| (row["key"], month) | |
| ).fetchone() | |
| usage = usage_row["count"] if usage_row else 0 | |
| keys.append( | |
| { | |
| "key": row["key"], | |
| "tier": row["tier"], | |
| "created_at": datetime.fromtimestamp( | |
| row["created_at"]).isoformat(), | |
| "last_used_at": datetime.fromtimestamp( | |
| row["last_used_at"]).isoformat() if row["last_used_at"] else None, | |
| "is_active": bool( | |
| row["is_active"]), | |
| "current_month_usage": usage, | |
| }) | |
| return {"keys": keys, "total": len(keys)} | |
| async def update_key_tier( | |
| api_key: str = Path(..., description="The API key to update"), | |
| req: UpdateTierRequest = Body(...), | |
| ): | |
| if req.tier not in [t.value for t in Tier]: | |
| raise HTTPException( | |
| status_code=400, detail=f"Invalid tier. Must be one of {[t.value for t in Tier]}") | |
| with tracker._get_conn() as conn: | |
| row = conn.execute( | |
| "SELECT key FROM api_keys WHERE key = ?", (api_key,)).fetchone() | |
| if not row: | |
| raise HTTPException(status_code=404, detail="API key not found") | |
| conn.execute("UPDATE api_keys SET tier = ? WHERE key = ?", | |
| (req.tier, api_key)) | |
| conn.commit() | |
| return {"message": f"Tier updated to {req.tier}"} | |
| async def deactivate_api_key( | |
| api_key: str = Path(..., description="The API key to deactivate")): | |
| with tracker._get_conn() as conn: | |
| row = conn.execute( | |
| "SELECT key FROM api_keys WHERE key = ?", (api_key,)).fetchone() | |
| if not row: | |
| raise HTTPException(status_code=404, detail="API key not found") | |
| conn.execute( | |
| "UPDATE api_keys SET is_active = 0 WHERE key = ?", (api_key,)) | |
| conn.commit() | |
| return {"message": "API key deactivated"} | |
| async def get_audit_logs( | |
| api_key: str = Path(..., description="The API key to audit"), | |
| start_date: Optional[str] = Query(None), | |
| end_date: Optional[str] = Query(None), | |
| limit: int = 100, | |
| ): | |
| start = datetime.fromisoformat(start_date) if start_date else None | |
| end = datetime.fromisoformat(end_date) if end_date else None | |
| logs = tracker.get_audit_logs(api_key, start, end, limit) | |
| return {"api_key": api_key, "logs": logs} | |
| async def get_global_stats(): | |
| with tracker._get_conn() as conn: | |
| total_keys = conn.execute( | |
| "SELECT COUNT(*) FROM api_keys WHERE is_active = 1").fetchone()[0] | |
| total_requests = conn.execute( | |
| "SELECT COUNT(*) FROM usage_log").fetchone()[0] | |
| by_tier = conn.execute( | |
| "SELECT tier, COUNT(*) as count FROM usage_log GROUP BY tier" | |
| ).fetchall() | |
| month = tracker._get_month_key() | |
| current_month_requests = conn.execute( | |
| "SELECT SUM(count) FROM monthly_counts WHERE year_month = ?", (month,) | |
| ).fetchone()[0] or 0 | |
| return { | |
| "active_api_keys": total_keys, | |
| "total_evaluations": total_requests, | |
| "current_month_evaluations": current_month_requests, | |
| "by_tier": [{"tier": row[0], "count": row[1]} for row in by_tier], | |
| } | |