rikhoffbauer2 commited on
Commit
2d1fee2
·
verified ·
1 Parent(s): 9cb951c

v8: Parameter locking for auto-tune — lock any param to constrain the search

Browse files
Files changed (1) hide show
  1. sample_extractor.py +252 -356
sample_extractor.py CHANGED
@@ -1,11 +1,9 @@
1
  #!/usr/bin/env python3
2
  """
3
- Sample Extractor v7 — Auto-tuning via reconstruction quality.
4
 
5
- New: auto_tune() optimizes parameters against the uploaded audio itself.
6
- No ground truth needed measures spectral envelope correlation between
7
- the rendered reconstruction and the original stem. Sweeps onset detection
8
- params and cluster counts, using cached NCC matrices for speed.
9
  """
10
 
11
  import argparse, json, os, sys, warnings, hashlib
@@ -20,278 +18,205 @@ from scipy.signal import fftconvolve
20
  warnings.filterwarnings("ignore", category=FutureWarning)
21
  warnings.filterwarnings("ignore", category=UserWarning)
22
 
23
-
24
- # ─── Data structures ─────────────────────────────────────────────────────────
25
-
26
  @dataclass
27
  class Hit:
28
  audio: np.ndarray; sr: int; onset_time: float; duration: float; index: int
29
  rms_energy: float = 0.0; spectral_centroid: float = 0.0
30
  label: str = ""; embedding: Optional[np.ndarray] = None; cluster_id: int = -1
31
- def save(self, path: str): sf.write(path, self.audio, self.sr, subtype='PCM_24')
32
 
33
  @dataclass
34
  class Cluster:
35
  cluster_id: int; label: str; hits: list = field(default_factory=list)
36
  best_hit_idx: int = 0; synthesized: Optional[np.ndarray] = None; midi_note: int = 60
37
  @property
38
- def best_hit(self) -> Hit: return self.hits[self.best_hit_idx]
39
  @property
40
- def count(self) -> int: return len(self.hits)
41
-
42
 
43
- DEMUCS_MODELS = ["htdemucs", "htdemucs_ft", "htdemucs_6s", "mdx", "mdx_extra", "mdx_extra_q"]
44
  DEMUCS_STEMS = {
45
- "htdemucs": ["drums","bass","other","vocals"], "htdemucs_ft": ["drums","bass","other","vocals"],
46
- "htdemucs_6s": ["drums","bass","other","vocals","guitar","piano"],
47
- "mdx": ["drums","bass","other","vocals"], "mdx_extra": ["drums","bass","other","vocals"],
48
- "mdx_extra_q": ["drums","bass","other","vocals"],
49
  }
50
 
51
-
52
- # ─── Caching ──────────────────────────────────────────────────────────────────
53
-
54
  _cache = {}
55
-
56
- def _audio_hash(audio: np.ndarray) -> str:
57
- return hashlib.md5(audio[:4000].tobytes()).hexdigest()
58
-
59
- def cache_get(key): return _cache.get(key)
60
- def cache_set(key, value): _cache[key] = value; return value
61
  def cache_clear(): _cache.clear()
62
 
63
-
64
- # ─── Stage 1: Stem separation ────────────────────────────────────────────────
65
-
66
- def extract_stem(audio_path, stem="drums", device="cpu",
67
- model_name="htdemucs_ft", shifts=1, overlap=0.25):
68
- if stem == "all":
69
- y, sr = librosa.load(audio_path, sr=44100, mono=True)
70
- return y.astype(np.float32), sr
71
- with open(audio_path, 'rb') as f:
72
- fh = hashlib.md5(f.read(200000)).hexdigest()
73
- ck = ("stem", fh, stem, model_name, shifts, overlap)
74
- c = cache_get(ck)
75
- if c is not None: print(f"[Stage 1] Cached {stem} stem"); return c
76
- from demucs.pretrained import get_model
77
- from demucs.apply import apply_model
78
- print(f"[Stage 1] Extracting '{stem}' with {model_name}...")
79
- model = get_model(model_name); model.eval().to(device); sr = model.samplerate
80
  if stem not in model.sources: raise ValueError(f"'{stem}' not in {model.sources}")
81
- a, _ = librosa.load(audio_path, sr=sr, mono=False)
82
  if a.ndim==1: a=np.stack([a,a])
83
  elif a.shape[0]>2: a=a[:2]
84
  elif a.shape[0]==1: a=np.concatenate([a,a],axis=0)
85
- wav = torch.from_numpy(a).float().unsqueeze(0).to(device)
86
- with torch.no_grad():
87
- src = apply_model(model, wav, device=device, shifts=shifts, split=True, overlap=overlap)
88
- r = src[0, model.sources.index(stem)].mean(dim=0).cpu().numpy()
89
- print(f" ✓ {stem}: {len(r)/sr:.1f}s")
90
- return cache_set(ck, (r.astype(np.float32), sr))
91
-
92
-
93
- # ─── Stage 2: Onset detection ────────────────────────────────────────────────
94
-
95
- def detect_onsets(y, sr, pre_pad=0.005, min_dur=0.02, max_dur=1.5,
96
- min_gap=0.03, energy_threshold_db=-35.0,
97
- mode="auto", backtrack=True, onset_delta=0.12):
98
- print(f"[Stage 2] Onsets (mode={mode}, delta={onset_delta}, energy≥{energy_threshold_db}dB)...")
99
- if mode == "percussive":
100
- oe = librosa.onset.onset_strength(y=y, sr=sr, aggregate=np.median, fmax=8000)
101
- elif mode == "harmonic":
102
- yh, _ = librosa.effects.hpss(y)
103
- oe = librosa.onset.onset_strength(y=yh, sr=sr, fmax=8000, lag=2, max_size=3)
104
- elif mode == "broadband":
105
- oe = librosa.onset.onset_strength(y=y, sr=sr)
106
  else:
107
- yh, yp = librosa.effects.hpss(y)
108
- envs = [librosa.onset.onset_strength(y=y,sr=sr,fmin=20,fmax=250,aggregate=np.median),
109
- librosa.onset.onset_strength(y=y,sr=sr,fmin=250,fmax=4000,aggregate=np.median),
110
- librosa.onset.onset_strength(y=y,sr=sr,fmin=4000,fmax=min(sr//2,20000),aggregate=np.median),
111
- librosa.onset.onset_strength(y=yh,sr=sr,lag=2)]
112
  def _n(x): m=x.max(); return x/m if m>0 else x
113
- oe = np.maximum.reduce([_n(e) for e in envs])
114
- w = max(1, int(min_gap*sr/512))
115
- fr = librosa.onset.onset_detect(onset_envelope=oe, sr=sr, wait=w,
116
- pre_avg=3, post_avg=3, pre_max=3, post_max=5,
117
- delta=onset_delta, backtrack=backtrack, units='frames')
118
- times = librosa.frames_to_time(fr, sr=sr)
119
- print(f" Raw: {len(times)}")
120
- thr = 10**(energy_threshold_db/20); hits = []
121
- for i, t in enumerate(times):
122
- s = max(0, int((t-pre_pad)*sr))
123
- e = min(int(times[i+1]*sr) if i+1<len(times) else len(y), s+int(max_dur*sr))
124
- seg = y[s:e]
 
125
  if len(seg)<int(min_dur*sr): continue
126
- rms = np.sqrt(np.mean(seg**2))
127
  if rms<thr: continue
128
- fl = min(int(0.005*sr), len(seg)//4)
129
  if fl>0: seg=seg.copy(); seg[-fl:]*=np.linspace(1,0,fl)
130
- sc = float(librosa.feature.spectral_centroid(y=seg,sr=sr).mean())
131
  hits.append(Hit(audio=seg,sr=sr,onset_time=t,duration=len(seg)/sr,
132
  index=len(hits),rms_energy=float(rms),spectral_centroid=sc))
133
  print(f" ✓ {len(hits)} hits"); return hits
134
 
135
-
136
- # ─── Stage 3: Classification ─────────────────────────────────────────────────
137
-
138
- LABEL_RULES = [
139
- ("kick", lambda lr,mr,hr,c,zcr,d: lr>0.5 and c<800),
140
- ("hihat_closed", lambda lr,mr,hr,c,zcr,d: hr>0.35 and c>4000 and d<0.15),
141
- ("hihat_open", lambda lr,mr,hr,c,zcr,d: hr>0.35 and c>4000 and d>=0.15),
142
- ("cymbal", lambda lr,mr,hr,c,zcr,d: hr>0.25 and c>3000),
143
- ("snare", lambda lr,mr,hr,c,zcr,d: mr>0.4 and zcr>0.1 and c>1000),
144
- ("tom", lambda lr,mr,hr,c,zcr,d: lr>0.3 and mr>0.3 and c<1500),
145
- ("bass", lambda lr,mr,hr,c,zcr,d: lr>0.6 and c<400 and d>0.2),
146
- ("vocal", lambda lr,mr,hr,c,zcr,d: mr>0.5 and 500<c<3000 and zcr<0.15),
147
- ("bright", lambda lr,mr,hr,c,zcr,d: c>2500),
148
- ("mid", lambda lr,mr,hr,c,zcr,d: c>800),
149
- ]
150
-
151
  def classify_hit(hit):
152
- y,sr = hit.audio, hit.sr
153
- D = np.abs(librosa.stft(y, n_fft=2048))
154
- f = librosa.fft_frequencies(sr=sr, n_fft=2048)
155
  le=np.sum(D[(f>=20)&(f<200)]**2); me=np.sum(D[(f>=200)&(f<4000)]**2)
156
- he=np.sum(D[(f>=4000)]**2); t=le+me+he+1e-10
157
- lr,mr,hr = le/t,me/t,he/t
158
- zcr = float(librosa.feature.zero_crossing_rate(y=y).mean())
159
- for name, fn in LABEL_RULES:
160
- if fn(lr,mr,hr,hit.spectral_centroid,zcr,hit.duration): return name
161
  return "other"
162
-
163
  def classify_hits(hits):
164
- print(f"[Stage 3] Classifying {len(hits)} hits...")
165
- for h in hits: h.label = classify_hit(h)
166
- counts = defaultdict(int)
167
- for h in hits: counts[h.label] += 1
168
- for l, c in sorted(counts.items(), key=lambda x: -x[1]): print(f" {l}: {c}")
169
  return hits
170
 
171
-
172
- # ─── Stage 4: NCC-based clustering ───────────────────────────────────────────
173
-
174
- def ncc_max(a, b):
175
- n = min(len(a), len(b)); a,b = a[:n].copy(), b[:n].copy()
176
- a-=a.mean(); b-=b.mean()
177
- norm = np.sqrt(np.dot(a,a)*np.dot(b,b))
178
  if norm<1e-10: return 0.0
179
  return float(np.max(np.abs(fftconvolve(a,b[::-1],mode='full'))))/norm
180
-
181
- def build_ncc_distance_matrix(hits, max_compare_samples=8820):
182
- key = ("ncc_dist", tuple(_audio_hash(h.audio) for h in hits), max_compare_samples)
183
- c = cache_get(key)
184
  if c is not None: print(f" Cached NCC matrix"); return c
185
  N=len(hits); D=np.zeros((N,N),dtype=np.float32)
186
  for i in range(N):
187
  ai=hits[i].audio[:max_compare_samples]
188
- for j in range(i+1,N):
189
- D[i,j]=D[j,i]=max(0.0, 1.0-ncc_max(ai,hits[j].audio[:max_compare_samples]))
190
- return cache_set(key, D)
191
-
192
- def _labels_to_clusters(labels, hits):
193
- cm = defaultdict(list)
194
  for i,l in enumerate(labels): cm[l].append(i)
195
- clusters = []
196
- for _, idx in sorted(cm.items()):
197
- v = defaultdict(int)
198
  for i in idx: v[hits[i].label]+=1
199
- maj = max(v, key=v.get)
200
- ex = sum(1 for c in clusters if c.label.rsplit('_',1)[0]==maj)
201
- clusters.append(Cluster(cluster_id=len(clusters), label=f"{maj}_{ex}",
202
- hits=[hits[i] for i in idx]))
203
- clusters.sort(key=lambda c: c.count, reverse=True)
204
  for i,c in enumerate(clusters): c.cluster_id=i
205
  return clusters
206
-
207
- def cluster_hits(hits, ncc_threshold=0.80, max_compare_ms=0,
208
- target_min=0, target_max=0, linkage='average'):
209
- from sklearn.cluster import AgglomerativeClustering
210
  if not hits: return []
211
  N=len(hits); sr=hits[0].sr
212
- if N==1: return [Cluster(cluster_id=0, label=f"{hits[0].label}_0", hits=[hits[0]])]
213
- if max_compare_ms<=0:
214
- ms = max(int(0.03*sr), int(np.median([len(h.audio) for h in hits])))
215
- else: ms = int(max_compare_ms/1000.0*sr)
216
- print(f"[Stage 4] NCC clustering ({N} hits, {ms/sr*1000:.0f}ms, {linkage})...")
217
- print(f" Computing {N*(N-1)//2} pairwise distances...")
218
- D = build_ncc_distance_matrix(hits, max_compare_samples=ms)
219
- use_t = target_min>0 and target_max>0 and target_max>=target_min
220
- tmin=max(1,min(target_min or 1,N)); tmax=max(tmin,min(target_max or N,N))
221
  if use_t:
222
  print(f" Target: {tmin}–{tmax}")
223
  lo,hi=0.001,1.0; bl,bn,bd=None,-1,0.5
224
  for _ in range(30):
225
- mid=(lo+hi)/2
226
- agg=AgglomerativeClustering(n_clusters=None,distance_threshold=max(0.001,mid),
227
- metric='precomputed',linkage=linkage)
228
- lb=agg.fit_predict(D); n=len(set(lb))
229
  if tmin<=n<=tmax: bl,bn,bd=lb,n,mid; break
230
  elif n>tmax: lo=mid
231
  else: hi=mid
232
  if bl is None or abs(n-(tmin+tmax)/2)<abs(bn-(tmin+tmax)/2): bl,bn,bd=lb,n,mid
233
  if bn<tmin or bn>tmax:
234
- tm=min((tmin+tmax)//2, N-1)
235
- print(f" Fallback n_clusters={tm}")
236
- try:
237
- agg=AgglomerativeClustering(n_clusters=tm,metric='precomputed',linkage=linkage)
238
- bl=agg.fit_predict(D); bn=tm
239
  except: pass
240
- labels=bl; print(f" → {bn} clusters")
241
  else:
242
- dt=max(0.001, 1.0-ncc_threshold); print(f" Fixed: dist≤{dt:.3f}")
243
- labels=AgglomerativeClustering(n_clusters=None,distance_threshold=dt,
244
- metric='precomputed',linkage=linkage).fit_predict(D)
245
  print(f" ✓ {len(set(labels))} clusters")
246
- cl = _labels_to_clusters(labels, hits)
247
- for c in cl: print(f" {c.label}: {c.count} hits")
248
  return cl
249
 
250
-
251
- # ─── Stage 5: Quality scoring ────────────────────────────────────────────────
252
-
253
- def sample_quality_score(y, sr, label="other"):
254
  import scipy.stats
255
- rms_env=librosa.feature.rms(y=y,frame_length=512,hop_length=128)[0]
256
- if len(rms_env)>=10:
257
- pk=np.argmax(rms_env); post=rms_env[pk:]
258
- c1=max(0,1.0-np.mean(post[-max(3,len(post)//5):])/( rms_env[pk]+1e-8)*5)
259
- if len(post)>=5:
260
- sl,_,r,_,_=scipy.stats.linregress(np.arange(len(post)),np.log(post+1e-8))
261
- c2=max(0,r**2) if sl<0 else r**2*0.3
262
- else: c2=0.0
263
  else: c1,c2=0.5,0.0
264
- comp=c1*0.6+c2*0.4
265
- snr=10*np.log10(np.percentile(y**2,99)/(np.percentile(y**2,10)+1e-12))
266
- ns=np.clip((snr-10)/40,0,1)
267
- ons=librosa.onset.onset_detect(y=y,sr=sr,units='samples',backtrack=True)
268
  if len(ons)>0:
269
  o=int(ons[0]); pre=y[max(0,o-int(sr*.02)):o]; sig=y[o:o+int(sr*.1)]
270
- np2=np.clip((-10*np.log10(np.mean(pre**2+1e-12)/np.mean(sig**2+1e-12))-5)/30,0,1) \
271
- if len(pre)>10 and len(sig)>10 else 0.5
272
  else: np2=0.5
273
- clean=ns*0.5+np2*0.5
274
- oe=librosa.onset.onset_strength(y=y,sr=sr)
275
- sh=float(np.max(oe)/(np.mean(oe)+1e-8)) if len(oe)>1 else 1.0
276
- oq=float(np.clip((sh-1.0)/5.0,0,1))
277
- tot=(comp*0.30+clean*0.40+oq*0.20+0.5*0.10)*100
278
- return {'total':float(tot),'completeness':float(comp),'cleanness':float(clean),'onset_quality':float(oq)}
279
-
280
  def select_best(clusters):
281
  print(f"[Stage 5] Selecting best...")
282
  for c in clusters:
283
  if c.count<=1: c.best_hit_idx=0; continue
284
- sc=[sample_quality_score(h.audio,h.sr,c.label.rsplit('_',1)[0])['total'] for h in c.hits]
285
- c.best_hit_idx=int(np.argmax(sc))
286
-
287
-
288
- # ─── Stage 6: Synthesis ──────────────────────────────────────────────────────
289
 
 
290
  def synthesize_from_cluster(cluster):
291
  if cluster.count<2: return None
292
- tl=int(np.median([len(h.audio) for h in cluster.hits]))
293
- al,wt=[],[]
294
- pp=None
295
  for i,h in enumerate(cluster.hits):
296
  a=h.audio.copy(); p=np.argmax(np.abs(a))
297
  if pp is None: pp=p
@@ -302,33 +227,24 @@ def synthesize_from_cluster(cluster):
302
  pk=np.abs(a).max()
303
  if pk>0: a=a/pk
304
  al.append(a); wt.append(2.0 if i==cluster.best_hit_idx else 1.0)
305
- al=np.array(al); w=np.array(wt); w/=w.sum()
306
- sy=np.average(al,axis=0,weights=w); pk=np.abs(sy).max()
307
  return (sy*0.95/pk).astype(np.float32) if pk>0 else sy.astype(np.float32)
308
 
309
-
310
- # ─── Stage 7: MIDI + rendering ───────────────────────────────────────────────
311
-
312
- def build_midi(clusters, bpm=120.0):
313
- import pretty_midi
314
- pm=pretty_midi.PrettyMIDI(initial_tempo=bpm)
315
  for i,c in enumerate(clusters): c.midi_note=min(36+i,127)
316
- inst=pretty_midi.Instrument(program=0,is_drum=True,name='Extracted Samples')
317
- pm.instruments.append(inst)
318
  for c in clusters:
319
  for h in c.hits:
320
- v=max(1,min(127,int(h.rms_energy/0.3*127)))
321
- inst.notes.append(pretty_midi.Note(velocity=v,pitch=c.midi_note,
322
- start=h.onset_time,end=h.onset_time+max(h.duration,0.05)))
323
- inst.notes.sort(key=lambda n: n.start); return pm
324
-
325
- def export_midi(clusters, path, bpm=120.0):
326
- pm=build_midi(clusters,bpm); pm.write(path)
327
- print(f" ✓ MIDI: {path} ({len(pm.instruments[0].notes)} notes)"); return pm
328
-
329
- def detect_bpm(y, sr):
330
  ck=("bpm",_audio_hash(y),sr); c=cache_get(ck)
331
- if c is not None: return c
332
  oe=librosa.onset.onset_strength(y=y,sr=sr,aggregate=np.median)
333
  bpm=float(librosa.feature.tempo(onset_envelope=oe,sr=sr).item())
334
  _,beats=librosa.beat.beat_track(onset_envelope=oe,sr=sr,units='time')
@@ -339,48 +255,34 @@ def detect_bpm(y, sr):
339
  else:
340
  if bpm<70: bpm*=2
341
  elif bpm>200: bpm/=2
342
- return cache_set(ck, round(bpm,1))
343
-
344
- def render_midi_with_samples(clusters, sr=44100):
345
  me=max((h.onset_time+h.duration for c in clusters for h in c.hits),default=1.0)
346
  buf=np.zeros(int((me+1.0)*sr),dtype=np.float64)
347
  for c in clusters:
348
- s=c.best_hit.audio.astype(np.float64)
349
- re=c.best_hit.rms_energy if c.best_hit.rms_energy>0 else 0.1
350
  for h in c.hits:
351
- vs=min(2.0,h.rms_energy/(re+1e-8))**0.5
352
- i=int(h.onset_time*sr); e=i+len(s)
353
  if e>len(buf): buf=np.concatenate([buf,np.zeros(e-len(buf))])
354
  buf[i:e]+=s*vs
355
- pk=np.abs(buf).max()
356
- return (buf/pk*0.9).astype(np.float32) if pk>1e-8 else buf.astype(np.float32)
357
-
358
  def build_sample_map(clusters):
359
- return {c.midi_note:{'label':c.label,'count':c.count,
360
- 'duration_ms':int(c.best_hit.duration*1000)} for c in clusters}
361
-
362
- def build_archive(clusters, bpm, sr, midi_path=None, rendered_audio=None):
363
- import zipfile, tempfile, io
364
- zp=tempfile.mktemp(suffix='.zip')
365
- idx={'bpm':round(bpm,1),'sample_rate':sr,'total_clusters':len(clusters),
366
- 'total_hits':sum(c.count for c in clusters),'samples':{}}
367
  with zipfile.ZipFile(zp,'w',compression=zipfile.ZIP_STORED) as zf:
368
  for c in clusters:
369
- b=c.best_hit; fn=f"samples/{c.label}.wav"
370
- buf=io.BytesIO(); sf.write(buf,b.audio,sr,format='WAV',subtype='PCM_24')
371
- zf.writestr(fn,buf.getvalue())
372
  ot=sorted([h.onset_time for h in c.hits])
373
- idx['samples'][c.label]={
374
- 'file':fn,'classification':c.label.rsplit('_',1)[0],
375
- 'midi_note':c.midi_note,'occurrences':c.count,
376
- 'onset_times_sec':[round(t,4) for t in ot],
377
- 'duration_sec':round(b.duration,4),
378
- 'rms_energy':round(b.rms_energy,6),
379
- 'spectral_centroid_hz':round(b.spectral_centroid,1)}
380
  if c.synthesized is not None:
381
- sf2=f"samples/{c.label}__synthesized.wav"; b2=io.BytesIO()
382
- sf.write(b2,c.synthesized,sr,format='WAV',subtype='PCM_24')
383
- zf.writestr(sf2,b2.getvalue()); idx['samples'][c.label]['synthesized_file']=sf2
384
  zf.writestr('index.json',json.dumps(idx,indent=2))
385
  if midi_path and os.path.exists(midi_path): zf.write(midi_path,'reconstruction.mid')
386
  if rendered_audio is not None:
@@ -388,66 +290,57 @@ def build_archive(clusters, bpm, sr, midi_path=None, rendered_audio=None):
388
  zf.writestr('rendered_reconstruction.wav',rb.getvalue())
389
  return zp
390
 
 
391
 
392
- # ─── Auto-tuner ──────────────────────────────────────────────────────────────
393
-
394
- def _spectral_envelope_corr(original, rendered, sr, n_fft=4096, hop=2048):
395
- """Spectral envelope correlation between two signals. Phase-insensitive.
396
- Returns correlation coefficient [−1, 1]. Higher = better reconstruction."""
397
- n = min(len(original), len(rendered))
398
- if n < n_fft: return 0.0
399
- S_orig = np.abs(librosa.stft(original[:n], n_fft=n_fft, hop_length=hop))
400
- S_rend = np.abs(librosa.stft(rendered[:n], n_fft=n_fft, hop_length=hop))
401
- # Average over time to get spectral envelope
402
- env_o = S_orig.mean(axis=1)
403
- env_r = S_rend.mean(axis=1)
404
- if env_o.std() < 1e-10 or env_r.std() < 1e-10: return 0.0
405
- return float(np.corrcoef(env_o, env_r)[0, 1])
406
-
407
 
408
- def _rms_envelope_corr(original, rendered, sr, hop=1024):
409
- """RMS amplitude envelope correlation. Measures timing/dynamics match."""
410
- n = min(len(original), len(rendered))
411
- if n < hop * 4: return 0.0
412
- rms_o = librosa.feature.rms(y=original[:n], hop_length=hop)[0]
413
- rms_r = librosa.feature.rms(y=rendered[:n], hop_length=hop)[0]
414
- n2 = min(len(rms_o), len(rms_r))
415
- if n2 < 4: return 0.0
416
- rms_o, rms_r = rms_o[:n2], rms_r[:n2]
417
- if rms_o.std() < 1e-10 or rms_r.std() < 1e-10: return 0.0
418
- return float(np.corrcoef(rms_o, rms_r)[0, 1])
419
 
 
 
 
 
420
 
421
- def _reconstruction_score(original, rendered, sr):
422
- """Combined score [0, 100] measuring how well the reconstruction matches."""
423
- spec_corr = _spectral_envelope_corr(original, rendered, sr)
424
- rms_corr = _rms_envelope_corr(original, rendered, sr)
425
- # Penalize if reconstruction is much shorter/longer
426
- len_ratio = min(len(rendered), len(original)) / (max(len(rendered), len(original)) + 1)
427
- score = (spec_corr * 0.5 + rms_corr * 0.4 + len_ratio * 0.1) * 100
428
- return max(0.0, score)
429
 
 
 
430
 
431
- def auto_tune(stem_audio, sr, mode="auto", log_fn=None):
432
- """Automatically find the best extraction parameters for this audio.
 
 
433
 
434
- Sweeps onset detection params and cluster counts. Uses the cached NCC matrix
435
- so re-clustering is near-instant after the first onset detection.
436
 
437
- Returns: (best_params: dict, best_score: float, log: list[str])
438
-
439
- The returned params can be passed directly to the extraction pipeline.
440
  """
 
441
  log = []
442
- def _log(msg):
443
- log.append(msg)
444
- if log_fn: log_fn(msg)
445
- print(msg)
446
 
447
- _log(f"[Auto-tune] Starting parameter search on {len(stem_audio)/sr:.1f}s audio...")
 
448
 
449
- # Parameter grid — coarse sweep first
450
- onset_configs = [
451
  {"onset_delta": 0.08, "energy_threshold_db": -40, "min_gap": 0.02},
452
  {"onset_delta": 0.10, "energy_threshold_db": -35, "min_gap": 0.025},
453
  {"onset_delta": 0.12, "energy_threshold_db": -35, "min_gap": 0.03},
@@ -456,85 +349,88 @@ def auto_tune(stem_audio, sr, mode="auto", log_fn=None):
456
  {"onset_delta": 0.25, "energy_threshold_db": -25, "min_gap": 0.06},
457
  ]
458
 
459
- cluster_targets = [3, 5, 8, 10, 15, 20, 30]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
 
461
- best_score = -1
462
- best_params = {}
463
- best_clusters = None
464
- results = []
465
 
466
  for oc_idx, oc in enumerate(onset_configs):
467
- _log(f"\n Config {oc_idx+1}/{len(onset_configs)}: delta={oc['onset_delta']}, "
468
- f"energy={oc['energy_threshold_db']}dB, gap={oc['min_gap']}s")
469
-
470
  hits = detect_onsets(stem_audio, sr, mode=mode, **oc)
471
- if len(hits) < 2:
472
- _log(f" → Only {len(hits)} hits, skipping")
473
- continue
474
-
475
  hits = classify_hits(hits)
476
 
477
- # NCC matrix computed once per onset config (cached within)
478
- for n_target in cluster_targets:
479
- if n_target >= len(hits): continue
480
-
481
- clusters = cluster_hits(hits, target_min=n_target, target_max=n_target,
482
- linkage='average')
483
  if not clusters: continue
484
-
485
- # Quick select best + render
486
  for c in clusters:
487
- if c.count <= 1: c.best_hit_idx = 0
488
- else:
489
- energies = [h.rms_energy for h in c.hits]
490
- c.best_hit_idx = int(np.argmax(energies)) # fast: pick loudest
491
-
492
  rendered = render_midi_with_samples(clusters, sr=sr)
493
  score = _reconstruction_score(stem_audio, rendered, sr)
494
-
495
- results.append({**oc, 'n_clusters': len(clusters),
496
- 'target': n_target, 'n_hits': len(hits), 'score': score})
497
-
498
  if score > best_score:
499
  best_score = score
500
- best_params = {**oc, 'n_clusters': len(clusters), 'target_min': n_target,
501
- 'target_max': n_target}
502
- best_clusters = clusters
503
- _log(f" ★ target={n_target} → {len(clusters)} clusters, "
504
- f"score={score:.1f} (NEW BEST)")
505
  else:
506
- _log(f" target={n_target} → {len(clusters)} clusters, score={score:.1f}")
507
 
508
- # Fine-tune: try ±1 around best target with best onset config
509
  if best_params:
510
  bt = best_params.get('target_min', 10)
511
- _log(f"\n Fine-tuning around best (delta={best_params['onset_delta']}, "
512
- f"target{bt})...")
513
-
514
- fine_oc = {k: best_params[k] for k in ['onset_delta', 'energy_threshold_db', 'min_gap']}
515
  hits = detect_onsets(stem_audio, sr, mode=mode, **fine_oc)
516
  if len(hits) >= 2:
517
  hits = classify_hits(hits)
518
- for ft in range(max(2, bt-3), bt+4):
519
- if ft >= len(hits): continue
 
520
  clusters = cluster_hits(hits, target_min=ft, target_max=ft, linkage='average')
521
  if not clusters: continue
522
  for c in clusters:
523
- if c.count <= 1: c.best_hit_idx = 0
524
- else: c.best_hit_idx = int(np.argmax([h.rms_energy for h in c.hits]))
525
  rendered = render_midi_with_samples(clusters, sr=sr)
526
  score = _reconstruction_score(stem_audio, rendered, sr)
527
  if score > best_score:
528
- best_score = score
529
- best_params = {**fine_oc, 'n_clusters': len(clusters),
530
- 'target_min': ft, 'target_max': ft}
531
- best_clusters = clusters
532
- _log(f" ★ target={ft} → {len(clusters)} clusters, "
533
- f"score={score:.1f} (NEW BEST)")
534
-
535
- _log(f"\n[Auto-tune] Best: score={best_score:.1f}, "
536
- f"delta={best_params.get('onset_delta')}, "
537
- f"energy={best_params.get('energy_threshold_db')}dB, "
538
- f"clusters={best_params.get('n_clusters')}")
539
 
 
540
  return best_params, best_score, log
 
1
  #!/usr/bin/env python3
2
  """
3
+ Sample Extractor v8 — Auto-tuning with parameter locking.
4
 
5
+ auto_tune() accepts a `locks` dict: any param name mapped to its fixed value
6
+ will be held constant during the search. All other params are swept freely.
 
 
7
  """
8
 
9
  import argparse, json, os, sys, warnings, hashlib
 
18
  warnings.filterwarnings("ignore", category=FutureWarning)
19
  warnings.filterwarnings("ignore", category=UserWarning)
20
 
 
 
 
21
  @dataclass
22
  class Hit:
23
  audio: np.ndarray; sr: int; onset_time: float; duration: float; index: int
24
  rms_energy: float = 0.0; spectral_centroid: float = 0.0
25
  label: str = ""; embedding: Optional[np.ndarray] = None; cluster_id: int = -1
26
+ def save(self, path): sf.write(path, self.audio, self.sr, subtype='PCM_24')
27
 
28
  @dataclass
29
  class Cluster:
30
  cluster_id: int; label: str; hits: list = field(default_factory=list)
31
  best_hit_idx: int = 0; synthesized: Optional[np.ndarray] = None; midi_note: int = 60
32
  @property
33
+ def best_hit(self): return self.hits[self.best_hit_idx]
34
  @property
35
+ def count(self): return len(self.hits)
 
36
 
37
+ DEMUCS_MODELS = ["htdemucs","htdemucs_ft","htdemucs_6s","mdx","mdx_extra","mdx_extra_q"]
38
  DEMUCS_STEMS = {
39
+ "htdemucs":["drums","bass","other","vocals"],"htdemucs_ft":["drums","bass","other","vocals"],
40
+ "htdemucs_6s":["drums","bass","other","vocals","guitar","piano"],
41
+ "mdx":["drums","bass","other","vocals"],"mdx_extra":["drums","bass","other","vocals"],
42
+ "mdx_extra_q":["drums","bass","other","vocals"],
43
  }
44
 
45
+ # ─── Cache ────────────────────────────────────────────────────────────────────
 
 
46
  _cache = {}
47
+ def _audio_hash(a): return hashlib.md5(a[:4000].tobytes()).hexdigest()
48
+ def cache_get(k): return _cache.get(k)
49
+ def cache_set(k,v): _cache[k]=v; return v
 
 
 
50
  def cache_clear(): _cache.clear()
51
 
52
+ # ─── Stage 1 ──────────────────────────────────────────────────────────────────
53
+ def extract_stem(audio_path,stem="drums",device="cpu",model_name="htdemucs_ft",shifts=1,overlap=0.25):
54
+ if stem=="all":
55
+ y,sr=librosa.load(audio_path,sr=44100,mono=True); return y.astype(np.float32),sr
56
+ with open(audio_path,'rb') as f: fh=hashlib.md5(f.read(200000)).hexdigest()
57
+ ck=("stem",fh,stem,model_name,shifts,overlap); c=cache_get(ck)
58
+ if c: print(f"[Stage 1] Cached {stem}"); return c
59
+ from demucs.pretrained import get_model; from demucs.apply import apply_model
60
+ print(f"[Stage 1] {stem} with {model_name}...")
61
+ model=get_model(model_name); model.eval().to(device); sr=model.samplerate
 
 
 
 
 
 
 
62
  if stem not in model.sources: raise ValueError(f"'{stem}' not in {model.sources}")
63
+ a,_=librosa.load(audio_path,sr=sr,mono=False)
64
  if a.ndim==1: a=np.stack([a,a])
65
  elif a.shape[0]>2: a=a[:2]
66
  elif a.shape[0]==1: a=np.concatenate([a,a],axis=0)
67
+ wav=torch.from_numpy(a).float().unsqueeze(0).to(device)
68
+ with torch.no_grad(): src=apply_model(model,wav,device=device,shifts=shifts,split=True,overlap=overlap)
69
+ r=src[0,model.sources.index(stem)].mean(dim=0).cpu().numpy()
70
+ print(f" ✓ {len(r)/sr:.1f}s"); return cache_set(ck,(r.astype(np.float32),sr))
71
+
72
+ # ─── Stage 2 ──────────────────────────────────────────────────────────────────
73
+ def detect_onsets(y,sr,pre_pad=0.005,min_dur=0.02,max_dur=1.5,min_gap=0.03,
74
+ energy_threshold_db=-35.0,mode="auto",backtrack=True,onset_delta=0.12):
75
+ print(f"[Stage 2] Onsets (delta={onset_delta}, energy≥{energy_threshold_db}dB)...")
76
+ if mode=="percussive": oe=librosa.onset.onset_strength(y=y,sr=sr,aggregate=np.median,fmax=8000)
77
+ elif mode=="harmonic": yh,_=librosa.effects.hpss(y); oe=librosa.onset.onset_strength(y=yh,sr=sr,fmax=8000,lag=2,max_size=3)
78
+ elif mode=="broadband": oe=librosa.onset.onset_strength(y=y,sr=sr)
 
 
 
 
 
 
 
 
 
79
  else:
80
+ yh,_=librosa.effects.hpss(y)
 
 
 
 
81
  def _n(x): m=x.max(); return x/m if m>0 else x
82
+ oe=np.maximum.reduce([_n(librosa.onset.onset_strength(y=y,sr=sr,fmin=20,fmax=250,aggregate=np.median)),
83
+ _n(librosa.onset.onset_strength(y=y,sr=sr,fmin=250,fmax=4000,aggregate=np.median)),
84
+ _n(librosa.onset.onset_strength(y=y,sr=sr,fmin=4000,fmax=min(sr//2,20000),aggregate=np.median)),
85
+ _n(librosa.onset.onset_strength(y=yh,sr=sr,lag=2))])
86
+ w=max(1,int(min_gap*sr/512))
87
+ fr=librosa.onset.onset_detect(onset_envelope=oe,sr=sr,wait=w,pre_avg=3,post_avg=3,
88
+ pre_max=3,post_max=5,delta=onset_delta,backtrack=backtrack,units='frames')
89
+ times=librosa.frames_to_time(fr,sr=sr); print(f" Raw: {len(times)}")
90
+ thr=10**(energy_threshold_db/20); hits=[]
91
+ for i,t in enumerate(times):
92
+ s=max(0,int((t-pre_pad)*sr))
93
+ e=min(int(times[i+1]*sr) if i+1<len(times) else len(y),s+int(max_dur*sr))
94
+ seg=y[s:e]
95
  if len(seg)<int(min_dur*sr): continue
96
+ rms=np.sqrt(np.mean(seg**2))
97
  if rms<thr: continue
98
+ fl=min(int(0.005*sr),len(seg)//4)
99
  if fl>0: seg=seg.copy(); seg[-fl:]*=np.linspace(1,0,fl)
100
+ sc=float(librosa.feature.spectral_centroid(y=seg,sr=sr).mean())
101
  hits.append(Hit(audio=seg,sr=sr,onset_time=t,duration=len(seg)/sr,
102
  index=len(hits),rms_energy=float(rms),spectral_centroid=sc))
103
  print(f" ✓ {len(hits)} hits"); return hits
104
 
105
+ # ─── Stage 3 ──────────────────────────────────────────────────────────────────
106
+ LABEL_RULES=[("kick",lambda lr,mr,hr,c,zcr,d:lr>0.5 and c<800),
107
+ ("hihat_closed",lambda lr,mr,hr,c,zcr,d:hr>0.35 and c>4000 and d<0.15),
108
+ ("hihat_open",lambda lr,mr,hr,c,zcr,d:hr>0.35 and c>4000 and d>=0.15),
109
+ ("cymbal",lambda lr,mr,hr,c,zcr,d:hr>0.25 and c>3000),
110
+ ("snare",lambda lr,mr,hr,c,zcr,d:mr>0.4 and zcr>0.1 and c>1000),
111
+ ("tom",lambda lr,mr,hr,c,zcr,d:lr>0.3 and mr>0.3 and c<1500),
112
+ ("bass",lambda lr,mr,hr,c,zcr,d:lr>0.6 and c<400 and d>0.2),
113
+ ("vocal",lambda lr,mr,hr,c,zcr,d:mr>0.5 and 500<c<3000 and zcr<0.15),
114
+ ("bright",lambda lr,mr,hr,c,zcr,d:c>2500),("mid",lambda lr,mr,hr,c,zcr,d:c>800)]
 
 
 
 
 
 
115
  def classify_hit(hit):
116
+ y,sr=hit.audio,hit.sr; D=np.abs(librosa.stft(y,n_fft=2048))
117
+ f=librosa.fft_frequencies(sr=sr,n_fft=2048)
 
118
  le=np.sum(D[(f>=20)&(f<200)]**2); me=np.sum(D[(f>=200)&(f<4000)]**2)
119
+ he=np.sum(D[(f>=4000)]**2); t=le+me+he+1e-10; lr,mr,hr=le/t,me/t,he/t
120
+ zcr=float(librosa.feature.zero_crossing_rate(y=y).mean())
121
+ for n,fn in LABEL_RULES:
122
+ if fn(lr,mr,hr,hit.spectral_centroid,zcr,hit.duration): return n
 
123
  return "other"
 
124
  def classify_hits(hits):
125
+ print(f"[Stage 3] Classifying {len(hits)}...")
126
+ for h in hits: h.label=classify_hit(h)
127
+ c=defaultdict(int)
128
+ for h in hits: c[h.label]+=1
129
+ for l,n in sorted(c.items(),key=lambda x:-x[1]): print(f" {l}: {n}")
130
  return hits
131
 
132
+ # ─── Stage 4 ──────────────────────────────────────────────────────────────────
133
+ def ncc_max(a,b):
134
+ n=min(len(a),len(b)); a,b=a[:n].copy(),b[:n].copy(); a-=a.mean(); b-=b.mean()
135
+ norm=np.sqrt(np.dot(a,a)*np.dot(b,b))
 
 
 
136
  if norm<1e-10: return 0.0
137
  return float(np.max(np.abs(fftconvolve(a,b[::-1],mode='full'))))/norm
138
+ def build_ncc_distance_matrix(hits,max_compare_samples=8820):
139
+ key=("ncc_dist",tuple(_audio_hash(h.audio) for h in hits),max_compare_samples)
140
+ c=cache_get(key)
 
141
  if c is not None: print(f" Cached NCC matrix"); return c
142
  N=len(hits); D=np.zeros((N,N),dtype=np.float32)
143
  for i in range(N):
144
  ai=hits[i].audio[:max_compare_samples]
145
+ for j in range(i+1,N): D[i,j]=D[j,i]=max(0.0,1.0-ncc_max(ai,hits[j].audio[:max_compare_samples]))
146
+ return cache_set(key,D)
147
+ def _labels_to_clusters(labels,hits):
148
+ cm=defaultdict(list)
 
 
149
  for i,l in enumerate(labels): cm[l].append(i)
150
+ clusters=[]
151
+ for _,idx in sorted(cm.items()):
152
+ v=defaultdict(int)
153
  for i in idx: v[hits[i].label]+=1
154
+ maj=max(v,key=v.get); ex=sum(1 for c in clusters if c.label.rsplit('_',1)[0]==maj)
155
+ clusters.append(Cluster(cluster_id=len(clusters),label=f"{maj}_{ex}",hits=[hits[i] for i in idx]))
156
+ clusters.sort(key=lambda c:c.count,reverse=True)
 
 
157
  for i,c in enumerate(clusters): c.cluster_id=i
158
  return clusters
159
+ def cluster_hits(hits,ncc_threshold=0.80,max_compare_ms=0,target_min=0,target_max=0,linkage='average'):
160
+ from sklearn.cluster import AgglomerativeClustering as AC
 
 
161
  if not hits: return []
162
  N=len(hits); sr=hits[0].sr
163
+ if N==1: return [Cluster(cluster_id=0,label=f"{hits[0].label}_0",hits=[hits[0]])]
164
+ ms=max(int(0.03*sr),int(np.median([len(h.audio) for h in hits]))) if max_compare_ms<=0 else int(max_compare_ms/1000.0*sr)
165
+ print(f"[Stage 4] NCC ({N} hits, {ms/sr*1000:.0f}ms, {linkage})...")
166
+ print(f" {N*(N-1)//2} pairs..."); D=build_ncc_distance_matrix(hits,max_compare_samples=ms)
167
+ use_t=target_min>0 and target_max>0 and target_max>=target_min
168
+ tmin,tmax=max(1,min(target_min or 1,N)),max(max(1,target_min or 1),min(target_max or N,N))
 
 
 
169
  if use_t:
170
  print(f" Target: {tmin}–{tmax}")
171
  lo,hi=0.001,1.0; bl,bn,bd=None,-1,0.5
172
  for _ in range(30):
173
+ mid=(lo+hi)/2; lb=AC(n_clusters=None,distance_threshold=max(0.001,mid),metric='precomputed',linkage=linkage).fit_predict(D)
174
+ n=len(set(lb))
 
 
175
  if tmin<=n<=tmax: bl,bn,bd=lb,n,mid; break
176
  elif n>tmax: lo=mid
177
  else: hi=mid
178
  if bl is None or abs(n-(tmin+tmax)/2)<abs(bn-(tmin+tmax)/2): bl,bn,bd=lb,n,mid
179
  if bn<tmin or bn>tmax:
180
+ tm=min((tmin+tmax)//2,N-1); print(f" Fallback n={tm}")
181
+ try: bl=AC(n_clusters=tm,metric='precomputed',linkage=linkage).fit_predict(D); bn=tm
 
 
 
182
  except: pass
183
+ labels=bl; print(f" → {bn}")
184
  else:
185
+ dt=max(0.001,1.0-ncc_threshold); print(f" Fixed dist≤{dt:.3f}")
186
+ labels=AC(n_clusters=None,distance_threshold=dt,metric='precomputed',linkage=linkage).fit_predict(D)
 
187
  print(f" ✓ {len(set(labels))} clusters")
188
+ cl=_labels_to_clusters(labels,hits)
189
+ for c in cl: print(f" {c.label}: {c.count}")
190
  return cl
191
 
192
+ # ─── Stage 5 ────────────────���─────────────────────────────────────────────────
193
+ def sample_quality_score(y,sr,label="other"):
 
 
194
  import scipy.stats
195
+ re=librosa.feature.rms(y=y,frame_length=512,hop_length=128)[0]
196
+ if len(re)>=10:
197
+ pk=np.argmax(re); po=re[pk:]
198
+ c1=max(0,1.0-np.mean(po[-max(3,len(po)//5):])/( re[pk]+1e-8)*5)
199
+ c2=(max(0,scipy.stats.linregress(np.arange(len(po)),np.log(po+1e-8))[2]**2) if scipy.stats.linregress(np.arange(len(po)),np.log(po+1e-8))[0]<0 else 0.0) if len(po)>=5 else 0.0
 
 
 
200
  else: c1,c2=0.5,0.0
201
+ comp=c1*0.6+c2*0.4; snr=10*np.log10(np.percentile(y**2,99)/(np.percentile(y**2,10)+1e-12))
202
+ ns=np.clip((snr-10)/40,0,1); ons=librosa.onset.onset_detect(y=y,sr=sr,units='samples',backtrack=True)
 
 
203
  if len(ons)>0:
204
  o=int(ons[0]); pre=y[max(0,o-int(sr*.02)):o]; sig=y[o:o+int(sr*.1)]
205
+ np2=np.clip((-10*np.log10(np.mean(pre**2+1e-12)/np.mean(sig**2+1e-12))-5)/30,0,1) if len(pre)>10 and len(sig)>10 else 0.5
 
206
  else: np2=0.5
207
+ cl=ns*0.5+np2*0.5; oe=librosa.onset.onset_strength(y=y,sr=sr)
208
+ oq=float(np.clip((float(np.max(oe)/(np.mean(oe)+1e-8)) if len(oe)>1 else 1.0-1.0)/5.0,0,1))
209
+ return {'total':float((comp*0.30+cl*0.40+oq*0.20+0.5*0.10)*100),'completeness':float(comp),'cleanness':float(cl),'onset_quality':float(oq)}
 
 
 
 
210
  def select_best(clusters):
211
  print(f"[Stage 5] Selecting best...")
212
  for c in clusters:
213
  if c.count<=1: c.best_hit_idx=0; continue
214
+ c.best_hit_idx=int(np.argmax([sample_quality_score(h.audio,h.sr,c.label.rsplit('_',1)[0])['total'] for h in c.hits]))
 
 
 
 
215
 
216
+ # ─── Stage 6 ──────────────────────────────────────────────────────────────────
217
  def synthesize_from_cluster(cluster):
218
  if cluster.count<2: return None
219
+ tl=int(np.median([len(h.audio) for h in cluster.hits])); al,wt=[],[]; pp=None
 
 
220
  for i,h in enumerate(cluster.hits):
221
  a=h.audio.copy(); p=np.argmax(np.abs(a))
222
  if pp is None: pp=p
 
227
  pk=np.abs(a).max()
228
  if pk>0: a=a/pk
229
  al.append(a); wt.append(2.0 if i==cluster.best_hit_idx else 1.0)
230
+ al=np.array(al); w=np.array(wt); w/=w.sum(); sy=np.average(al,axis=0,weights=w); pk=np.abs(sy).max()
 
231
  return (sy*0.95/pk).astype(np.float32) if pk>0 else sy.astype(np.float32)
232
 
233
+ # ─── Stage 7 ──────────────────────────────────────────────────────────────────
234
+ def build_midi(clusters,bpm=120.0):
235
+ import pretty_midi; pm=pretty_midi.PrettyMIDI(initial_tempo=bpm)
 
 
 
236
  for i,c in enumerate(clusters): c.midi_note=min(36+i,127)
237
+ inst=pretty_midi.Instrument(program=0,is_drum=True,name='Samples'); pm.instruments.append(inst)
 
238
  for c in clusters:
239
  for h in c.hits:
240
+ inst.notes.append(pretty_midi.Note(velocity=max(1,min(127,int(h.rms_energy/0.3*127))),
241
+ pitch=c.midi_note,start=h.onset_time,end=h.onset_time+max(h.duration,0.05)))
242
+ inst.notes.sort(key=lambda n:n.start); return pm
243
+ def export_midi(clusters,path,bpm=120.0):
244
+ pm=build_midi(clusters,bpm); pm.write(path); print(f" ✓ MIDI: {len(pm.instruments[0].notes)} notes"); return pm
245
+ def detect_bpm(y,sr):
 
 
 
 
246
  ck=("bpm",_audio_hash(y),sr); c=cache_get(ck)
247
+ if c: return c
248
  oe=librosa.onset.onset_strength(y=y,sr=sr,aggregate=np.median)
249
  bpm=float(librosa.feature.tempo(onset_envelope=oe,sr=sr).item())
250
  _,beats=librosa.beat.beat_track(onset_envelope=oe,sr=sr,units='time')
 
255
  else:
256
  if bpm<70: bpm*=2
257
  elif bpm>200: bpm/=2
258
+ return cache_set(ck,round(bpm,1))
259
+ def render_midi_with_samples(clusters,sr=44100):
 
260
  me=max((h.onset_time+h.duration for c in clusters for h in c.hits),default=1.0)
261
  buf=np.zeros(int((me+1.0)*sr),dtype=np.float64)
262
  for c in clusters:
263
+ s=c.best_hit.audio.astype(np.float64); re=c.best_hit.rms_energy if c.best_hit.rms_energy>0 else 0.1
 
264
  for h in c.hits:
265
+ vs=min(2.0,h.rms_energy/(re+1e-8))**0.5; i=int(h.onset_time*sr); e=i+len(s)
 
266
  if e>len(buf): buf=np.concatenate([buf,np.zeros(e-len(buf))])
267
  buf[i:e]+=s*vs
268
+ pk=np.abs(buf).max(); return (buf/pk*0.9).astype(np.float32) if pk>1e-8 else buf.astype(np.float32)
 
 
269
  def build_sample_map(clusters):
270
+ return {c.midi_note:{'label':c.label,'count':c.count,'duration_ms':int(c.best_hit.duration*1000)} for c in clusters}
271
+ def build_archive(clusters,bpm,sr,midi_path=None,rendered_audio=None):
272
+ import zipfile,tempfile,io; zp=tempfile.mktemp(suffix='.zip')
273
+ idx={'bpm':round(bpm,1),'sample_rate':sr,'total_clusters':len(clusters),'total_hits':sum(c.count for c in clusters),'samples':{}}
 
 
 
 
274
  with zipfile.ZipFile(zp,'w',compression=zipfile.ZIP_STORED) as zf:
275
  for c in clusters:
276
+ b=c.best_hit; fn=f"samples/{c.label}.wav"; buf=io.BytesIO()
277
+ sf.write(buf,b.audio,sr,format='WAV',subtype='PCM_24'); zf.writestr(fn,buf.getvalue())
 
278
  ot=sorted([h.onset_time for h in c.hits])
279
+ idx['samples'][c.label]={'file':fn,'classification':c.label.rsplit('_',1)[0],'midi_note':c.midi_note,
280
+ 'occurrences':c.count,'onset_times_sec':[round(t,4) for t in ot],'duration_sec':round(b.duration,4),
281
+ 'rms_energy':round(b.rms_energy,6),'spectral_centroid_hz':round(b.spectral_centroid,1)}
 
 
 
 
282
  if c.synthesized is not None:
283
+ sf2=f"samples/{c.label}__synth.wav"; b2=io.BytesIO()
284
+ sf.write(b2,c.synthesized,sr,format='WAV',subtype='PCM_24'); zf.writestr(sf2,b2.getvalue())
285
+ idx['samples'][c.label]['synthesized_file']=sf2
286
  zf.writestr('index.json',json.dumps(idx,indent=2))
287
  if midi_path and os.path.exists(midi_path): zf.write(midi_path,'reconstruction.mid')
288
  if rendered_audio is not None:
 
290
  zf.writestr('rendered_reconstruction.wav',rb.getvalue())
291
  return zp
292
 
293
+ # ─── Auto-tuner with parameter locking ───────────────────────────────────────
294
 
295
+ def _spectral_envelope_corr(orig,rend,sr,n_fft=4096,hop=2048):
296
+ n=min(len(orig),len(rend))
297
+ if n<n_fft: return 0.0
298
+ eo=np.abs(librosa.stft(orig[:n],n_fft=n_fft,hop_length=hop)).mean(axis=1)
299
+ er=np.abs(librosa.stft(rend[:n],n_fft=n_fft,hop_length=hop)).mean(axis=1)
300
+ if eo.std()<1e-10 or er.std()<1e-10: return 0.0
301
+ return float(np.corrcoef(eo,er)[0,1])
 
 
 
 
 
 
 
 
302
 
303
+ def _rms_envelope_corr(orig,rend,sr,hop=1024):
304
+ n=min(len(orig),len(rend))
305
+ if n<hop*4: return 0.0
306
+ ro=librosa.feature.rms(y=orig[:n],hop_length=hop)[0]; rr=librosa.feature.rms(y=rend[:n],hop_length=hop)[0]
307
+ n2=min(len(ro),len(rr))
308
+ if n2<4: return 0.0
309
+ ro,rr=ro[:n2],rr[:n2]
310
+ if ro.std()<1e-10 or rr.std()<1e-10: return 0.0
311
+ return float(np.corrcoef(ro,rr)[0,1])
 
 
312
 
313
+ def _reconstruction_score(orig,rend,sr):
314
+ sc=_spectral_envelope_corr(orig,rend,sr); rc=_rms_envelope_corr(orig,rend,sr)
315
+ lr=min(len(rend),len(orig))/(max(len(rend),len(orig))+1)
316
+ return max(0.0,(sc*0.5+rc*0.4+lr*0.1)*100)
317
 
 
 
 
 
 
 
 
 
318
 
319
+ def auto_tune(stem_audio, sr, mode="auto", locks=None, log_fn=None):
320
+ """Find best extraction parameters. Locked params are held constant.
321
 
322
+ locks: dict mapping param names to fixed values. Supported keys:
323
+ 'onset_delta', 'energy_threshold_db', 'min_gap',
324
+ 'target_min', 'target_max'
325
+ Any key present in locks will NOT be swept — its value is used as-is.
326
 
327
+ Example: locks={'target_min': 5, 'target_max': 20}
328
+ cluster count will always be between 5 and 20, only onset params are tuned.
329
 
330
+ Returns: (best_params, best_score, log_lines)
 
 
331
  """
332
+ locks = locks or {}
333
  log = []
334
+ def _log(m):
335
+ log.append(m)
336
+ if log_fn: log_fn(m)
337
+ print(m)
338
 
339
+ locked_names = list(locks.keys())
340
+ _log(f"[Auto-tune] {len(stem_audio)/sr:.1f}s audio, locked: {locked_names or 'none'}")
341
 
342
+ # Build onset config grid — replace locked values with fixed ones
343
+ base_onset_configs = [
344
  {"onset_delta": 0.08, "energy_threshold_db": -40, "min_gap": 0.02},
345
  {"onset_delta": 0.10, "energy_threshold_db": -35, "min_gap": 0.025},
346
  {"onset_delta": 0.12, "energy_threshold_db": -35, "min_gap": 0.03},
 
349
  {"onset_delta": 0.25, "energy_threshold_db": -25, "min_gap": 0.06},
350
  ]
351
 
352
+ # Apply locks to onset configs
353
+ onset_configs = []
354
+ seen = set()
355
+ for oc in base_onset_configs:
356
+ for k in ['onset_delta', 'energy_threshold_db', 'min_gap']:
357
+ if k in locks:
358
+ oc[k] = locks[k]
359
+ # Deduplicate configs that become identical after locking
360
+ key = (oc['onset_delta'], oc['energy_threshold_db'], oc['min_gap'])
361
+ if key not in seen:
362
+ seen.add(key)
363
+ onset_configs.append(oc)
364
+
365
+ # Build cluster target grid — respect locks
366
+ if 'target_min' in locks and 'target_max' in locks:
367
+ tmin_locked, tmax_locked = int(locks['target_min']), int(locks['target_max'])
368
+ # Generate targets within the locked range
369
+ cluster_targets = sorted(set([tmin_locked, tmax_locked,
370
+ (tmin_locked+tmax_locked)//2,
371
+ tmin_locked + (tmax_locked-tmin_locked)//3,
372
+ tmin_locked + 2*(tmax_locked-tmin_locked)//3]))
373
+ cluster_targets = [t for t in cluster_targets if tmin_locked <= t <= tmax_locked]
374
+ _log(f" Cluster targets (locked {tmin_locked}–{tmax_locked}): {cluster_targets}")
375
+ elif 'target_min' in locks:
376
+ tmin_locked = int(locks['target_min'])
377
+ cluster_targets = [t for t in [3,5,8,10,15,20,30] if t >= tmin_locked]
378
+ if not cluster_targets: cluster_targets = [tmin_locked]
379
+ elif 'target_max' in locks:
380
+ tmax_locked = int(locks['target_max'])
381
+ cluster_targets = [t for t in [3,5,8,10,15,20,30] if t <= tmax_locked]
382
+ if not cluster_targets: cluster_targets = [tmax_locked]
383
+ else:
384
+ cluster_targets = [3, 5, 8, 10, 15, 20, 30]
385
 
386
+ best_score, best_params = -1, {}
 
 
 
387
 
388
  for oc_idx, oc in enumerate(onset_configs):
389
+ _log(f"\n [{oc_idx+1}/{len(onset_configs)}] delta={oc['onset_delta']}, "
390
+ f"energy={oc['energy_threshold_db']}dB, gap={oc['min_gap']}")
 
391
  hits = detect_onsets(stem_audio, sr, mode=mode, **oc)
392
+ if len(hits) < 2: _log(f" {len(hits)} hits, skip"); continue
 
 
 
393
  hits = classify_hits(hits)
394
 
395
+ for nt in cluster_targets:
396
+ if nt >= len(hits): continue
397
+ clusters = cluster_hits(hits, target_min=nt, target_max=nt, linkage='average')
 
 
 
398
  if not clusters: continue
 
 
399
  for c in clusters:
400
+ if c.count<=1: c.best_hit_idx=0
401
+ else: c.best_hit_idx=int(np.argmax([h.rms_energy for h in c.hits]))
 
 
 
402
  rendered = render_midi_with_samples(clusters, sr=sr)
403
  score = _reconstruction_score(stem_audio, rendered, sr)
 
 
 
 
404
  if score > best_score:
405
  best_score = score
406
+ best_params = {**oc, 'n_clusters': len(clusters),
407
+ 'target_min': nt, 'target_max': nt}
408
+ _log(f" ★ target={nt} → {len(clusters)} cl, score={score:.1f} BEST")
 
 
409
  else:
410
+ _log(f" target={nt} → {len(clusters)} cl, score={score:.1f}")
411
 
412
+ # Fine-tune around best
413
  if best_params:
414
  bt = best_params.get('target_min', 10)
415
+ fine_oc = {k: best_params[k] for k in ['onset_delta','energy_threshold_db','min_gap']}
416
+ _log(f"\n Fine-tuning ±3 around target={bt}...")
 
 
417
  hits = detect_onsets(stem_audio, sr, mode=mode, **fine_oc)
418
  if len(hits) >= 2:
419
  hits = classify_hits(hits)
420
+ lo = max(2, int(locks.get('target_min', bt-3)))
421
+ hi = min(len(hits)-1, int(locks.get('target_max', bt+3)))
422
+ for ft in range(lo, hi+1):
423
  clusters = cluster_hits(hits, target_min=ft, target_max=ft, linkage='average')
424
  if not clusters: continue
425
  for c in clusters:
426
+ if c.count<=1: c.best_hit_idx=0
427
+ else: c.best_hit_idx=int(np.argmax([h.rms_energy for h in c.hits]))
428
  rendered = render_midi_with_samples(clusters, sr=sr)
429
  score = _reconstruction_score(stem_audio, rendered, sr)
430
  if score > best_score:
431
+ best_score=score
432
+ best_params={**fine_oc,'n_clusters':len(clusters),'target_min':ft,'target_max':ft}
433
+ _log(f" ★ target={ft} → {len(clusters)} cl, score={score:.1f} BEST")
 
 
 
 
 
 
 
 
434
 
435
+ _log(f"\n[Auto-tune] Best: score={best_score:.1f}, params={best_params}")
436
  return best_params, best_score, log