Classifier / app.py
StaticFace's picture
Update app.py
c205f25 verified
import os
CPU_THREADS = 16
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["OMP_NUM_THREADS"] = str(CPU_THREADS)
os.environ["MKL_NUM_THREADS"] = str(CPU_THREADS)
os.environ["OPENBLAS_NUM_THREADS"] = str(CPU_THREADS)
os.environ["NUMEXPR_NUM_THREADS"] = str(CPU_THREADS)
import torch
import gradio as gr
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
MODEL_ID = "MoritzLaurer/deberta-v3-large-zeroshot-v2.0"
torch.set_num_threads(CPU_THREADS)
torch.set_num_interop_threads(1)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
model.eval()
label2id = {k.lower(): v for k, v in model.config.label2id.items()}
entail_id = label2id.get("entailment", 2)
def _softmax(x):
x = x - np.max(x)
e = np.exp(x)
return e / np.sum(e)
def run_zero_shot(text, labels, hypothesis_template, multi_label, top_k):
text = (text or "").strip()
labels = (labels or "").strip()
hypothesis_template = (hypothesis_template or "").strip() or "This text is about {}"
if not text:
return {"error": "Enter some text."}
candidate_labels = [x.strip() for x in labels.split(",") if x.strip()]
if not candidate_labels:
return {"error": "Enter at least 1 label (comma-separated)."}
scores = []
with torch.inference_mode():
for lab in candidate_labels:
hyp = hypothesis_template.format(lab)
inputs = tokenizer(text, hyp, return_tensors="pt", truncation=True)
logits = model(**inputs).logits[0].float().cpu().numpy()
score = float(_softmax(logits)[entail_id])
scores.append(score)
scores_np = np.array(scores, dtype=np.float32)
if bool(multi_label):
out_scores = scores_np
else:
out_scores = _softmax(scores_np)
pairs = list(zip(candidate_labels, out_scores.tolist()))
pairs.sort(key=lambda x: x[1], reverse=True)
pairs = pairs[: max(1, int(top_k))]
return {
"cpu_threads": CPU_THREADS,
"top": {"label": pairs[0][0], "confidence_pct": round(pairs[0][1] * 100, 2)},
"all": [{"label": k, "confidence_pct": round(v * 100, 2)} for k, v in pairs],
}
demo = gr.Interface(
fn=run_zero_shot,
inputs=[
gr.Textbox(label="Text", lines=4, value="I am wahhhh"),
gr.Textbox(label="Candidate Labels (comma-separated)", value="sad, happy, angry, neutral"),
gr.Textbox(label="Hypothesis Template", value="This text is about {}"),
gr.Checkbox(label="Multi-label", value=False),
gr.Slider(label="Top-K to show", minimum=1, maximum=25, value=5, step=1),
],
outputs=gr.JSON(label="Output"),
title="Zero-Shot Classification (DeBERTa v3 Large, 16-core CPU)",
flagging_mode="never",
)
if __name__ == "__main__":
demo.launch()