| import os |
| import torch |
| import torchaudio |
| import argparse |
| from huggingface_hub import hf_hub_download |
|
|
| |
| from pyharp import ModelCard, build_endpoint, load_audio, save_audio |
| import gradio as gr |
|
|
| |
| model_card = ModelCard( |
| name="Apollo", |
| description="High-quality audio restoration for lossy MP3 compressed audio. Converts low-bitrate MP3s to near-lossless quality using band-sequence modeling.", |
| author="JusperLee", |
| tags=["audio restoration", "music", "apollo", "mp3", "lossless"], |
| ) |
|
|
| def load_audio(file_path): |
| audio, samplerate = torchaudio.load(file_path) |
| return audio.unsqueeze(0) |
|
|
| def save_audio(file_path, audio, samplerate=44100): |
| audio = audio.squeeze(0).cpu() |
| torchaudio.save(file_path, audio, samplerate) |
|
|
| |
| @torch.inference_mode() |
| def process_fn( |
| input_audio_path: str |
| ) -> str: |
| |
| device = torch.device("cpu") |
| |
| print(f"Using device: {device}") |
| print("Loading Apollo model...") |
| |
| |
| model_path = hf_hub_download( |
| repo_id="JusperLee/Apollo", |
| filename="pytorch_model.bin", |
| cache_dir="./checkpoints" |
| ) |
| |
| |
| print(f"Loading checkpoint from {model_path}") |
| checkpoint = torch.load(model_path, map_location=device, weights_only=False) |
| |
| |
| model_name = checkpoint['model_name'] |
| state_dict = checkpoint['state_dict'] |
| model_args = checkpoint.get('model_args', {}) |
| |
| print(f"Model class: {model_name}") |
| print(f"Model args: {model_args}") |
| |
| |
| from look2hear.models import get |
| model_class = get(model_name) |
| |
| |
| |
| if hasattr(model_args, 'to_container'): |
| model_args = model_args.to_container(resolve=True) |
| |
| print(f"Instantiating {model_name}...") |
| model = model_class(**model_args) |
| |
| |
| print("Loading state dict...") |
| model.load_state_dict(state_dict) |
| |
| model = model.to(device) |
| model.eval() |
| print("✓ Model loaded successfully") |
| |
| |
| |
| sig = load_audio(input_audio_path) |
|
|
| |
| sig = sig.to(device) |
|
|
| |
| if sig.dim() == 2: |
| sig = sig.unsqueeze(0) |
| |
| with torch.no_grad(): |
| output = model(sig) |
|
|
| |
| output = output.squeeze(0) |
| |
| output_audio_path = os.path.join("src", "_outputs", "output_restored.wav") |
| os.makedirs(os.path.dirname(output_audio_path), exist_ok=True) |
| torchaudio.save(output_audio_path, output, 44100) |
| print(f"✓ Saved output to {output_audio_path}") |
|
|
| return output_audio_path |
|
|
| |
| |
| |
|
|
| |
| with gr.Blocks() as demo: |
| |
| input_components = [ |
| gr.Audio(type="filepath", |
| label="Input Audio A") |
| .harp_required(True), |
| ] |
|
|
| |
| output_components = [ |
| gr.Audio(type="filepath", |
| label="Output Audio") |
| .set_info("The restored audio."), |
| ] |
|
|
| |
| app = build_endpoint( |
| model_card=model_card, |
| input_components=input_components, |
| output_components=output_components, |
| process_fn=process_fn, |
| ) |
|
|
| |
| demo.queue().launch(share=True, show_error=False, pwa=True) |
|
|
| |
| ''' |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Audio Inference Script") |
| parser.add_argument("--in_wav", type=str, required=True, help="Path to input wav file") |
| parser.add_argument("--out_wav", type=str, required=True, help="Path to output wav file") |
| args = parser.parse_args() |
| |
| main(args.in_wav, args.out_wav) |
| ''' |
|
|