| | import torch |
| | import torchaudio |
| | from pathlib import Path |
| | import argparse |
| | from tqdm import tqdm |
| | from acestep.music_dcae.music_dcae_pipeline import MusicDCAE |
| |
|
| | class AudioVAE: |
| | def __init__(self, device: torch.device): |
| | self.model = MusicDCAE().to(device) |
| | self.model.eval() |
| | self.device = device |
| | self.latent_mean = torch.tensor( |
| | [0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526], |
| | device=device, |
| | ).view(1, -1, 1, 1) |
| | self.latent_std = torch.tensor( |
| | [0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707], |
| | device=device, |
| | ).view(1, -1, 1, 1) |
| |
|
| | def encode(self, audio): |
| |
|
| | with torch.no_grad(): |
| | audio_lengths = torch.tensor([audio.shape[2]] * audio.shape[0]).to(self.device) |
| | latents, _ = self.model.encode(audio, audio_lengths, sr=48000) |
| | latents = (latents - self.latent_mean) / self.latent_std |
| | return latents |
| |
|
| | def decode(self, latents: torch.Tensor) -> torch.Tensor: |
| | with torch.no_grad(): |
| | latents = latents * self.latent_std + self.latent_mean |
| | _, audio_list = self.model.decode(latents, sr=48000) |
| | audio_batch = torch.stack(audio_list).to(self.device) |
| | return audio_batch |
| |
|
| | def load_audio(audio_path, target_sr=48000): |
| | """Load and preprocess audio file.""" |
| | audio, sr = torchaudio.load(audio_path) |
| |
|
| | if audio.shape[0] == 1: |
| | audio = audio.repeat(2, 1) |
| | elif audio.shape[0] > 2: |
| | audio = audio[:2] |
| |
|
| | if sr != target_sr: |
| | resampler = torchaudio.transforms.Resample(sr, target_sr) |
| | audio = resampler(audio) |
| |
|
| | return audio |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description='Encode audio files to VAE latents') |
| |
|
| | parser.add_argument('--audio-dir', type=str, required=True, |
| | help='Directory containing audio files') |
| | parser.add_argument('--output-dir', type=str, default="latents", |
| | help='Directory to save encoded latents') |
| |
|
| | args = parser.parse_args() |
| |
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | print(f"Using device: {device}") |
| |
|
| | output_dir = Path(args.output_dir) |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | audio_dir = Path(args.audio_dir) |
| | audio_extensions = ['*.mp3', '*.wav', '*.flac', '*.ogg', '*.m4a'] |
| | audio_files = [] |
| | for ext in audio_extensions: |
| | audio_files.extend(list(audio_dir.glob(ext))) |
| | audio_files = sorted(audio_files) |
| |
|
| | if len(audio_files) == 0: |
| | raise ValueError(f"No audio files found in {args.audio_dir}") |
| |
|
| | print(f"Found {len(audio_files)} audio files") |
| |
|
| | vae = AudioVAE(device) |
| | print("VAE loaded") |
| |
|
| | |
| | print("\nEncoding audio files...") |
| | for audio_path in tqdm(audio_files, desc="Encoding"): |
| | try: |
| | audio = load_audio(audio_path) |
| | audio = audio.unsqueeze(0).to(device) |
| | latents = vae.encode(audio) |
| | latents = latents.squeeze(0) |
| |
|
| | output_path = output_dir / f"{audio_path.stem}.pt" |
| | torch.save(latents.cpu(), output_path) |
| |
|
| | except Exception as e: |
| | print(f"\nError encoding {audio_path.name}: {e}") |
| | continue |
| |
|
| | print(f"\nEncoding complete! Saved {len(list(output_dir.glob('*.pt')))} latent files to {output_dir}") |
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|