Spaces:
Running
Running
| """OB-1 multi-account token manager.""" | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import random | |
| import time | |
| import httpx | |
| from ..core.config import ( | |
| OB1_WORKOS_AUTH_URL, | |
| OB1_WORKOS_CLIENT_ID, | |
| OB1_REFRESH_BUFFER, | |
| OB1_API_BASE, | |
| ) | |
| from ..core import config as _config | |
| from ..core.logger import get_logger | |
| log = get_logger("token") | |
| DEVICE_AUTH_URL = "https://api.workos.com/user_management/authorize/device" | |
| ORG_API_URL = f"{OB1_API_BASE}/auth/organizations" | |
| def _accounts_path() -> str: | |
| return os.path.join(os.path.dirname(__file__), "..", "..", "config", "accounts.json") | |
| class Account: | |
| def __init__(self, data: dict): | |
| self.email: str = data.get("email", "") | |
| self.access_token: str = data.get("access_token", "") | |
| self.refresh_token: str = data.get("refresh_token", "") | |
| self.expires_at: float = data.get("expires_at", 0) | |
| self.org_id: str = data.get("org_id", "") | |
| self.org_name: str = data.get("org_name", "") | |
| self.user_id: str = data.get("user_id", "") | |
| self.user_data: dict = data.get("user_data", {}) | |
| def active(self) -> bool: | |
| return bool(self.access_token) and self.expires_at > time.time() | |
| def to_dict(self) -> dict: | |
| return { | |
| "email": self.email, | |
| "access_token": self.access_token, | |
| "refresh_token": self.refresh_token, | |
| "expires_at": self.expires_at, | |
| "org_id": self.org_id, | |
| "org_name": self.org_name, | |
| "user_id": self.user_id, | |
| "user_data": self.user_data, | |
| } | |
| def _mask(token: str) -> str: | |
| if not token: | |
| return "" | |
| if len(token) <= 8: | |
| return token[:2] + "..." + token[-2:] | |
| return token[:4] + "..." + token[-4:] | |
| def to_public(self) -> dict: | |
| return { | |
| "email": self.email, | |
| "org_id": self.org_id, | |
| "org_name": self.org_name, | |
| "at_mask": self._mask(self.access_token), | |
| "rt_mask": self._mask(self.refresh_token), | |
| "active": self.active, | |
| "expires_at": int(self.expires_at * 1000), | |
| } | |
| class OB1TokenManager: | |
| """Manages multiple OB-1 accounts with round-robin and auto-refresh.""" | |
| def __init__(self): | |
| self._accounts: list[Account] = [] | |
| self._current_idx: int = 0 | |
| self._path = _accounts_path() | |
| self._request_count: int = 0 | |
| self._cost_today: float = 0 | |
| def load(self): | |
| # Load from accounts.json | |
| if os.path.exists(self._path): | |
| with open(self._path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| self._accounts = [Account(a) for a in data] | |
| log.info("Loaded %d accounts", len(self._accounts)) | |
| # Also import from ~/.ob1/credentials.json if accounts.json is empty | |
| if not self._accounts: | |
| cred_path = os.path.join(os.path.expanduser("~"), ".ob1", "credentials.json") | |
| if os.path.exists(cred_path): | |
| self._import_credentials(cred_path) | |
| def _import_credentials(self, path: str): | |
| with open(path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| oauth = data.get("oauth", {}) | |
| if not oauth.get("access_token"): | |
| return | |
| user = oauth.get("user", {}) | |
| acct = Account({ | |
| "email": user.get("email", ""), | |
| "access_token": oauth.get("access_token", ""), | |
| "refresh_token": oauth.get("refresh_token", ""), | |
| "expires_at": oauth.get("expires_at", 0) / 1000, | |
| "org_id": oauth.get("organization_id", ""), | |
| "user_id": user.get("id", ""), | |
| "user_data": user, | |
| }) | |
| self._accounts.append(acct) | |
| self._save() | |
| log.info("Imported %s from credentials.json", acct.email) | |
| def _save(self): | |
| os.makedirs(os.path.dirname(self._path), exist_ok=True) | |
| with open(self._path, "w", encoding="utf-8") as f: | |
| json.dump([a.to_dict() for a in self._accounts], f, indent=2) | |
| def is_loaded(self) -> bool: | |
| return len(self._accounts) > 0 | |
| def user_email(self) -> str: | |
| if self._accounts: | |
| return self._accounts[0].email | |
| return "" | |
| def org_id(self) -> str: | |
| if self._accounts: | |
| return self._accounts[0].org_id | |
| return "" | |
| def list_accounts(self) -> list[dict]: | |
| return [a.to_public() for a in self._accounts] | |
| def current_idx(self) -> int: | |
| return self._current_idx | |
| def stats(self) -> dict: | |
| active = sum(1 for a in self._accounts if a.active) | |
| return { | |
| "total": len(self._accounts), | |
| "active": active, | |
| "cost": self._cost_today, | |
| "requests": self._request_count, | |
| } | |
| def add_cost(self, cost: float): | |
| self._cost_today += cost | |
| self._request_count += 1 | |
| async def refresh_account(self, idx: int, force: bool = False) -> bool: | |
| if idx < 0 or idx >= len(self._accounts): | |
| return False | |
| acct = self._accounts[idx] | |
| if not acct.refresh_token: | |
| return False | |
| # Skip if token still valid (not within buffer), unless forced | |
| if not force and acct.expires_at - time.time() > OB1_REFRESH_BUFFER: | |
| log.debug("Skipping refresh for %s, token still valid (%.0fh remaining)", | |
| acct.email, (acct.expires_at - time.time()) / 3600) | |
| return True | |
| try: | |
| proxy = _config.PROXY_URL or None | |
| async with httpx.AsyncClient(proxy=proxy, timeout=30) as client: | |
| resp = await client.post( | |
| OB1_WORKOS_AUTH_URL, | |
| data={ | |
| "grant_type": "refresh_token", | |
| "refresh_token": acct.refresh_token, | |
| "client_id": OB1_WORKOS_CLIENT_ID, | |
| }, | |
| headers={"Content-Type": "application/x-www-form-urlencoded"}, | |
| ) | |
| if resp.status_code != 200: | |
| log.warning("Refresh failed for %s: %d %s", acct.email, resp.status_code, resp.text) | |
| return False | |
| result = resp.json() | |
| acct.access_token = result["access_token"] | |
| acct.refresh_token = result.get("refresh_token", acct.refresh_token) | |
| acct.expires_at = time.time() + result.get("expires_in", 3600) | |
| self._save() | |
| log.info("Refreshed %s", acct.email) | |
| return True | |
| except Exception as e: | |
| log.error("Refresh error for %s: %s", acct.email, e) | |
| return False | |
| def remove_account(self, idx: int) -> bool: | |
| if idx < 0 or idx >= len(self._accounts): | |
| return False | |
| removed = self._accounts.pop(idx) | |
| self._save() | |
| log.info("Removed %s", removed.email) | |
| return True | |
| async def add_account_from_device(self, auth_result: dict) -> str: | |
| """Add account from device auth result. Returns email.""" | |
| user = auth_result.get("user", {}) | |
| at = auth_result["access_token"] | |
| rt = auth_result["refresh_token"] | |
| expires_in = auth_result.get("expires_in", 3600) | |
| user_id = user.get("id", "") | |
| email = user.get("email", "") | |
| # Fetch org | |
| org_id = "" | |
| org_name = "" | |
| try: | |
| proxy = _config.PROXY_URL or None | |
| async with httpx.AsyncClient(proxy=proxy, timeout=15) as client: | |
| resp = await client.get( | |
| f"{ORG_API_URL}?user_id={user_id}", | |
| headers={"Authorization": f"Bearer {at}"}, | |
| ) | |
| if resp.status_code == 200: | |
| orgs = resp.json().get("data", []) | |
| if orgs: | |
| org_id = orgs[0].get("organizationId", "") | |
| org_name = orgs[0].get("organizationName", "") | |
| except Exception as e: | |
| log.error("Org fetch error: %s", e) | |
| # Check duplicate | |
| for a in self._accounts: | |
| if a.email == email: | |
| a.access_token = at | |
| a.refresh_token = rt | |
| a.expires_at = time.time() + expires_in | |
| a.org_id = org_id or a.org_id | |
| a.org_name = org_name or a.org_name | |
| self._save() | |
| return email | |
| acct = Account({ | |
| "email": email, | |
| "access_token": at, | |
| "refresh_token": rt, | |
| "expires_at": time.time() + expires_in, | |
| "org_id": org_id, | |
| "org_name": org_name, | |
| "user_id": user_id, | |
| "user_data": user, | |
| }) | |
| self._accounts.append(acct) | |
| self._save() | |
| log.info("Added account %s (org: %s)", email, org_name) | |
| return email | |
| async def get_api_key(self) -> str | None: | |
| """Get a valid API key based on rotation mode.""" | |
| if not self._accounts: | |
| return None | |
| n = len(self._accounts) | |
| mode = _config.OB1_ROTATION_MODE | |
| if mode == "performance": | |
| order = random.sample(range(n), n) | |
| elif mode == "cache-first": | |
| # 优先使用上次成功的账号 | |
| order = [self._current_idx] + [i for i in range(n) if i != self._current_idx] | |
| else: # balanced (default) — 轮流使用 | |
| order = [(self._current_idx + i) % n for i in range(n)] | |
| self._current_idx = (self._current_idx + 1) % n | |
| for idx in order: | |
| acct = self._accounts[idx] | |
| if acct.expires_at - time.time() < OB1_REFRESH_BUFFER: | |
| await self.refresh_account(idx) | |
| if acct.active: | |
| if acct.org_id: | |
| return f"{acct.access_token}:{acct.org_id}" | |
| return acct.access_token | |
| return None | |
| async def refresh(self) -> bool: | |
| """Refresh all accounts.""" | |
| ok = False | |
| for i in range(len(self._accounts)): | |
| if await self.refresh_account(i): | |
| ok = True | |
| return ok | |
| def import_accounts(self, data: list[dict]) -> int: | |
| """Import accounts from a list of dicts, skip duplicates by email.""" | |
| existing = {a.email for a in self._accounts} | |
| count = 0 | |
| for d in data: | |
| if d.get("email") and d["email"] not in existing: | |
| self._accounts.append(Account(d)) | |
| existing.add(d["email"]) | |
| count += 1 | |
| if count: | |
| self._save() | |
| return count | |
| def batch_remove(self, indices: list[int]) -> int: | |
| """Remove accounts by indices (descending to keep order).""" | |
| removed = 0 | |
| for i in sorted(indices, reverse=True): | |
| if 0 <= i < len(self._accounts): | |
| self._accounts.pop(i) | |
| removed += 1 | |
| if removed: | |
| self._save() | |
| return removed | |