|
|
""" |
|
|
Base Tool for SPARKNET |
|
|
Defines the interface for all tools that agents can use |
|
|
""" |
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
from typing import Any, Dict, Optional |
|
|
from pydantic import BaseModel, Field |
|
|
from loguru import logger |
|
|
import json |
|
|
|
|
|
|
|
|
class ToolParameter(BaseModel): |
|
|
"""Definition of a tool parameter.""" |
|
|
name: str = Field(..., description="Parameter name") |
|
|
type: str = Field(..., description="Parameter type (str, int, float, bool, list, dict)") |
|
|
description: str = Field(..., description="Parameter description") |
|
|
required: bool = Field(default=True, description="Whether parameter is required") |
|
|
default: Optional[Any] = Field(default=None, description="Default value if not required") |
|
|
|
|
|
|
|
|
class ToolResult(BaseModel): |
|
|
"""Result from tool execution.""" |
|
|
success: bool = Field(..., description="Whether execution was successful") |
|
|
output: Any = Field(..., description="Tool output") |
|
|
error: Optional[str] = Field(default=None, description="Error message if failed") |
|
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata") |
|
|
|
|
|
|
|
|
class BaseTool(ABC): |
|
|
"""Base class for all tools.""" |
|
|
|
|
|
def __init__(self, name: str, description: str): |
|
|
""" |
|
|
Initialize tool. |
|
|
|
|
|
Args: |
|
|
name: Tool name |
|
|
description: Tool description |
|
|
""" |
|
|
self.name = name |
|
|
self.description = description |
|
|
self.parameters: list[ToolParameter] = [] |
|
|
|
|
|
@abstractmethod |
|
|
async def execute(self, **kwargs) -> ToolResult: |
|
|
""" |
|
|
Execute the tool with given parameters. |
|
|
|
|
|
Args: |
|
|
**kwargs: Tool parameters |
|
|
|
|
|
Returns: |
|
|
ToolResult with execution results |
|
|
""" |
|
|
pass |
|
|
|
|
|
def add_parameter( |
|
|
self, |
|
|
name: str, |
|
|
param_type: str, |
|
|
description: str, |
|
|
required: bool = True, |
|
|
default: Optional[Any] = None, |
|
|
): |
|
|
""" |
|
|
Add a parameter definition to the tool. |
|
|
|
|
|
Args: |
|
|
name: Parameter name |
|
|
param_type: Parameter type |
|
|
description: Parameter description |
|
|
required: Whether parameter is required |
|
|
default: Default value |
|
|
""" |
|
|
param = ToolParameter( |
|
|
name=name, |
|
|
type=param_type, |
|
|
description=description, |
|
|
required=required, |
|
|
default=default, |
|
|
) |
|
|
self.parameters.append(param) |
|
|
|
|
|
def validate_parameters(self, **kwargs) -> tuple[bool, Optional[str]]: |
|
|
""" |
|
|
Validate provided parameters against tool definition. |
|
|
|
|
|
Args: |
|
|
**kwargs: Provided parameters |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, error_message) |
|
|
""" |
|
|
|
|
|
for param in self.parameters: |
|
|
if param.required and param.name not in kwargs: |
|
|
return False, f"Missing required parameter: {param.name}" |
|
|
|
|
|
|
|
|
for param in self.parameters: |
|
|
if param.name in kwargs: |
|
|
value = kwargs[param.name] |
|
|
expected_type = param.type |
|
|
|
|
|
|
|
|
type_map = { |
|
|
"str": str, |
|
|
"int": int, |
|
|
"float": float, |
|
|
"bool": bool, |
|
|
"list": list, |
|
|
"dict": dict, |
|
|
} |
|
|
|
|
|
if expected_type in type_map: |
|
|
if not isinstance(value, type_map[expected_type]): |
|
|
return False, f"Parameter {param.name} must be of type {expected_type}" |
|
|
|
|
|
return True, None |
|
|
|
|
|
def get_schema(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Get tool schema for LLM function calling. |
|
|
|
|
|
Returns: |
|
|
Tool schema dictionary |
|
|
""" |
|
|
return { |
|
|
"name": self.name, |
|
|
"description": self.description, |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
param.name: { |
|
|
"type": param.type, |
|
|
"description": param.description, |
|
|
} |
|
|
for param in self.parameters |
|
|
}, |
|
|
"required": [param.name for param in self.parameters if param.required], |
|
|
}, |
|
|
} |
|
|
|
|
|
async def safe_execute(self, **kwargs) -> ToolResult: |
|
|
""" |
|
|
Execute tool with parameter validation and error handling. |
|
|
|
|
|
Args: |
|
|
**kwargs: Tool parameters |
|
|
|
|
|
Returns: |
|
|
ToolResult with execution results |
|
|
""" |
|
|
|
|
|
is_valid, error_msg = self.validate_parameters(**kwargs) |
|
|
if not is_valid: |
|
|
logger.error(f"Tool {self.name} parameter validation failed: {error_msg}") |
|
|
return ToolResult(success=False, output=None, error=error_msg) |
|
|
|
|
|
|
|
|
for param in self.parameters: |
|
|
if not param.required and param.name not in kwargs: |
|
|
kwargs[param.name] = param.default |
|
|
|
|
|
|
|
|
try: |
|
|
logger.info(f"Executing tool: {self.name}") |
|
|
result = await self.execute(**kwargs) |
|
|
logger.info(f"Tool {self.name} executed successfully") |
|
|
return result |
|
|
except Exception as e: |
|
|
logger.error(f"Tool {self.name} execution failed: {e}") |
|
|
return ToolResult( |
|
|
success=False, |
|
|
output=None, |
|
|
error=str(e), |
|
|
) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"<Tool: {self.name}>" |
|
|
|
|
|
|
|
|
class ToolRegistry: |
|
|
"""Registry for managing available tools.""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Initialize tool registry.""" |
|
|
self.tools: Dict[str, BaseTool] = {} |
|
|
logger.info("Tool registry initialized") |
|
|
|
|
|
def register(self, tool: BaseTool): |
|
|
""" |
|
|
Register a tool. |
|
|
|
|
|
Args: |
|
|
tool: Tool instance to register |
|
|
""" |
|
|
self.tools[tool.name] = tool |
|
|
logger.info(f"Registered tool: {tool.name}") |
|
|
|
|
|
def unregister(self, tool_name: str): |
|
|
""" |
|
|
Unregister a tool. |
|
|
|
|
|
Args: |
|
|
tool_name: Name of tool to unregister |
|
|
""" |
|
|
if tool_name in self.tools: |
|
|
del self.tools[tool_name] |
|
|
logger.info(f"Unregistered tool: {tool_name}") |
|
|
|
|
|
def get_tool(self, tool_name: str) -> Optional[BaseTool]: |
|
|
""" |
|
|
Get a tool by name. |
|
|
|
|
|
Args: |
|
|
tool_name: Name of tool |
|
|
|
|
|
Returns: |
|
|
Tool instance or None |
|
|
""" |
|
|
return self.tools.get(tool_name) |
|
|
|
|
|
def list_tools(self) -> list[str]: |
|
|
""" |
|
|
List all registered tools. |
|
|
|
|
|
Returns: |
|
|
List of tool names |
|
|
""" |
|
|
return list(self.tools.keys()) |
|
|
|
|
|
def get_schemas(self) -> list[Dict[str, Any]]: |
|
|
""" |
|
|
Get schemas for all tools. |
|
|
|
|
|
Returns: |
|
|
List of tool schemas |
|
|
""" |
|
|
return [tool.get_schema() for tool in self.tools.values()] |
|
|
|
|
|
|
|
|
|
|
|
_tool_registry: Optional[ToolRegistry] = None |
|
|
|
|
|
|
|
|
def get_tool_registry() -> ToolRegistry: |
|
|
"""Get or create the global tool registry.""" |
|
|
global _tool_registry |
|
|
if _tool_registry is None: |
|
|
_tool_registry = ToolRegistry() |
|
|
return _tool_registry |
|
|
|