"""OpenAI-compatible API routes — proxies to OB-1 backend.""" from __future__ import annotations import json import time import uuid from typing import Any, AsyncGenerator, Optional from fastapi import APIRouter, Depends, Request from fastapi.responses import JSONResponse, StreamingResponse from ..core.auth import verify_api_key from ..core.logger import get_logger from ..core.models import AnthropicMessagesRequest, ChatCompletionRequest from ..services.ob1_client import OB1Client from ..services.token_manager import OB1TokenManager log = get_logger("routes") router = APIRouter() _token_manager: Optional[OB1TokenManager] = None _ob1_client: Optional[OB1Client] = None def init(token_manager: OB1TokenManager, ob1_client: OB1Client): global _token_manager, _ob1_client _token_manager = token_manager _ob1_client = ob1_client def _require_token_manager() -> OB1TokenManager: if _token_manager is None: raise RuntimeError("Token manager is not initialized") return _token_manager def _require_ob1_client() -> OB1Client: if _ob1_client is None: raise RuntimeError("OB1 client is not initialized") return _ob1_client @router.get("/v1/models") async def list_models(_: str = Depends(verify_api_key)): token_manager = _require_token_manager() ob1_client = _require_ob1_client() api_key = await token_manager.get_api_key() if not api_key: return {"object": "list", "data": []} raw = await ob1_client.fetch_models(api_key) models = [] for m in raw: models.append( { "id": m["id"], "object": "model", "created": m.get("created", 0), "owned_by": m["id"].split("/")[0] if "/" in m["id"] else "ob1", "name": m.get("name", m["id"]), } ) return {"object": "list", "data": models} @router.post("/v1/chat/completions") async def chat_completions( request: ChatCompletionRequest, _: str = Depends(verify_api_key), ): messages = [{"role": m.role, "content": m.content} for m in request.messages] extra_payload = _build_openai_extra_payload(request.tools, request.tool_choice) resp = await _send_chat_request( messages=messages, model=request.model, stream=request.stream, temperature=request.temperature, top_p=request.top_p, max_tokens=request.max_tokens, extra_payload=extra_payload, ) if isinstance(resp, JSONResponse): return resp if request.stream: log.debug("Streaming response started") return StreamingResponse( _proxy_stream(resp, _token_manager), media_type="text/event-stream", ) else: data = resp.json() usage = data.get("usage", {}) _track_usage(usage) log.info( "Chat response: model=%s prompt_tokens=%d completion_tokens=%d", data.get("model", "?"), usage.get("prompt_tokens", 0), usage.get("completion_tokens", 0), ) return JSONResponse(content=data) @router.post("/v1/messages") async def anthropic_messages( request: AnthropicMessagesRequest, _: str = Depends(verify_api_key), ): messages = _anthropic_to_openai_messages(request) extra_payload = _build_openai_extra_payload( _anthropic_tools_to_openai(request.tools), _anthropic_tool_choice_to_openai(request.tool_choice), ) resp = await _send_chat_request( messages=messages, model=request.model, stream=request.stream, temperature=request.temperature, top_p=request.top_p, max_tokens=request.max_tokens, extra_payload=extra_payload, ) if isinstance(resp, JSONResponse): return resp if request.stream: return StreamingResponse( _proxy_stream_anthropic(resp, request.model), media_type="text/event-stream", ) data = resp.json() usage = data.get("usage", {}) _track_usage(usage) await resp.aclose() return JSONResponse(content=_openai_to_anthropic_response(data, request.model)) async def _send_chat_request( *, messages: list[dict[str, Any]], model: str, stream: bool, temperature: float | None, top_p: float | None, max_tokens: int | None, extra_payload: dict[str, Any] | None = None, ): token_manager = _require_token_manager() ob1_client = _require_ob1_client() api_key = await token_manager.get_api_key() if not api_key: log.warning("No valid OB-1 token available") return JSONResponse( status_code=503, content={"error": "No valid OB-1 token. Run ob1 auth to login."}, ) resolved_model = await _resolve_model_name(model, api_key) log.info( "Chat request: model=%s resolved_model=%s stream=%s messages=%d", model, resolved_model, stream, len(messages), ) try: resp = await ob1_client.chat( api_key=api_key, messages=messages, model=resolved_model, stream=stream, temperature=temperature, top_p=top_p, max_tokens=max_tokens, extra_payload=extra_payload, ) except Exception as e: log.error("Backend error: %s", e) return JSONResponse(status_code=502, content={"error": f"Backend error: {e}"}) if resp.status_code == 401: await resp.aclose() log.warning("Token rejected (401), refreshing...") ok = await token_manager.refresh() if not ok: log.error("Token refresh failed") return JSONResponse( status_code=401, content={"error": "Token expired and refresh failed"} ) api_key = await token_manager.get_api_key() if not api_key: return JSONResponse( status_code=401, content={"error": "Token refresh failed"} ) try: resp = await ob1_client.chat( api_key=api_key, messages=messages, model=resolved_model, stream=stream, temperature=temperature, top_p=top_p, max_tokens=max_tokens, extra_payload=extra_payload, ) except Exception as e: log.error("Backend error after refresh: %s", e) return JSONResponse( status_code=502, content={"error": f"Backend error: {e}"} ) if resp.status_code != 200: try: body = (await resp.aread()).decode() except Exception: body = "unable to read response body" await resp.aclose() log.error("OB-1 returned %d: %s", resp.status_code, body[:200]) return JSONResponse( status_code=resp.status_code, content={"error": f"OB-1 returned {resp.status_code}: {body[:500]}"}, ) return resp async def _resolve_model_name(requested_model: str, api_key: str) -> str: ob1_client = _require_ob1_client() raw_models = await ob1_client.fetch_models(api_key) available = [item.get("id") for item in raw_models if item.get("id")] if requested_model in available: return requested_model anthropic_prefixed = f"anthropic/{requested_model}" if anthropic_prefixed in available: return anthropic_prefixed if requested_model.startswith("claude-"): lowered = requested_model.lower() family = None for candidate in ("haiku", "sonnet", "opus"): if candidate in lowered: family = candidate break if family: family_matches = [ model_id for model_id in available if model_id.startswith(f"anthropic/claude-{family}") ] if family_matches: return sorted(family_matches)[-1] anthropic_models = [ model_id for model_id in available if model_id.startswith("anthropic/") ] preferred_order = ["anthropic/claude-sonnet-4.6", "anthropic/claude-opus-4.6"] for model_id in preferred_order: if model_id in anthropic_models: return model_id if anthropic_models: return sorted(anthropic_models)[-1] return requested_model def _anthropic_to_openai_messages( request: AnthropicMessagesRequest, ) -> list[dict[str, Any]]: messages: list[dict[str, Any]] = [] if request.system: messages.append({"role": "system", "content": _flatten_content(request.system)}) for message in request.messages: blocks = ( message.content if isinstance(message.content, list) else [{"type": "text", "text": message.content}] ) text_parts: list[str] = [] tool_calls: list[dict[str, Any]] = [] tool_results: list[dict[str, Any]] = [] for block in blocks: if not isinstance(block, dict): continue block_type = block.get("type") if block_type == "text" and isinstance(block.get("text"), str): text_parts.append(block["text"]) elif block_type == "tool_use": tool_calls.append( { "id": block.get("id") or f"call_{uuid.uuid4().hex}", "type": "function", "function": { "name": block.get("name", "tool"), "arguments": json.dumps( block.get("input", {}), ensure_ascii=False ), }, } ) elif block_type == "tool_result": tool_results.append( { "role": "tool", "tool_call_id": block.get("tool_use_id", ""), "content": _flatten_content(block.get("content", "")), } ) if message.role == "assistant": assistant_message: dict[str, Any] = {"role": "assistant"} assistant_message["content"] = "\n".join( part for part in text_parts if part ) if tool_calls: assistant_message["tool_calls"] = tool_calls messages.append(assistant_message) elif message.role == "user" and tool_results: messages.extend(tool_results) if text_parts: messages.append({"role": "user", "content": "\n".join(text_parts)}) else: messages.append( { "role": message.role, "content": "\n".join(part for part in text_parts if part), } ) return messages def _anthropic_tools_to_openai( tools: Optional[list[dict[str, Any]]], ) -> Optional[list[dict[str, Any]]]: if not tools: return None converted: list[dict[str, Any]] = [] for tool in tools: converted.append( { "type": "function", "function": { "name": tool.get("name", "tool"), "description": tool.get("description", ""), "parameters": tool.get( "input_schema", {"type": "object", "properties": {}} ), }, } ) return converted def _anthropic_tool_choice_to_openai( tool_choice: Optional[dict[str, Any]], ) -> Optional[dict[str, Any] | str]: if not tool_choice: return None choice_type = tool_choice.get("type") if choice_type in {"auto", "none"}: return choice_type if choice_type in {"any", "required"}: return "required" if choice_type == "tool": name = tool_choice.get("name") if name: return {"type": "function", "function": {"name": name}} return None def _build_openai_extra_payload( tools: Optional[list[dict[str, Any]]], tool_choice: Optional[dict[str, Any] | str], ) -> Optional[dict[str, Any]]: extra_payload: dict[str, Any] = {} if tools: extra_payload["tools"] = tools if tool_choice is not None: extra_payload["tool_choice"] = tool_choice return extra_payload or None def _flatten_content(content: Any) -> str: if content is None: return "" if isinstance(content, str): return content if isinstance(content, list): parts: list[str] = [] for block in content: if not isinstance(block, dict): continue if block.get("type") == "text" and isinstance(block.get("text"), str): parts.append(block["text"]) elif block.get("type") == "tool_result": parts.append(_flatten_content(block.get("content", ""))) return "\n".join(part for part in parts if part) return str(content) def _openai_to_anthropic_response(data: dict[str, Any], model: str) -> dict[str, Any]: choice = (data.get("choices") or [{}])[0] message = choice.get("message") or {} content_blocks: list[dict[str, Any]] = [] text = _flatten_content(message.get("content", "")) if text: content_blocks.append({"type": "text", "text": text}) for tool_call in message.get("tool_calls") or []: function = tool_call.get("function") or {} content_blocks.append( { "type": "tool_use", "id": tool_call.get("id", f"toolu_{uuid.uuid4().hex}"), "name": function.get("name", "tool"), "input": _parse_json_object(function.get("arguments")), } ) usage = data.get("usage") or {} return { "id": data.get("id", f"msg_{uuid.uuid4().hex}"), "type": "message", "role": "assistant", "content": content_blocks or [{"type": "text", "text": ""}], "model": data.get("model", model), "stop_reason": _map_finish_reason(choice.get("finish_reason")), "stop_sequence": None, "usage": { "input_tokens": usage.get("prompt_tokens", 0), "output_tokens": usage.get("completion_tokens", 0), }, } def _map_finish_reason(reason: Optional[str]) -> str: if reason in {None, "stop"}: return "end_turn" if reason == "length": return "max_tokens" if reason == "tool_calls": return "tool_use" return reason or "end_turn" def _parse_json_object(value: Any) -> dict[str, Any]: if isinstance(value, dict): return value if isinstance(value, str) and value: try: parsed = json.loads(value) if isinstance(parsed, dict): return parsed except json.JSONDecodeError: pass return {} async def _proxy_stream_anthropic(resp, model: str) -> AsyncGenerator[str, None]: message_id = f"msg_{uuid.uuid4().hex}" sent_start = False text_started = False text_index = 0 usage: dict[str, Any] = {"input_tokens": 0, "output_tokens": 0} stop_reason = "end_turn" tool_state: dict[int, dict[str, Any]] = {} next_content_index = 0 try: async for line in resp.aiter_lines(): if not line or not line.startswith("data: "): continue payload = line[6:] if payload == "[DONE]": break try: chunk = json.loads(payload) except json.JSONDecodeError: continue if not sent_start: prompt_tokens = ((chunk.get("usage") or {}).get("prompt_tokens")) or 0 usage["input_tokens"] = prompt_tokens yield _anthropic_sse( "message_start", { "type": "message_start", "message": { "id": message_id, "type": "message", "role": "assistant", "content": [], "model": chunk.get("model", model), "stop_reason": None, "stop_sequence": None, "usage": usage, }, }, ) sent_start = True delta = (chunk.get("choices") or [{}])[0].get("delta") or {} finish_reason = (chunk.get("choices") or [{}])[0].get("finish_reason") text = delta.get("content") if text: if not text_started: text_index = next_content_index yield _anthropic_sse( "content_block_start", { "type": "content_block_start", "index": text_index, "content_block": {"type": "text", "text": ""}, }, ) text_started = True next_content_index += 1 yield _anthropic_sse( "content_block_delta", { "type": "content_block_delta", "index": text_index, "delta": {"type": "text_delta", "text": text}, }, ) chunk_usage = chunk.get("usage") or {} if chunk_usage.get("completion_tokens") is not None: usage["output_tokens"] = chunk_usage.get( "completion_tokens", usage["output_tokens"] ) _track_usage(chunk_usage) for tool_delta in delta.get("tool_calls") or []: tool_idx = tool_delta.get("index", 0) state = tool_state.setdefault( tool_idx, { "event_index": next_content_index, "id": tool_delta.get("id") or f"toolu_{uuid.uuid4().hex}", "name": ((tool_delta.get("function") or {}).get("name")) or "tool", "arguments": "", "started": False, }, ) if tool_delta.get("id"): state["id"] = tool_delta["id"] function = tool_delta.get("function") or {} if function.get("name"): state["name"] = function["name"] if not state["started"]: next_content_index = max( next_content_index, state["event_index"] + 1 ) yield _anthropic_sse( "content_block_start", { "type": "content_block_start", "index": state["event_index"], "content_block": { "type": "tool_use", "id": state["id"], "name": state["name"], "input": {}, }, }, ) state["started"] = True if function.get("arguments"): state["arguments"] += function["arguments"] yield _anthropic_sse( "content_block_delta", { "type": "content_block_delta", "index": state["event_index"], "delta": { "type": "input_json_delta", "partial_json": function["arguments"], }, }, ) if finish_reason: stop_reason = _map_finish_reason(finish_reason) if not sent_start: yield _anthropic_sse( "message_start", { "type": "message_start", "message": { "id": message_id, "type": "message", "role": "assistant", "content": [], "model": model, "stop_reason": None, "stop_sequence": None, "usage": usage, }, }, ) if text_started: yield _anthropic_sse( "content_block_stop", {"type": "content_block_stop", "index": text_index}, ) elif not tool_state: yield _anthropic_sse( "content_block_start", { "type": "content_block_start", "index": next_content_index, "content_block": {"type": "text", "text": ""}, }, ) yield _anthropic_sse( "content_block_stop", {"type": "content_block_stop", "index": next_content_index}, ) for state in sorted(tool_state.values(), key=lambda item: item["event_index"]): if state["started"]: yield _anthropic_sse( "content_block_stop", {"type": "content_block_stop", "index": state["event_index"]}, ) yield _anthropic_sse( "message_delta", { "type": "message_delta", "delta": {"stop_reason": stop_reason, "stop_sequence": None}, "usage": {"output_tokens": usage["output_tokens"]}, }, ) yield _anthropic_sse("message_stop", {"type": "message_stop"}) finally: await resp.aclose() def _anthropic_sse(event: str, data: dict[str, Any]) -> str: return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" def _track_usage(usage: dict): """Extract token counts from usage and record cost.""" pt = usage.get("prompt_tokens", 0) ct = usage.get("completion_tokens", 0) if pt or ct: # Rough OpenRouter-style cost estimate (per 1M tokens) cost = pt * 0.000015 + ct * 0.000075 _require_token_manager().add_cost(cost) elif usage: _require_token_manager().add_cost(0) async def _proxy_stream(resp, tm) -> AsyncGenerator[str, None]: """Proxy SSE stream from OB-1 backend directly to client.""" try: async for line in resp.aiter_lines(): if line: yield f"{line}\n\n" # Extract usage from the final chunk if line.startswith("data: ") and '"usage"' in line: try: chunk = json.loads(line[6:]) usage = chunk.get("usage") or {} if usage: _track_usage(usage) except Exception: pass finally: await resp.aclose()