| from __future__ import annotations |
|
|
| import os |
| from typing import List |
|
|
| import torch |
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel, Field |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
| MODEL_NAME = os.getenv("SCIBERT_MODEL", "allenai/scibert_scivocab_uncased") |
| MAX_LENGTH = int(os.getenv("SCIBERT_MAX_LENGTH", "512")) |
|
|
| LABELS = [ |
| "HEADING", |
| "ABSTRACT", |
| "BODY", |
| "REFERENCES", |
| "FIGURE_CAPTION", |
| "TABLE_CAPTION", |
| "ACKNOWLEDGEMENTS", |
| "EQUATION", |
| "METHODOLOGY", |
| "CONCLUSION", |
| "AUTHOR_INFO", |
| "TITLE", |
| ] |
|
|
| app = FastAPI(title="Scholarform SciBERT Service", version="1.0.0") |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AutoModelForSequenceClassification.from_pretrained( |
| MODEL_NAME, |
| ignore_mismatched_sizes=True, |
| ).to(device) |
| model.eval() |
|
|
|
|
| class PredictRequest(BaseModel): |
| texts: List[str] = Field(default_factory=list) |
|
|
|
|
| @app.get("/") |
| def root(): |
| return {"status": "ok", "service": "scibert", "model": MODEL_NAME, "device": device} |
|
|
|
|
| @app.get("/health") |
| def health(): |
| return {"status": "ok", "service": "scibert", "model": MODEL_NAME} |
|
|
|
|
| @app.post("/predict") |
| def predict(payload: PredictRequest): |
| texts = [t or "" for t in payload.texts] |
| if not texts: |
| raise HTTPException(status_code=422, detail="`texts` must contain at least one string") |
|
|
| try: |
| inputs = tokenizer( |
| texts, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=MAX_LENGTH, |
| ) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| outputs = model(**inputs) |
| probs = torch.softmax(outputs.logits, dim=1) |
| confidences, label_idxs = torch.max(probs, dim=1) |
|
|
| predictions = [] |
| for confidence, idx in zip(confidences, label_idxs): |
| label_index = idx.item() |
| label = LABELS[label_index] if label_index < len(LABELS) else "BODY" |
| predictions.append({"type": label, "confidence": float(confidence.item())}) |
|
|
| return { |
| "status": "ok", |
| "service": "scibert", |
| "model": MODEL_NAME, |
| "predictions": predictions, |
| } |
| except Exception as exc: |
| raise HTTPException(status_code=500, detail=f"SciBERT inference failed: {exc}") from exc |
|
|