drum-sample-extractor / optimizer_v2.py
rikhoffbauer2's picture
v2: Update optimizer_v2.py
1d29056 verified
"""
Generalized optimizer: tests across diverse synthetic songs, saves best config.
Changes from v1:
- Tests each config against MULTIPLE songs (rock/funk/halftime/vocal/sfx)
- Averages scores across all test songs for robust evaluation
- Saves winning configs to HF Hub with leaderboard scores
- Diagnostic-driven parameter tuning (same as before, improved)
"""
import json, time, traceback, numpy as np
from copy import deepcopy
from dataclasses import dataclass, field
from synth_generator import generate_test_song
from config_store import PipelineConfig, save_config
@dataclass
class IterationResult:
iteration: int
params: dict
scores: dict # {pattern: score}
avg_score: float
duration_s: float
changes: list
timestamp: str = ""
@dataclass
class OptimizerState:
history: list = field(default_factory=list)
best_config: dict = field(default_factory=dict)
best_score: float = 0.0
iteration: int = 0
def run_extraction_eval(song, config: PipelineConfig):
"""Run extraction + evaluation on a single song. Returns eval dict."""
from sample_extractor import (detect_onsets, classify_and_separate,
compute_embeddings, cluster_hits, select_best,
synthesize_from_cluster, sample_quality_score)
from evaluation import evaluate_extraction
hits = detect_onsets(song.drums_only, song.sr, pre_pad=config.pre_pad,
min_dur=config.min_dur, max_dur=config.max_dur,
min_gap=config.min_gap, energy_threshold_db=config.energy_threshold_db,
mode=config.onset_mode)
if not hits:
return {'overall_score': 0, 'mean_si_sdr': -50, 'mean_sample_score': 0,
'mean_env_corr': 0, 'mean_onset_error_ms': 50, 'hit_count_accuracy': 0}
hits = classify_and_separate(hits, separate_overlaps=config.separate_overlaps,
overlap_threshold=config.overlap_threshold)
embs = compute_embeddings(hits)
clusters = cluster_hits(hits, embs)
select_best(clusters)
if config.synthesize:
for c in clusters:
if c.count >= 2:
c.synthesized = synthesize_from_cluster(c)
gt_samples = {name: s.audio for name, s in song.samples.items()}
gt_hits = [{'sample': h.sample_name, 'onset': h.onset_time, 'velocity': h.velocity}
for h in song.hits]
report = evaluate_extraction(extracted_clusters=clusters, gt_samples=gt_samples,
gt_hit_map=gt_hits, sr=song.sr, all_hits=hits)
return {
'overall_score': report.overall_score,
'mean_si_sdr': report.mean_si_sdr,
'mean_sample_score': report.mean_sample_score,
'mean_env_corr': report.mean_env_corr,
'mean_onset_error_ms': report.mean_onset_error_ms,
'hit_count_accuracy': report.hit_count_accuracy,
}
def eval_config_across_songs(config: PipelineConfig, seeds: list, patterns: list,
bpms: list) -> dict:
"""Evaluate a config across multiple test songs. Returns averaged metrics."""
all_scores = []
for seed, pattern, bpm in zip(seeds, patterns, bpms):
try:
song = generate_test_song(pattern_name=pattern, bars=4, bpm=bpm,
variation='medium', seed=seed)
result = run_extraction_eval(song, config)
all_scores.append(result)
except Exception as e:
all_scores.append({'overall_score': 0, 'mean_si_sdr': -50,
'mean_sample_score': 0, 'mean_env_corr': 0,
'mean_onset_error_ms': 50, 'hit_count_accuracy': 0})
# Average across all songs
avg = {}
for key in all_scores[0]:
vals = [s[key] for s in all_scores]
avg[key] = float(np.mean(vals))
avg['n_songs'] = len(all_scores)
avg['per_song'] = all_scores
return avg
def diagnose_and_perturb(config: PipelineConfig, metrics: dict, rng) -> tuple:
"""Diagnose issues from metrics and perturb config. Returns (new_config, changes)."""
c = PipelineConfig.from_dict(config.to_dict())
changes = []
if metrics.get('mean_onset_error_ms', 0) > 20:
c.pre_pad = max(0.001, config.pre_pad * rng.uniform(0.5, 0.9))
c.min_gap = max(0.008, config.min_gap * rng.uniform(0.6, 0.9))
changes.append(f"onset_err={metrics['mean_onset_error_ms']:.0f}ms β†’ tightened timing")
if metrics.get('hit_count_accuracy', 1) < 0.7:
c.energy_threshold_db = max(-65, config.energy_threshold_db - rng.uniform(2, 8))
c.min_dur = max(0.008, config.min_dur * rng.uniform(0.5, 0.8))
changes.append(f"hit_acc={metrics['hit_count_accuracy']:.2f} β†’ lowered threshold")
if metrics.get('mean_si_sdr', 0) < 5:
c.overlap_threshold += rng.uniform(-0.05, 0.05)
c.overlap_threshold = np.clip(c.overlap_threshold, 0.05, 0.4)
changes.append(f"SI-SDR={metrics['mean_si_sdr']:.1f} β†’ adjusted overlap")
if metrics.get('mean_env_corr', 1) < 0.7:
c.max_dur = min(2.0, config.max_dur * rng.uniform(1.1, 1.3))
changes.append(f"env_corr={metrics['mean_env_corr']:.2f} β†’ increased max_dur")
if not changes:
c.energy_threshold_db += rng.uniform(-3, 3)
c.pre_pad = max(0.001, c.pre_pad + rng.uniform(-0.002, 0.002))
c.min_dur = max(0.008, c.min_dur + rng.uniform(-0.005, 0.005))
changes.append("random exploration")
return c, changes
def run_optimization(n_iterations: int = 10, config_name: str = "optimized",
author: str = "", save_to_hub: bool = True,
seed: int = 42, log_fn=None) -> OptimizerState:
"""Run optimization loop, testing each config across diverse songs."""
rng = np.random.RandomState(seed)
state = OptimizerState()
# Test suite: diverse songs
test_patterns = ['rock', 'funk', 'halftime'] * 2 # 6 songs
test_seeds = [seed + i * 17 for i in range(6)]
test_bpms = [120, 100, 140, 130, 110, 150]
config = PipelineConfig(name=config_name, author=author)
def log(msg):
if log_fn: log_fn(msg)
print(msg)
log(f"Optimization: {n_iterations} iters, {len(test_patterns)} test songs each")
for i in range(n_iterations):
t0 = time.time()
log(f"\n{'='*50}\nITERATION {i+1}/{n_iterations}\n{'='*50}")
try:
log(f" Testing config across {len(test_patterns)} songs...")
metrics = eval_config_across_songs(config, test_seeds, test_patterns, test_bpms)
avg_score = metrics['overall_score']
log(f" Score: {avg_score:.1f}/100 (SI-SDR={metrics['mean_si_sdr']:.1f}, "
f"sample={metrics['mean_sample_score']:.1f}, "
f"env={metrics['mean_env_corr']:.2f})")
if avg_score > state.best_score:
state.best_score = avg_score
state.best_config = config.to_dict()
log(f" β˜… NEW BEST: {avg_score:.1f}")
# Perturb
new_config, changes = diagnose_and_perturb(config, metrics, rng)
log(f" Changes: {'; '.join(changes)}")
state.history.append(IterationResult(
iteration=i, params=config.to_dict(),
scores={f"song_{j}": s['overall_score']
for j, s in enumerate(metrics.get('per_song', []))},
avg_score=avg_score, duration_s=time.time()-t0,
changes=changes, timestamp=time.strftime('%Y-%m-%d %H:%M:%S'),
))
config = new_config
except Exception as e:
log(f" ERROR: {e}")
config.energy_threshold_db += rng.uniform(-5, 5)
state.history.append(IterationResult(
iteration=i, params=config.to_dict(), scores={},
avg_score=0, duration_s=time.time()-t0, changes=[str(e)],
))
state.iteration = i + 1
# Save best config
if save_to_hub and state.best_config:
log(f"\nSaving best config (score={state.best_score:.1f})...")
best = PipelineConfig.from_dict(state.best_config)
best.name = config_name
best.author = author
best.overall_score = state.best_score
best.n_test_songs = len(test_patterns)
try:
save_config(best)
log(f" βœ“ Saved to {best.name}")
except Exception as e:
log(f" ⚠ Could not save to Hub: {e}")
log(f"\nBest score: {state.best_score:.1f}/100")
return state