| """API key manager - multi-user key management""" |
|
|
| import os |
| import orjson |
| import time |
| import secrets |
| import asyncio |
| from typing import List, Dict, Optional |
| from pathlib import Path |
|
|
| from app.core.logger import logger |
| from app.core.config import setting |
|
|
|
|
| class ApiKeyManager: |
| """API key management service""" |
| |
| _instance = None |
| |
| def __new__(cls): |
| if cls._instance is None: |
| cls._instance = super().__new__(cls) |
| return cls._instance |
| |
| def __init__(self): |
| if hasattr(self, '_initialized'): |
| return |
| |
| self.file_path = self._resolve_data_dir() / "api_keys.json" |
| self._keys: List[Dict] = [] |
| self._lock = asyncio.Lock() |
| self._loaded = False |
| self._storage = None |
| |
| self._initialized = True |
| logger.debug(f"[ApiKey] Initialized: {self.file_path}") |
|
|
| @staticmethod |
| def _resolve_data_dir() -> Path: |
| """Resolve data directory for persistence.""" |
| data_dir_env = os.getenv("DATA_DIR") |
| if data_dir_env: |
| return Path(data_dir_env) |
| if Path("/data").exists(): |
| return Path("/data") |
| return Path(__file__).parents[2] / "data" |
|
|
| def set_storage(self, storage) -> None: |
| """Set storage instance""" |
| self._storage = storage |
| data_dir = getattr(storage, "data_dir", None) |
| if data_dir: |
| self.file_path = Path(data_dir) / "api_keys.json" |
|
|
| def _use_storage(self) -> bool: |
| return bool(self._storage and hasattr(self._storage, "load_api_keys") and hasattr(self._storage, "save_api_keys")) |
|
|
| async def init(self): |
| """Initialize and load data""" |
| if not self._loaded: |
| await self._load_data() |
|
|
| async def _load_data(self): |
| """Load API keys""" |
| if self._loaded: |
| return |
|
|
| try: |
| if self._use_storage(): |
| self._keys = await self._storage.load_api_keys() |
| self._loaded = True |
| logger.debug(f"[ApiKey] Loaded {len(self._keys)} API keys (storage)") |
| return |
|
|
| if not self.file_path.exists(): |
| self._keys = [] |
| self._loaded = True |
| return |
|
|
| async with self._lock: |
| content = await asyncio.to_thread(self.file_path.read_bytes) |
| if content: |
| self._keys = orjson.loads(content) |
| self._loaded = True |
| logger.debug(f"[ApiKey] Loaded {len(self._keys)} API keys") |
| except Exception as e: |
| logger.error(f"[ApiKey] Load failed: {e}") |
| self._keys = [] |
| self._loaded = True |
|
|
| async def _save_data(self): |
| """Save API keys""" |
| if not self._loaded: |
| logger.warning("[ApiKey] Save skipped because data is not loaded to avoid overwrite") |
| return |
| |
| try: |
| if self._use_storage(): |
| await self._storage.save_api_keys(self._keys) |
| return |
|
|
| |
| self.file_path.parent.mkdir(parents=True, exist_ok=True) |
| |
| async with self._lock: |
| content = orjson.dumps(self._keys, option=orjson.OPT_INDENT_2) |
| await asyncio.to_thread(self.file_path.write_bytes, content) |
| except Exception as e: |
| logger.error(f"[ApiKey] Save failed: {e}") |
|
|
| def generate_key(self) -> str: |
| """Generate a new sk- prefixed key""" |
| return f"sk-{secrets.token_urlsafe(24)}" |
|
|
| async def add_key(self, name: str) -> Dict: |
| """Add API key""" |
| new_key = { |
| "key": self.generate_key(), |
| "name": name, |
| "created_at": int(time.time()), |
| "is_active": True |
| } |
| self._keys.append(new_key) |
| await self._save_data() |
| logger.info(f"[ApiKey] Added new key: {name}") |
| return new_key |
|
|
| async def batch_add_keys(self, name_prefix: str, count: int) -> List[Dict]: |
| """Batch add API keys""" |
| new_keys = [] |
| for i in range(1, count + 1): |
| name = f"{name_prefix}-{i}" if count > 1 else name_prefix |
| new_keys.append({ |
| "key": self.generate_key(), |
| "name": name, |
| "created_at": int(time.time()), |
| "is_active": True |
| }) |
| |
| self._keys.extend(new_keys) |
| await self._save_data() |
| logger.info(f"[ApiKey] Batch added {count} keys, prefix: {name_prefix}") |
| return new_keys |
|
|
| async def delete_key(self, key: str) -> bool: |
| """Delete API key""" |
| initial_len = len(self._keys) |
| self._keys = [k for k in self._keys if k["key"] != key] |
| |
| if len(self._keys) != initial_len: |
| await self._save_data() |
| logger.info(f"[ApiKey] Deleted key: {key[:10]}...") |
| return True |
| return False |
|
|
| async def batch_delete_keys(self, keys: List[str]) -> int: |
| """Batch delete API keys""" |
| initial_len = len(self._keys) |
| self._keys = [k for k in self._keys if k["key"] not in keys] |
| |
| deleted_count = initial_len - len(self._keys) |
| if deleted_count > 0: |
| await self._save_data() |
| logger.info(f"[ApiKey] Batch deleted {deleted_count} keys") |
| return deleted_count |
|
|
| async def update_key_status(self, key: str, is_active: bool) -> bool: |
| """Update key status""" |
| for k in self._keys: |
| if k["key"] == key: |
| k["is_active"] = is_active |
| await self._save_data() |
| return True |
| return False |
| |
| async def batch_update_keys_status(self, keys: List[str], is_active: bool) -> int: |
| """Batch update key status""" |
| updated_count = 0 |
| for k in self._keys: |
| if k["key"] in keys: |
| if k["is_active"] != is_active: |
| k["is_active"] = is_active |
| updated_count += 1 |
| |
| if updated_count > 0: |
| await self._save_data() |
| logger.info(f"[ApiKey] Batch updated {updated_count} keys to: {is_active}") |
| return updated_count |
|
|
| async def update_key_name(self, key: str, name: str) -> bool: |
| """Update key note""" |
| for k in self._keys: |
| if k["key"] == key: |
| k["name"] = name |
| await self._save_data() |
| return True |
| return False |
|
|
| def validate_key(self, key: str) -> Optional[Dict]: |
| """Validate key and return key info""" |
| |
| global_key = setting.grok_config.get("api_key") |
| if global_key and key == global_key: |
| return { |
| "key": global_key, |
| "name": "Default admin", |
| "is_active": True, |
| "is_admin": True |
| } |
| |
| |
| for k in self._keys: |
| if k["key"] == key: |
| if k["is_active"]: |
| return {**k, "is_admin": False} |
| return None |
| |
| return None |
|
|
| def get_all_keys(self) -> List[Dict]: |
| """Get all keys""" |
| return self._keys |
|
|
|
|
| |
| api_key_manager = ApiKeyManager() |
|
|