Apollo / app.py
NatalieElizabeth's picture
further cleanup and resetting device to cuda inside the process function
d4f3fbb
import os
import torch
import torchaudio
import spaces
from huggingface_hub import hf_hub_download
# For PyHARP wrapper
from pyharp import ModelCard, build_endpoint
import gradio as gr
# Create a ModelCard
model_card = ModelCard(
name="Apollo",
description="High-quality audio restoration for lossy MP3 compressed audio. Converts low-bitrate MP3s to near-lossless quality using band-sequence modeling.",
author="JusperLee",
tags=["audio restoration", "music", "apollo", "mp3", "lossless"],
)
def load_audio(file_path):
audio, samplerate = torchaudio.load(file_path)
return audio.unsqueeze(0)
# Load the model outside of the process function so that it only has to happen once
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("Loading Apollo model...")
# Download model weights from HuggingFace
model_path = hf_hub_download(
repo_id="JusperLee/Apollo",
filename="pytorch_model.bin",
cache_dir="./checkpoints"
)
# Load checkpoint WITH OmegaConf support
print(f"Loading checkpoint from {model_path}")
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
# Extract model info
model_name = checkpoint['model_name']
state_dict = checkpoint['state_dict']
model_args = checkpoint.get('model_args', {})
print(f"Model class: {model_name}")
print(f"Model args: {model_args}")
# Import the correct model class
from look2hear.models import get
model_class = get(model_name)
# Create model instance with model_args
# Convert OmegaConf to dict if needed
if hasattr(model_args, 'to_container'):
model_args = model_args.to_container(resolve=True)
print(f"Instantiating {model_name}...")
model = model_class(**model_args)
# Load state dict
print("Loading state dict...")
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
print("✓ Model loaded successfully")
# Defining the process function
@spaces.GPU
@torch.inference_mode()
def process_fn(
input_audio_path: str
) -> str:
device = torch.device("cuda")
sig = load_audio(input_audio_path)
# Move audio data to device
sig = sig.to(device)
# Add batch dimension if needed (Apollo expects [batch, channels, samples])
if sig.dim() == 2:
sig = sig.unsqueeze(0)
result = model(sig)
# Remove batch dimension
result = result.squeeze(0)
output_audio_path = os.path.join("src", "_outputs", "output_restored.wav")
os.makedirs(os.path.dirname(output_audio_path), exist_ok=True)
torchaudio.save(output_audio_path, result.cpu(), 44100)
print(f"✓ Saved output to {output_audio_path}")
return output_audio_path
# Build Gradio endpoint
with gr.Blocks() as demo:
# Define input Gradio Components
input_components = [
gr.Audio(type="filepath",
label="Input Audio A")
.harp_required(True),
]
# Define output Gradio Components
output_components = [
gr.Audio(type="filepath",
label="Output Audio")
.set_info("The restored audio."),
]
# Build a HARP-compatible endpoint
app = build_endpoint(
model_card=model_card,
input_components=input_components,
output_components=output_components,
process_fn=process_fn,
)
# run the model
demo.queue().launch(show_error=True, pwa=True)