Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Drum Sample Extractor Pipeline | |
| =============================== | |
| Extracts individual drum samples from an audio file through: | |
| 1. STEM SEPARATION β HTDemucs (v4 fine-tuned) isolates the drum track | |
| 2. ONSET DETECTION β librosa detects individual hit boundaries | |
| 3. INTRA-DRUM SEP β Spectral band splitting + optional AudioSep for overlapping sounds | |
| 4. CLUSTERING β CLAP embeddings + auto-K KMeans groups identical hits | |
| 5. SELECTION β Best representative per cluster (centroid-nearest + highest energy) | |
| 6. SYNTHESIS (opt) β Weighted average of cluster members for an "ideal" sample | |
| Usage: | |
| python drum_extractor.py input.mp3 --output-dir ./samples | |
| python drum_extractor.py input.wav --output-dir ./samples --no-gpu | |
| python drum_extractor.py input.mp3 --output-dir ./samples --use-audiosep | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| import warnings | |
| from collections import defaultdict | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Optional | |
| import librosa | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Data structures | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class DrumHit: | |
| """A single detected drum hit.""" | |
| audio: np.ndarray # mono waveform | |
| sr: int # sample rate | |
| onset_time: float # onset time in seconds (in the drum stem) | |
| duration: float # duration in seconds | |
| index: int # sequential index | |
| rms_energy: float = 0.0 | |
| spectral_centroid: float = 0.0 | |
| rough_label: str = "" # spectral rough label: kick/snare/hihat/other | |
| embedding: Optional[np.ndarray] = None | |
| cluster_id: int = -1 | |
| def save(self, path: str): | |
| sf.write(path, self.audio, self.sr, subtype='PCM_24') | |
| class DrumCluster: | |
| """A cluster of similar drum hits.""" | |
| cluster_id: int | |
| label: str # e.g. "kick_0", "snare_1" | |
| hits: list = field(default_factory=list) | |
| best_hit_idx: int = 0 # index into self.hits | |
| synthesized: Optional[np.ndarray] = None | |
| def best_hit(self) -> DrumHit: | |
| return self.hits[self.best_hit_idx] | |
| def count(self) -> int: | |
| return len(self.hits) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Stage 1: Drum stem extraction via Demucs | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def extract_drums_demucs(audio_path: str, device: str = "cpu") -> tuple[np.ndarray, int]: | |
| """Extract drum stem using HTDemucs v4 (fine-tuned).""" | |
| from demucs.pretrained import get_model | |
| from demucs.apply import apply_model | |
| print("=" * 60) | |
| print("STAGE 1: Extracting drum stem with HTDemucs") | |
| print("=" * 60) | |
| # Try htdemucs_ft first (better drums), fall back to htdemucs | |
| for model_name in ["htdemucs_ft", "htdemucs"]: | |
| try: | |
| model = get_model(model_name) | |
| print(f" Loaded model: {model_name}") | |
| break | |
| except Exception as e: | |
| print(f" Could not load {model_name}: {e}") | |
| else: | |
| raise RuntimeError("Could not load any Demucs model") | |
| model.eval() | |
| model.to(device) | |
| target_sr = model.samplerate # 44100 | |
| # Load audio using librosa (works without FFmpeg system libs) | |
| # librosa returns (samples, sr) as mono by default; load as-is for channel control | |
| import librosa as _lr | |
| audio_np, sr = _lr.load(audio_path, sr=target_sr, mono=False) | |
| # audio_np: (channels, samples) or (samples,) if mono | |
| if audio_np.ndim == 1: | |
| audio_np = np.stack([audio_np, audio_np]) # mono β stereo | |
| elif audio_np.shape[0] == 1: | |
| audio_np = np.concatenate([audio_np, audio_np], axis=0) | |
| elif audio_np.shape[0] > 2: | |
| audio_np = audio_np[:2] | |
| wav = torch.from_numpy(audio_np).float() # [2, T] | |
| wav = wav.unsqueeze(0).to(device) # [1, 2, T] | |
| print(f" Audio: {wav.shape[-1] / target_sr:.1f}s, {target_sr}Hz") | |
| # Separate | |
| with torch.no_grad(): | |
| sources = apply_model(model, wav, device=device, shifts=1, | |
| split=True, overlap=0.25, progress=True) | |
| # sources: [1, n_sources, 2, T] | |
| stem_names = model.sources # e.g. ['drums', 'bass', 'other', 'vocals'] | |
| drums_idx = stem_names.index('drums') | |
| drums_wav = sources[0, drums_idx] # [2, T] | |
| # Convert to mono numpy | |
| drums_mono = drums_wav.mean(dim=0).cpu().numpy() | |
| print(f" β Extracted drums: {len(drums_mono) / target_sr:.1f}s") | |
| return drums_mono, target_sr | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Stage 2: Onset detection & hit segmentation | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def detect_onsets(y: np.ndarray, sr: int, | |
| pre_pad: float = 0.005, | |
| min_hit_dur: float = 0.03, | |
| max_hit_dur: float = 0.8, | |
| min_gap: float = 0.02, | |
| energy_threshold_db: float = -40.0) -> list[DrumHit]: | |
| """Detect drum hit onsets and segment into individual hits.""" | |
| print("\n" + "=" * 60) | |
| print("STAGE 2: Detecting drum hit onsets") | |
| print("=" * 60) | |
| # Multi-band onset detection for better drum coverage | |
| # Low band (kick): 20-250 Hz | |
| # Mid band (snare/toms): 250-4000 Hz | |
| # High band (cymbals): 4000+ Hz | |
| onset_env_low = librosa.onset.onset_strength( | |
| y=y, sr=sr, fmin=20, fmax=250, aggregate=np.median | |
| ) | |
| onset_env_mid = librosa.onset.onset_strength( | |
| y=y, sr=sr, fmin=250, fmax=4000, aggregate=np.median | |
| ) | |
| onset_env_high = librosa.onset.onset_strength( | |
| y=y, sr=sr, fmin=4000, fmax=sr // 2, aggregate=np.median | |
| ) | |
| # Combine: normalize each band, then take max across bands | |
| def norm(x): | |
| mx = x.max() | |
| return x / mx if mx > 0 else x | |
| onset_env = np.maximum(norm(onset_env_low), | |
| np.maximum(norm(onset_env_mid), norm(onset_env_high))) | |
| # Detect onsets | |
| wait_frames = max(1, int(min_gap * sr / 512)) # hop_length=512 default | |
| onsets_frames = librosa.onset.onset_detect( | |
| onset_envelope=onset_env, | |
| sr=sr, | |
| wait=wait_frames, | |
| pre_avg=3, | |
| post_avg=3, | |
| pre_max=3, | |
| post_max=5, | |
| backtrack=True, | |
| units='frames' | |
| ) | |
| onset_times = librosa.frames_to_time(onsets_frames, sr=sr) | |
| print(f" Raw onsets detected: {len(onset_times)}") | |
| # Segment into hits | |
| hits = [] | |
| energy_threshold = 10 ** (energy_threshold_db / 20) | |
| for i, t in enumerate(onset_times): | |
| start_sample = max(0, int((t - pre_pad) * sr)) | |
| # End = next onset or max_hit_dur, whichever is shorter | |
| if i + 1 < len(onset_times): | |
| next_onset_sample = int(onset_times[i + 1] * sr) | |
| end_sample = min(next_onset_sample, start_sample + int(max_hit_dur * sr)) | |
| else: | |
| end_sample = min(len(y), start_sample + int(max_hit_dur * sr)) | |
| segment = y[start_sample:end_sample] | |
| # Skip too-short or too-quiet hits | |
| if len(segment) < int(min_hit_dur * sr): | |
| continue | |
| rms = np.sqrt(np.mean(segment ** 2)) | |
| if rms < energy_threshold: | |
| continue | |
| # Apply a quick fade-out to avoid clicks | |
| fade_len = min(int(0.005 * sr), len(segment) // 4) | |
| if fade_len > 0: | |
| segment = segment.copy() | |
| segment[-fade_len:] *= np.linspace(1, 0, fade_len) | |
| # Compute features | |
| spectral_centroid = float(librosa.feature.spectral_centroid( | |
| y=segment, sr=sr | |
| ).mean()) | |
| hit = DrumHit( | |
| audio=segment, | |
| sr=sr, | |
| onset_time=t, | |
| duration=len(segment) / sr, | |
| index=len(hits), | |
| rms_energy=float(rms), | |
| spectral_centroid=spectral_centroid, | |
| ) | |
| hits.append(hit) | |
| print(f" β Valid hits after filtering: {len(hits)}") | |
| return hits | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Stage 3: Rough spectral classification + optional intra-drum separation | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def rough_spectral_label(hit: DrumHit) -> str: | |
| """Assign a rough drum type label based on spectral characteristics.""" | |
| y, sr = hit.audio, hit.sr | |
| # Spectral centroid (mean frequency) | |
| centroid = hit.spectral_centroid | |
| # Energy distribution across bands | |
| D = np.abs(librosa.stft(y, n_fft=2048)) | |
| freqs = librosa.fft_frequencies(sr=sr, n_fft=2048) | |
| low_energy = np.sum(D[(freqs >= 20) & (freqs < 200)] ** 2) | |
| mid_energy = np.sum(D[(freqs >= 200) & (freqs < 4000)] ** 2) | |
| high_energy = np.sum(D[(freqs >= 4000)] ** 2) | |
| total = low_energy + mid_energy + high_energy + 1e-10 | |
| low_ratio = low_energy / total | |
| mid_ratio = mid_energy / total | |
| high_ratio = high_energy / total | |
| # Zero crossing rate (percussive = high) | |
| zcr = float(librosa.feature.zero_crossing_rate(y=y).mean()) | |
| # Decision tree | |
| if low_ratio > 0.5 and centroid < 800: | |
| return "kick" | |
| elif high_ratio > 0.35 and centroid > 4000: | |
| if hit.duration < 0.15: | |
| return "hihat_closed" | |
| else: | |
| return "hihat_open" | |
| elif high_ratio > 0.25 and centroid > 3000: | |
| return "cymbal" | |
| elif mid_ratio > 0.4 and zcr > 0.1 and centroid > 1000: | |
| return "snare" | |
| elif low_ratio > 0.3 and mid_ratio > 0.3: | |
| return "tom" | |
| elif centroid > 2500: | |
| return "perc_high" | |
| else: | |
| return "perc_low" | |
| def spectral_separate_hit(hit: DrumHit) -> dict[str, np.ndarray]: | |
| """ | |
| Decompose a single hit into spectral bands. | |
| Returns dict of {band_name: audio_array}. | |
| Useful for hits where multiple drums overlap. | |
| """ | |
| y, sr = hit.audio, hit.sr | |
| D = librosa.stft(y, n_fft=2048) | |
| freqs = librosa.fft_frequencies(sr=sr, n_fft=2048) | |
| bands = { | |
| "low": (20, 250), # kick range | |
| "mid": (250, 4000), # snare/tom range | |
| "high": (4000, sr // 2) # hihat/cymbal range | |
| } | |
| results = {} | |
| for name, (fmin, fmax) in bands.items(): | |
| mask = (freqs >= fmin) & (freqs <= fmax) | |
| D_band = np.zeros_like(D) | |
| D_band[mask] = D[mask] | |
| audio_band = librosa.istft(D_band, length=len(y)) | |
| # Only include if there's meaningful energy | |
| if np.sqrt(np.mean(audio_band ** 2)) > 0.001: | |
| results[name] = audio_band | |
| return results | |
| def classify_and_separate_hits(hits: list[DrumHit], | |
| separate_overlaps: bool = True) -> list[DrumHit]: | |
| """Classify hits and optionally split overlapping sounds into sub-hits.""" | |
| print("\n" + "=" * 60) | |
| print("STAGE 3: Spectral classification & separation") | |
| print("=" * 60) | |
| all_hits = [] | |
| overlap_count = 0 | |
| for hit in hits: | |
| label = rough_spectral_label(hit) | |
| hit.rough_label = label | |
| if separate_overlaps: | |
| # Check if multiple bands have significant energy (= overlap) | |
| bands = spectral_separate_hit(hit) | |
| if len(bands) >= 2: | |
| # Check if the sub-bands are meaningfully different | |
| energies = {k: np.sqrt(np.mean(v ** 2)) for k, v in bands.items()} | |
| max_e = max(energies.values()) | |
| significant = {k: v for k, v in bands.items() | |
| if energies[k] > 0.15 * max_e} | |
| if len(significant) >= 2: | |
| overlap_count += 1 | |
| # Create sub-hits for each significant band | |
| band_labels = {"low": "kick", "mid": "snare", "high": "hihat"} | |
| for band_name, band_audio in significant.items(): | |
| sub_hit = DrumHit( | |
| audio=band_audio, | |
| sr=hit.sr, | |
| onset_time=hit.onset_time, | |
| duration=hit.duration, | |
| index=len(all_hits), | |
| rms_energy=float(np.sqrt(np.mean(band_audio ** 2))), | |
| spectral_centroid=float(librosa.feature.spectral_centroid( | |
| y=band_audio, sr=hit.sr | |
| ).mean()), | |
| rough_label=band_labels.get(band_name, "other"), | |
| ) | |
| all_hits.append(sub_hit) | |
| continue # skip adding the original | |
| hit.index = len(all_hits) | |
| all_hits.append(hit) | |
| label_counts = defaultdict(int) | |
| for h in all_hits: | |
| label_counts[h.rough_label] += 1 | |
| print(f" Overlapping hits decomposed: {overlap_count}") | |
| print(f" Total hits after separation: {len(all_hits)}") | |
| print(f" Label distribution:") | |
| for label, count in sorted(label_counts.items(), key=lambda x: -x[1]): | |
| print(f" {label}: {count}") | |
| return all_hits | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Stage 4: Embedding & Clustering | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def compute_librosa_embeddings(hits: list[DrumHit]) -> np.ndarray: | |
| """Compute rich librosa feature embeddings for all hits.""" | |
| embeddings = [] | |
| for hit in hits: | |
| y, sr = hit.audio, hit.sr | |
| # Pad very short audio | |
| min_len = int(0.05 * sr) | |
| if len(y) < min_len: | |
| y = np.pad(y, (0, min_len - len(y))) | |
| # MFCCs (timbre) | |
| mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=20) | |
| mfcc_mean = mfcc.mean(axis=1) | |
| mfcc_std = mfcc.std(axis=1) | |
| # Spectral features | |
| centroid = librosa.feature.spectral_centroid(y=y, sr=sr) | |
| bandwidth = librosa.feature.spectral_bandwidth(y=y, sr=sr) | |
| rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr) | |
| contrast = librosa.feature.spectral_contrast(y=y, sr=sr, n_bands=4) | |
| flatness = librosa.feature.spectral_flatness(y=y) | |
| # Temporal features | |
| zcr = librosa.feature.zero_crossing_rate(y=y) | |
| rms = librosa.feature.rms(y=y) | |
| # Onset strength envelope shape | |
| onset_env = librosa.onset.onset_strength(y=y, sr=sr) | |
| if len(onset_env) > 1: | |
| onset_env_norm = onset_env / (onset_env.max() + 1e-10) | |
| # Attack/decay shape: first 4 moments | |
| attack_feats = [ | |
| onset_env_norm.mean(), | |
| onset_env_norm.std(), | |
| float(np.argmax(onset_env_norm)) / len(onset_env_norm), # peak position | |
| onset_env_norm[-1] if len(onset_env_norm) > 0 else 0, # tail energy | |
| ] | |
| else: | |
| attack_feats = [0, 0, 0, 0] | |
| # Assemble feature vector | |
| feat = np.concatenate([ | |
| mfcc_mean, # 20 | |
| mfcc_std, # 20 | |
| [centroid.mean(), centroid.std()], # 2 | |
| [bandwidth.mean(), bandwidth.std()], # 2 | |
| [rolloff.mean()], # 1 | |
| contrast.mean(axis=1), # 5 | |
| [flatness.mean()], # 1 | |
| [zcr.mean()], # 1 | |
| [rms.mean()], # 1 | |
| attack_feats, # 4 | |
| [hit.duration], # 1 | |
| ]) | |
| embeddings.append(feat) | |
| embeddings = np.array(embeddings, dtype=np.float32) | |
| # Normalize features (z-score per dimension) | |
| mean = embeddings.mean(axis=0) | |
| std = embeddings.std(axis=0) + 1e-8 | |
| embeddings = (embeddings - mean) / std | |
| return embeddings | |
| def compute_clap_embeddings(hits: list[DrumHit], device: str = "cpu") -> np.ndarray: | |
| """Compute CLAP audio embeddings (semantic, 512-dim).""" | |
| from transformers import ClapModel, ClapProcessor | |
| print(" Loading CLAP model (laion/larger_clap_general)...") | |
| model = ClapModel.from_pretrained("laion/larger_clap_general").to(device) | |
| processor = ClapProcessor.from_pretrained("laion/larger_clap_general") | |
| model.eval() | |
| clap_sr = 48000 | |
| embeddings = [] | |
| for i, hit in enumerate(hits): | |
| # Resample to 48kHz for CLAP | |
| y_48k = librosa.resample(hit.audio, orig_sr=hit.sr, target_sr=clap_sr) | |
| # Pad short audio to at least 0.5s | |
| min_samples = int(0.5 * clap_sr) | |
| if len(y_48k) < min_samples: | |
| y_48k = np.pad(y_48k, (0, min_samples - len(y_48k))) | |
| inputs = processor(audios=y_48k, sampling_rate=clap_sr, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| audio_embed = model.get_audio_features(**inputs) | |
| embeddings.append(audio_embed.squeeze().cpu().numpy()) | |
| if (i + 1) % 50 == 0: | |
| print(f" Embedded {i + 1}/{len(hits)}") | |
| return np.array(embeddings, dtype=np.float32) | |
| def cluster_hits(hits: list[DrumHit], | |
| embeddings: np.ndarray, | |
| min_clusters: int = 2, | |
| max_clusters: int = 30) -> list[DrumCluster]: | |
| """Cluster hits by embedding similarity, auto-selecting K.""" | |
| from sklearn.cluster import KMeans | |
| from sklearn.metrics import silhouette_score | |
| print("\n" + "=" * 60) | |
| print("STAGE 4: Clustering similar drum hits") | |
| print("=" * 60) | |
| n = len(hits) | |
| max_clusters = min(max_clusters, n - 1) | |
| if max_clusters < min_clusters: | |
| max_clusters = min_clusters | |
| # First cluster by rough label, then sub-cluster within each group | |
| label_groups = defaultdict(list) | |
| for i, hit in enumerate(hits): | |
| label_groups[hit.rough_label].append(i) | |
| all_clusters = [] | |
| for label, indices in label_groups.items(): | |
| if len(indices) < 2: | |
| # Single-hit group β its own cluster | |
| cluster = DrumCluster( | |
| cluster_id=len(all_clusters), | |
| label=f"{label}_0", | |
| hits=[hits[i] for i in indices] | |
| ) | |
| all_clusters.append(cluster) | |
| continue | |
| # Sub-cluster within this label group | |
| group_embeddings = embeddings[indices] | |
| # Auto-select k via silhouette score | |
| max_k = min(max(2, len(indices) // 3), 15) | |
| best_k, best_score = 1, -1 | |
| for k in range(2, max_k + 1): | |
| try: | |
| km = KMeans(n_clusters=k, random_state=42, n_init=10, max_iter=300) | |
| sub_labels = km.fit_predict(group_embeddings) | |
| score = silhouette_score(group_embeddings, sub_labels) | |
| if score > best_score: | |
| best_k, best_score = k, score | |
| except ValueError: | |
| continue | |
| # Fit with best k | |
| if best_k >= 2: | |
| km = KMeans(n_clusters=best_k, random_state=42, n_init=10) | |
| sub_labels = km.fit_predict(group_embeddings) | |
| else: | |
| sub_labels = np.zeros(len(indices), dtype=int) | |
| # Build clusters | |
| for sub_id in range(max(sub_labels) + 1): | |
| member_mask = sub_labels == sub_id | |
| member_indices = [indices[j] for j in range(len(indices)) if member_mask[j]] | |
| cluster = DrumCluster( | |
| cluster_id=len(all_clusters), | |
| label=f"{label}_{sub_id}", | |
| hits=[hits[i] for i in member_indices], | |
| ) | |
| all_clusters.append(cluster) | |
| print(f" {label}: {len(indices)} hits β {best_k} sub-clusters " | |
| f"(silhouette={best_score:.3f})") | |
| print(f"\n β Total clusters: {len(all_clusters)}") | |
| for c in all_clusters: | |
| print(f" {c.label}: {c.count} hits") | |
| return all_clusters | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Stage 5: Best representative selection | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def select_best_representatives(clusters: list[DrumCluster], | |
| embeddings_dict: dict = None): | |
| """Select the best representative hit from each cluster.""" | |
| print("\n" + "=" * 60) | |
| print("STAGE 5: Selecting best representatives") | |
| print("=" * 60) | |
| for cluster in clusters: | |
| if cluster.count == 1: | |
| cluster.best_hit_idx = 0 | |
| continue | |
| # Strategy: combine centroid-distance + energy + short duration preference | |
| # We want a clean, loud, representative hit | |
| # Compute per-hit feature vectors for within-cluster comparison | |
| hit_features = [] | |
| for hit in cluster.hits: | |
| feat = np.concatenate([ | |
| librosa.feature.mfcc(y=hit.audio, sr=hit.sr, n_mfcc=13).mean(axis=1), | |
| [hit.rms_energy, hit.spectral_centroid, hit.duration] | |
| ]) | |
| hit_features.append(feat) | |
| hit_features = np.array(hit_features) | |
| # Normalize | |
| mean = hit_features.mean(axis=0) | |
| std = hit_features.std(axis=0) + 1e-8 | |
| hit_features_norm = (hit_features - mean) / std | |
| # Centroid distance (representativeness) | |
| centroid = hit_features_norm.mean(axis=0) | |
| centroid_dists = np.linalg.norm(hit_features_norm - centroid, axis=1) | |
| centroid_scores = 1.0 - (centroid_dists / (centroid_dists.max() + 1e-8)) | |
| # Energy score (prefer louder = cleaner) | |
| energies = np.array([h.rms_energy for h in cluster.hits]) | |
| energy_scores = energies / (energies.max() + 1e-8) | |
| # Combined score | |
| scores = 0.6 * centroid_scores + 0.4 * energy_scores | |
| cluster.best_hit_idx = int(np.argmax(scores)) | |
| print(f" {cluster.label}: selected hit {cluster.best_hit_idx} " | |
| f"(score={scores[cluster.best_hit_idx]:.3f}, " | |
| f"energy={cluster.hits[cluster.best_hit_idx].rms_energy:.4f})") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Stage 6 (optional): Synthesize optimal sample from cluster | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def synthesize_from_cluster(cluster: DrumCluster) -> np.ndarray: | |
| """ | |
| Synthesize an 'optimal' sample by weighted-averaging cluster members. | |
| Strategy: align samples to their peak, normalize lengths, then do a | |
| weighted average in the time domain (weighted by similarity to centroid). | |
| This reduces noise/bleed while preserving the core transient. | |
| """ | |
| if cluster.count == 1: | |
| return cluster.hits[0].audio.copy() | |
| sr = cluster.hits[0].sr | |
| # Find max length and peak positions | |
| max_len = max(len(h.audio) for h in cluster.hits) | |
| target_len = int(np.median([len(h.audio) for h in cluster.hits])) | |
| # Align all hits to their peak (transient alignment) | |
| aligned = [] | |
| weights = [] | |
| peak_pos_target = None | |
| for i, hit in enumerate(cluster.hits): | |
| audio = hit.audio.copy() | |
| peak_pos = np.argmax(np.abs(audio)) | |
| if peak_pos_target is None: | |
| peak_pos_target = peak_pos | |
| # Shift to align peaks, then force exact target_len | |
| shift = peak_pos_target - peak_pos | |
| if shift > 0: | |
| audio = np.pad(audio, (shift, 0)) | |
| elif shift < 0: | |
| audio = audio[-shift:] | |
| # Force exact length | |
| if len(audio) >= target_len: | |
| audio = audio[:target_len] | |
| else: | |
| audio = np.pad(audio, (0, target_len - len(audio))) | |
| # Normalize amplitude | |
| peak = np.abs(audio).max() | |
| if peak > 0: | |
| audio = audio / peak | |
| aligned.append(audio) | |
| # Weight by similarity to best hit (closer = higher weight) | |
| if i == cluster.best_hit_idx: | |
| weights.append(2.0) # double weight for the best sample | |
| else: | |
| weights.append(1.0) | |
| # Weighted average | |
| aligned = np.array(aligned) | |
| weights = np.array(weights) | |
| weights = weights / weights.sum() | |
| synthesized = np.average(aligned, axis=0, weights=weights) | |
| # Normalize output | |
| peak = np.abs(synthesized).max() | |
| if peak > 0: | |
| synthesized = synthesized * (0.95 / peak) | |
| return synthesized | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Main pipeline | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_pipeline( | |
| audio_path: str, | |
| output_dir: str = "./drum_samples", | |
| use_gpu: bool = True, | |
| use_clap: bool = False, # CLAP embeddings (slower, semantic) | |
| use_audiosep: bool = False, # AudioSep for overlap separation | |
| separate_overlaps: bool = True, | |
| synthesize: bool = True, | |
| min_hit_dur: float = 0.03, | |
| max_hit_dur: float = 0.8, | |
| energy_threshold_db: float = -40.0, | |
| save_intermediates: bool = True, | |
| ): | |
| """Run the full drum sample extraction pipeline.""" | |
| device = "cuda" if (use_gpu and torch.cuda.is_available()) else "cpu" | |
| print(f"Device: {device}") | |
| print(f"Input: {audio_path}") | |
| print(f"Output: {output_dir}") | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # ββ Stage 1: Extract drums ββ | |
| drums_audio, drums_sr = extract_drums_demucs(audio_path, device=device) | |
| if save_intermediates: | |
| drums_path = output_dir / "drums_stem.wav" | |
| sf.write(str(drums_path), drums_audio, drums_sr, subtype='PCM_24') | |
| print(f" Saved drum stem: {drums_path}") | |
| # ββ Stage 2: Detect onsets & segment ββ | |
| hits = detect_onsets( | |
| drums_audio, drums_sr, | |
| min_hit_dur=min_hit_dur, | |
| max_hit_dur=max_hit_dur, | |
| energy_threshold_db=energy_threshold_db, | |
| ) | |
| if len(hits) == 0: | |
| print("\nβ No drum hits detected! Try lowering energy_threshold_db.") | |
| return | |
| # ββ Stage 3: Classify & optionally separate overlaps ββ | |
| hits = classify_and_separate_hits(hits, separate_overlaps=separate_overlaps) | |
| if save_intermediates: | |
| hits_dir = output_dir / "all_hits" | |
| hits_dir.mkdir(exist_ok=True) | |
| for hit in hits: | |
| hit_path = hits_dir / f"hit_{hit.index:04d}_{hit.rough_label}_{hit.onset_time:.3f}s.wav" | |
| hit.save(str(hit_path)) | |
| # ββ Stage 4: Embed & cluster ββ | |
| print("\n" + "=" * 60) | |
| print("STAGE 4a: Computing embeddings") | |
| print("=" * 60) | |
| if use_clap: | |
| embeddings = compute_clap_embeddings(hits, device=device) | |
| print(f" β CLAP embeddings: {embeddings.shape}") | |
| else: | |
| embeddings = compute_librosa_embeddings(hits) | |
| print(f" β Librosa embeddings: {embeddings.shape}") | |
| for i, hit in enumerate(hits): | |
| hit.embedding = embeddings[i] | |
| clusters = cluster_hits(hits, embeddings) | |
| # ββ Stage 5: Select best representatives ββ | |
| select_best_representatives(clusters) | |
| # ββ Stage 6: Optional synthesis ββ | |
| if synthesize: | |
| print("\n" + "=" * 60) | |
| print("STAGE 6: Synthesizing optimal samples") | |
| print("=" * 60) | |
| for cluster in clusters: | |
| if cluster.count >= 2: | |
| cluster.synthesized = synthesize_from_cluster(cluster) | |
| print(f" {cluster.label}: synthesized from {cluster.count} hits") | |
| # ββ Export ββ | |
| print("\n" + "=" * 60) | |
| print("EXPORT: Saving results") | |
| print("=" * 60) | |
| samples_dir = output_dir / "samples" | |
| samples_dir.mkdir(exist_ok=True) | |
| if synthesize: | |
| synth_dir = output_dir / "synthesized" | |
| synth_dir.mkdir(exist_ok=True) | |
| manifest = [] | |
| for cluster in clusters: | |
| best = cluster.best_hit | |
| # Save best representative | |
| sample_name = f"{cluster.label}__best.wav" | |
| sample_path = samples_dir / sample_name | |
| best.save(str(sample_path)) | |
| entry = { | |
| "cluster_id": cluster.cluster_id, | |
| "label": cluster.label, | |
| "count": cluster.count, | |
| "best_sample": str(sample_path), | |
| "best_onset_time": best.onset_time, | |
| "best_duration": best.duration, | |
| "best_rms_energy": best.rms_energy, | |
| "best_spectral_centroid": best.spectral_centroid, | |
| } | |
| # Save synthesized version | |
| if synthesize and cluster.synthesized is not None: | |
| synth_name = f"{cluster.label}__synthesized.wav" | |
| synth_path = synth_dir / synth_name | |
| sf.write(str(synth_path), cluster.synthesized, best.sr, subtype='PCM_24') | |
| entry["synthesized_sample"] = str(synth_path) | |
| manifest.append(entry) | |
| print(f" β {cluster.label}: {cluster.count} hits β {sample_path.name}") | |
| # Save manifest | |
| manifest_path = output_dir / "manifest.json" | |
| with open(manifest_path, "w") as f: | |
| json.dump(manifest, f, indent=2) | |
| print(f"\n Manifest saved: {manifest_path}") | |
| # Summary | |
| print("\n" + "=" * 60) | |
| print("SUMMARY") | |
| print("=" * 60) | |
| print(f" Input: {audio_path}") | |
| print(f" Drum stem: {output_dir / 'drums_stem.wav'}") | |
| print(f" Total hits: {len(hits)}") | |
| print(f" Clusters: {len(clusters)}") | |
| print(f" Samples saved: {samples_dir}") | |
| if synthesize: | |
| print(f" Synthesized: {synth_dir}") | |
| print(f" Manifest: {manifest_path}") | |
| return clusters | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CLI | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Extract individual drum samples from an audio file", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| %(prog)s song.mp3 -o ./my_samples | |
| %(prog)s drums.wav -o ./samples --no-gpu | |
| %(prog)s song.wav -o ./samples --clap # Use CLAP for semantic clustering | |
| %(prog)s song.wav -o ./samples --no-separate # Don't decompose overlaps | |
| %(prog)s song.wav -o ./samples --no-synthesize # Skip synthesis step | |
| """ | |
| ) | |
| parser.add_argument("input", help="Input audio file (mp3, wav, flac, etc.)") | |
| parser.add_argument("-o", "--output-dir", default="./drum_samples", | |
| help="Output directory (default: ./drum_samples)") | |
| parser.add_argument("--no-gpu", action="store_true", | |
| help="Force CPU-only processing") | |
| parser.add_argument("--clap", action="store_true", | |
| help="Use CLAP embeddings for clustering (slower, more semantic)") | |
| parser.add_argument("--no-separate", action="store_true", | |
| help="Don't separate overlapping drum sounds") | |
| parser.add_argument("--no-synthesize", action="store_true", | |
| help="Don't synthesize optimal samples from clusters") | |
| parser.add_argument("--no-intermediates", action="store_true", | |
| help="Don't save intermediate files (drum stem, individual hits)") | |
| parser.add_argument("--min-hit-dur", type=float, default=0.03, | |
| help="Minimum hit duration in seconds (default: 0.03)") | |
| parser.add_argument("--max-hit-dur", type=float, default=0.8, | |
| help="Maximum hit duration in seconds (default: 0.8)") | |
| parser.add_argument("--energy-threshold", type=float, default=-40.0, | |
| help="Energy threshold in dB for hit filtering (default: -40)") | |
| args = parser.parse_args() | |
| if not os.path.exists(args.input): | |
| print(f"Error: Input file not found: {args.input}") | |
| sys.exit(1) | |
| run_pipeline( | |
| audio_path=args.input, | |
| output_dir=args.output_dir, | |
| use_gpu=not args.no_gpu, | |
| use_clap=args.clap, | |
| separate_overlaps=not args.no_separate, | |
| synthesize=not args.no_synthesize, | |
| min_hit_dur=args.min_hit_dur, | |
| max_hit_dur=args.max_hit_dur, | |
| energy_threshold_db=args.energy_threshold, | |
| save_intermediates=not args.no_intermediates, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |