File size: 15,858 Bytes
d34b37f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
"""
Autonomous parameter optimizer for the drum extraction pipeline.

Runs a loop:
  1. Generate synthetic songs with known ground truth
  2. Run the extraction pipeline with current params
  3. Evaluate extraction quality against ground truth
  4. Use results to tune parameters for next iteration

Uses Bayesian-ish optimization: maintain a history of (params β†’ score),
then perturb the best-so-far params toward improving weak metrics.
"""

import json
import time
import traceback
import numpy as np
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path

from synth_generator import generate_test_song, SyntheticSong
from evaluation import evaluate_extraction, report_to_dict, EvalReport
from quality_metrics import drum_sample_score


@dataclass
class PipelineParams:
    """All tunable parameters of the extraction pipeline."""
    # Onset detection
    pre_pad: float = 0.005
    min_hit_dur: float = 0.03
    max_hit_dur: float = 0.8
    min_gap: float = 0.02
    energy_threshold_db: float = -40.0
    
    # Overlap separation
    separate_overlaps: bool = True
    overlap_energy_threshold: float = 0.15  # band energy ratio to count as significant
    
    # Clustering
    use_clap: bool = False
    
    # Selection weights (must sum to 1.0)
    w_completeness: float = 0.30
    w_cleanness: float = 0.40
    w_onset: float = 0.20
    w_representativeness: float = 0.10
    
    # Synthesis
    synthesize: bool = True
    synth_best_weight: float = 2.0  # weight multiplier for best sample in cluster
    
    def to_dict(self) -> dict:
        return self.__dict__.copy()
    
    @classmethod
    def from_dict(cls, d: dict) -> 'PipelineParams':
        valid_keys = cls.__dataclass_fields__.keys()
        return cls(**{k: v for k, v in d.items() if k in valid_keys})


@dataclass
class IterationResult:
    """Result of one optimization iteration."""
    iteration: int
    params: dict
    eval_report: dict
    overall_score: float
    duration_seconds: float
    test_config: dict        # which synthetic song was used
    timestamp: str


@dataclass
class OptimizerState:
    """Persistent state of the optimizer."""
    history: list = field(default_factory=list)  # [IterationResult]
    best_params: dict = field(default_factory=dict)
    best_score: float = 0.0
    iteration: int = 0


# ─────────────────────────────────────────────────────────────────────────────
# Parameter perturbation strategies
# ─────────────────────────────────────────────────────────────────────────────

def diagnose_and_perturb(params: PipelineParams, report: EvalReport,
                          rng: np.random.RandomState) -> PipelineParams:
    """Analyze evaluation report and intelligently perturb parameters.
    
    Instead of random search, we diagnose specific failure modes from the
    evaluation metrics and adjust the relevant parameters.
    """
    new_params = deepcopy(params)
    changes = []

    # ── Diagnosis 1: Poor onset precision (>20ms mean error) ──
    if report.mean_onset_error_ms > 20:
        # Reduce pre_pad to tighten onset capture
        new_params.pre_pad = max(0.001, params.pre_pad * rng.uniform(0.5, 0.9))
        # Reduce min_gap to catch faster sequences
        new_params.min_gap = max(0.01, params.min_gap * rng.uniform(0.6, 0.9))
        changes.append(f"onset_error={report.mean_onset_error_ms:.1f}ms β†’ tightened pre_pad/min_gap")

    # ── Diagnosis 2: Missing hits (low hit count accuracy) ──
    if report.hit_count_accuracy < 0.7:
        # Lower energy threshold to catch quieter hits
        new_params.energy_threshold_db = max(-60, params.energy_threshold_db - rng.uniform(2, 8))
        # Reduce min_hit_dur to catch shorter sounds
        new_params.min_hit_dur = max(0.01, params.min_hit_dur * rng.uniform(0.5, 0.8))
        changes.append(f"hit_acc={report.hit_count_accuracy:.2f} β†’ lowered threshold/min_dur")

    # ── Diagnosis 3: Too many false hits (extracted >> GT) ──
    total_ext = sum(m.n_hits_extracted for m in report.matches) if report.matches else 0
    total_gt = sum(m.n_hits_gt for m in report.matches) if report.matches else 1
    if total_ext > total_gt * 1.5:
        # Raise energy threshold
        new_params.energy_threshold_db = min(-20, params.energy_threshold_db + rng.uniform(2, 5))
        new_params.min_hit_dur = min(0.08, params.min_hit_dur * rng.uniform(1.1, 1.5))
        changes.append(f"over-extraction ({total_ext} vs {total_gt} GT) β†’ raised threshold")

    # ── Diagnosis 4: Low SI-SDR (poor sample quality) ──
    if report.mean_si_sdr < 5:
        # The extracted samples don't match GT well
        # Try adjusting overlap separation threshold
        new_params.overlap_energy_threshold = params.overlap_energy_threshold + rng.uniform(-0.05, 0.05)
        new_params.overlap_energy_threshold = np.clip(new_params.overlap_energy_threshold, 0.05, 0.4)
        changes.append(f"SI-SDR={report.mean_si_sdr:.1f}dB β†’ adjusted overlap threshold")

    # ── Diagnosis 5: Low sample scores (poor completeness/cleanness) ──
    if report.mean_sample_score < 50:
        # Adjust selection weights
        # More weight on cleanness if we're getting bleed-heavy samples
        new_params.w_cleanness = min(0.6, params.w_cleanness + rng.uniform(0, 0.1))
        new_params.w_completeness = max(0.15, params.w_completeness + rng.uniform(-0.05, 0.05))
        # Renormalize
        total_w = new_params.w_cleanness + new_params.w_completeness + new_params.w_onset + new_params.w_representativeness
        new_params.w_cleanness /= total_w
        new_params.w_completeness /= total_w
        new_params.w_onset /= total_w
        new_params.w_representativeness /= total_w
        changes.append(f"sample_score={report.mean_sample_score:.1f} β†’ adjusted selection weights")

    # ── Diagnosis 6: Low envelope correlation (transient mismatch) ──
    if report.mean_env_corr < 0.7:
        new_params.max_hit_dur = min(1.5, params.max_hit_dur * rng.uniform(1.1, 1.3))
        changes.append(f"env_corr={report.mean_env_corr:.2f} β†’ increased max_hit_dur")

    # ── Diagnosis 7: Unmatched GT samples (some drums never found) ──
    if len(report.unmatched_gt) > 0:
        new_params.energy_threshold_db = max(-60, params.energy_threshold_db - rng.uniform(3, 6))
        changes.append(f"missed {report.unmatched_gt} β†’ lowered energy threshold")

    # If no specific diagnosis triggered, apply small random perturbation
    if not changes:
        # Explore nearby parameter space
        new_params.energy_threshold_db += rng.uniform(-3, 3)
        new_params.pre_pad += rng.uniform(-0.002, 0.002)
        new_params.pre_pad = max(0.001, new_params.pre_pad)
        new_params.min_hit_dur += rng.uniform(-0.01, 0.01)
        new_params.min_hit_dur = max(0.01, new_params.min_hit_dur)
        changes.append("no specific issue β†’ random exploration")

    return new_params, changes


# ─────────────────────────────────────────────────────────────────────────────
# Main optimization loop
# ─────────────────────────────────────────────────────────────────────────────

def run_extraction_with_params(song: SyntheticSong, params: PipelineParams) -> tuple:
    """Run the extraction pipeline with given params on a song. 
    Returns (clusters, all_hits) or raises on failure."""
    from drum_extractor import (
        detect_onsets, classify_and_separate_hits,
        compute_librosa_embeddings, cluster_hits,
        select_best_representatives, synthesize_from_cluster,
    )

    # Stage 2: Onset detection
    hits = detect_onsets(
        song.drums_only, song.sr,
        pre_pad=params.pre_pad,
        min_hit_dur=params.min_hit_dur,
        max_hit_dur=params.max_hit_dur,
        min_gap=params.min_gap,
        energy_threshold_db=params.energy_threshold_db,
    )

    if len(hits) == 0:
        return [], []

    # Stage 3: Classify & separate
    hits = classify_and_separate_hits(hits, separate_overlaps=params.separate_overlaps)

    # Stage 4: Embed & cluster
    embeddings = compute_librosa_embeddings(hits)
    clusters = cluster_hits(hits, embeddings)

    # Stage 5: Select best (using our improved scoring)
    for cluster in clusters:
        if cluster.count == 1:
            cluster.best_hit_idx = 0
            continue

        scores = []
        base_label = cluster.label.rsplit('_', 1)[0]
        
        # Compute cluster radius for representativeness scoring
        hit_features = []
        for hit in cluster.hits:
            import librosa
            feat = np.concatenate([
                librosa.feature.mfcc(y=hit.audio, sr=hit.sr, n_mfcc=13).mean(axis=1),
                [hit.rms_energy, hit.spectral_centroid, hit.duration]
            ])
            hit_features.append(feat)
        hit_features = np.array(hit_features)
        mean_f = hit_features.mean(axis=0)
        std_f = hit_features.std(axis=0) + 1e-8
        hit_features_norm = (hit_features - mean_f) / std_f
        centroid = hit_features_norm.mean(axis=0)
        dists = np.linalg.norm(hit_features_norm - centroid, axis=1)
        radius = dists.max() + 1e-8

        for i, hit in enumerate(cluster.hits):
            score = drum_sample_score(
                hit.audio, hit.sr, base_label,
                centroid_dist=dists[i],
                cluster_radius=radius,
            )
            scores.append(score['total'])
        
        cluster.best_hit_idx = int(np.argmax(scores))

    # Stage 6: Synthesis
    if params.synthesize:
        for cluster in clusters:
            if cluster.count >= 2:
                cluster.synthesized = synthesize_from_cluster(cluster)

    return clusters, hits


def run_optimization_loop(
    n_iterations: int = 10,
    patterns: list = None,
    initial_params: PipelineParams = None,
    seed: int = 42,
    log_callback=None,
) -> OptimizerState:
    """Run the full autonomous optimization loop.
    
    Args:
        n_iterations: number of optimization iterations
        patterns: list of pattern names to test with (cycles through them)
        initial_params: starting pipeline parameters
        seed: random seed
        log_callback: function(str) called with log messages
    """
    if patterns is None:
        patterns = ['rock', 'funk', 'halftime']
    if initial_params is None:
        initial_params = PipelineParams()

    rng = np.random.RandomState(seed)
    state = OptimizerState(best_params=initial_params.to_dict())
    current_params = deepcopy(initial_params)

    def log(msg):
        if log_callback:
            log_callback(msg)
        print(msg)

    log(f"Starting optimization loop: {n_iterations} iterations")
    log(f"Patterns: {patterns}")

    for i in range(n_iterations):
        t0 = time.time()
        pattern_name = patterns[i % len(patterns)]
        song_seed = seed + i * 17  # different song each iteration

        log(f"\n{'='*60}")
        log(f"ITERATION {i+1}/{n_iterations} β€” pattern={pattern_name}, seed={song_seed}")
        log(f"{'='*60}")

        try:
            # 1. Generate synthetic song
            log("  Generating synthetic song...")
            song = generate_test_song(
                pattern_name=pattern_name,
                bars=4,
                bpm=100 + rng.randint(0, 40) * 2,  # vary BPM
                variation='medium',
                seed=song_seed,
            )
            log(f"    β†’ {song.duration:.1f}s, {song.bpm}BPM, "
                f"{len(song.hits)} hits, {len(song.samples)} sample types")

            # 2. Run extraction
            log(f"  Running extraction with params: threshold={current_params.energy_threshold_db:.1f}dB, "
                f"pre_pad={current_params.pre_pad:.3f}, min_dur={current_params.min_hit_dur:.3f}")
            clusters, all_hits = run_extraction_with_params(song, current_params)
            log(f"    β†’ {len(clusters)} clusters, {len(all_hits)} total hits")

            # 3. Evaluate
            log("  Evaluating against ground truth...")
            gt_samples = {name: s.audio for name, s in song.samples.items()}
            gt_hit_map = [
                {'sample': h.sample_name, 'onset': h.onset_time, 'velocity': h.velocity}
                for h in song.hits
            ]

            report = evaluate_extraction(
                extracted_clusters=clusters,
                gt_samples=gt_samples,
                gt_hit_map=gt_hit_map,
                sr=song.sr,
                all_hits=all_hits,
                pipeline_params=current_params.to_dict(),
            )

            duration = time.time() - t0

            log(f"  RESULTS:")
            log(f"    Overall Score: {report.overall_score:.1f}/100")
            log(f"    SI-SDR:        {report.mean_si_sdr:.1f} dB")
            log(f"    Sample Score:  {report.mean_sample_score:.1f}/100")
            log(f"    Env Corr:      {report.mean_env_corr:.3f}")
            log(f"    Onset Error:   {report.mean_onset_error_ms:.1f} ms")
            log(f"    Hit Count Acc: {report.hit_count_accuracy:.2f}")
            log(f"    Matched:       {len(report.matches)}/{len(song.samples)}")
            if report.unmatched_gt:
                log(f"    ⚠ Unmatched GT: {report.unmatched_gt}")

            # Record iteration
            result = IterationResult(
                iteration=i,
                params=current_params.to_dict(),
                eval_report=report_to_dict(report),
                overall_score=report.overall_score,
                duration_seconds=duration,
                test_config={'pattern': pattern_name, 'bpm': song.bpm, 'seed': song_seed},
                timestamp=time.strftime('%Y-%m-%d %H:%M:%S'),
            )
            state.history.append(result)

            # Update best
            if report.overall_score > state.best_score:
                state.best_score = report.overall_score
                state.best_params = current_params.to_dict()
                log(f"  β˜… NEW BEST SCORE: {report.overall_score:.1f}")

            # 4. Tune parameters for next iteration
            new_params, changes = diagnose_and_perturb(current_params, report, rng)
            log(f"  Parameter adjustments:")
            for change in changes:
                log(f"    β†’ {change}")
            current_params = new_params

        except Exception as e:
            log(f"  βœ— ERROR: {e}")
            log(traceback.format_exc())
            # On error, try random perturbation
            current_params.energy_threshold_db += rng.uniform(-5, 5)
            state.history.append(IterationResult(
                iteration=i,
                params=current_params.to_dict(),
                eval_report={'error': str(e)},
                overall_score=0.0,
                duration_seconds=time.time() - t0,
                test_config={'pattern': pattern_name},
                timestamp=time.strftime('%Y-%m-%d %H:%M:%S'),
            ))

        state.iteration = i + 1

    log(f"\n{'='*60}")
    log(f"OPTIMIZATION COMPLETE")
    log(f"{'='*60}")
    log(f"  Best score: {state.best_score:.1f}/100")
    log(f"  Best params: {json.dumps(state.best_params, indent=2)}")

    return state