|
|
""" |
|
|
LangGraph Workflow for SPARKNET |
|
|
Implements cyclic multi-agent workflows with StateGraph |
|
|
""" |
|
|
|
|
|
from typing import Literal, Dict, Any, Optional |
|
|
from datetime import datetime |
|
|
from loguru import logger |
|
|
|
|
|
from langgraph.graph import StateGraph, END |
|
|
from langgraph.checkpoint.memory import MemorySaver |
|
|
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage |
|
|
|
|
|
from .langgraph_state import ( |
|
|
AgentState, |
|
|
ScenarioType, |
|
|
TaskStatus, |
|
|
WorkflowOutput, |
|
|
create_initial_state, |
|
|
state_to_output, |
|
|
) |
|
|
from ..llm.langchain_ollama_client import LangChainOllamaClient |
|
|
|
|
|
|
|
|
class SparknetWorkflow: |
|
|
""" |
|
|
LangGraph-powered workflow orchestrator for SPARKNET. |
|
|
|
|
|
Implements cyclic workflow with conditional routing: |
|
|
START → PLANNER → ROUTER → [scenario executors] → CRITIC |
|
|
↑ ↓ |
|
|
└────────── REFINE ←──────────────────────┘ |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
llm_client: LangChainOllamaClient, |
|
|
planner_agent: Optional[Any] = None, |
|
|
critic_agent: Optional[Any] = None, |
|
|
memory_agent: Optional[Any] = None, |
|
|
vision_ocr_agent: Optional[Any] = None, |
|
|
quality_threshold: float = 0.85, |
|
|
max_iterations: int = 3, |
|
|
): |
|
|
self.llm_client = llm_client |
|
|
self.planner_agent = planner_agent |
|
|
self.critic_agent = critic_agent |
|
|
self.memory_agent = memory_agent |
|
|
self.vision_ocr_agent = vision_ocr_agent |
|
|
self.quality_threshold = quality_threshold |
|
|
self.max_iterations = max_iterations |
|
|
|
|
|
self.graph = self._build_graph() |
|
|
self.checkpointer = MemorySaver() |
|
|
self.app = self.graph.compile(checkpointer=self.checkpointer) |
|
|
|
|
|
if vision_ocr_agent: |
|
|
logger.info("Initialized SparknetWorkflow with LangGraph StateGraph and VisionOCR support") |
|
|
else: |
|
|
logger.info("Initialized SparknetWorkflow with LangGraph StateGraph") |
|
|
|
|
|
def _build_graph(self) -> StateGraph: |
|
|
workflow = StateGraph(AgentState) |
|
|
|
|
|
workflow.add_node("planner", self._planner_node) |
|
|
workflow.add_node("router", self._router_node) |
|
|
workflow.add_node("executor", self._executor_node) |
|
|
workflow.add_node("critic", self._critic_node) |
|
|
workflow.add_node("refine", self._refine_node) |
|
|
workflow.add_node("finish", self._finish_node) |
|
|
|
|
|
workflow.set_entry_point("planner") |
|
|
workflow.add_edge("planner", "router") |
|
|
workflow.add_edge("router", "executor") |
|
|
workflow.add_edge("executor", "critic") |
|
|
|
|
|
workflow.add_conditional_edges( |
|
|
"critic", |
|
|
self._should_refine, |
|
|
{ |
|
|
"refine": "refine", |
|
|
"finish": "finish", |
|
|
} |
|
|
) |
|
|
|
|
|
workflow.add_edge("refine", "planner") |
|
|
workflow.add_edge("finish", END) |
|
|
|
|
|
return workflow |
|
|
|
|
|
async def _planner_node(self, state: AgentState) -> AgentState: |
|
|
logger.info(f"PLANNER node processing task: {state['task_id']}") |
|
|
state["status"] = TaskStatus.PLANNING |
|
|
state["current_agent"] = "PlannerAgent" |
|
|
|
|
|
|
|
|
context_docs = [] |
|
|
if self.memory_agent: |
|
|
try: |
|
|
logger.info("Retrieving relevant context from memory...") |
|
|
context_docs = await self.memory_agent.retrieve_relevant_context( |
|
|
query=state["task_description"], |
|
|
context_type="all", |
|
|
top_k=3, |
|
|
scenario_filter=state["scenario"], |
|
|
min_quality_score=0.8 |
|
|
) |
|
|
if context_docs: |
|
|
logger.info(f"Retrieved {len(context_docs)} relevant memories") |
|
|
|
|
|
state["agent_outputs"]["memory_context"] = [ |
|
|
{"content": doc.page_content, "metadata": doc.metadata} |
|
|
for doc in context_docs |
|
|
] |
|
|
except Exception as e: |
|
|
logger.warning(f"Memory retrieval failed: {e}") |
|
|
|
|
|
system_msg = SystemMessage(content="Decompose the task into executable subtasks.") |
|
|
|
|
|
|
|
|
context_text = "" |
|
|
if context_docs: |
|
|
context_text = "\n\nRelevant past experiences:\n" |
|
|
for i, doc in enumerate(context_docs, 1): |
|
|
context_text += f"\n{i}. {doc.page_content[:200]}..." |
|
|
|
|
|
user_msg = HumanMessage( |
|
|
content=f"Task: {state['task_description']}\nScenario: {state['scenario']}{context_text}" |
|
|
) |
|
|
|
|
|
llm = self.llm_client.get_llm(complexity="complex") |
|
|
|
|
|
if self.planner_agent: |
|
|
from ..agents.base_agent import Task |
|
|
task = Task( |
|
|
id=state["task_id"], |
|
|
description=state["task_description"], |
|
|
metadata={"scenario": state["scenario"].value} |
|
|
) |
|
|
result_task = await self.planner_agent.process_task(task) |
|
|
|
|
|
if result_task.status == "completed": |
|
|
state["subtasks"] = [ |
|
|
{ |
|
|
"id": st.id, |
|
|
"description": st.description, |
|
|
"agent_type": st.agent_type, |
|
|
"dependencies": st.dependencies, |
|
|
} |
|
|
for st in result_task.result["task_graph"].subtasks.values() |
|
|
] |
|
|
state["execution_order"] = result_task.result["execution_order"] |
|
|
response_msg = AIMessage(content=f"Created plan with {len(state['subtasks'])} subtasks") |
|
|
state["messages"].append(response_msg) |
|
|
else: |
|
|
response = await llm.ainvoke([system_msg, user_msg]) |
|
|
state["messages"].append(response) |
|
|
state["subtasks"] = [ |
|
|
{"id": "subtask_1", "description": "Execute primary task", "agent_type": "ExecutorAgent", "dependencies": []} |
|
|
] |
|
|
state["execution_order"] = [["subtask_1"]] |
|
|
|
|
|
logger.info(f"Planning completed: {len(state.get('subtasks', []))} subtasks created") |
|
|
return state |
|
|
|
|
|
async def _router_node(self, state: AgentState) -> AgentState: |
|
|
logger.info(f"ROUTER node routing for scenario: {state['scenario']}") |
|
|
state["current_agent"] = "Router" |
|
|
|
|
|
scenario = state["scenario"] |
|
|
routing_msg = AIMessage(content=f"Routing to {scenario.value} workflow agents") |
|
|
state["messages"].append(routing_msg) |
|
|
|
|
|
state["agent_outputs"]["router"] = { |
|
|
"scenario": scenario.value, |
|
|
"agents_to_use": self._get_scenario_agents(scenario) |
|
|
} |
|
|
|
|
|
return state |
|
|
|
|
|
async def _executor_node(self, state: AgentState) -> AgentState: |
|
|
logger.info(f"EXECUTOR node executing for scenario: {state['scenario']}") |
|
|
state["status"] = TaskStatus.EXECUTING |
|
|
state["current_agent"] = "Executor" |
|
|
|
|
|
scenario = state["scenario"] |
|
|
|
|
|
|
|
|
if scenario == ScenarioType.PATENT_WAKEUP: |
|
|
logger.info("🎯 Routing to Patent Wake-Up pipeline") |
|
|
return await self._execute_patent_wakeup(state) |
|
|
|
|
|
|
|
|
agents = self._get_scenario_agents(scenario) |
|
|
|
|
|
|
|
|
from ..tools.langchain_tools import get_vista_tools |
|
|
tools = get_vista_tools(scenario.value) |
|
|
logger.info(f"Loaded {len(tools)} tools for scenario: {scenario.value}") |
|
|
|
|
|
|
|
|
llm = self.llm_client.get_llm(complexity="standard") |
|
|
llm_with_tools = llm.bind_tools(tools) |
|
|
|
|
|
|
|
|
tool_descriptions = "\n".join([f"- {tool.name}: {tool.description}" for tool in tools]) |
|
|
execution_prompt = HumanMessage( |
|
|
content=f"""Execute the following task using the available tools when needed: |
|
|
|
|
|
Task: {state['task_description']} |
|
|
Scenario: {scenario.value} |
|
|
|
|
|
Available tools: |
|
|
{tool_descriptions} |
|
|
|
|
|
Provide detailed results.""" |
|
|
) |
|
|
|
|
|
|
|
|
response = await llm_with_tools.ainvoke([execution_prompt]) |
|
|
state["messages"].append(response) |
|
|
|
|
|
|
|
|
tool_calls = [] |
|
|
if hasattr(response, 'tool_calls') and response.tool_calls: |
|
|
logger.info(f"LLM requested {len(response.tool_calls)} tool calls") |
|
|
for tool_call in response.tool_calls: |
|
|
tool_name = tool_call.get('name', 'unknown') |
|
|
tool_calls.append(tool_name) |
|
|
logger.info(f"Tool called: {tool_name}") |
|
|
|
|
|
state["agent_outputs"]["executor"] = { |
|
|
"result": response.content, |
|
|
"agents_used": agents, |
|
|
"tools_available": [tool.name for tool in tools], |
|
|
"tools_called": tool_calls, |
|
|
} |
|
|
state["final_output"] = response.content |
|
|
|
|
|
logger.info("Execution completed") |
|
|
return state |
|
|
|
|
|
async def _execute_patent_wakeup(self, state: AgentState) -> AgentState: |
|
|
""" |
|
|
Execute Patent Wake-Up scenario pipeline. |
|
|
Sequential execution: Document → Market → Matchmaking → Outreach |
|
|
""" |
|
|
logger.info("🚀 Executing Patent Wake-Up pipeline") |
|
|
|
|
|
|
|
|
from ..agents.scenario1 import ( |
|
|
DocumentAnalysisAgent, |
|
|
MarketAnalysisAgent, |
|
|
MatchmakingAgent, |
|
|
OutreachAgent |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
patent_path = state.get("input_data", {}).get("patent_path", "mock_patent.txt") |
|
|
|
|
|
try: |
|
|
|
|
|
logger.info("📄 Step 1/4: Analyzing patent document...") |
|
|
doc_agent = DocumentAnalysisAgent( |
|
|
llm_client=self.llm_client, |
|
|
memory_agent=self.memory_agent, |
|
|
vision_ocr_agent=self.vision_ocr_agent |
|
|
) |
|
|
patent_analysis = await doc_agent.analyze_patent(patent_path) |
|
|
state["agent_outputs"]["document_analysis"] = patent_analysis.model_dump() |
|
|
logger.success(f"✅ Patent analyzed: {patent_analysis.title}") |
|
|
|
|
|
|
|
|
logger.info("📊 Step 2/4: Analyzing market opportunities...") |
|
|
market_agent = MarketAnalysisAgent( |
|
|
llm_client=self.llm_client, |
|
|
memory_agent=self.memory_agent |
|
|
) |
|
|
market_analysis = await market_agent.analyze_market(patent_analysis) |
|
|
state["agent_outputs"]["market_analysis"] = market_analysis.model_dump() |
|
|
logger.success(f"✅ Market analyzed: {len(market_analysis.opportunities)} opportunities") |
|
|
|
|
|
|
|
|
logger.info("🤝 Step 3/4: Finding potential partners...") |
|
|
matching_agent = MatchmakingAgent( |
|
|
llm_client=self.llm_client, |
|
|
memory_agent=self.memory_agent |
|
|
) |
|
|
matches = await matching_agent.find_matches( |
|
|
patent_analysis, |
|
|
market_analysis, |
|
|
max_matches=10 |
|
|
) |
|
|
state["agent_outputs"]["matches"] = [m.model_dump() for m in matches] |
|
|
logger.success(f"✅ Found {len(matches)} potential partners") |
|
|
|
|
|
|
|
|
logger.info("📝 Step 4/4: Creating valorization brief...") |
|
|
outreach_agent = OutreachAgent( |
|
|
llm_client=self.llm_client, |
|
|
memory_agent=self.memory_agent |
|
|
) |
|
|
brief = await outreach_agent.create_valorization_brief( |
|
|
patent_analysis, |
|
|
market_analysis, |
|
|
matches |
|
|
) |
|
|
state["agent_outputs"]["brief"] = brief.model_dump() |
|
|
state["final_output"] = brief.content |
|
|
logger.success(f"✅ Brief created: {brief.pdf_path}") |
|
|
|
|
|
|
|
|
state["agent_outputs"]["executor"] = { |
|
|
"result": f"Patent Wake-Up workflow completed successfully", |
|
|
"patent_title": patent_analysis.title, |
|
|
"opportunities_found": len(market_analysis.opportunities), |
|
|
"matches_found": len(matches), |
|
|
"brief_path": brief.pdf_path, |
|
|
"agents_used": ["DocumentAnalysisAgent", "MarketAnalysisAgent", |
|
|
"MatchmakingAgent", "OutreachAgent"], |
|
|
} |
|
|
|
|
|
logger.success("✅ Patent Wake-Up pipeline completed successfully!") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Patent Wake-Up pipeline failed: {e}") |
|
|
state["agent_outputs"]["executor"] = { |
|
|
"result": f"Pipeline failed: {str(e)}", |
|
|
"error": str(e), |
|
|
"agents_used": [], |
|
|
} |
|
|
state["final_output"] = f"Error: {str(e)}" |
|
|
|
|
|
return state |
|
|
|
|
|
async def _critic_node(self, state: AgentState) -> AgentState: |
|
|
logger.info(f"CRITIC node validating output") |
|
|
state["status"] = TaskStatus.VALIDATING |
|
|
state["current_agent"] = "CriticAgent" |
|
|
|
|
|
if self.critic_agent: |
|
|
from ..agents.base_agent import Task |
|
|
task = Task( |
|
|
id=state["task_id"], |
|
|
description=state["task_description"], |
|
|
metadata={ |
|
|
"output_to_validate": state["final_output"], |
|
|
"output_type": self._get_output_type(state["scenario"]) |
|
|
} |
|
|
) |
|
|
result_task = await self.critic_agent.process_task(task) |
|
|
|
|
|
if result_task.status == "completed": |
|
|
validation = result_task.result |
|
|
state["validation_score"] = validation.overall_score |
|
|
state["validation_feedback"] = self.critic_agent.get_feedback_for_iteration(validation) |
|
|
state["validation_issues"] = validation.issues |
|
|
state["validation_suggestions"] = validation.suggestions |
|
|
|
|
|
feedback_msg = AIMessage( |
|
|
content=f"Validation score: {validation.overall_score:.2f}\n{state['validation_feedback']}" |
|
|
) |
|
|
state["messages"].append(feedback_msg) |
|
|
else: |
|
|
llm = self.llm_client.get_llm(complexity="analysis") |
|
|
validation_prompt = HumanMessage( |
|
|
content=f"Validate the following output:\n\n{state['final_output']}\n\nProvide a quality score (0.0-1.0) and feedback." |
|
|
) |
|
|
|
|
|
response = await llm.ainvoke([validation_prompt]) |
|
|
state["messages"].append(response) |
|
|
|
|
|
state["validation_score"] = 0.90 |
|
|
state["validation_feedback"] = response.content |
|
|
state["validation_issues"] = [] |
|
|
state["validation_suggestions"] = [] |
|
|
|
|
|
logger.info(f"Validation completed: score={state['validation_score']:.2f}") |
|
|
return state |
|
|
|
|
|
async def _refine_node(self, state: AgentState) -> AgentState: |
|
|
logger.info(f"REFINE node preparing for iteration {state['iteration_count'] + 1}") |
|
|
state["status"] = TaskStatus.REFINING |
|
|
state["current_agent"] = "Refiner" |
|
|
state["iteration_count"] += 1 |
|
|
|
|
|
refine_msg = HumanMessage( |
|
|
content=f"Iteration {state['iteration_count']}: Address the following issues:\n{state['validation_feedback']}" |
|
|
) |
|
|
state["messages"].append(refine_msg) |
|
|
|
|
|
state["intermediate_results"].append({ |
|
|
"iteration": state["iteration_count"] - 1, |
|
|
"output": state["final_output"], |
|
|
"score": state["validation_score"], |
|
|
"feedback": state["validation_feedback"], |
|
|
}) |
|
|
|
|
|
logger.info(f"Refinement prepared for iteration {state['iteration_count']}") |
|
|
return state |
|
|
|
|
|
async def _finish_node(self, state: AgentState) -> AgentState: |
|
|
logger.info(f"FINISH node completing workflow") |
|
|
state["status"] = TaskStatus.COMPLETED |
|
|
state["current_agent"] = None |
|
|
state["success"] = True |
|
|
state["end_time"] = datetime.now() |
|
|
state["execution_time_seconds"] = (state["end_time"] - state["start_time"]).total_seconds() |
|
|
|
|
|
|
|
|
if self.memory_agent and state.get("validation_score", 0) >= 0.75: |
|
|
try: |
|
|
logger.info("Storing episode in memory...") |
|
|
await self.memory_agent.store_episode( |
|
|
task_id=state["task_id"], |
|
|
task_description=state["task_description"], |
|
|
scenario=state["scenario"], |
|
|
workflow_steps=state.get("subtasks", []), |
|
|
outcome={ |
|
|
"final_output": state["final_output"], |
|
|
"validation_score": state.get("validation_score", 0), |
|
|
"success": state["success"], |
|
|
"tools_used": state.get("agent_outputs", {}).get("executor", {}).get("tools_called", []), |
|
|
}, |
|
|
quality_score=state.get("validation_score", 0), |
|
|
execution_time=state["execution_time_seconds"], |
|
|
iterations_used=state.get("iteration_count", 0), |
|
|
) |
|
|
logger.info(f"Episode stored: {state['task_id']}") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to store episode: {e}") |
|
|
|
|
|
completion_msg = AIMessage( |
|
|
content=f"Workflow completed successfully in {state['execution_time_seconds']:.2f}s" |
|
|
) |
|
|
state["messages"].append(completion_msg) |
|
|
|
|
|
logger.info(f"Workflow completed: {state['task_id']}") |
|
|
return state |
|
|
|
|
|
def _should_refine(self, state: AgentState) -> Literal["refine", "finish"]: |
|
|
score = state.get("validation_score", 0.0) |
|
|
iterations = state.get("iteration_count", 0) |
|
|
|
|
|
if score >= self.quality_threshold: |
|
|
logger.info(f"Quality threshold met ({score:.2f} >= {self.quality_threshold}), finishing") |
|
|
return "finish" |
|
|
|
|
|
if iterations >= state.get("max_iterations", self.max_iterations): |
|
|
logger.warning(f"Max iterations reached ({iterations}), finishing anyway") |
|
|
return "finish" |
|
|
|
|
|
logger.info(f"Refining (score={score:.2f}, iteration={iterations})") |
|
|
return "refine" |
|
|
|
|
|
def _get_scenario_agents(self, scenario: ScenarioType) -> list: |
|
|
scenario_map = { |
|
|
ScenarioType.PATENT_WAKEUP: ["DocumentAnalysisAgent", "MarketAnalysisAgent", "MatchmakingAgent", "OutreachAgent"], |
|
|
ScenarioType.AGREEMENT_SAFETY: ["LegalAnalysisAgent", "ComplianceAgent", "RiskAssessmentAgent", "RecommendationAgent"], |
|
|
ScenarioType.PARTNER_MATCHING: ["ProfilingAgent", "SemanticMatchingAgent", "NetworkAnalysisAgent", "ConnectionFacilitatorAgent"], |
|
|
ScenarioType.GENERAL: ["ExecutorAgent"] |
|
|
} |
|
|
return scenario_map.get(scenario, ["ExecutorAgent"]) |
|
|
|
|
|
def _get_output_type(self, scenario: ScenarioType) -> str: |
|
|
type_map = { |
|
|
ScenarioType.PATENT_WAKEUP: "patent_analysis", |
|
|
ScenarioType.AGREEMENT_SAFETY: "legal_review", |
|
|
ScenarioType.PARTNER_MATCHING: "stakeholder_matching", |
|
|
ScenarioType.GENERAL: "general" |
|
|
} |
|
|
return type_map.get(scenario, "general") |
|
|
|
|
|
async def run( |
|
|
self, |
|
|
task_description: str, |
|
|
scenario: ScenarioType = ScenarioType.GENERAL, |
|
|
task_id: Optional[str] = None, |
|
|
input_data: Optional[Dict[str, Any]] = None, |
|
|
config: Optional[Dict[str, Any]] = None, |
|
|
) -> WorkflowOutput: |
|
|
if task_id is None: |
|
|
task_id = f"task_{hash(task_description) % 100000}" |
|
|
|
|
|
initial_state = create_initial_state( |
|
|
task_id=task_id, |
|
|
task_description=task_description, |
|
|
scenario=scenario, |
|
|
max_iterations=self.max_iterations, |
|
|
input_data=input_data, |
|
|
) |
|
|
|
|
|
logger.info(f"Starting workflow for task: {task_id}") |
|
|
|
|
|
try: |
|
|
final_state = await self.app.ainvoke( |
|
|
initial_state, |
|
|
config=config or {"configurable": {"thread_id": task_id}} |
|
|
) |
|
|
|
|
|
output = state_to_output(final_state) |
|
|
logger.info(f"Workflow completed successfully: {task_id}") |
|
|
return output |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Workflow failed: {e}") |
|
|
initial_state["status"] = TaskStatus.FAILED |
|
|
initial_state["success"] = False |
|
|
initial_state["error"] = str(e) |
|
|
initial_state["end_time"] = datetime.now() |
|
|
return state_to_output(initial_state) |
|
|
|
|
|
async def stream( |
|
|
self, |
|
|
task_description: str, |
|
|
scenario: ScenarioType = ScenarioType.GENERAL, |
|
|
task_id: Optional[str] = None, |
|
|
config: Optional[Dict[str, Any]] = None, |
|
|
): |
|
|
if task_id is None: |
|
|
task_id = f"task_{hash(task_description) % 100000}" |
|
|
|
|
|
initial_state = create_initial_state( |
|
|
task_id=task_id, |
|
|
task_description=task_description, |
|
|
scenario=scenario, |
|
|
max_iterations=self.max_iterations, |
|
|
) |
|
|
|
|
|
async for event in self.app.astream( |
|
|
initial_state, |
|
|
config=config or {"configurable": {"thread_id": task_id}} |
|
|
): |
|
|
yield event |
|
|
|
|
|
|
|
|
def create_workflow( |
|
|
llm_client: LangChainOllamaClient, |
|
|
planner_agent: Optional[Any] = None, |
|
|
critic_agent: Optional[Any] = None, |
|
|
memory_agent: Optional[Any] = None, |
|
|
vision_ocr_agent: Optional[Any] = None, |
|
|
quality_threshold: float = 0.85, |
|
|
max_iterations: int = 3, |
|
|
) -> SparknetWorkflow: |
|
|
return SparknetWorkflow( |
|
|
llm_client=llm_client, |
|
|
planner_agent=planner_agent, |
|
|
critic_agent=critic_agent, |
|
|
memory_agent=memory_agent, |
|
|
vision_ocr_agent=vision_ocr_agent, |
|
|
quality_threshold=quality_threshold, |
|
|
max_iterations=max_iterations, |
|
|
) |
|
|
|