""" Autonomous parameter optimizer for the drum extraction pipeline. Runs a loop: 1. Generate synthetic songs with known ground truth 2. Run the extraction pipeline with current params 3. Evaluate extraction quality against ground truth 4. Use results to tune parameters for next iteration Uses Bayesian-ish optimization: maintain a history of (params → score), then perturb the best-so-far params toward improving weak metrics. """ import json import time import traceback import numpy as np from copy import deepcopy from dataclasses import dataclass, field from pathlib import Path from synth_generator import generate_test_song, SyntheticSong from evaluation import evaluate_extraction, report_to_dict, EvalReport from quality_metrics import drum_sample_score @dataclass class PipelineParams: """All tunable parameters of the extraction pipeline.""" # Onset detection pre_pad: float = 0.005 min_hit_dur: float = 0.03 max_hit_dur: float = 0.8 min_gap: float = 0.02 energy_threshold_db: float = -40.0 # Overlap separation separate_overlaps: bool = True overlap_energy_threshold: float = 0.15 # band energy ratio to count as significant # Clustering use_clap: bool = False # Selection weights (must sum to 1.0) w_completeness: float = 0.30 w_cleanness: float = 0.40 w_onset: float = 0.20 w_representativeness: float = 0.10 # Synthesis synthesize: bool = True synth_best_weight: float = 2.0 # weight multiplier for best sample in cluster def to_dict(self) -> dict: return self.__dict__.copy() @classmethod def from_dict(cls, d: dict) -> 'PipelineParams': valid_keys = cls.__dataclass_fields__.keys() return cls(**{k: v for k, v in d.items() if k in valid_keys}) @dataclass class IterationResult: """Result of one optimization iteration.""" iteration: int params: dict eval_report: dict overall_score: float duration_seconds: float test_config: dict # which synthetic song was used timestamp: str @dataclass class OptimizerState: """Persistent state of the optimizer.""" history: list = field(default_factory=list) # [IterationResult] best_params: dict = field(default_factory=dict) best_score: float = 0.0 iteration: int = 0 # ───────────────────────────────────────────────────────────────────────────── # Parameter perturbation strategies # ───────────────────────────────────────────────────────────────────────────── def diagnose_and_perturb(params: PipelineParams, report: EvalReport, rng: np.random.RandomState) -> PipelineParams: """Analyze evaluation report and intelligently perturb parameters. Instead of random search, we diagnose specific failure modes from the evaluation metrics and adjust the relevant parameters. """ new_params = deepcopy(params) changes = [] # ── Diagnosis 1: Poor onset precision (>20ms mean error) ── if report.mean_onset_error_ms > 20: # Reduce pre_pad to tighten onset capture new_params.pre_pad = max(0.001, params.pre_pad * rng.uniform(0.5, 0.9)) # Reduce min_gap to catch faster sequences new_params.min_gap = max(0.01, params.min_gap * rng.uniform(0.6, 0.9)) changes.append(f"onset_error={report.mean_onset_error_ms:.1f}ms → tightened pre_pad/min_gap") # ── Diagnosis 2: Missing hits (low hit count accuracy) ── if report.hit_count_accuracy < 0.7: # Lower energy threshold to catch quieter hits new_params.energy_threshold_db = max(-60, params.energy_threshold_db - rng.uniform(2, 8)) # Reduce min_hit_dur to catch shorter sounds new_params.min_hit_dur = max(0.01, params.min_hit_dur * rng.uniform(0.5, 0.8)) changes.append(f"hit_acc={report.hit_count_accuracy:.2f} → lowered threshold/min_dur") # ── Diagnosis 3: Too many false hits (extracted >> GT) ── total_ext = sum(m.n_hits_extracted for m in report.matches) if report.matches else 0 total_gt = sum(m.n_hits_gt for m in report.matches) if report.matches else 1 if total_ext > total_gt * 1.5: # Raise energy threshold new_params.energy_threshold_db = min(-20, params.energy_threshold_db + rng.uniform(2, 5)) new_params.min_hit_dur = min(0.08, params.min_hit_dur * rng.uniform(1.1, 1.5)) changes.append(f"over-extraction ({total_ext} vs {total_gt} GT) → raised threshold") # ── Diagnosis 4: Low SI-SDR (poor sample quality) ── if report.mean_si_sdr < 5: # The extracted samples don't match GT well # Try adjusting overlap separation threshold new_params.overlap_energy_threshold = params.overlap_energy_threshold + rng.uniform(-0.05, 0.05) new_params.overlap_energy_threshold = np.clip(new_params.overlap_energy_threshold, 0.05, 0.4) changes.append(f"SI-SDR={report.mean_si_sdr:.1f}dB → adjusted overlap threshold") # ── Diagnosis 5: Low sample scores (poor completeness/cleanness) ── if report.mean_sample_score < 50: # Adjust selection weights # More weight on cleanness if we're getting bleed-heavy samples new_params.w_cleanness = min(0.6, params.w_cleanness + rng.uniform(0, 0.1)) new_params.w_completeness = max(0.15, params.w_completeness + rng.uniform(-0.05, 0.05)) # Renormalize total_w = new_params.w_cleanness + new_params.w_completeness + new_params.w_onset + new_params.w_representativeness new_params.w_cleanness /= total_w new_params.w_completeness /= total_w new_params.w_onset /= total_w new_params.w_representativeness /= total_w changes.append(f"sample_score={report.mean_sample_score:.1f} → adjusted selection weights") # ── Diagnosis 6: Low envelope correlation (transient mismatch) ── if report.mean_env_corr < 0.7: new_params.max_hit_dur = min(1.5, params.max_hit_dur * rng.uniform(1.1, 1.3)) changes.append(f"env_corr={report.mean_env_corr:.2f} → increased max_hit_dur") # ── Diagnosis 7: Unmatched GT samples (some drums never found) ── if len(report.unmatched_gt) > 0: new_params.energy_threshold_db = max(-60, params.energy_threshold_db - rng.uniform(3, 6)) changes.append(f"missed {report.unmatched_gt} → lowered energy threshold") # If no specific diagnosis triggered, apply small random perturbation if not changes: # Explore nearby parameter space new_params.energy_threshold_db += rng.uniform(-3, 3) new_params.pre_pad += rng.uniform(-0.002, 0.002) new_params.pre_pad = max(0.001, new_params.pre_pad) new_params.min_hit_dur += rng.uniform(-0.01, 0.01) new_params.min_hit_dur = max(0.01, new_params.min_hit_dur) changes.append("no specific issue → random exploration") return new_params, changes # ───────────────────────────────────────────────────────────────────────────── # Main optimization loop # ───────────────────────────────────────────────────────────────────────────── def run_extraction_with_params(song: SyntheticSong, params: PipelineParams) -> tuple: """Run the extraction pipeline with given params on a song. Returns (clusters, all_hits) or raises on failure.""" from drum_extractor import ( detect_onsets, classify_and_separate_hits, compute_librosa_embeddings, cluster_hits, select_best_representatives, synthesize_from_cluster, ) # Stage 2: Onset detection hits = detect_onsets( song.drums_only, song.sr, pre_pad=params.pre_pad, min_hit_dur=params.min_hit_dur, max_hit_dur=params.max_hit_dur, min_gap=params.min_gap, energy_threshold_db=params.energy_threshold_db, ) if len(hits) == 0: return [], [] # Stage 3: Classify & separate hits = classify_and_separate_hits(hits, separate_overlaps=params.separate_overlaps) # Stage 4: Embed & cluster embeddings = compute_librosa_embeddings(hits) clusters = cluster_hits(hits, embeddings) # Stage 5: Select best (using our improved scoring) for cluster in clusters: if cluster.count == 1: cluster.best_hit_idx = 0 continue scores = [] base_label = cluster.label.rsplit('_', 1)[0] # Compute cluster radius for representativeness scoring hit_features = [] for hit in cluster.hits: import librosa feat = np.concatenate([ librosa.feature.mfcc(y=hit.audio, sr=hit.sr, n_mfcc=13).mean(axis=1), [hit.rms_energy, hit.spectral_centroid, hit.duration] ]) hit_features.append(feat) hit_features = np.array(hit_features) mean_f = hit_features.mean(axis=0) std_f = hit_features.std(axis=0) + 1e-8 hit_features_norm = (hit_features - mean_f) / std_f centroid = hit_features_norm.mean(axis=0) dists = np.linalg.norm(hit_features_norm - centroid, axis=1) radius = dists.max() + 1e-8 for i, hit in enumerate(cluster.hits): score = drum_sample_score( hit.audio, hit.sr, base_label, centroid_dist=dists[i], cluster_radius=radius, ) scores.append(score['total']) cluster.best_hit_idx = int(np.argmax(scores)) # Stage 6: Synthesis if params.synthesize: for cluster in clusters: if cluster.count >= 2: cluster.synthesized = synthesize_from_cluster(cluster) return clusters, hits def run_optimization_loop( n_iterations: int = 10, patterns: list = None, initial_params: PipelineParams = None, seed: int = 42, log_callback=None, ) -> OptimizerState: """Run the full autonomous optimization loop. Args: n_iterations: number of optimization iterations patterns: list of pattern names to test with (cycles through them) initial_params: starting pipeline parameters seed: random seed log_callback: function(str) called with log messages """ if patterns is None: patterns = ['rock', 'funk', 'halftime'] if initial_params is None: initial_params = PipelineParams() rng = np.random.RandomState(seed) state = OptimizerState(best_params=initial_params.to_dict()) current_params = deepcopy(initial_params) def log(msg): if log_callback: log_callback(msg) print(msg) log(f"Starting optimization loop: {n_iterations} iterations") log(f"Patterns: {patterns}") for i in range(n_iterations): t0 = time.time() pattern_name = patterns[i % len(patterns)] song_seed = seed + i * 17 # different song each iteration log(f"\n{'='*60}") log(f"ITERATION {i+1}/{n_iterations} — pattern={pattern_name}, seed={song_seed}") log(f"{'='*60}") try: # 1. Generate synthetic song log(" Generating synthetic song...") song = generate_test_song( pattern_name=pattern_name, bars=4, bpm=100 + rng.randint(0, 40) * 2, # vary BPM variation='medium', seed=song_seed, ) log(f" → {song.duration:.1f}s, {song.bpm}BPM, " f"{len(song.hits)} hits, {len(song.samples)} sample types") # 2. Run extraction log(f" Running extraction with params: threshold={current_params.energy_threshold_db:.1f}dB, " f"pre_pad={current_params.pre_pad:.3f}, min_dur={current_params.min_hit_dur:.3f}") clusters, all_hits = run_extraction_with_params(song, current_params) log(f" → {len(clusters)} clusters, {len(all_hits)} total hits") # 3. Evaluate log(" Evaluating against ground truth...") gt_samples = {name: s.audio for name, s in song.samples.items()} gt_hit_map = [ {'sample': h.sample_name, 'onset': h.onset_time, 'velocity': h.velocity} for h in song.hits ] report = evaluate_extraction( extracted_clusters=clusters, gt_samples=gt_samples, gt_hit_map=gt_hit_map, sr=song.sr, all_hits=all_hits, pipeline_params=current_params.to_dict(), ) duration = time.time() - t0 log(f" RESULTS:") log(f" Overall Score: {report.overall_score:.1f}/100") log(f" SI-SDR: {report.mean_si_sdr:.1f} dB") log(f" Sample Score: {report.mean_sample_score:.1f}/100") log(f" Env Corr: {report.mean_env_corr:.3f}") log(f" Onset Error: {report.mean_onset_error_ms:.1f} ms") log(f" Hit Count Acc: {report.hit_count_accuracy:.2f}") log(f" Matched: {len(report.matches)}/{len(song.samples)}") if report.unmatched_gt: log(f" ⚠ Unmatched GT: {report.unmatched_gt}") # Record iteration result = IterationResult( iteration=i, params=current_params.to_dict(), eval_report=report_to_dict(report), overall_score=report.overall_score, duration_seconds=duration, test_config={'pattern': pattern_name, 'bpm': song.bpm, 'seed': song_seed}, timestamp=time.strftime('%Y-%m-%d %H:%M:%S'), ) state.history.append(result) # Update best if report.overall_score > state.best_score: state.best_score = report.overall_score state.best_params = current_params.to_dict() log(f" ★ NEW BEST SCORE: {report.overall_score:.1f}") # 4. Tune parameters for next iteration new_params, changes = diagnose_and_perturb(current_params, report, rng) log(f" Parameter adjustments:") for change in changes: log(f" → {change}") current_params = new_params except Exception as e: log(f" ✗ ERROR: {e}") log(traceback.format_exc()) # On error, try random perturbation current_params.energy_threshold_db += rng.uniform(-5, 5) state.history.append(IterationResult( iteration=i, params=current_params.to_dict(), eval_report={'error': str(e)}, overall_score=0.0, duration_seconds=time.time() - t0, test_config={'pattern': pattern_name}, timestamp=time.strftime('%Y-%m-%d %H:%M:%S'), )) state.iteration = i + 1 log(f"\n{'='*60}") log(f"OPTIMIZATION COMPLETE") log(f"{'='*60}") log(f" Best score: {state.best_score:.1f}/100") log(f" Best params: {json.dumps(state.best_params, indent=2)}") return state