File size: 2,059 Bytes
31616e2
 
 
6065fb1
c97e46a
31616e2
 
d55b613
6065fb1
 
31616e2
 
6065fb1
 
31616e2
 
 
c97e46a
 
 
31616e2
6065fb1
31616e2
6065fb1
 
d55b613
31616e2
6065fb1
 
d55b613
31616e2
 
6065fb1
31616e2
d55b613
31616e2
 
 
 
 
6065fb1
31616e2
 
 
 
 
6065fb1
31616e2
 
c97e46a
31616e2
d55b613
 
31616e2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from typing import TypedDict, Literal, Annotated
from langgraph.graph import StateGraph, START, END, add_messages
from langgraph.prebuilt import ToolNode
from langchain_ollama import ChatOllama
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from src.tools import TOOLMAP
from src.prompts import REASON
from pathlib import Path

class AgentSchema(TypedDict):
  messages: Annotated[list, add_messages]
  niter: int

class ReActAgent():
  def __init__(self, modelid: str, verbose: bool = False, maxreason: int = 5):
    self.verbose, self.maxreason, self.brain = (
      verbose, maxreason, 
      ChatOllama(
        model=modelid, temperature=0, validate_model_on_init=True, reasoning=True
      ).bind_tools(list(TOOLMAP.values()))
    )
    workflow = StateGraph(AgentSchema)
    toolnode = ToolNode(list(TOOLMAP.values()))

    # nodes #
    workflow.add_node("reason", self.reason)
    workflow.add_node("toolnode", toolnode)

    # edges #
    workflow.add_edge(START, "reason")
    workflow.add_conditional_edges("reason", self.next_step)
    workflow.add_edge("toolnode", "reason")

    # compile #
    self.workflow = workflow.compile()
    imagepath = Path("assets", "agent_graph.png")
    imagepath.parent.mkdir(exist_ok=True)
    if imagepath.exists(): imagepath.unlink()
    with open(imagepath, "wb") as imw: imw.write(self.workflow.get_graph().draw_mermaid_png())
    return
  
  def __call__(self, query: str): return self.workflow.invoke({
    "messages": [{'role': 'system', 'content': REASON}, {'role': 'user', 'content': query}],
    "niter": 0
  })

  # nodes #
  def reason(self, state: AgentSchema) -> AgentSchema: 
    response = self.brain.invoke(state.get("messages", []))
    if self.verbose: print(f"Reasoning:\n{response.additional_kwargs.get("reasoning_content", "")}\nResponse:{response.content}")
    return {"messages": [response], "niter": int(state.get("niter"))+1}

  # edges #
  def next_step(self, state: AgentSchema): 
    if state.get("messages")[-1].tool_calls: return "toolnode"
    return END