sql_env / training /trl_adapter.py
hjerpe's picture
Upload folder using huggingface_hub
9e64e71 verified
"""TRL environment adapter for SQLEnv."""
from __future__ import annotations
import collections
try:
from sql_env.models import SQLAction
except ImportError: # pragma: no cover
from models import SQLAction # type: ignore[no-redef]
try:
from sql_env.server.sql_environment import SQLEnvironment
except ImportError: # pragma: no cover
from server.sql_environment import SQLEnvironment # type: ignore[no-redef]
def get_tool_definitions(env_cls: type | None = None) -> list[dict]:
"""Extract tool definitions from an environment class via introspection.
Inspects public methods (excluding reset and dunder) to build the
same JSON schema that TRL generates for environment_factory. This
guarantees SFT and GRPO see identical tool definitions.
"""
import inspect
if env_cls is None:
env_cls = SQLEnvTRL
_SKIP = {"reset", "reward"}
tools = []
for name, method in inspect.getmembers(env_cls, predicate=inspect.isfunction):
if name.startswith("_") or name in _SKIP:
continue
sig = inspect.signature(method)
doc = inspect.getdoc(method) or ""
# Split docstring into description and Args/Returns sections
lines = doc.split("\n")
description = lines[0].strip() if lines else name
# Parse Args section for parameter descriptions
param_descriptions: dict[str, str] = {}
return_description = ""
section = ""
for line in lines[1:]:
stripped = line.strip()
if stripped.lower().startswith("args:"):
section = "args"
continue
if stripped.lower().startswith("returns:"):
section = "returns"
continue
if section == "args" and ":" in stripped:
param_name, param_desc = stripped.split(":", 1)
param_descriptions[param_name.strip()] = param_desc.strip()
if section == "returns" and stripped:
return_description = stripped
# Build parameters schema from signature
properties = {}
required = []
for param_name, param in sig.parameters.items():
if param_name == "self":
continue
properties[param_name] = {
"type": "string",
"description": param_descriptions.get(
param_name, f"{param_name} parameter."
),
}
if param.default is inspect.Parameter.empty:
required.append(param_name)
tool = {
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": {
"type": "object",
"properties": properties,
"required": required,
},
},
}
if return_description:
tool["function"]["return"] = {
"type": "string",
"description": return_description,
}
tools.append(tool)
# Sort by name for deterministic ordering
tools.sort(key=lambda t: t["function"]["name"])
return tools
class _MinimalTokenizer:
"""Minimal tokenizer stub used only for SQLEnvironment initialization."""
def apply_chat_template(
self,
messages: list[dict[str, str]],
*,
tokenize: bool = False,
add_generation_prompt: bool = False,
) -> str:
"""Return an empty rendered prompt string.
Parameters
----------
messages
Chat message payload.
tokenize
Unused tokenizer flag.
add_generation_prompt
Unused generation-prompt flag.
Returns
-------
str
Always an empty string.
"""
del messages
del tokenize
del add_generation_prompt
return ""
_POST_EPISODE_PENALTY = -0.3
# Adapter-level repeat penalty (on top of environment's -0.03 in reward.py).
# Intentionally harsher: the env penalty shapes per-step reward, while this
# penalty shapes the episode-level signal that GRPO sees.
_REPEAT_PENALTY = -0.2
class SQLEnvTRL:
"""TRL-compatible adapter shell for SQLEnv."""
_questions_path: str | None = None
_db_dir: str | None = None
_step_budget: int = 10
@classmethod
def _configure(
cls,
*,
questions_path: str,
db_dir: str,
step_budget: int = 10,
) -> None:
"""Store class-level adapter configuration before TRL instantiation."""
if not questions_path:
raise ValueError("questions_path must be a non-empty string")
if not db_dir:
raise ValueError("db_dir must be a non-empty string")
if step_budget <= 0:
raise ValueError("step_budget must be a positive integer")
cls._questions_path = questions_path
cls._db_dir = db_dir
cls._step_budget = step_budget
def __init__(self) -> None:
"""Initialize a configured SQLEnvironment-backed adapter instance."""
if self.__class__._questions_path is None or self.__class__._db_dir is None:
raise RuntimeError(
"SQLEnvTRL.configure() must be called before SQLEnvTRL()"
)
tokenizer = _MinimalTokenizer()
self._env = SQLEnvironment(
questions_path=self.__class__._questions_path,
db_dir=self.__class__._db_dir,
tokenizer=tokenizer,
step_budget=self.__class__._step_budget,
)
self.reward = 0.0
self._done = False
self._recent_calls: collections.deque[tuple[str, str]] = collections.deque(
maxlen=3
)
self._repeat_count = 0
def reset(self, **kwargs: object) -> str | None:
"""Initialize a new episode and return the initial observation text.
TRL passes dataset columns as kwargs. If ``question_text`` is
present, the environment resets to the matching question (and
therefore the correct database).
Args:
kwargs: Dataset columns from TRL, may include question_text.
Returns:
Short observation hint for the language model, or None.
"""
self.reward = 0.0
self._done = False
self._recent_calls.clear()
self._repeat_count = 0
question_text = kwargs.get("question_text")
if question_text and isinstance(question_text, str):
# Filter to the matching question so the right DB loads
original = list(self._env.questions)
matching = [
q for q in self._env.questions if q.question_text == question_text
]
if matching:
self._env.questions = matching
try:
self._obs = self._env.reset(seed=None)
finally:
self._env.questions = original
else:
self._obs = self._env.reset(seed=None)
else:
self._obs = self._env.reset(seed=None)
# Return concise hint — full observation via describe/sample
tables = []
for line in (self._obs.schema_info or "").split("\n"):
stripped = line.strip().lstrip("- ").strip()
if stripped and stripped != "Available tables:":
tables.append(stripped)
return (
f"Tables: {', '.join(tables)}. "
"Use describe, sample, query, and answer tools."
)
def _dispatch(self, action_type: str, argument: str) -> str:
"""Execute an action with repeat detection and reward accumulation."""
if self._done:
self.reward += _POST_EPISODE_PENALTY
raise ValueError("Episode is over")
call_key = (action_type.lower(), argument)
if call_key in self._recent_calls:
self.reward += _REPEAT_PENALTY
self._repeat_count += 1
self._recent_calls.append(call_key)
observation = self._env.step(
SQLAction(action_type=action_type, argument=argument)
)
if observation.reward is not None:
self.reward += observation.reward
self._done = observation.done
if observation.result:
return observation.result
if observation.error:
return f"Error: {observation.error}"
return "No output."
def describe(self, table_name: str) -> str:
"""Show schema details for a database table.
Args:
table_name: Name of the table to describe.
Returns:
Schema information for the specified table.
"""
return self._dispatch("DESCRIBE", table_name)
def sample(self, table_name: str) -> str:
"""Show sample rows from a database table.
Args:
table_name: Name of the table to sample.
Returns:
Sample row output for the specified table.
"""
return self._dispatch("SAMPLE", table_name)
def query(self, sql: str) -> str:
"""Execute a read-only SQL query.
Args:
sql: SELECT SQL statement to execute.
Returns:
Query output text.
"""
return self._dispatch("QUERY", sql)
def answer(self, value: str) -> str:
"""Submit a final answer for the active episode.
Args:
value: Final answer value to submit.
Returns:
Feedback text for the submitted answer.
"""
return self._dispatch("ANSWER", value)
def sql_env_reward_func(environments, **kwargs):
"""Read accumulated reward from each environment instance.
Args:
environments: Completed environment instances (passed by TRL).
kwargs: Additional TRL reward kwargs (ignored).
Returns:
Reward values aligned with input environment order.
"""
return [float(env.reward) for env in environments]