Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Baseline Inference Script for SQL Migration Environment. | |
| Runs all 7 migration tasks sequentially using an LLM via OpenAI-compatible API. | |
| Outputs structured [START]/[STEP]/[END] format for automated evaluation. | |
| Fixes Applied: | |
| - D1: Task description injected into system prompt | |
| - D2: Hardcoded system prompt traps removed (no more audit_log/INTEGER traps) | |
| - D3: Data discovery rule added (agent runs SELECT before DDL) | |
| - D4: Submit guard added (agent must verify before submitting) | |
| - D5: Context window bloat fixed (schema not repeated every step) | |
| - D6: Parse error counter tracks consecutive errors only | |
| - D7: response_format JSON mode with fallback | |
| Usage: | |
| python inference.py | |
| Environment Variables: | |
| API_BASE_URL: LLM inference endpoint (default: HF router) | |
| MODEL_NAME: Model identifier (default: Qwen/Qwen2.5-72B-Instruct) | |
| HF_TOKEN or API_KEY: Authentication token | |
| """ | |
| import json | |
| import os | |
| import re | |
| import sys | |
| import time | |
| import traceback | |
| # Server URL for the environment | |
| ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860") | |
| # LLM Configuration | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| API_KEY = os.getenv("OPENAI_API_KEY") or HF_TOKEN or os.getenv("API_KEY") | |
| # --- D2: Cleaned system prompt — no hardcoded table names or type traps --- | |
| SYSTEM_PROMPT_TEMPLATE = """You are an autonomous SQLite database migration engine. You receive the current schema and a target schema. Write SQL to transform the current state to the target state without losing row data. | |
| TASK OBJECTIVE: | |
| {task_description} | |
| CRITICAL SQLite-specific rules (violations cause immediate errors): | |
| 1. SQLite does NOT support ALTER TABLE ADD CONSTRAINT, ALTER COLUMN, or ADD PRIMARY KEY. | |
| 2. To change column types, add NOT NULL, or add FKs: CREATE new table, INSERT INTO new SELECT FROM old, DROP old, RENAME new. | |
| 3. Apostrophes in data (O'Brien, O'Neill) are present — escape with '' in string literals. | |
| 4. Execute exactly ONE SQL statement per step. | |
| 5. If a table already exists, you MUST drop it before recreating it (e.g., DROP TABLE IF EXISTS users_new). | |
| 6. SQLite strictly expects `INSERT INTO tbl VALUES (...)`, not `VALUE (...)`. Ensure column counts match exactly. | |
| 7. For table normalization: create new tables first, INSERT INTO ... SELECT, then drop old tables. | |
| 8. For orphaned FK rows: check the TARGET SCHEMA for the anomaly/issues table name. Log invalid records there before dropping. | |
| 9. For text currency (e.g. '$90,000'): strip '$' and ',' then cast to the target type (INTEGER/REAL). | |
| 10. IMPORTANT: Before writing any DDL, execute SELECT * FROM tablename LIMIT 5 to inspect the data format. | |
| 11. Do NOT set submit_final to true until you run SELECT COUNT(*) and verify data matches the task. | |
| TARGET SCHEMA (achieve this exactly): | |
| {target_ddl} | |
| Respond ONLY with a valid JSON object. Do not use markdown backticks (```json). No conversational text. | |
| {{"sql_command": "your SQL here", "reasoning": "why", "submit_final": false}}""" | |
| ALL_TASKS = [ | |
| "column-restructure", | |
| "soft-delete-restoration", | |
| "table-normalization", | |
| "schema-version-merge", | |
| "multi-entity-extraction", | |
| "cascade-migration", | |
| "dual-source-consolidation", | |
| ] | |
| MAX_PARSE_ERRORS = 5 # Consecutive parse errors before giving up | |
| AUTO_SUBMIT_THRESHOLD = 0.95 | |
| MAX_HISTORY_PAIRS = 4 # Keep maximum of 4 user/assistant turn pairs | |
| def build_messages(system_prompt: str, history: list, current_obs_msg: dict) -> list: | |
| """ | |
| Build messages explicitly pruning history to avoid context bloat. | |
| """ | |
| system_msg = [{"role": "system", "content": system_prompt}] | |
| # We only want assistant/user pairs. Filter out system msgs if any exist in history | |
| filtered_history = [m for m in history if m["role"] != "system"] | |
| # Keep only the last MAX_HISTORY_PAIRS * 2 messages | |
| max_msgs = MAX_HISTORY_PAIRS * 2 | |
| if len(filtered_history) > max_msgs: | |
| pruned_history = filtered_history[-max_msgs:] | |
| else: | |
| pruned_history = filtered_history | |
| return system_msg + pruned_history + [current_obs_msg] | |
| def call_llm(messages: list, timeout: int = 90) -> str: | |
| """Call the LLM API with JSON mode fallback.""" | |
| from openai import OpenAI | |
| client = OpenAI( | |
| base_url=API_BASE_URL, | |
| api_key=API_KEY, | |
| timeout=timeout, | |
| ) | |
| # --- D7: Try JSON mode first, fallback to plain --- | |
| try: | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=messages, | |
| temperature=0.0, | |
| max_tokens=1024, | |
| response_format={"type": "json_object"}, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception: | |
| pass | |
| # Fallback: plain text mode | |
| try: | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=messages, | |
| temperature=0.0, | |
| max_tokens=1024, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| raise TimeoutError(f"LLM API error: {e}") | |
| def parse_action(raw_text: str) -> dict: | |
| """ | |
| Parse LLM output into an action dict. | |
| Handles: raw JSON, markdown-wrapped JSON, <think>...</think> blocks, | |
| escaped quotes in SQL, and truncated output recovery. | |
| """ | |
| text = raw_text.strip() | |
| # Strip <think>...</think> blocks (Qwen3, DeepSeek-R1) | |
| text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip() | |
| text = re.sub(r"<think>.*$", "", text, flags=re.DOTALL).strip() | |
| # Strip markdown code block fences | |
| if text.startswith("```"): | |
| lines = text.split("\n") | |
| lines = [l for l in lines if not l.strip().startswith("```")] | |
| text = "\n".join(lines).strip() | |
| # Try direct JSON parse | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError: | |
| pass | |
| # Try to find JSON object in the text | |
| start = text.find("{") | |
| end = text.rfind("}") + 1 | |
| if start >= 0 and end > start: | |
| try: | |
| return json.loads(text[start:end]) | |
| except json.JSONDecodeError: | |
| pass | |
| # --- D6: Improved regex that handles escaped quotes --- | |
| sql_match = re.search(r'"sql_command"\s*:\s*"((?:[^"\\]|\\.)*)"', text) | |
| if sql_match: | |
| sql = sql_match.group(1) | |
| # Unescape JSON string escapes | |
| sql = sql.replace('\\"', '"').replace("\\n", "\n").replace("\\\\", "\\") | |
| return { | |
| "sql_command": sql, | |
| "reasoning": "auto-extracted from malformed response", | |
| "submit_final": False, | |
| } | |
| raise ValueError(f"Could not parse JSON from LLM response: {text[:200]}") | |
| def run_task_local(task_name: str) -> dict: | |
| """ | |
| Run a single task using a local environment instance (no server needed). | |
| """ | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| from server.environment import DbMigrationEnvironment | |
| from models import MigrationAction | |
| import seeds | |
| env = DbMigrationEnvironment(task_name=task_name) | |
| task_config = seeds.TASKS[task_name] | |
| task_max_steps = task_config.get("max_steps", 20) | |
| print(f"[START] task={task_name} env=sql-migration-agent model={MODEL_NAME}", flush=True) | |
| obs = env.reset() | |
| # --- D1: Inject task description into system prompt --- | |
| task_system_prompt = SYSTEM_PROMPT_TEMPLATE.format( | |
| task_description=task_config["description"], | |
| target_ddl=obs.target_schema_sql, | |
| ) | |
| history = [{"role": "system", "content": task_system_prompt}] | |
| # Initial observation | |
| initial_msg = { | |
| "role": "user", | |
| "content": ( | |
| f"CURRENT DATABASE SCHEMA:\n{obs.current_schema_sql}\n\n" | |
| f"Status: {obs.last_execution_result}\n" | |
| f"Migration progress: {obs.migration_progress:.2f}\n\n" | |
| f"Start by inspecting the source data with SELECT queries, then begin the migration." | |
| ) | |
| } | |
| history = [] | |
| rewards_list = [] | |
| consecutive_parse_errors = 0 # D6: Track consecutive only | |
| final_score = 0.0 | |
| steps_taken = 0 | |
| done = False | |
| for step in range(task_max_steps): | |
| if done: | |
| break | |
| # --- D5: Context window fix: Aggressively prune history via build_messages --- | |
| messages = build_messages(task_system_prompt, history, initial_msg) | |
| try: | |
| raw_response = call_llm(messages) | |
| except TimeoutError as e: | |
| error_msg = str(e)[:100] | |
| print(f"[STEP] step={step+1} action=API_TIMEOUT reward=0.00 done=true error={error_msg}", flush=True) | |
| done = True | |
| break | |
| # Parse the action | |
| try: | |
| action_dict = parse_action(raw_response) | |
| consecutive_parse_errors = 0 # D6: Reset on success | |
| except ValueError: | |
| consecutive_parse_errors += 1 | |
| print(f"[STEP] step={step+1} action=PARSE_ERROR reward=0.00 done=false error=parse_error", flush=True) | |
| if consecutive_parse_errors >= MAX_PARSE_ERRORS: | |
| print(f"[STEP] step={step+1} action=MAX_PARSE_ERRORS reward=0.00 done=true error=too_many_consecutive_parse_errors", flush=True) | |
| done = True | |
| break | |
| # CRITICAL: Strip <think> tags before appending to history to prevent 413 Context OOM | |
| stripped_response = re.sub(r"<think>.*?</think>", "", raw_response, flags=re.DOTALL).strip() | |
| stripped_response = re.sub(r"<think>.*$", "", stripped_response, flags=re.DOTALL).strip() | |
| # If it's still huge, truncate it to 500 chars to save context | |
| if len(stripped_response) > 500: | |
| stripped_response = stripped_response[:500] + "... [TRUNCATED DUE TO PARSE ERROR]" | |
| history.append(initial_msg) # The prompt we sent | |
| history.append({"role": "assistant", "content": stripped_response}) # The stripped response | |
| initial_msg = { | |
| "role": "user", | |
| "content": 'ERROR: Your response was not a valid JSON object. Do not use markdown blocks. Respond strictly with: {"sql_command": "...", "reasoning": "...", "submit_final": false}' | |
| } | |
| continue | |
| # Build the MigrationAction | |
| try: | |
| action = MigrationAction( | |
| sql_command=action_dict.get("sql_command", ""), | |
| reasoning=action_dict.get("reasoning", ""), | |
| submit_final=action_dict.get("submit_final", False), | |
| ) | |
| except Exception as e: | |
| print(f"[STEP] step={step+1} action=INVALID_ACTION reward=0.00 done=false error={str(e)[:50]}", flush=True) | |
| continue | |
| # Execute the action | |
| obs = env.step(action) | |
| steps_taken = step + 1 | |
| step_reward = obs.reward if obs.reward is not None else 0.0 | |
| rewards_list.append(step_reward) | |
| final_score = obs.migration_progress | |
| done = obs.done | |
| # AUTO-SUBMIT: If we reached near-perfect score, force submit | |
| if final_score >= AUTO_SUBMIT_THRESHOLD and not done: | |
| done = True | |
| submit_action = MigrationAction( | |
| sql_command="SELECT 1", | |
| reasoning="Migration complete — auto-submitting", | |
| submit_final=True, | |
| ) | |
| obs = env.step(submit_action) | |
| final_score = obs.migration_progress | |
| # Log | |
| sql_abbrev = action.sql_command[:50].replace("\n", " ") | |
| if len(action.sql_command) > 50: | |
| sql_abbrev += "..." | |
| error_str = obs.metadata.get("error", "null") if obs.metadata else "null" | |
| if error_str != "null": | |
| error_str = error_str[:80] | |
| print( | |
| f"[STEP] step={steps_taken} action={sql_abbrev} " | |
| f"reward={step_reward:.2f} done={'true' if done else 'false'} " | |
| f"error={error_str}", | |
| flush=True, | |
| ) | |
| # Add to conversation history | |
| history.append(initial_msg) | |
| history.append({"role": "assistant", "content": json.dumps(action_dict)}) | |
| # --- D5: Lean feedback — NO schema repetition --- | |
| feedback_text = ( | |
| f"EXECUTION RESULT: {obs.last_execution_result}\n" | |
| f"Progress: {obs.migration_progress:.2f}" | |
| f"\nSchema Diff (Missing/Extra constraints vs Target):\n{obs.schema_diff}" | |
| ) | |
| if done: | |
| feedback_text += "\n\nEpisode complete." | |
| elif obs.migration_progress >= 0.9: | |
| feedback_text += ( | |
| "\n\nMigration is nearly complete! Run SELECT COUNT(*) on each table " | |
| "and compare to your expectations. If everything matches, set submit_final to true." | |
| ) | |
| else: | |
| feedback_text += "\n\nContinue the migration. Write your next SQL command." | |
| initial_msg = {"role": "user", "content": feedback_text} | |
| # Print END | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards_list) if rewards_list else "0.00" | |
| success = "true" if final_score >= 0.8 else "false" | |
| print( | |
| f"[END] success={success} steps={steps_taken} " | |
| f"score={final_score:.2f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| env.close() | |
| return { | |
| "task_name": task_name, | |
| "score": final_score, | |
| "steps": steps_taken, | |
| "rewards": rewards_list, | |
| } | |
| def main(): | |
| """Run all 7 tasks sequentially.""" | |
| if not API_KEY: | |
| print("WARNING: No API key found. Set HF_TOKEN or API_KEY.", file=sys.stderr) | |
| sys.exit(1) | |
| results = {} | |
| for task_name in ALL_TASKS: | |
| try: | |
| result = run_task_local(task_name) | |
| results[task_name] = result["score"] | |
| except Exception as e: | |
| print(f"[ERROR] task={task_name} error={str(e)[:200]}", file=sys.stderr) | |
| traceback.print_exc(file=sys.stderr) | |
| results[task_name] = 0.0 | |
| # Summary | |
| scores = list(results.values()) | |
| avg = sum(scores) / len(scores) if scores else 0.0 | |
| scores_str = " ".join(f"{t}={s:.2f}" for t, s in results.items()) | |
| print( | |
| f"[SUMMARY] {scores_str} avg={avg:.2f}", | |
| flush=True, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |