| from typing import Dict, Any, List, Optional, Tuple |
| import random |
| import json |
| import os |
| from pydantic import Field |
|
|
| from ...workflow.action_graph import ActionGraph |
| from ...models.model_configs import LLMConfig, OpenAILLMConfig |
| from ...models.base_model import LLMOutputParser |
| from ...workflow.operators import QAScEnsemble |
| from .utils import ( |
| format_transcript, |
| collect_last_round_candidates, |
| collect_round_candidates, |
| ) |
| from ...prompts.workflow.multi_agent_debate import ( |
| DEBATER_AGENT_PROMPT, |
| JUDGE_AGENT_PROMPT, |
| build_agent_prompt, |
| build_judge_prompt, |
| get_default_personas, |
| ) |
| from .pruning import PruningPipeline |
| from ...agents.customize_agent import CustomizeAgent |
|
|
|
|
| class DebateAgentOutput(LLMOutputParser): |
| """Output structure for individual debater in a round.""" |
|
|
| thought: str = Field(default="", description="Thinking process") |
| argument: str = Field(default="", description="Argument or rebuttal for this round") |
| answer: Optional[str] = Field(default=None, description="Current answer for this round (optional)") |
|
|
|
|
| class DebateJudgeOutput(LLMOutputParser): |
| """Final judgment from judge after debate.""" |
|
|
| rationale: str = Field(default="", description="Judging rationale") |
| winning_agent_id: int = Field(default=0, description="Winning debater ID (starting from 0)") |
| final_answer: str = Field(default="", description="Final answer") |
|
|
|
|
| class MultiAgentDebateActionGraph(ActionGraph): |
| """Multi-Agent Debate ActionGraph implementation (Google MAD style).""" |
|
|
| name: str = "MultiAgentDebate" |
| description: str = "Multi-agent debate workflow framework" |
| |
| llm_config: LLMConfig = Field(default_factory=lambda: OpenAILLMConfig( |
| model="gpt-4o-mini", |
| openai_key=os.getenv("OPENAI_API_KEY") |
| ), description="Default LLM configuration for all agents") |
| |
| debater_agents: Optional[List[CustomizeAgent]] = Field(default=None, description="Optional: multiple debater CustomizeAgents, randomly selected during execution") |
| judge_agent: Optional[CustomizeAgent] = Field(default=None, description="Optional: judge CustomizeAgent, used for judging phase if provided") |
| |
| llm_config_pool: Optional[List[LLMConfig]] = Field(default=None, description="Optional: LLM configuration pool for random selection, provides choices for agents without specified models") |
| |
| group_graphs_enabled: bool = Field(default=False, description="Enable group graph mode: replace individual debaters with workflow graphs") |
| group_graphs: Optional[List[ActionGraph]] = Field(default=None, description="When group graph mode is enabled, provide workflow graph list (length >= 1)") |
|
|
| |
| _sc_ensemble: Optional[QAScEnsemble] = None |
|
|
| def init_module(self): |
| """Initialize module (create LLM, construct reusable operators).""" |
| super().init_module() |
| |
| |
| if self.group_graphs_enabled and self.debater_agents: |
| raise ValueError( |
| "Configuration conflict: cannot configure debater_agents when group_graphs_enabled is enabled." |
| ) |
| if self.group_graphs_enabled and (not self.group_graphs or len(self.group_graphs) == 0): |
| raise ValueError( |
| "Configuration error: must provide non-empty group_graphs list when group graph mode is enabled." |
| ) |
| if (not self.group_graphs_enabled) and self.group_graphs: |
| raise ValueError( |
| "Configuration error: provided group_graphs but did not enable group_graphs_enabled. Please enable both or remove group_graphs." |
| ) |
| |
| self._sc_ensemble = QAScEnsemble(self._llm) |
| |
| def _create_default_debater_agent(self) -> CustomizeAgent: |
| """Create default debater CustomizeAgent (XML parsing thought/argument/answer).""" |
| llm_config = random.choice(self.llm_config_pool) if self.llm_config_pool else self.llm_config |
| |
| return CustomizeAgent( |
| name="DebaterAgent", |
| description="Generate argument/rebuttal and optional answer per debate round.", |
| prompt=DEBATER_AGENT_PROMPT, |
| llm_config=llm_config, |
| inputs=[ |
| {"name": "problem", "type": "str", "description": "Problem statement"}, |
| {"name": "transcript_text", "type": "str", "description": "Formatted debate transcript so far"}, |
| {"name": "role", "type": "str", "description": "Debater role/persona"}, |
| {"name": "agent_id", "type": "str", "description": "Debater id (string)"}, |
| {"name": "round_index", "type": "str", "description": "1-based round index"}, |
| {"name": "total_rounds", "type": "str", "description": "Total rounds"}, |
| ], |
| outputs=[ |
| {"name": "thought", "type": "str", "description": "Brief reasoning", "required": True}, |
| {"name": "argument", "type": "str", "description": "Argument or rebuttal", "required": True}, |
| {"name": "answer", "type": "str", "description": "Optional current answer", "required": False}, |
| ], |
| parse_mode="xml", |
| ) |
|
|
| def _create_default_judge_agent(self) -> CustomizeAgent: |
| """Create default judge CustomizeAgent (XML parsing rationale/winning_agent_id/final_answer).""" |
| llm_config = random.choice(self.llm_config_pool) if self.llm_config_pool else self.llm_config |
| |
| return CustomizeAgent( |
| name="JudgeAgent", |
| description="Deliver final decision and answer based on debate transcript.", |
| prompt=JUDGE_AGENT_PROMPT, |
| llm_config=llm_config, |
| inputs=[ |
| {"name": "problem", "type": "str", "description": "Problem statement"}, |
| {"name": "transcript_text", "type": "str", "description": "Formatted debate transcript"}, |
| {"name": "roles_text", "type": "str", "description": "Roles listing text"}, |
| ], |
| outputs=[ |
| {"name": "rationale", "type": "str", "description": "Judging rationale", "required": True}, |
| {"name": "winning_agent_id", "type": "str", "description": "Winning agent id (integer as string)", "required": True}, |
| {"name": "final_answer", "type": "str", "description": "Final answer", "required": True}, |
| ], |
| parse_mode="xml", |
| ) |
|
|
| def execute( |
| self, |
| problem: str, |
| num_agents: int = 3, |
| num_rounds: int = 3, |
| judge_mode: str = "llm_judge", |
| personas: Optional[List[str]] = None, |
| return_transcript: bool = True, |
| agent_llm_configs: Optional[List[LLMConfig]] = None, |
| enable_pruning: bool = False, |
| pruning_qp_threshold: float = 0.15, |
| pruning_dp_similarity_threshold: float = 0.92, |
| pruning_enable_mr: bool = False, |
| pruning_mr_llm_config: Optional[LLMConfig] = None, |
| pruning_snapshot_mode: bool = False, |
| transcript_mode: str = "prev", |
| **kwargs, |
| ) -> dict: |
| """Execute debate workflow (synchronous).""" |
| state = self._setup_debate(problem, num_agents, num_rounds, personas, agent_llm_configs) |
| transcript = self._run_debate_rounds(problem, state, transcript_mode) |
| |
| pruning_info = None |
| pruning_debug = None |
| pruning_rounds_debug: Optional[List[Dict[str, Any]]] = None |
| if enable_pruning: |
| |
| min_keep = max(1, int(round(state["num_agents"] * 0.3))) |
| pipeline = PruningPipeline( |
| enable_qp=True, |
| enable_dp=True, |
| enable_mr=pruning_enable_mr, |
| qp_threshold=pruning_qp_threshold, |
| dp_similarity_threshold=pruning_dp_similarity_threshold, |
| mr_llm_config=pruning_mr_llm_config, |
| min_keep_count=min_keep, |
| ) |
| if pruning_snapshot_mode: |
| |
| pruning_rounds_debug = [] |
| for r in range(state["num_rounds"]): |
| rcands = collect_round_candidates( |
| transcript=transcript, num_agents=state["num_agents"], round_index=r |
| ) |
| info_r = pipeline.apply(problem=problem, candidates=rcands) |
| pruning_rounds_debug.append( |
| { |
| "round": r, |
| "before_candidates": rcands, |
| "after_candidates": info_r.get("candidates", []), |
| "mr_suggested": info_r.get("mr_suggested"), |
| } |
| ) |
| candidates = collect_last_round_candidates( |
| transcript=transcript, num_agents=state["num_agents"], last_round_index=state["num_rounds"] - 1 |
| ) |
| pruning_info = pipeline.apply(problem=problem, candidates=candidates) |
| try: |
| pruning_debug = { |
| "before_candidates": candidates, |
| "after_candidates": pruning_info.get("candidates", []), |
| "mr_suggested": pruning_info.get("mr_suggested"), |
| } |
| except Exception: |
| pruning_debug = None |
| consensus = self._generate_consensus(problem, state, transcript, judge_mode, pruning_info) |
| result: Dict[str, Any] = { |
| "final_answer": consensus["final_answer"], |
| "winner": consensus.get("winner"), |
| "rationale": consensus.get("rationale"), |
| } |
| if return_transcript: |
| result["transcript"] = transcript |
| if enable_pruning and pruning_debug is not None: |
| result["pruning"] = pruning_debug |
| if enable_pruning and pruning_snapshot_mode and pruning_rounds_debug is not None: |
| result["pruning_rounds"] = pruning_rounds_debug |
| return result |
|
|
| async def async_execute( |
| self, |
| problem: str, |
| num_agents: int = 3, |
| num_rounds: int = 3, |
| judge_mode: str = "llm_judge", |
| personas: Optional[List[str]] = None, |
| return_transcript: bool = True, |
| agent_llm_configs: Optional[List[LLMConfig]] = None, |
| enable_pruning: bool = False, |
| pruning_qp_threshold: float = 0.15, |
| pruning_dp_similarity_threshold: float = 0.92, |
| pruning_enable_mr: bool = False, |
| pruning_mr_llm_config: Optional[LLMConfig] = None, |
| pruning_snapshot_mode: bool = False, |
| transcript_mode: str = "prev", |
| **kwargs, |
| ) -> dict: |
| """Execute debate workflow (asynchronous).""" |
| state = self._setup_debate(problem, num_agents, num_rounds, personas, agent_llm_configs) |
| transcript = await self._run_debate_rounds_async(problem, state, transcript_mode) |
| pruning_info = None |
| pruning_debug = None |
| pruning_rounds_debug: Optional[List[Dict[str, Any]]] = None |
| if enable_pruning: |
| min_keep = max(1, int(round(state["num_agents"] * 0.3))) |
| pipeline = PruningPipeline( |
| enable_qp=True, |
| enable_dp=True, |
| enable_mr=pruning_enable_mr, |
| qp_threshold=pruning_qp_threshold, |
| dp_similarity_threshold=pruning_dp_similarity_threshold, |
| mr_llm_config=pruning_mr_llm_config, |
| min_keep_count=min_keep, |
| ) |
| if pruning_snapshot_mode: |
| pruning_rounds_debug = [] |
| for r in range(state["num_rounds"]): |
| rcands = collect_round_candidates( |
| transcript=transcript, num_agents=state["num_agents"], round_index=r |
| ) |
| info_r = pipeline.apply(problem=problem, candidates=rcands) |
| pruning_rounds_debug.append( |
| { |
| "round": r, |
| "before_candidates": rcands, |
| "after_candidates": info_r.get("candidates", []), |
| "mr_suggested": info_r.get("mr_suggested"), |
| } |
| ) |
| candidates = collect_last_round_candidates( |
| transcript=transcript, num_agents=state["num_agents"], last_round_index=state["num_rounds"] - 1 |
| ) |
| pruning_info = pipeline.apply(problem=problem, candidates=candidates) |
| try: |
| pruning_debug = { |
| "before_candidates": candidates, |
| "after_candidates": pruning_info.get("candidates", []), |
| "mr_suggested": pruning_info.get("mr_suggested"), |
| } |
| except Exception: |
| pruning_debug = None |
| consensus = await self._generate_consensus_async(problem, state, transcript, judge_mode, pruning_info) |
| result: Dict[str, Any] = { |
| "final_answer": consensus["final_answer"], |
| "winner": consensus.get("winner"), |
| } |
| if return_transcript: |
| result["transcript"] = transcript |
| if enable_pruning and pruning_debug is not None: |
| result["pruning"] = pruning_debug |
| if enable_pruning and pruning_snapshot_mode and pruning_rounds_debug is not None: |
| result["pruning_rounds"] = pruning_rounds_debug |
| return result |
|
|
| def _setup_debate( |
| self, |
| problem: str, |
| num_agents: int, |
| num_rounds: int, |
| personas: Optional[List[str]], |
| agent_llm_configs: Optional[List[LLMConfig]] = None, |
| ) -> dict: |
| """Setup debate environment.""" |
| if num_agents <= 1: |
| raise ValueError("num_agents must be greater than 1") |
| if num_rounds <= 0: |
| raise ValueError("num_rounds must be positive") |
|
|
| roles: List[str] = personas or get_default_personas(num_agents) |
| |
| agents_for_ids: List[CustomizeAgent] = self._prepare_runtime_debaters(num_agents, agent_llm_configs) |
| state: Dict[str, Any] = { |
| "problem": problem, |
| "num_agents": num_agents, |
| "num_rounds": num_rounds, |
| "roles": roles, |
| "agents": agents_for_ids, |
| } |
| return state |
|
|
| def _prepare_runtime_debaters(self, num_agents: int, agent_llm_configs: Optional[List[LLMConfig]]) -> List[CustomizeAgent]: |
| """Select CustomizeAgent for each agent_id that remains unchanged throughout the debate. |
| Priority: |
| 1) User explicitly passes debater_agents → cycle/truncate by length and assign to each position |
| 2) Pass agent_llm_configs → create default debater for each position |
| 3) Use llm_config_pool random selection → create default debater for each position (prioritized over default llm_config) |
| 4) Fallback to default llm_config |
| """ |
| |
| if self.group_graphs_enabled: |
| return [] |
| |
| |
| if self.debater_agents: |
| agents: List[CustomizeAgent] = [] |
| for i in range(num_agents): |
| agents.append(self.debater_agents[i % len(self.debater_agents)]) |
| return agents |
| |
| |
| if agent_llm_configs and len(agent_llm_configs) > 0: |
| return [ |
| self._create_debater_agent_with_llm(agent_llm_configs[i % len(agent_llm_configs)]) |
| for i in range(num_agents) |
| ] |
| |
| |
| if self.llm_config_pool and len(self.llm_config_pool) > 0: |
| return [self._create_debater_agent_with_llm(random.choice(self.llm_config_pool)) for _ in range(num_agents)] |
| |
| |
| default_agent = self._create_default_debater_agent() |
| return [default_agent for _ in range(num_agents)] |
|
|
| def _create_debater_agent_with_llm(self, llm_cfg: LLMConfig) -> CustomizeAgent: |
| """Create a debater agent with given LLM configuration that is consistent with default structure.""" |
| return CustomizeAgent( |
| name="DebaterAgent", |
| description="Generate argument/rebuttal and optional answer per debate round.", |
| prompt=DEBATER_AGENT_PROMPT, |
| llm_config=llm_cfg, |
| inputs=[ |
| {"name": "problem", "type": "str", "description": "Problem statement"}, |
| {"name": "transcript_text", "type": "str", "description": "Formatted debate transcript so far"}, |
| {"name": "role", "type": "str", "description": "Debater role/persona"}, |
| {"name": "agent_id", "type": "str", "description": "Debater id (string)"}, |
| {"name": "round_index", "type": "str", "description": "1-based round index"}, |
| {"name": "total_rounds", "type": "str", "description": "Total rounds"}, |
| ], |
| outputs=[ |
| {"name": "thought", "type": "str", "description": "Brief reasoning", "required": True}, |
| {"name": "argument", "type": "str", "description": "Argument or rebuttal", "required": True}, |
| {"name": "answer", "type": "str", "description": "Optional current answer", "required": False}, |
| ], |
| parse_mode="xml", |
| ) |
|
|
| def _run_debate_rounds(self, problem: str, state: dict, transcript_mode: str = "prev") -> List[dict]: |
| """Run debate rounds (synchronous). Return transcript. |
| |
| Args: |
| transcript_mode: Control transcript range accessible to agents |
| - "prev": Can only see n-1 round speeches (default) |
| - "all": Can see all previous round speeches |
| """ |
| transcript: List[dict] = [] |
| num_agents: int = state["num_agents"] |
| num_rounds: int = state["num_rounds"] |
| roles: List[str] = state["roles"] |
|
|
| for round_index in range(num_rounds): |
| for agent_id in range(num_agents): |
| |
| if self.group_graphs_enabled and self.group_graphs: |
| graph = self.group_graphs[agent_id % len(self.group_graphs)] |
| |
| transcript_text = self._get_transcript_for_agent( |
| transcript, round_index, agent_id, transcript_mode, num_agents |
| ) |
| g_inputs = { |
| "problem": problem, |
| "transcript_text": transcript_text, |
| "role": roles[agent_id], |
| "agent_id": str(agent_id), |
| "round_index": str(round_index + 1), |
| "total_rounds": str(num_rounds), |
| } |
| g_out = graph.execute(**g_inputs) |
| structured = { |
| "argument": g_out.get("argument", g_out.get("output", "")), |
| "answer": g_out.get("answer"), |
| "thought": g_out.get("thought", ""), |
| } |
| else: |
| |
| selected_agent: Optional[CustomizeAgent] = None |
| agents_for_ids: Optional[List[CustomizeAgent]] = state.get("agents") |
| if agents_for_ids: |
| selected_agent = agents_for_ids[agent_id] |
| elif self.debater_agents: |
| selected_agent = random.choice(self.debater_agents) |
|
|
| if selected_agent is not None: |
| try: |
| |
| transcript_text = self._get_transcript_for_agent( |
| transcript, round_index, agent_id, transcript_mode, num_agents |
| ) |
| inputs = { |
| "problem": problem, |
| "transcript_text": transcript_text, |
| "role": roles[agent_id], |
| "agent_id": str(agent_id), |
| "round_index": str(round_index + 1), |
| "total_rounds": str(num_rounds), |
| } |
| msg = selected_agent(inputs=inputs) |
| structured = msg.content.get_structured_data() |
| except Exception as e: |
| print(f"Agent execution error: {e}") |
| structured = {"argument": "", "answer": "", "thought": ""} |
| else: |
| |
| transcript_text = self._get_transcript_for_agent( |
| transcript, round_index, agent_id, transcript_mode, num_agents |
| ) |
| prompt = build_agent_prompt( |
| problem=problem, |
| transcript_text=transcript_text, |
| role=roles[agent_id], |
| agent_id=agent_id, |
| round_index=round_index, |
| total_rounds=num_rounds, |
| ) |
| response = self._llm.generate( |
| prompt=prompt, |
| parser=DebateAgentOutput, |
| parse_mode="xml", |
| ) |
| structured = response.get_structured_data() |
| transcript.append( |
| { |
| "agent_id": agent_id, |
| "round": round_index, |
| "role": roles[agent_id], |
| "argument": structured.get("argument", ""), |
| "answer": structured.get("answer"), |
| "thought": structured.get("thought", ""), |
| } |
| ) |
| |
| try: |
| arg_full = str(structured.get("argument", "")).strip() |
| ans_full = str(structured.get("answer") or "").strip() |
| print( |
| f"[Round {round_index + 1}] Agent#{agent_id} ({roles[agent_id]})\n" |
| f"Argument: {arg_full}\n" |
| f"Answer: {ans_full}\n" |
| ) |
| except Exception: |
| pass |
| return transcript |
|
|
| def _get_transcript_for_agent(self, transcript: List[dict], round_index: int, agent_id: int, |
| transcript_mode: str, num_agents: int) -> str: |
| """根据访问模式获取agent可以访问的transcript。 |
| |
| Args: |
| transcript: 完整的transcript |
| round_index: 当前轮次索引 |
| agent_id: 当前agent的ID |
| transcript_mode: 访问模式 |
| - "prev": 只能看到n-1轮次的发言(默认) |
| - "all": 可以看到之前所有轮次的发言 |
| num_agents: agent总数 |
| |
| Returns: |
| str: 格式化后的transcript文本 |
| """ |
| if transcript_mode == "prev": |
| |
| filtered_transcript = [t for t in transcript if t["round"] < round_index] |
| elif transcript_mode == "all": |
| |
| filtered_transcript = [] |
| for t in transcript: |
| if t["round"] < round_index: |
| |
| filtered_transcript.append(t) |
| elif t["round"] == round_index and t["agent_id"] < agent_id: |
| |
| filtered_transcript.append(t) |
| else: |
| |
| filtered_transcript = [t for t in transcript if t["round"] < round_index] |
| |
| return format_transcript(filtered_transcript) |
|
|
| async def _run_debate_rounds_async(self, problem: str, state: dict, transcript_mode: str = "prev") -> List[dict]: |
| """运行辩论轮次(异步)。返回 transcript。 |
| |
| Args: |
| transcript_mode: 控制agent可以访问的transcript范围 |
| - "prev": 只能看到n-1轮次的发言(默认) |
| - "all": 可以看到之前所有轮次的发言 |
| """ |
| transcript: List[dict] = [] |
| num_agents: int = state["num_agents"] |
| num_rounds: int = state["num_rounds"] |
| roles: List[str] = state["roles"] |
|
|
| for round_index in range(num_rounds): |
| |
| |
| |
| if self.group_graphs_enabled and self.group_graphs: |
| for agent_id in range(num_agents): |
| graph = self.group_graphs[agent_id % len(self.group_graphs)] |
| |
| transcript_text = self._get_transcript_for_agent( |
| transcript, round_index, agent_id, transcript_mode, num_agents |
| ) |
| g_inputs = { |
| "problem": problem, |
| "transcript_text": transcript_text, |
| "role": roles[agent_id], |
| "agent_id": str(agent_id), |
| "round_index": str(round_index + 1), |
| "total_rounds": str(num_rounds), |
| } |
| g_out = graph.execute(**g_inputs) |
| structured = { |
| "argument": g_out.get("argument", g_out.get("output", "")), |
| "answer": g_out.get("answer"), |
| "thought": g_out.get("thought", ""), |
| } |
| transcript.append( |
| { |
| "agent_id": agent_id, |
| "round": round_index, |
| "role": roles[agent_id], |
| "argument": structured.get("argument", ""), |
| "answer": structured.get("answer"), |
| "thought": structured.get("thought", ""), |
| } |
| ) |
| |
| try: |
| print( |
| f"[Round {round_index + 1}] Agent#{agent_id} ({roles[agent_id]})\n" |
| f"Argument: {str(structured.get('argument','')).strip()}\n" |
| f"Answer: {str(structured.get('answer') or '').strip()}\n" |
| ) |
| except Exception: |
| pass |
| |
| elif state.get("agents") or self.debater_agents or self.debater_agent is not None: |
| import asyncio |
| tasks = [] |
| id_list: List[int] = [] |
| for agent_id in range(num_agents): |
| agents_for_ids: Optional[List[CustomizeAgent]] = state.get("agents") |
| if agents_for_ids: |
| selected_agent = agents_for_ids[agent_id] |
| elif self.debater_agents: |
| selected_agent = random.choice(self.debater_agents) |
| else: |
| selected_agent = None |
| |
| transcript_text = self._get_transcript_for_agent( |
| transcript, round_index, agent_id, transcript_mode, num_agents |
| ) |
| inputs = { |
| "problem": problem, |
| "transcript_text": transcript_text, |
| "role": roles[agent_id], |
| "agent_id": str(agent_id), |
| "round_index": str(round_index + 1), |
| "total_rounds": str(num_rounds), |
| } |
| tasks.append(selected_agent(inputs=inputs)) |
| id_list.append(agent_id) |
| messages = await asyncio.gather(*tasks) |
| for agent_id, msg in zip(id_list, messages): |
| structured = msg.content.get_structured_data() |
| transcript.append( |
| { |
| "agent_id": agent_id, |
| "round": round_index, |
| "role": roles[agent_id], |
| "argument": structured.get("argument", ""), |
| "answer": structured.get("answer"), |
| "thought": structured.get("thought", ""), |
| } |
| ) |
| |
| try: |
| for agent_id, msg in zip(id_list, messages): |
| st = msg.content.get_structured_data() |
| arg_full = str(st.get("argument", "")).strip() |
| ans_full = str(st.get("answer") or "").strip() |
| print( |
| f"[Round {round_index + 1}] Agent#{agent_id} ({roles[agent_id]})\n" |
| f"Argument: {arg_full}\n" |
| f"Answer: {ans_full}\n" |
| ) |
| except Exception: |
| pass |
| else: |
| prompts: List[Tuple[int, str]] = [] |
| for agent_id in range(num_agents): |
| |
| transcript_text = self._get_transcript_for_agent( |
| transcript, round_index, agent_id, transcript_mode, num_agents |
| ) |
| prompt = build_agent_prompt( |
| problem=problem, |
| transcript_text=transcript_text, |
| role=roles[agent_id], |
| agent_id=agent_id, |
| round_index=round_index, |
| total_rounds=num_rounds, |
| ) |
| prompts.append((agent_id, prompt)) |
|
|
| results = await self._llm.batch_generate_async( |
| batch_messages=[[{"role": "user", "content": p}] for _, p in prompts] |
| ) |
| parsed_list = self._llm.parse_generated_texts( |
| texts=results, parser=DebateAgentOutput, parse_mode="xml" |
| ) |
| for (agent_id, _), parsed in zip(prompts, parsed_list): |
| structured = parsed.get_structured_data() |
| transcript.append( |
| { |
| "agent_id": agent_id, |
| "round": round_index, |
| "role": roles[agent_id], |
| "argument": structured.get("argument", ""), |
| "answer": structured.get("answer"), |
| "thought": structured.get("thought", ""), |
| } |
| ) |
| return transcript |
|
|
| def _generate_consensus( |
| self, problem: str, state: dict, transcript: List[dict], judge_mode: str, pruning_info: Optional[Dict[str, Any]] = None |
| ) -> dict: |
| """根据 judge 模式生成最终共识(同步)。""" |
| if judge_mode == "self_consistency": |
| |
| agent_final_answers = self._collect_agent_final_answers(state, transcript) |
| if len(agent_final_answers) == 0: |
| |
| agent_final_answers = [t["argument"] for t in transcript if t.get("argument")] |
| sc = self._sc_ensemble.execute(solutions=agent_final_answers) |
| return { |
| "final_answer": sc["response"], |
| "winner": None, |
| } |
|
|
| |
| if self.judge_agent is not None: |
| roles_text = "\n".join([f"#{i}: {r}" for i, r in enumerate(state["roles"])]) |
| inputs = { |
| "problem": problem, |
| "transcript_text": format_transcript(transcript), |
| "roles_text": roles_text, |
| } |
| |
| if pruning_info and pruning_info.get("mr_suggested"): |
| suggested = pruning_info["mr_suggested"].get("corrected", "") |
| if suggested: |
| inputs["problem"] = problem + "\n\n(Consider corrected consolidation, if helpful.)" |
| msg = self.judge_agent(inputs=inputs) |
| jd = msg.content.get_structured_data() |
| else: |
| judge_prompt = build_judge_prompt( |
| problem=problem, |
| transcript_text=format_transcript(transcript), |
| roles=state["roles"], |
| ) |
| judge_resp = self._llm.generate( |
| prompt=judge_prompt, parser=DebateJudgeOutput, parse_mode="xml" |
| ) |
| jd = judge_resp.get_structured_data() |
| |
| winner_id = int(jd.get("winning_agent_id", 0)) |
| final_answer = jd.get("final_answer", "") |
| |
| |
| winner_answer = self._get_winner_answer(transcript, winner_id, state["num_rounds"]) |
| |
| return { |
| "final_answer": final_answer, |
| "winner": winner_id, |
| "winner_answer": winner_answer, |
| "rationale": jd.get("rationale", ""), |
| } |
|
|
| async def _generate_consensus_async( |
| self, problem: str, state: dict, transcript: List[dict], judge_mode: str, pruning_info: Optional[Dict[str, Any]] = None |
| ) -> dict: |
| """根据 judge 模式生成最终共识(异步)。""" |
| if judge_mode == "self_consistency": |
| agent_final_answers = self._collect_agent_final_answers(state, transcript) |
| if len(agent_final_answers) == 0: |
| agent_final_answers = [t["argument"] for t in transcript if t.get("argument")] |
| sc = await self._sc_ensemble.async_execute(solutions=agent_final_answers) |
| return { |
| "final_answer": sc["response"], |
| "winner": None, |
| } |
|
|
| if self.judge_agent is not None: |
| roles_text = "\n".join([f"#{i}: {r}" for i, r in enumerate(state["roles"])]) |
| inputs = { |
| "problem": problem, |
| "transcript_text": format_transcript(transcript), |
| "roles_text": roles_text, |
| } |
| if pruning_info and pruning_info.get("mr_suggested"): |
| suggested = pruning_info["mr_suggested"].get("corrected", "") |
| if suggested: |
| inputs["problem"] = problem + "\n\n(Consider corrected consolidation, if helpful.)" |
| msg = await self.judge_agent(inputs=inputs) |
| jd = msg.content.get_structured_data() |
| else: |
| judge_prompt = build_judge_prompt( |
| problem=problem, |
| transcript_text=format_transcript(transcript), |
| roles=state["roles"], |
| ) |
| judge_resp = await self._llm.async_generate( |
| prompt=judge_prompt, parser=DebateJudgeOutput, parse_mode="xml" |
| ) |
| jd = judge_resp.get_structured_data() |
| |
| winner_id = int(jd.get("winning_agent_id", 0)) |
| final_answer = jd.get("final_answer", "") |
| |
| |
| winner_answer = self._get_winner_answer(transcript, winner_id, state["num_rounds"]) |
| |
| return { |
| "final_answer": final_answer, |
| "winner": winner_id, |
| "winner_answer": winner_answer, |
| "rationale": jd.get("rationale", ""), |
| } |
|
|
| def _collect_agent_final_answers(self, state: dict, transcript: List[dict]) -> List[str]: |
| """收集每位辩手的最终答案(若有)。""" |
| num_agents = state["num_agents"] |
| num_rounds = state["num_rounds"] |
| final_answers: List[str] = [] |
| for agent_id in range(num_agents): |
| |
| records = [t for t in transcript if t["agent_id"] == agent_id and t["round"] == num_rounds - 1] |
| if len(records) == 0: |
| continue |
| ans = records[-1].get("answer") |
| if ans and isinstance(ans, str) and len(ans.strip()) > 0: |
| final_answers.append(ans) |
| return final_answers |
|
|
| def _get_winner_answer(self, transcript: List[dict], winner_id: int, num_rounds: int) -> Optional[str]: |
| """获取获胜者在最后一轮的答案。""" |
| |
| records = [t for t in transcript if t["agent_id"] == winner_id and t["round"] == num_rounds - 1] |
| if len(records) == 0: |
| return None |
| |
| answer = records[-1].get("answer") |
| if answer and isinstance(answer, str) and len(answer.strip()) > 0: |
| return answer.strip() |
| |
| |
| argument = records[-1].get("argument", "") |
| return argument.strip() if argument else None |
|
|
| def save_module(self, path: str, ignore: List[str] = [], **kwargs) -> str: |
| """保存模块配置(直接保存agents,不保存llm_config_pool)""" |
| |
| os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True) |
| |
| |
| agent_pool_path = path.replace('.json', '_agents.json') |
| if self.debater_agents: |
| agent_data = [] |
| for i, agent in enumerate(self.debater_agents): |
| |
| agent_path = agent_pool_path.replace('.json', f'_{i}.json') |
| agent.save_module(agent_path) |
| agent_data.append({ |
| "name": agent.name, |
| "description": agent.description, |
| "file_path": agent_path |
| }) |
| |
| with open(agent_pool_path, 'w', encoding='utf-8') as f: |
| json.dump(agent_data, f, ensure_ascii=False, indent=2) |
| |
| |
| judge_agent_path = path.replace('.json', '_judge.json') |
| if self.judge_agent: |
| self.judge_agent.save_module(judge_agent_path) |
| |
| |
| config = { |
| "llm_config": self._serialize_llm_config(self.llm_config), |
| "name": self.name, |
| "description": self.description, |
| "agent_pool_file": agent_pool_path if self.debater_agents else None, |
| "judge_agent_file": judge_agent_path if self.judge_agent else None |
| } |
| |
| with open(path, 'w', encoding='utf-8') as f: |
| json.dump(config, f, ensure_ascii=False, indent=2) |
| |
| print(f"模块配置已保存到: {path}") |
| return path |
| |
| def get_config(self) -> dict: |
| """获取当前模块的配置字典(不包含llm_config_pool)""" |
| config = { |
| "llm_config": self._serialize_llm_config(self.llm_config), |
| "name": self.name, |
| "description": self.description, |
| } |
| |
| |
| if self.debater_agents: |
| agent_data = [] |
| for agent in self.debater_agents: |
| agent_info = { |
| "name": agent.name, |
| "description": agent.description, |
| "config": agent.get_config() |
| } |
| agent_data.append(agent_info) |
| config["debater_agents"] = agent_data |
| |
| |
| if self.judge_agent: |
| config["judge_agent"] = { |
| "name": self.judge_agent.name, |
| "description": self.judge_agent.description, |
| "config": self.judge_agent.get_config() |
| } |
| |
| return config |
| |
| @classmethod |
| def from_dict(cls, data: Dict[str, Any], **kwargs) -> 'MultiAgentDebateActionGraph': |
| """从配置字典创建MultiAgentDebateActionGraph实例(不重建llm_config_pool)""" |
| |
| instance = cls() |
| |
| |
| if data.get("llm_config"): |
| instance.llm_config = instance._deserialize_llm_config(data["llm_config"]) |
| |
| |
| if data.get("name"): |
| instance.name = data["name"] |
| |
| if data.get("description"): |
| instance.description = data["description"] |
| |
| |
| if data.get("debater_agents"): |
| agents = [] |
| for agent_info in data["debater_agents"]: |
| try: |
| agent_config = agent_info.get("config", {}) |
| llm_config = instance._deserialize_llm_config(agent_config.get("llm_config")) |
| |
| |
| agent_config_clean = {k: v for k, v in agent_config.items() |
| if k not in ['name', 'description', 'llm_config']} |
| |
| agent = CustomizeAgent( |
| name=agent_info["name"], |
| description=agent_info["description"], |
| llm_config=llm_config, |
| **agent_config_clean |
| ) |
| agents.append(agent) |
| except Exception as e: |
| print(f"警告: 重建agent {agent_info.get('name', 'unknown')}失败: {e}") |
| continue |
| |
| instance.debater_agents = agents |
| |
| |
| if data.get("judge_agent"): |
| try: |
| judge_info = data["judge_agent"] |
| judge_config = judge_info.get("config", {}) |
| llm_config = instance._deserialize_llm_config(judge_config.get("llm_config")) |
| |
| |
| judge_config_clean = {k: v for k, v in judge_config.items() |
| if k not in ['name', 'description', 'llm_config']} |
| |
| instance.judge_agent = CustomizeAgent( |
| name=judge_info["name"], |
| description=judge_info["description"], |
| llm_config=llm_config, |
| **judge_config_clean |
| ) |
| except Exception as e: |
| print(f"警告: 重建judge agent失败: {e}") |
| |
| return instance |
| |
| @classmethod |
| def load_module(cls, path: str, llm_config: LLMConfig = None, **kwargs) -> 'MultiAgentDebateActionGraph': |
| """从文件加载MultiAgentDebateActionGraph实例(类方法,不重建llm_config_pool)""" |
| if not os.path.exists(path): |
| raise FileNotFoundError(f"模块配置文件不存在: {path}") |
| |
| try: |
| with open(path, 'r', encoding='utf-8') as f: |
| config = json.load(f) |
| except json.JSONDecodeError as e: |
| raise ValueError(f"配置文件格式错误: {e}") |
| except Exception as e: |
| raise RuntimeError(f"读取配置文件失败: {e}") |
| |
| |
| instance = cls() |
| |
| |
| if config.get("llm_config"): |
| try: |
| instance.llm_config = instance._deserialize_llm_config(config["llm_config"]) |
| except Exception as e: |
| print(f"警告: 重建llm_config失败: {e}") |
| |
| |
| if config.get("name"): |
| instance.name = config["name"] |
| |
| if config.get("description"): |
| instance.description = config["description"] |
| |
| |
| agent_pool_file = config.get("agent_pool_file") |
| if agent_pool_file and os.path.exists(agent_pool_file): |
| try: |
| with open(agent_pool_file, 'r', encoding='utf-8') as f: |
| agent_data = json.load(f) |
| |
| agents = [] |
| for agent_info in agent_data: |
| try: |
| agent_path = agent_info.get("file_path") |
| if agent_path and os.path.exists(agent_path): |
| |
| agent = CustomizeAgent.from_file( |
| path=agent_path, |
| llm_config=instance.llm_config or llm_config |
| ) |
| agents.append(agent) |
| else: |
| print(f"警告: agent文件不存在: {agent_path}") |
| except Exception as e: |
| print(f"警告: 加载agent {agent_info.get('name', 'unknown')}失败: {e}") |
| continue |
| |
| instance.debater_agents = agents |
| print(f"从 {agent_pool_file} 加载了 {len(agents)} 个agents") |
| except Exception as e: |
| print(f"警告: 加载agent pool失败: {e}") |
| |
| |
| judge_agent_file = config.get("judge_agent_file") |
| if judge_agent_file and os.path.exists(judge_agent_file): |
| try: |
| |
| instance.judge_agent = CustomizeAgent.from_file( |
| path=judge_agent_file, |
| llm_config=instance.llm_config or llm_config |
| ) |
| print(f"从 {judge_agent_file} 加载了judge agent") |
| except Exception as e: |
| print(f"警告: 加载judge agent失败: {e}") |
| |
| print(f"从 {path} 加载了模块配置") |
| return instance |
| |
|
|
| |
| def _serialize_llm_config(self, llm_config) -> Optional[Dict[str, Any]]: |
| """序列化LLM配置(只保存模型名称和基本参数)""" |
| if not llm_config: |
| return None |
| |
| config_info = { |
| "model": llm_config.model if hasattr(llm_config, 'model') else None, |
| "temperature": llm_config.temperature if hasattr(llm_config, 'temperature') else None, |
| "config_type": type(llm_config).__name__ |
| } |
| |
| return config_info |
| |
| def _deserialize_llm_config(self, config_info: Optional[Dict[str, Any]]) -> Optional[LLMConfig]: |
| """反序列化LLM配置(从环境变量重建)""" |
| if not config_info: |
| return None |
| |
| config_type = config_info.get("config_type", "OpenAILLMConfig") |
| |
| if config_type == "OpenAILLMConfig": |
| from ...models.model_configs import OpenAILLMConfig |
| return OpenAILLMConfig( |
| model=config_info.get("model", "gpt-4o-mini"), |
| openai_key=os.getenv("OPENAI_API_KEY") |
| ) |
| elif config_type == "OpenRouterConfig": |
| from ...models.model_configs import OpenRouterConfig |
| return OpenRouterConfig( |
| model=config_info.get("model", "meta-llama/llama-3.1-70b-instruct"), |
| openrouter_key=os.getenv("OPENROUTER_API_KEY") |
| ) |
| |
| return None |
|
|