|
|
""" |
|
|
Simple Task Example for SPARKNET |
|
|
Demonstrates basic agent and tool usage |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
|
|
|
from src.llm.ollama_client import OllamaClient |
|
|
from src.agents.executor_agent import ExecutorAgent |
|
|
from src.agents.base_agent import Task |
|
|
from src.tools import register_default_tools |
|
|
from src.utils.logging import setup_logging |
|
|
from src.utils.gpu_manager import get_gpu_manager |
|
|
from loguru import logger |
|
|
|
|
|
|
|
|
async def main(): |
|
|
"""Run simple task example.""" |
|
|
|
|
|
|
|
|
setup_logging(log_level="INFO") |
|
|
|
|
|
logger.info("="*60) |
|
|
logger.info("SPARKNET Simple Task Example") |
|
|
logger.info("="*60) |
|
|
|
|
|
|
|
|
gpu_manager = get_gpu_manager() |
|
|
logger.info("\n" + gpu_manager.monitor()) |
|
|
|
|
|
|
|
|
logger.info("\nInitializing Ollama client...") |
|
|
ollama_client = OllamaClient( |
|
|
host="localhost", |
|
|
port=11434, |
|
|
default_model="llama3.2:latest", |
|
|
) |
|
|
|
|
|
|
|
|
if not ollama_client.is_available(): |
|
|
logger.error("Ollama server is not available! Make sure it's running with 'ollama serve'") |
|
|
return |
|
|
|
|
|
|
|
|
models = ollama_client.list_models() |
|
|
logger.info(f"\nAvailable models: {len(models)}") |
|
|
for model in models: |
|
|
logger.info(f" - {model['name']}") |
|
|
|
|
|
|
|
|
logger.info("\nRegistering tools...") |
|
|
tool_registry = register_default_tools() |
|
|
logger.info(f"Registered {len(tool_registry.list_tools())} tools: {tool_registry.list_tools()}") |
|
|
|
|
|
|
|
|
logger.info("\nCreating ExecutorAgent...") |
|
|
agent = ExecutorAgent( |
|
|
llm_client=ollama_client, |
|
|
model="llama3.2:latest", |
|
|
temperature=0.5, |
|
|
) |
|
|
agent.set_tool_registry(tool_registry) |
|
|
|
|
|
|
|
|
tasks = [ |
|
|
Task( |
|
|
id="task_1", |
|
|
description="Use the gpu_monitor tool to check the status of all GPUs", |
|
|
), |
|
|
Task( |
|
|
id="task_2", |
|
|
description="Use the directory_list tool to list all items in the current directory", |
|
|
), |
|
|
Task( |
|
|
id="task_3", |
|
|
description="Use the python_executor tool to calculate the sum of numbers from 1 to 100", |
|
|
), |
|
|
] |
|
|
|
|
|
|
|
|
logger.info("\n" + "="*60) |
|
|
logger.info("Executing Tasks") |
|
|
logger.info("="*60) |
|
|
|
|
|
for task in tasks: |
|
|
logger.info(f"\nTask {task.id}: {task.description}") |
|
|
logger.info("-" * 60) |
|
|
|
|
|
result = await agent.process_task(task) |
|
|
|
|
|
logger.info(f"Status: {result.status}") |
|
|
if result.result: |
|
|
logger.info(f"Result: {result.result}") |
|
|
if result.error: |
|
|
logger.error(f"Error: {result.error}") |
|
|
|
|
|
logger.info("-" * 60) |
|
|
|
|
|
|
|
|
logger.info("\n" + "="*60) |
|
|
logger.info("Agent Statistics") |
|
|
logger.info("="*60) |
|
|
stats = agent.get_stats() |
|
|
for key, value in stats.items(): |
|
|
logger.info(f"{key}: {value}") |
|
|
|
|
|
logger.info("\n" + "="*60) |
|
|
logger.info("Example completed!") |
|
|
logger.info("="*60) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
asyncio.run(main()) |
|
|
|