Spaces:
Sleeping
Sleeping
| """ | |
| Synthetic drum song generator with known ground-truth samples. | |
| Generates realistic drum patterns by: | |
| 1. Synthesizing individual drum samples (kick, snare, hihat, etc.) with controlled parameters | |
| 2. Placing them in musical patterns with velocity variation, timing humanization, and overlap | |
| 3. Optionally mixing with bass/harmony for realistic Demucs testing | |
| 4. Returning both the mix AND the isolated ground-truth samples + onset map | |
| This gives us a perfect evaluation setup: we know exactly which samples are where, | |
| so we can compare extracted samples against the originals. | |
| """ | |
| import numpy as np | |
| from scipy.signal import butter, filtfilt, lfilter | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| import soundfile as sf | |
| import json | |
| class GroundTruthSample: | |
| """A ground-truth drum sample used to build the synthetic song.""" | |
| name: str # e.g. "kick", "snare" | |
| audio: np.ndarray # the clean one-shot sample | |
| sr: int | |
| frequency_range: tuple # (low_hz, high_hz) primary energy band | |
| def duration(self) -> float: | |
| return len(self.audio) / self.sr | |
| class PlacedHit: | |
| """A single hit placed in the timeline.""" | |
| sample_name: str | |
| onset_time: float # seconds | |
| velocity: float # 0-1 amplitude multiplier | |
| audio: np.ndarray # the actual audio placed (with velocity applied) | |
| sr: int | |
| class SyntheticSong: | |
| """A complete synthetic drum song with ground truth.""" | |
| mix: np.ndarray # full mix audio | |
| drums_only: np.ndarray # drums-only mix | |
| sr: int | |
| bpm: float | |
| duration: float | |
| samples: dict # {name: GroundTruthSample} | |
| hits: list # [PlacedHit, ...] | |
| per_sample_stems: dict # {name: np.ndarray} isolated stems | |
| pattern_description: str | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Sample synthesis (parametric drum sounds) | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| def _butter_filter(y, sr, fmin=None, fmax=None, order=4): | |
| """Apply butterworth bandpass/lowpass/highpass filter.""" | |
| nyq = sr / 2 | |
| if fmin and fmax: | |
| b, a = butter(order, [fmin / nyq, fmax / nyq], btype='band') | |
| elif fmin: | |
| b, a = butter(order, fmin / nyq, btype='high') | |
| elif fmax: | |
| b, a = butter(order, fmax / nyq, btype='low') | |
| else: | |
| return y | |
| return filtfilt(b, a, y) | |
| def synthesize_kick(sr: int = 44100, pitch: float = 60.0, | |
| decay: float = 12.0, punch: float = 150.0, | |
| duration: float = 0.25, noise_amount: float = 0.05) -> np.ndarray: | |
| """Synthesize a kick drum: sine sweep + sub thump + click.""" | |
| t = np.arange(int(sr * duration)) / sr | |
| # Frequency sweep: punch Hz → pitch Hz | |
| freq = (punch - pitch) * np.exp(-30 * t) + pitch | |
| phase = 2 * np.pi * np.cumsum(freq / sr) | |
| body = np.sin(phase) * np.exp(-decay * t) | |
| # Sub thump | |
| sub = 0.4 * np.sin(2 * np.pi * pitch * t) * np.exp(-15 * t) | |
| # Click transient | |
| click = noise_amount * np.random.randn(len(t)) * np.exp(-200 * t) | |
| click = _butter_filter(click, sr, fmax=4000) | |
| kick = body + sub + click | |
| kick = kick / (np.abs(kick).max() + 1e-8) * 0.95 | |
| return kick.astype(np.float32) | |
| def synthesize_snare(sr: int = 44100, body_freq: float = 200.0, | |
| noise_decay: float = 12.0, body_decay: float = 20.0, | |
| duration: float = 0.25, wire_amount: float = 0.6) -> np.ndarray: | |
| """Synthesize a snare drum: body tone + noise wires.""" | |
| t = np.arange(int(sr * duration)) / sr | |
| # Body | |
| body = np.sin(2 * np.pi * body_freq * t) * np.exp(-body_decay * t) * 0.5 | |
| # Snare wires (filtered noise) | |
| noise = np.random.randn(len(t)) * np.exp(-noise_decay * t) * wire_amount | |
| noise = _butter_filter(noise, sr, fmin=1000, fmax=10000) | |
| # Overtone ring | |
| ring = 0.15 * np.sin(2 * np.pi * body_freq * 2.7 * t) * np.exp(-25 * t) | |
| snare = body + noise + ring | |
| snare = snare / (np.abs(snare).max() + 1e-8) * 0.95 | |
| return snare.astype(np.float32) | |
| def synthesize_hihat(sr: int = 44100, is_open: bool = False, | |
| brightness: float = 8000.0, | |
| duration: float = None) -> np.ndarray: | |
| """Synthesize a hi-hat: filtered noise with metallic overtones.""" | |
| if duration is None: | |
| duration = 0.4 if is_open else 0.08 | |
| t = np.arange(int(sr * duration)) / sr | |
| decay = 6.0 if is_open else 40.0 | |
| noise = np.random.randn(len(t)) * np.exp(-decay * t) | |
| noise = _butter_filter(noise, sr, fmin=brightness) | |
| # Metallic overtones | |
| metal = 0.2 * np.sin(2 * np.pi * 6500 * t) * np.exp(-(decay + 5) * t) | |
| metal += 0.1 * np.sin(2 * np.pi * 9200 * t) * np.exp(-(decay + 8) * t) | |
| hh = noise + metal | |
| hh = hh / (np.abs(hh).max() + 1e-8) * 0.7 | |
| return hh.astype(np.float32) | |
| def synthesize_tom(sr: int = 44100, pitch: float = 120.0, | |
| decay: float = 10.0, duration: float = 0.3) -> np.ndarray: | |
| """Synthesize a tom: pitched body + slight noise.""" | |
| t = np.arange(int(sr * duration)) / sr | |
| freq = pitch * 1.5 * np.exp(-8 * t) + pitch | |
| phase = 2 * np.pi * np.cumsum(freq / sr) | |
| body = np.sin(phase) * np.exp(-decay * t) | |
| noise = 0.1 * np.random.randn(len(t)) * np.exp(-20 * t) | |
| noise = _butter_filter(noise, sr, fmin=200, fmax=3000) | |
| tom = body + noise | |
| tom = tom / (np.abs(tom).max() + 1e-8) * 0.9 | |
| return tom.astype(np.float32) | |
| def synthesize_cymbal(sr: int = 44100, duration: float = 1.5) -> np.ndarray: | |
| """Synthesize a crash/ride cymbal: dense metallic noise.""" | |
| t = np.arange(int(sr * duration)) / sr | |
| noise = np.random.randn(len(t)) * np.exp(-3 * t) | |
| noise = _butter_filter(noise, sr, fmin=3000) | |
| # Multiple metallic partials | |
| partials = sum( | |
| (0.15 / (i + 1)) * np.sin(2 * np.pi * f * t) * np.exp(-(2 + i) * t) | |
| for i, f in enumerate([4200, 5800, 7300, 9100, 11500]) | |
| ) | |
| cym = noise + partials | |
| cym = cym / (np.abs(cym).max() + 1e-8) * 0.6 | |
| return cym.astype(np.float32) | |
| def synthesize_bass_note(sr: int = 44100, freq: float = 65.0, | |
| duration: float = 0.5) -> np.ndarray: | |
| """Synthesize a bass note for adding to the mix (tests Demucs separation).""" | |
| t = np.arange(int(sr * duration)) / sr | |
| # Sawtooth-ish bass with harmonics | |
| wave = (np.sin(2 * np.pi * freq * t) + | |
| 0.5 * np.sin(2 * np.pi * freq * 2 * t) + | |
| 0.25 * np.sin(2 * np.pi * freq * 3 * t)) | |
| envelope = np.minimum(t * 50, 1.0) * np.exp(-3 * t) # quick attack, slow decay | |
| bass = wave * envelope | |
| bass = _butter_filter(bass, sr, fmax=500) | |
| bass = bass / (np.abs(bass).max() + 1e-8) * 0.5 | |
| return bass.astype(np.float32) | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Sample set creation with controlled variation | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| def create_sample_set(sr: int = 44100, seed: int = 42, | |
| variation: str = "medium") -> dict: | |
| """Create a set of ground-truth drum samples with parametric variation. | |
| Args: | |
| variation: "none" (identical hits), "low", "medium", "high" | |
| """ | |
| rng = np.random.RandomState(seed) | |
| # Base parameters with per-variation noise | |
| var_scale = {"none": 0.0, "low": 0.05, "medium": 0.15, "high": 0.3}[variation] | |
| def vary(val, amount=None): | |
| a = amount if amount is not None else var_scale | |
| return val * (1.0 + rng.uniform(-a, a)) | |
| samples = { | |
| 'kick': GroundTruthSample( | |
| name='kick', | |
| audio=synthesize_kick(sr, pitch=vary(60), decay=vary(12), punch=vary(150)), | |
| sr=sr, | |
| frequency_range=(30, 300), | |
| ), | |
| 'snare': GroundTruthSample( | |
| name='snare', | |
| audio=synthesize_snare(sr, body_freq=vary(200), noise_decay=vary(12)), | |
| sr=sr, | |
| frequency_range=(100, 8000), | |
| ), | |
| 'hihat_closed': GroundTruthSample( | |
| name='hihat_closed', | |
| audio=synthesize_hihat(sr, is_open=False, brightness=vary(8000)), | |
| sr=sr, | |
| frequency_range=(3000, 20000), | |
| ), | |
| 'hihat_open': GroundTruthSample( | |
| name='hihat_open', | |
| audio=synthesize_hihat(sr, is_open=True, brightness=vary(7000)), | |
| sr=sr, | |
| frequency_range=(2000, 20000), | |
| ), | |
| 'tom': GroundTruthSample( | |
| name='tom', | |
| audio=synthesize_tom(sr, pitch=vary(120), decay=vary(10)), | |
| sr=sr, | |
| frequency_range=(50, 2000), | |
| ), | |
| 'cymbal': GroundTruthSample( | |
| name='cymbal', | |
| audio=synthesize_cymbal(sr), | |
| sr=sr, | |
| frequency_range=(2000, 20000), | |
| ), | |
| } | |
| return samples | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Pattern generation | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| def generate_basic_rock(bars: int = 4) -> dict: | |
| """Basic rock pattern. Returns {sample_name: [(beat_position, velocity), ...]}""" | |
| pattern = { | |
| 'kick': [], | |
| 'snare': [], | |
| 'hihat_closed': [], | |
| 'hihat_open': [], | |
| } | |
| for bar in range(bars): | |
| offset = bar * 4 # 4 beats per bar | |
| # Kick on 1 and 3 | |
| pattern['kick'].extend([(offset + 0, 0.9), (offset + 2, 0.85)]) | |
| # Snare on 2 and 4 | |
| pattern['snare'].extend([(offset + 1, 0.85), (offset + 3, 0.9)]) | |
| # HH on every 8th note | |
| for eighth in range(8): | |
| vel = 0.6 if eighth % 2 == 0 else 0.4 # accented downbeats | |
| pattern['hihat_closed'].append((offset + eighth * 0.5, vel)) | |
| # Open hat on "& of 4" | |
| pattern['hihat_open'].append((offset + 3.5, 0.55)) | |
| return pattern | |
| def generate_funk_pattern(bars: int = 4) -> dict: | |
| """Funky syncopated pattern with ghost notes.""" | |
| pattern = { | |
| 'kick': [], | |
| 'snare': [], | |
| 'hihat_closed': [], | |
| 'hihat_open': [], | |
| 'tom': [], | |
| } | |
| for bar in range(bars): | |
| o = bar * 4 | |
| # Syncopated kick | |
| pattern['kick'].extend([ | |
| (o + 0, 0.95), (o + 0.75, 0.6), (o + 2, 0.9), (o + 2.5, 0.7) | |
| ]) | |
| # Snare with ghost notes | |
| pattern['snare'].extend([ | |
| (o + 1, 0.9), (o + 1.75, 0.3), (o + 3, 0.85), (o + 3.25, 0.25) | |
| ]) | |
| # 16th note hats | |
| for sixteenth in range(16): | |
| vel = 0.5 + 0.2 * (sixteenth % 4 == 0) | |
| pattern['hihat_closed'].append((o + sixteenth * 0.25, vel)) | |
| # Tom fill on last bar | |
| if bar == bars - 1: | |
| pattern['tom'].extend([ | |
| (o + 3, 0.8), (o + 3.25, 0.75), (o + 3.5, 0.85), (o + 3.75, 0.9) | |
| ]) | |
| return pattern | |
| def generate_halftime_pattern(bars: int = 4) -> dict: | |
| """Half-time/trap-influenced pattern.""" | |
| pattern = { | |
| 'kick': [], | |
| 'snare': [], | |
| 'hihat_closed': [], | |
| 'cymbal': [], | |
| } | |
| for bar in range(bars): | |
| o = bar * 4 | |
| # Kick on 1 | |
| pattern['kick'].append((o + 0, 0.95)) | |
| # Occasional double kick | |
| if bar % 2 == 1: | |
| pattern['kick'].append((o + 0.5, 0.7)) | |
| # Snare on 3 only (half time) | |
| pattern['snare'].append((o + 2, 0.9)) | |
| # Fast hats | |
| for sixteenth in range(16): | |
| vel = 0.3 + 0.15 * (sixteenth % 2 == 0) | |
| pattern['hihat_closed'].append((o + sixteenth * 0.25, vel)) | |
| # Crash on bar 1 | |
| if bar == 0: | |
| pattern['cymbal'].append((o + 0, 0.7)) | |
| return pattern | |
| PATTERNS = { | |
| 'rock': generate_basic_rock, | |
| 'funk': generate_funk_pattern, | |
| 'halftime': generate_halftime_pattern, | |
| } | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Song assembly | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| def assemble_song( | |
| samples: dict, | |
| pattern: dict, | |
| sr: int = 44100, | |
| bpm: float = 120.0, | |
| humanize_timing_ms: float = 5.0, | |
| humanize_velocity: float = 0.05, | |
| add_bass: bool = True, | |
| bass_notes: list = None, | |
| room_noise_db: float = -60.0, | |
| seed: int = 42, | |
| ) -> SyntheticSong: | |
| """Assemble a complete synthetic song from samples and pattern.""" | |
| rng = np.random.RandomState(seed) | |
| beat_dur = 60.0 / bpm | |
| # Calculate total duration | |
| all_beats = [] | |
| for name, events in pattern.items(): | |
| if events: | |
| all_beats.extend([e[0] for e in events]) | |
| max_beat = max(all_beats) if all_beats else 4 | |
| total_dur = (max_beat + 2) * beat_dur # add 2 beats of tail | |
| total_samples = int(total_dur * sr) | |
| # Initialize stems | |
| drums_mix = np.zeros(total_samples, dtype=np.float64) | |
| per_sample = {name: np.zeros(total_samples, dtype=np.float64) for name in samples} | |
| hits = [] | |
| # Place each hit | |
| for sample_name, events in pattern.items(): | |
| if sample_name not in samples: | |
| continue | |
| sample = samples[sample_name] | |
| for beat_pos, velocity in events: | |
| # Humanize timing | |
| timing_offset = rng.normal(0, humanize_timing_ms / 1000.0) | |
| onset_time = beat_pos * beat_dur + timing_offset | |
| onset_time = max(0, onset_time) | |
| # Humanize velocity | |
| vel = velocity * (1.0 + rng.uniform(-humanize_velocity, humanize_velocity)) | |
| vel = np.clip(vel, 0.05, 1.0) | |
| # Place in timeline | |
| start = int(onset_time * sr) | |
| audio = sample.audio * vel | |
| end = min(start + len(audio), total_samples) | |
| actual_len = end - start | |
| if actual_len <= 0: | |
| continue | |
| drums_mix[start:end] += audio[:actual_len] | |
| per_sample[sample_name][start:end] += audio[:actual_len] | |
| hits.append(PlacedHit( | |
| sample_name=sample_name, | |
| onset_time=onset_time, | |
| velocity=vel, | |
| audio=audio[:actual_len], | |
| sr=sr, | |
| )) | |
| # Optional bass line (tests Demucs separation) | |
| bass_track = np.zeros(total_samples, dtype=np.float64) | |
| if add_bass: | |
| if bass_notes is None: | |
| # Simple root note bass on beat 1 and 3 | |
| bass_notes_list = [(0, 65), (2, 65), (4, 82), (6, 82)] | |
| # Repeat for all bars | |
| n_bars = int(max_beat / 4) + 1 | |
| bass_notes = [] | |
| for bar in range(n_bars): | |
| for beat, freq in bass_notes_list: | |
| if beat + bar * 4 <= max_beat: | |
| bass_notes.append((beat + bar * 4, freq)) | |
| for beat_pos, freq in bass_notes: | |
| onset = beat_pos * beat_dur | |
| start = int(onset * sr) | |
| bass = synthesize_bass_note(sr, freq=freq, duration=beat_dur * 2) | |
| end = min(start + len(bass), total_samples) | |
| bass_track[start:end] += bass[:end - start] | |
| # Add room noise | |
| noise = rng.randn(total_samples) * (10 ** (room_noise_db / 20)) | |
| # Final mix | |
| full_mix = drums_mix + bass_track + noise | |
| # Normalize | |
| peak = np.abs(full_mix).max() | |
| if peak > 0: | |
| scale = 0.9 / peak | |
| full_mix *= scale | |
| drums_mix *= scale | |
| for name in per_sample: | |
| per_sample[name] *= scale | |
| return SyntheticSong( | |
| mix=full_mix.astype(np.float32), | |
| drums_only=drums_mix.astype(np.float32), | |
| sr=sr, | |
| bpm=bpm, | |
| duration=total_dur, | |
| samples=samples, | |
| hits=hits, | |
| per_sample_stems=per_sample, | |
| pattern_description=str({k: len(v) for k, v in pattern.items()}), | |
| ) | |
| def generate_test_song( | |
| pattern_name: str = 'rock', | |
| bars: int = 4, | |
| bpm: float = 120.0, | |
| sr: int = 44100, | |
| variation: str = 'medium', | |
| add_bass: bool = True, | |
| seed: int = 42, | |
| ) -> SyntheticSong: | |
| """High-level function: generate a complete test song with ground truth.""" | |
| samples = create_sample_set(sr=sr, seed=seed, variation=variation) | |
| pattern_fn = PATTERNS.get(pattern_name, generate_basic_rock) | |
| pattern = pattern_fn(bars=bars) | |
| return assemble_song( | |
| samples=samples, | |
| pattern=pattern, | |
| sr=sr, | |
| bpm=bpm, | |
| add_bass=add_bass, | |
| seed=seed, | |
| ) | |
| def save_ground_truth(song: SyntheticSong, output_dir: str): | |
| """Save all ground truth data for evaluation.""" | |
| import os | |
| os.makedirs(output_dir, exist_ok=True) | |
| os.makedirs(os.path.join(output_dir, 'gt_samples'), exist_ok=True) | |
| os.makedirs(os.path.join(output_dir, 'gt_stems'), exist_ok=True) | |
| # Save mix and drums | |
| sf.write(os.path.join(output_dir, 'mix.wav'), song.mix, song.sr, subtype='PCM_24') | |
| sf.write(os.path.join(output_dir, 'drums_only.wav'), song.drums_only, song.sr, subtype='PCM_24') | |
| # Save individual samples | |
| for name, sample in song.samples.items(): | |
| sf.write(os.path.join(output_dir, 'gt_samples', f'{name}.wav'), | |
| sample.audio, sample.sr, subtype='PCM_24') | |
| # Save per-sample stems | |
| for name, stem in song.per_sample_stems.items(): | |
| sf.write(os.path.join(output_dir, 'gt_stems', f'{name}_stem.wav'), | |
| stem, song.sr, subtype='PCM_24') | |
| # Save hit map | |
| hit_map = [ | |
| {'sample': h.sample_name, 'onset': h.onset_time, 'velocity': h.velocity} | |
| for h in song.hits | |
| ] | |
| with open(os.path.join(output_dir, 'hit_map.json'), 'w') as f: | |
| json.dump({ | |
| 'bpm': song.bpm, | |
| 'duration': song.duration, | |
| 'sr': song.sr, | |
| 'pattern': song.pattern_description, | |
| 'hits': hit_map, | |
| }, f, indent=2) | |