| import streamlit as st |
| import torch |
| from torch import nn |
| import csv |
| from transformers import AutoModel, AutoTokenizer |
| from huggingface_hub import hf_hub_download |
| from model import ClassificationModel |
|
|
| st.set_page_config(page_title="Article Theme Classifier", layout="centered") |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| MAX_LENGTH = 512 |
|
|
| @st.cache_resource |
| def get_model(): |
| base_model = AutoModel.from_pretrained("distilbert-base-cased") |
| class_model = ClassificationModel(base_model) |
| |
| weights_path = hf_hub_download( |
| repo_id="MostoHF/TunedDistillBertCased", |
| filename="pytorch_model.bin" |
| ) |
|
|
| state_dict = torch.load(weights_path, map_location=device) |
| class_model.load_state_dict(state_dict) |
| class_model.to(device) |
| class_model.eval() |
| |
| return class_model |
|
|
| @st.cache_resource |
| def get_tokenizer(): |
| return AutoTokenizer.from_pretrained("distilbert-base-cased") |
|
|
| @st.cache_resource |
| def get_ind_to_cat(): |
| ind_to_category_copy = {} |
| with open('ind_to_category.csv', mode='r', newline='') as f: |
| reader = csv.reader(f) |
| next(reader) |
| for key, value in reader: |
| ind_to_category_copy[int(key)] = value |
| return ind_to_category_copy |
|
|
| class_model = get_model() |
| tokenizer = get_tokenizer() |
| ind_to_category = get_ind_to_cat() |
|
|
| def inference(title, abstract, threshold=0.95): |
| cur_elem = title + '@' + abstract |
|
|
| encoding = tokenizer(cur_elem, padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors="pt") |
| input_ids = encoding["input_ids"].to(device) |
| attention_mask = encoding["attention_mask"].to(device) |
|
|
| with torch.no_grad(): |
| res_probs = torch.exp(class_model(input_ids, attention_mask)) |
| |
| probs = res_probs.squeeze(0) |
| sorted_probs, sorted_indices = torch.sort(probs, descending=True) |
|
|
| total = 0.0 |
| selected_indices = [] |
| selected_probs = [] |
| |
| for prob, idx in zip(sorted_probs, sorted_indices): |
| total += prob.item() |
| selected_indices.append(idx.item()) |
| selected_probs.append(prob.item()) |
| if total >= threshold: |
| break |
|
|
| ans_themes = [ind_to_category[idx] for idx in selected_indices] |
| return ans_themes, selected_probs |
|
|
|
|
| |
|
|
| st.title("📄 Article Theme Classifier") |
|
|
| title = st.text_input("Title", value="Введите title...") |
| abstract = st.text_input("Abstract", value="Введите abstract...") |
| threshold = st.slider("Выберите cumulative probability threshold", 0.0, 1.0, step=0.01, value=0.95) |
|
|
| if st.button("Submit"): |
| if title or abstract: |
| st.success(f"✅ Title") |
| st.info(f"📑 Abstract") |
| themes, probs = inference(title, abstract, threshold) |
| st.subheader("Predicted Themes:") |
| for i in range(len(themes)): |
| st.write(f"**{themes[i]}** — {probs[i]:.4f}") |
| else: |
| st.warning("❌ Please fill in at least one of the fields.") |
|
|