Spaces:
Sleeping
Sleeping
| import os | |
| # Ensure headless-safe matplotlib + writable cache when called from Gradio/subprocess. | |
| os.environ.setdefault("MPLBACKEND", "Agg") | |
| os.environ.setdefault("MPLCONFIGDIR", os.environ.get("MPLCONFIGDIR", "/tmp/mplconfig")) | |
| import time | |
| import json | |
| import argparse | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import sys | |
| from pathlib import Path | |
| # ========================================== | |
| # RELATIVE PATH RESOLUTION | |
| # ========================================== | |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent | |
| sys.path.append(str(PROJECT_ROOT)) | |
| # Dynamically resolve where the databases are kept | |
| if (PROJECT_ROOT / "data" / "database").exists() and list((PROJECT_ROOT / "data" / "database").rglob("*.sqlite")): | |
| DB_ROOT = PROJECT_ROOT / "data" / "database" | |
| else: | |
| DB_ROOT = PROJECT_ROOT / "final_databases" | |
| from src.execution_reward import ( | |
| execution_reward_batch_sequential, | |
| execution_reward_batch_parallel, | |
| execution_reward_batch_parallel_by_db, | |
| execution_reward_timed, | |
| set_use_cache, | |
| set_use_schema_validation, | |
| clear_result_cache | |
| ) | |
| def generate_mock_rollouts(num_rollouts: int = 100, heavy_n: int = 500_000): | |
| """Generates heavy queries across multiple databases to properly test true concurrency.""" | |
| print(f"\nGenerating {num_rollouts} heavy rollouts to simulate RLHF query workload...", flush=True) | |
| # Smart search for real databases | |
| real_dbs = [str(p) for p in DB_ROOT.rglob("*.sqlite")] | |
| if real_dbs: | |
| print(f"Found {len(real_dbs)} real SQLite databases in {DB_ROOT}. Distributing workload...", flush=True) | |
| else: | |
| print(f"❌ CRITICAL ERROR: No real databases found in {DB_ROOT}. Cannot run benchmark.", flush=True) | |
| sys.exit(1) | |
| rollouts = [] | |
| for i in range(num_rollouts): | |
| db_path = real_dbs[i % len(real_dbs)] | |
| # Heavy deterministic CPU-ish query (may be cut off by the 2s timeout depending on machine). | |
| heavy_sql = f""" | |
| WITH RECURSIVE cnt(x) AS ( | |
| SELECT 1 | |
| UNION ALL | |
| SELECT x+1 FROM cnt WHERE x < {heavy_n + (i % 10_000)} | |
| ) | |
| SELECT sum(x) FROM cnt; | |
| """ | |
| clean_sql = heavy_sql.replace("\n", " ").strip() | |
| rollouts.append((clean_sql, db_path, clean_sql)) | |
| if num_rollouts >= 500 and (i + 1) % 250 == 0: | |
| print(f" generated {i + 1}/{num_rollouts}...", flush=True) | |
| return rollouts | |
| def profile_bottlenecks(rollouts, sample_size: int = 20, print_every: int = 5): | |
| """Profiles CPU usage to identify time spent in parsing, planning, and execution.""" | |
| print("\n" + "="*65) | |
| print(" 🔍 CPU PROFILING: IDENTIFYING BOTTLENECKS (100 Rollouts)") | |
| print("="*65) | |
| clear_result_cache() | |
| set_use_cache(False) # Disable cache to force real work | |
| set_use_schema_validation(False) # CTE-heavy benchmark queries may fail schema validation | |
| total_parse = 0.0 | |
| total_plan = 0.0 | |
| total_exec = 0.0 | |
| # Profile a small subset by default so the script prints quickly. | |
| sample_size = min(int(sample_size), len(rollouts)) | |
| sample_rollouts = rollouts[:sample_size] | |
| for i, (pred, db, gold) in enumerate(sample_rollouts, 1): | |
| _, timings = execution_reward_timed(pred, db, gold, measure_plan=True) | |
| total_parse += timings['parse_s'] | |
| total_plan += timings['plan_s'] | |
| total_exec += timings['exec_s'] | |
| if print_every and (i % int(print_every) == 0 or i == sample_size): | |
| print(f" profiled {i}/{sample_size}...", flush=True) | |
| total_time = total_parse + total_plan + total_exec | |
| if total_time == 0: total_time = 0.0001 # Prevent div by zero | |
| print(f"{'Phase':<15} | {'Avg Time (ms)':<15} | {'% of Total CPU':<15}") | |
| print("-" * 65) | |
| print(f"{'Regex Parsing':<15} | {(total_parse/sample_size)*1000:<15.2f} | {(total_parse/total_time)*100:<14.1f}%") | |
| print(f"{'Query Planning':<15} | {(total_plan/sample_size)*1000:<15.2f} | {(total_plan/total_time)*100:<14.1f}%") | |
| print(f"{'DB Execution':<15} | {(total_exec/sample_size)*1000:<15.2f} | {(total_exec/total_time)*100:<14.1f}%") | |
| print("="*65 + "\n") | |
| def run_benchmark_for_setting(rollouts, use_cache: bool, max_workers: int): | |
| set_use_cache(use_cache) | |
| set_use_schema_validation(False) # benchmark focuses on execution speed | |
| # Sequential | |
| clear_result_cache() | |
| start_time = time.perf_counter() | |
| execution_reward_batch_sequential(rollouts) | |
| sequential_s = time.perf_counter() - start_time | |
| # Parallel | |
| clear_result_cache() | |
| start_time = time.perf_counter() | |
| # 1 thread per DB (recommended) | |
| execution_reward_batch_parallel_by_db(rollouts, max_workers=max_workers) | |
| parallel_s = time.perf_counter() - start_time | |
| speedup = sequential_s / parallel_s if parallel_s > 0 else 0 | |
| return { | |
| "sequential_s": sequential_s, | |
| "parallel_s": parallel_s, | |
| "speedup": speedup | |
| } | |
| def print_comparison_table(results): | |
| print("="*65) | |
| print(f"{'Setting':<16} | {'Sequential (s)':<14} | {'Parallel (s)':<14} | {'Speedup':<10}") | |
| print("-" * 65) | |
| for setting, key in [("With Cache", "with_cache"), ("Without Cache", "without_cache")]: | |
| seq = results[key]['sequential_s'] | |
| par = results[key]['parallel_s'] | |
| spd = results[key]['speedup'] | |
| print(f"{setting:<16} | {seq:<14.4f} | {par:<14.4f} | {spd:<9.2f}x") | |
| print("="*65 + "\n") | |
| def plot_results(results, output_path: str): | |
| labels = ['With Cache', 'Without Cache'] | |
| seq_times = [results['with_cache']['sequential_s'], results['without_cache']['sequential_s']] | |
| par_times = [results['with_cache']['parallel_s'], results['without_cache']['parallel_s']] | |
| x = np.arange(len(labels)) | |
| width = 0.35 | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| ax.bar(x - width/2, seq_times, width, label='Sequential', color='#4C72B0') | |
| ax.bar(x + width/2, par_times, width, label='Parallel', color='#DD8452') | |
| ax.set_ylabel('Execution Time (seconds)') | |
| ax.set_title('Text2SQL Reward Execution: Sequential vs Parallel') | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(labels) | |
| ax.legend() | |
| for container in ax.containers: | |
| ax.bar_label(container, fmt='%.2f', padding=3) | |
| fig.tight_layout() | |
| plt.savefig(output_path, dpi=300) | |
| plt.close() | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Benchmark SQL Execution Reward") | |
| parser.add_argument("--n", type=int, default=1000, help="Number of rollouts to benchmark") | |
| parser.add_argument("--max-workers", type=int, default=20, help="Max workers for parallel execution") | |
| parser.add_argument("--heavy-n", type=int, default=200_000, help="Recursive CTE upper bound (controls heaviness)") | |
| parser.add_argument("--skip-profile", action="store_true", help="Skip the CPU profiling section for faster startup") | |
| parser.add_argument("--profile-n", type=int, default=20, help="Number of rollouts to use for CPU profiling") | |
| args = parser.parse_args() | |
| os.makedirs(str(PROJECT_ROOT / "results"), exist_ok=True) | |
| rollouts = generate_mock_rollouts(args.n, heavy_n=args.heavy_n) | |
| if not args.skip_profile: | |
| profile_bottlenecks(rollouts, sample_size=args.profile_n) | |
| print("Starting Main Scalability Benchmarks...") | |
| print("Running Experiment A: Cache ENABLED...") | |
| results_with_cache = run_benchmark_for_setting(rollouts, use_cache=True, max_workers=args.max_workers) | |
| print("Running Experiment B: Cache DISABLED...") | |
| results_without_cache = run_benchmark_for_setting(rollouts, use_cache=False, max_workers=args.max_workers) | |
| final_results = { | |
| "with_cache": results_with_cache, | |
| "without_cache": results_without_cache | |
| } | |
| json_path = str(PROJECT_ROOT / "results" / "task1_results.json") | |
| with open(json_path, 'w') as f: | |
| json.dump(final_results, f, indent=4) | |
| print_comparison_table(final_results) | |
| plot_results(final_results, str(PROJECT_ROOT / "results" / "task1_plot.png")) | |
| if __name__ == "__main__": | |
| main() |