drum-sample-extractor / sample_extractor.py
rikhoffbauer2's picture
v8: Parameter locking for auto-tune β€” lock any param to constrain the search
2d1fee2 verified
#!/usr/bin/env python3
"""
Sample Extractor v8 β€” Auto-tuning with parameter locking.
auto_tune() accepts a `locks` dict: any param name mapped to its fixed value
will be held constant during the search. All other params are swept freely.
"""
import argparse, json, os, sys, warnings, hashlib
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
import librosa, numpy as np, soundfile as sf, torch
from scipy.signal import fftconvolve
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
@dataclass
class Hit:
audio: np.ndarray; sr: int; onset_time: float; duration: float; index: int
rms_energy: float = 0.0; spectral_centroid: float = 0.0
label: str = ""; embedding: Optional[np.ndarray] = None; cluster_id: int = -1
def save(self, path): sf.write(path, self.audio, self.sr, subtype='PCM_24')
@dataclass
class Cluster:
cluster_id: int; label: str; hits: list = field(default_factory=list)
best_hit_idx: int = 0; synthesized: Optional[np.ndarray] = None; midi_note: int = 60
@property
def best_hit(self): return self.hits[self.best_hit_idx]
@property
def count(self): return len(self.hits)
DEMUCS_MODELS = ["htdemucs","htdemucs_ft","htdemucs_6s","mdx","mdx_extra","mdx_extra_q"]
DEMUCS_STEMS = {
"htdemucs":["drums","bass","other","vocals"],"htdemucs_ft":["drums","bass","other","vocals"],
"htdemucs_6s":["drums","bass","other","vocals","guitar","piano"],
"mdx":["drums","bass","other","vocals"],"mdx_extra":["drums","bass","other","vocals"],
"mdx_extra_q":["drums","bass","other","vocals"],
}
# ─── Cache ────────────────────────────────────────────────────────────────────
_cache = {}
def _audio_hash(a): return hashlib.md5(a[:4000].tobytes()).hexdigest()
def cache_get(k): return _cache.get(k)
def cache_set(k,v): _cache[k]=v; return v
def cache_clear(): _cache.clear()
# ─── Stage 1 ──────────────────────────────────────────────────────────────────
def extract_stem(audio_path,stem="drums",device="cpu",model_name="htdemucs_ft",shifts=1,overlap=0.25):
if stem=="all":
y,sr=librosa.load(audio_path,sr=44100,mono=True); return y.astype(np.float32),sr
with open(audio_path,'rb') as f: fh=hashlib.md5(f.read(200000)).hexdigest()
ck=("stem",fh,stem,model_name,shifts,overlap); c=cache_get(ck)
if c: print(f"[Stage 1] Cached {stem}"); return c
from demucs.pretrained import get_model; from demucs.apply import apply_model
print(f"[Stage 1] {stem} with {model_name}...")
model=get_model(model_name); model.eval().to(device); sr=model.samplerate
if stem not in model.sources: raise ValueError(f"'{stem}' not in {model.sources}")
a,_=librosa.load(audio_path,sr=sr,mono=False)
if a.ndim==1: a=np.stack([a,a])
elif a.shape[0]>2: a=a[:2]
elif a.shape[0]==1: a=np.concatenate([a,a],axis=0)
wav=torch.from_numpy(a).float().unsqueeze(0).to(device)
with torch.no_grad(): src=apply_model(model,wav,device=device,shifts=shifts,split=True,overlap=overlap)
r=src[0,model.sources.index(stem)].mean(dim=0).cpu().numpy()
print(f" βœ“ {len(r)/sr:.1f}s"); return cache_set(ck,(r.astype(np.float32),sr))
# ─── Stage 2 ──────────────────────────────────────────────────────────────────
def detect_onsets(y,sr,pre_pad=0.005,min_dur=0.02,max_dur=1.5,min_gap=0.03,
energy_threshold_db=-35.0,mode="auto",backtrack=True,onset_delta=0.12):
print(f"[Stage 2] Onsets (delta={onset_delta}, energyβ‰₯{energy_threshold_db}dB)...")
if mode=="percussive": oe=librosa.onset.onset_strength(y=y,sr=sr,aggregate=np.median,fmax=8000)
elif mode=="harmonic": yh,_=librosa.effects.hpss(y); oe=librosa.onset.onset_strength(y=yh,sr=sr,fmax=8000,lag=2,max_size=3)
elif mode=="broadband": oe=librosa.onset.onset_strength(y=y,sr=sr)
else:
yh,_=librosa.effects.hpss(y)
def _n(x): m=x.max(); return x/m if m>0 else x
oe=np.maximum.reduce([_n(librosa.onset.onset_strength(y=y,sr=sr,fmin=20,fmax=250,aggregate=np.median)),
_n(librosa.onset.onset_strength(y=y,sr=sr,fmin=250,fmax=4000,aggregate=np.median)),
_n(librosa.onset.onset_strength(y=y,sr=sr,fmin=4000,fmax=min(sr//2,20000),aggregate=np.median)),
_n(librosa.onset.onset_strength(y=yh,sr=sr,lag=2))])
w=max(1,int(min_gap*sr/512))
fr=librosa.onset.onset_detect(onset_envelope=oe,sr=sr,wait=w,pre_avg=3,post_avg=3,
pre_max=3,post_max=5,delta=onset_delta,backtrack=backtrack,units='frames')
times=librosa.frames_to_time(fr,sr=sr); print(f" Raw: {len(times)}")
thr=10**(energy_threshold_db/20); hits=[]
for i,t in enumerate(times):
s=max(0,int((t-pre_pad)*sr))
e=min(int(times[i+1]*sr) if i+1<len(times) else len(y),s+int(max_dur*sr))
seg=y[s:e]
if len(seg)<int(min_dur*sr): continue
rms=np.sqrt(np.mean(seg**2))
if rms<thr: continue
fl=min(int(0.005*sr),len(seg)//4)
if fl>0: seg=seg.copy(); seg[-fl:]*=np.linspace(1,0,fl)
sc=float(librosa.feature.spectral_centroid(y=seg,sr=sr).mean())
hits.append(Hit(audio=seg,sr=sr,onset_time=t,duration=len(seg)/sr,
index=len(hits),rms_energy=float(rms),spectral_centroid=sc))
print(f" βœ“ {len(hits)} hits"); return hits
# ─── Stage 3 ──────────────────────────────────────────────────────────────────
LABEL_RULES=[("kick",lambda lr,mr,hr,c,zcr,d:lr>0.5 and c<800),
("hihat_closed",lambda lr,mr,hr,c,zcr,d:hr>0.35 and c>4000 and d<0.15),
("hihat_open",lambda lr,mr,hr,c,zcr,d:hr>0.35 and c>4000 and d>=0.15),
("cymbal",lambda lr,mr,hr,c,zcr,d:hr>0.25 and c>3000),
("snare",lambda lr,mr,hr,c,zcr,d:mr>0.4 and zcr>0.1 and c>1000),
("tom",lambda lr,mr,hr,c,zcr,d:lr>0.3 and mr>0.3 and c<1500),
("bass",lambda lr,mr,hr,c,zcr,d:lr>0.6 and c<400 and d>0.2),
("vocal",lambda lr,mr,hr,c,zcr,d:mr>0.5 and 500<c<3000 and zcr<0.15),
("bright",lambda lr,mr,hr,c,zcr,d:c>2500),("mid",lambda lr,mr,hr,c,zcr,d:c>800)]
def classify_hit(hit):
y,sr=hit.audio,hit.sr; D=np.abs(librosa.stft(y,n_fft=2048))
f=librosa.fft_frequencies(sr=sr,n_fft=2048)
le=np.sum(D[(f>=20)&(f<200)]**2); me=np.sum(D[(f>=200)&(f<4000)]**2)
he=np.sum(D[(f>=4000)]**2); t=le+me+he+1e-10; lr,mr,hr=le/t,me/t,he/t
zcr=float(librosa.feature.zero_crossing_rate(y=y).mean())
for n,fn in LABEL_RULES:
if fn(lr,mr,hr,hit.spectral_centroid,zcr,hit.duration): return n
return "other"
def classify_hits(hits):
print(f"[Stage 3] Classifying {len(hits)}...")
for h in hits: h.label=classify_hit(h)
c=defaultdict(int)
for h in hits: c[h.label]+=1
for l,n in sorted(c.items(),key=lambda x:-x[1]): print(f" {l}: {n}")
return hits
# ─── Stage 4 ──────────────────────────────────────────────────────────────────
def ncc_max(a,b):
n=min(len(a),len(b)); a,b=a[:n].copy(),b[:n].copy(); a-=a.mean(); b-=b.mean()
norm=np.sqrt(np.dot(a,a)*np.dot(b,b))
if norm<1e-10: return 0.0
return float(np.max(np.abs(fftconvolve(a,b[::-1],mode='full'))))/norm
def build_ncc_distance_matrix(hits,max_compare_samples=8820):
key=("ncc_dist",tuple(_audio_hash(h.audio) for h in hits),max_compare_samples)
c=cache_get(key)
if c is not None: print(f" Cached NCC matrix"); return c
N=len(hits); D=np.zeros((N,N),dtype=np.float32)
for i in range(N):
ai=hits[i].audio[:max_compare_samples]
for j in range(i+1,N): D[i,j]=D[j,i]=max(0.0,1.0-ncc_max(ai,hits[j].audio[:max_compare_samples]))
return cache_set(key,D)
def _labels_to_clusters(labels,hits):
cm=defaultdict(list)
for i,l in enumerate(labels): cm[l].append(i)
clusters=[]
for _,idx in sorted(cm.items()):
v=defaultdict(int)
for i in idx: v[hits[i].label]+=1
maj=max(v,key=v.get); ex=sum(1 for c in clusters if c.label.rsplit('_',1)[0]==maj)
clusters.append(Cluster(cluster_id=len(clusters),label=f"{maj}_{ex}",hits=[hits[i] for i in idx]))
clusters.sort(key=lambda c:c.count,reverse=True)
for i,c in enumerate(clusters): c.cluster_id=i
return clusters
def cluster_hits(hits,ncc_threshold=0.80,max_compare_ms=0,target_min=0,target_max=0,linkage='average'):
from sklearn.cluster import AgglomerativeClustering as AC
if not hits: return []
N=len(hits); sr=hits[0].sr
if N==1: return [Cluster(cluster_id=0,label=f"{hits[0].label}_0",hits=[hits[0]])]
ms=max(int(0.03*sr),int(np.median([len(h.audio) for h in hits]))) if max_compare_ms<=0 else int(max_compare_ms/1000.0*sr)
print(f"[Stage 4] NCC ({N} hits, {ms/sr*1000:.0f}ms, {linkage})...")
print(f" {N*(N-1)//2} pairs..."); D=build_ncc_distance_matrix(hits,max_compare_samples=ms)
use_t=target_min>0 and target_max>0 and target_max>=target_min
tmin,tmax=max(1,min(target_min or 1,N)),max(max(1,target_min or 1),min(target_max or N,N))
if use_t:
print(f" Target: {tmin}–{tmax}")
lo,hi=0.001,1.0; bl,bn,bd=None,-1,0.5
for _ in range(30):
mid=(lo+hi)/2; lb=AC(n_clusters=None,distance_threshold=max(0.001,mid),metric='precomputed',linkage=linkage).fit_predict(D)
n=len(set(lb))
if tmin<=n<=tmax: bl,bn,bd=lb,n,mid; break
elif n>tmax: lo=mid
else: hi=mid
if bl is None or abs(n-(tmin+tmax)/2)<abs(bn-(tmin+tmax)/2): bl,bn,bd=lb,n,mid
if bn<tmin or bn>tmax:
tm=min((tmin+tmax)//2,N-1); print(f" Fallback n={tm}")
try: bl=AC(n_clusters=tm,metric='precomputed',linkage=linkage).fit_predict(D); bn=tm
except: pass
labels=bl; print(f" β†’ {bn}")
else:
dt=max(0.001,1.0-ncc_threshold); print(f" Fixed dist≀{dt:.3f}")
labels=AC(n_clusters=None,distance_threshold=dt,metric='precomputed',linkage=linkage).fit_predict(D)
print(f" βœ“ {len(set(labels))} clusters")
cl=_labels_to_clusters(labels,hits)
for c in cl: print(f" {c.label}: {c.count}")
return cl
# ─── Stage 5 ──────────────────────────────────────────────────────────────────
def sample_quality_score(y,sr,label="other"):
import scipy.stats
re=librosa.feature.rms(y=y,frame_length=512,hop_length=128)[0]
if len(re)>=10:
pk=np.argmax(re); po=re[pk:]
c1=max(0,1.0-np.mean(po[-max(3,len(po)//5):])/( re[pk]+1e-8)*5)
c2=(max(0,scipy.stats.linregress(np.arange(len(po)),np.log(po+1e-8))[2]**2) if scipy.stats.linregress(np.arange(len(po)),np.log(po+1e-8))[0]<0 else 0.0) if len(po)>=5 else 0.0
else: c1,c2=0.5,0.0
comp=c1*0.6+c2*0.4; snr=10*np.log10(np.percentile(y**2,99)/(np.percentile(y**2,10)+1e-12))
ns=np.clip((snr-10)/40,0,1); ons=librosa.onset.onset_detect(y=y,sr=sr,units='samples',backtrack=True)
if len(ons)>0:
o=int(ons[0]); pre=y[max(0,o-int(sr*.02)):o]; sig=y[o:o+int(sr*.1)]
np2=np.clip((-10*np.log10(np.mean(pre**2+1e-12)/np.mean(sig**2+1e-12))-5)/30,0,1) if len(pre)>10 and len(sig)>10 else 0.5
else: np2=0.5
cl=ns*0.5+np2*0.5; oe=librosa.onset.onset_strength(y=y,sr=sr)
oq=float(np.clip((float(np.max(oe)/(np.mean(oe)+1e-8)) if len(oe)>1 else 1.0-1.0)/5.0,0,1))
return {'total':float((comp*0.30+cl*0.40+oq*0.20+0.5*0.10)*100),'completeness':float(comp),'cleanness':float(cl),'onset_quality':float(oq)}
def select_best(clusters):
print(f"[Stage 5] Selecting best...")
for c in clusters:
if c.count<=1: c.best_hit_idx=0; continue
c.best_hit_idx=int(np.argmax([sample_quality_score(h.audio,h.sr,c.label.rsplit('_',1)[0])['total'] for h in c.hits]))
# ─── Stage 6 ──────────────────────────────────────────────────────────────────
def synthesize_from_cluster(cluster):
if cluster.count<2: return None
tl=int(np.median([len(h.audio) for h in cluster.hits])); al,wt=[],[]; pp=None
for i,h in enumerate(cluster.hits):
a=h.audio.copy(); p=np.argmax(np.abs(a))
if pp is None: pp=p
s=pp-p
if s>0: a=np.pad(a,(s,0))
elif s<0: a=a[-s:]
a=a[:tl] if len(a)>=tl else np.pad(a,(0,tl-len(a)))
pk=np.abs(a).max()
if pk>0: a=a/pk
al.append(a); wt.append(2.0 if i==cluster.best_hit_idx else 1.0)
al=np.array(al); w=np.array(wt); w/=w.sum(); sy=np.average(al,axis=0,weights=w); pk=np.abs(sy).max()
return (sy*0.95/pk).astype(np.float32) if pk>0 else sy.astype(np.float32)
# ─── Stage 7 ──────────────────────────────────────────────────────────────────
def build_midi(clusters,bpm=120.0):
import pretty_midi; pm=pretty_midi.PrettyMIDI(initial_tempo=bpm)
for i,c in enumerate(clusters): c.midi_note=min(36+i,127)
inst=pretty_midi.Instrument(program=0,is_drum=True,name='Samples'); pm.instruments.append(inst)
for c in clusters:
for h in c.hits:
inst.notes.append(pretty_midi.Note(velocity=max(1,min(127,int(h.rms_energy/0.3*127))),
pitch=c.midi_note,start=h.onset_time,end=h.onset_time+max(h.duration,0.05)))
inst.notes.sort(key=lambda n:n.start); return pm
def export_midi(clusters,path,bpm=120.0):
pm=build_midi(clusters,bpm); pm.write(path); print(f" βœ“ MIDI: {len(pm.instruments[0].notes)} notes"); return pm
def detect_bpm(y,sr):
ck=("bpm",_audio_hash(y),sr); c=cache_get(ck)
if c: return c
oe=librosa.onset.onset_strength(y=y,sr=sr,aggregate=np.median)
bpm=float(librosa.feature.tempo(onset_envelope=oe,sr=sr).item())
_,beats=librosa.beat.beat_track(onset_envelope=oe,sr=sr,units='time')
if len(beats)>2:
ibi=60.0/float(np.median(np.diff(beats)))
for c in [bpm,ibi]:
if 70<=c<=200: bpm=c; break
else:
if bpm<70: bpm*=2
elif bpm>200: bpm/=2
return cache_set(ck,round(bpm,1))
def render_midi_with_samples(clusters,sr=44100):
me=max((h.onset_time+h.duration for c in clusters for h in c.hits),default=1.0)
buf=np.zeros(int((me+1.0)*sr),dtype=np.float64)
for c in clusters:
s=c.best_hit.audio.astype(np.float64); re=c.best_hit.rms_energy if c.best_hit.rms_energy>0 else 0.1
for h in c.hits:
vs=min(2.0,h.rms_energy/(re+1e-8))**0.5; i=int(h.onset_time*sr); e=i+len(s)
if e>len(buf): buf=np.concatenate([buf,np.zeros(e-len(buf))])
buf[i:e]+=s*vs
pk=np.abs(buf).max(); return (buf/pk*0.9).astype(np.float32) if pk>1e-8 else buf.astype(np.float32)
def build_sample_map(clusters):
return {c.midi_note:{'label':c.label,'count':c.count,'duration_ms':int(c.best_hit.duration*1000)} for c in clusters}
def build_archive(clusters,bpm,sr,midi_path=None,rendered_audio=None):
import zipfile,tempfile,io; zp=tempfile.mktemp(suffix='.zip')
idx={'bpm':round(bpm,1),'sample_rate':sr,'total_clusters':len(clusters),'total_hits':sum(c.count for c in clusters),'samples':{}}
with zipfile.ZipFile(zp,'w',compression=zipfile.ZIP_STORED) as zf:
for c in clusters:
b=c.best_hit; fn=f"samples/{c.label}.wav"; buf=io.BytesIO()
sf.write(buf,b.audio,sr,format='WAV',subtype='PCM_24'); zf.writestr(fn,buf.getvalue())
ot=sorted([h.onset_time for h in c.hits])
idx['samples'][c.label]={'file':fn,'classification':c.label.rsplit('_',1)[0],'midi_note':c.midi_note,
'occurrences':c.count,'onset_times_sec':[round(t,4) for t in ot],'duration_sec':round(b.duration,4),
'rms_energy':round(b.rms_energy,6),'spectral_centroid_hz':round(b.spectral_centroid,1)}
if c.synthesized is not None:
sf2=f"samples/{c.label}__synth.wav"; b2=io.BytesIO()
sf.write(b2,c.synthesized,sr,format='WAV',subtype='PCM_24'); zf.writestr(sf2,b2.getvalue())
idx['samples'][c.label]['synthesized_file']=sf2
zf.writestr('index.json',json.dumps(idx,indent=2))
if midi_path and os.path.exists(midi_path): zf.write(midi_path,'reconstruction.mid')
if rendered_audio is not None:
rb=io.BytesIO(); sf.write(rb,rendered_audio,sr,format='WAV',subtype='PCM_16')
zf.writestr('rendered_reconstruction.wav',rb.getvalue())
return zp
# ─── Auto-tuner with parameter locking ───────────────────────────────────────
def _spectral_envelope_corr(orig,rend,sr,n_fft=4096,hop=2048):
n=min(len(orig),len(rend))
if n<n_fft: return 0.0
eo=np.abs(librosa.stft(orig[:n],n_fft=n_fft,hop_length=hop)).mean(axis=1)
er=np.abs(librosa.stft(rend[:n],n_fft=n_fft,hop_length=hop)).mean(axis=1)
if eo.std()<1e-10 or er.std()<1e-10: return 0.0
return float(np.corrcoef(eo,er)[0,1])
def _rms_envelope_corr(orig,rend,sr,hop=1024):
n=min(len(orig),len(rend))
if n<hop*4: return 0.0
ro=librosa.feature.rms(y=orig[:n],hop_length=hop)[0]; rr=librosa.feature.rms(y=rend[:n],hop_length=hop)[0]
n2=min(len(ro),len(rr))
if n2<4: return 0.0
ro,rr=ro[:n2],rr[:n2]
if ro.std()<1e-10 or rr.std()<1e-10: return 0.0
return float(np.corrcoef(ro,rr)[0,1])
def _reconstruction_score(orig,rend,sr):
sc=_spectral_envelope_corr(orig,rend,sr); rc=_rms_envelope_corr(orig,rend,sr)
lr=min(len(rend),len(orig))/(max(len(rend),len(orig))+1)
return max(0.0,(sc*0.5+rc*0.4+lr*0.1)*100)
def auto_tune(stem_audio, sr, mode="auto", locks=None, log_fn=None):
"""Find best extraction parameters. Locked params are held constant.
locks: dict mapping param names to fixed values. Supported keys:
'onset_delta', 'energy_threshold_db', 'min_gap',
'target_min', 'target_max'
Any key present in locks will NOT be swept β€” its value is used as-is.
Example: locks={'target_min': 5, 'target_max': 20}
β†’ cluster count will always be between 5 and 20, only onset params are tuned.
Returns: (best_params, best_score, log_lines)
"""
locks = locks or {}
log = []
def _log(m):
log.append(m)
if log_fn: log_fn(m)
print(m)
locked_names = list(locks.keys())
_log(f"[Auto-tune] {len(stem_audio)/sr:.1f}s audio, locked: {locked_names or 'none'}")
# Build onset config grid β€” replace locked values with fixed ones
base_onset_configs = [
{"onset_delta": 0.08, "energy_threshold_db": -40, "min_gap": 0.02},
{"onset_delta": 0.10, "energy_threshold_db": -35, "min_gap": 0.025},
{"onset_delta": 0.12, "energy_threshold_db": -35, "min_gap": 0.03},
{"onset_delta": 0.15, "energy_threshold_db": -30, "min_gap": 0.04},
{"onset_delta": 0.20, "energy_threshold_db": -28, "min_gap": 0.05},
{"onset_delta": 0.25, "energy_threshold_db": -25, "min_gap": 0.06},
]
# Apply locks to onset configs
onset_configs = []
seen = set()
for oc in base_onset_configs:
for k in ['onset_delta', 'energy_threshold_db', 'min_gap']:
if k in locks:
oc[k] = locks[k]
# Deduplicate configs that become identical after locking
key = (oc['onset_delta'], oc['energy_threshold_db'], oc['min_gap'])
if key not in seen:
seen.add(key)
onset_configs.append(oc)
# Build cluster target grid β€” respect locks
if 'target_min' in locks and 'target_max' in locks:
tmin_locked, tmax_locked = int(locks['target_min']), int(locks['target_max'])
# Generate targets within the locked range
cluster_targets = sorted(set([tmin_locked, tmax_locked,
(tmin_locked+tmax_locked)//2,
tmin_locked + (tmax_locked-tmin_locked)//3,
tmin_locked + 2*(tmax_locked-tmin_locked)//3]))
cluster_targets = [t for t in cluster_targets if tmin_locked <= t <= tmax_locked]
_log(f" Cluster targets (locked {tmin_locked}–{tmax_locked}): {cluster_targets}")
elif 'target_min' in locks:
tmin_locked = int(locks['target_min'])
cluster_targets = [t for t in [3,5,8,10,15,20,30] if t >= tmin_locked]
if not cluster_targets: cluster_targets = [tmin_locked]
elif 'target_max' in locks:
tmax_locked = int(locks['target_max'])
cluster_targets = [t for t in [3,5,8,10,15,20,30] if t <= tmax_locked]
if not cluster_targets: cluster_targets = [tmax_locked]
else:
cluster_targets = [3, 5, 8, 10, 15, 20, 30]
best_score, best_params = -1, {}
for oc_idx, oc in enumerate(onset_configs):
_log(f"\n [{oc_idx+1}/{len(onset_configs)}] delta={oc['onset_delta']}, "
f"energy={oc['energy_threshold_db']}dB, gap={oc['min_gap']}")
hits = detect_onsets(stem_audio, sr, mode=mode, **oc)
if len(hits) < 2: _log(f" {len(hits)} hits, skip"); continue
hits = classify_hits(hits)
for nt in cluster_targets:
if nt >= len(hits): continue
clusters = cluster_hits(hits, target_min=nt, target_max=nt, linkage='average')
if not clusters: continue
for c in clusters:
if c.count<=1: c.best_hit_idx=0
else: c.best_hit_idx=int(np.argmax([h.rms_energy for h in c.hits]))
rendered = render_midi_with_samples(clusters, sr=sr)
score = _reconstruction_score(stem_audio, rendered, sr)
if score > best_score:
best_score = score
best_params = {**oc, 'n_clusters': len(clusters),
'target_min': nt, 'target_max': nt}
_log(f" β˜… target={nt} β†’ {len(clusters)} cl, score={score:.1f} BEST")
else:
_log(f" target={nt} β†’ {len(clusters)} cl, score={score:.1f}")
# Fine-tune around best
if best_params:
bt = best_params.get('target_min', 10)
fine_oc = {k: best_params[k] for k in ['onset_delta','energy_threshold_db','min_gap']}
_log(f"\n Fine-tuning Β±3 around target={bt}...")
hits = detect_onsets(stem_audio, sr, mode=mode, **fine_oc)
if len(hits) >= 2:
hits = classify_hits(hits)
lo = max(2, int(locks.get('target_min', bt-3)))
hi = min(len(hits)-1, int(locks.get('target_max', bt+3)))
for ft in range(lo, hi+1):
clusters = cluster_hits(hits, target_min=ft, target_max=ft, linkage='average')
if not clusters: continue
for c in clusters:
if c.count<=1: c.best_hit_idx=0
else: c.best_hit_idx=int(np.argmax([h.rms_energy for h in c.hits]))
rendered = render_midi_with_samples(clusters, sr=sr)
score = _reconstruction_score(stem_audio, rendered, sr)
if score > best_score:
best_score=score
best_params={**fine_oc,'n_clusters':len(clusters),'target_min':ft,'target_max':ft}
_log(f" β˜… target={ft} β†’ {len(clusters)} cl, score={score:.1f} BEST")
_log(f"\n[Auto-tune] Best: score={best_score:.1f}, params={best_params}")
return best_params, best_score, log