| | import json |
| | from pathlib import Path |
| |
|
| | import pandas as pd |
| | import streamlit as st |
| |
|
| | from category_classification.models import models as class_models |
| | from languages import * |
| | from results import process_results |
| |
|
| | page_title = {en: "Papers classification", ru: "Классификация статей"} |
| | model_label = {en: "Select model", ru: "Выберете модель"} |
| | title_label = {en: "Title", ru: "Название статьи"} |
| | authors_label = {en: "Author(s)", ru: "Автор(ы)"} |
| | abstract_label = {en: "Abstract", ru: "Аннотация"} |
| | metrics_label = {en: "Test metrics", ru: "Метрики на тренировочном датасете"} |
| |
|
| | with open( |
| | Path(__file__).parent / "category_classification" / "test_results.json", "r" |
| | ) as metric_f: |
| | metrics = json.load(metric_f) |
| |
|
| |
|
| | def text_area_height(line_height: int): |
| | return 34 * line_height |
| |
|
| |
|
| | @st.cache_data |
| | def load_class_model(name): |
| | model = class_models.get_model(name) |
| | return model |
| |
|
| |
|
| | lang = st.pills(label=langs_str, options=langs) |
| | if lang is None: |
| | lang = en |
| | st.title(page_title[lang]) |
| | model_name = st.selectbox( |
| | model_label[lang], options=class_models.get_model_names_by_lang(lang) |
| | ) |
| | title = st.text_area(title_label[lang], height=text_area_height(2)) |
| | authors = st.text_area(authors_label[lang], height=text_area_height(2)) |
| | abstract = st.text_area(abstract_label[lang], height=text_area_height(5)) |
| |
|
| | if title: |
| | input = {"title": title, "abstract": abstract, "authors": authors} |
| | model = load_class_model(model_name) |
| | results = model(input) |
| | results = process_results(results, lang) |
| | st.dataframe(results, hide_index=True) |
| |
|
| | lang_metrics = pd.DataFrame(metrics[lang]) |
| | if not lang_metrics.empty: |
| | with st.expander(metrics_label[lang]): |
| | st.dataframe(lang_metrics) |
| |
|