ob12api / src /services /token_manager.py
david-baxter's picture
Upload 24 files
cef045d verified
"""OB-1 multi-account token manager."""
from __future__ import annotations
import json
import os
import random
import time
import httpx
from ..core.config import (
OB1_WORKOS_AUTH_URL,
OB1_WORKOS_CLIENT_ID,
OB1_REFRESH_BUFFER,
OB1_API_BASE,
)
from ..core import config as _config
from ..core.logger import get_logger
log = get_logger("token")
DEVICE_AUTH_URL = "https://api.workos.com/user_management/authorize/device"
ORG_API_URL = f"{OB1_API_BASE}/auth/organizations"
def _accounts_path() -> str:
return os.path.join(os.path.dirname(__file__), "..", "..", "config", "accounts.json")
class Account:
def __init__(self, data: dict):
self.email: str = data.get("email", "")
self.access_token: str = data.get("access_token", "")
self.refresh_token: str = data.get("refresh_token", "")
self.expires_at: float = data.get("expires_at", 0)
self.org_id: str = data.get("org_id", "")
self.org_name: str = data.get("org_name", "")
self.user_id: str = data.get("user_id", "")
self.user_data: dict = data.get("user_data", {})
@property
def active(self) -> bool:
return bool(self.access_token) and self.expires_at > time.time()
def to_dict(self) -> dict:
return {
"email": self.email,
"access_token": self.access_token,
"refresh_token": self.refresh_token,
"expires_at": self.expires_at,
"org_id": self.org_id,
"org_name": self.org_name,
"user_id": self.user_id,
"user_data": self.user_data,
}
@staticmethod
def _mask(token: str) -> str:
if not token:
return ""
if len(token) <= 8:
return token[:2] + "..." + token[-2:]
return token[:4] + "..." + token[-4:]
def to_public(self) -> dict:
return {
"email": self.email,
"org_id": self.org_id,
"org_name": self.org_name,
"at_mask": self._mask(self.access_token),
"rt_mask": self._mask(self.refresh_token),
"active": self.active,
"expires_at": int(self.expires_at * 1000),
}
class OB1TokenManager:
"""Manages multiple OB-1 accounts with round-robin and auto-refresh."""
def __init__(self):
self._accounts: list[Account] = []
self._current_idx: int = 0
self._path = _accounts_path()
self._request_count: int = 0
self._cost_today: float = 0
def load(self):
# Load from accounts.json
if os.path.exists(self._path):
with open(self._path, "r", encoding="utf-8") as f:
data = json.load(f)
self._accounts = [Account(a) for a in data]
log.info("Loaded %d accounts", len(self._accounts))
# Also import from ~/.ob1/credentials.json if accounts.json is empty
if not self._accounts:
cred_path = os.path.join(os.path.expanduser("~"), ".ob1", "credentials.json")
if os.path.exists(cred_path):
self._import_credentials(cred_path)
def _import_credentials(self, path: str):
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
oauth = data.get("oauth", {})
if not oauth.get("access_token"):
return
user = oauth.get("user", {})
acct = Account({
"email": user.get("email", ""),
"access_token": oauth.get("access_token", ""),
"refresh_token": oauth.get("refresh_token", ""),
"expires_at": oauth.get("expires_at", 0) / 1000,
"org_id": oauth.get("organization_id", ""),
"user_id": user.get("id", ""),
"user_data": user,
})
self._accounts.append(acct)
self._save()
log.info("Imported %s from credentials.json", acct.email)
def _save(self):
os.makedirs(os.path.dirname(self._path), exist_ok=True)
with open(self._path, "w", encoding="utf-8") as f:
json.dump([a.to_dict() for a in self._accounts], f, indent=2)
@property
def is_loaded(self) -> bool:
return len(self._accounts) > 0
@property
def user_email(self) -> str:
if self._accounts:
return self._accounts[0].email
return ""
@property
def org_id(self) -> str:
if self._accounts:
return self._accounts[0].org_id
return ""
def list_accounts(self) -> list[dict]:
return [a.to_public() for a in self._accounts]
@property
def current_idx(self) -> int:
return self._current_idx
@property
def stats(self) -> dict:
active = sum(1 for a in self._accounts if a.active)
return {
"total": len(self._accounts),
"active": active,
"cost": self._cost_today,
"requests": self._request_count,
}
def add_cost(self, cost: float):
self._cost_today += cost
self._request_count += 1
async def refresh_account(self, idx: int, force: bool = False) -> bool:
if idx < 0 or idx >= len(self._accounts):
return False
acct = self._accounts[idx]
if not acct.refresh_token:
return False
# Skip if token still valid (not within buffer), unless forced
if not force and acct.expires_at - time.time() > OB1_REFRESH_BUFFER:
log.debug("Skipping refresh for %s, token still valid (%.0fh remaining)",
acct.email, (acct.expires_at - time.time()) / 3600)
return True
try:
proxy = _config.PROXY_URL or None
async with httpx.AsyncClient(proxy=proxy, timeout=30) as client:
resp = await client.post(
OB1_WORKOS_AUTH_URL,
data={
"grant_type": "refresh_token",
"refresh_token": acct.refresh_token,
"client_id": OB1_WORKOS_CLIENT_ID,
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
if resp.status_code != 200:
log.warning("Refresh failed for %s: %d %s", acct.email, resp.status_code, resp.text)
return False
result = resp.json()
acct.access_token = result["access_token"]
acct.refresh_token = result.get("refresh_token", acct.refresh_token)
acct.expires_at = time.time() + result.get("expires_in", 3600)
self._save()
log.info("Refreshed %s", acct.email)
return True
except Exception as e:
log.error("Refresh error for %s: %s", acct.email, e)
return False
def remove_account(self, idx: int) -> bool:
if idx < 0 or idx >= len(self._accounts):
return False
removed = self._accounts.pop(idx)
self._save()
log.info("Removed %s", removed.email)
return True
async def add_account_from_device(self, auth_result: dict) -> str:
"""Add account from device auth result. Returns email."""
user = auth_result.get("user", {})
at = auth_result["access_token"]
rt = auth_result["refresh_token"]
expires_in = auth_result.get("expires_in", 3600)
user_id = user.get("id", "")
email = user.get("email", "")
# Fetch org
org_id = ""
org_name = ""
try:
proxy = _config.PROXY_URL or None
async with httpx.AsyncClient(proxy=proxy, timeout=15) as client:
resp = await client.get(
f"{ORG_API_URL}?user_id={user_id}",
headers={"Authorization": f"Bearer {at}"},
)
if resp.status_code == 200:
orgs = resp.json().get("data", [])
if orgs:
org_id = orgs[0].get("organizationId", "")
org_name = orgs[0].get("organizationName", "")
except Exception as e:
log.error("Org fetch error: %s", e)
# Check duplicate
for a in self._accounts:
if a.email == email:
a.access_token = at
a.refresh_token = rt
a.expires_at = time.time() + expires_in
a.org_id = org_id or a.org_id
a.org_name = org_name or a.org_name
self._save()
return email
acct = Account({
"email": email,
"access_token": at,
"refresh_token": rt,
"expires_at": time.time() + expires_in,
"org_id": org_id,
"org_name": org_name,
"user_id": user_id,
"user_data": user,
})
self._accounts.append(acct)
self._save()
log.info("Added account %s (org: %s)", email, org_name)
return email
async def get_api_key(self) -> str | None:
"""Get a valid API key based on rotation mode."""
if not self._accounts:
return None
n = len(self._accounts)
mode = _config.OB1_ROTATION_MODE
if mode == "performance":
order = random.sample(range(n), n)
elif mode == "cache-first":
# 优先使用上次成功的账号
order = [self._current_idx] + [i for i in range(n) if i != self._current_idx]
else: # balanced (default) — 轮流使用
order = [(self._current_idx + i) % n for i in range(n)]
self._current_idx = (self._current_idx + 1) % n
for idx in order:
acct = self._accounts[idx]
if acct.expires_at - time.time() < OB1_REFRESH_BUFFER:
await self.refresh_account(idx)
if acct.active:
if acct.org_id:
return f"{acct.access_token}:{acct.org_id}"
return acct.access_token
return None
async def refresh(self) -> bool:
"""Refresh all accounts."""
ok = False
for i in range(len(self._accounts)):
if await self.refresh_account(i):
ok = True
return ok
def import_accounts(self, data: list[dict]) -> int:
"""Import accounts from a list of dicts, skip duplicates by email."""
existing = {a.email for a in self._accounts}
count = 0
for d in data:
if d.get("email") and d["email"] not in existing:
self._accounts.append(Account(d))
existing.add(d["email"])
count += 1
if count:
self._save()
return count
def batch_remove(self, indices: list[int]) -> int:
"""Remove accounts by indices (descending to keep order)."""
removed = 0
for i in sorted(indices, reverse=True):
if 0 <= i < len(self._accounts):
self._accounts.pop(i)
removed += 1
if removed:
self._save()
return removed