Spaces:
Sleeping
Sleeping
File size: 8,095 Bytes
f0e5200 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 | import os
# Ensure headless-safe matplotlib + writable cache when called from Gradio/subprocess.
os.environ.setdefault("MPLBACKEND", "Agg")
os.environ.setdefault("MPLCONFIGDIR", os.environ.get("MPLCONFIGDIR", "/tmp/mplconfig"))
import time
import json
import argparse
import matplotlib.pyplot as plt
import numpy as np
import sys
from pathlib import Path
# ==========================================
# RELATIVE PATH RESOLUTION
# ==========================================
PROJECT_ROOT = Path(__file__).resolve().parent.parent
sys.path.append(str(PROJECT_ROOT))
# Dynamically resolve where the databases are kept
if (PROJECT_ROOT / "data" / "database").exists() and list((PROJECT_ROOT / "data" / "database").rglob("*.sqlite")):
DB_ROOT = PROJECT_ROOT / "data" / "database"
else:
DB_ROOT = PROJECT_ROOT / "final_databases"
from src.execution_reward import (
execution_reward_batch_sequential,
execution_reward_batch_parallel,
execution_reward_batch_parallel_by_db,
execution_reward_timed,
set_use_cache,
set_use_schema_validation,
clear_result_cache
)
def generate_mock_rollouts(num_rollouts: int = 100, heavy_n: int = 500_000):
"""Generates heavy queries across multiple databases to properly test true concurrency."""
print(f"\nGenerating {num_rollouts} heavy rollouts to simulate RLHF query workload...", flush=True)
# Smart search for real databases
real_dbs = [str(p) for p in DB_ROOT.rglob("*.sqlite")]
if real_dbs:
print(f"Found {len(real_dbs)} real SQLite databases in {DB_ROOT}. Distributing workload...", flush=True)
else:
print(f"❌ CRITICAL ERROR: No real databases found in {DB_ROOT}. Cannot run benchmark.", flush=True)
sys.exit(1)
rollouts = []
for i in range(num_rollouts):
db_path = real_dbs[i % len(real_dbs)]
# Heavy deterministic CPU-ish query (may be cut off by the 2s timeout depending on machine).
heavy_sql = f"""
WITH RECURSIVE cnt(x) AS (
SELECT 1
UNION ALL
SELECT x+1 FROM cnt WHERE x < {heavy_n + (i % 10_000)}
)
SELECT sum(x) FROM cnt;
"""
clean_sql = heavy_sql.replace("\n", " ").strip()
rollouts.append((clean_sql, db_path, clean_sql))
if num_rollouts >= 500 and (i + 1) % 250 == 0:
print(f" generated {i + 1}/{num_rollouts}...", flush=True)
return rollouts
def profile_bottlenecks(rollouts, sample_size: int = 20, print_every: int = 5):
"""Profiles CPU usage to identify time spent in parsing, planning, and execution."""
print("\n" + "="*65)
print(" 🔍 CPU PROFILING: IDENTIFYING BOTTLENECKS (100 Rollouts)")
print("="*65)
clear_result_cache()
set_use_cache(False) # Disable cache to force real work
set_use_schema_validation(False) # CTE-heavy benchmark queries may fail schema validation
total_parse = 0.0
total_plan = 0.0
total_exec = 0.0
# Profile a small subset by default so the script prints quickly.
sample_size = min(int(sample_size), len(rollouts))
sample_rollouts = rollouts[:sample_size]
for i, (pred, db, gold) in enumerate(sample_rollouts, 1):
_, timings = execution_reward_timed(pred, db, gold, measure_plan=True)
total_parse += timings['parse_s']
total_plan += timings['plan_s']
total_exec += timings['exec_s']
if print_every and (i % int(print_every) == 0 or i == sample_size):
print(f" profiled {i}/{sample_size}...", flush=True)
total_time = total_parse + total_plan + total_exec
if total_time == 0: total_time = 0.0001 # Prevent div by zero
print(f"{'Phase':<15} | {'Avg Time (ms)':<15} | {'% of Total CPU':<15}")
print("-" * 65)
print(f"{'Regex Parsing':<15} | {(total_parse/sample_size)*1000:<15.2f} | {(total_parse/total_time)*100:<14.1f}%")
print(f"{'Query Planning':<15} | {(total_plan/sample_size)*1000:<15.2f} | {(total_plan/total_time)*100:<14.1f}%")
print(f"{'DB Execution':<15} | {(total_exec/sample_size)*1000:<15.2f} | {(total_exec/total_time)*100:<14.1f}%")
print("="*65 + "\n")
def run_benchmark_for_setting(rollouts, use_cache: bool, max_workers: int):
set_use_cache(use_cache)
set_use_schema_validation(False) # benchmark focuses on execution speed
# Sequential
clear_result_cache()
start_time = time.perf_counter()
execution_reward_batch_sequential(rollouts)
sequential_s = time.perf_counter() - start_time
# Parallel
clear_result_cache()
start_time = time.perf_counter()
# 1 thread per DB (recommended)
execution_reward_batch_parallel_by_db(rollouts, max_workers=max_workers)
parallel_s = time.perf_counter() - start_time
speedup = sequential_s / parallel_s if parallel_s > 0 else 0
return {
"sequential_s": sequential_s,
"parallel_s": parallel_s,
"speedup": speedup
}
def print_comparison_table(results):
print("="*65)
print(f"{'Setting':<16} | {'Sequential (s)':<14} | {'Parallel (s)':<14} | {'Speedup':<10}")
print("-" * 65)
for setting, key in [("With Cache", "with_cache"), ("Without Cache", "without_cache")]:
seq = results[key]['sequential_s']
par = results[key]['parallel_s']
spd = results[key]['speedup']
print(f"{setting:<16} | {seq:<14.4f} | {par:<14.4f} | {spd:<9.2f}x")
print("="*65 + "\n")
def plot_results(results, output_path: str):
labels = ['With Cache', 'Without Cache']
seq_times = [results['with_cache']['sequential_s'], results['without_cache']['sequential_s']]
par_times = [results['with_cache']['parallel_s'], results['without_cache']['parallel_s']]
x = np.arange(len(labels))
width = 0.35
fig, ax = plt.subplots(figsize=(8, 6))
ax.bar(x - width/2, seq_times, width, label='Sequential', color='#4C72B0')
ax.bar(x + width/2, par_times, width, label='Parallel', color='#DD8452')
ax.set_ylabel('Execution Time (seconds)')
ax.set_title('Text2SQL Reward Execution: Sequential vs Parallel')
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend()
for container in ax.containers:
ax.bar_label(container, fmt='%.2f', padding=3)
fig.tight_layout()
plt.savefig(output_path, dpi=300)
plt.close()
def main():
parser = argparse.ArgumentParser(description="Benchmark SQL Execution Reward")
parser.add_argument("--n", type=int, default=1000, help="Number of rollouts to benchmark")
parser.add_argument("--max-workers", type=int, default=20, help="Max workers for parallel execution")
parser.add_argument("--heavy-n", type=int, default=200_000, help="Recursive CTE upper bound (controls heaviness)")
parser.add_argument("--skip-profile", action="store_true", help="Skip the CPU profiling section for faster startup")
parser.add_argument("--profile-n", type=int, default=20, help="Number of rollouts to use for CPU profiling")
args = parser.parse_args()
os.makedirs(str(PROJECT_ROOT / "results"), exist_ok=True)
rollouts = generate_mock_rollouts(args.n, heavy_n=args.heavy_n)
if not args.skip_profile:
profile_bottlenecks(rollouts, sample_size=args.profile_n)
print("Starting Main Scalability Benchmarks...")
print("Running Experiment A: Cache ENABLED...")
results_with_cache = run_benchmark_for_setting(rollouts, use_cache=True, max_workers=args.max_workers)
print("Running Experiment B: Cache DISABLED...")
results_without_cache = run_benchmark_for_setting(rollouts, use_cache=False, max_workers=args.max_workers)
final_results = {
"with_cache": results_with_cache,
"without_cache": results_without_cache
}
json_path = str(PROJECT_ROOT / "results" / "task1_results.json")
with open(json_path, 'w') as f:
json.dump(final_results, f, indent=4)
print_comparison_table(final_results)
plot_results(final_results, str(PROJECT_ROOT / "results" / "task1_plot.png"))
if __name__ == "__main__":
main() |