File size: 4,655 Bytes
40deb66
a877f54
314810a
a877f54
40deb66
bcdc55d
40deb66
01faebd
7a0b5ad
 
c36efba
a394be7
 
c82d598
 
da8c40b
a877f54
 
 
 
40deb66
 
 
bcdc55d
c82d598
9236a37
 
 
 
 
 
 
 
40deb66
 
 
 
0019780
7a0b5ad
c82d598
 
7a0b5ad
 
ed2f57e
 
7a0b5ad
 
 
 
40deb66
a877f54
 
 
01faebd
40deb66
97fc18c
 
01faebd
 
 
 
 
97fc18c
 
c82d598
01faebd
97fc18c
 
 
01faebd
97fc18c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a877f54
 
 
7a0b5ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from datetime import datetime, timezone
from dotenv import load_dotenv
from colorama import Fore, Style  # type: ignore[import]
from langchain.agents import create_agent
from langchain_core.messages import HumanMessage
from agent.api.api import get_llm
from agent.tools.math_solver import math_solver
from agent.tools.file_downloader import file_downloader
from agent.tools.ocr_reader import ocr_reader
from agent.tools.list_files import list_files
from agent.tools.http_get import http_get

from agent.agents.websearchagents import web_search_agents

# from agent.agents.websearchagent import websearch_agent
from agent.agents.answer_extractor import extract_answer

load_dotenv()


def supervisor_agent():
    """Return a supervisor agent instance with math_solver and websearch_agent."""
    return create_agent(
        model=get_llm(),
        # tools=[math_solver, websearch_agent, web_search_agents],
        tools=[
            math_solver,
            web_search_agents,
            file_downloader,
            ocr_reader,
            list_files,
            http_get,
        ],
        system_prompt=(
            f"You are a supervisor agent. "
            f"Current time is: {datetime.now(timezone.utc).isoformat()}. "
            f"Your memory are out of date. "
            f"For any math or calculation questions, use the math_solver tool for check, "
            f"the accurate is the most important. "
            f"All questions that need real-time, must use the web_search_agents tool "
            f"to get a concise and accurate final answer. "
            f"If an image or photo file is attached, download and use the ocr_reader tool "
            f"to extract and describe the content before answering. "
            f"Once you have found the answer, respond immediately. "
            f"Do NOT continue searching or verifying unnecessarily — "
            f"you have a limited number of action steps and must avoid exceeding them. "
            f"If you do not have enough information to answer the question "
            f"and no tool can help, respond with: "
            f"'Insufficient information to provide an answer.'"
        ),
    )


def run(query: str, file_url: str | None = None, max_retries: int = 3) -> str:
    """Entry point: let the supervisor agent finish the work."""
    last_error: str | None = None

    # Append file URL info to the query if available
    full_query = query
    if file_url:
        full_query += f"\n\nAttached file URL: {file_url}"

    for attempt in range(1, max_retries + 1):
        print(
            f"{Fore.CYAN}[Supervisor] Processing query (attempt {attempt}/{max_retries})...\n"
            f"[Supervisor] Query: {full_query}{Style.RESET_ALL}"
        )
        agent = supervisor_agent()

        messages = [HumanMessage(content=full_query)]
        if last_error:
            messages.append(
                HumanMessage(
                    content=(
                        f"Previous attempt failed with error: {last_error}\n"
                        f"Try thinking about the problem more simply. "
                        f"Use fewer steps and a more straightforward approach."
                    )
                )
            )

        try:
            result = agent.invoke(
                {"messages": messages},
                # Max agent action is 30 turn.
                config={"recursion_limit": 30},
            )
            content = result["messages"][-1].content
            if isinstance(content, list):
                content = content[0].get("text", "")
            else:
                content = str(content)
            return extract_answer(content, query)
        except Exception as e:
            last_error = str(e)
            print(
                f"{Fore.RED}[Supervisor] Attempt {attempt} failed: {last_error}{Style.RESET_ALL}"
            )

    print(f"{Fore.RED}[Supervisor] All {max_retries} attempts failed.{Style.RESET_ALL}")
    return extract_answer(
        f"Agent failed after {max_retries} attempts. Last error: {last_error}", query
    )


if __name__ == "__main__":
    # run(input("Query:"))
    ######################
    agent = supervisor_agent()
    chat_history: list = []
    while True:
        query = input("\nYou: ")
        if query.lower() in ("exit", "quit"):
            break
        chat_history.append(HumanMessage(content=query))
        result = agent.invoke({"messages": chat_history})
        chat_history = result["messages"]
        content = chat_history[-1].content
        if isinstance(content, list):
            content = content[0].get("text", "")
        print(f"Agent: {content}")