File size: 6,307 Bytes
fbe5cb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
645e967
fbe5cb3
 
 
 
 
25a922a
fbe5cb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25a922a
fbe5cb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25a922a
fbe5cb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25a922a
fbe5cb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from smolagents import (
    CodeAgent,
    VisitWebpageTool,
    WebSearchTool,
    WikipediaSearchTool,
    PythonInterpreterTool,
    FinalAnswerTool,
    LiteLLMModel,
)
from vision_tool import image_reasoning_tool
from throttle import consume
import os
import time

# ---- TOOLS ----

common = dict(
    api_key=os.getenv("GROQ_API_KEY"),
    api_base="https://api.groq.com/openai/v1",
    flatten_messages_as_text=True,
)



# ---- MULTI-AGENT SYSTEM ----
class MultyAgentSystem:
    def __init__(self):
        self.deepseek_model = LiteLLMModel(
            "groq/deepseek-r1-distill-llama-70b",
            max_tokens=512,
            **common,
        )
        self.qwen_model = LiteLLMModel("groq/qwen-qwq-32b", **common)
        self.fallback_model = LiteLLMModel("groq/llama3-70b-8k", **common)
        self.llama_model = LiteLLMModel("groq/llama-3.3-70b-versatile", **common)

        self.verification_limit = int(os.getenv("VERIFY_WORD_LIMIT", "75"))

        # --- Web agent definition ---
        self.web_agent = CodeAgent(
            model=self.llama_model,
            tools=[WebSearchTool(), VisitWebpageTool(), WikipediaSearchTool()],
            name="web_agent",
            description=(
                "You are a web browsing agent. Whenever the given {task} involves browsing "
                "the web or a specific website such as Wikipedia or YouTube, you will use "
                "the provided tools. For web-based factual and retrieval tasks, be as precise and source-reliable as possible."
            ),
            additional_authorized_imports=[
                "markdownify",
                "json",
                "requests",
                "urllib.request",
                "urllib.parse",
                "wikipedia-api",
            ],
            verbosity_level=0,
            max_steps=10,
        )

        # --- Info agent definition ---
        self.info_agent = CodeAgent(
            model=self.llama_model,
            tools=[PythonInterpreterTool(), image_reasoning_tool],
            name="info_agent",
            description=(
                "You are an agent tasked with cleaning, parsing, calculating information, and performing OCR if images are provided in the {task}. "
                "You can also analyze images using a vision model. You handle all math, code, and data manipulation. Use numpy, math, and available libraries. "
                "For image or chess tasks, use pytesseract, PIL, chess, or the image_reasoning_tool as required."
            ),
            additional_authorized_imports=[
                "numpy",
                "math",
                "pytesseract",
                "PIL",
                "chess",
            ],
        )

        # --- Manager agent definition ---
        manager_planning_interval = int(os.getenv("MANAGER_PLANNING_INTERVAL", "3"))
        manager_max_steps = int(os.getenv("MANAGER_MAX_STEPS", "8"))

        # The manager starts with the smaller Qwen model to minimize token usage
        # and only relies on DeepSeek when verifying critical answers.
        self.manager_agent = CodeAgent(
            model=self.llama_model,
            tools=[FinalAnswerTool()],
            managed_agents=[self.web_agent, self.info_agent],
            name="manager_agent",
            description=(
                "You are the manager. Given a {task}, plan which agent to use: "
                "If web data is needed, delegate to web_agent. If math, parsing, image reasoning, or code is needed, use info_agent. "
                "After collecting outputs, optionally cross-validate and check correctness, then finalize and submit the best answer using FinalAnswerTool. "
                "For each task, explicitly explain your planning steps and reasons for choosing which agent, and always prefer the most accurate and complete answer possible."
            ),
            additional_authorized_imports=[
                "json",
                "pandas",
                "numpy",
            ],
            planning_interval=manager_planning_interval,
            verbosity_level=2,
            max_steps=manager_max_steps,
        )

        # runtime tracking for fallback switching
        self.total_runtime = 0.0
        self.first_call_duration = None
        self.model_switched = False

    def _switch_to_fallback(self):
        if self.model_switched:
            return
        self.manager_agent.model = self.fallback_model
        self.model_switched = True

    def run(self, question, high_stakes: bool = False, **kwargs):
        start_time = time.time()
        print("Generating initial answer with llama-4-scout")
        max_completion_tokens = kwargs.get("max_completion_tokens", 512)
        prompt_tokens = len(question.split())
        consume(prompt_tokens + max_completion_tokens)
        initial_answer = self.manager_agent(question, **kwargs)
        call_duration = time.time() - start_time

        answer = initial_answer
        if high_stakes or len(initial_answer.split()) > self.verification_limit:
            print("Verifying answer using DeepSeek-70B")
            verification_prompt = (
                "Review the following answer for accuracy and rewrite if needed:"
                f"\n\n{initial_answer}"
            )
            try:
                max_completion_tokens = kwargs.get("max_completion_tokens", 512)
                prompt_tokens = len(verification_prompt.split())
                consume(prompt_tokens + max_completion_tokens)
                answer = self.deepseek_model(
                    verification_prompt, max_completion_tokens=max_completion_tokens
                )
            except Exception as e:
                print(f"Verification failed: {e}. Using initial answer.")
                answer = initial_answer

        if self.first_call_duration is None:
            self.first_call_duration = call_duration
            if self.first_call_duration > 30:
                self._switch_to_fallback()

        self.total_runtime += call_duration
        if self.total_runtime > 300 and not self.model_switched:
            self._switch_to_fallback()

        return answer

    def __call__(self, question, high_stakes: bool = False, **kwargs):

        return self.run(question, high_stakes=high_stakes, **kwargs)