File size: 15,858 Bytes
d34b37f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 | """
Autonomous parameter optimizer for the drum extraction pipeline.
Runs a loop:
1. Generate synthetic songs with known ground truth
2. Run the extraction pipeline with current params
3. Evaluate extraction quality against ground truth
4. Use results to tune parameters for next iteration
Uses Bayesian-ish optimization: maintain a history of (params β score),
then perturb the best-so-far params toward improving weak metrics.
"""
import json
import time
import traceback
import numpy as np
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
from synth_generator import generate_test_song, SyntheticSong
from evaluation import evaluate_extraction, report_to_dict, EvalReport
from quality_metrics import drum_sample_score
@dataclass
class PipelineParams:
"""All tunable parameters of the extraction pipeline."""
# Onset detection
pre_pad: float = 0.005
min_hit_dur: float = 0.03
max_hit_dur: float = 0.8
min_gap: float = 0.02
energy_threshold_db: float = -40.0
# Overlap separation
separate_overlaps: bool = True
overlap_energy_threshold: float = 0.15 # band energy ratio to count as significant
# Clustering
use_clap: bool = False
# Selection weights (must sum to 1.0)
w_completeness: float = 0.30
w_cleanness: float = 0.40
w_onset: float = 0.20
w_representativeness: float = 0.10
# Synthesis
synthesize: bool = True
synth_best_weight: float = 2.0 # weight multiplier for best sample in cluster
def to_dict(self) -> dict:
return self.__dict__.copy()
@classmethod
def from_dict(cls, d: dict) -> 'PipelineParams':
valid_keys = cls.__dataclass_fields__.keys()
return cls(**{k: v for k, v in d.items() if k in valid_keys})
@dataclass
class IterationResult:
"""Result of one optimization iteration."""
iteration: int
params: dict
eval_report: dict
overall_score: float
duration_seconds: float
test_config: dict # which synthetic song was used
timestamp: str
@dataclass
class OptimizerState:
"""Persistent state of the optimizer."""
history: list = field(default_factory=list) # [IterationResult]
best_params: dict = field(default_factory=dict)
best_score: float = 0.0
iteration: int = 0
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Parameter perturbation strategies
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def diagnose_and_perturb(params: PipelineParams, report: EvalReport,
rng: np.random.RandomState) -> PipelineParams:
"""Analyze evaluation report and intelligently perturb parameters.
Instead of random search, we diagnose specific failure modes from the
evaluation metrics and adjust the relevant parameters.
"""
new_params = deepcopy(params)
changes = []
# ββ Diagnosis 1: Poor onset precision (>20ms mean error) ββ
if report.mean_onset_error_ms > 20:
# Reduce pre_pad to tighten onset capture
new_params.pre_pad = max(0.001, params.pre_pad * rng.uniform(0.5, 0.9))
# Reduce min_gap to catch faster sequences
new_params.min_gap = max(0.01, params.min_gap * rng.uniform(0.6, 0.9))
changes.append(f"onset_error={report.mean_onset_error_ms:.1f}ms β tightened pre_pad/min_gap")
# ββ Diagnosis 2: Missing hits (low hit count accuracy) ββ
if report.hit_count_accuracy < 0.7:
# Lower energy threshold to catch quieter hits
new_params.energy_threshold_db = max(-60, params.energy_threshold_db - rng.uniform(2, 8))
# Reduce min_hit_dur to catch shorter sounds
new_params.min_hit_dur = max(0.01, params.min_hit_dur * rng.uniform(0.5, 0.8))
changes.append(f"hit_acc={report.hit_count_accuracy:.2f} β lowered threshold/min_dur")
# ββ Diagnosis 3: Too many false hits (extracted >> GT) ββ
total_ext = sum(m.n_hits_extracted for m in report.matches) if report.matches else 0
total_gt = sum(m.n_hits_gt for m in report.matches) if report.matches else 1
if total_ext > total_gt * 1.5:
# Raise energy threshold
new_params.energy_threshold_db = min(-20, params.energy_threshold_db + rng.uniform(2, 5))
new_params.min_hit_dur = min(0.08, params.min_hit_dur * rng.uniform(1.1, 1.5))
changes.append(f"over-extraction ({total_ext} vs {total_gt} GT) β raised threshold")
# ββ Diagnosis 4: Low SI-SDR (poor sample quality) ββ
if report.mean_si_sdr < 5:
# The extracted samples don't match GT well
# Try adjusting overlap separation threshold
new_params.overlap_energy_threshold = params.overlap_energy_threshold + rng.uniform(-0.05, 0.05)
new_params.overlap_energy_threshold = np.clip(new_params.overlap_energy_threshold, 0.05, 0.4)
changes.append(f"SI-SDR={report.mean_si_sdr:.1f}dB β adjusted overlap threshold")
# ββ Diagnosis 5: Low sample scores (poor completeness/cleanness) ββ
if report.mean_sample_score < 50:
# Adjust selection weights
# More weight on cleanness if we're getting bleed-heavy samples
new_params.w_cleanness = min(0.6, params.w_cleanness + rng.uniform(0, 0.1))
new_params.w_completeness = max(0.15, params.w_completeness + rng.uniform(-0.05, 0.05))
# Renormalize
total_w = new_params.w_cleanness + new_params.w_completeness + new_params.w_onset + new_params.w_representativeness
new_params.w_cleanness /= total_w
new_params.w_completeness /= total_w
new_params.w_onset /= total_w
new_params.w_representativeness /= total_w
changes.append(f"sample_score={report.mean_sample_score:.1f} β adjusted selection weights")
# ββ Diagnosis 6: Low envelope correlation (transient mismatch) ββ
if report.mean_env_corr < 0.7:
new_params.max_hit_dur = min(1.5, params.max_hit_dur * rng.uniform(1.1, 1.3))
changes.append(f"env_corr={report.mean_env_corr:.2f} β increased max_hit_dur")
# ββ Diagnosis 7: Unmatched GT samples (some drums never found) ββ
if len(report.unmatched_gt) > 0:
new_params.energy_threshold_db = max(-60, params.energy_threshold_db - rng.uniform(3, 6))
changes.append(f"missed {report.unmatched_gt} β lowered energy threshold")
# If no specific diagnosis triggered, apply small random perturbation
if not changes:
# Explore nearby parameter space
new_params.energy_threshold_db += rng.uniform(-3, 3)
new_params.pre_pad += rng.uniform(-0.002, 0.002)
new_params.pre_pad = max(0.001, new_params.pre_pad)
new_params.min_hit_dur += rng.uniform(-0.01, 0.01)
new_params.min_hit_dur = max(0.01, new_params.min_hit_dur)
changes.append("no specific issue β random exploration")
return new_params, changes
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Main optimization loop
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def run_extraction_with_params(song: SyntheticSong, params: PipelineParams) -> tuple:
"""Run the extraction pipeline with given params on a song.
Returns (clusters, all_hits) or raises on failure."""
from drum_extractor import (
detect_onsets, classify_and_separate_hits,
compute_librosa_embeddings, cluster_hits,
select_best_representatives, synthesize_from_cluster,
)
# Stage 2: Onset detection
hits = detect_onsets(
song.drums_only, song.sr,
pre_pad=params.pre_pad,
min_hit_dur=params.min_hit_dur,
max_hit_dur=params.max_hit_dur,
min_gap=params.min_gap,
energy_threshold_db=params.energy_threshold_db,
)
if len(hits) == 0:
return [], []
# Stage 3: Classify & separate
hits = classify_and_separate_hits(hits, separate_overlaps=params.separate_overlaps)
# Stage 4: Embed & cluster
embeddings = compute_librosa_embeddings(hits)
clusters = cluster_hits(hits, embeddings)
# Stage 5: Select best (using our improved scoring)
for cluster in clusters:
if cluster.count == 1:
cluster.best_hit_idx = 0
continue
scores = []
base_label = cluster.label.rsplit('_', 1)[0]
# Compute cluster radius for representativeness scoring
hit_features = []
for hit in cluster.hits:
import librosa
feat = np.concatenate([
librosa.feature.mfcc(y=hit.audio, sr=hit.sr, n_mfcc=13).mean(axis=1),
[hit.rms_energy, hit.spectral_centroid, hit.duration]
])
hit_features.append(feat)
hit_features = np.array(hit_features)
mean_f = hit_features.mean(axis=0)
std_f = hit_features.std(axis=0) + 1e-8
hit_features_norm = (hit_features - mean_f) / std_f
centroid = hit_features_norm.mean(axis=0)
dists = np.linalg.norm(hit_features_norm - centroid, axis=1)
radius = dists.max() + 1e-8
for i, hit in enumerate(cluster.hits):
score = drum_sample_score(
hit.audio, hit.sr, base_label,
centroid_dist=dists[i],
cluster_radius=radius,
)
scores.append(score['total'])
cluster.best_hit_idx = int(np.argmax(scores))
# Stage 6: Synthesis
if params.synthesize:
for cluster in clusters:
if cluster.count >= 2:
cluster.synthesized = synthesize_from_cluster(cluster)
return clusters, hits
def run_optimization_loop(
n_iterations: int = 10,
patterns: list = None,
initial_params: PipelineParams = None,
seed: int = 42,
log_callback=None,
) -> OptimizerState:
"""Run the full autonomous optimization loop.
Args:
n_iterations: number of optimization iterations
patterns: list of pattern names to test with (cycles through them)
initial_params: starting pipeline parameters
seed: random seed
log_callback: function(str) called with log messages
"""
if patterns is None:
patterns = ['rock', 'funk', 'halftime']
if initial_params is None:
initial_params = PipelineParams()
rng = np.random.RandomState(seed)
state = OptimizerState(best_params=initial_params.to_dict())
current_params = deepcopy(initial_params)
def log(msg):
if log_callback:
log_callback(msg)
print(msg)
log(f"Starting optimization loop: {n_iterations} iterations")
log(f"Patterns: {patterns}")
for i in range(n_iterations):
t0 = time.time()
pattern_name = patterns[i % len(patterns)]
song_seed = seed + i * 17 # different song each iteration
log(f"\n{'='*60}")
log(f"ITERATION {i+1}/{n_iterations} β pattern={pattern_name}, seed={song_seed}")
log(f"{'='*60}")
try:
# 1. Generate synthetic song
log(" Generating synthetic song...")
song = generate_test_song(
pattern_name=pattern_name,
bars=4,
bpm=100 + rng.randint(0, 40) * 2, # vary BPM
variation='medium',
seed=song_seed,
)
log(f" β {song.duration:.1f}s, {song.bpm}BPM, "
f"{len(song.hits)} hits, {len(song.samples)} sample types")
# 2. Run extraction
log(f" Running extraction with params: threshold={current_params.energy_threshold_db:.1f}dB, "
f"pre_pad={current_params.pre_pad:.3f}, min_dur={current_params.min_hit_dur:.3f}")
clusters, all_hits = run_extraction_with_params(song, current_params)
log(f" β {len(clusters)} clusters, {len(all_hits)} total hits")
# 3. Evaluate
log(" Evaluating against ground truth...")
gt_samples = {name: s.audio for name, s in song.samples.items()}
gt_hit_map = [
{'sample': h.sample_name, 'onset': h.onset_time, 'velocity': h.velocity}
for h in song.hits
]
report = evaluate_extraction(
extracted_clusters=clusters,
gt_samples=gt_samples,
gt_hit_map=gt_hit_map,
sr=song.sr,
all_hits=all_hits,
pipeline_params=current_params.to_dict(),
)
duration = time.time() - t0
log(f" RESULTS:")
log(f" Overall Score: {report.overall_score:.1f}/100")
log(f" SI-SDR: {report.mean_si_sdr:.1f} dB")
log(f" Sample Score: {report.mean_sample_score:.1f}/100")
log(f" Env Corr: {report.mean_env_corr:.3f}")
log(f" Onset Error: {report.mean_onset_error_ms:.1f} ms")
log(f" Hit Count Acc: {report.hit_count_accuracy:.2f}")
log(f" Matched: {len(report.matches)}/{len(song.samples)}")
if report.unmatched_gt:
log(f" β Unmatched GT: {report.unmatched_gt}")
# Record iteration
result = IterationResult(
iteration=i,
params=current_params.to_dict(),
eval_report=report_to_dict(report),
overall_score=report.overall_score,
duration_seconds=duration,
test_config={'pattern': pattern_name, 'bpm': song.bpm, 'seed': song_seed},
timestamp=time.strftime('%Y-%m-%d %H:%M:%S'),
)
state.history.append(result)
# Update best
if report.overall_score > state.best_score:
state.best_score = report.overall_score
state.best_params = current_params.to_dict()
log(f" β
NEW BEST SCORE: {report.overall_score:.1f}")
# 4. Tune parameters for next iteration
new_params, changes = diagnose_and_perturb(current_params, report, rng)
log(f" Parameter adjustments:")
for change in changes:
log(f" β {change}")
current_params = new_params
except Exception as e:
log(f" β ERROR: {e}")
log(traceback.format_exc())
# On error, try random perturbation
current_params.energy_threshold_db += rng.uniform(-5, 5)
state.history.append(IterationResult(
iteration=i,
params=current_params.to_dict(),
eval_report={'error': str(e)},
overall_score=0.0,
duration_seconds=time.time() - t0,
test_config={'pattern': pattern_name},
timestamp=time.strftime('%Y-%m-%d %H:%M:%S'),
))
state.iteration = i + 1
log(f"\n{'='*60}")
log(f"OPTIMIZATION COMPLETE")
log(f"{'='*60}")
log(f" Best score: {state.best_score:.1f}/100")
log(f" Best params: {json.dumps(state.best_params, indent=2)}")
return state
|