import os # Ensure headless-safe matplotlib + writable cache when called from Gradio/subprocess. os.environ.setdefault("MPLBACKEND", "Agg") os.environ.setdefault("MPLCONFIGDIR", os.environ.get("MPLCONFIGDIR", "/tmp/mplconfig")) import time import json import argparse import matplotlib.pyplot as plt import numpy as np import sys from pathlib import Path # ========================================== # RELATIVE PATH RESOLUTION # ========================================== PROJECT_ROOT = Path(__file__).resolve().parent.parent sys.path.append(str(PROJECT_ROOT)) # Dynamically resolve where the databases are kept if (PROJECT_ROOT / "data" / "database").exists() and list((PROJECT_ROOT / "data" / "database").rglob("*.sqlite")): DB_ROOT = PROJECT_ROOT / "data" / "database" else: DB_ROOT = PROJECT_ROOT / "final_databases" from src.execution_reward import ( execution_reward_batch_sequential, execution_reward_batch_parallel, execution_reward_batch_parallel_by_db, execution_reward_timed, set_use_cache, set_use_schema_validation, clear_result_cache ) def generate_mock_rollouts(num_rollouts: int = 100, heavy_n: int = 500_000): """Generates heavy queries across multiple databases to properly test true concurrency.""" print(f"\nGenerating {num_rollouts} heavy rollouts to simulate RLHF query workload...", flush=True) # Smart search for real databases real_dbs = [str(p) for p in DB_ROOT.rglob("*.sqlite")] if real_dbs: print(f"Found {len(real_dbs)} real SQLite databases in {DB_ROOT}. Distributing workload...", flush=True) else: print(f"❌ CRITICAL ERROR: No real databases found in {DB_ROOT}. Cannot run benchmark.", flush=True) sys.exit(1) rollouts = [] for i in range(num_rollouts): db_path = real_dbs[i % len(real_dbs)] # Heavy deterministic CPU-ish query (may be cut off by the 2s timeout depending on machine). heavy_sql = f""" WITH RECURSIVE cnt(x) AS ( SELECT 1 UNION ALL SELECT x+1 FROM cnt WHERE x < {heavy_n + (i % 10_000)} ) SELECT sum(x) FROM cnt; """ clean_sql = heavy_sql.replace("\n", " ").strip() rollouts.append((clean_sql, db_path, clean_sql)) if num_rollouts >= 500 and (i + 1) % 250 == 0: print(f" generated {i + 1}/{num_rollouts}...", flush=True) return rollouts def profile_bottlenecks(rollouts, sample_size: int = 20, print_every: int = 5): """Profiles CPU usage to identify time spent in parsing, planning, and execution.""" print("\n" + "="*65) print(" 🔍 CPU PROFILING: IDENTIFYING BOTTLENECKS (100 Rollouts)") print("="*65) clear_result_cache() set_use_cache(False) # Disable cache to force real work set_use_schema_validation(False) # CTE-heavy benchmark queries may fail schema validation total_parse = 0.0 total_plan = 0.0 total_exec = 0.0 # Profile a small subset by default so the script prints quickly. sample_size = min(int(sample_size), len(rollouts)) sample_rollouts = rollouts[:sample_size] for i, (pred, db, gold) in enumerate(sample_rollouts, 1): _, timings = execution_reward_timed(pred, db, gold, measure_plan=True) total_parse += timings['parse_s'] total_plan += timings['plan_s'] total_exec += timings['exec_s'] if print_every and (i % int(print_every) == 0 or i == sample_size): print(f" profiled {i}/{sample_size}...", flush=True) total_time = total_parse + total_plan + total_exec if total_time == 0: total_time = 0.0001 # Prevent div by zero print(f"{'Phase':<15} | {'Avg Time (ms)':<15} | {'% of Total CPU':<15}") print("-" * 65) print(f"{'Regex Parsing':<15} | {(total_parse/sample_size)*1000:<15.2f} | {(total_parse/total_time)*100:<14.1f}%") print(f"{'Query Planning':<15} | {(total_plan/sample_size)*1000:<15.2f} | {(total_plan/total_time)*100:<14.1f}%") print(f"{'DB Execution':<15} | {(total_exec/sample_size)*1000:<15.2f} | {(total_exec/total_time)*100:<14.1f}%") print("="*65 + "\n") def run_benchmark_for_setting(rollouts, use_cache: bool, max_workers: int): set_use_cache(use_cache) set_use_schema_validation(False) # benchmark focuses on execution speed # Sequential clear_result_cache() start_time = time.perf_counter() execution_reward_batch_sequential(rollouts) sequential_s = time.perf_counter() - start_time # Parallel clear_result_cache() start_time = time.perf_counter() # 1 thread per DB (recommended) execution_reward_batch_parallel_by_db(rollouts, max_workers=max_workers) parallel_s = time.perf_counter() - start_time speedup = sequential_s / parallel_s if parallel_s > 0 else 0 return { "sequential_s": sequential_s, "parallel_s": parallel_s, "speedup": speedup } def print_comparison_table(results): print("="*65) print(f"{'Setting':<16} | {'Sequential (s)':<14} | {'Parallel (s)':<14} | {'Speedup':<10}") print("-" * 65) for setting, key in [("With Cache", "with_cache"), ("Without Cache", "without_cache")]: seq = results[key]['sequential_s'] par = results[key]['parallel_s'] spd = results[key]['speedup'] print(f"{setting:<16} | {seq:<14.4f} | {par:<14.4f} | {spd:<9.2f}x") print("="*65 + "\n") def plot_results(results, output_path: str): labels = ['With Cache', 'Without Cache'] seq_times = [results['with_cache']['sequential_s'], results['without_cache']['sequential_s']] par_times = [results['with_cache']['parallel_s'], results['without_cache']['parallel_s']] x = np.arange(len(labels)) width = 0.35 fig, ax = plt.subplots(figsize=(8, 6)) ax.bar(x - width/2, seq_times, width, label='Sequential', color='#4C72B0') ax.bar(x + width/2, par_times, width, label='Parallel', color='#DD8452') ax.set_ylabel('Execution Time (seconds)') ax.set_title('Text2SQL Reward Execution: Sequential vs Parallel') ax.set_xticks(x) ax.set_xticklabels(labels) ax.legend() for container in ax.containers: ax.bar_label(container, fmt='%.2f', padding=3) fig.tight_layout() plt.savefig(output_path, dpi=300) plt.close() def main(): parser = argparse.ArgumentParser(description="Benchmark SQL Execution Reward") parser.add_argument("--n", type=int, default=1000, help="Number of rollouts to benchmark") parser.add_argument("--max-workers", type=int, default=20, help="Max workers for parallel execution") parser.add_argument("--heavy-n", type=int, default=200_000, help="Recursive CTE upper bound (controls heaviness)") parser.add_argument("--skip-profile", action="store_true", help="Skip the CPU profiling section for faster startup") parser.add_argument("--profile-n", type=int, default=20, help="Number of rollouts to use for CPU profiling") args = parser.parse_args() os.makedirs(str(PROJECT_ROOT / "results"), exist_ok=True) rollouts = generate_mock_rollouts(args.n, heavy_n=args.heavy_n) if not args.skip_profile: profile_bottlenecks(rollouts, sample_size=args.profile_n) print("Starting Main Scalability Benchmarks...") print("Running Experiment A: Cache ENABLED...") results_with_cache = run_benchmark_for_setting(rollouts, use_cache=True, max_workers=args.max_workers) print("Running Experiment B: Cache DISABLED...") results_without_cache = run_benchmark_for_setting(rollouts, use_cache=False, max_workers=args.max_workers) final_results = { "with_cache": results_with_cache, "without_cache": results_without_cache } json_path = str(PROJECT_ROOT / "results" / "task1_results.json") with open(json_path, 'w') as f: json.dump(final_results, f, indent=4) print_comparison_table(final_results) plot_results(final_results, str(PROJECT_ROOT / "results" / "task1_plot.png")) if __name__ == "__main__": main()