""" arXiv Article Classifier — Streamlit UI Запуск локально: streamlit run app.py --server.port 8080 """ import json import os import numpy as np import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification # --------------------------------------------------------------------------- # Стили # --------------------------------------------------------------------------- st.markdown(""" """, unsafe_allow_html=True) # --------------------------------------------------------------------------- # Конфиг моделей # --------------------------------------------------------------------------- MODELS = { # "large": { # "label": "Большая", # "dir": "./model_v2", # "base": "allenai/scibert_scivocab_uncased", # "base_url": "https://huggingface.co/allenai/scibert_scivocab_uncased", # "dataset": "mteb/arxiv-clustering-p2p", # "dataset_url": "https://huggingface.co/datasets/mteb/arxiv-clustering-p2p", # "n_classes": 122, # "desc": "SciBERT · 122 категории", # "topics": "CS · Math · Physics · HEP · Astrophysics · Condensed Matter · Statistics · EESS · Quantitative Biology · Quantitative Finance · Economics · Nonlinear Sciences", # }, "small": { "label": "Простая", "dir": "./model", "base": "distilbert-base-cased", "base_url": "https://huggingface.co/distilbert-base-cased", "dataset": "ccdv/arxiv-classification", "dataset_url": "https://huggingface.co/datasets/ccdv/arxiv-classification", "n_classes": 11, "desc": "DistilBERT · 11 категорий", "topics": "cs.CV · cs.AI · cs.NE · cs.IT · cs.DS · cs.SY · cs.CE · cs.PL · math.AC · math.GR · math.ST", }, } MAX_LEN = 256 THRESHOLD = 0.95 # --------------------------------------------------------------------------- # Загрузка модели # --------------------------------------------------------------------------- @st.cache_resource def load_model(model_dir: str): device = ( "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" ) tokenizer = AutoTokenizer.from_pretrained(model_dir) model = AutoModelForSequenceClassification.from_pretrained(model_dir) model.to(device) model.eval() with open(f"{model_dir}/id2label.json") as f: id2label = {int(k): v for k, v in json.load(f).items()} label_full = {} if os.path.exists(f"{model_dir}/label_full.json"): with open(f"{model_dir}/label_full.json") as f: label_full = json.load(f) return tokenizer, model, id2label, label_full, device def predict_top95(title, abstract, model_dir): tokenizer, model, id2label, label_full, device = load_model(model_dir) text = title.strip() if abstract.strip(): text = text + "\n\n" + abstract.strip() enc = tokenizer( text, max_length=MAX_LEN, padding="max_length", truncation=True, return_tensors="pt", ).to(device) with torch.no_grad(): logits = model(**enc).logits probs = torch.softmax(logits, dim=-1).squeeze().cpu().numpy() sorted_idx = np.argsort(probs)[::-1] result, cumsum = [], 0.0 for idx in sorted_idx: prob = float(probs[idx]) cat = id2label[int(idx)] result.append({ "category": cat, "full_name": label_full.get(cat, cat), "probability": prob, }) cumsum += prob if cumsum >= THRESHOLD: break return result def render_results(results): for rank, r in enumerate(results, start=1): pct = r["probability"] * 100 bar = int(r["probability"] * 20) * "█" + (20 - int(r["probability"] * 20)) * "░" st.markdown(f"""
{pct:.1f}%
{rank}. {r['full_name']}
{r['category']}
{bar}
""", unsafe_allow_html=True) # --------------------------------------------------------------------------- # UI # --------------------------------------------------------------------------- st.set_page_config(page_title="arXiv Classifier") st.markdown("# arXiv Classifier") st.markdown("

Классификация научных статей по тематике arxiv

", unsafe_allow_html=True) # Проверяем доступность моделей available = {k: v for k, v in MODELS.items() if os.path.exists(f"{v['dir']}/config.json")} if not available: st.error("Модели не найдены. Сначала запустите обучение.") st.stop() # --------------------------------------------------------------------------- # Режим работы # --------------------------------------------------------------------------- mode = st.radio( "Режим", ["Одна модель", "Сравнение моделей"], horizontal=True, label_visibility="collapsed", ) # --------------------------------------------------------------------------- # Поля ввода # --------------------------------------------------------------------------- title = st.text_input("Название статьи *", placeholder="Например: Attention Is All You Need") abstract = st.text_area( "Аннотация (abstract)", placeholder="Необязательно. Если не указана — классификация только по названию.", height=150, ) # Выбор модели (только в режиме одной) if mode == "Одна модель": model_key = st.radio( "Модель", list(available.keys()), format_func=lambda k: f"{available[k]['label']} — {available[k]['desc']}", horizontal=True, ) cfg = available[model_key] st.divider() run = st.button("Классифицировать", type="primary", use_container_width=True) # --------------------------------------------------------------------------- # Предсказание # --------------------------------------------------------------------------- if run: if not title.strip(): st.error("Пожалуйста, введите название статьи.") st.stop() if mode == "Одна модель": cfg = available[model_key] with st.spinner("Предсказываем..."): try: results = predict_top95(title, abstract, cfg["dir"]) except Exception as e: st.error(f"Ошибка: {e}"); st.stop() st.success(f"Топ-{len(results)} категорий (суммарная вероятность ≥ 95%)") render_results(results) else: # Сравнение if len(available) < 2: st.warning("Для сравнения нужны обе модели. Сейчас доступна только одна.") st.stop() with st.spinner("Запускаем обе модели..."): try: res_large = predict_top95(title, abstract, MODELS["large"]["dir"]) res_small = predict_top95(title, abstract, MODELS["small"]["dir"]) except Exception as e: st.error(f"Ошибка: {e}"); st.stop() col_l, col_r = st.columns(2) with col_l: st.markdown( f"
{MODELS['large']['label']} — {MODELS['large']['desc']}
", unsafe_allow_html=True, ) render_results(res_large) with col_r: st.markdown( f"
{MODELS['small']['label']} — {MODELS['small']['desc']}
", unsafe_allow_html=True, ) render_results(res_small) # --------------------------------------------------------------------------- # Сайдбар # --------------------------------------------------------------------------- with st.sidebar: st.markdown("### О сервисе") for key, cfg in available.items(): st.markdown( f"**{cfg['label']}** \n" f"Модель: [{cfg['base']}]({cfg['base_url']}) \n" f"Датасет: [{cfg['dataset']}]({cfg['dataset_url']}) \n" f"Классов: **{cfg['n_classes']}**" ) # Тематики в виде тегов tags = cfg["topics"].split(" · ") tags_html = " ".join( f"{t}" for t in tags ) st.markdown(tags_html, unsafe_allow_html=True) st.markdown("") st.divider() st.caption( "**Top-95%** — категории выводятся по убыванию вероятности, " "пока суммарная вероятность не превысит 95%." )