File size: 5,056 Bytes
e3a472a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""SC-TIR-style agent loop adapted from AIMO3 (math) to coding.

Loop:
    user → LLM → (tool calls?) → tools → LLM → ... → final answer

Stops when:
  - LLM emits content with no tool calls, OR
  - max_steps hit (forces a final response without tools)

The pattern mirrors Sardor's AIMO3 SC-TIR pipeline: the model alternates
between thinking and tool-augmented action, with deterministic verification
on the tool side.
"""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from typing import Any, Dict, List

from serving.base import LLMClient, LLMResponse, ToolCall
from tools.base import ToolRegistry, ToolResult
from agent.prompts import SYSTEM, build_repo_overview, initial_user_prompt


@dataclass
class AgentTurn:
    role: str
    content: str
    tool_calls: List[Dict[str, Any]] = field(default_factory=list)
    tool_call_id: str | None = None


@dataclass
class AgentRun:
    answer: str
    transcript: List[Dict[str, Any]]
    tool_calls: List[Dict[str, Any]]
    steps: int
    finished: bool


class Agent:
    def __init__(
        self,
        llm: LLMClient,
        tools: ToolRegistry,
        max_steps: int = 6,
        max_tool_output_chars: int = 6000,
    ):
        self.llm = llm
        self.tools = tools
        self.max_steps = max_steps
        self.max_tool_output_chars = max_tool_output_chars

    def run(self, question: str, repo_summary: Dict[str, Any]) -> AgentRun:
        overview = build_repo_overview(
            repo=repo_summary.get("repo", ""),
            n_files=repo_summary.get("n_files", 0),
            n_chunks=repo_summary.get("n_chunks", 0),
            total_tokens=repo_summary.get("total_tokens", 0),
            top_paths=_pick_top_paths(repo_summary),
        )
        messages: List[Dict[str, Any]] = [
            {"role": "system", "content": SYSTEM},
            {"role": "user", "content": initial_user_prompt(question, overview)},
        ]
        tool_schema = self.tools.schema()
        tool_calls_log: List[Dict[str, Any]] = []
        step = 0
        finished = False

        while step < self.max_steps:
            resp = self.llm.complete(messages, tools=tool_schema)
            assistant_msg: Dict[str, Any] = {"role": "assistant"}
            if resp.content:
                assistant_msg["content"] = resp.content
            if resp.tool_calls:
                assistant_msg["tool_calls"] = [self._tool_call_to_msg(tc) for tc in resp.tool_calls]
            else:
                assistant_msg.setdefault("content", "")
            messages.append(assistant_msg)

            if not resp.tool_calls:
                finished = True
                break

            for tc in resp.tool_calls:
                tool_calls_log.append({"name": tc.name, "arguments": tc.arguments})
                result = self.tools.call(tc.name, tc.arguments)
                tool_msg = {
                    "role": "tool",
                    "tool_call_id": tc.id,
                    "name": tc.name,
                    "content": self._format_tool_output(result),
                }
                messages.append(tool_msg)
            step += 1

        # If we hit max_steps without a final answer, force one more text-only call.
        if not finished:
            messages.append({
                "role": "user",
                "content": "You've used the tool budget. Provide your best final answer now, without tool calls.",
            })
            resp = self.llm.complete(messages, tools=[])
            messages.append({"role": "assistant", "content": resp.content or ""})

        # Final answer = last assistant message with content
        answer = ""
        for m in reversed(messages):
            if m.get("role") == "assistant" and m.get("content"):
                answer = m["content"]
                break

        return AgentRun(
            answer=answer,
            transcript=messages,
            tool_calls=tool_calls_log,
            steps=step,
            finished=finished,
        )

    def _tool_call_to_msg(self, tc: ToolCall) -> Dict[str, Any]:
        return {
            "id": tc.id,
            "type": "function",
            "function": {"name": tc.name, "arguments": json.dumps(tc.arguments)},
        }

    def _format_tool_output(self, result: ToolResult) -> str:
        body = result.output if result.ok else f"[error] {result.error}"
        if len(body) > self.max_tool_output_chars:
            body = body[: self.max_tool_output_chars] + "\n[... truncated]"
        return body


def _pick_top_paths(summary: Dict[str, Any]) -> List[str]:
    chunks = summary.get("chunks") or []
    seen: List[str] = []
    seen_set = set()
    # priority 0 first, then 1; keep insertion order
    for prio in (0, 1, 2):
        for c in chunks:
            if c.get("priority") == prio and c.get("path") not in seen_set:
                seen.append(c["path"])
                seen_set.add(c["path"])
                if len(seen) >= 60:
                    return seen
    return seen