| | import torch |
| | import spaces |
| | import tempfile |
| | import soundfile as sf |
| | import gradio as gr |
| | import librosa as lb |
| | import yaml |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | from model.cleanmel import CleanMel |
| | from model.vocos.pretrained import Vocos |
| | from model.stft import InputSTFT, TargetMel |
| |
|
| | DEVICE = torch.device("cuda:5") |
| |
|
| | def read_audio(file_path): |
| | audio, sample_rate = sf.read(file_path) |
| | if audio.ndim > 1: |
| | audio = audio[:, 0] |
| | if sample_rate != 16000: |
| | audio = lb.resample(audio, orig_sr=sample_rate, target_sr=16000) |
| | sample_rate = 16000 |
| |
|
| | return torch.tensor(audio).float().squeeze().unsqueeze(0) |
| |
|
| | def stft(audio): |
| | transform = InputSTFT( |
| | n_fft=512, |
| | n_win=512, |
| | n_hop=128, |
| | normalize=False, |
| | center=True, |
| | onesided=True, |
| | online=False |
| | ).to(DEVICE) |
| | return transform(audio) |
| |
|
| | def mel_transform(audio, X_norm): |
| | transform = TargetMel( |
| | sample_rate=16000, |
| | n_fft=512, |
| | n_win=512, |
| | n_hop=128, |
| | n_mels=80, |
| | f_min=0, |
| | f_max=8000, |
| | power=2, |
| | center=True, |
| | normalize=False, |
| | onesided=True, |
| | mel_norm="slaney", |
| | mel_scale="slaney", |
| | librosa_mel=True, |
| | online=False |
| | ).to(DEVICE) |
| | return transform(audio, X_norm) |
| |
|
| | def load_cleanmel(model_name): |
| | model_config = f"./configs/cleanmel_offline.yaml" |
| | model_config = yaml.safe_load(open(model_config, "r"))["model"]["arch"]["init_args"] |
| | cleanmel = CleanMel(**model_config) |
| | cleanmel.load_state_dict(torch.load(f"./ckpts/CleanMel/{model_name}.ckpt")) |
| | return cleanmel.eval() |
| |
|
| | def load_vocos(model_name): |
| | vocos = Vocos.from_hparams(config_path="./configs/vocos_offline.yaml") |
| | vocos = Vocos.from_pretrained(None, model_path=f"./ckpts/Vocos/{model_name}.pt", model=vocos) |
| | return vocos.eval() |
| |
|
| | def get_mrm_pred(Y_hat, x, X_norm): |
| | X_noisy = mel_transform(x, X_norm) |
| | Y_hat = Y_hat.squeeze() |
| | Y_hat = torch.square(Y_hat * (torch.sqrt(X_noisy) + 1e-10)) |
| | return Y_hat |
| |
|
| | def safe_log(x): |
| | return torch.log(torch.clip(x, min=1e-5)) |
| |
|
| | @spaces.GPU |
| | @torch.inference_mode() |
| | def enhance_cleanmel_L_mask(audio_path): |
| | model = load_cleanmel("offline_CleanMel_L_mask").to(DEVICE) |
| | vocos = load_vocos("vocos_offline").to(DEVICE) |
| | x = read_audio(audio_path).to(DEVICE) |
| | X, X_norm = stft(x) |
| | Y_hat = model(X) |
| | MRM_hat = torch.sigmoid(Y_hat) |
| | Y_hat = get_mrm_pred(MRM_hat, x, X_norm) |
| | logMel_hat = safe_log(Y_hat) |
| | y_hat = vocos(logMel_hat, X_norm) |
| | with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file: |
| | sf.write(tmp_file.name, y_hat.squeeze().cpu().numpy(), 16000) |
| | with tempfile.NamedTemporaryFile(suffix='.npy', delete=False) as tmp_logmel_np_file: |
| | np.save(tmp_logmel_np_file.name, logMel_hat.squeeze().cpu().numpy()) |
| | logMel_img = logMel_hat.squeeze().cpu().numpy()[::-1, :] |
| | with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_logmel_img: |
| | |
| | plt.figure(figsize=(logMel_img.shape[1] / 100, logMel_img.shape[0] / 50)) |
| | plt.clf() |
| | plt.imshow(logMel_img, vmin=-11, cmap="jet") |
| | plt.tight_layout() |
| | plt.ylabel("Mel bands") |
| | plt.xlabel("Time (second)") |
| | plt.yticks([0, 80], [80, 0]) |
| | dur = x.shape[-1] / 16000 |
| | xticks = [int(x) for x in np.linspace(0, logMel_img.shape[-1], 11)] |
| | xticks_str = ["{:.1f}".format(x) for x in np.linspace(0, dur, 11)] |
| | plt.xticks(xticks, xticks_str) |
| | plt.savefig(tmp_logmel_img.name) |
| | |
| | return tmp_file.name, tmp_logmel_img.name, tmp_logmel_np_file.name |
| |
|
| | if __name__ == "__main__": |
| | demo = gr.Blocks() |
| | with gr.Blocks(title="CleanMel Demo") as demo: |
| | gr.Markdown("## CleanMel Demo") |
| | gr.Markdown("This demo showcases the CleanMel model for speech enhancement.") |
| | |
| | with gr.Row(): |
| | audio_input = gr.Audio(label="Input Audio", type="filepath", sources="upload") |
| | enhance_button = gr.Button("Enhance Audio") |
| | |
| | output_audio = gr.Audio(label="Enhanced Audio", type="filepath") |
| | output_mel = gr.Image(label="Output LogMel Spectrogram", type="filepath", visible=True) |
| | output_np = gr.File(label="Enhanced LogMel Spec. (.npy)", type="filepath") |
| | |
| | enhance_button.click( |
| | enhance_cleanmel_L_mask, |
| | inputs=audio_input, |
| | outputs=[output_audio, output_mel, output_np] |
| | ) |
| |
|
| | demo.launch(debug=False, share=True) |