| | """Model wrapper for LiteLLM""" |
| |
|
| | import os |
| | import json |
| | from typing import List, Dict, Any, Optional |
| |
|
| | try: |
| | import litellm |
| | except ImportError: |
| | print("⚠️ litellm not installed. Install with: pip install litellm") |
| | litellm = None |
| |
|
| |
|
| | class LiteLLMModel: |
| | """Wrapper for LiteLLM models""" |
| | |
| | def __init__(self, model_id: str): |
| | self.model_id = model_id |
| | |
| | |
| | if "groq" in model_id.lower(): |
| | if not os.getenv("GROQ_API_KEY"): |
| | print("⚠️ GROQ_API_KEY not set in environment") |
| | raise RuntimeError("GROQ_API_KEY not set. Please add it to your Space secrets.") |
| | |
| | def generate(self, messages: List[Dict], tools: Optional[List] = None) -> Dict: |
| | if not litellm: |
| | return {"content": "Unknown - litellm not installed"} |
| | |
| | try: |
| | formatted_tools = None |
| | if tools: |
| | formatted_tools = [ |
| | { |
| | "type": "function", |
| | "function": { |
| | "name": tool.name, |
| | "description": tool.description, |
| | "parameters": tool.parameters |
| | } |
| | } |
| | for tool in tools |
| | ] |
| | |
| | |
| | if "groq" in self.model_id.lower(): |
| | api_key = os.getenv("GROQ_API_KEY") |
| | if not api_key: |
| | raise RuntimeError("GROQ_API_KEY not set in environment") |
| | |
| | print(f"DEBUG: Using Groq model: {self.model_id}") |
| | |
| | response = litellm.completion( |
| | model=self.model_id, |
| | api_key=api_key, |
| | messages=messages, |
| | tools=formatted_tools, |
| | temperature=0.1 |
| | ) |
| | else: |
| | |
| | response = litellm.completion( |
| | model=self.model_id, |
| | messages=messages, |
| | tools=formatted_tools, |
| | temperature=0.1 |
| | ) |
| | |
| | message = response.choices[0].message |
| | result = { |
| | "content": message.content or "" |
| | } |
| | |
| | if hasattr(message, 'tool_calls') and message.tool_calls: |
| | result["tool_calls"] = [] |
| | for tc in message.tool_calls: |
| | |
| | args = tc.function.arguments |
| | if isinstance(args, str): |
| | try: |
| | args = json.loads(args) |
| | except: |
| | args = {} |
| | |
| | result["tool_calls"].append({ |
| | "id": tc.id if hasattr(tc, 'id') else f"call_{tc.function.name}", |
| | "name": tc.function.name, |
| | "arguments": args |
| | }) |
| | |
| | return result |
| | |
| | except Exception as e: |
| | print(f"Model error: {e}") |
| | return {"content": "Unknown"} |
| |
|
| |
|
| | def get_model(model_type: str, model_id: str): |
| | if model_type == "LiteLLMModel": |
| | return LiteLLMModel(model_id) |
| | else: |
| | raise ValueError(f"Unknown model type: {model_type}") |