| """ |
| HuggingFace Inference Endpoint Handler for SongFormer |
| Supports binary audio input (WAV, MP3, etc.) via base64 encoding or direct bytes |
| """ |
|
|
| import os |
| import sys |
| import io |
| import base64 |
| import json |
| import tempfile |
| from typing import Dict, Any, Union |
| import librosa |
| import numpy as np |
| import torch |
| from transformers import AutoModel |
|
|
| class EndpointHandler: |
| """ |
| HuggingFace Inference Endpoint Handler for SongFormer model. |
| |
| Accepts base64-encoded audio (WAV, MP3, FLAC, etc.) |
| """ |
|
|
| def __init__(self, path: str = ""): |
| """ |
| Initialize the handler and load the SongFormer model. |
| |
| Args: |
| path: Path to the model directory (provided by HuggingFace) |
| """ |
| |
| self.model_path = path or os.getcwd() |
| os.environ["SONGFORMER_LOCAL_DIR"] = self.model_path |
| sys.path.insert(0, self.model_path) |
|
|
| |
|
|
| |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Loading SongFormer model on {self.device}...") |
|
|
| |
| |
| self.model = AutoModel.from_pretrained( |
| self.model_path, |
| trust_remote_code=True, |
| device_map=None, |
| ) |
| self.model.to(self.device) |
| self.model.eval() |
|
|
| |
| self.target_sr = 24000 |
|
|
| print("SongFormer model loaded successfully!") |
|
|
| def _decode_base64_audio(self, audio_b64: str) -> np.ndarray: |
| """ |
| Decode base64-encoded audio to numpy array. |
| |
| Args: |
| audio_b64: Base64-encoded audio string |
| |
| Returns: |
| numpy array of audio samples at 24kHz |
| """ |
| |
| try: |
| audio_bytes = base64.b64decode(audio_b64) |
| except Exception as e: |
| raise ValueError(f"Failed to decode base64 audio data: {e}") |
|
|
| |
|
|
| |
| audio_io = io.BytesIO(audio_bytes) |
|
|
| |
| audio_array, _ = librosa.load(audio_io, sr=self.target_sr, mono=True) |
|
|
| return audio_array |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Process inference request with base64-encoded audio. |
| |
| Expected input: |
| { |
| "inputs": "<base64-encoded-audio-data>" |
| } |
| |
| Returns: |
| { |
| "segments": [ |
| { |
| "label": "intro", |
| "start": 0.0, |
| "end": 15.2 |
| }, |
| ... |
| ], |
| "duration": 180.5, |
| "num_segments": 8 |
| } |
| """ |
| try: |
| |
| audio_b64 = data.get("inputs") |
| if not audio_b64: |
| raise ValueError("Missing 'inputs' key with base64-encoded audio") |
|
|
| if not isinstance(audio_b64, str): |
| raise ValueError("Input must be a base64-encoded string") |
|
|
| |
| audio_array = self._decode_base64_audio(audio_b64) |
|
|
| |
| with torch.no_grad(): |
| result = self.model(audio_array) |
|
|
| |
| duration = len(audio_array) / self.target_sr |
|
|
| |
| output = { |
| "segments": result, |
| "duration": float(duration), |
| "num_segments": len(result) |
| } |
|
|
| return output |
|
|
| except Exception as e: |
| |
| return { |
| "error": str(e), |
| "error_type": type(e).__name__, |
| "segments": [], |
| "duration": 0.0, |
| "num_segments": 0 |
| } |
|
|
|
|
| |
| if __name__ == "__main__": |
| import argparse |
|
|
| parser = argparse.ArgumentParser(description="Test SongFormer handler locally") |
| parser.add_argument("audio_file", help="Path to audio file to test") |
| parser.add_argument("--model-path", default=".", help="Path to model directory") |
| args = parser.parse_args() |
|
|
| |
| handler = EndpointHandler(args.model_path) |
|
|
| |
| with open(args.audio_file, "rb") as f: |
| audio_bytes = f.read() |
|
|
| audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') |
|
|
| |
| print("\n=== Testing with base64-encoded audio ===") |
| result = handler({"inputs": audio_b64}) |
| print(json.dumps(result, indent=2)) |
|
|
| |
| print("\n=== Testing with direct file path (not typical for endpoint) ===") |
| result_direct = handler.model(args.audio_file) |
| print(json.dumps(result_direct, indent=2)) |
|
|