Spaces:
Running on Zero
Running on Zero
| import os | |
| import torch | |
| import torchaudio | |
| import spaces | |
| from huggingface_hub import hf_hub_download | |
| # For PyHARP wrapper | |
| from pyharp import ModelCard, build_endpoint | |
| import gradio as gr | |
| # Create a ModelCard | |
| model_card = ModelCard( | |
| name="Apollo", | |
| description="High-quality audio restoration for lossy MP3 compressed audio. Converts low-bitrate MP3s to near-lossless quality using band-sequence modeling.", | |
| author="JusperLee", | |
| tags=["audio restoration", "music", "apollo", "mp3", "lossless"], | |
| ) | |
| def load_audio(file_path): | |
| audio, samplerate = torchaudio.load(file_path) | |
| return audio.unsqueeze(0) | |
| # Load the model outside of the process function so that it only has to happen once | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| print("Loading Apollo model...") | |
| # Download model weights from HuggingFace | |
| model_path = hf_hub_download( | |
| repo_id="JusperLee/Apollo", | |
| filename="pytorch_model.bin", | |
| cache_dir="./checkpoints" | |
| ) | |
| # Load checkpoint WITH OmegaConf support | |
| print(f"Loading checkpoint from {model_path}") | |
| checkpoint = torch.load(model_path, map_location=device, weights_only=False) | |
| # Extract model info | |
| model_name = checkpoint['model_name'] | |
| state_dict = checkpoint['state_dict'] | |
| model_args = checkpoint.get('model_args', {}) | |
| print(f"Model class: {model_name}") | |
| print(f"Model args: {model_args}") | |
| # Import the correct model class | |
| from look2hear.models import get | |
| model_class = get(model_name) | |
| # Create model instance with model_args | |
| # Convert OmegaConf to dict if needed | |
| if hasattr(model_args, 'to_container'): | |
| model_args = model_args.to_container(resolve=True) | |
| print(f"Instantiating {model_name}...") | |
| model = model_class(**model_args) | |
| # Load state dict | |
| print("Loading state dict...") | |
| model.load_state_dict(state_dict) | |
| model = model.to(device) | |
| model.eval() | |
| print("✓ Model loaded successfully") | |
| # Defining the process function | |
| def process_fn( | |
| input_audio_path: str | |
| ) -> str: | |
| device = torch.device("cuda") | |
| sig = load_audio(input_audio_path) | |
| # Move audio data to device | |
| sig = sig.to(device) | |
| # Add batch dimension if needed (Apollo expects [batch, channels, samples]) | |
| if sig.dim() == 2: | |
| sig = sig.unsqueeze(0) | |
| result = model(sig) | |
| # Remove batch dimension | |
| result = result.squeeze(0) | |
| output_audio_path = os.path.join("src", "_outputs", "output_restored.wav") | |
| os.makedirs(os.path.dirname(output_audio_path), exist_ok=True) | |
| torchaudio.save(output_audio_path, result.cpu(), 44100) | |
| print(f"✓ Saved output to {output_audio_path}") | |
| return output_audio_path | |
| # Build Gradio endpoint | |
| with gr.Blocks() as demo: | |
| # Define input Gradio Components | |
| input_components = [ | |
| gr.Audio(type="filepath", | |
| label="Input Audio A") | |
| .harp_required(True), | |
| ] | |
| # Define output Gradio Components | |
| output_components = [ | |
| gr.Audio(type="filepath", | |
| label="Output Audio") | |
| .set_info("The restored audio."), | |
| ] | |
| # Build a HARP-compatible endpoint | |
| app = build_endpoint( | |
| model_card=model_card, | |
| input_components=input_components, | |
| output_components=output_components, | |
| process_fn=process_fn, | |
| ) | |
| # run the model | |
| demo.queue().launch(show_error=True, pwa=True) |