| # Implementation Specification |
|
|
| **Change:** F006 -- GRPO Training Pipeline |
| **Date:** 2026-03-27 |
| **Research Summary:** [specs/F006-RESEARCH_SUMMARY.md](F006-RESEARCH_SUMMARY.md) |
| **Verification Spec:** See VERIFICATION_SPEC.md (generated by autocode-verification-planner) |
| **Behavior Delta:** Archived to [specs/behavior/training.md](behavior/training.md) |
| |
| **Plan Status:** |
| - [x] Draft |
| - [x] Approved for Implementation |
| - [x] Implementation Complete |
| - [x] Verification Passed |
| |
| --- |
| |
| ## Core Intent (Immutable) |
| |
| > **DO NOT MODIFY THIS SECTION DURING REFINEMENT** |
| > Changes to Core Intent mean you are describing a different feature. |
| > If refinement reveals the need to change this section, create a new feature instead. |
| |
| **User Problem:** |
| Train a model that learns SQL exploration strategy through RL. The "before vs after" comparison is the competition's money shot -- untrained agent flails randomly, trained agent explores strategically. |
| |
| **Success Criteria:** |
| - Training notebook runs end-to-end in one click |
| - Learning curve clearly shows improvement over episodes |
| - Side-by-side episode transcripts: random vs trained |
| - Reproducible results (deterministic given seed) |
| |
| **Avoid:** |
| - Training that does not converge at all (no learning signal) |
| - Requiring an expensive GPU for hours to see any signal |
| - Notebook with hidden dependencies that break on fresh setup |
| |
| **Out of Scope:** |
| - wandb / TensorBoard integration (MVP: print metrics) |
| - vLLM inference (use HF generate for simplicity) |
| - Hard-difficulty questions in training set (add later) |
| - WebSocket-based training (use local env) |
| - Multi-GPU / distributed training |
| - Custom RLHF algorithms beyond GRPO |
| |
| --- |
| |
| ## 0. Slicing & Scope Budget (Anti-Waterfall) |
| |
| This spec must be executable in **small, mergeable increments**. |
| |
| ### Scope Budget |
| - Target: **3 slices** |
| - Hard max: **<= 10 steps total** |
| - Each step must end in: **implement -> verify -> merge** |
| |
| ### Slice Definition |
| |
| | Slice | Name | Value | |
| |-------|------|-------| |
| | S1 | Training Config + Prompts | Configurable training setup, system prompt for SQL agent | |
| | S2 | Rollout + Rewards | TRL-compatible rollout function and reward callables | |
| | S3 | Training Notebook | End-to-end notebook with learning curve and comparison | |
| |
| ## Status Icons |
| |
| **Step Status:** |
| - !! Not Started |
| - >> In Progress |
| - OK Completed |
| - XX Blocked/Failed |
| |
| **Result Outcome:** |
| - OK Fully Successful (all tests passed, no issues) |
| - !! Completed with Issues (needs follow-up) |
| - XX Failed/Blocked |
| |
| --- |
| |
| ## 1. Implementation Overview |
| |
| ### Summary |
| |
| Add a `training/` subpackage with configuration, rollout, reward wrappers, and prompt modules that integrate with TRL's GRPOTrainer. Provide a `notebooks/train_grpo.ipynb` notebook as the user-facing entry point that trains a small LLM (default: Qwen3-1.7B) to play SQLEnv, then produces learning curves and before/after episode comparisons. |
|
|
| ### Scope |
|
|
| **In Scope:** |
| - `training/config.py` -- dataclass with all hyperparameters and model name |
| - `training/prompts.py` -- system prompt for SQL exploration agent |
| - `training/rollout.py` -- `rollout_func` that plays SQLEnv episodes via HF generate |
| - `training/rewards.py` -- reward callables matching TRL `reward_funcs` signature |
| - `notebooks/train_grpo.ipynb` -- end-to-end training notebook |
| - `training/__init__.py` -- public exports |
|
|
| **Out of Scope:** |
| - vLLM inference backend |
| - wandb/TensorBoard logging |
| - Training on hard-difficulty questions |
| - Distributed or multi-GPU training |
|
|
| --- |
|
|
| ## 1a. Execution Status |
|
|
| **Progress:** 6/6 steps complete |
| **Current Step:** None (implementation complete) |
| **Last Updated:** 2026-03-28T07:37:20Z |
| **Latest Result:** OK Fully Successful - Step 3.1 complete, 68/68 tests passed |
| **Blockers:** None |
|
|
| --- |
|
|
| ## 1b. Risk Assessment |
|
|
| **Risk Tier:** Medium |
|
|
| **Risk Tier Definitions:** |
| - **Low:** Pure logic, non-user-facing, no security implications |
| - **Medium:** User input handling, data validation, API changes |
| - **High:** Authentication, payments, secrets management, untrusted input |
|
|
| **High-Risk Indicators Present:** None |
|
|
| **Security Review Required:** No |
|
|
| **Justification:** |
| External model loading from HuggingFace Hub and GPU resource management require care, but no security-sensitive data flows. Risk is primarily around convergence and resource requirements. |
|
|
| --- |
|
|
| ## 2. Change Manifest |
|
|
| ### Files to Create |
|
|
| | File | Purpose | |
| |------|---------| |
| | `training/__init__.py` | Package init, public exports | |
| | `training/config.py` | `GRPOConfig` dataclass with hyperparameters | |
| | `training/prompts.py` | System prompt for SQL exploration agent | |
| | `training/rollout.py` | `rollout_func` for TRL GRPOTrainer | |
| | `training/rewards.py` | Reward callables: correctness, progress, operational | |
| | `training/data_loading.py` | Model/question loading helpers for notebook runtime and tests | |
| | `training/notebook_pipeline.py` | Notebook orchestration helpers for trainer setup, baseline, and metrics | |
| | `notebooks/train_grpo.ipynb` | End-to-end training notebook | |
| | `tests/integration/test_training_pipeline.py` | Integration verification for rollout + rewards pipeline | |
| | `tests/e2e/test_training_e2e.py` | Notebook smoke verification and pipeline behavior checks | |
| | `tests/unit/test_error_handling.py` | Error-path verification for model/questions loading and fallback logging | |
|
|
| ### Files to Modify |
|
|
| | File | Changes | |
| |------|---------| |
| | `pyproject.toml` | Add `trl` and training optional dependency group | |
|
|
| ### Files to Delete |
|
|
| None. |
|
|
| --- |
|
|
| ## 3. Interface Specifications |
|
|
| ### New Types |
|
|
| ```python |
| # Location: training/config.py |
| |
| from dataclasses import dataclass, field |
| |
| @dataclass |
| class GRPOConfig: |
| """All hyperparameters for GRPO training on SQLEnv.""" |
| |
| # Model |
| model_name: str = "Qwen/Qwen3-1.7B" |
| max_new_tokens: int = 256 |
| |
| # Training |
| num_train_epochs: int = 1 |
| per_device_train_batch_size: int = 2 |
| gradient_accumulation_steps: int = 4 |
| learning_rate: float = 5e-6 |
| num_generations: int = 4 # G in GRPO (completions per prompt) |
| |
| # Environment |
| questions_path: str = "data/questions/questions_train.json" |
| db_dir: str = "data/databases" |
| step_budget: int = 10 # Shorter budget for training |
| difficulty_filter: list[str] = field(default_factory=lambda: ["easy", "medium"]) |
| |
| # Reproducibility |
| seed: int = 42 |
| |
| # Output |
| output_dir: str = "outputs/grpo_run" |
| logging_steps: int = 10 |
| ``` |
|
|
| ### New Functions |
|
|
| ```python |
| # Location: training/prompts.py |
| |
| def get_system_prompt() -> str: |
| """Return the system prompt for the SQL exploration agent. |
| |
| Returns: |
| System prompt string instructing the model on SQLEnv action format. |
| """ |
| |
| |
| def format_observation(obs: "SQLObservation") -> str: |
| """Format an SQLObservation into a user-turn string for the model. |
| |
| Args: |
| obs: The observation from the environment. |
| |
| Returns: |
| Formatted string suitable as a user message in chat history. |
| """ |
| ``` |
|
|
| ```python |
| # Location: training/rollout.py |
| |
| from typing import Any |
| |
| def rollout_func( |
| prompts: list[str], |
| model: Any, |
| tokenizer: Any, |
| config: "GRPOConfig", |
| ) -> list[dict[str, Any]]: |
| """Play SQLEnv episodes for a batch of question prompts. |
| |
| Each prompt is a question text. The function: |
| 1. Creates a local SQLEnvironment |
| 2. Resets with the question |
| 3. Loops: model.generate() -> parse action -> env.step() |
| 4. Collects completions and metadata |
| |
| Args: |
| prompts: List of question texts (from training dataset). |
| model: HuggingFace model for generation. |
| tokenizer: HuggingFace tokenizer. |
| config: Training configuration. |
| |
| Returns: |
| List of dicts with keys: |
| - "prompt": str (the input prompt) |
| - "completion": str (full model output trajectory) |
| - "metadata": dict with episode_id, steps, done, answer_correct |
| """ |
| ``` |
|
|
| ```python |
| # Location: training/rewards.py |
| |
| def reward_correctness( |
| completions: list[list[dict[str, str]]], |
| **kwargs: Any, |
| ) -> list[float]: |
| """Binary reward: 1.0 if episode ended with correct answer, 0.0 otherwise. |
| |
| Args: |
| completions: Batch of completion message lists (TRL format). |
| **kwargs: Additional metadata from rollout (includes 'metadata' key). |
| |
| Returns: |
| List of float rewards, one per completion. |
| """ |
| |
| |
| def reward_progress( |
| completions: list[list[dict[str, str]]], |
| **kwargs: Any, |
| ) -> list[float]: |
| """Progress reward: cumulative progress score from environment. |
| |
| Args: |
| completions: Batch of completion message lists (TRL format). |
| **kwargs: Additional metadata from rollout. |
| |
| Returns: |
| List of float rewards, one per completion. |
| """ |
| |
| |
| def reward_operational( |
| completions: list[list[dict[str, str]]], |
| **kwargs: Any, |
| ) -> list[float]: |
| """Operational reward: sum of per-step L1 signals (exec_ok, new_info, etc.). |
| |
| Args: |
| completions: Batch of completion message lists (TRL format). |
| **kwargs: Additional metadata from rollout. |
| |
| Returns: |
| List of float rewards, one per completion. |
| """ |
| ``` |
|
|
| --- |
|
|
| ## 4. Data Flow |
|
|
| ### Primary Flow (Training Loop) |
|
|
| ``` |
| 1. Notebook loads GRPOConfig and model/tokenizer from HuggingFace |
| - Input: config.model_name |
| - Output: model, tokenizer, config |
| |
| 2. Load training questions filtered by difficulty |
| - Input: config.questions_path, config.difficulty_filter |
| - Output: list[str] of question texts as prompts |
| |
| 3. GRPOTrainer calls rollout_func for each batch of prompts |
| - Input: prompts, model, tokenizer, config |
| - Action: For each prompt, play a full SQLEnv episode |
| a. Create local SQLEnvironment |
| b. env.reset(question) -> initial observation |
| c. Loop: format obs -> model.generate() -> parse SQLAction -> env.step() |
| d. Collect full trajectory as completion string |
| - Output: completions + metadata (correctness, progress, operational signals) |
| |
| 4. GRPOTrainer calls each reward_func on completions |
| - Input: completions list, metadata kwargs |
| - Output: list[float] per reward function |
| |
| 5. GRPOTrainer computes GRPO loss and updates model weights |
| - Input: completions, rewards, model |
| - Output: updated model weights, logged metrics |
| |
| 6. Repeat steps 3-5 for num_train_epochs |
| ``` |
|
|
| ### Alternative Flow: Unparseable Model Output |
|
|
| ``` |
| 1. Model generates text that cannot be parsed as SQLAction |
| 2. rollout_func defaults to QUERY action with raw text as argument |
| 3. Environment returns an error observation |
| 4. Episode continues (agent can recover in subsequent steps) |
| ``` |
|
|
| ### Alternative Flow: Episode Exceeds Token Budget |
|
|
| ``` |
| 1. Observation context grows beyond max_new_tokens window |
| 2. rollout_func truncates conversation history, keeping: |
| a. System prompt (always) |
| b. Most recent 3 observation-action pairs |
| 3. Episode continues with truncated context |
| ``` |
|
|
| --- |
|
|
| ## 5. Error Handling |
|
|
| ### Error Types |
|
|
| | Error | When | Strategy | |
| |-------|------|----------| |
| | `ModelLoadError` | Model not found on HuggingFace | Fail fast with clear message naming model_name | |
| | `ActionParseError` | Model output not parseable as SQLAction | Default to QUERY with raw text, log warning | |
| | `OOMError` | GPU out of memory during training | Print guidance: reduce batch_size or num_generations | |
| | `QuestionLoadError` | Questions file missing or empty | Fail fast with path in error message | |
| | `EnvironmentError` | SQLEnv database missing | Fail fast pointing to data download instructions | |
| |
| ### Error Handling Strategy |
| |
| ```python |
| # In rollout_func: graceful degradation |
| try: |
| action = parse_action(model_output) |
| except ActionParseError: |
| action = SQLAction(action_type="QUERY", argument=model_output) |
| |
| # In notebook: fail-fast on setup |
| try: |
| model = AutoModelForCausalLM.from_pretrained(config.model_name) |
| except Exception as e: |
| raise RuntimeError(f"Cannot load model '{config.model_name}': {e}") |
| ``` |
| |
| ### Retry Strategy |
|
|
| | Operation | Retry? | Strategy | |
| |-----------|--------|----------| |
| | Model download | No | Fail fast, user must fix network/model name | |
| | Episode rollout | No | Single attempt per episode, errors become low-reward signal | |
| | Training step | No | OOM is fatal for that config, must adjust params | |
|
|
| --- |
|
|
| ## 6. Slice Plan (What we will ship, in order) |
|
|
| ### Slice S1 -- Training Config + Prompts |
| **Value:** Centralized, documented configuration and system prompt ready for training integration |
| **User-visible change:** No (internal infrastructure) |
| **Interfaces introduced/changed:** `GRPOConfig`, `get_system_prompt()`, `format_observation()` |
| **Rollback safety:** Additive only -- new files, no existing code changed |
|
|
| ### Slice S2 -- Rollout + Rewards |
| **Value:** TRL-compatible rollout and reward functions that can drive GRPO training |
| **User-visible change:** No (library code) |
| **Interfaces introduced/changed:** `rollout_func()`, `reward_correctness()`, `reward_progress()`, `reward_operational()` |
| **Rollback safety:** Additive only -- new files in training/ package |
|
|
| ### Slice S3 -- Training Notebook |
| **Value:** Users can run one notebook to train a model and see before/after results |
| **User-visible change:** Yes -- the notebook is the primary deliverable |
| **Interfaces introduced/changed:** `notebooks/train_grpo.ipynb`, `pyproject.toml` training deps |
| **Rollback safety:** Notebook is standalone; pyproject.toml change is additive (optional deps group) |
|
|
| --- |
|
|
| ## 7. Implementation Steps |
|
|
| > **VERIFICATION NOTE:** Test criteria for each step are defined in VERIFICATION_SPEC.md. |
| > The verification-planner (separate agent) generated independent test criteria. |
| > Run the tests specified there after implementing each step. |
| |
| ### Step 1.1: Training Config Dataclass |
| **Slice:** S1 |
| **Goal:** Create `training/config.py` with `GRPOConfig` dataclass holding all hyperparameters. |
| |
| **Files:** |
| - `training/__init__.py` - create - package init with public exports |
| - `training/config.py` - create - GRPOConfig dataclass |
| |
| **Interface Changes:** |
| - New type: `GRPOConfig` with fields as specified in Section 3 |
| |
| **Verification:** |
| > See VERIFICATION_SPEC.md for test criteria defined by independent verification planner. |
|
|
| **Risk Tier for This Step:** Low |
|
|
| **Merge Criteria:** |
| - [x] Tests from VERIFICATION_SPEC.md pass |
| - [x] No TODOs left in changed code (or explicitly tracked) |
| - [x] Backwards compatible (or flag/migration documented) |
| |
| **Status:** OK Completed |
| |
| **Completed:** 2026-03-28T06:44:31Z |
| **Changes Made:** |
| - Created `training/config.py` with `GRPOConfig` dataclass and input validation in `__post_init__` |
| - Created `training/__init__.py` exporting `GRPOConfig` |
| - Added `tests/unit/test_grpo_config.py` covering defaults, overrides, required fields, and validation failures |
| |
| **Result:** |
| - **Outcome:** OK Fully Successful |
| - **Evidence Captured:** |
| ``` |
| Command: uv run --with pytest pytest tests/unit/test_grpo_config.py -v |
| Result: 7 passed in 17.06s |
| ``` |
| - **Tests run:** `uv run --with pytest pytest tests/unit/test_grpo_config.py -v` |
| - **Notes:** |
| - Added explicit validation for numeric bounds and non-empty difficulty filter to fail fast during setup |
| - `uv run pytest ...` failed because pytest is not installed by default; used `uv run --with pytest pytest ...` for scoped test dependency |
| - Kept config required fields (`questions_path`, `db_dir`, `output_dir`) positional/required per verification criteria |
| - **Issues:** None |
| - **Follow-ups Created:** None |
| - **Human Review Completed:** N/A |
|
|
| **Context for Next Step:** |
| - GRPOConfig available for import by prompts.py and rollout.py |
|
|
| --- |
|
|
| ### Step 1.2: System Prompt and Observation Formatter |
| **Slice:** S1 |
| **Goal:** Create `training/prompts.py` with system prompt and observation formatting for model input. |
|
|
| **Files:** |
| - `training/prompts.py` - create - system prompt and observation formatter |
|
|
| **Interface Changes:** |
| - New functions: `get_system_prompt() -> str`, `format_observation(obs: SQLObservation) -> str` |
|
|
| **Details:** |
| - System prompt should instruct the model on: |
| - Available actions: DESCRIBE, SAMPLE, QUERY, ANSWER |
| - Action format: `ACTION_TYPE: argument` |
| - Exploration strategy guidance (describe tables first, then query, then answer) |
| - Budget awareness |
| - `format_observation` converts SQLObservation fields into a readable user-turn string |
|
|
| **Verification:** |
| > See VERIFICATION_SPEC.md for test criteria defined by independent verification planner. |
| |
| **Risk Tier for This Step:** Low |
| |
| **Merge Criteria:** |
| - [x] Tests from VERIFICATION_SPEC.md pass |
| - [x] No TODOs left in changed code (or explicitly tracked) |
| - [x] Backwards compatible (or flag/migration documented) |
|
|
| **Status:** OK Completed |
|
|
| **Completed:** 2026-03-28T06:47:49Z |
| **Changes Made:** |
| - Created `training/prompts.py` with deterministic `get_system_prompt()` and `format_observation()` helpers |
| - Added truncation guard for long observation results to keep prompt payload bounded |
| - Updated `training/__init__.py` exports to include prompt helpers |
| - Added `tests/unit/test_prompts.py` covering prompt content and observation formatting edge cases |
|
|
| **Result:** |
| - **Outcome:** OK Fully Successful |
| - **Evidence Captured:** |
| ``` |
| Command: uv run --with pytest pytest tests/unit/test_prompts.py -v |
| Result: 8 passed in 2.92s |
| ``` |
| - **Tests run:** `uv run --with pytest pytest tests/unit/test_prompts.py -v` |
| - **Notes:** |
| - `uv run pytest ...` failed because pytest is not installed in the base env; used `uv run --with pytest pytest ...` for scoped dependency execution |
| - **Issues:** None |
| - **Follow-ups Created:** None |
| - **Human Review Completed:** N/A |
|
|
| **Context for Next Step:** |
| - Prompt module ready for use in rollout.py |
|
|
| --- |
|
|
| ### Step 2.1: Action Parser Utility |
| **Slice:** S2 |
| **Goal:** Create a robust parser that extracts `SQLAction` from free-form model output text. |
|
|
| **Files:** |
| - `training/rollout.py` - create - contains `parse_model_output(text: str) -> SQLAction` |
|
|
| **Interface Changes:** |
| - New function: `parse_model_output(text: str) -> SQLAction` |
| - Parses `ACTION_TYPE: argument` format from model text |
| - Falls back to `SQLAction(action_type="QUERY", argument=text)` on parse failure |
|
|
| **Verification:** |
| > See VERIFICATION_SPEC.md for test criteria defined by independent verification planner. |
| |
| **Risk Tier for This Step:** Low |
| |
| **Merge Criteria:** |
| - [x] Tests from VERIFICATION_SPEC.md pass |
| - [x] No TODOs left in changed code (or explicitly tracked) |
| - [x] Backwards compatible (or flag/migration documented) |
|
|
| **Status:** OK Completed |
|
|
| **Completed:** 2026-03-28T06:51:50Z |
| **Changes Made:** |
| - Created `training/rollout.py` with `parse_model_output(text)` and a focused line parser helper |
| - Added action parsing for DESCRIBE/SAMPLE/QUERY/ANSWER with case-insensitive matching |
| - Added robust fallback behavior to `SQLAction(action_type="QUERY", argument=<raw_text>)` on parse failure |
| - Added `tests/unit/test_rollout.py` with coverage for happy path, edge cases, multiline output, and fallback behavior |
|
|
| **Result:** |
| - **Outcome:** OK Fully Successful |
| - **Evidence Captured:** |
| ``` |
| Command: uv run --with pytest pytest tests/unit/test_rollout.py -v |
| Result: 11 passed in 2.44s |
| ``` |
| - **Tests run:** `uv run --with pytest pytest tests/unit/test_rollout.py -v` |
| - **Notes:** |
| - `uv run pytest ...` failed because pytest is not installed in the base env; used `uv run --with pytest pytest ...` for scoped dependency execution |
| - **Issues:** None |
| - **Follow-ups Created:** None |
| - **Human Review Completed:** N/A |
|
|
| **Context for Next Step:** |
| - parse_model_output is available in `training/rollout.py` for Step 2.2 rollout integration |
|
|
| --- |
|
|
| ### Step 2.2: Rollout Function |
| **Slice:** S2 |
| **Goal:** Implement `rollout_func` that plays full SQLEnv episodes using HF generate. |
|
|
| **Files:** |
| - `training/rollout.py` - modify - add `rollout_func` and `play_episode` helper |
|
|
| **Interface Changes:** |
| - New function: `rollout_func(prompts, model, tokenizer, config) -> list[dict]` |
| - New helper: `play_episode(question_text, model, tokenizer, config, env) -> dict` |
| - Creates local SQLEnvironment for the episode |
| - Loops: format obs -> generate -> parse -> step until done or budget exhausted |
| - Returns completion string and metadata dict |
|
|
| **Details:** |
| - Use `model.generate()` (HF native, not vLLM) for inference |
| - Build chat messages using tokenizer.apply_chat_template |
| - Truncate conversation history if it exceeds token window (keep system prompt + last 3 turns) |
| - Metadata includes: episode_id, step_count, done, answer_correct, cumulative_progress, operational_signals |
| |
| **Verification:** |
| > See VERIFICATION_SPEC.md for test criteria defined by independent verification planner. |
|
|
| **Risk Tier for This Step:** Medium |
| > Core integration point between model and environment -- most likely source of bugs. |
|
|
| **Merge Criteria:** |
| - [x] Tests from VERIFICATION_SPEC.md pass |
| - [x] No TODOs left in changed code (or explicitly tracked) |
| - [x] Backwards compatible (or flag/migration documented) |
| |
| **Status:** OK Completed |
| |
| **Completed:** 2026-03-28T07:04:59Z |
| **Changes Made:** |
| - Expanded `training/rollout.py` with `rollout_func`, `play_episode`, message-history truncation, prompt-aware environment reset, and HF `model.generate()` integration paths for both list and tensor-like outputs. |
| - Added rollout metadata fields (`episode_id`, `step_count`, `done`, `answer_correct`, `cumulative_progress`, `operational_signals`) and top-level compatibility keys (`content`, `correct`, `progress`, `operational`). |
| - Extended `tests/unit/test_rollout.py` with Step 2.2 coverage for batch behavior, step-budget termination, metadata shape, unparseable-action fallback continuity, history truncation, HF-style generation decoding, prompt binding, and incorrect-answer correctness guard. |
|
|
| **Result:** |
| - **Outcome:** OK Fully Successful |
| - **Evidence Captured:** |
| ``` |
| Command: uv run --with pytest pytest tests/unit/test_rollout.py -v |
| Result: 21 passed in 2.58s |
| ``` |
| - **Tests run:** `uv run --with pytest pytest tests/unit/test_rollout.py -v` |
| - **Notes:** |
| - Used `uv run --with pytest ...` because `pytest` is not available in the base environment. |
| - Medium-risk reviewer gate executed and resolved to APPROVE after decoder/correctness fixes. |
| - **Issues:** None |
| - **Follow-ups Created:** None |
| - **Human Review Completed:** N/A |
|
|
| **Context for Next Step:** |
| - rollout metadata now carries correctness/progress/operational signals needed by `training/rewards.py` in Step 2.3 |
|
|
| --- |
|
|
| ### Step 2.3: Reward Functions |
| **Slice:** S2 |
| **Goal:** Implement three TRL-compatible reward callables that consume rollout metadata. |
|
|
| **Files:** |
| - `training/rewards.py` - create - reward_correctness, reward_progress, reward_operational |
| |
| **Interface Changes:** |
| - New functions (all with TRL reward_func signature): |
| - `reward_correctness(completions, **kwargs) -> list[float]` |
| - `reward_progress(completions, **kwargs) -> list[float]` |
| - `reward_operational(completions, **kwargs) -> list[float]` |
|
|
| **Details:** |
| - `reward_correctness`: Binary 1.0/0.0 based on metadata["answer_correct"] |
| - `reward_progress`: Float from metadata["cumulative_progress"], normalized to [0, 1] |
| - `reward_operational`: Sum of per-step operational signals from metadata["operational_signals"] |
| - All functions access metadata via kwargs (TRL passes extra data from rollout return) |
| - Each function must handle missing metadata gracefully (return 0.0) |
| |
| **Verification:** |
| > See VERIFICATION_SPEC.md for test criteria defined by independent verification planner. |
|
|
| **Risk Tier for This Step:** Low |
|
|
| **Merge Criteria:** |
| - [x] Tests from VERIFICATION_SPEC.md pass |
| - [x] No TODOs left in changed code (or explicitly tracked) |
| - [x] Backwards compatible (or flag/migration documented) |
| |
| **Status:** OK Completed |
| |
| **Completed:** 2026-03-28T07:07:32Z |
| **Changes Made:** |
| - Created `training/rewards.py` with TRL-compatible `reward_correctness`, `reward_progress`, and `reward_operational` callables |
| - Added robust metadata extraction paths so reward functions support both nested `metadata` payloads and flattened rollout kwargs |
| - Updated `training/__init__.py` exports for reward helper imports from the package root |
| - Added `tests/unit/test_rewards.py` covering correctness/progress/operational behavior across happy path, edge, and batch scenarios |
|
|
| **Result:** |
| - **Outcome:** OK Fully Successful |
| - **Evidence Captured:** |
| ``` |
| Command: uv run --with pytest pytest tests/unit/test_rewards.py -v |
| Result: 19 passed in 3.35s |
| ``` |
| - **Tests run:** `uv run --with pytest pytest tests/unit/test_rewards.py -v` |
| - **Notes:** |
| - Used `uv run --with pytest ...` because `pytest` is not available in the base environment. |
| - **Issues:** None |
| - **Follow-ups Created:** None |
| - **Human Review Completed:** N/A |
|
|
| **Context for Next Step:** |
| - `training/` now exposes config, prompts, rollout parsing/execution, and reward callables; next step is notebook wiring plus optional training dependencies in `pyproject.toml` |
|
|
| --- |
|
|
| ### Step 3.1: Training Notebook |
| **Slice:** S3 |
| **Goal:** Create end-to-end training notebook that loads model, trains with GRPO, and produces learning curves. |
|
|
| **Files:** |
| - `notebooks/train_grpo.ipynb` - create - end-to-end training notebook |
| - `pyproject.toml` - modify - add `[project.optional-dependencies] training` group |
|
|
| **Interface Changes:** |
| - New optional dependency group: `training = ["trl>=0.12.0", "accelerate>=0.34.0"]` |
|
|
| **Details:** |
| Notebook cells (linear flow): |
| 1. **Setup**: Install dependencies, import modules, set seed |
| 2. **Config**: Instantiate GRPOConfig (users can override model_name here) |
| 3. **Load Model**: `AutoModelForCausalLM.from_pretrained(config.model_name)` |
| 4. **Load Dataset**: Load questions, filter by difficulty, format as prompts |
| 5. **Initialize GRPOTrainer**: Pass model, tokenizer, rollout_func, reward_funcs, config |
| 6. **Train**: `trainer.train()` with progress bar and metric printing |
| 7. **Learning Curve**: Plot reward over training steps (matplotlib) |
| 8. **Comparison**: Run 5 episodes with random actions vs trained model, display side-by-side transcripts |
| 9. **Save**: Save trained model to config.output_dir |
|
|
| **Verification:** |
| > See VERIFICATION_SPEC.md for test criteria defined by independent verification planner. |
| |
| **Risk Tier for This Step:** Medium |
| > User-facing deliverable; must work on fresh setup. |
| |
| **Merge Criteria:** |
| - [x] Tests from VERIFICATION_SPEC.md pass |
| - [x] No TODOs left in changed code (or explicitly tracked) |
| - [x] Backwards compatible (or flag/migration documented) |
|
|
| **Status:** OK Completed |
|
|
| **Completed:** 2026-03-28T07:37:20Z |
| **Changes Made:** |
| - Created `notebooks/train_grpo.ipynb` as the primary user-facing training notebook for F006, with one-pass setup, model/question loading, trainer construction, training execution, learning-curve plotting, random-baseline vs trained transcript comparison, and artifact save steps. |
| - Added `[project.optional-dependencies].training` in `pyproject.toml` with `trl>=0.14.0,<0.15.0` and `accelerate>=0.34.0` to keep TRL/torch compatibility stable for this repository. |
| - Added `training/data_loading.py` to centralize notebook error handling for model loading and question filtering/loading. |
| - Added `training/notebook_pipeline.py` to centralize trainer wiring, random baseline generation, training execution, and metrics extraction. |
| - Updated `training/__init__.py` exports to include notebook-facing helpers. |
| - Added `tests/e2e/test_training_e2e.py` for notebook smoke structure + pipeline behavior checks. |
| - Added `tests/integration/test_training_pipeline.py` for rollout/reward integration scenarios. |
| - Added `tests/unit/test_error_handling.py` for model/question loading failures, OOM guidance messaging, and parse-fallback warning logging. |
|
|
| **Result:** |
| - **Outcome:** OK Fully Successful |
| - **Evidence Captured:** |
| ``` |
| Command: uv run --with pytest pytest tests/unit/test_grpo_config.py tests/unit/test_prompts.py tests/unit/test_rollout.py tests/unit/test_rewards.py tests/unit/test_error_handling.py tests/integration/test_training_pipeline.py tests/e2e/test_training_e2e.py -v |
| Result: 68 passed in 5.79s |
| Command: uv run --extra training python -c "from trl import GRPOConfig, GRPOTrainer; print('ok')" |
| Result: ok |
| ``` |
| - **Tests run:** `uv run --with pytest pytest tests/unit/test_grpo_config.py tests/unit/test_prompts.py tests/unit/test_rollout.py tests/unit/test_rewards.py tests/unit/test_error_handling.py tests/integration/test_training_pipeline.py tests/e2e/test_training_e2e.py -v` |
| - **Notes:** |
| - Added concrete integration/e2e/error test files that were listed in `VERIFICATION_SPEC.md` but missing from repository. |
| - Notebook now compares random-policy baseline transcripts against trained-policy transcripts, matching the feature's user-facing comparison goal. |
| - Parse fallback now emits a warning log to align behavior with error-handling verification expectations. |
| - **Issues:** None |
| - **Follow-ups Created:** None |
| - **Human Review Completed:** N/A |
|
|
| **Context for Next Step:** |
| - All implementation deliverables complete; feature is ready for final verification/finalization bookkeeping. |
|
|
| --- |
|
|
| ## 8. Rollout Considerations |
|
|
| ### Feature Flags |
| - [ ] Required: No |
|
|
| ### Migration |
| - [ ] Data migration needed: No |
|
|
| ### Rollback Plan |
| All changes are additive (new `training/` package and `notebooks/` directory). Rollback is simply removing those directories and reverting the pyproject.toml optional deps change. |
|
|
| --- |
|
|
| ## 9. Execution Tracking |
|
|
| All execution state is tracked within this document: |
| - **Section 1a:** Overall progress summary |
| - **Section 7:** Per-step completion details, test results, and handoff context |
| - **FEATURES.json:** Feature-level status/progress metadata used by `/autocode-next-step` and `opencode-ctx ralph run` |
| - **Git history:** Full audit trail of changes to this file |
|
|
| The implementing agent updates this document after each step and keeps the matching `FEATURES.json` entry in sync during implementation/finalization. Humans can monitor progress by: |
| - Checking Section 1a for summary |
| - Reviewing Section 7 for detailed step status |
| - Inspecting the feature's `progress` and `status` fields in `FEATURES.json` |
| - Running `git log --oneline IMPLEMENTATION_SPEC.md` for change history |
|
|
| --- |
|
|
| ## 9a. Slice Completion Protocol |
|
|
| After all steps in a slice pass verification: |
|
|
| 1. **Run verifier subagent** for spec compliance |
| - Validates against VERIFICATION_SPEC.md criteria |
| - Ensures no TODOs or incomplete work in slice |
| |
| 2. **Run compound-engineer subagent** to extract learnings |
| - **Mandatory invocation** after every slice completion |
| - Updates CLAUDE.md Learnings section (if durable patterns found) |
| - May exit with "no update needed" (valid for routine work) |
| |
| 3. **Commit** the slice changes |
| - Follow commit message format in CLAUDE.md |
| - Each slice gets its own atomic commit |
| |
| 4. **Continue to next slice** (if more slices remain) |
| - Or proceed to final verification if all slices complete |
| |
| **Note:** PR creation happens only after ALL slices are complete. Use `/commit-push-pr` manually when ready. |
| |
| --- |
| |
| ## 10. User Value Summary |
| |
| **Status:** Generated |
| |
| ### What Users Can Now Do |
| Users can now run a single notebook (`notebooks/train_grpo.ipynb`) to configure GRPO training, load a compatible TRL stack, train a model on SQLEnv prompts, and inspect both reward-curve output and transcript comparisons between random and trained policies. |
|
|
| ### How to Access/Test |
| 1. Install training extras: `uv sync --extra training` |
| 2. Open `notebooks/train_grpo.ipynb` |
| 3. Run all cells to train and save artifacts to `outputs/grpo_run` |
|
|
| ### Demo |
| - **Command:** `jupyter notebook notebooks/train_grpo.ipynb` |
| - **Verification command:** `uv run --with pytest pytest tests/unit/test_grpo_config.py tests/unit/test_prompts.py tests/unit/test_rollout.py tests/unit/test_rewards.py tests/unit/test_error_handling.py tests/integration/test_training_pipeline.py tests/e2e/test_training_e2e.py -v` |
|
|
| ### Release Notes Snippet |
| Add a GRPO training pipeline for SQLEnv with a runnable notebook, pinned TRL training dependencies, robust loading/error helpers, and verification coverage across unit, integration, and notebook-smoke paths. |
|
|
| --- |
|
|
| ## 11. PR Contract (Auto-Generated by autocode-next-step) |
|
|
| **Status:** Generated |
|
|
| ### Scope |
| - Finalized Step 3.1 (Training Notebook) for F006. |
| - Added training optional dependency group in `pyproject.toml` with TRL pin compatible with repo torch version. |
| - Added notebook support helpers for model/question loading and trainer orchestration. |
| - Added/expanded verification tests for notebook smoke, pipeline integration, and error handling. |
|
|
| ### Files Changed |
| - `pyproject.toml` |
| - `notebooks/train_grpo.ipynb` |
| - `training/__init__.py` |
| - `training/data_loading.py` |
| - `training/notebook_pipeline.py` |
| - `training/rollout.py` |
| - `tests/e2e/test_training_e2e.py` |
| - `tests/integration/test_training_pipeline.py` |
| - `tests/unit/test_error_handling.py` |
| - `specs/F006-IMPLEMENTATION_SPEC.md` |
| - `specs/behavior/training.md` |
|
|
| ### Verification Evidence |
| - `uv run --with pytest pytest tests/unit/test_grpo_config.py tests/unit/test_prompts.py tests/unit/test_rollout.py tests/unit/test_rewards.py tests/unit/test_error_handling.py tests/integration/test_training_pipeline.py tests/e2e/test_training_e2e.py -v` -> 68 passed |
| - `uv run --extra training python -c "from trl import GRPOConfig, GRPOTrainer; print('ok')"` -> ok |
| - Verifier verdict: APPROVED (`specs/F006-VERIFICATION_REPORT.md`) |
|
|
| ### Risk and Rollback |
| - Risk tier: Medium (training dependencies and user-facing notebook workflow). |
| - Rollback: remove notebook/training helper additions and revert `pyproject.toml` training extra. |
|
|
| ### Ready for Next Command |
| All implementation and verification criteria for F006 are complete. Run `/commit-push-pr` when ready. |
|
|
| --- |
|
|
| ## Stop Conditions (When to Split This Spec) |
|
|
| Stop and create a new IMPLEMENTATION_SPEC if: |
| - A step requires touching more than **3 files** in unrelated areas |
| - You need to introduce **multiple new abstractions** "just in case" |
| - Verification cannot be made targeted and concrete |
| - You discover new unknowns that change the plan materially |
| - The next slice cannot be merged safely without finishing later slices |
| |
| When splitting, ensure the current slice ends in a merged, stable state. |
| |
| --- |
| |
| ## Human Checkpoint |
| |
| **Before handing to AI agent:** |
| |
| - [ ] Interface specifications are complete |
| - [ ] Data flow is accurate |
| - [ ] Error handling is specified |
| - [ ] Implementation order makes sense |
| - [ ] VERIFICATION_SPEC.md has been generated |
|
|
| **Questions:** |
| 1. Confirm Qwen3-1.7B is accessible on HuggingFace Hub for the target environment. |
| 2. Verify TRL GRPOTrainer API matches the rollout_func / reward_funcs signatures assumed here. |
|
|
| --- |
|
|
| ## Handoff Notes |
|
|
| **For the implementing AI agent:** |
|
|
| ``` |
| Context: See RESEARCH_SUMMARY.md for system understanding |
| Spec: Follow this document exactly |
| Verification: Use tests from VERIFICATION_SPEC.md (independent agent) |
| Ambiguity: Stop and ask rather than assume |
| Order: Follow implementation order exactly |
| Key decisions: |
| - HF generate (not vLLM) for inference |
| - Model name is a config parameter (default Qwen3-1.7B) |
| - Start with easy+medium questions only |
| - Follow TRL GRPOTrainer Wordle tutorial pattern |
| - reward_funcs are separate callables |
| ``` |
|
|
| --- |
|
|
| *Specification completed: 2026-03-27* |
| *Approved by: [pending]* |
| *Verification spec: VERIFICATION_SPEC.md* |
| *Target agent: Claude Code* |
|
|