from __future__ import annotations import json import os import re from dataclasses import dataclass from typing import Any, Type import litellm from litellm import completion from pydantic import BaseModel, ValidationError @dataclass class ModelConfig: provider: str model: str temperature: float = 0.2 max_tokens: int = 12000 @property def model_name(self) -> str: if "/" in self.model: return self.model if self.provider.lower() == "openai": return f"openai/{self.model}" if self.provider.lower() == "gemini": return f"gemini/{self.model}" return self.model class MultiProviderLLMClient: def __init__(self, default_config: ModelConfig, stage_models: dict[str, str] | None = None): self.default_config = default_config self.stage_models = stage_models or {} litellm.drop_params = True self._validate_env(default_config.provider) def _validate_env(self, provider: str) -> None: provider = provider.lower() if provider == "openai" and not os.getenv("OPENAI_API_KEY"): raise ValueError("OPENAI_API_KEY is required for provider=openai") if provider == "gemini" and not os.getenv("GEMINI_API_KEY"): raise ValueError("GEMINI_API_KEY is required for provider=gemini") def config_for_stage(self, stage_name: str) -> ModelConfig: model_override = self.stage_models.get(stage_name) if not model_override: return self.default_config provider = self.default_config.provider model = model_override if "/" in model_override: provider, model = model_override.split("/", 1) self._validate_env(provider) return ModelConfig( provider=provider, model=model, temperature=self.default_config.temperature, max_tokens=self.default_config.max_tokens, ) def generate_structured( self, *, stage_name: str, system_prompt: str, user_prompt: str, response_model: Type[BaseModel], ) -> BaseModel: config = self.config_for_stage(stage_name) completion_kwargs = { "model": config.model_name, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], "max_tokens": config.max_tokens, "response_format": {"type": "json_object"}, } temperature = self._temperature_for_model(config) if temperature is not None: completion_kwargs["temperature"] = temperature response = completion( **completion_kwargs, ) content = response.choices[0].message.content or "" payload = self._parse_json(content) try: return response_model.model_validate(payload) except ValidationError as exc: if isinstance(payload, list) and len(payload) == 1 and isinstance(payload[0], dict): try: return response_model.model_validate(payload[0]) except ValidationError: pass raise ValueError( f"Stage {stage_name} returned invalid JSON for {response_model.__name__}: {exc}\nRaw content:\n{content}" ) from exc def generate_text( self, *, stage_name: str, system_prompt: str, user_prompt: str, ) -> str: config = self.config_for_stage(stage_name) completion_kwargs = { "model": config.model_name, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], "max_tokens": config.max_tokens, } temperature = self._temperature_for_model(config) if temperature is not None: completion_kwargs["temperature"] = temperature response = completion(**completion_kwargs) return (response.choices[0].message.content or "").strip() @staticmethod def _parse_json(text: str) -> Any: text = text.strip() if text.startswith("```"): match = re.search(r"```(?:json)?\s*(.*?)```", text, flags=re.S) if match: text = match.group(1).strip() try: return json.loads(text) except json.JSONDecodeError: match = re.search(r"(\{.*\}|\[.*\])", text, flags=re.S) if match: return json.loads(match.group(1)) raise @staticmethod def _temperature_for_model(config: ModelConfig) -> float | None: model_name = config.model_name.lower() if "gpt-5" in model_name: return None return config.temperature