#!/usr/bin/env python3 """ Sample Extractor v8 — Auto-tuning with parameter locking. auto_tune() accepts a `locks` dict: any param name mapped to its fixed value will be held constant during the search. All other params are swept freely. """ import argparse, json, os, sys, warnings, hashlib from collections import defaultdict from dataclasses import dataclass, field from pathlib import Path from typing import Optional import librosa, numpy as np, soundfile as sf, torch from scipy.signal import fftconvolve warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) @dataclass class Hit: audio: np.ndarray; sr: int; onset_time: float; duration: float; index: int rms_energy: float = 0.0; spectral_centroid: float = 0.0 label: str = ""; embedding: Optional[np.ndarray] = None; cluster_id: int = -1 def save(self, path): sf.write(path, self.audio, self.sr, subtype='PCM_24') @dataclass class Cluster: cluster_id: int; label: str; hits: list = field(default_factory=list) best_hit_idx: int = 0; synthesized: Optional[np.ndarray] = None; midi_note: int = 60 @property def best_hit(self): return self.hits[self.best_hit_idx] @property def count(self): return len(self.hits) DEMUCS_MODELS = ["htdemucs","htdemucs_ft","htdemucs_6s","mdx","mdx_extra","mdx_extra_q"] DEMUCS_STEMS = { "htdemucs":["drums","bass","other","vocals"],"htdemucs_ft":["drums","bass","other","vocals"], "htdemucs_6s":["drums","bass","other","vocals","guitar","piano"], "mdx":["drums","bass","other","vocals"],"mdx_extra":["drums","bass","other","vocals"], "mdx_extra_q":["drums","bass","other","vocals"], } # ─── Cache ──────────────────────────────────────────────────────────────────── _cache = {} def _audio_hash(a): return hashlib.md5(a[:4000].tobytes()).hexdigest() def cache_get(k): return _cache.get(k) def cache_set(k,v): _cache[k]=v; return v def cache_clear(): _cache.clear() # ─── Stage 1 ────────────────────────────────────────────────────────────────── def extract_stem(audio_path,stem="drums",device="cpu",model_name="htdemucs_ft",shifts=1,overlap=0.25): if stem=="all": y,sr=librosa.load(audio_path,sr=44100,mono=True); return y.astype(np.float32),sr with open(audio_path,'rb') as f: fh=hashlib.md5(f.read(200000)).hexdigest() ck=("stem",fh,stem,model_name,shifts,overlap); c=cache_get(ck) if c: print(f"[Stage 1] Cached {stem}"); return c from demucs.pretrained import get_model; from demucs.apply import apply_model print(f"[Stage 1] {stem} with {model_name}...") model=get_model(model_name); model.eval().to(device); sr=model.samplerate if stem not in model.sources: raise ValueError(f"'{stem}' not in {model.sources}") a,_=librosa.load(audio_path,sr=sr,mono=False) if a.ndim==1: a=np.stack([a,a]) elif a.shape[0]>2: a=a[:2] elif a.shape[0]==1: a=np.concatenate([a,a],axis=0) wav=torch.from_numpy(a).float().unsqueeze(0).to(device) with torch.no_grad(): src=apply_model(model,wav,device=device,shifts=shifts,split=True,overlap=overlap) r=src[0,model.sources.index(stem)].mean(dim=0).cpu().numpy() print(f" ✓ {len(r)/sr:.1f}s"); return cache_set(ck,(r.astype(np.float32),sr)) # ─── Stage 2 ────────────────────────────────────────────────────────────────── def detect_onsets(y,sr,pre_pad=0.005,min_dur=0.02,max_dur=1.5,min_gap=0.03, energy_threshold_db=-35.0,mode="auto",backtrack=True,onset_delta=0.12): print(f"[Stage 2] Onsets (delta={onset_delta}, energy≥{energy_threshold_db}dB)...") if mode=="percussive": oe=librosa.onset.onset_strength(y=y,sr=sr,aggregate=np.median,fmax=8000) elif mode=="harmonic": yh,_=librosa.effects.hpss(y); oe=librosa.onset.onset_strength(y=yh,sr=sr,fmax=8000,lag=2,max_size=3) elif mode=="broadband": oe=librosa.onset.onset_strength(y=y,sr=sr) else: yh,_=librosa.effects.hpss(y) def _n(x): m=x.max(); return x/m if m>0 else x oe=np.maximum.reduce([_n(librosa.onset.onset_strength(y=y,sr=sr,fmin=20,fmax=250,aggregate=np.median)), _n(librosa.onset.onset_strength(y=y,sr=sr,fmin=250,fmax=4000,aggregate=np.median)), _n(librosa.onset.onset_strength(y=y,sr=sr,fmin=4000,fmax=min(sr//2,20000),aggregate=np.median)), _n(librosa.onset.onset_strength(y=yh,sr=sr,lag=2))]) w=max(1,int(min_gap*sr/512)) fr=librosa.onset.onset_detect(onset_envelope=oe,sr=sr,wait=w,pre_avg=3,post_avg=3, pre_max=3,post_max=5,delta=onset_delta,backtrack=backtrack,units='frames') times=librosa.frames_to_time(fr,sr=sr); print(f" Raw: {len(times)}") thr=10**(energy_threshold_db/20); hits=[] for i,t in enumerate(times): s=max(0,int((t-pre_pad)*sr)) e=min(int(times[i+1]*sr) if i+10: seg=seg.copy(); seg[-fl:]*=np.linspace(1,0,fl) sc=float(librosa.feature.spectral_centroid(y=seg,sr=sr).mean()) hits.append(Hit(audio=seg,sr=sr,onset_time=t,duration=len(seg)/sr, index=len(hits),rms_energy=float(rms),spectral_centroid=sc)) print(f" ✓ {len(hits)} hits"); return hits # ─── Stage 3 ────────────────────────────────────────────────────────────────── LABEL_RULES=[("kick",lambda lr,mr,hr,c,zcr,d:lr>0.5 and c<800), ("hihat_closed",lambda lr,mr,hr,c,zcr,d:hr>0.35 and c>4000 and d<0.15), ("hihat_open",lambda lr,mr,hr,c,zcr,d:hr>0.35 and c>4000 and d>=0.15), ("cymbal",lambda lr,mr,hr,c,zcr,d:hr>0.25 and c>3000), ("snare",lambda lr,mr,hr,c,zcr,d:mr>0.4 and zcr>0.1 and c>1000), ("tom",lambda lr,mr,hr,c,zcr,d:lr>0.3 and mr>0.3 and c<1500), ("bass",lambda lr,mr,hr,c,zcr,d:lr>0.6 and c<400 and d>0.2), ("vocal",lambda lr,mr,hr,c,zcr,d:mr>0.5 and 5002500),("mid",lambda lr,mr,hr,c,zcr,d:c>800)] def classify_hit(hit): y,sr=hit.audio,hit.sr; D=np.abs(librosa.stft(y,n_fft=2048)) f=librosa.fft_frequencies(sr=sr,n_fft=2048) le=np.sum(D[(f>=20)&(f<200)]**2); me=np.sum(D[(f>=200)&(f<4000)]**2) he=np.sum(D[(f>=4000)]**2); t=le+me+he+1e-10; lr,mr,hr=le/t,me/t,he/t zcr=float(librosa.feature.zero_crossing_rate(y=y).mean()) for n,fn in LABEL_RULES: if fn(lr,mr,hr,hit.spectral_centroid,zcr,hit.duration): return n return "other" def classify_hits(hits): print(f"[Stage 3] Classifying {len(hits)}...") for h in hits: h.label=classify_hit(h) c=defaultdict(int) for h in hits: c[h.label]+=1 for l,n in sorted(c.items(),key=lambda x:-x[1]): print(f" {l}: {n}") return hits # ─── Stage 4 ────────────────────────────────────────────────────────────────── def ncc_max(a,b): n=min(len(a),len(b)); a,b=a[:n].copy(),b[:n].copy(); a-=a.mean(); b-=b.mean() norm=np.sqrt(np.dot(a,a)*np.dot(b,b)) if norm<1e-10: return 0.0 return float(np.max(np.abs(fftconvolve(a,b[::-1],mode='full'))))/norm def build_ncc_distance_matrix(hits,max_compare_samples=8820): key=("ncc_dist",tuple(_audio_hash(h.audio) for h in hits),max_compare_samples) c=cache_get(key) if c is not None: print(f" Cached NCC matrix"); return c N=len(hits); D=np.zeros((N,N),dtype=np.float32) for i in range(N): ai=hits[i].audio[:max_compare_samples] for j in range(i+1,N): D[i,j]=D[j,i]=max(0.0,1.0-ncc_max(ai,hits[j].audio[:max_compare_samples])) return cache_set(key,D) def _labels_to_clusters(labels,hits): cm=defaultdict(list) for i,l in enumerate(labels): cm[l].append(i) clusters=[] for _,idx in sorted(cm.items()): v=defaultdict(int) for i in idx: v[hits[i].label]+=1 maj=max(v,key=v.get); ex=sum(1 for c in clusters if c.label.rsplit('_',1)[0]==maj) clusters.append(Cluster(cluster_id=len(clusters),label=f"{maj}_{ex}",hits=[hits[i] for i in idx])) clusters.sort(key=lambda c:c.count,reverse=True) for i,c in enumerate(clusters): c.cluster_id=i return clusters def cluster_hits(hits,ncc_threshold=0.80,max_compare_ms=0,target_min=0,target_max=0,linkage='average'): from sklearn.cluster import AgglomerativeClustering as AC if not hits: return [] N=len(hits); sr=hits[0].sr if N==1: return [Cluster(cluster_id=0,label=f"{hits[0].label}_0",hits=[hits[0]])] ms=max(int(0.03*sr),int(np.median([len(h.audio) for h in hits]))) if max_compare_ms<=0 else int(max_compare_ms/1000.0*sr) print(f"[Stage 4] NCC ({N} hits, {ms/sr*1000:.0f}ms, {linkage})...") print(f" {N*(N-1)//2} pairs..."); D=build_ncc_distance_matrix(hits,max_compare_samples=ms) use_t=target_min>0 and target_max>0 and target_max>=target_min tmin,tmax=max(1,min(target_min or 1,N)),max(max(1,target_min or 1),min(target_max or N,N)) if use_t: print(f" Target: {tmin}–{tmax}") lo,hi=0.001,1.0; bl,bn,bd=None,-1,0.5 for _ in range(30): mid=(lo+hi)/2; lb=AC(n_clusters=None,distance_threshold=max(0.001,mid),metric='precomputed',linkage=linkage).fit_predict(D) n=len(set(lb)) if tmin<=n<=tmax: bl,bn,bd=lb,n,mid; break elif n>tmax: lo=mid else: hi=mid if bl is None or abs(n-(tmin+tmax)/2)tmax: tm=min((tmin+tmax)//2,N-1); print(f" Fallback n={tm}") try: bl=AC(n_clusters=tm,metric='precomputed',linkage=linkage).fit_predict(D); bn=tm except: pass labels=bl; print(f" → {bn}") else: dt=max(0.001,1.0-ncc_threshold); print(f" Fixed dist≤{dt:.3f}") labels=AC(n_clusters=None,distance_threshold=dt,metric='precomputed',linkage=linkage).fit_predict(D) print(f" ✓ {len(set(labels))} clusters") cl=_labels_to_clusters(labels,hits) for c in cl: print(f" {c.label}: {c.count}") return cl # ─── Stage 5 ────────────────────────────────────────────────────────────────── def sample_quality_score(y,sr,label="other"): import scipy.stats re=librosa.feature.rms(y=y,frame_length=512,hop_length=128)[0] if len(re)>=10: pk=np.argmax(re); po=re[pk:] c1=max(0,1.0-np.mean(po[-max(3,len(po)//5):])/( re[pk]+1e-8)*5) c2=(max(0,scipy.stats.linregress(np.arange(len(po)),np.log(po+1e-8))[2]**2) if scipy.stats.linregress(np.arange(len(po)),np.log(po+1e-8))[0]<0 else 0.0) if len(po)>=5 else 0.0 else: c1,c2=0.5,0.0 comp=c1*0.6+c2*0.4; snr=10*np.log10(np.percentile(y**2,99)/(np.percentile(y**2,10)+1e-12)) ns=np.clip((snr-10)/40,0,1); ons=librosa.onset.onset_detect(y=y,sr=sr,units='samples',backtrack=True) if len(ons)>0: o=int(ons[0]); pre=y[max(0,o-int(sr*.02)):o]; sig=y[o:o+int(sr*.1)] np2=np.clip((-10*np.log10(np.mean(pre**2+1e-12)/np.mean(sig**2+1e-12))-5)/30,0,1) if len(pre)>10 and len(sig)>10 else 0.5 else: np2=0.5 cl=ns*0.5+np2*0.5; oe=librosa.onset.onset_strength(y=y,sr=sr) oq=float(np.clip((float(np.max(oe)/(np.mean(oe)+1e-8)) if len(oe)>1 else 1.0-1.0)/5.0,0,1)) return {'total':float((comp*0.30+cl*0.40+oq*0.20+0.5*0.10)*100),'completeness':float(comp),'cleanness':float(cl),'onset_quality':float(oq)} def select_best(clusters): print(f"[Stage 5] Selecting best...") for c in clusters: if c.count<=1: c.best_hit_idx=0; continue c.best_hit_idx=int(np.argmax([sample_quality_score(h.audio,h.sr,c.label.rsplit('_',1)[0])['total'] for h in c.hits])) # ─── Stage 6 ────────────────────────────────────────────────────────────────── def synthesize_from_cluster(cluster): if cluster.count<2: return None tl=int(np.median([len(h.audio) for h in cluster.hits])); al,wt=[],[]; pp=None for i,h in enumerate(cluster.hits): a=h.audio.copy(); p=np.argmax(np.abs(a)) if pp is None: pp=p s=pp-p if s>0: a=np.pad(a,(s,0)) elif s<0: a=a[-s:] a=a[:tl] if len(a)>=tl else np.pad(a,(0,tl-len(a))) pk=np.abs(a).max() if pk>0: a=a/pk al.append(a); wt.append(2.0 if i==cluster.best_hit_idx else 1.0) al=np.array(al); w=np.array(wt); w/=w.sum(); sy=np.average(al,axis=0,weights=w); pk=np.abs(sy).max() return (sy*0.95/pk).astype(np.float32) if pk>0 else sy.astype(np.float32) # ─── Stage 7 ────────────────────────────────────────────────────────────────── def build_midi(clusters,bpm=120.0): import pretty_midi; pm=pretty_midi.PrettyMIDI(initial_tempo=bpm) for i,c in enumerate(clusters): c.midi_note=min(36+i,127) inst=pretty_midi.Instrument(program=0,is_drum=True,name='Samples'); pm.instruments.append(inst) for c in clusters: for h in c.hits: inst.notes.append(pretty_midi.Note(velocity=max(1,min(127,int(h.rms_energy/0.3*127))), pitch=c.midi_note,start=h.onset_time,end=h.onset_time+max(h.duration,0.05))) inst.notes.sort(key=lambda n:n.start); return pm def export_midi(clusters,path,bpm=120.0): pm=build_midi(clusters,bpm); pm.write(path); print(f" ✓ MIDI: {len(pm.instruments[0].notes)} notes"); return pm def detect_bpm(y,sr): ck=("bpm",_audio_hash(y),sr); c=cache_get(ck) if c: return c oe=librosa.onset.onset_strength(y=y,sr=sr,aggregate=np.median) bpm=float(librosa.feature.tempo(onset_envelope=oe,sr=sr).item()) _,beats=librosa.beat.beat_track(onset_envelope=oe,sr=sr,units='time') if len(beats)>2: ibi=60.0/float(np.median(np.diff(beats))) for c in [bpm,ibi]: if 70<=c<=200: bpm=c; break else: if bpm<70: bpm*=2 elif bpm>200: bpm/=2 return cache_set(ck,round(bpm,1)) def render_midi_with_samples(clusters,sr=44100): me=max((h.onset_time+h.duration for c in clusters for h in c.hits),default=1.0) buf=np.zeros(int((me+1.0)*sr),dtype=np.float64) for c in clusters: s=c.best_hit.audio.astype(np.float64); re=c.best_hit.rms_energy if c.best_hit.rms_energy>0 else 0.1 for h in c.hits: vs=min(2.0,h.rms_energy/(re+1e-8))**0.5; i=int(h.onset_time*sr); e=i+len(s) if e>len(buf): buf=np.concatenate([buf,np.zeros(e-len(buf))]) buf[i:e]+=s*vs pk=np.abs(buf).max(); return (buf/pk*0.9).astype(np.float32) if pk>1e-8 else buf.astype(np.float32) def build_sample_map(clusters): return {c.midi_note:{'label':c.label,'count':c.count,'duration_ms':int(c.best_hit.duration*1000)} for c in clusters} def build_archive(clusters,bpm,sr,midi_path=None,rendered_audio=None): import zipfile,tempfile,io; zp=tempfile.mktemp(suffix='.zip') idx={'bpm':round(bpm,1),'sample_rate':sr,'total_clusters':len(clusters),'total_hits':sum(c.count for c in clusters),'samples':{}} with zipfile.ZipFile(zp,'w',compression=zipfile.ZIP_STORED) as zf: for c in clusters: b=c.best_hit; fn=f"samples/{c.label}.wav"; buf=io.BytesIO() sf.write(buf,b.audio,sr,format='WAV',subtype='PCM_24'); zf.writestr(fn,buf.getvalue()) ot=sorted([h.onset_time for h in c.hits]) idx['samples'][c.label]={'file':fn,'classification':c.label.rsplit('_',1)[0],'midi_note':c.midi_note, 'occurrences':c.count,'onset_times_sec':[round(t,4) for t in ot],'duration_sec':round(b.duration,4), 'rms_energy':round(b.rms_energy,6),'spectral_centroid_hz':round(b.spectral_centroid,1)} if c.synthesized is not None: sf2=f"samples/{c.label}__synth.wav"; b2=io.BytesIO() sf.write(b2,c.synthesized,sr,format='WAV',subtype='PCM_24'); zf.writestr(sf2,b2.getvalue()) idx['samples'][c.label]['synthesized_file']=sf2 zf.writestr('index.json',json.dumps(idx,indent=2)) if midi_path and os.path.exists(midi_path): zf.write(midi_path,'reconstruction.mid') if rendered_audio is not None: rb=io.BytesIO(); sf.write(rb,rendered_audio,sr,format='WAV',subtype='PCM_16') zf.writestr('rendered_reconstruction.wav',rb.getvalue()) return zp # ─── Auto-tuner with parameter locking ─────────────────────────────────────── def _spectral_envelope_corr(orig,rend,sr,n_fft=4096,hop=2048): n=min(len(orig),len(rend)) if n= tmin_locked] if not cluster_targets: cluster_targets = [tmin_locked] elif 'target_max' in locks: tmax_locked = int(locks['target_max']) cluster_targets = [t for t in [3,5,8,10,15,20,30] if t <= tmax_locked] if not cluster_targets: cluster_targets = [tmax_locked] else: cluster_targets = [3, 5, 8, 10, 15, 20, 30] best_score, best_params = -1, {} for oc_idx, oc in enumerate(onset_configs): _log(f"\n [{oc_idx+1}/{len(onset_configs)}] delta={oc['onset_delta']}, " f"energy={oc['energy_threshold_db']}dB, gap={oc['min_gap']}") hits = detect_onsets(stem_audio, sr, mode=mode, **oc) if len(hits) < 2: _log(f" {len(hits)} hits, skip"); continue hits = classify_hits(hits) for nt in cluster_targets: if nt >= len(hits): continue clusters = cluster_hits(hits, target_min=nt, target_max=nt, linkage='average') if not clusters: continue for c in clusters: if c.count<=1: c.best_hit_idx=0 else: c.best_hit_idx=int(np.argmax([h.rms_energy for h in c.hits])) rendered = render_midi_with_samples(clusters, sr=sr) score = _reconstruction_score(stem_audio, rendered, sr) if score > best_score: best_score = score best_params = {**oc, 'n_clusters': len(clusters), 'target_min': nt, 'target_max': nt} _log(f" ★ target={nt} → {len(clusters)} cl, score={score:.1f} BEST") else: _log(f" target={nt} → {len(clusters)} cl, score={score:.1f}") # Fine-tune around best if best_params: bt = best_params.get('target_min', 10) fine_oc = {k: best_params[k] for k in ['onset_delta','energy_threshold_db','min_gap']} _log(f"\n Fine-tuning ±3 around target={bt}...") hits = detect_onsets(stem_audio, sr, mode=mode, **fine_oc) if len(hits) >= 2: hits = classify_hits(hits) lo = max(2, int(locks.get('target_min', bt-3))) hi = min(len(hits)-1, int(locks.get('target_max', bt+3))) for ft in range(lo, hi+1): clusters = cluster_hits(hits, target_min=ft, target_max=ft, linkage='average') if not clusters: continue for c in clusters: if c.count<=1: c.best_hit_idx=0 else: c.best_hit_idx=int(np.argmax([h.rms_energy for h in c.hits])) rendered = render_midi_with_samples(clusters, sr=sr) score = _reconstruction_score(stem_audio, rendered, sr) if score > best_score: best_score=score best_params={**fine_oc,'n_clusters':len(clusters),'target_min':ft,'target_max':ft} _log(f" ★ target={ft} → {len(clusters)} cl, score={score:.1f} BEST") _log(f"\n[Auto-tune] Best: score={best_score:.1f}, params={best_params}") return best_params, best_score, log