Spaces:
Sleeping
Sleeping
| """ | |
| arXiv Article Classifier — Streamlit UI | |
| Запуск локально: | |
| streamlit run app.py --server.port 8080 | |
| """ | |
| import json | |
| import os | |
| import numpy as np | |
| import streamlit as st | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # --------------------------------------------------------------------------- | |
| # Стили | |
| # --------------------------------------------------------------------------- | |
| st.markdown(""" | |
| <style> | |
| /* Фон */ | |
| .stApp { background-color: #f7faf7; } | |
| .main .block-container { padding-top: 2rem; } | |
| /* Заголовки */ | |
| h1 { color: #2d6a4f !important; letter-spacing: -0.5px; } | |
| h2, h3 { color: #40916c !important; } | |
| /* Текст */ | |
| p, label, .stMarkdown { color: #374151 !important; } | |
| /* Radio */ | |
| .stRadio > label { color: #40916c !important; font-weight: 600; } | |
| /* Поля ввода */ | |
| .stTextInput input, .stTextArea textarea { | |
| background-color: #ffffff !important; | |
| border: 1px solid #b7e4c7 !important; | |
| color: #1f2937 !important; | |
| border-radius: 8px !important; | |
| } | |
| .stTextInput input:focus, .stTextArea textarea:focus { | |
| border-color: #52b788 !important; | |
| box-shadow: 0 0 0 2px rgba(82,183,136,0.15) !important; | |
| } | |
| .stTextInput label, .stTextArea label { | |
| color: #40916c !important; | |
| font-weight: 600; | |
| } | |
| /* Кнопка */ | |
| .stButton > button { | |
| background-color: #52b788 !important; | |
| color: #ffffff !important; | |
| border: none !important; | |
| border-radius: 8px !important; | |
| font-weight: 600; | |
| transition: all 0.2s; | |
| } | |
| .stButton > button:hover { | |
| background-color: #40916c !important; | |
| color: #ffffff !important; | |
| } | |
| /* Divider */ | |
| hr { border-color: #d8f3dc !important; } | |
| /* Success/error */ | |
| .stSuccess { background-color: #d8f3dc !important; color: #1b4332 !important; border-color: #95d5b2 !important; } | |
| .stError { background-color: #fef2f2 !important; } | |
| /* Sidebar */ | |
| [data-testid="stSidebar"] { | |
| background-color: #f0faf2 !important; | |
| border-right: 1px solid #d8f3dc; | |
| } | |
| [data-testid="stSidebar"] p, | |
| [data-testid="stSidebar"] span, | |
| [data-testid="stSidebar"] div { color: #374151 !important; } | |
| [data-testid="stSidebar"] a { color: #40916c !important; } | |
| /* Карточка категории */ | |
| .cat-card { | |
| background: #ffffff; | |
| border: 1px solid #d8f3dc; | |
| border-left: 4px solid #52b788; | |
| border-radius: 8px; | |
| padding: 10px 14px; | |
| margin-bottom: 8px; | |
| } | |
| .cat-title { color: #1b4332; font-weight: 600; font-size: 0.95rem; } | |
| .cat-code { color: #74c69d; font-size: 0.78rem; font-family: monospace; margin-top: 2px; } | |
| .cat-pct { color: #40916c; font-size: 1.2rem; font-weight: 700; float: right; } | |
| /* Заголовок колонки сравнения */ | |
| .col-header { | |
| background: #d8f3dc; | |
| border-radius: 8px; | |
| padding: 8px 14px; | |
| margin-bottom: 12px; | |
| color: #1b4332 !important; | |
| font-weight: 700; | |
| font-size: 0.9rem; | |
| text-align: center; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # --------------------------------------------------------------------------- | |
| # Конфиг моделей | |
| # --------------------------------------------------------------------------- | |
| MODELS = { | |
| # "large": { | |
| # "label": "Большая", | |
| # "dir": "./model_v2", | |
| # "base": "allenai/scibert_scivocab_uncased", | |
| # "base_url": "https://huggingface.co/allenai/scibert_scivocab_uncased", | |
| # "dataset": "mteb/arxiv-clustering-p2p", | |
| # "dataset_url": "https://huggingface.co/datasets/mteb/arxiv-clustering-p2p", | |
| # "n_classes": 122, | |
| # "desc": "SciBERT · 122 категории", | |
| # "topics": "CS · Math · Physics · HEP · Astrophysics · Condensed Matter · Statistics · EESS · Quantitative Biology · Quantitative Finance · Economics · Nonlinear Sciences", | |
| # }, | |
| "small": { | |
| "label": "Простая", | |
| "dir": "./model", | |
| "base": "distilbert-base-cased", | |
| "base_url": "https://huggingface.co/distilbert-base-cased", | |
| "dataset": "ccdv/arxiv-classification", | |
| "dataset_url": "https://huggingface.co/datasets/ccdv/arxiv-classification", | |
| "n_classes": 11, | |
| "desc": "DistilBERT · 11 категорий", | |
| "topics": "cs.CV · cs.AI · cs.NE · cs.IT · cs.DS · cs.SY · cs.CE · cs.PL · math.AC · math.GR · math.ST", | |
| }, | |
| } | |
| MAX_LEN = 256 | |
| THRESHOLD = 0.95 | |
| # --------------------------------------------------------------------------- | |
| # Загрузка модели | |
| # --------------------------------------------------------------------------- | |
| def load_model(model_dir: str): | |
| device = ( | |
| "mps" if torch.backends.mps.is_available() else | |
| "cuda" if torch.cuda.is_available() else | |
| "cpu" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_dir) | |
| model.to(device) | |
| model.eval() | |
| with open(f"{model_dir}/id2label.json") as f: | |
| id2label = {int(k): v for k, v in json.load(f).items()} | |
| label_full = {} | |
| if os.path.exists(f"{model_dir}/label_full.json"): | |
| with open(f"{model_dir}/label_full.json") as f: | |
| label_full = json.load(f) | |
| return tokenizer, model, id2label, label_full, device | |
| def predict_top95(title, abstract, model_dir): | |
| tokenizer, model, id2label, label_full, device = load_model(model_dir) | |
| text = title.strip() | |
| if abstract.strip(): | |
| text = text + "\n\n" + abstract.strip() | |
| enc = tokenizer( | |
| text, max_length=MAX_LEN, padding="max_length", | |
| truncation=True, return_tensors="pt", | |
| ).to(device) | |
| with torch.no_grad(): | |
| logits = model(**enc).logits | |
| probs = torch.softmax(logits, dim=-1).squeeze().cpu().numpy() | |
| sorted_idx = np.argsort(probs)[::-1] | |
| result, cumsum = [], 0.0 | |
| for idx in sorted_idx: | |
| prob = float(probs[idx]) | |
| cat = id2label[int(idx)] | |
| result.append({ | |
| "category": cat, | |
| "full_name": label_full.get(cat, cat), | |
| "probability": prob, | |
| }) | |
| cumsum += prob | |
| if cumsum >= THRESHOLD: | |
| break | |
| return result | |
| def render_results(results): | |
| for rank, r in enumerate(results, start=1): | |
| pct = r["probability"] * 100 | |
| bar = int(r["probability"] * 20) * "█" + (20 - int(r["probability"] * 20)) * "░" | |
| st.markdown(f""" | |
| <div class="cat-card"> | |
| <span class="cat-pct">{pct:.1f}%</span> | |
| <div class="cat-title">{rank}. {r['full_name']}</div> | |
| <div class="cat-code">{r['category']}</div> | |
| <div style="color:#95d5b2;font-size:0.75rem;letter-spacing:1px;margin-top:4px">{bar}</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # --------------------------------------------------------------------------- | |
| # UI | |
| # --------------------------------------------------------------------------- | |
| st.set_page_config(page_title="arXiv Classifier") | |
| st.markdown("# arXiv Classifier") | |
| st.markdown("<p style='color:#52b788;margin-top:-12px;margin-bottom:8px'>Классификация научных статей по тематике arxiv</p>", unsafe_allow_html=True) | |
| # Проверяем доступность моделей | |
| available = {k: v for k, v in MODELS.items() if os.path.exists(f"{v['dir']}/config.json")} | |
| if not available: | |
| st.error("Модели не найдены. Сначала запустите обучение.") | |
| st.stop() | |
| # --------------------------------------------------------------------------- | |
| # Режим работы | |
| # --------------------------------------------------------------------------- | |
| mode = st.radio( | |
| "Режим", | |
| ["Одна модель", "Сравнение моделей"], | |
| horizontal=True, | |
| label_visibility="collapsed", | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Поля ввода | |
| # --------------------------------------------------------------------------- | |
| title = st.text_input("Название статьи *", placeholder="Например: Attention Is All You Need") | |
| abstract = st.text_area( | |
| "Аннотация (abstract)", | |
| placeholder="Необязательно. Если не указана — классификация только по названию.", | |
| height=150, | |
| ) | |
| # Выбор модели (только в режиме одной) | |
| if mode == "Одна модель": | |
| model_key = st.radio( | |
| "Модель", | |
| list(available.keys()), | |
| format_func=lambda k: f"{available[k]['label']} — {available[k]['desc']}", | |
| horizontal=True, | |
| ) | |
| cfg = available[model_key] | |
| st.divider() | |
| run = st.button("Классифицировать", type="primary", use_container_width=True) | |
| # --------------------------------------------------------------------------- | |
| # Предсказание | |
| # --------------------------------------------------------------------------- | |
| if run: | |
| if not title.strip(): | |
| st.error("Пожалуйста, введите название статьи.") | |
| st.stop() | |
| if mode == "Одна модель": | |
| cfg = available[model_key] | |
| with st.spinner("Предсказываем..."): | |
| try: | |
| results = predict_top95(title, abstract, cfg["dir"]) | |
| except Exception as e: | |
| st.error(f"Ошибка: {e}"); st.stop() | |
| st.success(f"Топ-{len(results)} категорий (суммарная вероятность ≥ 95%)") | |
| render_results(results) | |
| else: # Сравнение | |
| if len(available) < 2: | |
| st.warning("Для сравнения нужны обе модели. Сейчас доступна только одна.") | |
| st.stop() | |
| with st.spinner("Запускаем обе модели..."): | |
| try: | |
| res_large = predict_top95(title, abstract, MODELS["large"]["dir"]) | |
| res_small = predict_top95(title, abstract, MODELS["small"]["dir"]) | |
| except Exception as e: | |
| st.error(f"Ошибка: {e}"); st.stop() | |
| col_l, col_r = st.columns(2) | |
| with col_l: | |
| st.markdown( | |
| f"<div class='col-header'>{MODELS['large']['label']} — {MODELS['large']['desc']}</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| render_results(res_large) | |
| with col_r: | |
| st.markdown( | |
| f"<div class='col-header'>{MODELS['small']['label']} — {MODELS['small']['desc']}</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| render_results(res_small) | |
| # --------------------------------------------------------------------------- | |
| # Сайдбар | |
| # --------------------------------------------------------------------------- | |
| with st.sidebar: | |
| st.markdown("### О сервисе") | |
| for key, cfg in available.items(): | |
| st.markdown( | |
| f"**{cfg['label']}** \n" | |
| f"Модель: [{cfg['base']}]({cfg['base_url']}) \n" | |
| f"Датасет: [{cfg['dataset']}]({cfg['dataset_url']}) \n" | |
| f"Классов: **{cfg['n_classes']}**" | |
| ) | |
| # Тематики в виде тегов | |
| tags = cfg["topics"].split(" · ") | |
| tags_html = " ".join( | |
| f"<span style='display:inline-block;background:#d8f3dc;color:#1b4332;" | |
| f"border-radius:4px;padding:1px 6px;font-size:0.72rem;" | |
| f"margin:2px 2px 2px 0;font-family:monospace'>{t}</span>" | |
| for t in tags | |
| ) | |
| st.markdown(tags_html, unsafe_allow_html=True) | |
| st.markdown("") | |
| st.divider() | |
| st.caption( | |
| "**Top-95%** — категории выводятся по убыванию вероятности, " | |
| "пока суммарная вероятность не превысит 95%." | |
| ) | |