File size: 4,009 Bytes
ac05fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""hint_generator.py — Template-based hint generator (v0.1 starter).

Composer 2.5 inserts text hints at error-turn sites:
  "Reminder: Available tools are: …"  (when a tool-call refs a non-existent tool)
  "Reminder: tool arguments must be valid JSON"  (on JSONDecodeError)
  ... etc.

This module provides a registry of hint templates keyed by error_kind. The
data collator (in trl_path/data_collator.py) calls dispatch(error_kind, ctx)
to get the hint text to splice into ctx_teacher.

v0.2 will replace these templates with an LLM-driven hint generator (likely
Sonnet 4.6 or Opus 4.7 via OpenRouter) for cases where templates are too rigid
(style violations, wasteful explanations).
"""

from __future__ import annotations

from collections.abc import Callable
from typing import TypedDict


class HintContext(TypedDict, total=False):
    """Per-error context the hint generator can use."""
    error_kind: str          # e.g. "tool_not_found", "json_decode", "type_error"
    error_message: str       # raw error from the env
    available_tools: list[str]  # for tool_not_found
    tool_name: str           # the failing tool, if known
    tool_schema: dict        # the schema, if known
    intent: str              # student's apparent intent, if extractable


# ---------------------------------------------------------------------------
# Hint templates
# ---------------------------------------------------------------------------

def hint_tool_not_found(ctx: HintContext) -> str:
    tools = ctx.get("available_tools", [])
    if tools:
        tool_list = ", ".join(f"`{t}`" for t in tools)
        return f"Reminder: Available tools are: {tool_list}. Please use one of these."
    return "Reminder: the tool you tried to call does not exist. Use only available tools."


def hint_json_decode(ctx: HintContext) -> str:
    return (
        "Reminder: tool arguments must be valid JSON. Common mistakes: "
        "single quotes (use double), trailing commas, unescaped newlines in strings."
    )


def hint_type_error(ctx: HintContext) -> str:
    name = ctx.get("tool_name")
    schema = ctx.get("tool_schema")
    if name and schema:
        return (
            f"Reminder: `{name}` expects arguments matching this schema:\n"
            f"  {schema}\n"
            "Re-issue the call with arguments matching the schema."
        )
    return "Reminder: tool arguments do not match the expected types. Check the schema."


def hint_runtime_error(ctx: HintContext) -> str:
    msg = ctx.get("error_message", "an exception")
    return (
        f"Reminder: the previous tool call raised {msg}. "
        "Reconsider the inputs or read the relevant code first to understand state."
    )


def hint_repeated_failure(ctx: HintContext) -> str:
    """Triggered when the same kind of error happens 3+ times in a row."""
    return (
        "Reminder: this approach has failed multiple times. "
        "Step back and consider an alternative approach: read more files, "
        "search for similar patterns elsewhere, or break the task down differently."
    )


# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------

HINT_TEMPLATES: dict[str, Callable[[HintContext], str]] = {
    "tool_not_found":   hint_tool_not_found,
    "json_decode":      hint_json_decode,
    "type_error":       hint_type_error,
    "runtime_error":    hint_runtime_error,
    "repeated_failure": hint_repeated_failure,
}


def dispatch(error_kind: str, ctx: HintContext | None = None) -> str | None:
    """Generate a hint for the given error_kind. Returns None if unknown."""
    fn = HINT_TEMPLATES.get(error_kind)
    if fn is None:
        return None
    return fn(ctx or {})


def register(error_kind: str, fn: Callable[[HintContext], str]) -> None:
    """Add a custom hint template."""
    HINT_TEMPLATES[error_kind] = fn


__all__ = ["dispatch", "register", "HintContext", "HINT_TEMPLATES"]