| | import gradio as gr |
| | import spaces |
| | import yaml |
| | import torch |
| | import librosa |
| | import torchaudio |
| | from diffusers import DDIMScheduler |
| | from transformers import AutoProcessor, ClapModel, ClapConfig |
| | from model.udit import UDiT |
| | from vae_modules.autoencoder_wrapper import Autoencoder |
| | import numpy as np |
| | from huggingface_hub import hf_hub_download |
| |
|
| | clap_bin_path = hf_hub_download("laion/larger_clap_general", "pytorch_model.bin") |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | diffusion_config = './config/SoloAudio.yaml' |
| | diffusion_ckpt = './pretrained_models/soloaudio_v2.pt' |
| | autoencoder_path = './pretrained_models/audio-vae.pt' |
| | uncond_path = './pretrained_models/uncond.npz' |
| | sample_rate = 24000 |
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| |
|
| | with open(diffusion_config, 'r') as fp: |
| | diff_config = yaml.safe_load(fp) |
| |
|
| | v_prediction = diff_config["ddim"]["v_prediction"] |
| |
|
| | processor = AutoProcessor.from_pretrained('laion/larger_clap_general') |
| | clap_config = ClapConfig.from_pretrained("laion/larger_clap_general") |
| | clapmodel = ClapModel(clap_config) |
| | clap_ckpt = torch.load(clap_bin_path, map_location='cpu') |
| | clapmodel.load_state_dict(clap_ckpt) |
| | clapmodel.to(device) |
| | |
| |
|
| | autoencoder = Autoencoder(autoencoder_path, 'stable_vae', quantization_first=True) |
| | autoencoder.eval() |
| | autoencoder = autoencoder.float().to(device) |
| | unet = UDiT( |
| | **diff_config['diffwrap']['UDiT'] |
| | ).to(device) |
| | unet.load_state_dict(torch.load(diffusion_ckpt)['model']) |
| | unet.eval() |
| |
|
| | if v_prediction: |
| | print('v prediction') |
| | scheduler = DDIMScheduler(**diff_config["ddim"]['diffusers']) |
| | else: |
| | print('noise prediction') |
| | scheduler = DDIMScheduler(**diff_config["ddim"]['diffusers']) |
| | |
| | @spaces.GPU |
| | def reset_scheduler_dtype(): |
| | latents = torch.randn((1, 128, 128), device="cuda") |
| | noise = torch.randn_like(latents) |
| | timesteps = torch.randint( |
| | 0, |
| | scheduler.config.num_train_timesteps, |
| | (latents.shape[0],), |
| | device=latents.device |
| | ) |
| | _ = scheduler.add_noise(latents, noise, timesteps) |
| | return "Scheduler dtype reset completed." |
| |
|
| |
|
| | @spaces.GPU |
| | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): |
| | """ |
| | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and |
| | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 |
| | """ |
| | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) |
| | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) |
| | |
| | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) |
| | |
| | noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg |
| | return noise_cfg |
| |
|
| | @spaces.GPU |
| | def sample_diffusion(mixture, timbre, ddim_steps=50, eta=0, seed=2023, guidance_scale=False, guidance_rescale=0.0,): |
| | with torch.no_grad(): |
| | scheduler.set_timesteps(ddim_steps) |
| | generator = torch.Generator(device=device).manual_seed(seed) |
| | |
| | noise = torch.randn(mixture.shape, generator=generator, device=device) |
| | pred = noise |
| |
|
| | for t in scheduler.timesteps: |
| | pred = scheduler.scale_model_input(pred, t) |
| | if guidance_scale: |
| | uncond = torch.tensor(np.load(uncond_path)['arr_0']).unsqueeze(0).to(device) |
| | pred_combined = torch.cat([pred, pred], dim=0) |
| | mixture_combined = torch.cat([mixture, mixture], dim=0) |
| | timbre_combined = torch.cat([timbre, uncond], dim=0) |
| | output_combined = unet(x=pred_combined, timesteps=t, mixture=mixture_combined, timbre=timbre_combined) |
| | output_pos, output_neg = torch.chunk(output_combined, 2, dim=0) |
| |
|
| | model_output = output_neg + guidance_scale * (output_pos - output_neg) |
| | if guidance_rescale > 0.0: |
| | |
| | model_output = rescale_noise_cfg(model_output, output_pos, |
| | guidance_rescale=guidance_rescale) |
| | else: |
| | model_output = unet(x=pred, timesteps=t, mixture=mixture, timbre=timbre) |
| | pred = scheduler.step(model_output=model_output, timestep=t, sample=pred, |
| | eta=eta, generator=generator).prev_sample |
| |
|
| | pred = autoencoder(embedding=pred).squeeze(1) |
| |
|
| | return pred |
| |
|
| | @spaces.GPU |
| | def tse(gt_file_input, text_input, num_infer_steps, eta, seed, guidance_scale, guidance_rescale): |
| | reset_scheduler_dtype() |
| | with torch.no_grad(): |
| | |
| | mixture, sr = torchaudio.load(gt_file_input) |
| | if sr != sample_rate: |
| | resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate) |
| | mixture = resampler(mixture) |
| | sr = sample_rate |
| | if mixture.shape[0] > 1: |
| | mixture = torch.mean(mixture, dim=0) |
| | else: |
| | mixture = mixture[0] |
| | |
| | |
| | current_length = len(mixture) |
| | target_length = sample_rate * 10 |
| | |
| | if current_length > target_length: |
| | |
| | mixture = mixture[:target_length] |
| | elif current_length < target_length: |
| | |
| | padding = target_length - current_length |
| | mixture = np.pad(mixture, (0, padding), mode='constant') |
| | mixture = torch.tensor(mixture).unsqueeze(0).to(device) |
| | mixture = autoencoder(audio=mixture.unsqueeze(1)) |
| |
|
| | text_inputs = processor( |
| | text=[text_input], |
| | max_length=10, |
| | padding='max_length', |
| | truncation=True, |
| | return_tensors="pt" |
| | ) |
| | inputs = { |
| | "input_ids": text_inputs["input_ids"][0].unsqueeze(0), |
| | "attention_mask": text_inputs["attention_mask"][0].unsqueeze(0), |
| | } |
| | inputs = {key: value.to(device) for key, value in inputs.items()} |
| | timbre = clapmodel.get_text_features(**inputs) |
| |
|
| | |
| | pred = sample_diffusion(mixture, timbre, num_infer_steps, eta, seed, guidance_scale, guidance_rescale) |
| | return sample_rate, pred.squeeze().cpu().numpy() |
| | |
| |
|
| |
|
| | |
| | css = """ |
| | #col-container { |
| | margin: 0 auto; |
| | max-width: 1280px; |
| | } |
| | """ |
| |
|
| | |
| | with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: |
| | with gr.Column(elem_id="col-container"): |
| | gr.Markdown(""" |
| | # SoloAudio: Target Sound Extraction with Language-oriented Audio Diffusion Transformer |
| | 🔧 Tips: Adjust advanced settings for more control. This space only supports a 10-second audio input now. |
| | |
| | 💡 If you're looking to extract a specific speaker's voice, try [SoloSpeech](https://huggingface.co/spaces/OpenSound/SoloSpeech). |
| | |
| | 🔗 Learn more about 🎯**SoloAudio** on the [SoloAudio Homepage](https://wanghelin1997.github.io/SoloAudio-Demo/). |
| | |
| | """) |
| |
|
| |
|
| | with gr.Tab("Target Sound Extraction"): |
| | |
| | with gr.Row(): |
| | gt_file_input = gr.Audio( |
| | label="Upload Audio Mixture", |
| | type="filepath", |
| | value="demo/0_mix.wav" |
| | ) |
| |
|
| | with gr.Row(equal_height=True): |
| | text_input = gr.Textbox( |
| | label="Describe The Sound You Want to Extract", |
| | show_label=True, |
| | max_lines=2, |
| | placeholder="Enter your prompt", |
| | value="The sound of gunshot", |
| | container=True, |
| | scale=4 |
| | ) |
| | run_button = gr.Button("Extract", scale=1) |
| |
|
| | |
| | result = gr.Audio(label="Extracted Audio Stem", type="numpy") |
| |
|
| | |
| | with gr.Accordion("Advanced Settings", open=False): |
| | |
| | guidance_scale = gr.Slider(minimum=1.0, maximum=10, step=0.1, value=3.0, label="Guidance Scale") |
| | guidance_rescale = gr.Slider(minimum=0.0, maximum=1, step=0.05, value=0., label="Guidance Rescale") |
| | num_infer_steps = gr.Slider(minimum=25, maximum=200, step=5, value=50, label="DDIM Steps") |
| | eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.0, label="Eta") |
| | seed = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Seed") |
| |
|
| | |
| | run_button.click( |
| | fn=tse, |
| | inputs=[gt_file_input, text_input, num_infer_steps, eta, seed, guidance_scale, guidance_rescale], |
| | outputs=[result] |
| | ) |
| | text_input.submit(fn=tse, |
| | inputs=[gt_file_input, text_input, num_infer_steps, eta, seed, guidance_scale, guidance_rescale], |
| | outputs=[result] |
| | ) |
| |
|
| | |
| | demo.launch() |