| """ |
| Generalized optimizer: tests across diverse synthetic songs, saves best config. |
| |
| Changes from v1: |
| - Tests each config against MULTIPLE songs (rock/funk/halftime/vocal/sfx) |
| - Averages scores across all test songs for robust evaluation |
| - Saves winning configs to HF Hub with leaderboard scores |
| - Diagnostic-driven parameter tuning (same as before, improved) |
| """ |
|
|
| import json, time, traceback, numpy as np |
| from copy import deepcopy |
| from dataclasses import dataclass, field |
|
|
| from synth_generator import generate_test_song |
| from config_store import PipelineConfig, save_config |
|
|
|
|
| @dataclass |
| class IterationResult: |
| iteration: int |
| params: dict |
| scores: dict |
| avg_score: float |
| duration_s: float |
| changes: list |
| timestamp: str = "" |
|
|
|
|
| @dataclass |
| class OptimizerState: |
| history: list = field(default_factory=list) |
| best_config: dict = field(default_factory=dict) |
| best_score: float = 0.0 |
| iteration: int = 0 |
|
|
|
|
| def run_extraction_eval(song, config: PipelineConfig): |
| """Run extraction + evaluation on a single song. Returns eval dict.""" |
| from sample_extractor import (detect_onsets, classify_and_separate, |
| compute_embeddings, cluster_hits, select_best, |
| synthesize_from_cluster, sample_quality_score) |
| from evaluation import evaluate_extraction |
|
|
| hits = detect_onsets(song.drums_only, song.sr, pre_pad=config.pre_pad, |
| min_dur=config.min_dur, max_dur=config.max_dur, |
| min_gap=config.min_gap, energy_threshold_db=config.energy_threshold_db, |
| mode=config.onset_mode) |
| if not hits: |
| return {'overall_score': 0, 'mean_si_sdr': -50, 'mean_sample_score': 0, |
| 'mean_env_corr': 0, 'mean_onset_error_ms': 50, 'hit_count_accuracy': 0} |
|
|
| hits = classify_and_separate(hits, separate_overlaps=config.separate_overlaps, |
| overlap_threshold=config.overlap_threshold) |
| embs = compute_embeddings(hits) |
| clusters = cluster_hits(hits, embs) |
| select_best(clusters) |
|
|
| if config.synthesize: |
| for c in clusters: |
| if c.count >= 2: |
| c.synthesized = synthesize_from_cluster(c) |
|
|
| gt_samples = {name: s.audio for name, s in song.samples.items()} |
| gt_hits = [{'sample': h.sample_name, 'onset': h.onset_time, 'velocity': h.velocity} |
| for h in song.hits] |
|
|
| report = evaluate_extraction(extracted_clusters=clusters, gt_samples=gt_samples, |
| gt_hit_map=gt_hits, sr=song.sr, all_hits=hits) |
| return { |
| 'overall_score': report.overall_score, |
| 'mean_si_sdr': report.mean_si_sdr, |
| 'mean_sample_score': report.mean_sample_score, |
| 'mean_env_corr': report.mean_env_corr, |
| 'mean_onset_error_ms': report.mean_onset_error_ms, |
| 'hit_count_accuracy': report.hit_count_accuracy, |
| } |
|
|
|
|
| def eval_config_across_songs(config: PipelineConfig, seeds: list, patterns: list, |
| bpms: list) -> dict: |
| """Evaluate a config across multiple test songs. Returns averaged metrics.""" |
| all_scores = [] |
| for seed, pattern, bpm in zip(seeds, patterns, bpms): |
| try: |
| song = generate_test_song(pattern_name=pattern, bars=4, bpm=bpm, |
| variation='medium', seed=seed) |
| result = run_extraction_eval(song, config) |
| all_scores.append(result) |
| except Exception as e: |
| all_scores.append({'overall_score': 0, 'mean_si_sdr': -50, |
| 'mean_sample_score': 0, 'mean_env_corr': 0, |
| 'mean_onset_error_ms': 50, 'hit_count_accuracy': 0}) |
|
|
| |
| avg = {} |
| for key in all_scores[0]: |
| vals = [s[key] for s in all_scores] |
| avg[key] = float(np.mean(vals)) |
| avg['n_songs'] = len(all_scores) |
| avg['per_song'] = all_scores |
| return avg |
|
|
|
|
| def diagnose_and_perturb(config: PipelineConfig, metrics: dict, rng) -> tuple: |
| """Diagnose issues from metrics and perturb config. Returns (new_config, changes).""" |
| c = PipelineConfig.from_dict(config.to_dict()) |
| changes = [] |
|
|
| if metrics.get('mean_onset_error_ms', 0) > 20: |
| c.pre_pad = max(0.001, config.pre_pad * rng.uniform(0.5, 0.9)) |
| c.min_gap = max(0.008, config.min_gap * rng.uniform(0.6, 0.9)) |
| changes.append(f"onset_err={metrics['mean_onset_error_ms']:.0f}ms β tightened timing") |
|
|
| if metrics.get('hit_count_accuracy', 1) < 0.7: |
| c.energy_threshold_db = max(-65, config.energy_threshold_db - rng.uniform(2, 8)) |
| c.min_dur = max(0.008, config.min_dur * rng.uniform(0.5, 0.8)) |
| changes.append(f"hit_acc={metrics['hit_count_accuracy']:.2f} β lowered threshold") |
|
|
| if metrics.get('mean_si_sdr', 0) < 5: |
| c.overlap_threshold += rng.uniform(-0.05, 0.05) |
| c.overlap_threshold = np.clip(c.overlap_threshold, 0.05, 0.4) |
| changes.append(f"SI-SDR={metrics['mean_si_sdr']:.1f} β adjusted overlap") |
|
|
| if metrics.get('mean_env_corr', 1) < 0.7: |
| c.max_dur = min(2.0, config.max_dur * rng.uniform(1.1, 1.3)) |
| changes.append(f"env_corr={metrics['mean_env_corr']:.2f} β increased max_dur") |
|
|
| if not changes: |
| c.energy_threshold_db += rng.uniform(-3, 3) |
| c.pre_pad = max(0.001, c.pre_pad + rng.uniform(-0.002, 0.002)) |
| c.min_dur = max(0.008, c.min_dur + rng.uniform(-0.005, 0.005)) |
| changes.append("random exploration") |
|
|
| return c, changes |
|
|
|
|
| def run_optimization(n_iterations: int = 10, config_name: str = "optimized", |
| author: str = "", save_to_hub: bool = True, |
| seed: int = 42, log_fn=None) -> OptimizerState: |
| """Run optimization loop, testing each config across diverse songs.""" |
| rng = np.random.RandomState(seed) |
| state = OptimizerState() |
|
|
| |
| test_patterns = ['rock', 'funk', 'halftime'] * 2 |
| test_seeds = [seed + i * 17 for i in range(6)] |
| test_bpms = [120, 100, 140, 130, 110, 150] |
|
|
| config = PipelineConfig(name=config_name, author=author) |
|
|
| def log(msg): |
| if log_fn: log_fn(msg) |
| print(msg) |
|
|
| log(f"Optimization: {n_iterations} iters, {len(test_patterns)} test songs each") |
|
|
| for i in range(n_iterations): |
| t0 = time.time() |
| log(f"\n{'='*50}\nITERATION {i+1}/{n_iterations}\n{'='*50}") |
|
|
| try: |
| log(f" Testing config across {len(test_patterns)} songs...") |
| metrics = eval_config_across_songs(config, test_seeds, test_patterns, test_bpms) |
| avg_score = metrics['overall_score'] |
|
|
| log(f" Score: {avg_score:.1f}/100 (SI-SDR={metrics['mean_si_sdr']:.1f}, " |
| f"sample={metrics['mean_sample_score']:.1f}, " |
| f"env={metrics['mean_env_corr']:.2f})") |
|
|
| if avg_score > state.best_score: |
| state.best_score = avg_score |
| state.best_config = config.to_dict() |
| log(f" β
NEW BEST: {avg_score:.1f}") |
|
|
| |
| new_config, changes = diagnose_and_perturb(config, metrics, rng) |
| log(f" Changes: {'; '.join(changes)}") |
|
|
| state.history.append(IterationResult( |
| iteration=i, params=config.to_dict(), |
| scores={f"song_{j}": s['overall_score'] |
| for j, s in enumerate(metrics.get('per_song', []))}, |
| avg_score=avg_score, duration_s=time.time()-t0, |
| changes=changes, timestamp=time.strftime('%Y-%m-%d %H:%M:%S'), |
| )) |
| config = new_config |
|
|
| except Exception as e: |
| log(f" ERROR: {e}") |
| config.energy_threshold_db += rng.uniform(-5, 5) |
| state.history.append(IterationResult( |
| iteration=i, params=config.to_dict(), scores={}, |
| avg_score=0, duration_s=time.time()-t0, changes=[str(e)], |
| )) |
|
|
| state.iteration = i + 1 |
|
|
| |
| if save_to_hub and state.best_config: |
| log(f"\nSaving best config (score={state.best_score:.1f})...") |
| best = PipelineConfig.from_dict(state.best_config) |
| best.name = config_name |
| best.author = author |
| best.overall_score = state.best_score |
| best.n_test_songs = len(test_patterns) |
| try: |
| save_config(best) |
| log(f" β Saved to {best.name}") |
| except Exception as e: |
| log(f" β Could not save to Hub: {e}") |
|
|
| log(f"\nBest score: {state.best_score:.1f}/100") |
| return state |
|
|