Utiuzhnikov commited on
Commit
44f7fa5
·
verified ·
1 Parent(s): 0ca213c

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +334 -0
  2. requirements.txt +9 -3
app.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ arXiv Article Classifier — Streamlit UI
3
+
4
+ Запуск локально:
5
+ streamlit run app.py --server.port 8080
6
+ """
7
+
8
+ import json
9
+ import os
10
+ import numpy as np
11
+ import streamlit as st
12
+ import torch
13
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
14
+
15
+ # ---------------------------------------------------------------------------
16
+ # Стили
17
+ # ---------------------------------------------------------------------------
18
+ st.markdown("""
19
+ <style>
20
+ /* Фон */
21
+ .stApp { background-color: #f7faf7; }
22
+ .main .block-container { padding-top: 2rem; }
23
+
24
+ /* Заголовки */
25
+ h1 { color: #2d6a4f !important; letter-spacing: -0.5px; }
26
+ h2, h3 { color: #40916c !important; }
27
+
28
+ /* Текст */
29
+ p, label, .stMarkdown { color: #374151 !important; }
30
+
31
+ /* Radio */
32
+ .stRadio > label { color: #40916c !important; font-weight: 600; }
33
+
34
+ /* Поля ввода */
35
+ .stTextInput input, .stTextArea textarea {
36
+ background-color: #ffffff !important;
37
+ border: 1px solid #b7e4c7 !important;
38
+ color: #1f2937 !important;
39
+ border-radius: 8px !important;
40
+ }
41
+ .stTextInput input:focus, .stTextArea textarea:focus {
42
+ border-color: #52b788 !important;
43
+ box-shadow: 0 0 0 2px rgba(82,183,136,0.15) !important;
44
+ }
45
+ .stTextInput label, .stTextArea label {
46
+ color: #40916c !important;
47
+ font-weight: 600;
48
+ }
49
+
50
+ /* Кнопка */
51
+ .stButton > button {
52
+ background-color: #52b788 !important;
53
+ color: #ffffff !important;
54
+ border: none !important;
55
+ border-radius: 8px !important;
56
+ font-weight: 600;
57
+ transition: all 0.2s;
58
+ }
59
+ .stButton > button:hover {
60
+ background-color: #40916c !important;
61
+ color: #ffffff !important;
62
+ }
63
+
64
+ /* Divider */
65
+ hr { border-color: #d8f3dc !important; }
66
+
67
+ /* Success/error */
68
+ .stSuccess { background-color: #d8f3dc !important; color: #1b4332 !important; border-color: #95d5b2 !important; }
69
+ .stError { background-color: #fef2f2 !important; }
70
+
71
+ /* Sidebar */
72
+ [data-testid="stSidebar"] {
73
+ background-color: #f0faf2 !important;
74
+ border-right: 1px solid #d8f3dc;
75
+ }
76
+ [data-testid="stSidebar"] p,
77
+ [data-testid="stSidebar"] span,
78
+ [data-testid="stSidebar"] div { color: #374151 !important; }
79
+ [data-testid="stSidebar"] a { color: #40916c !important; }
80
+
81
+ /* Карточка категории */
82
+ .cat-card {
83
+ background: #ffffff;
84
+ border: 1px solid #d8f3dc;
85
+ border-left: 4px solid #52b788;
86
+ border-radius: 8px;
87
+ padding: 10px 14px;
88
+ margin-bottom: 8px;
89
+ }
90
+ .cat-title { color: #1b4332; font-weight: 600; font-size: 0.95rem; }
91
+ .cat-code { color: #74c69d; font-size: 0.78rem; font-family: monospace; margin-top: 2px; }
92
+ .cat-pct { color: #40916c; font-size: 1.2rem; font-weight: 700; float: right; }
93
+
94
+ /* Заголовок колонки сравнения */
95
+ .col-header {
96
+ background: #d8f3dc;
97
+ border-radius: 8px;
98
+ padding: 8px 14px;
99
+ margin-bottom: 12px;
100
+ color: #1b4332 !important;
101
+ font-weight: 700;
102
+ font-size: 0.9rem;
103
+ text-align: center;
104
+ }
105
+ </style>
106
+ """, unsafe_allow_html=True)
107
+
108
+ # ---------------------------------------------------------------------------
109
+ # Конфиг моделей
110
+ # ---------------------------------------------------------------------------
111
+ MODELS = {
112
+ "large": {
113
+ "label": "Большая",
114
+ "dir": "./model_v2",
115
+ "base": "allenai/scibert_scivocab_uncased",
116
+ "base_url": "https://huggingface.co/allenai/scibert_scivocab_uncased",
117
+ "dataset": "mteb/arxiv-clustering-p2p",
118
+ "dataset_url": "https://huggingface.co/datasets/mteb/arxiv-clustering-p2p",
119
+ "n_classes": 122,
120
+ "desc": "SciBERT · 122 категории",
121
+ "topics": "CS · Math · Physics · HEP · Astrophysics · Condensed Matter · Statistics · EESS · Quantitative Biology · Quantitative Finance · Economics · Nonlinear Sciences",
122
+ },
123
+ "small": {
124
+ "label": "Простая",
125
+ "dir": "./model",
126
+ "base": "distilbert-base-cased",
127
+ "base_url": "https://huggingface.co/distilbert-base-cased",
128
+ "dataset": "ccdv/arxiv-classification",
129
+ "dataset_url": "https://huggingface.co/datasets/ccdv/arxiv-classification",
130
+ "n_classes": 11,
131
+ "desc": "DistilBERT · 11 категорий",
132
+ "topics": "cs.CV · cs.AI · cs.NE · cs.IT · cs.DS · cs.SY · cs.CE · cs.PL · math.AC · math.GR · math.ST",
133
+ },
134
+ }
135
+
136
+ MAX_LEN = 256
137
+ THRESHOLD = 0.95
138
+
139
+
140
+ # ---------------------------------------------------------------------------
141
+ # Загрузка модели
142
+ # ---------------------------------------------------------------------------
143
+ @st.cache_resource
144
+ def load_model(model_dir: str):
145
+ device = (
146
+ "mps" if torch.backends.mps.is_available() else
147
+ "cuda" if torch.cuda.is_available() else
148
+ "cpu"
149
+ )
150
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
151
+ model = AutoModelForSequenceClassification.from_pretrained(model_dir)
152
+ model.to(device)
153
+ model.eval()
154
+
155
+ with open(f"{model_dir}/id2label.json") as f:
156
+ id2label = {int(k): v for k, v in json.load(f).items()}
157
+
158
+ label_full = {}
159
+ if os.path.exists(f"{model_dir}/label_full.json"):
160
+ with open(f"{model_dir}/label_full.json") as f:
161
+ label_full = json.load(f)
162
+
163
+ return tokenizer, model, id2label, label_full, device
164
+
165
+
166
+ def predict_top95(title, abstract, model_dir):
167
+ tokenizer, model, id2label, label_full, device = load_model(model_dir)
168
+ text = title.strip()
169
+ if abstract.strip():
170
+ text = text + "\n\n" + abstract.strip()
171
+
172
+ enc = tokenizer(
173
+ text, max_length=MAX_LEN, padding="max_length",
174
+ truncation=True, return_tensors="pt",
175
+ ).to(device)
176
+
177
+ with torch.no_grad():
178
+ logits = model(**enc).logits
179
+
180
+ probs = torch.softmax(logits, dim=-1).squeeze().cpu().numpy()
181
+ sorted_idx = np.argsort(probs)[::-1]
182
+
183
+ result, cumsum = [], 0.0
184
+ for idx in sorted_idx:
185
+ prob = float(probs[idx])
186
+ cat = id2label[int(idx)]
187
+ result.append({
188
+ "category": cat,
189
+ "full_name": label_full.get(cat, cat),
190
+ "probability": prob,
191
+ })
192
+ cumsum += prob
193
+ if cumsum >= THRESHOLD:
194
+ break
195
+ return result
196
+
197
+
198
+ def render_results(results):
199
+ for rank, r in enumerate(results, start=1):
200
+ pct = r["probability"] * 100
201
+ bar = int(r["probability"] * 20) * "█" + (20 - int(r["probability"] * 20)) * "░"
202
+ st.markdown(f"""
203
+ <div class="cat-card">
204
+ <span class="cat-pct">{pct:.1f}%</span>
205
+ <div class="cat-title">{rank}. {r['full_name']}</div>
206
+ <div class="cat-code">{r['category']}</div>
207
+ <div style="color:#95d5b2;font-size:0.75rem;letter-spacing:1px;margin-top:4px">{bar}</div>
208
+ </div>
209
+ """, unsafe_allow_html=True)
210
+
211
+
212
+ # ---------------------------------------------------------------------------
213
+ # UI
214
+ # ---------------------------------------------------------------------------
215
+ st.set_page_config(page_title="arXiv Classifier")
216
+
217
+ st.markdown("# arXiv Classifier")
218
+ st.markdown("<p style='color:#52b788;margin-top:-12px;margin-bottom:8px'>Классификация научных статей по тематике arxiv</p>", unsafe_allow_html=True)
219
+
220
+ # Проверяем доступность моделей
221
+ available = {k: v for k, v in MODELS.items() if os.path.exists(f"{v['dir']}/config.json")}
222
+ if not available:
223
+ st.error("Модели не найдены. Сначала запустите обучение.")
224
+ st.stop()
225
+
226
+ # ---------------------------------------------------------------------------
227
+ # Режим работы
228
+ # ---------------------------------------------------------------------------
229
+ mode = st.radio(
230
+ "Режим",
231
+ ["Одна модель", "Сравнение моделей"],
232
+ horizontal=True,
233
+ label_visibility="collapsed",
234
+ )
235
+
236
+ # ---------------------------------------------------------------------------
237
+ # Поля ввода
238
+ # ---------------------------------------------------------------------------
239
+ title = st.text_input("Название статьи *", placeholder="Например: Attention Is All You Need")
240
+ abstract = st.text_area(
241
+ "Аннотация (abstract)",
242
+ placeholder="Необязательно. Если не указана — классификация только по названию.",
243
+ height=150,
244
+ )
245
+
246
+ # Выбор модели (только в режиме одной)
247
+ if mode == "Одна модель":
248
+ model_key = st.radio(
249
+ "Модель",
250
+ list(available.keys()),
251
+ format_func=lambda k: f"{available[k]['label']} — {available[k]['desc']}",
252
+ horizontal=True,
253
+ )
254
+ cfg = available[model_key]
255
+
256
+ st.divider()
257
+ run = st.button("Классифицировать", type="primary", use_container_width=True)
258
+
259
+ # ---------------------------------------------------------------------------
260
+ # Предсказание
261
+ # ---------------------------------------------------------------------------
262
+ if run:
263
+ if not title.strip():
264
+ st.error("Пожалуйста, введите название статьи.")
265
+ st.stop()
266
+
267
+ if mode == "Одна модель":
268
+ cfg = available[model_key]
269
+ with st.spinner("Предсказываем..."):
270
+ try:
271
+ results = predict_top95(title, abstract, cfg["dir"])
272
+ except Exception as e:
273
+ st.error(f"Ошибка: {e}"); st.stop()
274
+
275
+ st.success(f"Топ-{len(results)} категорий (суммарная вероятность ≥ 95%)")
276
+ render_results(results)
277
+
278
+ else: # Сравнение
279
+ if len(available) < 2:
280
+ st.warning("Для сравнения нужны обе модели. Сейчас доступна только одна.")
281
+ st.stop()
282
+
283
+ with st.spinner("Запускаем обе модели..."):
284
+ try:
285
+ res_large = predict_top95(title, abstract, MODELS["large"]["dir"])
286
+ res_small = predict_top95(title, abstract, MODELS["small"]["dir"])
287
+ except Exception as e:
288
+ st.error(f"Ошибка: {e}"); st.stop()
289
+
290
+ col_l, col_r = st.columns(2)
291
+
292
+ with col_l:
293
+ st.markdown(
294
+ f"<div class='col-header'>{MODELS['large']['label']} — {MODELS['large']['desc']}</div>",
295
+ unsafe_allow_html=True,
296
+ )
297
+ render_results(res_large)
298
+
299
+ with col_r:
300
+ st.markdown(
301
+ f"<div class='col-header'>{MODELS['small']['label']} — {MODELS['small']['desc']}</div>",
302
+ unsafe_allow_html=True,
303
+ )
304
+ render_results(res_small)
305
+
306
+ # ---------------------------------------------------------------------------
307
+ # Сайдбар
308
+ # ---------------------------------------------------------------------------
309
+ with st.sidebar:
310
+ st.markdown("### О сервисе")
311
+
312
+ for key, cfg in available.items():
313
+ st.markdown(
314
+ f"**{cfg['label']}** \n"
315
+ f"Модель: [{cfg['base']}]({cfg['base_url']}) \n"
316
+ f"Датасет: [{cfg['dataset']}]({cfg['dataset_url']}) \n"
317
+ f"Классов: **{cfg['n_classes']}**"
318
+ )
319
+ # Тематики в виде тегов
320
+ tags = cfg["topics"].split(" · ")
321
+ tags_html = " ".join(
322
+ f"<span style='display:inline-block;background:#d8f3dc;color:#1b4332;"
323
+ f"border-radius:4px;padding:1px 6px;font-size:0.72rem;"
324
+ f"margin:2px 2px 2px 0;font-family:monospace'>{t}</span>"
325
+ for t in tags
326
+ )
327
+ st.markdown(tags_html, unsafe_allow_html=True)
328
+ st.markdown("")
329
+
330
+ st.divider()
331
+ st.caption(
332
+ "**Top-95%** — категории выводятся по убыванию вероятности, "
333
+ "пока суммарная вероятность не превысит 95%."
334
+ )
requirements.txt CHANGED
@@ -1,3 +1,9 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.30.0
3
+ datasets>=2.0.0
4
+ scikit-learn>=1.0.0
5
+ numpy>=1.24.0
6
+ pandas>=1.5.0
7
+ matplotlib>=3.5.0
8
+ streamlit>=1.20.0
9
+ accelerate>=0.20.0