text2sql_final_space / scripts /benchmark_parallel_reward.py
tjhalanigrid's picture
Step 2: added code folders
f0e5200
import os
# Ensure headless-safe matplotlib + writable cache when called from Gradio/subprocess.
os.environ.setdefault("MPLBACKEND", "Agg")
os.environ.setdefault("MPLCONFIGDIR", os.environ.get("MPLCONFIGDIR", "/tmp/mplconfig"))
import time
import json
import argparse
import matplotlib.pyplot as plt
import numpy as np
import sys
from pathlib import Path
# ==========================================
# RELATIVE PATH RESOLUTION
# ==========================================
PROJECT_ROOT = Path(__file__).resolve().parent.parent
sys.path.append(str(PROJECT_ROOT))
# Dynamically resolve where the databases are kept
if (PROJECT_ROOT / "data" / "database").exists() and list((PROJECT_ROOT / "data" / "database").rglob("*.sqlite")):
DB_ROOT = PROJECT_ROOT / "data" / "database"
else:
DB_ROOT = PROJECT_ROOT / "final_databases"
from src.execution_reward import (
execution_reward_batch_sequential,
execution_reward_batch_parallel,
execution_reward_batch_parallel_by_db,
execution_reward_timed,
set_use_cache,
set_use_schema_validation,
clear_result_cache
)
def generate_mock_rollouts(num_rollouts: int = 100, heavy_n: int = 500_000):
"""Generates heavy queries across multiple databases to properly test true concurrency."""
print(f"\nGenerating {num_rollouts} heavy rollouts to simulate RLHF query workload...", flush=True)
# Smart search for real databases
real_dbs = [str(p) for p in DB_ROOT.rglob("*.sqlite")]
if real_dbs:
print(f"Found {len(real_dbs)} real SQLite databases in {DB_ROOT}. Distributing workload...", flush=True)
else:
print(f"❌ CRITICAL ERROR: No real databases found in {DB_ROOT}. Cannot run benchmark.", flush=True)
sys.exit(1)
rollouts = []
for i in range(num_rollouts):
db_path = real_dbs[i % len(real_dbs)]
# Heavy deterministic CPU-ish query (may be cut off by the 2s timeout depending on machine).
heavy_sql = f"""
WITH RECURSIVE cnt(x) AS (
SELECT 1
UNION ALL
SELECT x+1 FROM cnt WHERE x < {heavy_n + (i % 10_000)}
)
SELECT sum(x) FROM cnt;
"""
clean_sql = heavy_sql.replace("\n", " ").strip()
rollouts.append((clean_sql, db_path, clean_sql))
if num_rollouts >= 500 and (i + 1) % 250 == 0:
print(f" generated {i + 1}/{num_rollouts}...", flush=True)
return rollouts
def profile_bottlenecks(rollouts, sample_size: int = 20, print_every: int = 5):
"""Profiles CPU usage to identify time spent in parsing, planning, and execution."""
print("\n" + "="*65)
print(" 🔍 CPU PROFILING: IDENTIFYING BOTTLENECKS (100 Rollouts)")
print("="*65)
clear_result_cache()
set_use_cache(False) # Disable cache to force real work
set_use_schema_validation(False) # CTE-heavy benchmark queries may fail schema validation
total_parse = 0.0
total_plan = 0.0
total_exec = 0.0
# Profile a small subset by default so the script prints quickly.
sample_size = min(int(sample_size), len(rollouts))
sample_rollouts = rollouts[:sample_size]
for i, (pred, db, gold) in enumerate(sample_rollouts, 1):
_, timings = execution_reward_timed(pred, db, gold, measure_plan=True)
total_parse += timings['parse_s']
total_plan += timings['plan_s']
total_exec += timings['exec_s']
if print_every and (i % int(print_every) == 0 or i == sample_size):
print(f" profiled {i}/{sample_size}...", flush=True)
total_time = total_parse + total_plan + total_exec
if total_time == 0: total_time = 0.0001 # Prevent div by zero
print(f"{'Phase':<15} | {'Avg Time (ms)':<15} | {'% of Total CPU':<15}")
print("-" * 65)
print(f"{'Regex Parsing':<15} | {(total_parse/sample_size)*1000:<15.2f} | {(total_parse/total_time)*100:<14.1f}%")
print(f"{'Query Planning':<15} | {(total_plan/sample_size)*1000:<15.2f} | {(total_plan/total_time)*100:<14.1f}%")
print(f"{'DB Execution':<15} | {(total_exec/sample_size)*1000:<15.2f} | {(total_exec/total_time)*100:<14.1f}%")
print("="*65 + "\n")
def run_benchmark_for_setting(rollouts, use_cache: bool, max_workers: int):
set_use_cache(use_cache)
set_use_schema_validation(False) # benchmark focuses on execution speed
# Sequential
clear_result_cache()
start_time = time.perf_counter()
execution_reward_batch_sequential(rollouts)
sequential_s = time.perf_counter() - start_time
# Parallel
clear_result_cache()
start_time = time.perf_counter()
# 1 thread per DB (recommended)
execution_reward_batch_parallel_by_db(rollouts, max_workers=max_workers)
parallel_s = time.perf_counter() - start_time
speedup = sequential_s / parallel_s if parallel_s > 0 else 0
return {
"sequential_s": sequential_s,
"parallel_s": parallel_s,
"speedup": speedup
}
def print_comparison_table(results):
print("="*65)
print(f"{'Setting':<16} | {'Sequential (s)':<14} | {'Parallel (s)':<14} | {'Speedup':<10}")
print("-" * 65)
for setting, key in [("With Cache", "with_cache"), ("Without Cache", "without_cache")]:
seq = results[key]['sequential_s']
par = results[key]['parallel_s']
spd = results[key]['speedup']
print(f"{setting:<16} | {seq:<14.4f} | {par:<14.4f} | {spd:<9.2f}x")
print("="*65 + "\n")
def plot_results(results, output_path: str):
labels = ['With Cache', 'Without Cache']
seq_times = [results['with_cache']['sequential_s'], results['without_cache']['sequential_s']]
par_times = [results['with_cache']['parallel_s'], results['without_cache']['parallel_s']]
x = np.arange(len(labels))
width = 0.35
fig, ax = plt.subplots(figsize=(8, 6))
ax.bar(x - width/2, seq_times, width, label='Sequential', color='#4C72B0')
ax.bar(x + width/2, par_times, width, label='Parallel', color='#DD8452')
ax.set_ylabel('Execution Time (seconds)')
ax.set_title('Text2SQL Reward Execution: Sequential vs Parallel')
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend()
for container in ax.containers:
ax.bar_label(container, fmt='%.2f', padding=3)
fig.tight_layout()
plt.savefig(output_path, dpi=300)
plt.close()
def main():
parser = argparse.ArgumentParser(description="Benchmark SQL Execution Reward")
parser.add_argument("--n", type=int, default=1000, help="Number of rollouts to benchmark")
parser.add_argument("--max-workers", type=int, default=20, help="Max workers for parallel execution")
parser.add_argument("--heavy-n", type=int, default=200_000, help="Recursive CTE upper bound (controls heaviness)")
parser.add_argument("--skip-profile", action="store_true", help="Skip the CPU profiling section for faster startup")
parser.add_argument("--profile-n", type=int, default=20, help="Number of rollouts to use for CPU profiling")
args = parser.parse_args()
os.makedirs(str(PROJECT_ROOT / "results"), exist_ok=True)
rollouts = generate_mock_rollouts(args.n, heavy_n=args.heavy_n)
if not args.skip_profile:
profile_bottlenecks(rollouts, sample_size=args.profile_n)
print("Starting Main Scalability Benchmarks...")
print("Running Experiment A: Cache ENABLED...")
results_with_cache = run_benchmark_for_setting(rollouts, use_cache=True, max_workers=args.max_workers)
print("Running Experiment B: Cache DISABLED...")
results_without_cache = run_benchmark_for_setting(rollouts, use_cache=False, max_workers=args.max_workers)
final_results = {
"with_cache": results_with_cache,
"without_cache": results_without_cache
}
json_path = str(PROJECT_ROOT / "results" / "task1_results.json")
with open(json_path, 'w') as f:
json.dump(final_results, f, indent=4)
print_comparison_table(final_results)
plot_results(final_results, str(PROJECT_ROOT / "results" / "task1_plot.png"))
if __name__ == "__main__":
main()