SPARKNET / tests /unit /test_planner_migration.py
MHamdan's picture
Initial commit: SPARKNET framework
a9dc537
"""
Test migrated PlannerAgent with LangChain
"""
import asyncio
from src.llm.langchain_ollama_client import get_langchain_client
from src.agents.planner_agent import PlannerAgent
from src.workflow.langgraph_state import ScenarioType
async def test_planner_migration():
print("Testing PlannerAgent migration to LangChain...")
print()
# Initialize LangChain client
client = get_langchain_client(default_complexity='complex', enable_monitoring=False)
print("βœ“ LangChain client initialized")
# Create PlannerAgent
planner = PlannerAgent(llm_client=client)
print("βœ“ PlannerAgent created with LangChain")
print()
# Test 1: Template-based planning
print("Test 1: Template-based planning (patent_wakeup)")
task_graph = await planner.decompose_task(
task_description="Analyze dormant patent US123456 for commercialization",
scenario="patent_wakeup"
)
print(f" βœ“ Generated {len(task_graph.subtasks)} subtasks")
print(f" βœ“ Execution order: {len(task_graph.get_execution_order())} parallel layers")
print(f" βœ“ Graph valid: {task_graph.validate()}")
print()
# Test 2: LangChain-based planning
print("Test 2: LangChain-based planning (custom task)")
try:
task_graph2 = await planner.decompose_task(
task_description="Research market opportunities for AI-powered drug discovery platform"
)
print(f" βœ“ Generated {len(task_graph2.subtasks)} subtasks via LangChain")
print(f" βœ“ Graph valid: {task_graph2.validate()}")
except Exception as e:
print(f" Note: LangChain planning requires Ollama running")
print(f" Error: {e}")
print()
print("βœ“ All PlannerAgent migration tests passed!")
if __name__ == "__main__":
asyncio.run(test_planner_migration())