File size: 8,095 Bytes
f0e5200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import os
# Ensure headless-safe matplotlib + writable cache when called from Gradio/subprocess.
os.environ.setdefault("MPLBACKEND", "Agg")
os.environ.setdefault("MPLCONFIGDIR", os.environ.get("MPLCONFIGDIR", "/tmp/mplconfig"))
import time
import json
import argparse
import matplotlib.pyplot as plt
import numpy as np
import sys
from pathlib import Path

# ==========================================
# RELATIVE PATH RESOLUTION
# ==========================================
PROJECT_ROOT = Path(__file__).resolve().parent.parent
sys.path.append(str(PROJECT_ROOT))

# Dynamically resolve where the databases are kept
if (PROJECT_ROOT / "data" / "database").exists() and list((PROJECT_ROOT / "data" / "database").rglob("*.sqlite")):
    DB_ROOT = PROJECT_ROOT / "data" / "database"
else:
    DB_ROOT = PROJECT_ROOT / "final_databases"

from src.execution_reward import (
    execution_reward_batch_sequential,
    execution_reward_batch_parallel,
    execution_reward_batch_parallel_by_db,
    execution_reward_timed,
    set_use_cache,
    set_use_schema_validation,
    clear_result_cache
)

def generate_mock_rollouts(num_rollouts: int = 100, heavy_n: int = 500_000):
    """Generates heavy queries across multiple databases to properly test true concurrency."""
    print(f"\nGenerating {num_rollouts} heavy rollouts to simulate RLHF query workload...", flush=True)
    
    # Smart search for real databases
    real_dbs = [str(p) for p in DB_ROOT.rglob("*.sqlite")]
    
    if real_dbs:
        print(f"Found {len(real_dbs)} real SQLite databases in {DB_ROOT}. Distributing workload...", flush=True)
    else:
        print(f"❌ CRITICAL ERROR: No real databases found in {DB_ROOT}. Cannot run benchmark.", flush=True)
        sys.exit(1)
        
    rollouts = []
    for i in range(num_rollouts):
        db_path = real_dbs[i % len(real_dbs)]
            
        # Heavy deterministic CPU-ish query (may be cut off by the 2s timeout depending on machine).
        heavy_sql = f"""
        WITH RECURSIVE cnt(x) AS (
            SELECT 1
            UNION ALL
            SELECT x+1 FROM cnt WHERE x < {heavy_n + (i % 10_000)}
        )
        SELECT sum(x) FROM cnt;
        """
        clean_sql = heavy_sql.replace("\n", " ").strip()
        rollouts.append((clean_sql, db_path, clean_sql))
        if num_rollouts >= 500 and (i + 1) % 250 == 0:
            print(f"  generated {i + 1}/{num_rollouts}...", flush=True)
        
    return rollouts

def profile_bottlenecks(rollouts, sample_size: int = 20, print_every: int = 5):
    """Profiles CPU usage to identify time spent in parsing, planning, and execution."""
    print("\n" + "="*65)
    print(" 🔍 CPU PROFILING: IDENTIFYING BOTTLENECKS (100 Rollouts)")
    print("="*65)
    
    clear_result_cache()
    set_use_cache(False) # Disable cache to force real work
    set_use_schema_validation(False)  # CTE-heavy benchmark queries may fail schema validation
    
    total_parse = 0.0
    total_plan = 0.0
    total_exec = 0.0
    
    # Profile a small subset by default so the script prints quickly.
    sample_size = min(int(sample_size), len(rollouts))
    sample_rollouts = rollouts[:sample_size]
    
    for i, (pred, db, gold) in enumerate(sample_rollouts, 1):
        _, timings = execution_reward_timed(pred, db, gold, measure_plan=True)
        total_parse += timings['parse_s']
        total_plan += timings['plan_s']
        total_exec += timings['exec_s']
        if print_every and (i % int(print_every) == 0 or i == sample_size):
            print(f"  profiled {i}/{sample_size}...", flush=True)
        
    total_time = total_parse + total_plan + total_exec
    if total_time == 0: total_time = 0.0001 # Prevent div by zero
    
    print(f"{'Phase':<15} | {'Avg Time (ms)':<15} | {'% of Total CPU':<15}")
    print("-" * 65)
    print(f"{'Regex Parsing':<15} | {(total_parse/sample_size)*1000:<15.2f} | {(total_parse/total_time)*100:<14.1f}%")
    print(f"{'Query Planning':<15} | {(total_plan/sample_size)*1000:<15.2f} | {(total_plan/total_time)*100:<14.1f}%")
    print(f"{'DB Execution':<15} | {(total_exec/sample_size)*1000:<15.2f} | {(total_exec/total_time)*100:<14.1f}%")
    print("="*65 + "\n")

def run_benchmark_for_setting(rollouts, use_cache: bool, max_workers: int):
    set_use_cache(use_cache)
    set_use_schema_validation(False)  # benchmark focuses on execution speed
    
    # Sequential
    clear_result_cache()
    start_time = time.perf_counter()
    execution_reward_batch_sequential(rollouts)
    sequential_s = time.perf_counter() - start_time

    # Parallel
    clear_result_cache()
    start_time = time.perf_counter()
    # 1 thread per DB (recommended)
    execution_reward_batch_parallel_by_db(rollouts, max_workers=max_workers)
    parallel_s = time.perf_counter() - start_time

    speedup = sequential_s / parallel_s if parallel_s > 0 else 0

    return {
        "sequential_s": sequential_s,
        "parallel_s": parallel_s,
        "speedup": speedup
    }

def print_comparison_table(results):
    print("="*65)
    print(f"{'Setting':<16} | {'Sequential (s)':<14} | {'Parallel (s)':<14} | {'Speedup':<10}")
    print("-" * 65)
    for setting, key in [("With Cache", "with_cache"), ("Without Cache", "without_cache")]:
        seq = results[key]['sequential_s']
        par = results[key]['parallel_s']
        spd = results[key]['speedup']
        print(f"{setting:<16} | {seq:<14.4f} | {par:<14.4f} | {spd:<9.2f}x")
    print("="*65 + "\n")

def plot_results(results, output_path: str):
    labels = ['With Cache', 'Without Cache']
    seq_times = [results['with_cache']['sequential_s'], results['without_cache']['sequential_s']]
    par_times = [results['with_cache']['parallel_s'], results['without_cache']['parallel_s']]

    x = np.arange(len(labels))
    width = 0.35

    fig, ax = plt.subplots(figsize=(8, 6))
    ax.bar(x - width/2, seq_times, width, label='Sequential', color='#4C72B0')
    ax.bar(x + width/2, par_times, width, label='Parallel', color='#DD8452')

    ax.set_ylabel('Execution Time (seconds)')
    ax.set_title('Text2SQL Reward Execution: Sequential vs Parallel')
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.legend()
    
    for container in ax.containers:
        ax.bar_label(container, fmt='%.2f', padding=3)

    fig.tight_layout()
    plt.savefig(output_path, dpi=300)
    plt.close()

def main():
    parser = argparse.ArgumentParser(description="Benchmark SQL Execution Reward")
    parser.add_argument("--n", type=int, default=1000, help="Number of rollouts to benchmark")
    parser.add_argument("--max-workers", type=int, default=20, help="Max workers for parallel execution")
    parser.add_argument("--heavy-n", type=int, default=200_000, help="Recursive CTE upper bound (controls heaviness)")
    parser.add_argument("--skip-profile", action="store_true", help="Skip the CPU profiling section for faster startup")
    parser.add_argument("--profile-n", type=int, default=20, help="Number of rollouts to use for CPU profiling")
    args = parser.parse_args()

    os.makedirs(str(PROJECT_ROOT / "results"), exist_ok=True)

    rollouts = generate_mock_rollouts(args.n, heavy_n=args.heavy_n)
    
    if not args.skip_profile:
        profile_bottlenecks(rollouts, sample_size=args.profile_n)
    
    print("Starting Main Scalability Benchmarks...")

    print("Running Experiment A: Cache ENABLED...")
    results_with_cache = run_benchmark_for_setting(rollouts, use_cache=True, max_workers=args.max_workers)

    print("Running Experiment B: Cache DISABLED...")
    results_without_cache = run_benchmark_for_setting(rollouts, use_cache=False, max_workers=args.max_workers)

    final_results = {
        "with_cache": results_with_cache,
        "without_cache": results_without_cache
    }

    json_path = str(PROJECT_ROOT / "results" / "task1_results.json")
    with open(json_path, 'w') as f:
        json.dump(final_results, f, indent=4)

    print_comparison_table(final_results)
    plot_results(final_results, str(PROJECT_ROOT / "results" / "task1_plot.png"))

if __name__ == "__main__":
    main()