|
|
""" |
|
|
Base Agent for SPARKNET |
|
|
Defines the core agent interface and functionality |
|
|
""" |
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
from typing import List, Dict, Optional, Any |
|
|
from dataclasses import dataclass |
|
|
from datetime import datetime |
|
|
from loguru import logger |
|
|
import json |
|
|
|
|
|
from ..llm.ollama_client import OllamaClient |
|
|
from ..tools.base_tool import BaseTool, ToolRegistry, ToolResult |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Message: |
|
|
"""Message for agent communication.""" |
|
|
role: str |
|
|
content: str |
|
|
sender: Optional[str] = None |
|
|
timestamp: Optional[datetime] = None |
|
|
metadata: Optional[Dict[str, Any]] = None |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.timestamp is None: |
|
|
self.timestamp = datetime.now() |
|
|
|
|
|
def to_dict(self) -> Dict[str, str]: |
|
|
"""Convert to dictionary for Ollama API.""" |
|
|
return { |
|
|
"role": "user" if self.role == "agent" else self.role, |
|
|
"content": self.content, |
|
|
} |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Task: |
|
|
"""Task for agent execution.""" |
|
|
id: str |
|
|
description: str |
|
|
priority: int = 0 |
|
|
status: str = "pending" |
|
|
result: Optional[Any] = None |
|
|
error: Optional[str] = None |
|
|
metadata: Optional[Dict[str, Any]] = None |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.metadata is None: |
|
|
self.metadata = {} |
|
|
|
|
|
|
|
|
class BaseAgent(ABC): |
|
|
"""Base class for all SPARKNET agents.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
name: str, |
|
|
description: str, |
|
|
llm_client: OllamaClient, |
|
|
model: str, |
|
|
system_prompt: str, |
|
|
tools: Optional[List[BaseTool]] = None, |
|
|
temperature: float = 0.7, |
|
|
max_tokens: Optional[int] = None, |
|
|
): |
|
|
""" |
|
|
Initialize agent. |
|
|
|
|
|
Args: |
|
|
name: Agent name |
|
|
description: Agent description |
|
|
llm_client: Ollama client instance |
|
|
model: Model to use |
|
|
system_prompt: System prompt for the agent |
|
|
tools: List of available tools |
|
|
temperature: LLM temperature |
|
|
max_tokens: Max tokens to generate |
|
|
""" |
|
|
self.name = name |
|
|
self.description = description |
|
|
self.llm_client = llm_client |
|
|
self.model = model |
|
|
self.system_prompt = system_prompt |
|
|
self.tools = {tool.name: tool for tool in (tools or [])} |
|
|
self.temperature = temperature |
|
|
self.max_tokens = max_tokens |
|
|
|
|
|
|
|
|
self.messages: List[Message] = [] |
|
|
|
|
|
|
|
|
self.tool_registry: Optional[ToolRegistry] = None |
|
|
|
|
|
logger.info(f"Initialized agent: {self.name} with model {self.model}") |
|
|
|
|
|
def add_tool(self, tool: BaseTool): |
|
|
""" |
|
|
Add a tool to the agent's toolbox. |
|
|
|
|
|
Args: |
|
|
tool: Tool to add |
|
|
""" |
|
|
self.tools[tool.name] = tool |
|
|
logger.info(f"Agent {self.name} added tool: {tool.name}") |
|
|
|
|
|
def remove_tool(self, tool_name: str): |
|
|
""" |
|
|
Remove a tool from the agent's toolbox. |
|
|
|
|
|
Args: |
|
|
tool_name: Name of tool to remove |
|
|
""" |
|
|
if tool_name in self.tools: |
|
|
del self.tools[tool_name] |
|
|
logger.info(f"Agent {self.name} removed tool: {tool_name}") |
|
|
|
|
|
def set_tool_registry(self, registry: ToolRegistry): |
|
|
""" |
|
|
Set the tool registry for accessing shared tools. |
|
|
|
|
|
Args: |
|
|
registry: Tool registry instance |
|
|
""" |
|
|
self.tool_registry = registry |
|
|
|
|
|
async def call_llm( |
|
|
self, |
|
|
prompt: Optional[str] = None, |
|
|
messages: Optional[List[Message]] = None, |
|
|
temperature: Optional[float] = None, |
|
|
) -> str: |
|
|
""" |
|
|
Call the LLM with a prompt or messages. |
|
|
|
|
|
Args: |
|
|
prompt: Single prompt string |
|
|
messages: List of messages |
|
|
temperature: Override temperature |
|
|
|
|
|
Returns: |
|
|
LLM response |
|
|
""" |
|
|
temp = temperature if temperature is not None else self.temperature |
|
|
|
|
|
if prompt: |
|
|
|
|
|
response = self.llm_client.generate( |
|
|
prompt=prompt, |
|
|
model=self.model, |
|
|
system=self.system_prompt, |
|
|
temperature=temp, |
|
|
max_tokens=self.max_tokens, |
|
|
) |
|
|
elif messages: |
|
|
|
|
|
|
|
|
chat_messages = [ |
|
|
{"role": "system", "content": self.system_prompt} |
|
|
] |
|
|
|
|
|
chat_messages.extend([msg.to_dict() for msg in messages]) |
|
|
|
|
|
response = self.llm_client.chat( |
|
|
messages=chat_messages, |
|
|
model=self.model, |
|
|
temperature=temp, |
|
|
) |
|
|
else: |
|
|
raise ValueError("Either prompt or messages must be provided") |
|
|
|
|
|
logger.debug(f"Agent {self.name} received LLM response: {len(response)} chars") |
|
|
return response |
|
|
|
|
|
async def execute_tool(self, tool_name: str, **kwargs) -> ToolResult: |
|
|
""" |
|
|
Execute a tool by name. |
|
|
|
|
|
Args: |
|
|
tool_name: Name of tool to execute |
|
|
**kwargs: Tool parameters |
|
|
|
|
|
Returns: |
|
|
ToolResult from tool execution |
|
|
""" |
|
|
|
|
|
tool = self.tools.get(tool_name) |
|
|
|
|
|
|
|
|
if tool is None and self.tool_registry: |
|
|
tool = self.tool_registry.get_tool(tool_name) |
|
|
|
|
|
if tool is None: |
|
|
logger.error(f"Tool not found: {tool_name}") |
|
|
return ToolResult( |
|
|
success=False, |
|
|
output=None, |
|
|
error=f"Tool not found: {tool_name}", |
|
|
) |
|
|
|
|
|
logger.info(f"Agent {self.name} executing tool: {tool_name}") |
|
|
result = await tool.safe_execute(**kwargs) |
|
|
|
|
|
return result |
|
|
|
|
|
def add_message(self, message: Message): |
|
|
""" |
|
|
Add a message to the agent's history. |
|
|
|
|
|
Args: |
|
|
message: Message to add |
|
|
""" |
|
|
self.messages.append(message) |
|
|
|
|
|
async def receive_message(self, message: Message) -> Optional[str]: |
|
|
""" |
|
|
Receive and process a message from another agent or user. |
|
|
|
|
|
Args: |
|
|
message: Incoming message |
|
|
|
|
|
Returns: |
|
|
Response or None |
|
|
""" |
|
|
logger.info(f"Agent {self.name} received message from {message.sender}") |
|
|
self.add_message(message) |
|
|
|
|
|
|
|
|
return await self.process_message(message) |
|
|
|
|
|
async def process_message(self, message: Message) -> Optional[str]: |
|
|
""" |
|
|
Process an incoming message. Can be overridden by subclasses. |
|
|
|
|
|
Args: |
|
|
message: Message to process |
|
|
|
|
|
Returns: |
|
|
Response or None |
|
|
""" |
|
|
|
|
|
response = await self.call_llm(messages=self.messages) |
|
|
|
|
|
|
|
|
self.add_message( |
|
|
Message( |
|
|
role="assistant", |
|
|
content=response, |
|
|
sender=self.name, |
|
|
) |
|
|
) |
|
|
|
|
|
return response |
|
|
|
|
|
@abstractmethod |
|
|
async def process_task(self, task: Task) -> Task: |
|
|
""" |
|
|
Process a task. Must be implemented by subclasses. |
|
|
|
|
|
Args: |
|
|
task: Task to process |
|
|
|
|
|
Returns: |
|
|
Updated task with results |
|
|
""" |
|
|
pass |
|
|
|
|
|
async def send_message(self, recipient: "BaseAgent", content: str) -> Optional[str]: |
|
|
""" |
|
|
Send a message to another agent. |
|
|
|
|
|
Args: |
|
|
recipient: Recipient agent |
|
|
content: Message content |
|
|
|
|
|
Returns: |
|
|
Response from recipient |
|
|
""" |
|
|
message = Message( |
|
|
role="agent", |
|
|
content=content, |
|
|
sender=self.name, |
|
|
) |
|
|
|
|
|
logger.info(f"Agent {self.name} sending message to {recipient.name}") |
|
|
response = await recipient.receive_message(message) |
|
|
|
|
|
return response |
|
|
|
|
|
def get_available_tools(self) -> List[str]: |
|
|
""" |
|
|
Get list of available tool names. |
|
|
|
|
|
Returns: |
|
|
List of tool names |
|
|
""" |
|
|
tool_names = list(self.tools.keys()) |
|
|
|
|
|
if self.tool_registry: |
|
|
tool_names.extend(self.tool_registry.list_tools()) |
|
|
|
|
|
return list(set(tool_names)) |
|
|
|
|
|
def get_tool_schemas(self) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Get schemas for all available tools. |
|
|
|
|
|
Returns: |
|
|
List of tool schemas |
|
|
""" |
|
|
schemas = [tool.get_schema() for tool in self.tools.values()] |
|
|
|
|
|
if self.tool_registry: |
|
|
schemas.extend(self.tool_registry.get_schemas()) |
|
|
|
|
|
return schemas |
|
|
|
|
|
def clear_history(self): |
|
|
"""Clear message history.""" |
|
|
self.messages.clear() |
|
|
logger.info(f"Agent {self.name} cleared message history") |
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Get agent statistics. |
|
|
|
|
|
Returns: |
|
|
Dictionary with agent stats |
|
|
""" |
|
|
return { |
|
|
"name": self.name, |
|
|
"model": self.model, |
|
|
"messages_count": len(self.messages), |
|
|
"tools_count": len(self.tools), |
|
|
} |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"<Agent: {self.name} (model={self.model}, tools={len(self.tools)})>" |
|
|
|