rohith083's picture
Update app.py
28aca74 verified
from __future__ import annotations
import os
from typing import List
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from transformers import AutoModelForSequenceClassification, AutoTokenizer
MODEL_NAME = os.getenv("SCIBERT_MODEL", "allenai/scibert_scivocab_uncased")
MAX_LENGTH = int(os.getenv("SCIBERT_MAX_LENGTH", "512"))
LABELS = [
"HEADING",
"ABSTRACT",
"BODY",
"REFERENCES",
"FIGURE_CAPTION",
"TABLE_CAPTION",
"ACKNOWLEDGEMENTS",
"EQUATION",
"METHODOLOGY",
"CONCLUSION",
"AUTHOR_INFO",
"TITLE",
]
app = FastAPI(title="Scholarform SciBERT Service", version="1.0.0")
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
ignore_mismatched_sizes=True,
).to(device)
model.eval()
class PredictRequest(BaseModel):
texts: List[str] = Field(default_factory=list)
@app.get("/")
def root():
return {"status": "ok", "service": "scibert", "model": MODEL_NAME, "device": device}
@app.get("/health")
def health():
return {"status": "ok", "service": "scibert", "model": MODEL_NAME}
@app.post("/predict")
def predict(payload: PredictRequest):
texts = [t or "" for t in payload.texts]
if not texts:
raise HTTPException(status_code=422, detail="`texts` must contain at least one string")
try:
inputs = tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=MAX_LENGTH,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)
confidences, label_idxs = torch.max(probs, dim=1)
predictions = []
for confidence, idx in zip(confidences, label_idxs):
label_index = idx.item()
label = LABELS[label_index] if label_index < len(LABELS) else "BODY"
predictions.append({"type": label, "confidence": float(confidence.item())})
return {
"status": "ok",
"service": "scibert",
"model": MODEL_NAME,
"predictions": predictions,
}
except Exception as exc:
raise HTTPException(status_code=500, detail=f"SciBERT inference failed: {exc}") from exc