File size: 24,158 Bytes
fcf261a
 
2d1fee2
d1fa59c
2d1fee2
 
fcf261a
 
d1fa59c
fcf261a
 
 
 
f89e0aa
fcf261a
f89e0aa
fcf261a
 
 
 
 
 
f89e0aa
 
 
2d1fee2
fcf261a
 
 
f89e0aa
 
fcf261a
2d1fee2
fcf261a
2d1fee2
f89e0aa
2d1fee2
f89e0aa
2d1fee2
 
 
 
f89e0aa
fcf261a
2d1fee2
d1fa59c
2d1fee2
 
 
d1fa59c
 
2d1fee2
 
 
 
 
 
 
 
 
 
39d4239
2d1fee2
39d4239
 
 
2d1fee2
 
 
 
 
 
 
 
 
 
 
 
d1fa59c
2d1fee2
d1fa59c
2d1fee2
 
 
 
 
 
 
 
 
 
 
 
 
39d4239
2d1fee2
39d4239
2d1fee2
39d4239
2d1fee2
39d4239
 
 
fcf261a
2d1fee2
 
 
 
 
 
 
 
 
 
39d4239
2d1fee2
 
39d4239
2d1fee2
 
 
 
fcf261a
39d4239
2d1fee2
 
 
 
 
f89e0aa
 
2d1fee2
 
 
 
39d4239
 
2d1fee2
 
 
39d4239
 
f89e0aa
39d4239
2d1fee2
 
 
 
39d4239
2d1fee2
 
 
39d4239
2d1fee2
 
 
39d4239
1243e53
2d1fee2
 
d1fa59c
39d4239
2d1fee2
 
 
 
 
 
39d4239
 
 
d1fa59c
2d1fee2
 
39d4239
 
 
 
 
2d1fee2
 
39d4239
2d1fee2
1243e53
2d1fee2
 
d1fa59c
2d1fee2
 
39d4239
fcf261a
2d1fee2
 
f89e0aa
2d1fee2
 
 
 
 
39d4239
2d1fee2
 
39d4239
 
2d1fee2
39d4239
2d1fee2
 
 
d1fa59c
39d4239
fcf261a
d1fa59c
2d1fee2
fcf261a
2d1fee2
d1fa59c
 
2d1fee2
39d4239
 
 
 
 
 
 
 
d1fa59c
39d4239
2d1fee2
39d4239
fcf261a
2d1fee2
 
 
d1fa59c
2d1fee2
fcf261a
 
2d1fee2
 
 
 
 
 
39d4239
2d1fee2
39d4239
 
 
d1fa59c
 
 
 
0a6fa16
d1fa59c
 
2d1fee2
 
39d4239
 
0a6fa16
2d1fee2
0a6fa16
2d1fee2
d1fa59c
39d4239
2d1fee2
f89e0aa
2d1fee2
 
 
 
39d4239
0a6fa16
2d1fee2
 
39d4239
2d1fee2
 
 
0a6fa16
2d1fee2
 
 
39d4239
d1fa59c
0a6fa16
d1fa59c
 
39d4239
 
2d1fee2
39d4239
2d1fee2
 
 
 
 
 
 
39d4239
2d1fee2
 
 
 
 
 
 
 
 
39d4239
2d1fee2
 
 
 
39d4239
 
2d1fee2
 
39d4239
2d1fee2
 
 
 
39d4239
2d1fee2
 
39d4239
2d1fee2
39d4239
2d1fee2
39d4239
2d1fee2
 
 
 
39d4239
2d1fee2
 
39d4239
2d1fee2
 
39d4239
 
 
 
 
 
 
 
2d1fee2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39d4239
2d1fee2
39d4239
 
2d1fee2
 
39d4239
2d1fee2
39d4239
 
2d1fee2
 
 
39d4239
 
2d1fee2
 
39d4239
 
 
 
2d1fee2
 
 
39d4239
2d1fee2
39d4239
2d1fee2
39d4239
 
2d1fee2
 
39d4239
 
 
2d1fee2
 
 
39d4239
 
 
2d1fee2
 
39d4239
 
 
2d1fee2
 
 
39d4239
2d1fee2
39d4239
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
#!/usr/bin/env python3
"""
Sample Extractor v8 β€” Auto-tuning with parameter locking.

auto_tune() accepts a `locks` dict: any param name mapped to its fixed value
will be held constant during the search. All other params are swept freely.
"""

import argparse, json, os, sys, warnings, hashlib
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

import librosa, numpy as np, soundfile as sf, torch
from scipy.signal import fftconvolve

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

@dataclass
class Hit:
    audio: np.ndarray; sr: int; onset_time: float; duration: float; index: int
    rms_energy: float = 0.0; spectral_centroid: float = 0.0
    label: str = ""; embedding: Optional[np.ndarray] = None; cluster_id: int = -1
    def save(self, path): sf.write(path, self.audio, self.sr, subtype='PCM_24')

@dataclass
class Cluster:
    cluster_id: int; label: str; hits: list = field(default_factory=list)
    best_hit_idx: int = 0; synthesized: Optional[np.ndarray] = None; midi_note: int = 60
    @property
    def best_hit(self): return self.hits[self.best_hit_idx]
    @property
    def count(self): return len(self.hits)

DEMUCS_MODELS = ["htdemucs","htdemucs_ft","htdemucs_6s","mdx","mdx_extra","mdx_extra_q"]
DEMUCS_STEMS = {
    "htdemucs":["drums","bass","other","vocals"],"htdemucs_ft":["drums","bass","other","vocals"],
    "htdemucs_6s":["drums","bass","other","vocals","guitar","piano"],
    "mdx":["drums","bass","other","vocals"],"mdx_extra":["drums","bass","other","vocals"],
    "mdx_extra_q":["drums","bass","other","vocals"],
}

# ─── Cache ────────────────────────────────────────────────────────────────────
_cache = {}
def _audio_hash(a): return hashlib.md5(a[:4000].tobytes()).hexdigest()
def cache_get(k): return _cache.get(k)
def cache_set(k,v): _cache[k]=v; return v
def cache_clear(): _cache.clear()

# ─── Stage 1 ──────────────────────────────────────────────────────────────────
def extract_stem(audio_path,stem="drums",device="cpu",model_name="htdemucs_ft",shifts=1,overlap=0.25):
    if stem=="all":
        y,sr=librosa.load(audio_path,sr=44100,mono=True); return y.astype(np.float32),sr
    with open(audio_path,'rb') as f: fh=hashlib.md5(f.read(200000)).hexdigest()
    ck=("stem",fh,stem,model_name,shifts,overlap); c=cache_get(ck)
    if c: print(f"[Stage 1] Cached {stem}"); return c
    from demucs.pretrained import get_model; from demucs.apply import apply_model
    print(f"[Stage 1] {stem} with {model_name}...")
    model=get_model(model_name); model.eval().to(device); sr=model.samplerate
    if stem not in model.sources: raise ValueError(f"'{stem}' not in {model.sources}")
    a,_=librosa.load(audio_path,sr=sr,mono=False)
    if a.ndim==1: a=np.stack([a,a])
    elif a.shape[0]>2: a=a[:2]
    elif a.shape[0]==1: a=np.concatenate([a,a],axis=0)
    wav=torch.from_numpy(a).float().unsqueeze(0).to(device)
    with torch.no_grad(): src=apply_model(model,wav,device=device,shifts=shifts,split=True,overlap=overlap)
    r=src[0,model.sources.index(stem)].mean(dim=0).cpu().numpy()
    print(f"  βœ“ {len(r)/sr:.1f}s"); return cache_set(ck,(r.astype(np.float32),sr))

# ─── Stage 2 ──────────────────────────────────────────────────────────────────
def detect_onsets(y,sr,pre_pad=0.005,min_dur=0.02,max_dur=1.5,min_gap=0.03,
                  energy_threshold_db=-35.0,mode="auto",backtrack=True,onset_delta=0.12):
    print(f"[Stage 2] Onsets (delta={onset_delta}, energyβ‰₯{energy_threshold_db}dB)...")
    if mode=="percussive": oe=librosa.onset.onset_strength(y=y,sr=sr,aggregate=np.median,fmax=8000)
    elif mode=="harmonic": yh,_=librosa.effects.hpss(y); oe=librosa.onset.onset_strength(y=yh,sr=sr,fmax=8000,lag=2,max_size=3)
    elif mode=="broadband": oe=librosa.onset.onset_strength(y=y,sr=sr)
    else:
        yh,_=librosa.effects.hpss(y)
        def _n(x): m=x.max(); return x/m if m>0 else x
        oe=np.maximum.reduce([_n(librosa.onset.onset_strength(y=y,sr=sr,fmin=20,fmax=250,aggregate=np.median)),
            _n(librosa.onset.onset_strength(y=y,sr=sr,fmin=250,fmax=4000,aggregate=np.median)),
            _n(librosa.onset.onset_strength(y=y,sr=sr,fmin=4000,fmax=min(sr//2,20000),aggregate=np.median)),
            _n(librosa.onset.onset_strength(y=yh,sr=sr,lag=2))])
    w=max(1,int(min_gap*sr/512))
    fr=librosa.onset.onset_detect(onset_envelope=oe,sr=sr,wait=w,pre_avg=3,post_avg=3,
        pre_max=3,post_max=5,delta=onset_delta,backtrack=backtrack,units='frames')
    times=librosa.frames_to_time(fr,sr=sr); print(f"  Raw: {len(times)}")
    thr=10**(energy_threshold_db/20); hits=[]
    for i,t in enumerate(times):
        s=max(0,int((t-pre_pad)*sr))
        e=min(int(times[i+1]*sr) if i+1<len(times) else len(y),s+int(max_dur*sr))
        seg=y[s:e]
        if len(seg)<int(min_dur*sr): continue
        rms=np.sqrt(np.mean(seg**2))
        if rms<thr: continue
        fl=min(int(0.005*sr),len(seg)//4)
        if fl>0: seg=seg.copy(); seg[-fl:]*=np.linspace(1,0,fl)
        sc=float(librosa.feature.spectral_centroid(y=seg,sr=sr).mean())
        hits.append(Hit(audio=seg,sr=sr,onset_time=t,duration=len(seg)/sr,
                        index=len(hits),rms_energy=float(rms),spectral_centroid=sc))
    print(f"  βœ“ {len(hits)} hits"); return hits

# ─── Stage 3 ──────────────────────────────────────────────────────────────────
LABEL_RULES=[("kick",lambda lr,mr,hr,c,zcr,d:lr>0.5 and c<800),
    ("hihat_closed",lambda lr,mr,hr,c,zcr,d:hr>0.35 and c>4000 and d<0.15),
    ("hihat_open",lambda lr,mr,hr,c,zcr,d:hr>0.35 and c>4000 and d>=0.15),
    ("cymbal",lambda lr,mr,hr,c,zcr,d:hr>0.25 and c>3000),
    ("snare",lambda lr,mr,hr,c,zcr,d:mr>0.4 and zcr>0.1 and c>1000),
    ("tom",lambda lr,mr,hr,c,zcr,d:lr>0.3 and mr>0.3 and c<1500),
    ("bass",lambda lr,mr,hr,c,zcr,d:lr>0.6 and c<400 and d>0.2),
    ("vocal",lambda lr,mr,hr,c,zcr,d:mr>0.5 and 500<c<3000 and zcr<0.15),
    ("bright",lambda lr,mr,hr,c,zcr,d:c>2500),("mid",lambda lr,mr,hr,c,zcr,d:c>800)]
def classify_hit(hit):
    y,sr=hit.audio,hit.sr; D=np.abs(librosa.stft(y,n_fft=2048))
    f=librosa.fft_frequencies(sr=sr,n_fft=2048)
    le=np.sum(D[(f>=20)&(f<200)]**2); me=np.sum(D[(f>=200)&(f<4000)]**2)
    he=np.sum(D[(f>=4000)]**2); t=le+me+he+1e-10; lr,mr,hr=le/t,me/t,he/t
    zcr=float(librosa.feature.zero_crossing_rate(y=y).mean())
    for n,fn in LABEL_RULES:
        if fn(lr,mr,hr,hit.spectral_centroid,zcr,hit.duration): return n
    return "other"
def classify_hits(hits):
    print(f"[Stage 3] Classifying {len(hits)}...")
    for h in hits: h.label=classify_hit(h)
    c=defaultdict(int)
    for h in hits: c[h.label]+=1
    for l,n in sorted(c.items(),key=lambda x:-x[1]): print(f"    {l}: {n}")
    return hits

# ─── Stage 4 ──────────────────────────────────────────────────────────────────
def ncc_max(a,b):
    n=min(len(a),len(b)); a,b=a[:n].copy(),b[:n].copy(); a-=a.mean(); b-=b.mean()
    norm=np.sqrt(np.dot(a,a)*np.dot(b,b))
    if norm<1e-10: return 0.0
    return float(np.max(np.abs(fftconvolve(a,b[::-1],mode='full'))))/norm
def build_ncc_distance_matrix(hits,max_compare_samples=8820):
    key=("ncc_dist",tuple(_audio_hash(h.audio) for h in hits),max_compare_samples)
    c=cache_get(key)
    if c is not None: print(f"  Cached NCC matrix"); return c
    N=len(hits); D=np.zeros((N,N),dtype=np.float32)
    for i in range(N):
        ai=hits[i].audio[:max_compare_samples]
        for j in range(i+1,N): D[i,j]=D[j,i]=max(0.0,1.0-ncc_max(ai,hits[j].audio[:max_compare_samples]))
    return cache_set(key,D)
def _labels_to_clusters(labels,hits):
    cm=defaultdict(list)
    for i,l in enumerate(labels): cm[l].append(i)
    clusters=[]
    for _,idx in sorted(cm.items()):
        v=defaultdict(int)
        for i in idx: v[hits[i].label]+=1
        maj=max(v,key=v.get); ex=sum(1 for c in clusters if c.label.rsplit('_',1)[0]==maj)
        clusters.append(Cluster(cluster_id=len(clusters),label=f"{maj}_{ex}",hits=[hits[i] for i in idx]))
    clusters.sort(key=lambda c:c.count,reverse=True)
    for i,c in enumerate(clusters): c.cluster_id=i
    return clusters
def cluster_hits(hits,ncc_threshold=0.80,max_compare_ms=0,target_min=0,target_max=0,linkage='average'):
    from sklearn.cluster import AgglomerativeClustering as AC
    if not hits: return []
    N=len(hits); sr=hits[0].sr
    if N==1: return [Cluster(cluster_id=0,label=f"{hits[0].label}_0",hits=[hits[0]])]
    ms=max(int(0.03*sr),int(np.median([len(h.audio) for h in hits]))) if max_compare_ms<=0 else int(max_compare_ms/1000.0*sr)
    print(f"[Stage 4] NCC ({N} hits, {ms/sr*1000:.0f}ms, {linkage})...")
    print(f"  {N*(N-1)//2} pairs..."); D=build_ncc_distance_matrix(hits,max_compare_samples=ms)
    use_t=target_min>0 and target_max>0 and target_max>=target_min
    tmin,tmax=max(1,min(target_min or 1,N)),max(max(1,target_min or 1),min(target_max or N,N))
    if use_t:
        print(f"  Target: {tmin}–{tmax}")
        lo,hi=0.001,1.0; bl,bn,bd=None,-1,0.5
        for _ in range(30):
            mid=(lo+hi)/2; lb=AC(n_clusters=None,distance_threshold=max(0.001,mid),metric='precomputed',linkage=linkage).fit_predict(D)
            n=len(set(lb))
            if tmin<=n<=tmax: bl,bn,bd=lb,n,mid; break
            elif n>tmax: lo=mid
            else: hi=mid
            if bl is None or abs(n-(tmin+tmax)/2)<abs(bn-(tmin+tmax)/2): bl,bn,bd=lb,n,mid
        if bn<tmin or bn>tmax:
            tm=min((tmin+tmax)//2,N-1); print(f"  Fallback n={tm}")
            try: bl=AC(n_clusters=tm,metric='precomputed',linkage=linkage).fit_predict(D); bn=tm
            except: pass
        labels=bl; print(f"  β†’ {bn}")
    else:
        dt=max(0.001,1.0-ncc_threshold); print(f"  Fixed dist≀{dt:.3f}")
        labels=AC(n_clusters=None,distance_threshold=dt,metric='precomputed',linkage=linkage).fit_predict(D)
    print(f"  βœ“ {len(set(labels))} clusters")
    cl=_labels_to_clusters(labels,hits)
    for c in cl: print(f"    {c.label}: {c.count}")
    return cl

# ─── Stage 5 ──────────────────────────────────────────────────────────────────
def sample_quality_score(y,sr,label="other"):
    import scipy.stats
    re=librosa.feature.rms(y=y,frame_length=512,hop_length=128)[0]
    if len(re)>=10:
        pk=np.argmax(re); po=re[pk:]
        c1=max(0,1.0-np.mean(po[-max(3,len(po)//5):])/( re[pk]+1e-8)*5)
        c2=(max(0,scipy.stats.linregress(np.arange(len(po)),np.log(po+1e-8))[2]**2) if scipy.stats.linregress(np.arange(len(po)),np.log(po+1e-8))[0]<0 else 0.0) if len(po)>=5 else 0.0
    else: c1,c2=0.5,0.0
    comp=c1*0.6+c2*0.4; snr=10*np.log10(np.percentile(y**2,99)/(np.percentile(y**2,10)+1e-12))
    ns=np.clip((snr-10)/40,0,1); ons=librosa.onset.onset_detect(y=y,sr=sr,units='samples',backtrack=True)
    if len(ons)>0:
        o=int(ons[0]); pre=y[max(0,o-int(sr*.02)):o]; sig=y[o:o+int(sr*.1)]
        np2=np.clip((-10*np.log10(np.mean(pre**2+1e-12)/np.mean(sig**2+1e-12))-5)/30,0,1) if len(pre)>10 and len(sig)>10 else 0.5
    else: np2=0.5
    cl=ns*0.5+np2*0.5; oe=librosa.onset.onset_strength(y=y,sr=sr)
    oq=float(np.clip((float(np.max(oe)/(np.mean(oe)+1e-8)) if len(oe)>1 else 1.0-1.0)/5.0,0,1))
    return {'total':float((comp*0.30+cl*0.40+oq*0.20+0.5*0.10)*100),'completeness':float(comp),'cleanness':float(cl),'onset_quality':float(oq)}
def select_best(clusters):
    print(f"[Stage 5] Selecting best...")
    for c in clusters:
        if c.count<=1: c.best_hit_idx=0; continue
        c.best_hit_idx=int(np.argmax([sample_quality_score(h.audio,h.sr,c.label.rsplit('_',1)[0])['total'] for h in c.hits]))

# ─── Stage 6 ──────────────────────────────────────────────────────────────────
def synthesize_from_cluster(cluster):
    if cluster.count<2: return None
    tl=int(np.median([len(h.audio) for h in cluster.hits])); al,wt=[],[]; pp=None
    for i,h in enumerate(cluster.hits):
        a=h.audio.copy(); p=np.argmax(np.abs(a))
        if pp is None: pp=p
        s=pp-p
        if s>0: a=np.pad(a,(s,0))
        elif s<0: a=a[-s:]
        a=a[:tl] if len(a)>=tl else np.pad(a,(0,tl-len(a)))
        pk=np.abs(a).max()
        if pk>0: a=a/pk
        al.append(a); wt.append(2.0 if i==cluster.best_hit_idx else 1.0)
    al=np.array(al); w=np.array(wt); w/=w.sum(); sy=np.average(al,axis=0,weights=w); pk=np.abs(sy).max()
    return (sy*0.95/pk).astype(np.float32) if pk>0 else sy.astype(np.float32)

# ─── Stage 7 ──────────────────────────────────────────────────────────────────
def build_midi(clusters,bpm=120.0):
    import pretty_midi; pm=pretty_midi.PrettyMIDI(initial_tempo=bpm)
    for i,c in enumerate(clusters): c.midi_note=min(36+i,127)
    inst=pretty_midi.Instrument(program=0,is_drum=True,name='Samples'); pm.instruments.append(inst)
    for c in clusters:
        for h in c.hits:
            inst.notes.append(pretty_midi.Note(velocity=max(1,min(127,int(h.rms_energy/0.3*127))),
                pitch=c.midi_note,start=h.onset_time,end=h.onset_time+max(h.duration,0.05)))
    inst.notes.sort(key=lambda n:n.start); return pm
def export_midi(clusters,path,bpm=120.0):
    pm=build_midi(clusters,bpm); pm.write(path); print(f"  βœ“ MIDI: {len(pm.instruments[0].notes)} notes"); return pm
def detect_bpm(y,sr):
    ck=("bpm",_audio_hash(y),sr); c=cache_get(ck)
    if c: return c
    oe=librosa.onset.onset_strength(y=y,sr=sr,aggregate=np.median)
    bpm=float(librosa.feature.tempo(onset_envelope=oe,sr=sr).item())
    _,beats=librosa.beat.beat_track(onset_envelope=oe,sr=sr,units='time')
    if len(beats)>2:
        ibi=60.0/float(np.median(np.diff(beats)))
        for c in [bpm,ibi]:
            if 70<=c<=200: bpm=c; break
        else:
            if bpm<70: bpm*=2
            elif bpm>200: bpm/=2
    return cache_set(ck,round(bpm,1))
def render_midi_with_samples(clusters,sr=44100):
    me=max((h.onset_time+h.duration for c in clusters for h in c.hits),default=1.0)
    buf=np.zeros(int((me+1.0)*sr),dtype=np.float64)
    for c in clusters:
        s=c.best_hit.audio.astype(np.float64); re=c.best_hit.rms_energy if c.best_hit.rms_energy>0 else 0.1
        for h in c.hits:
            vs=min(2.0,h.rms_energy/(re+1e-8))**0.5; i=int(h.onset_time*sr); e=i+len(s)
            if e>len(buf): buf=np.concatenate([buf,np.zeros(e-len(buf))])
            buf[i:e]+=s*vs
    pk=np.abs(buf).max(); return (buf/pk*0.9).astype(np.float32) if pk>1e-8 else buf.astype(np.float32)
def build_sample_map(clusters):
    return {c.midi_note:{'label':c.label,'count':c.count,'duration_ms':int(c.best_hit.duration*1000)} for c in clusters}
def build_archive(clusters,bpm,sr,midi_path=None,rendered_audio=None):
    import zipfile,tempfile,io; zp=tempfile.mktemp(suffix='.zip')
    idx={'bpm':round(bpm,1),'sample_rate':sr,'total_clusters':len(clusters),'total_hits':sum(c.count for c in clusters),'samples':{}}
    with zipfile.ZipFile(zp,'w',compression=zipfile.ZIP_STORED) as zf:
        for c in clusters:
            b=c.best_hit; fn=f"samples/{c.label}.wav"; buf=io.BytesIO()
            sf.write(buf,b.audio,sr,format='WAV',subtype='PCM_24'); zf.writestr(fn,buf.getvalue())
            ot=sorted([h.onset_time for h in c.hits])
            idx['samples'][c.label]={'file':fn,'classification':c.label.rsplit('_',1)[0],'midi_note':c.midi_note,
                'occurrences':c.count,'onset_times_sec':[round(t,4) for t in ot],'duration_sec':round(b.duration,4),
                'rms_energy':round(b.rms_energy,6),'spectral_centroid_hz':round(b.spectral_centroid,1)}
            if c.synthesized is not None:
                sf2=f"samples/{c.label}__synth.wav"; b2=io.BytesIO()
                sf.write(b2,c.synthesized,sr,format='WAV',subtype='PCM_24'); zf.writestr(sf2,b2.getvalue())
                idx['samples'][c.label]['synthesized_file']=sf2
        zf.writestr('index.json',json.dumps(idx,indent=2))
        if midi_path and os.path.exists(midi_path): zf.write(midi_path,'reconstruction.mid')
        if rendered_audio is not None:
            rb=io.BytesIO(); sf.write(rb,rendered_audio,sr,format='WAV',subtype='PCM_16')
            zf.writestr('rendered_reconstruction.wav',rb.getvalue())
    return zp

# ─── Auto-tuner with parameter locking ───────────────────────────────────────

def _spectral_envelope_corr(orig,rend,sr,n_fft=4096,hop=2048):
    n=min(len(orig),len(rend))
    if n<n_fft: return 0.0
    eo=np.abs(librosa.stft(orig[:n],n_fft=n_fft,hop_length=hop)).mean(axis=1)
    er=np.abs(librosa.stft(rend[:n],n_fft=n_fft,hop_length=hop)).mean(axis=1)
    if eo.std()<1e-10 or er.std()<1e-10: return 0.0
    return float(np.corrcoef(eo,er)[0,1])

def _rms_envelope_corr(orig,rend,sr,hop=1024):
    n=min(len(orig),len(rend))
    if n<hop*4: return 0.0
    ro=librosa.feature.rms(y=orig[:n],hop_length=hop)[0]; rr=librosa.feature.rms(y=rend[:n],hop_length=hop)[0]
    n2=min(len(ro),len(rr))
    if n2<4: return 0.0
    ro,rr=ro[:n2],rr[:n2]
    if ro.std()<1e-10 or rr.std()<1e-10: return 0.0
    return float(np.corrcoef(ro,rr)[0,1])

def _reconstruction_score(orig,rend,sr):
    sc=_spectral_envelope_corr(orig,rend,sr); rc=_rms_envelope_corr(orig,rend,sr)
    lr=min(len(rend),len(orig))/(max(len(rend),len(orig))+1)
    return max(0.0,(sc*0.5+rc*0.4+lr*0.1)*100)


def auto_tune(stem_audio, sr, mode="auto", locks=None, log_fn=None):
    """Find best extraction parameters. Locked params are held constant.

    locks: dict mapping param names to fixed values. Supported keys:
        'onset_delta', 'energy_threshold_db', 'min_gap',
        'target_min', 'target_max'
    Any key present in locks will NOT be swept β€” its value is used as-is.

    Example: locks={'target_min': 5, 'target_max': 20}
      β†’ cluster count will always be between 5 and 20, only onset params are tuned.

    Returns: (best_params, best_score, log_lines)
    """
    locks = locks or {}
    log = []
    def _log(m):
        log.append(m)
        if log_fn: log_fn(m)
        print(m)

    locked_names = list(locks.keys())
    _log(f"[Auto-tune] {len(stem_audio)/sr:.1f}s audio, locked: {locked_names or 'none'}")

    # Build onset config grid β€” replace locked values with fixed ones
    base_onset_configs = [
        {"onset_delta": 0.08, "energy_threshold_db": -40, "min_gap": 0.02},
        {"onset_delta": 0.10, "energy_threshold_db": -35, "min_gap": 0.025},
        {"onset_delta": 0.12, "energy_threshold_db": -35, "min_gap": 0.03},
        {"onset_delta": 0.15, "energy_threshold_db": -30, "min_gap": 0.04},
        {"onset_delta": 0.20, "energy_threshold_db": -28, "min_gap": 0.05},
        {"onset_delta": 0.25, "energy_threshold_db": -25, "min_gap": 0.06},
    ]

    # Apply locks to onset configs
    onset_configs = []
    seen = set()
    for oc in base_onset_configs:
        for k in ['onset_delta', 'energy_threshold_db', 'min_gap']:
            if k in locks:
                oc[k] = locks[k]
        # Deduplicate configs that become identical after locking
        key = (oc['onset_delta'], oc['energy_threshold_db'], oc['min_gap'])
        if key not in seen:
            seen.add(key)
            onset_configs.append(oc)

    # Build cluster target grid β€” respect locks
    if 'target_min' in locks and 'target_max' in locks:
        tmin_locked, tmax_locked = int(locks['target_min']), int(locks['target_max'])
        # Generate targets within the locked range
        cluster_targets = sorted(set([tmin_locked, tmax_locked,
            (tmin_locked+tmax_locked)//2,
            tmin_locked + (tmax_locked-tmin_locked)//3,
            tmin_locked + 2*(tmax_locked-tmin_locked)//3]))
        cluster_targets = [t for t in cluster_targets if tmin_locked <= t <= tmax_locked]
        _log(f"  Cluster targets (locked {tmin_locked}–{tmax_locked}): {cluster_targets}")
    elif 'target_min' in locks:
        tmin_locked = int(locks['target_min'])
        cluster_targets = [t for t in [3,5,8,10,15,20,30] if t >= tmin_locked]
        if not cluster_targets: cluster_targets = [tmin_locked]
    elif 'target_max' in locks:
        tmax_locked = int(locks['target_max'])
        cluster_targets = [t for t in [3,5,8,10,15,20,30] if t <= tmax_locked]
        if not cluster_targets: cluster_targets = [tmax_locked]
    else:
        cluster_targets = [3, 5, 8, 10, 15, 20, 30]

    best_score, best_params = -1, {}

    for oc_idx, oc in enumerate(onset_configs):
        _log(f"\n  [{oc_idx+1}/{len(onset_configs)}] delta={oc['onset_delta']}, "
             f"energy={oc['energy_threshold_db']}dB, gap={oc['min_gap']}")
        hits = detect_onsets(stem_audio, sr, mode=mode, **oc)
        if len(hits) < 2: _log(f"    {len(hits)} hits, skip"); continue
        hits = classify_hits(hits)

        for nt in cluster_targets:
            if nt >= len(hits): continue
            clusters = cluster_hits(hits, target_min=nt, target_max=nt, linkage='average')
            if not clusters: continue
            for c in clusters:
                if c.count<=1: c.best_hit_idx=0
                else: c.best_hit_idx=int(np.argmax([h.rms_energy for h in c.hits]))
            rendered = render_midi_with_samples(clusters, sr=sr)
            score = _reconstruction_score(stem_audio, rendered, sr)
            if score > best_score:
                best_score = score
                best_params = {**oc, 'n_clusters': len(clusters),
                               'target_min': nt, 'target_max': nt}
                _log(f"    β˜… target={nt} β†’ {len(clusters)} cl, score={score:.1f} BEST")
            else:
                _log(f"    target={nt} β†’ {len(clusters)} cl, score={score:.1f}")

    # Fine-tune around best
    if best_params:
        bt = best_params.get('target_min', 10)
        fine_oc = {k: best_params[k] for k in ['onset_delta','energy_threshold_db','min_gap']}
        _log(f"\n  Fine-tuning Β±3 around target={bt}...")
        hits = detect_onsets(stem_audio, sr, mode=mode, **fine_oc)
        if len(hits) >= 2:
            hits = classify_hits(hits)
            lo = max(2, int(locks.get('target_min', bt-3)))
            hi = min(len(hits)-1, int(locks.get('target_max', bt+3)))
            for ft in range(lo, hi+1):
                clusters = cluster_hits(hits, target_min=ft, target_max=ft, linkage='average')
                if not clusters: continue
                for c in clusters:
                    if c.count<=1: c.best_hit_idx=0
                    else: c.best_hit_idx=int(np.argmax([h.rms_energy for h in c.hits]))
                rendered = render_midi_with_samples(clusters, sr=sr)
                score = _reconstruction_score(stem_audio, rendered, sr)
                if score > best_score:
                    best_score=score
                    best_params={**fine_oc,'n_clusters':len(clusters),'target_min':ft,'target_max':ft}
                    _log(f"    β˜… target={ft} β†’ {len(clusters)} cl, score={score:.1f} BEST")

    _log(f"\n[Auto-tune] Best: score={best_score:.1f}, params={best_params}")
    return best_params, best_score, log