busy-module-text / handler.py
EurekaPotato's picture
Upload folder using huggingface_hub
c413c42 verified
"""
Text Feature Extraction β€” Hugging Face Inference Endpoint Handler
Extracts all 9 text features from conversation transcript:
t0_explicit_free, t1_explicit_busy, t2_avg_resp_len, t3_short_ratio,
t4_cognitive_load, t5_time_pressure, t6_deflection, t7_sentiment,
t8_coherence, t9_latency
Derived from: src/text_features.py
"""
# ──────────────────────────────────────────────────────────────────────── #
# Imports from standardized modules
# ──────────────────────────────────────────────────────────────────────── #
try:
from text_features import TextFeatureExtractor
except ImportError:
import sys
sys.path.append('.')
from text_features import TextFeatureExtractor
# Initialize global extractor
print("[INFO] Initializing Global TextFeatureExtractor...")
# Preload models to avoid first-request latency in the Space runtime.
extractor = TextFeatureExtractor(use_intent_model=True, preload=True)
# ──────────────────────────────────────────────────────────────────────── #
# FastAPI handler for deployment
# ──────────────────────────────────────────────────────────────────────── #
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import Optional, List, Dict
import traceback
import numpy as np
import os
# ──────────────────────────────────────────────────────────────────────── #
# Constants & Defaults
# ──────────────────────────────────────────────────────────────────────── #
DEFAULT_TEXT_FEATURES = {
"t0_explicit_free": 0.0, "t1_explicit_busy": 0.0,
"t2_avg_resp_len": 0.0, "t3_short_ratio": 0.0,
"t4_cognitive_load": 0.0, "t5_time_pressure": 0.0,
"t6_deflection": 0.0, "t7_sentiment": 0.0,
"t8_coherence": 0.5, "t9_latency": 0.0,
}
app = FastAPI(title="Text Feature Extraction API", version="1.0.0")
def _cors_origins_from_env() -> list[str]:
raw = (os.getenv("ALLOWED_ORIGINS") or "").strip()
if not raw:
return ["*"]
return [o.strip() for o in raw.split(",") if o.strip()]
_cors_origins = _cors_origins_from_env()
app.add_middleware(
CORSMiddleware,
allow_origins=_cors_origins,
# Browsers reject: Access-Control-Allow-Origin="*" with credentials=true.
allow_credentials=("*" not in _cors_origins),
allow_methods=["*"], allow_headers=["*"],
)
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
print(f"[GLOBAL ERROR] {request.url}: {exc}")
traceback.print_exc()
return JSONResponse(
status_code=200,
content={**DEFAULT_TEXT_FEATURES, "_error": str(exc), "_handler": "global"},
)
class TextRequest(BaseModel):
transcript: str = ""
# Optional list of extra utterances if available
utterances: List[str] = []
question: str = ""
events: Optional[List[Dict]] = None
@app.get("/")
async def root():
return {
"service": "Text Feature Extraction API",
"version": "1.0.0",
"endpoints": ["/health", "/extract-text-features"],
}
@app.get("/health")
async def health():
return {
"status": "healthy",
"intent_model_loaded": extractor.use_intent_model,
"models_preloaded": True,
}
@app.post("/extract-text-features")
async def extract_text_features(data: TextRequest):
"""Extract all 9 text features from transcript."""
# Prepare inputs for TextFeatureExtractor.extract_all
# It expects: transcript_list, full_transcript, question, events
transcript_list = data.utterances
if not transcript_list and data.transcript:
transcript_list = [data.transcript]
features = extractor.extract_all(
transcript_list=transcript_list,
full_transcript=data.transcript,
question=data.question,
events=data.events,
)
# Sanitize inputs to ensure floats
sanitized = {}
for k, v in features.items():
if isinstance(v, float):
sanitized[k] = 0.0 if np.isnan(v) or np.isinf(v) else v
else:
sanitized[k] = v
return sanitized
if __name__ == "__main__":
import uvicorn
import os
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)