Abhinav Singh commited on
Commit
7d21a80
Β·
1 Parent(s): e8bc352

feat(inference): add OpenEnv-compliant inference script

Browse files

inference.py runs the agent loop against all 3 tasks in sequence.

Strict stdout format (as required by hackathon spec):
[START] task=<id> env=sql-optim-env model=<MODEL_NAME>
[STEP] step=<n> action=suggestions=<n>,score=<f> reward=<f> done=<bool> error=<msg|null>
[END] success=<bool> steps=<n> score=<f> rewards=<r1,...,rn>

Agent strategy:
- SYSTEM prompt instructs model to output strict JSON with suggestions,
optimized_query, summary, estimated_improvement, approved fields
- USER prompt includes schema_info, sql_query, dialect, step context,
and issues_found_so_far from previous steps
- parse_action() strips markdown fences and falls back gracefully on parse error
- Episode success threshold: max reward >= 0.5
- Configurable via: API_BASE_URL, MODEL_NAME, HF_TOKEN env vars

Files changed (1) hide show
  1. inference.py +229 -0
inference.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference.py β€” SQL Query Optimization Environment
3
+ ===================================================
4
+ OpenEnv Hackathon Phase 1 Submission
5
+
6
+ Required environment variables:
7
+ API_BASE_URL The API endpoint for the LLM (default: HuggingFace router)
8
+ MODEL_NAME The model identifier (default: Qwen/Qwen2.5-72B-Instruct)
9
+ HF_TOKEN Your HuggingFace / API key
10
+
11
+ stdout format (strictly followed):
12
+ [START] task=<task_name> env=<benchmark> model=<model_name>
13
+ [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
14
+ [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
15
+ """
16
+
17
+ import os
18
+ import json
19
+ import sys
20
+ from typing import List, Optional
21
+ from openai import OpenAI
22
+
23
+ # ── Resolve paths so we can import env/models from root ──────────────────
24
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
25
+ sys.path.insert(0, ROOT_DIR)
26
+
27
+ from env import SQLOptimEnv
28
+ from models import Action
29
+
30
+ # ── Configuration ─────────────────────────────────────────────────────────
31
+ API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
32
+ MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
33
+ HF_TOKEN = os.environ.get("HF_TOKEN", "") or os.environ.get("API_KEY", "")
34
+
35
+ BENCHMARK = "sql-optim-env"
36
+ TEMPERATURE = 0.0
37
+ MAX_TOKENS = 1500
38
+
39
+ TASK_IDS = [
40
+ "task_1_basic_antipatterns",
41
+ "task_2_join_optimization",
42
+ "task_3_advanced_optimization",
43
+ ]
44
+
45
+ SYSTEM_PROMPT = """\
46
+ You are an expert database engineer and SQL performance specialist with deep knowledge of \
47
+ PostgreSQL internals, query planning, and index design.
48
+
49
+ You will receive a SQL query, its database schema, and a task description. \
50
+ Your job is to:
51
+ 1. Identify ALL performance issues and anti-patterns in the query.
52
+ 2. Produce an optimized rewrite of the query.
53
+ 3. Estimate the expected performance improvement.
54
+
55
+ Respond ONLY with a valid JSON object in this exact format (no markdown, no extra text):
56
+ {
57
+ "suggestions": [
58
+ {
59
+ "issue_type": "string (e.g. select_star, non_sargable_predicate, correlated_subquery, missing_index, etc.)",
60
+ "line": <integer line number in the query>,
61
+ "description": "clear explanation of why this is a problem",
62
+ "severity": "critical | high | medium | low",
63
+ "fix": "specific fix or rewritten clause"
64
+ }
65
+ ],
66
+ "optimized_query": "the full rewritten SQL query with all improvements applied",
67
+ "summary": "2-4 sentence overall analysis of the query performance profile",
68
+ "estimated_improvement": "e.g. '10-50x faster on large tables due to index usage', '~80% reduction in I/O'",
69
+ "approved": false
70
+ }
71
+
72
+ Be thorough and precise. Every issue you identify should have a concrete fix.
73
+ """
74
+
75
+
76
+ # ── Logging helpers ────────────────────────────────────────────────────────
77
+
78
+ def log_start(task: str, env: str, model: str) -> None:
79
+ print(f"[START] task={task} env={env} model={model}", flush=True)
80
+
81
+
82
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
83
+ error_val = error if error else "null"
84
+ done_val = str(done).lower()
85
+ print(
86
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
87
+ flush=True,
88
+ )
89
+
90
+
91
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
92
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
93
+ print(
94
+ f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}",
95
+ flush=True,
96
+ )
97
+
98
+
99
+ # ── Model interaction ──────────────────────────────────────────────────────
100
+
101
+ def parse_action(response_text: str) -> dict:
102
+ """Parse JSON from model response, stripping code fences if present."""
103
+ clean = response_text.strip()
104
+ if clean.startswith("```"):
105
+ lines = clean.split("\n")
106
+ # Drop first and last fence lines
107
+ clean = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:])
108
+ if clean.startswith("json"):
109
+ clean = clean[4:].strip()
110
+ try:
111
+ return json.loads(clean)
112
+ except json.JSONDecodeError:
113
+ return {
114
+ "suggestions": [],
115
+ "optimized_query": "",
116
+ "summary": "JSON parse error β€” model returned malformed output.",
117
+ "estimated_improvement": "unknown",
118
+ "approved": False,
119
+ }
120
+
121
+
122
+ def get_model_action(client: OpenAI, obs) -> tuple[dict, Optional[str]]:
123
+ """Call the LLM and return (parsed_action_dict, error_or_None)."""
124
+ user_content = f"""Task: {obs.task_name}
125
+ Difficulty: {obs.difficulty}
126
+ SQL Dialect: {obs.dialect}
127
+
128
+ Instructions:
129
+ {obs.task_description}
130
+
131
+ Database Schema:
132
+ {obs.schema_info}
133
+
134
+ SQL Query to Analyze (step {obs.step_count + 1}/{obs.max_steps}):
135
+ ```sql
136
+ {obs.sql_query}
137
+ ```
138
+
139
+ Issues identified in previous steps: {obs.issues_found_so_far if obs.issues_found_so_far else 'None yet'}
140
+
141
+ Provide your complete analysis and optimized rewrite now.
142
+ """
143
+ try:
144
+ completion = client.chat.completions.create(
145
+ model=MODEL_NAME,
146
+ messages=[
147
+ {"role": "system", "content": SYSTEM_PROMPT},
148
+ {"role": "user", "content": user_content},
149
+ ],
150
+ temperature=TEMPERATURE,
151
+ max_tokens=MAX_TOKENS,
152
+ stream=False,
153
+ )
154
+ response_text = completion.choices[0].message.content or ""
155
+ return parse_action(response_text), None
156
+ except Exception as exc:
157
+ error_msg = str(exc)
158
+ return {
159
+ "suggestions": [],
160
+ "optimized_query": "",
161
+ "summary": f"Model call failed: {error_msg}",
162
+ "estimated_improvement": "unknown",
163
+ "approved": False,
164
+ }, error_msg
165
+
166
+
167
+ # ── Main loop ──────────────────────────────────────────────────────────────
168
+
169
+ def main():
170
+ if not HF_TOKEN:
171
+ print("[ERROR] HF_TOKEN environment variable is not set.", flush=True)
172
+ sys.exit(1)
173
+
174
+ client = OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
175
+ local_env = SQLOptimEnv()
176
+ results = {}
177
+
178
+ for task_id in TASK_IDS:
179
+ obs = local_env.reset(task_id=task_id)
180
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
181
+
182
+ rewards: List[float] = []
183
+ steps_taken = 0
184
+ score = 0.0
185
+ success = False
186
+
187
+ try:
188
+ for step in range(1, obs.max_steps + 1):
189
+ parsed, error = get_model_action(client, obs)
190
+
191
+ action = Action(
192
+ suggestions=parsed.get("suggestions", []),
193
+ optimized_query=parsed.get("optimized_query", ""),
194
+ summary=parsed.get("summary", ""),
195
+ estimated_improvement=parsed.get("estimated_improvement", ""),
196
+ approved=parsed.get("approved", False),
197
+ )
198
+
199
+ result = local_env.step(action)
200
+ reward = result.reward.score
201
+ done = result.done
202
+
203
+ rewards.append(reward)
204
+ steps_taken = step
205
+ obs = result.observation
206
+
207
+ action_summary = f"suggestions={len(action.suggestions)},score={reward:.2f}"
208
+ log_step(step=step, action=action_summary, reward=reward, done=done, error=error)
209
+
210
+ if done:
211
+ break
212
+
213
+ score = max(rewards) if rewards else 0.0
214
+ success = score >= 0.5
215
+
216
+ finally:
217
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
218
+
219
+ results[task_id] = {
220
+ "task_name": obs.task_name,
221
+ "final_score": round(score, 4),
222
+ "steps_taken": steps_taken,
223
+ }
224
+
225
+ return results
226
+
227
+
228
+ if __name__ == "__main__":
229
+ main()