grok2api / app /services /api_keys.py
tejmar's picture
Initial commit
2c97e18
"""API key manager - multi-user key management"""
import os
import orjson
import time
import secrets
import asyncio
from typing import List, Dict, Optional
from pathlib import Path
from app.core.logger import logger
from app.core.config import setting
class ApiKeyManager:
"""API key management service"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if hasattr(self, '_initialized'):
return
self.file_path = self._resolve_data_dir() / "api_keys.json"
self._keys: List[Dict] = []
self._lock = asyncio.Lock()
self._loaded = False
self._storage = None
self._initialized = True
logger.debug(f"[ApiKey] Initialized: {self.file_path}")
@staticmethod
def _resolve_data_dir() -> Path:
"""Resolve data directory for persistence."""
data_dir_env = os.getenv("DATA_DIR")
if data_dir_env:
return Path(data_dir_env)
if Path("/data").exists():
return Path("/data")
return Path(__file__).parents[2] / "data"
def set_storage(self, storage) -> None:
"""Set storage instance"""
self._storage = storage
data_dir = getattr(storage, "data_dir", None)
if data_dir:
self.file_path = Path(data_dir) / "api_keys.json"
def _use_storage(self) -> bool:
return bool(self._storage and hasattr(self._storage, "load_api_keys") and hasattr(self._storage, "save_api_keys"))
async def init(self):
"""Initialize and load data"""
if not self._loaded:
await self._load_data()
async def _load_data(self):
"""Load API keys"""
if self._loaded:
return
try:
if self._use_storage():
self._keys = await self._storage.load_api_keys()
self._loaded = True
logger.debug(f"[ApiKey] Loaded {len(self._keys)} API keys (storage)")
return
if not self.file_path.exists():
self._keys = []
self._loaded = True
return
async with self._lock:
content = await asyncio.to_thread(self.file_path.read_bytes)
if content:
self._keys = orjson.loads(content)
self._loaded = True
logger.debug(f"[ApiKey] Loaded {len(self._keys)} API keys")
except Exception as e:
logger.error(f"[ApiKey] Load failed: {e}")
self._keys = []
self._loaded = True # Prevent overwrite if load fails
async def _save_data(self):
"""Save API keys"""
if not self._loaded:
logger.warning("[ApiKey] Save skipped because data is not loaded to avoid overwrite")
return
try:
if self._use_storage():
await self._storage.save_api_keys(self._keys)
return
# Ensure directory exists
self.file_path.parent.mkdir(parents=True, exist_ok=True)
async with self._lock:
content = orjson.dumps(self._keys, option=orjson.OPT_INDENT_2)
await asyncio.to_thread(self.file_path.write_bytes, content)
except Exception as e:
logger.error(f"[ApiKey] Save failed: {e}")
def generate_key(self) -> str:
"""Generate a new sk- prefixed key"""
return f"sk-{secrets.token_urlsafe(24)}"
async def add_key(self, name: str) -> Dict:
"""Add API key"""
new_key = {
"key": self.generate_key(),
"name": name,
"created_at": int(time.time()),
"is_active": True
}
self._keys.append(new_key)
await self._save_data()
logger.info(f"[ApiKey] Added new key: {name}")
return new_key
async def batch_add_keys(self, name_prefix: str, count: int) -> List[Dict]:
"""Batch add API keys"""
new_keys = []
for i in range(1, count + 1):
name = f"{name_prefix}-{i}" if count > 1 else name_prefix
new_keys.append({
"key": self.generate_key(),
"name": name,
"created_at": int(time.time()),
"is_active": True
})
self._keys.extend(new_keys)
await self._save_data()
logger.info(f"[ApiKey] Batch added {count} keys, prefix: {name_prefix}")
return new_keys
async def delete_key(self, key: str) -> bool:
"""Delete API key"""
initial_len = len(self._keys)
self._keys = [k for k in self._keys if k["key"] != key]
if len(self._keys) != initial_len:
await self._save_data()
logger.info(f"[ApiKey] Deleted key: {key[:10]}...")
return True
return False
async def batch_delete_keys(self, keys: List[str]) -> int:
"""Batch delete API keys"""
initial_len = len(self._keys)
self._keys = [k for k in self._keys if k["key"] not in keys]
deleted_count = initial_len - len(self._keys)
if deleted_count > 0:
await self._save_data()
logger.info(f"[ApiKey] Batch deleted {deleted_count} keys")
return deleted_count
async def update_key_status(self, key: str, is_active: bool) -> bool:
"""Update key status"""
for k in self._keys:
if k["key"] == key:
k["is_active"] = is_active
await self._save_data()
return True
return False
async def batch_update_keys_status(self, keys: List[str], is_active: bool) -> int:
"""Batch update key status"""
updated_count = 0
for k in self._keys:
if k["key"] in keys:
if k["is_active"] != is_active:
k["is_active"] = is_active
updated_count += 1
if updated_count > 0:
await self._save_data()
logger.info(f"[ApiKey] Batch updated {updated_count} keys to: {is_active}")
return updated_count
async def update_key_name(self, key: str, name: str) -> bool:
"""Update key note"""
for k in self._keys:
if k["key"] == key:
k["name"] = name
await self._save_data()
return True
return False
def validate_key(self, key: str) -> Optional[Dict]:
"""Validate key and return key info"""
# 1. Check global config key (default admin key)
global_key = setting.grok_config.get("api_key")
if global_key and key == global_key:
return {
"key": global_key,
"name": "Default admin",
"is_active": True,
"is_admin": True
}
# 2. Check multi-key list
for k in self._keys:
if k["key"] == key:
if k["is_active"]:
return {**k, "is_admin": False} # Keys are not treated as admins for now
return None
return None
def get_all_keys(self) -> List[Dict]:
"""Get all keys"""
return self._keys
# Global instance
api_key_manager = ApiKeyManager()