File size: 19,810 Bytes
a9dc537
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
"""
PlannerAgent for SPARKNET - LangChain Version
Breaks down complex VISTA scenarios into executable workflows
Uses LangChain chains for structured task decomposition
"""

from typing import List, Dict, Optional, Any
from dataclasses import dataclass, field
from loguru import logger
import json
import networkx as nx
from pydantic import BaseModel, Field

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.messages import HumanMessage, SystemMessage

from .base_agent import BaseAgent, Task, Message
from ..llm.langchain_ollama_client import LangChainOllamaClient
from ..workflow.langgraph_state import SubTask as SubTaskModel, TaskStatus


# Pydantic model for planning output
class TaskDecomposition(BaseModel):
    """Structured output from planning chain"""
    subtasks: List[Dict[str, Any]] = Field(description="List of subtasks with dependencies")
    reasoning: str = Field(description="Explanation of the planning strategy")
    estimated_total_duration: float = Field(description="Total estimated duration in seconds")


@dataclass
class TaskGraph:
    """Directed acyclic graph of tasks with dependencies."""
    subtasks: Dict[str, SubTaskModel] = field(default_factory=dict)
    graph: nx.DiGraph = field(default_factory=nx.DiGraph)

    def add_subtask(self, subtask: SubTaskModel):
        """Add a subtask to the graph."""
        self.subtasks[subtask.id] = subtask
        self.graph.add_node(subtask.id, task=subtask)

        # Add edges for dependencies
        for dep_id in subtask.dependencies:
            if dep_id in self.subtasks:
                self.graph.add_edge(dep_id, subtask.id)

    def get_execution_order(self) -> List[List[str]]:
        """
        Get tasks in execution order (topological sort).
        Returns list of lists - inner lists can be executed in parallel.
        """
        try:
            generations = list(nx.topological_generations(self.graph))
            return generations
        except nx.NetworkXError as e:
            logger.error(f"Error in topological sort: {e}")
            return []

    def validate(self) -> bool:
        """Validate graph has no cycles."""
        return nx.is_directed_acyclic_graph(self.graph)


class PlannerAgent(BaseAgent):
    """
    Agent specialized in task decomposition and workflow planning.
    Uses LangChain chains with qwen2.5:14b for complex reasoning.
    """

    # Scenario templates for common VISTA workflows
    SCENARIO_TEMPLATES = {
        'patent_wakeup': {
            'description': 'Analyze dormant patent and create valorization roadmap',
            'stages': [
                {
                    'name': 'document_analysis',
                    'agent': 'DocumentAnalysisAgent',
                    'description': 'Extract and analyze patent content',
                    'dependencies': [],
                },
                {
                    'name': 'market_analysis',
                    'agent': 'MarketAnalysisAgent',
                    'description': 'Identify market opportunities for patent',
                    'dependencies': ['document_analysis'],
                },
                {
                    'name': 'matchmaking',
                    'agent': 'MatchmakingAgent',
                    'description': 'Match patent with potential licensees',
                    'dependencies': ['document_analysis', 'market_analysis'],
                },
                {
                    'name': 'outreach',
                    'agent': 'OutreachAgent',
                    'description': 'Generate valorization brief and outreach materials',
                    'dependencies': ['matchmaking'],
                },
            ],
        },
        'agreement_safety': {
            'description': 'Review legal agreement for risks and compliance',
            'stages': [
                {
                    'name': 'document_parsing',
                    'agent': 'LegalAnalysisAgent',
                    'description': 'Parse agreement and extract clauses',
                    'dependencies': [],
                },
                {
                    'name': 'compliance_check',
                    'agent': 'ComplianceAgent',
                    'description': 'Check GDPR and Law 25 compliance',
                    'dependencies': ['document_parsing'],
                },
                {
                    'name': 'risk_assessment',
                    'agent': 'RiskAssessmentAgent',
                    'description': 'Identify problematic clauses and risks',
                    'dependencies': ['document_parsing'],
                },
                {
                    'name': 'recommendations',
                    'agent': 'RecommendationAgent',
                    'description': 'Generate improvement suggestions',
                    'dependencies': ['compliance_check', 'risk_assessment'],
                },
            ],
        },
        'partner_matching': {
            'description': 'Match stakeholders based on complementary capabilities',
            'stages': [
                {
                    'name': 'profiling',
                    'agent': 'ProfilingAgent',
                    'description': 'Extract stakeholder capabilities and needs',
                    'dependencies': [],
                },
                {
                    'name': 'semantic_matching',
                    'agent': 'SemanticMatchingAgent',
                    'description': 'Find complementary partners using embeddings',
                    'dependencies': ['profiling'],
                },
                {
                    'name': 'network_analysis',
                    'agent': 'NetworkAnalysisAgent',
                    'description': 'Identify strategic network connections',
                    'dependencies': ['profiling'],
                },
                {
                    'name': 'facilitation',
                    'agent': 'ConnectionFacilitatorAgent',
                    'description': 'Generate introduction materials',
                    'dependencies': ['semantic_matching', 'network_analysis'],
                },
            ],
        },
    }

    def __init__(
        self,
        llm_client: LangChainOllamaClient,
        memory_agent: Optional['MemoryAgent'] = None,
        temperature: float = 0.7,
    ):
        """
        Initialize PlannerAgent with LangChain client.

        Args:
            llm_client: LangChain Ollama client
            memory_agent: Optional memory agent for context
            temperature: LLM temperature for planning
        """
        self.llm_client = llm_client
        self.memory_agent = memory_agent
        self.temperature = temperature

        # Create planning chains
        self.planning_chain = self._create_planning_chain()
        self.refinement_chain = self._create_refinement_chain()

        # Store for backward compatibility
        self.name = "PlannerAgent"
        self.description = "Task decomposition and workflow planning"

        logger.info(f"Initialized PlannerAgent with LangChain (complexity: complex)")

    def _create_planning_chain(self):
        """
        Create LangChain chain for task decomposition.

        Returns:
            Runnable chain: prompt | llm | parser
        """
        system_template = """You are a strategic planning agent for research valorization tasks.

Your role is to:
1. Analyze complex tasks and break them into manageable subtasks
2. Identify dependencies between subtasks
3. Assign appropriate agents to each subtask
4. Estimate task complexity and duration
5. Create optimal execution plans

Available agent types:
- ExecutorAgent: General task execution
- DocumentAnalysisAgent: Analyze patents and documents
- MarketAnalysisAgent: Market research and opportunity identification
- MatchmakingAgent: Stakeholder matching and connections
- OutreachAgent: Generate outreach materials and briefs
- LegalAnalysisAgent: Legal document analysis
- ComplianceAgent: Compliance checking (GDPR, Law 25)
- RiskAssessmentAgent: Risk identification
- ProfilingAgent: Stakeholder profiling
- SemanticMatchingAgent: Semantic similarity matching
- NetworkAnalysisAgent: Network and relationship analysis

Output your plan as a structured JSON object with:
- subtasks: List of subtask objects with id, description, agent_type, dependencies, estimated_duration, priority
- reasoning: Your strategic reasoning for this decomposition
- estimated_total_duration: Total estimated time in seconds"""

        human_template = """Given the following task, create a detailed execution plan:

Task: {task_description}

{context_section}

Break this down into specific subtasks. For each subtask:
- Give it a unique ID (use snake_case)
- Describe what needs to be done
- Specify which agent type should handle it
- List any dependencies (IDs of tasks that must complete first)
- Estimate duration in seconds
- Set priority (1=highest)

Think step-by-step about:
- What is the ultimate goal?
- What information is needed?
- What are the logical stages?
- Which subtasks can run in parallel?
- What are the critical dependencies?

Output JSON only."""

        prompt = ChatPromptTemplate.from_messages([
            ("system", system_template),
            ("human", human_template)
        ])

        # Use complex model for planning
        llm = self.llm_client.get_llm(complexity="complex", temperature=self.temperature)

        # JSON output parser
        parser = JsonOutputParser(pydantic_object=TaskDecomposition)

        # Create chain
        chain = prompt | llm | parser

        return chain

    def _create_refinement_chain(self):
        """
        Create LangChain chain for replanning based on feedback.

        Returns:
            Runnable chain for refinement
        """
        system_template = """You are refining an existing task plan based on feedback.

Your role is to:
1. Review the original plan and feedback
2. Identify what went wrong or could be improved
3. Create an improved plan that addresses the issues
4. Maintain successful elements from the original plan

Be thoughtful about what to change and what to keep."""

        human_template = """Refine the following plan based on feedback:

Original Task: {task_description}

Original Plan:
{original_plan}

Feedback from execution:
{feedback}

Issues encountered:
{issues}

Create an improved plan that addresses these issues while maintaining what worked well.
Output JSON in the same format as before."""

        prompt = ChatPromptTemplate.from_messages([
            ("system", system_template),
            ("human", human_template)
        ])

        llm = self.llm_client.get_llm(complexity="complex", temperature=self.temperature)
        parser = JsonOutputParser(pydantic_object=TaskDecomposition)

        chain = prompt | llm | parser

        return chain

    async def process_task(self, task: Task) -> Task:
        """
        Process planning task by decomposing into workflow.

        Args:
            task: High-level task to plan

        Returns:
            Updated task with plan in result
        """
        logger.info(f"PlannerAgent planning task: {task.id}")
        task.status = "in_progress"

        try:
            # Check if this is a known scenario
            scenario = task.metadata.get('scenario') if task.metadata else None

            if scenario and scenario in self.SCENARIO_TEMPLATES:
                # Use template-based planning
                logger.info(f"Using template for scenario: {scenario}")
                task_graph = await self._plan_from_template(task, scenario)
            else:
                # Use LangChain-based planning for custom tasks
                logger.info("Using LangChain planning for custom task")
                task_graph = await self._plan_with_langchain(task)

            # Validate the graph
            if not task_graph.validate():
                raise ValueError("Generated task graph contains cycles!")

            # Store plan in task result
            task.result = {
                'task_graph': task_graph,
                'execution_order': task_graph.get_execution_order(),
                'total_subtasks': len(task_graph.subtasks),
            }
            task.status = "completed"

            logger.info(f"Planning completed: {len(task_graph.subtasks)} subtasks")

        except Exception as e:
            logger.error(f"Planning failed: {e}")
            task.status = "failed"
            task.error = str(e)

        return task

    async def _plan_from_template(self, task: Task, scenario: str) -> TaskGraph:
        """
        Create task graph from scenario template.

        Args:
            task: Original task
            scenario: Scenario identifier

        Returns:
            TaskGraph based on template
        """
        template = self.SCENARIO_TEMPLATES[scenario]
        task_graph = TaskGraph()

        # Get task parameters
        params = task.metadata.get('parameters', {}) if task.metadata else {}

        # Create subtasks from template stages
        for i, stage in enumerate(template['stages']):
            subtask = SubTaskModel(
                id=f"{task.id}_{stage['name']}",
                description=stage['description'],
                agent_type=stage['agent'],
                dependencies=[f"{task.id}_{dep}" for dep in stage['dependencies']],
                estimated_duration=30.0,
                priority=i + 1,
                parameters=params,
                status=TaskStatus.PENDING
            )
            task_graph.add_subtask(subtask)

        logger.debug(f"Created task graph with {len(task_graph.subtasks)} subtasks from template")

        return task_graph

    async def _plan_with_langchain(self, task: Task, context: Optional[List[Any]] = None) -> TaskGraph:
        """
        Create task graph using LangChain planning chain.

        Args:
            task: Original task
            context: Optional context from memory

        Returns:
            TaskGraph generated by LangChain
        """
        # Prepare context section
        context_section = ""
        if context and len(context) > 0:
            context_section = "Relevant past experiences:\n"
            for i, ctx in enumerate(context[:3], 1):  # Top 3 contexts
                context_section += f"{i}. {ctx.page_content[:200]}...\n"

        # Invoke planning chain
        try:
            result = await self.planning_chain.ainvoke({
                "task_description": task.description,
                "context_section": context_section
            })

            # Parse result into TaskGraph
            task_graph = TaskGraph()

            for subtask_data in result.get('subtasks', []):
                subtask = SubTaskModel(
                    id=f"{task.id}_{subtask_data.get('id', f'subtask_{len(task_graph.subtasks)}')}",
                    description=subtask_data.get('description', ''),
                    agent_type=subtask_data.get('agent_type', 'ExecutorAgent'),
                    dependencies=[f"{task.id}_{dep}" for dep in subtask_data.get('dependencies', [])],
                    estimated_duration=subtask_data.get('estimated_duration', 30.0),
                    priority=subtask_data.get('priority', 0),
                    parameters=subtask_data.get('parameters', {}),
                    status=TaskStatus.PENDING
                )
                task_graph.add_subtask(subtask)

            logger.debug(f"Created task graph with {len(task_graph.subtasks)} subtasks from LangChain")

            return task_graph

        except Exception as e:
            logger.error(f"Failed to parse LangChain planning response: {e}")
            raise ValueError(f"Failed to generate plan: {e}")

    async def decompose_task(
        self,
        task_description: str,
        scenario: Optional[str] = None,
        context: Optional[List[Any]] = None
    ) -> TaskGraph:
        """
        Decompose a high-level task into subtasks.

        Args:
            task_description: Natural language description
            scenario: Optional scenario identifier
            context: Optional context from memory

        Returns:
            TaskGraph with subtasks and dependencies
        """
        # Create a task object
        task = Task(
            id=f"plan_{hash(task_description) % 10000}",
            description=task_description,
            metadata={'scenario': scenario} if scenario else {},
        )

        # Process with planning logic
        result_task = await self.process_task(task)

        if result_task.status == "completed" and result_task.result:
            return result_task.result['task_graph']
        else:
            raise RuntimeError(f"Planning failed: {result_task.error}")

    async def adapt_plan(
        self,
        task_graph: TaskGraph,
        feedback: str,
        issues: List[str]
    ) -> TaskGraph:
        """
        Adapt an existing plan based on execution feedback.

        Args:
            task_graph: Original task graph
            feedback: Feedback from execution
            issues: List of issues encountered

        Returns:
            Updated task graph
        """
        logger.info("Adapting plan based on feedback")

        # Convert task graph to dict for refinement
        original_plan = {
            "subtasks": [
                {
                    "id": st.id,
                    "description": st.description,
                    "agent_type": st.agent_type,
                    "dependencies": st.dependencies
                }
                for st in task_graph.subtasks.values()
            ]
        }

        try:
            # Invoke refinement chain
            result = await self.refinement_chain.ainvoke({
                "task_description": "Refine task decomposition",
                "original_plan": json.dumps(original_plan, indent=2),
                "feedback": feedback,
                "issues": "\n".join(f"- {issue}" for issue in issues)
            })

            # Create new task graph from refined plan
            new_task_graph = TaskGraph()

            for subtask_data in result.get('subtasks', []):
                subtask = SubTaskModel(
                    id=subtask_data.get('id', f'subtask_{len(new_task_graph.subtasks)}'),
                    description=subtask_data.get('description', ''),
                    agent_type=subtask_data.get('agent_type', 'ExecutorAgent'),
                    dependencies=subtask_data.get('dependencies', []),
                    estimated_duration=subtask_data.get('estimated_duration', 30.0),
                    priority=subtask_data.get('priority', 0),
                    parameters=subtask_data.get('parameters', {}),
                    status=TaskStatus.PENDING
                )
                new_task_graph.add_subtask(subtask)

            logger.info(f"Plan adapted: {len(new_task_graph.subtasks)} subtasks")
            return new_task_graph

        except Exception as e:
            logger.error(f"Plan adaptation failed: {e}, returning original plan")
            return task_graph

    def get_parallel_tasks(self, task_graph: TaskGraph) -> List[List[SubTaskModel]]:
        """
        Get tasks that can be executed in parallel.

        Args:
            task_graph: Task graph

        Returns:
            List of parallel task groups
        """
        execution_order = task_graph.get_execution_order()
        parallel_groups = []

        for task_ids in execution_order:
            group = [task_graph.subtasks[task_id] for task_id in task_ids]
            parallel_groups.append(group)

        return parallel_groups