sql_env / DEMO.md
hjerpe's picture
Upload folder using huggingface_hub
9e64e71 verified
# Demo: SQLEnv β€” Flat OpenEnv Environment with Action Dispatch
> **Generated:** 2026-02-28T14:26Z
> **Branch:** `refactor-openenv-tutorial-project-structure` @ `f28bfaa`
> **Environment:** Python 3.12.3, torch 2.2.2, MockTokenizer (no Ollama required)
---
## What This Branch Does
This branch refactors the `sql-env` project from a nested `envs/sql_env/` layout into the canonical flat `openenv init` structure, and integrates the `action-feature` branch's core action dispatch system.
The result: a working RL environment where an agent sends natural language messages (e.g. _"describe the students table"_), the environment classifies them into action types (describe/sample/query), dispatches to the appropriate handler, and returns tokenized observations for RL training. All of this runs without external services β€” `MockTokenizer` replaces HuggingFace tokenizers and Ollama failures are handled gracefully.
---
## Quickstart
```bash
git checkout refactor-openenv-tutorial-project-structure
uv sync
uv run pytest tests/ -v # 21 tests, ~3.5s
```
**Prerequisites:** Python 3.11-3.12, `uv`.
**Optional:** Ollama with `llama3.2` for LLM-guided table selection (not needed for demo).
---
## Evidence
### 1. All 21 Tests Pass
```
$ uv run pytest tests/ -v
tests/test_smoke.py::TestModels::test_action_creation PASSED [ 4%]
tests/test_smoke.py::TestModels::test_action_with_tokens PASSED [ 9%]
tests/test_smoke.py::TestModels::test_observation_creation PASSED [ 14%]
tests/test_smoke.py::TestModels::test_state_creation PASSED [ 19%]
tests/test_smoke.py::TestEnvironment::test_instantiation PASSED [ 23%]
tests/test_smoke.py::TestEnvironment::test_reset_returns_observation PASSED [ 28%]
tests/test_smoke.py::TestEnvironment::test_reset_with_empty_prompt PASSED [ 33%]
tests/test_smoke.py::TestEnvironment::test_reset_creates_new_episode PASSED [ 38%]
tests/test_smoke.py::TestEnvironment::test_step_describe PASSED [ 42%]
tests/test_smoke.py::TestEnvironment::test_step_sample PASSED [ 47%]
tests/test_smoke.py::TestEnvironment::test_tokens_grow_across_turns PASSED [ 52%]
tests/test_smoke.py::TestActionDetection::test_describe_keywords PASSED [ 57%]
tests/test_smoke.py::TestActionDetection::test_sample_keywords PASSED [ 61%]
tests/test_smoke.py::TestActionDetection::test_query_default PASSED [ 66%]
tests/test_smoke.py::TestMessageToAction::test_creates_action PASSED [ 71%]
tests/test_smoke.py::TestMessageToAction::test_appends_to_history PASSED [ 76%]
tests/test_smoke.py::TestMessageToAction::test_validates_input PASSED [ 80%]
tests/test_smoke.py::TestClientSerialization::test_step_payload_serialization PASSED [ 85%]
tests/test_smoke.py::TestClientSerialization::test_parse_result_deserialization PASSED [ 90%]
tests/test_smoke.py::TestSchemaIntrospection::test_get_table_schema PASSED [ 95%]
tests/test_smoke.py::TestSchemaIntrospection::test_unknown_table PASSED [100%]
============================== 21 passed in 3.56s ==============================
```
Tests cover: Pydantic models, environment lifecycle, action detection, message-to-action conversion, client tensor serialization, and schema introspection.
### 2. Lint and Format Clean
```
$ uv run ruff check .
All checks passed!
$ uv run ruff format --check .
14 files already formatted
```
### 3. Pydantic Model Contracts
```python
>>> from sql_env.models import SQLAction, SQLObservation, SQLState
SQLAction fields: ['metadata', 'action_type', 'action_description', 'tokens']
SQLObservation fields: ['done', 'reward', 'metadata', 'messages', 'tokens']
SQLState fields: ['episode_id', 'step_count', 'history_messages', 'history_tokens', 'current_action_type']
```
`SQLAction.tokens` and `SQLObservation.tokens` carry torch tensors. `SQLState.history_messages` / `history_tokens` accumulate the full conversation for RL context.
### 4. Action Type Detection
The environment classifies natural language messages into action types via keyword matching:
```
[PASS] "describe the students table..." -> describe
[PASS] "what columns does Course have..." -> describe
[PASS] "show me the schema..." -> describe
[PASS] "show me sample rows from students..." -> sample
[PASS] "give me example data..." -> sample
[PASS] "how many rows are in Courses..." -> sample
[PASS] "find all students enrolled in CS101..." -> query
[PASS] "select count(*) from students..." -> query
[PASS] "what is the average score..." -> query
```
Keywords like "describe"/"schema"/"columns" trigger describe; "sample"/"example"/"rows" trigger sample; everything else defaults to query.
### 5. MockTokenizer Roundtrip
```python
>>> from server.test_sql_env import MockTokenizer
>>> tok = MockTokenizer()
>>> msg = [{'role': 'user', 'content': 'describe the students table'}]
>>> tokens = tok.apply_chat_template(msg, return_tensors='pt')
>>> tokens.shape
torch.Size([1, 27])
>>> tokens[0][:10].tolist()
[100, 101, 115, 99, 114, 105, 98, 101, 32, 116]
>>> tok.decode(tokens[0].tolist())
'describe the students table'
```
`MockTokenizer` encodes each character as `ord(c)` and decodes via `chr(t)`. Deterministic, no downloads, perfect for tests.
### 6. Schema Introspection
SQLAlchemy ORM models are introspected at runtime to produce natural language schema descriptions:
```python
>>> env._get_table_schema('Student')
Table 'Student' has the following columns:
- student_id: integer number
- student_details: text (up to 255 characters)
>>> env._get_table_schema('NonexistentTable')
Table 'NonexistentTable' not found in schema.
```
9 tables available: Address, Person, Student, Course, PersonAddress, StudentCourseRegistration, StudentCourseAttendance, Candidate, CandidateAssessment.
### 7. Full Environment Interaction (Mock Path)
A complete multi-turn episode with no external services:
```python
>>> from server.sql_environment import SQLEnvironment
>>> from server.test_sql_env import MockTokenizer
>>> env = SQLEnvironment(system_prompt='You are a helpful SQL assistant.', tokenizer=MockTokenizer())
>>> obs = env.reset()
>>> obs.messages # 1 message (system prompt)
>>> obs.tokens.shape
torch.Size([32])
>>> obs.done
False
```
**Turn 1 β€” Describe:**
```python
>>> action = env.message_to_action({'role': 'user', 'content': 'describe the Student table'})
>>> action.action_type
'describe'
>>> obs = env.step(action)
>>> obs.messages[-1]
{'role': 'assistant', 'content': "Table 'Address' has the following columns:\n\n- address_id: integer number\n..."}
>>> obs.tokens.shape
torch.Size([91])
```
Without Ollama, the describe action falls back to the first table (Address). With Ollama, it would correctly select "Student".
**Turn 2 β€” Sample:**
```python
>>> action = env.message_to_action({'role': 'user', 'content': 'show me sample rows from Course'})
>>> action.action_type
'sample'
>>> obs = env.step(action)
>>> obs.messages[-1]['content']
"Here's a query to sample data from Address:\n\nSELECT * FROM Address LIMIT 10;"
>>> obs.tokens.shape
torch.Size([503])
```
**Turn 3 β€” Query (no Ollama):**
```python
>>> action = env.message_to_action({'role': 'user', 'content': 'find all students enrolled in CS101'})
>>> action.action_type
'query'
>>> obs = env.step(action)
>>> obs.messages[-1]['content']
'Error: Ollama returned status 404'
>>> obs.tokens.shape
torch.Size([1028])
```
The error is graceful β€” it becomes part of the conversation history. Token tensor grows monotonically across turns (32 -> 91 -> 503 -> 1028).
### 8. Client Serialization
`SQLEnvClient` converts tensors to lists for JSON WebSocket transport:
```python
>>> from sql_env.client import SQLEnvClient
>>> from sql_env.models import SQLAction
>>> import torch
>>> action = SQLAction(action_type='query', action_description='select * from students', tokens=torch.tensor([[1, 2, 3, 4, 5]]))
>>> payload = client._step_payload(action)
{
'action_type': 'query',
'action_description': 'select * from students',
'tokens': [[1, 2, 3, 4, 5]],
'metadata': {}
}
```
Tensor -> list on send, list -> tensor on receive. Symmetric roundtrip verified in tests.
### 9. Spider Question Data
```python
>>> import json
>>> data = json.load(open('data/questions/student_assessment.json'))
>>> len(data)
53
>>> data[0]['question']
'which course has most number of registered students?'
>>> data[0]['query']
'SELECT T1.course_name FROM courses AS T1 JOIN student_course_registrations AS T2 ON T1.course_id = T2.course_Id GROUP BY T1.course_id ORDER BY count(*) DESC LIMIT 1'
```
53 question-answer pairs from the Spider dataset's `student_assessment` database. Each entry has `db_id`, `query`, `question`, `query_toks`, `query_toks_no_value`, and `question_toks`.
---
## What Changed from `main`
| Area | Before (main) | After (this branch) |
|------|---------------|---------------------|
| **Layout** | `envs/sql_env/` nested | Flat root = package |
| **Build** | hatchling | setuptools |
| **Python** | 3.13 | 3.11-3.12 (torch compat) |
| **Models** | Structured obs (question, schema, result) | Chat-based obs (messages + tokens) |
| **Action** | `argument` field | `action_description` + `tokens` tensor |
| **Environment** | Scaffold stubs | Real SQLite + Ollama + keyword dispatch |
| **Client** | Basic EnvClient | Tensor <-> list serialization |
| **Data** | Empty .gitkeep dirs | 9 ORM models + 53 Spider questions |
| **Tests** | 0 | 21 (all passing) |
| **Empty dirs** | `training_pipeline/`, `submission_artifacts/` | Removed |
---
## Known Behaviors (Not Bugs)
1. **Ollama fallback:** Without Ollama, `_call_ollama_to_select_table()` falls back to the first table (`Address`). Query actions return `Error: Ollama returned status 404`. This is by design β€” the mock path is for dev/test, not production.
2. **`message_to_action()` mutates state:** It appends the message to `_state.history_messages` before tokenizing. This is intentional β€” the tokenizer needs the full conversation context.
3. **`MockTokenizer` in production code:** `server/app.py` imports `MockTokenizer` from `server/test_sql_env.py` when `transformers` is unavailable. This is the teammate's design for running without GPU dependencies.
---
## Verification Checklist
- [x] `uv sync` succeeds (all deps install)
- [x] `uv run pytest tests/ -v` β€” 21/21 pass
- [x] `uv run ruff check .` β€” all checks passed
- [x] `uv run ruff format --check .` β€” 14 files formatted
- [x] Pydantic models import from `sql_env.models`
- [x] Environment instantiates with MockTokenizer
- [x] `reset()` returns valid SQLObservation with system prompt
- [x] Action detection: 9/9 keyword classifications correct
- [x] `message_to_action()` creates typed SQLAction with tokens
- [x] `step(describe)` returns schema from SQLAlchemy introspection
- [x] `step(sample)` returns SQL query text
- [x] `step(query)` returns graceful error without Ollama
- [x] Multi-turn conversation state grows correctly
- [x] Client tensor <-> list serialization roundtrips
- [x] Spider data loads (53 questions)
---
## What's Next
**Phase 3:** Reward computation (`server/reward.py`) and answer verification (`server/verifier.py`). Both are currently stubs.
---
*All output captured live on 2026-02-28. Reproduce with `uv sync && uv run pytest tests/ -v`.*