""" Generalized optimizer: tests across diverse synthetic songs, saves best config. Changes from v1: - Tests each config against MULTIPLE songs (rock/funk/halftime/vocal/sfx) - Averages scores across all test songs for robust evaluation - Saves winning configs to HF Hub with leaderboard scores - Diagnostic-driven parameter tuning (same as before, improved) """ import json, time, traceback, numpy as np from copy import deepcopy from dataclasses import dataclass, field from synth_generator import generate_test_song from config_store import PipelineConfig, save_config @dataclass class IterationResult: iteration: int params: dict scores: dict # {pattern: score} avg_score: float duration_s: float changes: list timestamp: str = "" @dataclass class OptimizerState: history: list = field(default_factory=list) best_config: dict = field(default_factory=dict) best_score: float = 0.0 iteration: int = 0 def run_extraction_eval(song, config: PipelineConfig): """Run extraction + evaluation on a single song. Returns eval dict.""" from sample_extractor import (detect_onsets, classify_and_separate, compute_embeddings, cluster_hits, select_best, synthesize_from_cluster, sample_quality_score) from evaluation import evaluate_extraction hits = detect_onsets(song.drums_only, song.sr, pre_pad=config.pre_pad, min_dur=config.min_dur, max_dur=config.max_dur, min_gap=config.min_gap, energy_threshold_db=config.energy_threshold_db, mode=config.onset_mode) if not hits: return {'overall_score': 0, 'mean_si_sdr': -50, 'mean_sample_score': 0, 'mean_env_corr': 0, 'mean_onset_error_ms': 50, 'hit_count_accuracy': 0} hits = classify_and_separate(hits, separate_overlaps=config.separate_overlaps, overlap_threshold=config.overlap_threshold) embs = compute_embeddings(hits) clusters = cluster_hits(hits, embs) select_best(clusters) if config.synthesize: for c in clusters: if c.count >= 2: c.synthesized = synthesize_from_cluster(c) gt_samples = {name: s.audio for name, s in song.samples.items()} gt_hits = [{'sample': h.sample_name, 'onset': h.onset_time, 'velocity': h.velocity} for h in song.hits] report = evaluate_extraction(extracted_clusters=clusters, gt_samples=gt_samples, gt_hit_map=gt_hits, sr=song.sr, all_hits=hits) return { 'overall_score': report.overall_score, 'mean_si_sdr': report.mean_si_sdr, 'mean_sample_score': report.mean_sample_score, 'mean_env_corr': report.mean_env_corr, 'mean_onset_error_ms': report.mean_onset_error_ms, 'hit_count_accuracy': report.hit_count_accuracy, } def eval_config_across_songs(config: PipelineConfig, seeds: list, patterns: list, bpms: list) -> dict: """Evaluate a config across multiple test songs. Returns averaged metrics.""" all_scores = [] for seed, pattern, bpm in zip(seeds, patterns, bpms): try: song = generate_test_song(pattern_name=pattern, bars=4, bpm=bpm, variation='medium', seed=seed) result = run_extraction_eval(song, config) all_scores.append(result) except Exception as e: all_scores.append({'overall_score': 0, 'mean_si_sdr': -50, 'mean_sample_score': 0, 'mean_env_corr': 0, 'mean_onset_error_ms': 50, 'hit_count_accuracy': 0}) # Average across all songs avg = {} for key in all_scores[0]: vals = [s[key] for s in all_scores] avg[key] = float(np.mean(vals)) avg['n_songs'] = len(all_scores) avg['per_song'] = all_scores return avg def diagnose_and_perturb(config: PipelineConfig, metrics: dict, rng) -> tuple: """Diagnose issues from metrics and perturb config. Returns (new_config, changes).""" c = PipelineConfig.from_dict(config.to_dict()) changes = [] if metrics.get('mean_onset_error_ms', 0) > 20: c.pre_pad = max(0.001, config.pre_pad * rng.uniform(0.5, 0.9)) c.min_gap = max(0.008, config.min_gap * rng.uniform(0.6, 0.9)) changes.append(f"onset_err={metrics['mean_onset_error_ms']:.0f}ms → tightened timing") if metrics.get('hit_count_accuracy', 1) < 0.7: c.energy_threshold_db = max(-65, config.energy_threshold_db - rng.uniform(2, 8)) c.min_dur = max(0.008, config.min_dur * rng.uniform(0.5, 0.8)) changes.append(f"hit_acc={metrics['hit_count_accuracy']:.2f} → lowered threshold") if metrics.get('mean_si_sdr', 0) < 5: c.overlap_threshold += rng.uniform(-0.05, 0.05) c.overlap_threshold = np.clip(c.overlap_threshold, 0.05, 0.4) changes.append(f"SI-SDR={metrics['mean_si_sdr']:.1f} → adjusted overlap") if metrics.get('mean_env_corr', 1) < 0.7: c.max_dur = min(2.0, config.max_dur * rng.uniform(1.1, 1.3)) changes.append(f"env_corr={metrics['mean_env_corr']:.2f} → increased max_dur") if not changes: c.energy_threshold_db += rng.uniform(-3, 3) c.pre_pad = max(0.001, c.pre_pad + rng.uniform(-0.002, 0.002)) c.min_dur = max(0.008, c.min_dur + rng.uniform(-0.005, 0.005)) changes.append("random exploration") return c, changes def run_optimization(n_iterations: int = 10, config_name: str = "optimized", author: str = "", save_to_hub: bool = True, seed: int = 42, log_fn=None) -> OptimizerState: """Run optimization loop, testing each config across diverse songs.""" rng = np.random.RandomState(seed) state = OptimizerState() # Test suite: diverse songs test_patterns = ['rock', 'funk', 'halftime'] * 2 # 6 songs test_seeds = [seed + i * 17 for i in range(6)] test_bpms = [120, 100, 140, 130, 110, 150] config = PipelineConfig(name=config_name, author=author) def log(msg): if log_fn: log_fn(msg) print(msg) log(f"Optimization: {n_iterations} iters, {len(test_patterns)} test songs each") for i in range(n_iterations): t0 = time.time() log(f"\n{'='*50}\nITERATION {i+1}/{n_iterations}\n{'='*50}") try: log(f" Testing config across {len(test_patterns)} songs...") metrics = eval_config_across_songs(config, test_seeds, test_patterns, test_bpms) avg_score = metrics['overall_score'] log(f" Score: {avg_score:.1f}/100 (SI-SDR={metrics['mean_si_sdr']:.1f}, " f"sample={metrics['mean_sample_score']:.1f}, " f"env={metrics['mean_env_corr']:.2f})") if avg_score > state.best_score: state.best_score = avg_score state.best_config = config.to_dict() log(f" ★ NEW BEST: {avg_score:.1f}") # Perturb new_config, changes = diagnose_and_perturb(config, metrics, rng) log(f" Changes: {'; '.join(changes)}") state.history.append(IterationResult( iteration=i, params=config.to_dict(), scores={f"song_{j}": s['overall_score'] for j, s in enumerate(metrics.get('per_song', []))}, avg_score=avg_score, duration_s=time.time()-t0, changes=changes, timestamp=time.strftime('%Y-%m-%d %H:%M:%S'), )) config = new_config except Exception as e: log(f" ERROR: {e}") config.energy_threshold_db += rng.uniform(-5, 5) state.history.append(IterationResult( iteration=i, params=config.to_dict(), scores={}, avg_score=0, duration_s=time.time()-t0, changes=[str(e)], )) state.iteration = i + 1 # Save best config if save_to_hub and state.best_config: log(f"\nSaving best config (score={state.best_score:.1f})...") best = PipelineConfig.from_dict(state.best_config) best.name = config_name best.author = author best.overall_score = state.best_score best.n_test_songs = len(test_patterns) try: save_config(best) log(f" ✓ Saved to {best.name}") except Exception as e: log(f" ⚠ Could not save to Hub: {e}") log(f"\nBest score: {state.best_score:.1f}/100") return state