| """TRL environment adapter for SQLEnv.""" |
|
|
| from __future__ import annotations |
|
|
| import collections |
|
|
| try: |
| from sql_env.models import SQLAction |
| except ImportError: |
| from models import SQLAction |
|
|
| try: |
| from sql_env.server.sql_environment import SQLEnvironment |
| except ImportError: |
| from server.sql_environment import SQLEnvironment |
|
|
|
|
| def get_tool_definitions(env_cls: type | None = None) -> list[dict]: |
| """Extract tool definitions from an environment class via introspection. |
| |
| Inspects public methods (excluding reset and dunder) to build the |
| same JSON schema that TRL generates for environment_factory. This |
| guarantees SFT and GRPO see identical tool definitions. |
| """ |
| import inspect |
|
|
| if env_cls is None: |
| env_cls = SQLEnvTRL |
|
|
| _SKIP = {"reset", "reward"} |
| tools = [] |
|
|
| for name, method in inspect.getmembers(env_cls, predicate=inspect.isfunction): |
| if name.startswith("_") or name in _SKIP: |
| continue |
|
|
| sig = inspect.signature(method) |
| doc = inspect.getdoc(method) or "" |
|
|
| |
| lines = doc.split("\n") |
| description = lines[0].strip() if lines else name |
|
|
| |
| param_descriptions: dict[str, str] = {} |
| return_description = "" |
| section = "" |
| for line in lines[1:]: |
| stripped = line.strip() |
| if stripped.lower().startswith("args:"): |
| section = "args" |
| continue |
| if stripped.lower().startswith("returns:"): |
| section = "returns" |
| continue |
| if section == "args" and ":" in stripped: |
| param_name, param_desc = stripped.split(":", 1) |
| param_descriptions[param_name.strip()] = param_desc.strip() |
| if section == "returns" and stripped: |
| return_description = stripped |
|
|
| |
| properties = {} |
| required = [] |
| for param_name, param in sig.parameters.items(): |
| if param_name == "self": |
| continue |
| properties[param_name] = { |
| "type": "string", |
| "description": param_descriptions.get( |
| param_name, f"{param_name} parameter." |
| ), |
| } |
| if param.default is inspect.Parameter.empty: |
| required.append(param_name) |
|
|
| tool = { |
| "type": "function", |
| "function": { |
| "name": name, |
| "description": description, |
| "parameters": { |
| "type": "object", |
| "properties": properties, |
| "required": required, |
| }, |
| }, |
| } |
| if return_description: |
| tool["function"]["return"] = { |
| "type": "string", |
| "description": return_description, |
| } |
|
|
| tools.append(tool) |
|
|
| |
| tools.sort(key=lambda t: t["function"]["name"]) |
| return tools |
|
|
|
|
| class _MinimalTokenizer: |
| """Minimal tokenizer stub used only for SQLEnvironment initialization.""" |
|
|
| def apply_chat_template( |
| self, |
| messages: list[dict[str, str]], |
| *, |
| tokenize: bool = False, |
| add_generation_prompt: bool = False, |
| ) -> str: |
| """Return an empty rendered prompt string. |
| |
| Parameters |
| ---------- |
| messages |
| Chat message payload. |
| tokenize |
| Unused tokenizer flag. |
| add_generation_prompt |
| Unused generation-prompt flag. |
| |
| Returns |
| ------- |
| str |
| Always an empty string. |
| """ |
|
|
| del messages |
| del tokenize |
| del add_generation_prompt |
| return "" |
|
|
|
|
| _POST_EPISODE_PENALTY = -0.3 |
| |
| |
| |
| _REPEAT_PENALTY = -0.2 |
|
|
|
|
| class SQLEnvTRL: |
| """TRL-compatible adapter shell for SQLEnv.""" |
|
|
| _questions_path: str | None = None |
| _db_dir: str | None = None |
| _step_budget: int = 10 |
|
|
| @classmethod |
| def _configure( |
| cls, |
| *, |
| questions_path: str, |
| db_dir: str, |
| step_budget: int = 10, |
| ) -> None: |
| """Store class-level adapter configuration before TRL instantiation.""" |
|
|
| if not questions_path: |
| raise ValueError("questions_path must be a non-empty string") |
| if not db_dir: |
| raise ValueError("db_dir must be a non-empty string") |
| if step_budget <= 0: |
| raise ValueError("step_budget must be a positive integer") |
|
|
| cls._questions_path = questions_path |
| cls._db_dir = db_dir |
| cls._step_budget = step_budget |
|
|
| def __init__(self) -> None: |
| """Initialize a configured SQLEnvironment-backed adapter instance.""" |
|
|
| if self.__class__._questions_path is None or self.__class__._db_dir is None: |
| raise RuntimeError( |
| "SQLEnvTRL.configure() must be called before SQLEnvTRL()" |
| ) |
|
|
| tokenizer = _MinimalTokenizer() |
| self._env = SQLEnvironment( |
| questions_path=self.__class__._questions_path, |
| db_dir=self.__class__._db_dir, |
| tokenizer=tokenizer, |
| step_budget=self.__class__._step_budget, |
| ) |
| self.reward = 0.0 |
| self._done = False |
| self._recent_calls: collections.deque[tuple[str, str]] = collections.deque( |
| maxlen=3 |
| ) |
| self._repeat_count = 0 |
|
|
| def reset(self, **kwargs: object) -> str | None: |
| """Initialize a new episode and return the initial observation text. |
| |
| TRL passes dataset columns as kwargs. If ``question_text`` is |
| present, the environment resets to the matching question (and |
| therefore the correct database). |
| |
| Args: |
| kwargs: Dataset columns from TRL, may include question_text. |
| |
| Returns: |
| Short observation hint for the language model, or None. |
| """ |
|
|
| self.reward = 0.0 |
| self._done = False |
| self._recent_calls.clear() |
| self._repeat_count = 0 |
|
|
| question_text = kwargs.get("question_text") |
| if question_text and isinstance(question_text, str): |
| |
| original = list(self._env.questions) |
| matching = [ |
| q for q in self._env.questions if q.question_text == question_text |
| ] |
| if matching: |
| self._env.questions = matching |
| try: |
| self._obs = self._env.reset(seed=None) |
| finally: |
| self._env.questions = original |
| else: |
| self._obs = self._env.reset(seed=None) |
| else: |
| self._obs = self._env.reset(seed=None) |
|
|
| |
| tables = [] |
| for line in (self._obs.schema_info or "").split("\n"): |
| stripped = line.strip().lstrip("- ").strip() |
| if stripped and stripped != "Available tables:": |
| tables.append(stripped) |
| return ( |
| f"Tables: {', '.join(tables)}. " |
| "Use describe, sample, query, and answer tools." |
| ) |
|
|
| def _dispatch(self, action_type: str, argument: str) -> str: |
| """Execute an action with repeat detection and reward accumulation.""" |
| if self._done: |
| self.reward += _POST_EPISODE_PENALTY |
| raise ValueError("Episode is over") |
|
|
| call_key = (action_type.lower(), argument) |
| if call_key in self._recent_calls: |
| self.reward += _REPEAT_PENALTY |
| self._repeat_count += 1 |
| self._recent_calls.append(call_key) |
|
|
| observation = self._env.step( |
| SQLAction(action_type=action_type, argument=argument) |
| ) |
| if observation.reward is not None: |
| self.reward += observation.reward |
| self._done = observation.done |
|
|
| if observation.result: |
| return observation.result |
| if observation.error: |
| return f"Error: {observation.error}" |
| return "No output." |
|
|
| def describe(self, table_name: str) -> str: |
| """Show schema details for a database table. |
| |
| Args: |
| table_name: Name of the table to describe. |
| |
| Returns: |
| Schema information for the specified table. |
| """ |
| return self._dispatch("DESCRIBE", table_name) |
|
|
| def sample(self, table_name: str) -> str: |
| """Show sample rows from a database table. |
| |
| Args: |
| table_name: Name of the table to sample. |
| |
| Returns: |
| Sample row output for the specified table. |
| """ |
| return self._dispatch("SAMPLE", table_name) |
|
|
| def query(self, sql: str) -> str: |
| """Execute a read-only SQL query. |
| |
| Args: |
| sql: SELECT SQL statement to execute. |
| |
| Returns: |
| Query output text. |
| """ |
| return self._dispatch("QUERY", sql) |
|
|
| def answer(self, value: str) -> str: |
| """Submit a final answer for the active episode. |
| |
| Args: |
| value: Final answer value to submit. |
| |
| Returns: |
| Feedback text for the submitted answer. |
| """ |
| return self._dispatch("ANSWER", value) |
|
|
|
|
| def sql_env_reward_func(environments, **kwargs): |
| """Read accumulated reward from each environment instance. |
| |
| Args: |
| environments: Completed environment instances (passed by TRL). |
| kwargs: Additional TRL reward kwargs (ignored). |
| |
| Returns: |
| Reward values aligned with input environment order. |
| """ |
| return [float(env.reward) for env in environments] |
|
|