ob12api / src /api /routes.py
david-baxter's picture
Upload 24 files
cef045d verified
"""OpenAI-compatible API routes — proxies to OB-1 backend."""
from __future__ import annotations
import json
import time
import uuid
from typing import Any, AsyncGenerator, Optional
from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, StreamingResponse
from ..core.auth import verify_api_key
from ..core.logger import get_logger
from ..core.models import AnthropicMessagesRequest, ChatCompletionRequest
from ..services.ob1_client import OB1Client
from ..services.token_manager import OB1TokenManager
log = get_logger("routes")
router = APIRouter()
_token_manager: Optional[OB1TokenManager] = None
_ob1_client: Optional[OB1Client] = None
def init(token_manager: OB1TokenManager, ob1_client: OB1Client):
global _token_manager, _ob1_client
_token_manager = token_manager
_ob1_client = ob1_client
def _require_token_manager() -> OB1TokenManager:
if _token_manager is None:
raise RuntimeError("Token manager is not initialized")
return _token_manager
def _require_ob1_client() -> OB1Client:
if _ob1_client is None:
raise RuntimeError("OB1 client is not initialized")
return _ob1_client
@router.get("/v1/models")
async def list_models(_: str = Depends(verify_api_key)):
token_manager = _require_token_manager()
ob1_client = _require_ob1_client()
api_key = await token_manager.get_api_key()
if not api_key:
return {"object": "list", "data": []}
raw = await ob1_client.fetch_models(api_key)
models = []
for m in raw:
models.append(
{
"id": m["id"],
"object": "model",
"created": m.get("created", 0),
"owned_by": m["id"].split("/")[0] if "/" in m["id"] else "ob1",
"name": m.get("name", m["id"]),
}
)
return {"object": "list", "data": models}
@router.post("/v1/chat/completions")
async def chat_completions(
request: ChatCompletionRequest,
_: str = Depends(verify_api_key),
):
messages = [{"role": m.role, "content": m.content} for m in request.messages]
extra_payload = _build_openai_extra_payload(request.tools, request.tool_choice)
resp = await _send_chat_request(
messages=messages,
model=request.model,
stream=request.stream,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
extra_payload=extra_payload,
)
if isinstance(resp, JSONResponse):
return resp
if request.stream:
log.debug("Streaming response started")
return StreamingResponse(
_proxy_stream(resp, _token_manager),
media_type="text/event-stream",
)
else:
data = resp.json()
usage = data.get("usage", {})
_track_usage(usage)
log.info(
"Chat response: model=%s prompt_tokens=%d completion_tokens=%d",
data.get("model", "?"),
usage.get("prompt_tokens", 0),
usage.get("completion_tokens", 0),
)
return JSONResponse(content=data)
@router.post("/v1/messages")
async def anthropic_messages(
request: AnthropicMessagesRequest,
_: str = Depends(verify_api_key),
):
messages = _anthropic_to_openai_messages(request)
extra_payload = _build_openai_extra_payload(
_anthropic_tools_to_openai(request.tools),
_anthropic_tool_choice_to_openai(request.tool_choice),
)
resp = await _send_chat_request(
messages=messages,
model=request.model,
stream=request.stream,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
extra_payload=extra_payload,
)
if isinstance(resp, JSONResponse):
return resp
if request.stream:
return StreamingResponse(
_proxy_stream_anthropic(resp, request.model),
media_type="text/event-stream",
)
data = resp.json()
usage = data.get("usage", {})
_track_usage(usage)
await resp.aclose()
return JSONResponse(content=_openai_to_anthropic_response(data, request.model))
async def _send_chat_request(
*,
messages: list[dict[str, Any]],
model: str,
stream: bool,
temperature: float | None,
top_p: float | None,
max_tokens: int | None,
extra_payload: dict[str, Any] | None = None,
):
token_manager = _require_token_manager()
ob1_client = _require_ob1_client()
api_key = await token_manager.get_api_key()
if not api_key:
log.warning("No valid OB-1 token available")
return JSONResponse(
status_code=503,
content={"error": "No valid OB-1 token. Run ob1 auth to login."},
)
resolved_model = await _resolve_model_name(model, api_key)
log.info(
"Chat request: model=%s resolved_model=%s stream=%s messages=%d",
model,
resolved_model,
stream,
len(messages),
)
try:
resp = await ob1_client.chat(
api_key=api_key,
messages=messages,
model=resolved_model,
stream=stream,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
extra_payload=extra_payload,
)
except Exception as e:
log.error("Backend error: %s", e)
return JSONResponse(status_code=502, content={"error": f"Backend error: {e}"})
if resp.status_code == 401:
await resp.aclose()
log.warning("Token rejected (401), refreshing...")
ok = await token_manager.refresh()
if not ok:
log.error("Token refresh failed")
return JSONResponse(
status_code=401, content={"error": "Token expired and refresh failed"}
)
api_key = await token_manager.get_api_key()
if not api_key:
return JSONResponse(
status_code=401, content={"error": "Token refresh failed"}
)
try:
resp = await ob1_client.chat(
api_key=api_key,
messages=messages,
model=resolved_model,
stream=stream,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
extra_payload=extra_payload,
)
except Exception as e:
log.error("Backend error after refresh: %s", e)
return JSONResponse(
status_code=502, content={"error": f"Backend error: {e}"}
)
if resp.status_code != 200:
try:
body = (await resp.aread()).decode()
except Exception:
body = "unable to read response body"
await resp.aclose()
log.error("OB-1 returned %d: %s", resp.status_code, body[:200])
return JSONResponse(
status_code=resp.status_code,
content={"error": f"OB-1 returned {resp.status_code}: {body[:500]}"},
)
return resp
async def _resolve_model_name(requested_model: str, api_key: str) -> str:
ob1_client = _require_ob1_client()
raw_models = await ob1_client.fetch_models(api_key)
available = [item.get("id") for item in raw_models if item.get("id")]
if requested_model in available:
return requested_model
anthropic_prefixed = f"anthropic/{requested_model}"
if anthropic_prefixed in available:
return anthropic_prefixed
if requested_model.startswith("claude-"):
lowered = requested_model.lower()
family = None
for candidate in ("haiku", "sonnet", "opus"):
if candidate in lowered:
family = candidate
break
if family:
family_matches = [
model_id
for model_id in available
if model_id.startswith(f"anthropic/claude-{family}")
]
if family_matches:
return sorted(family_matches)[-1]
anthropic_models = [
model_id for model_id in available if model_id.startswith("anthropic/")
]
preferred_order = ["anthropic/claude-sonnet-4.6", "anthropic/claude-opus-4.6"]
for model_id in preferred_order:
if model_id in anthropic_models:
return model_id
if anthropic_models:
return sorted(anthropic_models)[-1]
return requested_model
def _anthropic_to_openai_messages(
request: AnthropicMessagesRequest,
) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = []
if request.system:
messages.append({"role": "system", "content": _flatten_content(request.system)})
for message in request.messages:
blocks = (
message.content
if isinstance(message.content, list)
else [{"type": "text", "text": message.content}]
)
text_parts: list[str] = []
tool_calls: list[dict[str, Any]] = []
tool_results: list[dict[str, Any]] = []
for block in blocks:
if not isinstance(block, dict):
continue
block_type = block.get("type")
if block_type == "text" and isinstance(block.get("text"), str):
text_parts.append(block["text"])
elif block_type == "tool_use":
tool_calls.append(
{
"id": block.get("id") or f"call_{uuid.uuid4().hex}",
"type": "function",
"function": {
"name": block.get("name", "tool"),
"arguments": json.dumps(
block.get("input", {}), ensure_ascii=False
),
},
}
)
elif block_type == "tool_result":
tool_results.append(
{
"role": "tool",
"tool_call_id": block.get("tool_use_id", ""),
"content": _flatten_content(block.get("content", "")),
}
)
if message.role == "assistant":
assistant_message: dict[str, Any] = {"role": "assistant"}
assistant_message["content"] = "\n".join(
part for part in text_parts if part
)
if tool_calls:
assistant_message["tool_calls"] = tool_calls
messages.append(assistant_message)
elif message.role == "user" and tool_results:
messages.extend(tool_results)
if text_parts:
messages.append({"role": "user", "content": "\n".join(text_parts)})
else:
messages.append(
{
"role": message.role,
"content": "\n".join(part for part in text_parts if part),
}
)
return messages
def _anthropic_tools_to_openai(
tools: Optional[list[dict[str, Any]]],
) -> Optional[list[dict[str, Any]]]:
if not tools:
return None
converted: list[dict[str, Any]] = []
for tool in tools:
converted.append(
{
"type": "function",
"function": {
"name": tool.get("name", "tool"),
"description": tool.get("description", ""),
"parameters": tool.get(
"input_schema", {"type": "object", "properties": {}}
),
},
}
)
return converted
def _anthropic_tool_choice_to_openai(
tool_choice: Optional[dict[str, Any]],
) -> Optional[dict[str, Any] | str]:
if not tool_choice:
return None
choice_type = tool_choice.get("type")
if choice_type in {"auto", "none"}:
return choice_type
if choice_type in {"any", "required"}:
return "required"
if choice_type == "tool":
name = tool_choice.get("name")
if name:
return {"type": "function", "function": {"name": name}}
return None
def _build_openai_extra_payload(
tools: Optional[list[dict[str, Any]]],
tool_choice: Optional[dict[str, Any] | str],
) -> Optional[dict[str, Any]]:
extra_payload: dict[str, Any] = {}
if tools:
extra_payload["tools"] = tools
if tool_choice is not None:
extra_payload["tool_choice"] = tool_choice
return extra_payload or None
def _flatten_content(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for block in content:
if not isinstance(block, dict):
continue
if block.get("type") == "text" and isinstance(block.get("text"), str):
parts.append(block["text"])
elif block.get("type") == "tool_result":
parts.append(_flatten_content(block.get("content", "")))
return "\n".join(part for part in parts if part)
return str(content)
def _openai_to_anthropic_response(data: dict[str, Any], model: str) -> dict[str, Any]:
choice = (data.get("choices") or [{}])[0]
message = choice.get("message") or {}
content_blocks: list[dict[str, Any]] = []
text = _flatten_content(message.get("content", ""))
if text:
content_blocks.append({"type": "text", "text": text})
for tool_call in message.get("tool_calls") or []:
function = tool_call.get("function") or {}
content_blocks.append(
{
"type": "tool_use",
"id": tool_call.get("id", f"toolu_{uuid.uuid4().hex}"),
"name": function.get("name", "tool"),
"input": _parse_json_object(function.get("arguments")),
}
)
usage = data.get("usage") or {}
return {
"id": data.get("id", f"msg_{uuid.uuid4().hex}"),
"type": "message",
"role": "assistant",
"content": content_blocks or [{"type": "text", "text": ""}],
"model": data.get("model", model),
"stop_reason": _map_finish_reason(choice.get("finish_reason")),
"stop_sequence": None,
"usage": {
"input_tokens": usage.get("prompt_tokens", 0),
"output_tokens": usage.get("completion_tokens", 0),
},
}
def _map_finish_reason(reason: Optional[str]) -> str:
if reason in {None, "stop"}:
return "end_turn"
if reason == "length":
return "max_tokens"
if reason == "tool_calls":
return "tool_use"
return reason or "end_turn"
def _parse_json_object(value: Any) -> dict[str, Any]:
if isinstance(value, dict):
return value
if isinstance(value, str) and value:
try:
parsed = json.loads(value)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
return {}
async def _proxy_stream_anthropic(resp, model: str) -> AsyncGenerator[str, None]:
message_id = f"msg_{uuid.uuid4().hex}"
sent_start = False
text_started = False
text_index = 0
usage: dict[str, Any] = {"input_tokens": 0, "output_tokens": 0}
stop_reason = "end_turn"
tool_state: dict[int, dict[str, Any]] = {}
next_content_index = 0
try:
async for line in resp.aiter_lines():
if not line or not line.startswith("data: "):
continue
payload = line[6:]
if payload == "[DONE]":
break
try:
chunk = json.loads(payload)
except json.JSONDecodeError:
continue
if not sent_start:
prompt_tokens = ((chunk.get("usage") or {}).get("prompt_tokens")) or 0
usage["input_tokens"] = prompt_tokens
yield _anthropic_sse(
"message_start",
{
"type": "message_start",
"message": {
"id": message_id,
"type": "message",
"role": "assistant",
"content": [],
"model": chunk.get("model", model),
"stop_reason": None,
"stop_sequence": None,
"usage": usage,
},
},
)
sent_start = True
delta = (chunk.get("choices") or [{}])[0].get("delta") or {}
finish_reason = (chunk.get("choices") or [{}])[0].get("finish_reason")
text = delta.get("content")
if text:
if not text_started:
text_index = next_content_index
yield _anthropic_sse(
"content_block_start",
{
"type": "content_block_start",
"index": text_index,
"content_block": {"type": "text", "text": ""},
},
)
text_started = True
next_content_index += 1
yield _anthropic_sse(
"content_block_delta",
{
"type": "content_block_delta",
"index": text_index,
"delta": {"type": "text_delta", "text": text},
},
)
chunk_usage = chunk.get("usage") or {}
if chunk_usage.get("completion_tokens") is not None:
usage["output_tokens"] = chunk_usage.get(
"completion_tokens", usage["output_tokens"]
)
_track_usage(chunk_usage)
for tool_delta in delta.get("tool_calls") or []:
tool_idx = tool_delta.get("index", 0)
state = tool_state.setdefault(
tool_idx,
{
"event_index": next_content_index,
"id": tool_delta.get("id") or f"toolu_{uuid.uuid4().hex}",
"name": ((tool_delta.get("function") or {}).get("name"))
or "tool",
"arguments": "",
"started": False,
},
)
if tool_delta.get("id"):
state["id"] = tool_delta["id"]
function = tool_delta.get("function") or {}
if function.get("name"):
state["name"] = function["name"]
if not state["started"]:
next_content_index = max(
next_content_index, state["event_index"] + 1
)
yield _anthropic_sse(
"content_block_start",
{
"type": "content_block_start",
"index": state["event_index"],
"content_block": {
"type": "tool_use",
"id": state["id"],
"name": state["name"],
"input": {},
},
},
)
state["started"] = True
if function.get("arguments"):
state["arguments"] += function["arguments"]
yield _anthropic_sse(
"content_block_delta",
{
"type": "content_block_delta",
"index": state["event_index"],
"delta": {
"type": "input_json_delta",
"partial_json": function["arguments"],
},
},
)
if finish_reason:
stop_reason = _map_finish_reason(finish_reason)
if not sent_start:
yield _anthropic_sse(
"message_start",
{
"type": "message_start",
"message": {
"id": message_id,
"type": "message",
"role": "assistant",
"content": [],
"model": model,
"stop_reason": None,
"stop_sequence": None,
"usage": usage,
},
},
)
if text_started:
yield _anthropic_sse(
"content_block_stop",
{"type": "content_block_stop", "index": text_index},
)
elif not tool_state:
yield _anthropic_sse(
"content_block_start",
{
"type": "content_block_start",
"index": next_content_index,
"content_block": {"type": "text", "text": ""},
},
)
yield _anthropic_sse(
"content_block_stop",
{"type": "content_block_stop", "index": next_content_index},
)
for state in sorted(tool_state.values(), key=lambda item: item["event_index"]):
if state["started"]:
yield _anthropic_sse(
"content_block_stop",
{"type": "content_block_stop", "index": state["event_index"]},
)
yield _anthropic_sse(
"message_delta",
{
"type": "message_delta",
"delta": {"stop_reason": stop_reason, "stop_sequence": None},
"usage": {"output_tokens": usage["output_tokens"]},
},
)
yield _anthropic_sse("message_stop", {"type": "message_stop"})
finally:
await resp.aclose()
def _anthropic_sse(event: str, data: dict[str, Any]) -> str:
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
def _track_usage(usage: dict):
"""Extract token counts from usage and record cost."""
pt = usage.get("prompt_tokens", 0)
ct = usage.get("completion_tokens", 0)
if pt or ct:
# Rough OpenRouter-style cost estimate (per 1M tokens)
cost = pt * 0.000015 + ct * 0.000075
_require_token_manager().add_cost(cost)
elif usage:
_require_token_manager().add_cost(0)
async def _proxy_stream(resp, tm) -> AsyncGenerator[str, None]:
"""Proxy SSE stream from OB-1 backend directly to client."""
try:
async for line in resp.aiter_lines():
if line:
yield f"{line}\n\n"
# Extract usage from the final chunk
if line.startswith("data: ") and '"usage"' in line:
try:
chunk = json.loads(line[6:])
usage = chunk.get("usage") or {}
if usage:
_track_usage(usage)
except Exception:
pass
finally:
await resp.aclose()