text2sql_tani / src /execution_reward_soft.py
tjhalanigrid's picture
Added full project
cf17729
import random
import threading
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
from src.execution_reward import (
_normalize_sql,
is_valid_select,
execute_sql_cached,
execute_sql_cached_conn,
EXECUTION_ERROR,
validate_sql_schema,
USE_SCHEMA_VALIDATION,
_connect_readonly,
)
# =========================================================
# 🔥 SOFT REWARD CORE
# =========================================================
def compute_soft_reward(pred_res, gold_res, sample_k=10):
try:
# =================================================
# 1. EDGE CASES
# =================================================
if not gold_res:
return 1.0 if not pred_res else 0.3
if not pred_res:
return -0.05
# =================================================
# 2. SAFE HASHING
# =================================================
def make_hashable(row):
return tuple(str(item) for item in row)
pred_counter = Counter(make_hashable(r) for r in pred_res)
# =================================================
# 3. SAMPLING
# =================================================
k = min(sample_k, len(gold_res))
sample = random.sample(gold_res, k)
# =================================================
# 4. MATCH COUNT
# =================================================
match = 0
for row in sample:
key = make_hashable(row)
if pred_counter.get(key, 0) > 0:
pred_counter[key] -= 1
match += 1
score = match / max(len(sample), 1)
# =================================================
# 5. 🔥 ANTI-CHEAT LENGTH PENALTY
# =================================================
len_ratio = len(pred_res) / max(len(gold_res), 1)
if len_ratio > 1.5:
score = score / (len_ratio ** 0.5) # 🔥 smoother penalty
# =================================================
# 6. CLAMP SCORE (IMPORTANT FOR STABILITY)
# =================================================
score = max(0.0, min(1.0, score))
# =================================================
# 7. FINAL REWARD
# =================================================
return 0.3 + 0.7 * score
except Exception:
return -0.05
# =========================================================
# 🔥 MAIN EXECUTION REWARD
# =========================================================
_TLS = threading.local()
def _get_thread_conn(db_path: str):
conns = getattr(_TLS, "conns", None)
if conns is None:
conns = {}
_TLS.conns = conns
conn = conns.get(db_path)
if conn is None:
conn = _connect_readonly(db_path)
conns[db_path] = conn
return conn
def execution_reward_soft_pooled(pred_sql, db_path, gold_sql, *, sample_k: int = 10):
"""
Soft execution reward, but reuses a per-thread read-only SQLite connection.
This avoids connect/close overhead in RL loops.
"""
try:
sql = _normalize_sql(pred_sql)
gold = _normalize_sql(gold_sql)
if not is_valid_select(sql):
return -0.05
if USE_SCHEMA_VALIDATION:
ok, _ = validate_sql_schema(sql, db_path)
if not ok:
return -0.05
conn = _get_thread_conn(db_path)
pred_res = execute_sql_cached_conn(conn, db_path, sql)
if pred_res == EXECUTION_ERROR:
return -0.05
gold_res = execute_sql_cached_conn(conn, db_path, gold)
if gold_res == EXECUTION_ERROR:
return -0.05
return compute_soft_reward(pred_res, gold_res, sample_k=int(sample_k))
except Exception:
return -0.05
def execution_reward_soft(pred_sql, db_path, gold_sql):
try:
sql = _normalize_sql(pred_sql)
gold = _normalize_sql(gold_sql)
# =================================================
# BASIC VALIDATION
# =================================================
if not is_valid_select(sql):
return -0.05
if USE_SCHEMA_VALIDATION:
ok, _ = validate_sql_schema(sql, db_path)
if not ok:
return -0.05
# =================================================
# EXECUTION
# =================================================
pred_res = execute_sql_cached(db_path, sql)
if pred_res == EXECUTION_ERROR:
return -0.05
gold_res = execute_sql_cached(db_path, gold)
if gold_res == EXECUTION_ERROR:
return -0.05
return compute_soft_reward(pred_res, gold_res)
except Exception:
return -0.05
def execution_reward_soft_batch_parallel_by_db(rollouts, *, max_workers: int = 20, sample_k: int = 10):
"""
rollouts: Sequence[(pred_sql, db_path, gold_sql)]
Executes with 1-thread-per-DB grouping for better connection reuse.
Returns rewards in the same order as input.
"""
if not rollouts:
return []
# Group by DB so each worker can hold a single connection and reuse it.
by_db = {}
for idx, (pred_sql, db_path, gold_sql) in enumerate(rollouts):
by_db.setdefault(db_path, []).append((idx, pred_sql, gold_sql))
out = [0.0 for _ in range(len(rollouts))]
def _worker(db_path: str, items):
conn = _connect_readonly(db_path)
try:
for idx, pred_sql, gold_sql in items:
# Do NOT use the global thread-local here; this worker owns the connection.
try:
sql = _normalize_sql(pred_sql)
gold = _normalize_sql(gold_sql)
if not is_valid_select(sql):
out[idx] = -0.05
continue
if USE_SCHEMA_VALIDATION:
ok, _ = validate_sql_schema(sql, db_path)
if not ok:
out[idx] = -0.05
continue
pred_res = execute_sql_cached_conn(conn, db_path, sql)
if pred_res == EXECUTION_ERROR:
out[idx] = -0.05
continue
gold_res = execute_sql_cached_conn(conn, db_path, gold)
if gold_res == EXECUTION_ERROR:
out[idx] = -0.05
continue
out[idx] = float(compute_soft_reward(pred_res, gold_res, sample_k=int(sample_k)))
except Exception:
out[idx] = -0.05
finally:
conn.close()
with ThreadPoolExecutor(max_workers=int(max_workers)) as ex:
futures = [ex.submit(_worker, db_path, items) for db_path, items in by_db.items()]
for fut in as_completed(futures):
fut.result()
return out