Ad Classifiers
Collection
Advertisement classifier in the text generation era • 3 items • Updated
Binary classification model for ad-detection on QA Systems.
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
classifier_model_path = "teknology/ad-classifier-v0.4"
tokenizer = AutoTokenizer.from_pretrained(classifier_model_path)
model = AutoModelForSequenceClassification.from_pretrained(classifier_model_path)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def classify(passages):
inputs = tokenizer(
passages, padding=True, truncation=True, max_length=512, return_tensors="pt"
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
return predictions.cpu().tolist()
preds = classify(["sample_text_1", "sample_text_2"])
Previous versions can be found at:
Base model
microsoft/deberta-v3-base