| |
| """ |
| Sample Extractor v8 β Auto-tuning with parameter locking. |
| |
| auto_tune() accepts a `locks` dict: any param name mapped to its fixed value |
| will be held constant during the search. All other params are swept freely. |
| """ |
|
|
| import argparse, json, os, sys, warnings, hashlib |
| from collections import defaultdict |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import Optional |
|
|
| import librosa, numpy as np, soundfile as sf, torch |
| from scipy.signal import fftconvolve |
|
|
| warnings.filterwarnings("ignore", category=FutureWarning) |
| warnings.filterwarnings("ignore", category=UserWarning) |
|
|
| @dataclass |
| class Hit: |
| audio: np.ndarray; sr: int; onset_time: float; duration: float; index: int |
| rms_energy: float = 0.0; spectral_centroid: float = 0.0 |
| label: str = ""; embedding: Optional[np.ndarray] = None; cluster_id: int = -1 |
| def save(self, path): sf.write(path, self.audio, self.sr, subtype='PCM_24') |
|
|
| @dataclass |
| class Cluster: |
| cluster_id: int; label: str; hits: list = field(default_factory=list) |
| best_hit_idx: int = 0; synthesized: Optional[np.ndarray] = None; midi_note: int = 60 |
| @property |
| def best_hit(self): return self.hits[self.best_hit_idx] |
| @property |
| def count(self): return len(self.hits) |
|
|
| DEMUCS_MODELS = ["htdemucs","htdemucs_ft","htdemucs_6s","mdx","mdx_extra","mdx_extra_q"] |
| DEMUCS_STEMS = { |
| "htdemucs":["drums","bass","other","vocals"],"htdemucs_ft":["drums","bass","other","vocals"], |
| "htdemucs_6s":["drums","bass","other","vocals","guitar","piano"], |
| "mdx":["drums","bass","other","vocals"],"mdx_extra":["drums","bass","other","vocals"], |
| "mdx_extra_q":["drums","bass","other","vocals"], |
| } |
|
|
| |
| _cache = {} |
| def _audio_hash(a): return hashlib.md5(a[:4000].tobytes()).hexdigest() |
| def cache_get(k): return _cache.get(k) |
| def cache_set(k,v): _cache[k]=v; return v |
| def cache_clear(): _cache.clear() |
|
|
| |
| def extract_stem(audio_path,stem="drums",device="cpu",model_name="htdemucs_ft",shifts=1,overlap=0.25): |
| if stem=="all": |
| y,sr=librosa.load(audio_path,sr=44100,mono=True); return y.astype(np.float32),sr |
| with open(audio_path,'rb') as f: fh=hashlib.md5(f.read(200000)).hexdigest() |
| ck=("stem",fh,stem,model_name,shifts,overlap); c=cache_get(ck) |
| if c: print(f"[Stage 1] Cached {stem}"); return c |
| from demucs.pretrained import get_model; from demucs.apply import apply_model |
| print(f"[Stage 1] {stem} with {model_name}...") |
| model=get_model(model_name); model.eval().to(device); sr=model.samplerate |
| if stem not in model.sources: raise ValueError(f"'{stem}' not in {model.sources}") |
| a,_=librosa.load(audio_path,sr=sr,mono=False) |
| if a.ndim==1: a=np.stack([a,a]) |
| elif a.shape[0]>2: a=a[:2] |
| elif a.shape[0]==1: a=np.concatenate([a,a],axis=0) |
| wav=torch.from_numpy(a).float().unsqueeze(0).to(device) |
| with torch.no_grad(): src=apply_model(model,wav,device=device,shifts=shifts,split=True,overlap=overlap) |
| r=src[0,model.sources.index(stem)].mean(dim=0).cpu().numpy() |
| print(f" β {len(r)/sr:.1f}s"); return cache_set(ck,(r.astype(np.float32),sr)) |
|
|
| |
| def detect_onsets(y,sr,pre_pad=0.005,min_dur=0.02,max_dur=1.5,min_gap=0.03, |
| energy_threshold_db=-35.0,mode="auto",backtrack=True,onset_delta=0.12): |
| print(f"[Stage 2] Onsets (delta={onset_delta}, energyβ₯{energy_threshold_db}dB)...") |
| if mode=="percussive": oe=librosa.onset.onset_strength(y=y,sr=sr,aggregate=np.median,fmax=8000) |
| elif mode=="harmonic": yh,_=librosa.effects.hpss(y); oe=librosa.onset.onset_strength(y=yh,sr=sr,fmax=8000,lag=2,max_size=3) |
| elif mode=="broadband": oe=librosa.onset.onset_strength(y=y,sr=sr) |
| else: |
| yh,_=librosa.effects.hpss(y) |
| def _n(x): m=x.max(); return x/m if m>0 else x |
| oe=np.maximum.reduce([_n(librosa.onset.onset_strength(y=y,sr=sr,fmin=20,fmax=250,aggregate=np.median)), |
| _n(librosa.onset.onset_strength(y=y,sr=sr,fmin=250,fmax=4000,aggregate=np.median)), |
| _n(librosa.onset.onset_strength(y=y,sr=sr,fmin=4000,fmax=min(sr//2,20000),aggregate=np.median)), |
| _n(librosa.onset.onset_strength(y=yh,sr=sr,lag=2))]) |
| w=max(1,int(min_gap*sr/512)) |
| fr=librosa.onset.onset_detect(onset_envelope=oe,sr=sr,wait=w,pre_avg=3,post_avg=3, |
| pre_max=3,post_max=5,delta=onset_delta,backtrack=backtrack,units='frames') |
| times=librosa.frames_to_time(fr,sr=sr); print(f" Raw: {len(times)}") |
| thr=10**(energy_threshold_db/20); hits=[] |
| for i,t in enumerate(times): |
| s=max(0,int((t-pre_pad)*sr)) |
| e=min(int(times[i+1]*sr) if i+1<len(times) else len(y),s+int(max_dur*sr)) |
| seg=y[s:e] |
| if len(seg)<int(min_dur*sr): continue |
| rms=np.sqrt(np.mean(seg**2)) |
| if rms<thr: continue |
| fl=min(int(0.005*sr),len(seg)//4) |
| if fl>0: seg=seg.copy(); seg[-fl:]*=np.linspace(1,0,fl) |
| sc=float(librosa.feature.spectral_centroid(y=seg,sr=sr).mean()) |
| hits.append(Hit(audio=seg,sr=sr,onset_time=t,duration=len(seg)/sr, |
| index=len(hits),rms_energy=float(rms),spectral_centroid=sc)) |
| print(f" β {len(hits)} hits"); return hits |
|
|
| |
| LABEL_RULES=[("kick",lambda lr,mr,hr,c,zcr,d:lr>0.5 and c<800), |
| ("hihat_closed",lambda lr,mr,hr,c,zcr,d:hr>0.35 and c>4000 and d<0.15), |
| ("hihat_open",lambda lr,mr,hr,c,zcr,d:hr>0.35 and c>4000 and d>=0.15), |
| ("cymbal",lambda lr,mr,hr,c,zcr,d:hr>0.25 and c>3000), |
| ("snare",lambda lr,mr,hr,c,zcr,d:mr>0.4 and zcr>0.1 and c>1000), |
| ("tom",lambda lr,mr,hr,c,zcr,d:lr>0.3 and mr>0.3 and c<1500), |
| ("bass",lambda lr,mr,hr,c,zcr,d:lr>0.6 and c<400 and d>0.2), |
| ("vocal",lambda lr,mr,hr,c,zcr,d:mr>0.5 and 500<c<3000 and zcr<0.15), |
| ("bright",lambda lr,mr,hr,c,zcr,d:c>2500),("mid",lambda lr,mr,hr,c,zcr,d:c>800)] |
| def classify_hit(hit): |
| y,sr=hit.audio,hit.sr; D=np.abs(librosa.stft(y,n_fft=2048)) |
| f=librosa.fft_frequencies(sr=sr,n_fft=2048) |
| le=np.sum(D[(f>=20)&(f<200)]**2); me=np.sum(D[(f>=200)&(f<4000)]**2) |
| he=np.sum(D[(f>=4000)]**2); t=le+me+he+1e-10; lr,mr,hr=le/t,me/t,he/t |
| zcr=float(librosa.feature.zero_crossing_rate(y=y).mean()) |
| for n,fn in LABEL_RULES: |
| if fn(lr,mr,hr,hit.spectral_centroid,zcr,hit.duration): return n |
| return "other" |
| def classify_hits(hits): |
| print(f"[Stage 3] Classifying {len(hits)}...") |
| for h in hits: h.label=classify_hit(h) |
| c=defaultdict(int) |
| for h in hits: c[h.label]+=1 |
| for l,n in sorted(c.items(),key=lambda x:-x[1]): print(f" {l}: {n}") |
| return hits |
|
|
| |
| def ncc_max(a,b): |
| n=min(len(a),len(b)); a,b=a[:n].copy(),b[:n].copy(); a-=a.mean(); b-=b.mean() |
| norm=np.sqrt(np.dot(a,a)*np.dot(b,b)) |
| if norm<1e-10: return 0.0 |
| return float(np.max(np.abs(fftconvolve(a,b[::-1],mode='full'))))/norm |
| def build_ncc_distance_matrix(hits,max_compare_samples=8820): |
| key=("ncc_dist",tuple(_audio_hash(h.audio) for h in hits),max_compare_samples) |
| c=cache_get(key) |
| if c is not None: print(f" Cached NCC matrix"); return c |
| N=len(hits); D=np.zeros((N,N),dtype=np.float32) |
| for i in range(N): |
| ai=hits[i].audio[:max_compare_samples] |
| for j in range(i+1,N): D[i,j]=D[j,i]=max(0.0,1.0-ncc_max(ai,hits[j].audio[:max_compare_samples])) |
| return cache_set(key,D) |
| def _labels_to_clusters(labels,hits): |
| cm=defaultdict(list) |
| for i,l in enumerate(labels): cm[l].append(i) |
| clusters=[] |
| for _,idx in sorted(cm.items()): |
| v=defaultdict(int) |
| for i in idx: v[hits[i].label]+=1 |
| maj=max(v,key=v.get); ex=sum(1 for c in clusters if c.label.rsplit('_',1)[0]==maj) |
| clusters.append(Cluster(cluster_id=len(clusters),label=f"{maj}_{ex}",hits=[hits[i] for i in idx])) |
| clusters.sort(key=lambda c:c.count,reverse=True) |
| for i,c in enumerate(clusters): c.cluster_id=i |
| return clusters |
| def cluster_hits(hits,ncc_threshold=0.80,max_compare_ms=0,target_min=0,target_max=0,linkage='average'): |
| from sklearn.cluster import AgglomerativeClustering as AC |
| if not hits: return [] |
| N=len(hits); sr=hits[0].sr |
| if N==1: return [Cluster(cluster_id=0,label=f"{hits[0].label}_0",hits=[hits[0]])] |
| ms=max(int(0.03*sr),int(np.median([len(h.audio) for h in hits]))) if max_compare_ms<=0 else int(max_compare_ms/1000.0*sr) |
| print(f"[Stage 4] NCC ({N} hits, {ms/sr*1000:.0f}ms, {linkage})...") |
| print(f" {N*(N-1)//2} pairs..."); D=build_ncc_distance_matrix(hits,max_compare_samples=ms) |
| use_t=target_min>0 and target_max>0 and target_max>=target_min |
| tmin,tmax=max(1,min(target_min or 1,N)),max(max(1,target_min or 1),min(target_max or N,N)) |
| if use_t: |
| print(f" Target: {tmin}β{tmax}") |
| lo,hi=0.001,1.0; bl,bn,bd=None,-1,0.5 |
| for _ in range(30): |
| mid=(lo+hi)/2; lb=AC(n_clusters=None,distance_threshold=max(0.001,mid),metric='precomputed',linkage=linkage).fit_predict(D) |
| n=len(set(lb)) |
| if tmin<=n<=tmax: bl,bn,bd=lb,n,mid; break |
| elif n>tmax: lo=mid |
| else: hi=mid |
| if bl is None or abs(n-(tmin+tmax)/2)<abs(bn-(tmin+tmax)/2): bl,bn,bd=lb,n,mid |
| if bn<tmin or bn>tmax: |
| tm=min((tmin+tmax)//2,N-1); print(f" Fallback n={tm}") |
| try: bl=AC(n_clusters=tm,metric='precomputed',linkage=linkage).fit_predict(D); bn=tm |
| except: pass |
| labels=bl; print(f" β {bn}") |
| else: |
| dt=max(0.001,1.0-ncc_threshold); print(f" Fixed distβ€{dt:.3f}") |
| labels=AC(n_clusters=None,distance_threshold=dt,metric='precomputed',linkage=linkage).fit_predict(D) |
| print(f" β {len(set(labels))} clusters") |
| cl=_labels_to_clusters(labels,hits) |
| for c in cl: print(f" {c.label}: {c.count}") |
| return cl |
|
|
| |
| def sample_quality_score(y,sr,label="other"): |
| import scipy.stats |
| re=librosa.feature.rms(y=y,frame_length=512,hop_length=128)[0] |
| if len(re)>=10: |
| pk=np.argmax(re); po=re[pk:] |
| c1=max(0,1.0-np.mean(po[-max(3,len(po)//5):])/( re[pk]+1e-8)*5) |
| c2=(max(0,scipy.stats.linregress(np.arange(len(po)),np.log(po+1e-8))[2]**2) if scipy.stats.linregress(np.arange(len(po)),np.log(po+1e-8))[0]<0 else 0.0) if len(po)>=5 else 0.0 |
| else: c1,c2=0.5,0.0 |
| comp=c1*0.6+c2*0.4; snr=10*np.log10(np.percentile(y**2,99)/(np.percentile(y**2,10)+1e-12)) |
| ns=np.clip((snr-10)/40,0,1); ons=librosa.onset.onset_detect(y=y,sr=sr,units='samples',backtrack=True) |
| if len(ons)>0: |
| o=int(ons[0]); pre=y[max(0,o-int(sr*.02)):o]; sig=y[o:o+int(sr*.1)] |
| np2=np.clip((-10*np.log10(np.mean(pre**2+1e-12)/np.mean(sig**2+1e-12))-5)/30,0,1) if len(pre)>10 and len(sig)>10 else 0.5 |
| else: np2=0.5 |
| cl=ns*0.5+np2*0.5; oe=librosa.onset.onset_strength(y=y,sr=sr) |
| oq=float(np.clip((float(np.max(oe)/(np.mean(oe)+1e-8)) if len(oe)>1 else 1.0-1.0)/5.0,0,1)) |
| return {'total':float((comp*0.30+cl*0.40+oq*0.20+0.5*0.10)*100),'completeness':float(comp),'cleanness':float(cl),'onset_quality':float(oq)} |
| def select_best(clusters): |
| print(f"[Stage 5] Selecting best...") |
| for c in clusters: |
| if c.count<=1: c.best_hit_idx=0; continue |
| c.best_hit_idx=int(np.argmax([sample_quality_score(h.audio,h.sr,c.label.rsplit('_',1)[0])['total'] for h in c.hits])) |
|
|
| |
| def synthesize_from_cluster(cluster): |
| if cluster.count<2: return None |
| tl=int(np.median([len(h.audio) for h in cluster.hits])); al,wt=[],[]; pp=None |
| for i,h in enumerate(cluster.hits): |
| a=h.audio.copy(); p=np.argmax(np.abs(a)) |
| if pp is None: pp=p |
| s=pp-p |
| if s>0: a=np.pad(a,(s,0)) |
| elif s<0: a=a[-s:] |
| a=a[:tl] if len(a)>=tl else np.pad(a,(0,tl-len(a))) |
| pk=np.abs(a).max() |
| if pk>0: a=a/pk |
| al.append(a); wt.append(2.0 if i==cluster.best_hit_idx else 1.0) |
| al=np.array(al); w=np.array(wt); w/=w.sum(); sy=np.average(al,axis=0,weights=w); pk=np.abs(sy).max() |
| return (sy*0.95/pk).astype(np.float32) if pk>0 else sy.astype(np.float32) |
|
|
| |
| def build_midi(clusters,bpm=120.0): |
| import pretty_midi; pm=pretty_midi.PrettyMIDI(initial_tempo=bpm) |
| for i,c in enumerate(clusters): c.midi_note=min(36+i,127) |
| inst=pretty_midi.Instrument(program=0,is_drum=True,name='Samples'); pm.instruments.append(inst) |
| for c in clusters: |
| for h in c.hits: |
| inst.notes.append(pretty_midi.Note(velocity=max(1,min(127,int(h.rms_energy/0.3*127))), |
| pitch=c.midi_note,start=h.onset_time,end=h.onset_time+max(h.duration,0.05))) |
| inst.notes.sort(key=lambda n:n.start); return pm |
| def export_midi(clusters,path,bpm=120.0): |
| pm=build_midi(clusters,bpm); pm.write(path); print(f" β MIDI: {len(pm.instruments[0].notes)} notes"); return pm |
| def detect_bpm(y,sr): |
| ck=("bpm",_audio_hash(y),sr); c=cache_get(ck) |
| if c: return c |
| oe=librosa.onset.onset_strength(y=y,sr=sr,aggregate=np.median) |
| bpm=float(librosa.feature.tempo(onset_envelope=oe,sr=sr).item()) |
| _,beats=librosa.beat.beat_track(onset_envelope=oe,sr=sr,units='time') |
| if len(beats)>2: |
| ibi=60.0/float(np.median(np.diff(beats))) |
| for c in [bpm,ibi]: |
| if 70<=c<=200: bpm=c; break |
| else: |
| if bpm<70: bpm*=2 |
| elif bpm>200: bpm/=2 |
| return cache_set(ck,round(bpm,1)) |
| def render_midi_with_samples(clusters,sr=44100): |
| me=max((h.onset_time+h.duration for c in clusters for h in c.hits),default=1.0) |
| buf=np.zeros(int((me+1.0)*sr),dtype=np.float64) |
| for c in clusters: |
| s=c.best_hit.audio.astype(np.float64); re=c.best_hit.rms_energy if c.best_hit.rms_energy>0 else 0.1 |
| for h in c.hits: |
| vs=min(2.0,h.rms_energy/(re+1e-8))**0.5; i=int(h.onset_time*sr); e=i+len(s) |
| if e>len(buf): buf=np.concatenate([buf,np.zeros(e-len(buf))]) |
| buf[i:e]+=s*vs |
| pk=np.abs(buf).max(); return (buf/pk*0.9).astype(np.float32) if pk>1e-8 else buf.astype(np.float32) |
| def build_sample_map(clusters): |
| return {c.midi_note:{'label':c.label,'count':c.count,'duration_ms':int(c.best_hit.duration*1000)} for c in clusters} |
| def build_archive(clusters,bpm,sr,midi_path=None,rendered_audio=None): |
| import zipfile,tempfile,io; zp=tempfile.mktemp(suffix='.zip') |
| idx={'bpm':round(bpm,1),'sample_rate':sr,'total_clusters':len(clusters),'total_hits':sum(c.count for c in clusters),'samples':{}} |
| with zipfile.ZipFile(zp,'w',compression=zipfile.ZIP_STORED) as zf: |
| for c in clusters: |
| b=c.best_hit; fn=f"samples/{c.label}.wav"; buf=io.BytesIO() |
| sf.write(buf,b.audio,sr,format='WAV',subtype='PCM_24'); zf.writestr(fn,buf.getvalue()) |
| ot=sorted([h.onset_time for h in c.hits]) |
| idx['samples'][c.label]={'file':fn,'classification':c.label.rsplit('_',1)[0],'midi_note':c.midi_note, |
| 'occurrences':c.count,'onset_times_sec':[round(t,4) for t in ot],'duration_sec':round(b.duration,4), |
| 'rms_energy':round(b.rms_energy,6),'spectral_centroid_hz':round(b.spectral_centroid,1)} |
| if c.synthesized is not None: |
| sf2=f"samples/{c.label}__synth.wav"; b2=io.BytesIO() |
| sf.write(b2,c.synthesized,sr,format='WAV',subtype='PCM_24'); zf.writestr(sf2,b2.getvalue()) |
| idx['samples'][c.label]['synthesized_file']=sf2 |
| zf.writestr('index.json',json.dumps(idx,indent=2)) |
| if midi_path and os.path.exists(midi_path): zf.write(midi_path,'reconstruction.mid') |
| if rendered_audio is not None: |
| rb=io.BytesIO(); sf.write(rb,rendered_audio,sr,format='WAV',subtype='PCM_16') |
| zf.writestr('rendered_reconstruction.wav',rb.getvalue()) |
| return zp |
|
|
| |
|
|
| def _spectral_envelope_corr(orig,rend,sr,n_fft=4096,hop=2048): |
| n=min(len(orig),len(rend)) |
| if n<n_fft: return 0.0 |
| eo=np.abs(librosa.stft(orig[:n],n_fft=n_fft,hop_length=hop)).mean(axis=1) |
| er=np.abs(librosa.stft(rend[:n],n_fft=n_fft,hop_length=hop)).mean(axis=1) |
| if eo.std()<1e-10 or er.std()<1e-10: return 0.0 |
| return float(np.corrcoef(eo,er)[0,1]) |
|
|
| def _rms_envelope_corr(orig,rend,sr,hop=1024): |
| n=min(len(orig),len(rend)) |
| if n<hop*4: return 0.0 |
| ro=librosa.feature.rms(y=orig[:n],hop_length=hop)[0]; rr=librosa.feature.rms(y=rend[:n],hop_length=hop)[0] |
| n2=min(len(ro),len(rr)) |
| if n2<4: return 0.0 |
| ro,rr=ro[:n2],rr[:n2] |
| if ro.std()<1e-10 or rr.std()<1e-10: return 0.0 |
| return float(np.corrcoef(ro,rr)[0,1]) |
|
|
| def _reconstruction_score(orig,rend,sr): |
| sc=_spectral_envelope_corr(orig,rend,sr); rc=_rms_envelope_corr(orig,rend,sr) |
| lr=min(len(rend),len(orig))/(max(len(rend),len(orig))+1) |
| return max(0.0,(sc*0.5+rc*0.4+lr*0.1)*100) |
|
|
|
|
| def auto_tune(stem_audio, sr, mode="auto", locks=None, log_fn=None): |
| """Find best extraction parameters. Locked params are held constant. |
| |
| locks: dict mapping param names to fixed values. Supported keys: |
| 'onset_delta', 'energy_threshold_db', 'min_gap', |
| 'target_min', 'target_max' |
| Any key present in locks will NOT be swept β its value is used as-is. |
| |
| Example: locks={'target_min': 5, 'target_max': 20} |
| β cluster count will always be between 5 and 20, only onset params are tuned. |
| |
| Returns: (best_params, best_score, log_lines) |
| """ |
| locks = locks or {} |
| log = [] |
| def _log(m): |
| log.append(m) |
| if log_fn: log_fn(m) |
| print(m) |
|
|
| locked_names = list(locks.keys()) |
| _log(f"[Auto-tune] {len(stem_audio)/sr:.1f}s audio, locked: {locked_names or 'none'}") |
|
|
| |
| base_onset_configs = [ |
| {"onset_delta": 0.08, "energy_threshold_db": -40, "min_gap": 0.02}, |
| {"onset_delta": 0.10, "energy_threshold_db": -35, "min_gap": 0.025}, |
| {"onset_delta": 0.12, "energy_threshold_db": -35, "min_gap": 0.03}, |
| {"onset_delta": 0.15, "energy_threshold_db": -30, "min_gap": 0.04}, |
| {"onset_delta": 0.20, "energy_threshold_db": -28, "min_gap": 0.05}, |
| {"onset_delta": 0.25, "energy_threshold_db": -25, "min_gap": 0.06}, |
| ] |
|
|
| |
| onset_configs = [] |
| seen = set() |
| for oc in base_onset_configs: |
| for k in ['onset_delta', 'energy_threshold_db', 'min_gap']: |
| if k in locks: |
| oc[k] = locks[k] |
| |
| key = (oc['onset_delta'], oc['energy_threshold_db'], oc['min_gap']) |
| if key not in seen: |
| seen.add(key) |
| onset_configs.append(oc) |
|
|
| |
| if 'target_min' in locks and 'target_max' in locks: |
| tmin_locked, tmax_locked = int(locks['target_min']), int(locks['target_max']) |
| |
| cluster_targets = sorted(set([tmin_locked, tmax_locked, |
| (tmin_locked+tmax_locked)//2, |
| tmin_locked + (tmax_locked-tmin_locked)//3, |
| tmin_locked + 2*(tmax_locked-tmin_locked)//3])) |
| cluster_targets = [t for t in cluster_targets if tmin_locked <= t <= tmax_locked] |
| _log(f" Cluster targets (locked {tmin_locked}β{tmax_locked}): {cluster_targets}") |
| elif 'target_min' in locks: |
| tmin_locked = int(locks['target_min']) |
| cluster_targets = [t for t in [3,5,8,10,15,20,30] if t >= tmin_locked] |
| if not cluster_targets: cluster_targets = [tmin_locked] |
| elif 'target_max' in locks: |
| tmax_locked = int(locks['target_max']) |
| cluster_targets = [t for t in [3,5,8,10,15,20,30] if t <= tmax_locked] |
| if not cluster_targets: cluster_targets = [tmax_locked] |
| else: |
| cluster_targets = [3, 5, 8, 10, 15, 20, 30] |
|
|
| best_score, best_params = -1, {} |
|
|
| for oc_idx, oc in enumerate(onset_configs): |
| _log(f"\n [{oc_idx+1}/{len(onset_configs)}] delta={oc['onset_delta']}, " |
| f"energy={oc['energy_threshold_db']}dB, gap={oc['min_gap']}") |
| hits = detect_onsets(stem_audio, sr, mode=mode, **oc) |
| if len(hits) < 2: _log(f" {len(hits)} hits, skip"); continue |
| hits = classify_hits(hits) |
|
|
| for nt in cluster_targets: |
| if nt >= len(hits): continue |
| clusters = cluster_hits(hits, target_min=nt, target_max=nt, linkage='average') |
| if not clusters: continue |
| for c in clusters: |
| if c.count<=1: c.best_hit_idx=0 |
| else: c.best_hit_idx=int(np.argmax([h.rms_energy for h in c.hits])) |
| rendered = render_midi_with_samples(clusters, sr=sr) |
| score = _reconstruction_score(stem_audio, rendered, sr) |
| if score > best_score: |
| best_score = score |
| best_params = {**oc, 'n_clusters': len(clusters), |
| 'target_min': nt, 'target_max': nt} |
| _log(f" β
target={nt} β {len(clusters)} cl, score={score:.1f} BEST") |
| else: |
| _log(f" target={nt} β {len(clusters)} cl, score={score:.1f}") |
|
|
| |
| if best_params: |
| bt = best_params.get('target_min', 10) |
| fine_oc = {k: best_params[k] for k in ['onset_delta','energy_threshold_db','min_gap']} |
| _log(f"\n Fine-tuning Β±3 around target={bt}...") |
| hits = detect_onsets(stem_audio, sr, mode=mode, **fine_oc) |
| if len(hits) >= 2: |
| hits = classify_hits(hits) |
| lo = max(2, int(locks.get('target_min', bt-3))) |
| hi = min(len(hits)-1, int(locks.get('target_max', bt+3))) |
| for ft in range(lo, hi+1): |
| clusters = cluster_hits(hits, target_min=ft, target_max=ft, linkage='average') |
| if not clusters: continue |
| for c in clusters: |
| if c.count<=1: c.best_hit_idx=0 |
| else: c.best_hit_idx=int(np.argmax([h.rms_energy for h in c.hits])) |
| rendered = render_midi_with_samples(clusters, sr=sr) |
| score = _reconstruction_score(stem_audio, rendered, sr) |
| if score > best_score: |
| best_score=score |
| best_params={**fine_oc,'n_clusters':len(clusters),'target_min':ft,'target_max':ft} |
| _log(f" β
target={ft} β {len(clusters)} cl, score={score:.1f} BEST") |
|
|
| _log(f"\n[Auto-tune] Best: score={best_score:.1f}, params={best_params}") |
| return best_params, best_score, log |
|
|