busy-module-audio / handler.py
EurekaPotato's picture
Upload folder using huggingface_hub
8263279 verified
"""
Audio Feature Extraction — Hugging Face Inference Endpoint Handler
Extracts all 17 voice features from uploaded audio:
v1_snr, v2_noise_* (5), v3_speech_rate, v4/v5_pitch, v6/v7_energy,
v8/v9/v10_pause, v11/v12/v13_emotion
Derived from: src/audio_features.py, src/emotion_features.py
"""
import io
import os
import tempfile
import numpy as np
import librosa
from scipy import signal as scipy_signal
from typing import Dict
import torch
import torch.nn as nn
from torchvision import models
import warnings
warnings.filterwarnings("ignore")
# ──────────────────────────────────────────────────────────────────────── #
# Imports from standardized modules
# ──────────────────────────────────────────────────────────────────────── #
try:
from audio_features import AudioFeatureExtractor
except ImportError:
# Fallback if running from a different context
import sys
sys.path.append('.')
from audio_features import AudioFeatureExtractor
# Initialize global extractor
# We use a global instance to cache models (VAD, Emotion)
print("[INFO] Initializing Global AudioFeatureExtractor...")
extractor = AudioFeatureExtractor(
sample_rate=16000,
use_emotion=True,
emotion_models_dir="/app/models" # Absolute path in Docker container
)
# Ensure models are downloaded/ready
if extractor.use_emotion and extractor.emotion_extractor:
print("[INFO] Checking for emotion models...")
# Trigger download if needed/possible
try:
if len(extractor.emotion_extractor.models) == 0:
print("[INFO] Models not found, attempting download...")
extractor.emotion_extractor.download_models()
# Re-init manually to load them
extractor.emotion_extractor.__init__(models_dir=extractor.emotion_extractor.models_dir)
except Exception as e:
print(f"[WARN] Failed to download emotion models: {e}")
# ──────────────────────────────────────────────────────────────────────── #
# Helper to handle NaN/Inf for JSON
# ──────────────────────────────────────────────────────────────────────── #
def sanitize_features(features: Dict[str, float]) -> Dict[str, float]:
sanitized = {}
for key, val in features.items():
if isinstance(val, (float, np.floating)):
if np.isnan(val) or np.isinf(val):
sanitized[key] = 0.0
else:
sanitized[key] = float(val)
elif isinstance(val, (int, np.integer)):
sanitized[key] = int(val)
else:
sanitized[key] = val # keep string/other as is
return sanitized
# ──────────────────────────────────────────────────────────────────────── #
# FastAPI handler for deployment (HF Spaces / Cloud Run / Lambda)
# ──────────────────────────────────────────────────────────────────────── #
from fastapi import FastAPI, File, UploadFile, Form, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import Optional
import base64
import traceback
app = FastAPI(title="Audio Feature Extraction API", version="1.0.0")
def _cors_origins_from_env() -> list[str]:
raw = (os.getenv("ALLOWED_ORIGINS") or "").strip()
if not raw:
return ["*"]
return [o.strip() for o in raw.split(",") if o.strip()]
_cors_origins = _cors_origins_from_env()
app.add_middleware(
CORSMiddleware,
allow_origins=_cors_origins,
# Browsers reject: Access-Control-Allow-Origin="*" with credentials=true.
allow_credentials=("*" not in _cors_origins),
allow_methods=["*"], allow_headers=["*"],
)
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
"""Catch any unhandled exceptions and return defaults instead of 500."""
print(f"[GLOBAL ERROR] {request.url}: {exc}")
traceback.print_exc()
return JSONResponse(
status_code=200,
content={**DEFAULT_AUDIO_FEATURES, "_error": str(exc), "_handler": "global"},
)
# Extractor is already initialized globally above
# ──────────────────────────────────────────────────────────────────────── #
# Constants & Defaults
# ──────────────────────────────────────────────────────────────────────── #
DEFAULT_AUDIO_FEATURES = {
"v1_snr": 0.0,
"v2_noise_traffic": 0.0,
"v2_noise_office": 0.0,
"v2_noise_crowd": 0.0,
"v2_noise_wind": 0.0,
"v2_noise_clean": 1.0,
"v3_speech_rate": 0.0,
"v4_pitch_mean": 0.0,
"v5_pitch_std": 0.0,
"v6_energy_mean": 0.0,
"v7_energy_std": 0.0,
"v8_pause_ratio": 0.0,
"v9_avg_pause_dur": 0.0,
"v10_mid_pause_cnt": 0.0,
"v11_emotion_stress": 0.0,
"v12_emotion_energy": 0.0,
"v13_emotion_valence": 0.0,
}
class AudioBase64Request(BaseModel):
audio_base64: str = ""
transcript: str = ""
mime_type: str = ""
def infer_audio_extension(audio_bytes: bytes, mime_type: str = "") -> str:
normalized = (mime_type or "").lower().split(";")[0].strip()
mime_map = {
"audio/webm": ".webm",
"audio/ogg": ".ogg",
"audio/wav": ".wav",
"audio/x-wav": ".wav",
"audio/mpeg": ".mp3",
"audio/mp3": ".mp3",
"audio/mp4": ".m4a",
"audio/x-m4a": ".m4a",
"audio/aac": ".aac",
"audio/flac": ".flac",
}
if normalized in mime_map:
return mime_map[normalized]
if audio_bytes.startswith(b"RIFF"):
return ".wav"
if audio_bytes.startswith(b"OggS"):
return ".ogg"
if audio_bytes.startswith(b"\x1A\x45\xDF\xA3"):
return ".webm"
if audio_bytes.startswith(b"fLaC"):
return ".flac"
if audio_bytes[4:8] == b"ftyp":
return ".m4a"
if audio_bytes.startswith(b"ID3") or (len(audio_bytes) > 1 and audio_bytes[0] == 0xFF and (audio_bytes[1] & 0xE0) == 0xE0):
return ".mp3"
return ".bin"
def decode_audio_bytes(audio_bytes: bytes, mime_type: str = ""):
import soundfile as sf
try:
y, sr = sf.read(io.BytesIO(audio_bytes))
return y, sr
except Exception as sf_err:
print(f"[WARN] soundfile failed ({sf_err}), trying librosa from buffer...")
try:
y, sr = librosa.load(io.BytesIO(audio_bytes), sr=16000, mono=True)
return y, sr
except Exception as librosa_err:
print(f"[WARN] librosa buffer decode failed ({librosa_err}), trying temp file...")
suffix = infer_audio_extension(audio_bytes, mime_type)
temp_path = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
temp_file.write(audio_bytes)
temp_path = temp_file.name
y, sr = librosa.load(temp_path, sr=16000, mono=True)
return y, sr
finally:
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
@app.get("/")
async def root():
return {
"service": "Audio Feature Extraction API",
"version": "1.0.0",
"endpoints": ["/health", "/extract-audio-features", "/extract-audio-features-base64"],
}
@app.get("/health")
async def health():
vad_status = extractor.vad_model is not None
emotion_status = extractor.emotion_extractor is not None if extractor.use_emotion else False
return {
"status": "healthy",
"vad_loaded": vad_status,
"emotion_loaded": emotion_status
}
@app.post("/extract-audio-features")
async def extract_audio_features(audio: UploadFile = File(...), transcript: str = Form("")):
"""Extract all 17 voice features from uploaded audio file."""
try:
audio_bytes = await audio.read()
# librosa.load returns (audio, sr)
y, sr = librosa.load(io.BytesIO(audio_bytes), sr=16000, mono=True)
# AudioFeatureExtractor.extract_all expects numpy array and optional transcript
features = extractor.extract_all(y, transcript)
return sanitize_features(features)
except Exception as e:
print(f"[ERROR] extract_audio_features: {e}")
traceback.print_exc()
return {**DEFAULT_AUDIO_FEATURES, "_error": str(e)}
@app.post("/extract-audio-features-base64")
async def extract_audio_features_base64(data: AudioBase64Request):
"""Extract features from base64-encoded audio (for Vercel serverless calls)."""
audio_b64 = data.audio_base64
transcript = data.transcript
mime_type = data.mime_type
# Handle empty / missing audio — return default features
if not audio_b64 or len(audio_b64) < 100:
print("[INFO] Empty or too-short audio_base64, returning defaults")
return {**DEFAULT_AUDIO_FEATURES}
try:
# Strip data URL prefix if present (e.g. "data:audio/wav;base64,...")
if "," in audio_b64[:80]:
audio_b64 = audio_b64.split(",", 1)[1]
audio_bytes = base64.b64decode(audio_b64)
print(f"[INFO] Decoded {len(audio_bytes)} bytes of audio")
if mime_type:
print(f"[INFO] MIME type hint: {mime_type}")
y, sr = decode_audio_bytes(audio_bytes, mime_type)
if hasattr(y, 'shape') and len(y.shape) > 1:
y = np.mean(y, axis=1)
y = np.asarray(y, dtype=np.float32)
if sr != 16000:
y = librosa.resample(y, orig_sr=sr, target_sr=16000)
y = y.astype(np.float32)
if len(y) < 100:
print("[WARN] Audio too short after decode, returning defaults")
return {**DEFAULT_AUDIO_FEATURES}
features = extractor.extract_all(y, transcript)
print(f"[OK] Extracted {len(features)} audio features")
return sanitize_features(features)
except Exception as e:
print(f"[ERROR] extract_audio_features_base64: {e}")
traceback.print_exc()
# Return defaults rather than 500
return {**DEFAULT_AUDIO_FEATURES, "_error": str(e)}
if __name__ == "__main__":
import uvicorn
import os
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)