File size: 2,470 Bytes
28aca74
 
 
313bb51
 
 
 
28aca74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313bb51
 
 
 
28aca74
313bb51
 
 
 
28aca74
313bb51
 
28aca74
 
 
 
 
313bb51
 
28aca74
 
 
 
 
 
 
 
 
313bb51
28aca74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from __future__ import annotations

import os
from typing import List

import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from transformers import AutoModelForSequenceClassification, AutoTokenizer

MODEL_NAME = os.getenv("SCIBERT_MODEL", "allenai/scibert_scivocab_uncased")
MAX_LENGTH = int(os.getenv("SCIBERT_MAX_LENGTH", "512"))

LABELS = [
    "HEADING",
    "ABSTRACT",
    "BODY",
    "REFERENCES",
    "FIGURE_CAPTION",
    "TABLE_CAPTION",
    "ACKNOWLEDGEMENTS",
    "EQUATION",
    "METHODOLOGY",
    "CONCLUSION",
    "AUTHOR_INFO",
    "TITLE",
]

app = FastAPI(title="Scholarform SciBERT Service", version="1.0.0")

device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    ignore_mismatched_sizes=True,
).to(device)
model.eval()


class PredictRequest(BaseModel):
    texts: List[str] = Field(default_factory=list)


@app.get("/")
def root():
    return {"status": "ok", "service": "scibert", "model": MODEL_NAME, "device": device}


@app.get("/health")
def health():
    return {"status": "ok", "service": "scibert", "model": MODEL_NAME}


@app.post("/predict")
def predict(payload: PredictRequest):
    texts = [t or "" for t in payload.texts]
    if not texts:
        raise HTTPException(status_code=422, detail="`texts` must contain at least one string")

    try:
        inputs = tokenizer(
            texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=MAX_LENGTH,
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)
            probs = torch.softmax(outputs.logits, dim=1)
            confidences, label_idxs = torch.max(probs, dim=1)

        predictions = []
        for confidence, idx in zip(confidences, label_idxs):
            label_index = idx.item()
            label = LABELS[label_index] if label_index < len(LABELS) else "BODY"
            predictions.append({"type": label, "confidence": float(confidence.item())})

        return {
            "status": "ok",
            "service": "scibert",
            "model": MODEL_NAME,
            "predictions": predictions,
        }
    except Exception as exc:
        raise HTTPException(status_code=500, detail=f"SciBERT inference failed: {exc}") from exc