Eishaan commited on
Commit
ba3f47d
·
1 Parent(s): 5f32203

final: complete airtight hardening (strict constraints, numerical stability, performance metrics)

Browse files
server/__pycache__/environment.cpython-312.pyc CHANGED
Binary files a/server/__pycache__/environment.cpython-312.pyc and b/server/__pycache__/environment.cpython-312.pyc differ
 
server/__pycache__/grader.cpython-312.pyc CHANGED
Binary files a/server/__pycache__/grader.cpython-312.pyc and b/server/__pycache__/grader.cpython-312.pyc differ
 
server/environment.py CHANGED
@@ -351,19 +351,27 @@ class DbMigrationEnvironment(Environment):
351
  metadata={"error": "not_initialized"},
352
  )
353
 
 
 
 
354
  self._step_count += 1
355
  sql_command = action.sql_command.strip()
356
 
357
  # --- A3: Dangerous SQL Blacklist ---
358
  sql_lower = sql_command.lower()
359
  if re.search(r"pragma\s+foreign_keys\s*=\s*(off|0)", sql_lower):
360
- execution_result = "Security Error: Disabling PRAGMA foreign_keys is strictly explicitly forbidden."
 
 
 
 
361
  action_error = "pragma_off_blocked"
362
  elif _DANGEROUS_PATTERNS.search(sql_command):
363
  execution_result = (
364
  "Error: This SQL command is not allowed for security reasons. "
365
  "ATTACH DATABASE, DETACH DATABASE, LOAD_EXTENSION, and "
366
- "PRAGMA writable_schema are blocked."
 
367
  )
368
  action_error = "blocked_command"
369
  else:
@@ -443,10 +451,12 @@ class DbMigrationEnvironment(Environment):
443
  self._state.migration_progress = current_score
444
 
445
  # Build metadata with reasoning and debug info
 
446
  meta = {
447
  "reasoning": action.reasoning,
448
  "sql_executed": action.sql_command,
449
  "step": self._step_count,
 
450
  }
451
  if action_error:
452
  meta["error"] = action_error
 
351
  metadata={"error": "not_initialized"},
352
  )
353
 
354
+ import time
355
+ start_time = time.time()
356
+
357
  self._step_count += 1
358
  sql_command = action.sql_command.strip()
359
 
360
  # --- A3: Dangerous SQL Blacklist ---
361
  sql_lower = sql_command.lower()
362
  if re.search(r"pragma\s+foreign_keys\s*=\s*(off|0)", sql_lower):
363
+ execution_result = (
364
+ "Security Error: Disabling PRAGMA foreign_keys is strictly forbidden. "
365
+ "Tip: To maintain integrity, perform your migration using temporary tables "
366
+ "or deferred constraints instead of disabling enforcement."
367
+ )
368
  action_error = "pragma_off_blocked"
369
  elif _DANGEROUS_PATTERNS.search(sql_command):
370
  execution_result = (
371
  "Error: This SQL command is not allowed for security reasons. "
372
  "ATTACH DATABASE, DETACH DATABASE, LOAD_EXTENSION, and "
373
+ "PRAGMA writable_schema are blocked. "
374
+ "Tip: Use standard DML (INSERT/UPDATE/DELETE) and DDL (CREATE/DROP) within a single database."
375
  )
376
  action_error = "blocked_command"
377
  else:
 
451
  self._state.migration_progress = current_score
452
 
453
  # Build metadata with reasoning and debug info
454
+ execution_ms = int((time.time() - start_time) * 1000)
455
  meta = {
456
  "reasoning": action.reasoning,
457
  "sql_executed": action.sql_command,
458
  "step": self._step_count,
459
+ "execution_ms": execution_ms,
460
  }
461
  if action_error:
462
  meta["error"] = action_error
server/grader.py CHANGED
@@ -47,9 +47,15 @@ def _get_table_names(conn: sqlite3.Connection) -> Set[str]:
47
  def _get_column_info(conn: sqlite3.Connection, table: str) -> List[dict]:
48
  """Get column info for a table. Returns list of {name, type, notnull, pk}."""
49
  try:
50
- cursor = conn.execute(f"PRAGMA table_info({table})")
51
  return [
52
- {"name": row[1].lower(), "type": row[2].upper(), "notnull": row[3], "pk": row[5]}
 
 
 
 
 
 
53
  for row in cursor.fetchall()
54
  ]
55
  except Exception:
@@ -61,9 +67,12 @@ def _get_column_names(conn: sqlite3.Connection, table: str) -> Set[str]:
61
  return {col["name"] for col in _get_column_info(conn, table)}
62
 
63
 
64
- def _get_column_signatures(conn: sqlite3.Connection, table: str) -> Set[Tuple[str, str]]:
65
- """Get (name, type) tuples for strict schema grading."""
66
- return {(col["name"], col["type"]) for col in _get_column_info(conn, table)}
 
 
 
67
 
68
 
69
  def _get_row_count(conn: sqlite3.Connection, table: str) -> int:
@@ -158,24 +167,60 @@ def _compare_row_data(
158
  # Per-row comparison (order-independent for flexibility)
159
  golden_set = set()
160
  for row in golden_rows:
161
- # Normalize: convert all values to strings for loose comparison
162
- golden_set.add(tuple(str(v).strip() if v is not None else "" for v in row))
 
 
 
 
 
 
 
 
163
 
164
  matched = 0
165
  for row in agent_rows:
166
- normalized = tuple(str(v).strip() if v is not None else "" for v in row)
167
- if normalized in golden_set:
168
- matched += 1
169
- golden_set.discard(normalized)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  if len(golden_rows) == 0:
172
- content_match = 0.0
173
  else:
174
  content_match = matched / len(golden_rows)
175
 
176
  # Penalize extra rows (data bloat)
177
  if len(agent_rows) > len(golden_rows):
178
- bloat_penalty = max(0, 1.0 - (len(agent_rows) - len(golden_rows)) / len(golden_rows))
 
179
  content_match *= bloat_penalty
180
 
181
  return 0.4 * count_match + 0.6 * content_match
 
47
  def _get_column_info(conn: sqlite3.Connection, table: str) -> List[dict]:
48
  """Get column info for a table. Returns list of {name, type, notnull, pk}."""
49
  try:
50
+ cursor = conn.execute(f"PRAGMA table_info([{table}])")
51
  return [
52
+ {
53
+ "name": row[1].lower(),
54
+ "type": row[2].upper(),
55
+ "notnull": row[3],
56
+ "dflt_value": row[4], # Added check for default values
57
+ "pk": row[5]
58
+ }
59
  for row in cursor.fetchall()
60
  ]
61
  except Exception:
 
67
  return {col["name"] for col in _get_column_info(conn, table)}
68
 
69
 
70
+ def _get_column_signatures(conn: sqlite3.Connection, table: str) -> Set[Tuple[str, str, int, Any]]:
71
+ """Get (name, type, notnull, dflt_value) tuples for absolute schema grading."""
72
+ return {
73
+ (col["name"], col["type"], col["notnull"], str(col["dflt_value"]))
74
+ for col in _get_column_info(conn, table)
75
+ }
76
 
77
 
78
  def _get_row_count(conn: sqlite3.Connection, table: str) -> int:
 
167
  # Per-row comparison (order-independent for flexibility)
168
  golden_set = set()
169
  for row in golden_rows:
170
+ # Normalize: try numerical comparison first, fallback to string
171
+ normalized_golden = []
172
+ for v in row:
173
+ if v is None: normalized_golden.append("")
174
+ else:
175
+ try:
176
+ normalized_golden.append(float(v))
177
+ except (ValueError, TypeError):
178
+ normalized_golden.append(str(v).strip())
179
+ golden_set.add(tuple(normalized_golden))
180
 
181
  matched = 0
182
  for row in agent_rows:
183
+ # Normalize: try numerical comparison first, fallback to string
184
+ normalized_agent = []
185
+ for v in row:
186
+ if v is None: normalized_agent.append("")
187
+ else:
188
+ try: # If it looks like a number, treat as a float for comparison
189
+ normalized_agent.append(float(v))
190
+ except (ValueError, TypeError):
191
+ normalized_agent.append(str(v).strip())
192
+
193
+ normalized_agent = tuple(normalized_agent)
194
+
195
+ # Look for match in golden set with numerical tolerance
196
+ found = False
197
+ for g_row in list(golden_set):
198
+ # Compare rows with float tolerance where applicable
199
+ match = True
200
+ if len(normalized_agent) != len(g_row): continue
201
+
202
+ for a_val, g_val in zip(normalized_agent, g_row):
203
+ if isinstance(a_val, float) and isinstance(g_val, float):
204
+ if abs(a_val - g_val) > 1e-7:
205
+ match = False; break
206
+ elif str(a_val) != str(g_val):
207
+ match = False; break
208
+
209
+ if match:
210
+ matched += 1
211
+ golden_set.discard(g_row)
212
+ found = True
213
+ break
214
 
215
  if len(golden_rows) == 0:
216
+ content_match = 1.0 # Logic fix: if golden has no rows, agent with 0 rows is 1.0
217
  else:
218
  content_match = matched / len(golden_rows)
219
 
220
  # Penalize extra rows (data bloat)
221
  if len(agent_rows) > len(golden_rows):
222
+ # 1% penalty per extra row to prevent "guessing" by dumping all possible data
223
+ bloat_penalty = max(0.5, 1.0 - (len(agent_rows) - len(golden_rows)) * 0.01)
224
  content_match *= bloat_penalty
225
 
226
  return 0.4 * count_match + 0.6 * content_match