arXiv_Classifier / src /streamlit_app.py
Utiuzhnikov's picture
Update src/streamlit_app.py
24f1585 verified
"""
arXiv Article Classifier — Streamlit UI
Запуск локально:
streamlit run app.py --server.port 8080
"""
import json
import os
import numpy as np
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# ---------------------------------------------------------------------------
# Стили
# ---------------------------------------------------------------------------
st.markdown("""
<style>
/* Фон */
.stApp { background-color: #f7faf7; }
.main .block-container { padding-top: 2rem; }
/* Заголовки */
h1 { color: #2d6a4f !important; letter-spacing: -0.5px; }
h2, h3 { color: #40916c !important; }
/* Текст */
p, label, .stMarkdown { color: #374151 !important; }
/* Radio */
.stRadio > label { color: #40916c !important; font-weight: 600; }
/* Поля ввода */
.stTextInput input, .stTextArea textarea {
background-color: #ffffff !important;
border: 1px solid #b7e4c7 !important;
color: #1f2937 !important;
border-radius: 8px !important;
}
.stTextInput input:focus, .stTextArea textarea:focus {
border-color: #52b788 !important;
box-shadow: 0 0 0 2px rgba(82,183,136,0.15) !important;
}
.stTextInput label, .stTextArea label {
color: #40916c !important;
font-weight: 600;
}
/* Кнопка */
.stButton > button {
background-color: #52b788 !important;
color: #ffffff !important;
border: none !important;
border-radius: 8px !important;
font-weight: 600;
transition: all 0.2s;
}
.stButton > button:hover {
background-color: #40916c !important;
color: #ffffff !important;
}
/* Divider */
hr { border-color: #d8f3dc !important; }
/* Success/error */
.stSuccess { background-color: #d8f3dc !important; color: #1b4332 !important; border-color: #95d5b2 !important; }
.stError { background-color: #fef2f2 !important; }
/* Sidebar */
[data-testid="stSidebar"] {
background-color: #f0faf2 !important;
border-right: 1px solid #d8f3dc;
}
[data-testid="stSidebar"] p,
[data-testid="stSidebar"] span,
[data-testid="stSidebar"] div { color: #374151 !important; }
[data-testid="stSidebar"] a { color: #40916c !important; }
/* Карточка категории */
.cat-card {
background: #ffffff;
border: 1px solid #d8f3dc;
border-left: 4px solid #52b788;
border-radius: 8px;
padding: 10px 14px;
margin-bottom: 8px;
}
.cat-title { color: #1b4332; font-weight: 600; font-size: 0.95rem; }
.cat-code { color: #74c69d; font-size: 0.78rem; font-family: monospace; margin-top: 2px; }
.cat-pct { color: #40916c; font-size: 1.2rem; font-weight: 700; float: right; }
/* Заголовок колонки сравнения */
.col-header {
background: #d8f3dc;
border-radius: 8px;
padding: 8px 14px;
margin-bottom: 12px;
color: #1b4332 !important;
font-weight: 700;
font-size: 0.9rem;
text-align: center;
}
</style>
""", unsafe_allow_html=True)
# ---------------------------------------------------------------------------
# Конфиг моделей
# ---------------------------------------------------------------------------
MODELS = {
# "large": {
# "label": "Большая",
# "dir": "./model_v2",
# "base": "allenai/scibert_scivocab_uncased",
# "base_url": "https://huggingface.co/allenai/scibert_scivocab_uncased",
# "dataset": "mteb/arxiv-clustering-p2p",
# "dataset_url": "https://huggingface.co/datasets/mteb/arxiv-clustering-p2p",
# "n_classes": 122,
# "desc": "SciBERT · 122 категории",
# "topics": "CS · Math · Physics · HEP · Astrophysics · Condensed Matter · Statistics · EESS · Quantitative Biology · Quantitative Finance · Economics · Nonlinear Sciences",
# },
"small": {
"label": "Простая",
"dir": "./model",
"base": "distilbert-base-cased",
"base_url": "https://huggingface.co/distilbert-base-cased",
"dataset": "ccdv/arxiv-classification",
"dataset_url": "https://huggingface.co/datasets/ccdv/arxiv-classification",
"n_classes": 11,
"desc": "DistilBERT · 11 категорий",
"topics": "cs.CV · cs.AI · cs.NE · cs.IT · cs.DS · cs.SY · cs.CE · cs.PL · math.AC · math.GR · math.ST",
},
}
MAX_LEN = 256
THRESHOLD = 0.95
# ---------------------------------------------------------------------------
# Загрузка модели
# ---------------------------------------------------------------------------
@st.cache_resource
def load_model(model_dir: str):
device = (
"mps" if torch.backends.mps.is_available() else
"cuda" if torch.cuda.is_available() else
"cpu"
)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
model.to(device)
model.eval()
with open(f"{model_dir}/id2label.json") as f:
id2label = {int(k): v for k, v in json.load(f).items()}
label_full = {}
if os.path.exists(f"{model_dir}/label_full.json"):
with open(f"{model_dir}/label_full.json") as f:
label_full = json.load(f)
return tokenizer, model, id2label, label_full, device
def predict_top95(title, abstract, model_dir):
tokenizer, model, id2label, label_full, device = load_model(model_dir)
text = title.strip()
if abstract.strip():
text = text + "\n\n" + abstract.strip()
enc = tokenizer(
text, max_length=MAX_LEN, padding="max_length",
truncation=True, return_tensors="pt",
).to(device)
with torch.no_grad():
logits = model(**enc).logits
probs = torch.softmax(logits, dim=-1).squeeze().cpu().numpy()
sorted_idx = np.argsort(probs)[::-1]
result, cumsum = [], 0.0
for idx in sorted_idx:
prob = float(probs[idx])
cat = id2label[int(idx)]
result.append({
"category": cat,
"full_name": label_full.get(cat, cat),
"probability": prob,
})
cumsum += prob
if cumsum >= THRESHOLD:
break
return result
def render_results(results):
for rank, r in enumerate(results, start=1):
pct = r["probability"] * 100
bar = int(r["probability"] * 20) * "█" + (20 - int(r["probability"] * 20)) * "░"
st.markdown(f"""
<div class="cat-card">
<span class="cat-pct">{pct:.1f}%</span>
<div class="cat-title">{rank}. {r['full_name']}</div>
<div class="cat-code">{r['category']}</div>
<div style="color:#95d5b2;font-size:0.75rem;letter-spacing:1px;margin-top:4px">{bar}</div>
</div>
""", unsafe_allow_html=True)
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
st.set_page_config(page_title="arXiv Classifier")
st.markdown("# arXiv Classifier")
st.markdown("<p style='color:#52b788;margin-top:-12px;margin-bottom:8px'>Классификация научных статей по тематике arxiv</p>", unsafe_allow_html=True)
# Проверяем доступность моделей
available = {k: v for k, v in MODELS.items() if os.path.exists(f"{v['dir']}/config.json")}
if not available:
st.error("Модели не найдены. Сначала запустите обучение.")
st.stop()
# ---------------------------------------------------------------------------
# Режим работы
# ---------------------------------------------------------------------------
mode = st.radio(
"Режим",
["Одна модель", "Сравнение моделей"],
horizontal=True,
label_visibility="collapsed",
)
# ---------------------------------------------------------------------------
# Поля ввода
# ---------------------------------------------------------------------------
title = st.text_input("Название статьи *", placeholder="Например: Attention Is All You Need")
abstract = st.text_area(
"Аннотация (abstract)",
placeholder="Необязательно. Если не указана — классификация только по названию.",
height=150,
)
# Выбор модели (только в режиме одной)
if mode == "Одна модель":
model_key = st.radio(
"Модель",
list(available.keys()),
format_func=lambda k: f"{available[k]['label']}{available[k]['desc']}",
horizontal=True,
)
cfg = available[model_key]
st.divider()
run = st.button("Классифицировать", type="primary", use_container_width=True)
# ---------------------------------------------------------------------------
# Предсказание
# ---------------------------------------------------------------------------
if run:
if not title.strip():
st.error("Пожалуйста, введите название статьи.")
st.stop()
if mode == "Одна модель":
cfg = available[model_key]
with st.spinner("Предсказываем..."):
try:
results = predict_top95(title, abstract, cfg["dir"])
except Exception as e:
st.error(f"Ошибка: {e}"); st.stop()
st.success(f"Топ-{len(results)} категорий (суммарная вероятность ≥ 95%)")
render_results(results)
else: # Сравнение
if len(available) < 2:
st.warning("Для сравнения нужны обе модели. Сейчас доступна только одна.")
st.stop()
with st.spinner("Запускаем обе модели..."):
try:
res_large = predict_top95(title, abstract, MODELS["large"]["dir"])
res_small = predict_top95(title, abstract, MODELS["small"]["dir"])
except Exception as e:
st.error(f"Ошибка: {e}"); st.stop()
col_l, col_r = st.columns(2)
with col_l:
st.markdown(
f"<div class='col-header'>{MODELS['large']['label']}{MODELS['large']['desc']}</div>",
unsafe_allow_html=True,
)
render_results(res_large)
with col_r:
st.markdown(
f"<div class='col-header'>{MODELS['small']['label']}{MODELS['small']['desc']}</div>",
unsafe_allow_html=True,
)
render_results(res_small)
# ---------------------------------------------------------------------------
# Сайдбар
# ---------------------------------------------------------------------------
with st.sidebar:
st.markdown("### О сервисе")
for key, cfg in available.items():
st.markdown(
f"**{cfg['label']}** \n"
f"Модель: [{cfg['base']}]({cfg['base_url']}) \n"
f"Датасет: [{cfg['dataset']}]({cfg['dataset_url']}) \n"
f"Классов: **{cfg['n_classes']}**"
)
# Тематики в виде тегов
tags = cfg["topics"].split(" · ")
tags_html = " ".join(
f"<span style='display:inline-block;background:#d8f3dc;color:#1b4332;"
f"border-radius:4px;padding:1px 6px;font-size:0.72rem;"
f"margin:2px 2px 2px 0;font-family:monospace'>{t}</span>"
for t in tags
)
st.markdown(tags_html, unsafe_allow_html=True)
st.markdown("")
st.divider()
st.caption(
"**Top-95%** — категории выводятся по убыванию вероятности, "
"пока суммарная вероятность не превысит 95%."
)