| import json |
| from pathlib import Path |
|
|
| import pandas as pd |
| import streamlit as st |
|
|
| from category_classification.models import models as class_models |
| from languages import * |
| from results import process_results |
|
|
| page_title = {en: "Papers classification", ru: "Классификация статей"} |
| model_label = {en: "Select model", ru: "Выберете модель"} |
| title_label = {en: "Title", ru: "Название статьи"} |
| authors_label = {en: "Author(s)", ru: "Автор(ы)"} |
| abstract_label = {en: "Abstract", ru: "Аннотация"} |
| metrics_label = {en: "Test metrics", ru: "Метрики на тренировочном датасете"} |
|
|
| with open( |
| Path(__file__).parent / "category_classification" / "test_results.json", "r" |
| ) as metric_f: |
| metrics = json.load(metric_f) |
|
|
|
|
| def text_area_height(line_height: int): |
| return 34 * line_height |
|
|
|
|
| @st.cache_data |
| def load_class_model(name): |
| model = class_models.get_model(name) |
| return model |
|
|
|
|
| lang = st.pills(label=langs_str, options=langs) |
| if lang is None: |
| lang = en |
| st.title(page_title[lang]) |
| model_name = st.selectbox( |
| model_label[lang], options=class_models.get_model_names_by_lang(lang) |
| ) |
| title = st.text_area(title_label[lang], height=text_area_height(2)) |
| authors = st.text_area(authors_label[lang], height=text_area_height(2)) |
| abstract = st.text_area(abstract_label[lang], height=text_area_height(5)) |
|
|
| if title: |
| input = {"title": title, "abstract": abstract, "authors": authors} |
| model = load_class_model(model_name) |
| results = model(input) |
| results = process_results(results, lang) |
| st.dataframe(results, hide_index=True) |
|
|
| lang_metrics = pd.DataFrame(metrics[lang]) |
| if not lang_metrics.empty: |
| with st.expander(metrics_label[lang]): |
| st.dataframe(lang_metrics) |
|
|