import os import torch import torchaudio import spaces from huggingface_hub import hf_hub_download # For PyHARP wrapper from pyharp import ModelCard, build_endpoint import gradio as gr # Create a ModelCard model_card = ModelCard( name="Apollo", description="High-quality audio restoration for lossy MP3 compressed audio. Converts low-bitrate MP3s to near-lossless quality using band-sequence modeling.", author="JusperLee", tags=["audio restoration", "music", "apollo", "mp3", "lossless"], ) def load_audio(file_path): audio, samplerate = torchaudio.load(file_path) return audio.unsqueeze(0) # Load the model outside of the process function so that it only has to happen once device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") print("Loading Apollo model...") # Download model weights from HuggingFace model_path = hf_hub_download( repo_id="JusperLee/Apollo", filename="pytorch_model.bin", cache_dir="./checkpoints" ) # Load checkpoint WITH OmegaConf support print(f"Loading checkpoint from {model_path}") checkpoint = torch.load(model_path, map_location=device, weights_only=False) # Extract model info model_name = checkpoint['model_name'] state_dict = checkpoint['state_dict'] model_args = checkpoint.get('model_args', {}) print(f"Model class: {model_name}") print(f"Model args: {model_args}") # Import the correct model class from look2hear.models import get model_class = get(model_name) # Create model instance with model_args # Convert OmegaConf to dict if needed if hasattr(model_args, 'to_container'): model_args = model_args.to_container(resolve=True) print(f"Instantiating {model_name}...") model = model_class(**model_args) # Load state dict print("Loading state dict...") model.load_state_dict(state_dict) model = model.to(device) model.eval() print("✓ Model loaded successfully") # Defining the process function @spaces.GPU @torch.inference_mode() def process_fn( input_audio_path: str ) -> str: device = torch.device("cuda") sig = load_audio(input_audio_path) # Move audio data to device sig = sig.to(device) # Add batch dimension if needed (Apollo expects [batch, channels, samples]) if sig.dim() == 2: sig = sig.unsqueeze(0) result = model(sig) # Remove batch dimension result = result.squeeze(0) output_audio_path = os.path.join("src", "_outputs", "output_restored.wav") os.makedirs(os.path.dirname(output_audio_path), exist_ok=True) torchaudio.save(output_audio_path, result.cpu(), 44100) print(f"✓ Saved output to {output_audio_path}") return output_audio_path # Build Gradio endpoint with gr.Blocks() as demo: # Define input Gradio Components input_components = [ gr.Audio(type="filepath", label="Input Audio A") .harp_required(True), ] # Define output Gradio Components output_components = [ gr.Audio(type="filepath", label="Output Audio") .set_info("The restored audio."), ] # Build a HARP-compatible endpoint app = build_endpoint( model_card=model_card, input_components=input_components, output_components=output_components, process_fn=process_fn, ) # run the model demo.queue().launch(show_error=True, pwa=True)