Add model code and full model weights
Browse files- .gitignore +144 -0
- backbone/__init__.py +0 -0
- backbone/dataset.py +221 -0
- backbone/pl_model.py +275 -0
- backbone/pl_train.py +335 -0
- backbone_weights/leftBinSyntax_R3D_full_fold00.pt +3 -0
- backbone_weights/leftBinSyntax_R3D_full_fold01.pt +3 -0
- backbone_weights/leftBinSyntax_R3D_full_fold02.pt +3 -0
- backbone_weights/leftBinSyntax_R3D_full_fold03.pt +3 -0
- backbone_weights/leftBinSyntax_R3D_full_fold04.pt +3 -0
- backbone_weights/rightBinSyntax_R3D_full_fold00.pt +3 -0
- backbone_weights/rightBinSyntax_R3D_full_fold01.pt +3 -0
- backbone_weights/rightBinSyntax_R3D_full_fold02.pt +3 -0
- backbone_weights/rightBinSyntax_R3D_full_fold03.pt +3 -0
- backbone_weights/rightBinSyntax_R3D_full_fold04.pt +3 -0
- full_model/__init__.py +0 -0
- full_model/rnn_dataset.py +257 -0
- full_model/rnn_model.py +386 -0
- full_model/rnn_train.py +418 -0
- full_model_weights/LeftBinSyntax_R3D_fold00_lstm_mean_post_best.pt +3 -0
- full_model_weights/LeftBinSyntax_R3D_fold01_lstm_mean_post_best.pt +3 -0
- full_model_weights/LeftBinSyntax_R3D_fold02_lstm_mean_post_best.pt +3 -0
- full_model_weights/LeftBinSyntax_R3D_fold03_lstm_mean_post_best.pt +3 -0
- full_model_weights/LeftBinSyntax_R3D_fold04_lstm_mean_post_best.pt +3 -0
- full_model_weights/RightBinSyntax_R3D_fold00_lstm_mean_post_best.pt +3 -0
- full_model_weights/RightBinSyntax_R3D_fold01_lstm_mean_post_best.pt +3 -0
- full_model_weights/RightBinSyntax_R3D_fold02_lstm_mean_post_best.pt +3 -0
- full_model_weights/RightBinSyntax_R3D_fold03_lstm_mean_post_best.pt +3 -0
- full_model_weights/RightBinSyntax_R3D_fold04_lstm_mean_post_best.pt +3 -0
- inference/__init__.py +0 -0
- inference/metrics_visualization.py +426 -0
- inference/rnn_apply.py +344 -0
- requirements.txt +21 -0
- scaling_coeffs/scaling_coeffs.json +27 -0
.gitignore
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
env/
|
| 8 |
+
venv/
|
| 9 |
+
ENV/
|
| 10 |
+
|
| 11 |
+
# PyInstaller
|
| 12 |
+
# Usually these files are written by a python script from a template
|
| 13 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 14 |
+
*.manifest
|
| 15 |
+
*.spec
|
| 16 |
+
|
| 17 |
+
# Installer logs
|
| 18 |
+
pip-log.txt
|
| 19 |
+
pip-delete-this-directory.txt
|
| 20 |
+
|
| 21 |
+
# Unit test / coverage reports
|
| 22 |
+
htmlcov/
|
| 23 |
+
.tox/
|
| 24 |
+
.nox/
|
| 25 |
+
.coverage
|
| 26 |
+
.coverage.*
|
| 27 |
+
.cache
|
| 28 |
+
nosetests.xml
|
| 29 |
+
coverage.xml
|
| 30 |
+
*.cover
|
| 31 |
+
*.py,cover
|
| 32 |
+
.hypothesis/
|
| 33 |
+
.pytest_cache/
|
| 34 |
+
cover/
|
| 35 |
+
|
| 36 |
+
# Translations
|
| 37 |
+
*.mo
|
| 38 |
+
*.pot
|
| 39 |
+
|
| 40 |
+
# Django stuff:
|
| 41 |
+
*.log
|
| 42 |
+
local_settings.py
|
| 43 |
+
db.sqlite3
|
| 44 |
+
db.sqlite3-journal
|
| 45 |
+
|
| 46 |
+
# Flask stuff:
|
| 47 |
+
instance/
|
| 48 |
+
.webassets-cache
|
| 49 |
+
|
| 50 |
+
# Scrapy stuff:
|
| 51 |
+
.scrapy
|
| 52 |
+
|
| 53 |
+
# Sphinx documentation
|
| 54 |
+
docs/_build/
|
| 55 |
+
|
| 56 |
+
# PyBuilder
|
| 57 |
+
target/
|
| 58 |
+
|
| 59 |
+
# Jupyter Notebook
|
| 60 |
+
.ipynb_checkpoints
|
| 61 |
+
|
| 62 |
+
# IPython
|
| 63 |
+
profile_default/
|
| 64 |
+
ipython_config.py
|
| 65 |
+
|
| 66 |
+
# pyenv
|
| 67 |
+
.python-version
|
| 68 |
+
|
| 69 |
+
# pipenv
|
| 70 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 71 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 72 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 73 |
+
# install all needed dependencies.
|
| 74 |
+
#Pipfile.lock
|
| 75 |
+
|
| 76 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 77 |
+
__pypackages__/
|
| 78 |
+
|
| 79 |
+
# Celery stuff
|
| 80 |
+
celerybeat-schedule
|
| 81 |
+
celerybeat.pid
|
| 82 |
+
|
| 83 |
+
# SageMath parsed files
|
| 84 |
+
*.sage.py
|
| 85 |
+
|
| 86 |
+
# Environments
|
| 87 |
+
.env
|
| 88 |
+
.venv
|
| 89 |
+
env/
|
| 90 |
+
venv/
|
| 91 |
+
ENV/
|
| 92 |
+
env.bak/
|
| 93 |
+
venv.bak/
|
| 94 |
+
|
| 95 |
+
# Spyder project settings
|
| 96 |
+
.spyderproject
|
| 97 |
+
.spyproject
|
| 98 |
+
|
| 99 |
+
# Rope project settings
|
| 100 |
+
.ropeproject
|
| 101 |
+
|
| 102 |
+
# mkdocs documentation
|
| 103 |
+
/site
|
| 104 |
+
|
| 105 |
+
# mypy
|
| 106 |
+
.mypy_cache/
|
| 107 |
+
.dmypy.json
|
| 108 |
+
dmypy.json
|
| 109 |
+
|
| 110 |
+
# Pyre type checker
|
| 111 |
+
.pyre/
|
| 112 |
+
|
| 113 |
+
# Cython debug symbols
|
| 114 |
+
cython_debug/
|
| 115 |
+
|
| 116 |
+
# PyCharm
|
| 117 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 118 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 119 |
+
# and can be added to your project. Official documentation:
|
| 120 |
+
# https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
| 121 |
+
.idea/
|
| 122 |
+
|
| 123 |
+
# VS Code
|
| 124 |
+
.vscode/
|
| 125 |
+
|
| 126 |
+
# Pyright
|
| 127 |
+
.pyright/
|
| 128 |
+
|
| 129 |
+
# Pyre type checker
|
| 130 |
+
.pyre/
|
| 131 |
+
|
| 132 |
+
__pycache__/
|
| 133 |
+
*.pyc
|
| 134 |
+
*.pyo
|
| 135 |
+
logs/
|
| 136 |
+
rnn_logs/
|
| 137 |
+
backbone_logs/
|
| 138 |
+
checkpoints/
|
| 139 |
+
lightning_logs/
|
| 140 |
+
wandb/
|
| 141 |
+
runs/
|
| 142 |
+
tensorboard/
|
| 143 |
+
results/
|
| 144 |
+
visualizations/
|
backbone/__init__.py
ADDED
|
File without changes
|
backbone/dataset.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from typing import Callable, Optional, Tuple, Any
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pydicom
|
| 8 |
+
import torch
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
from torch.utils.data import Dataset
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SyntaxDataset(Dataset):
|
| 14 |
+
"""
|
| 15 |
+
PyTorch Dataset для обучения 3D-backbone по DICOM-видео.
|
| 16 |
+
|
| 17 |
+
Ожидается, что:
|
| 18 |
+
- meta (JSON) содержит список словарей с полями:
|
| 19 |
+
"path": относительный путь к DICOM-файлу от директории этого JSON
|
| 20 |
+
"artery": 0 (левая) или 1 (правая)
|
| 21 |
+
"<label>": численное значение SYNTAX score (например, "syntax_left")
|
| 22 |
+
- видео загружается из multi-frame DICOM (pydicom), как 3D-массив.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
root: str,
|
| 28 |
+
meta: str,
|
| 29 |
+
train: bool,
|
| 30 |
+
length: int,
|
| 31 |
+
label: str,
|
| 32 |
+
artery_bin: int,
|
| 33 |
+
validation: bool = False,
|
| 34 |
+
transform: Optional[Callable] = None,
|
| 35 |
+
) -> None:
|
| 36 |
+
# Корень датасета (используется, если meta передан как относительный путь)
|
| 37 |
+
self.root = Path(root).resolve()
|
| 38 |
+
# Режим обучения/валидации
|
| 39 |
+
self.train = train
|
| 40 |
+
# Требуемая длина клипа (количество временных кадров)
|
| 41 |
+
self.length = int(length)
|
| 42 |
+
# Имя поля с меткой (например, "syntax_left")
|
| 43 |
+
self.label = label
|
| 44 |
+
# Функция аугментаций/преобразований
|
| 45 |
+
self.transform = transform
|
| 46 |
+
# Флаг: использовать только записи с положительным score
|
| 47 |
+
self.validation = validation
|
| 48 |
+
|
| 49 |
+
# Проверка валидности кода артерии: 0 — left, 1 — right
|
| 50 |
+
if artery_bin not in (0, 1):
|
| 51 |
+
raise ValueError("artery_bin must be 0 (left) or 1 (right)")
|
| 52 |
+
self.artery_bin = artery_bin
|
| 53 |
+
|
| 54 |
+
# Полный путь к JSON-файлу метаданных
|
| 55 |
+
meta_path = meta if os.path.isabs(meta) else self.root / meta
|
| 56 |
+
meta_path = Path(meta_path).resolve()
|
| 57 |
+
|
| 58 |
+
# Директория, в которой расположен JSON; пути в "path" задаются относительно неё
|
| 59 |
+
json_dir = meta_path.parent
|
| 60 |
+
|
| 61 |
+
print(f"Backbone dataset: root={self.root}, meta={meta_path}, json_dir={json_dir}")
|
| 62 |
+
|
| 63 |
+
# Загрузка списка записей из JSON
|
| 64 |
+
with open(meta_path, "r", encoding="utf-8") as f:
|
| 65 |
+
dataset = json.load(f)
|
| 66 |
+
|
| 67 |
+
# Фильтрация по артерии
|
| 68 |
+
dataset = [rec for rec in dataset if rec.get("artery") == artery_bin]
|
| 69 |
+
|
| 70 |
+
# Для валидации можно оставить только записи с положительным score
|
| 71 |
+
if validation:
|
| 72 |
+
dataset = [rec for rec in dataset if float(rec.get(self.label, 0.0)) > 0]
|
| 73 |
+
|
| 74 |
+
# Сохраняем базовую директорию JSON для последующей сборки путей DICOM
|
| 75 |
+
self.json_dir = json_dir
|
| 76 |
+
self.dataset = dataset
|
| 77 |
+
|
| 78 |
+
# Инициализация веса записи (если не задан) значением 1.0
|
| 79 |
+
for rec in self.dataset:
|
| 80 |
+
rec.setdefault("weight", 1.0)
|
| 81 |
+
|
| 82 |
+
print(f"Backbone dataset loaded: {len(self.dataset)} samples after filtering")
|
| 83 |
+
|
| 84 |
+
def get_sample_weights(self) -> Tensor:
|
| 85 |
+
"""
|
| 86 |
+
Возвращает веса для выборки с помощью WeightedRandomSampler.
|
| 87 |
+
|
| 88 |
+
Логика:
|
| 89 |
+
- делим записи на интервалы по severity (score) отдельно для левой и правой артерии;
|
| 90 |
+
- чем реже встречается интервал, тем больше вес его записей.
|
| 91 |
+
"""
|
| 92 |
+
# Пороговые значения score для разбиения на интервалы (по артерии)
|
| 93 |
+
bin_thresholds = {
|
| 94 |
+
0: [0, 5, 10, 15], # левая
|
| 95 |
+
1: [0, 2, 5, 8], # правая
|
| 96 |
+
}
|
| 97 |
+
thr0, thr1, thr2, thr3 = bin_thresholds[self.artery_bin]
|
| 98 |
+
|
| 99 |
+
# Функция определения номера бина по значению score
|
| 100 |
+
def in_bin(score: float) -> int:
|
| 101 |
+
if score == thr0:
|
| 102 |
+
return 0
|
| 103 |
+
if thr0 < score <= thr1:
|
| 104 |
+
return 1
|
| 105 |
+
if thr1 < score <= thr2:
|
| 106 |
+
return 2
|
| 107 |
+
if thr2 < score <= thr3:
|
| 108 |
+
return 3
|
| 109 |
+
return 4
|
| 110 |
+
|
| 111 |
+
# Список score'ов и соответствующих бинов
|
| 112 |
+
scores = [float(rec.get(self.label, 0.0)) for rec in self.dataset]
|
| 113 |
+
bins = [in_bin(s) for s in scores]
|
| 114 |
+
|
| 115 |
+
# Подсчёт количества элементов в каждом бине
|
| 116 |
+
counts = np.bincount(np.array(bins, dtype=np.int64), minlength=5)
|
| 117 |
+
total = int(counts.sum())
|
| 118 |
+
|
| 119 |
+
# Вес бина — обратная частота
|
| 120 |
+
weights_by_bin = np.array(
|
| 121 |
+
[(total / counts[b]) if counts[b] > 0 else 0.0 for b in range(5)],
|
| 122 |
+
dtype=np.float64,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Веса для каждой записи по номеру её бина
|
| 126 |
+
weights = np.array([weights_by_bin[b] for b in bins], dtype=np.float64)
|
| 127 |
+
|
| 128 |
+
return torch.as_tensor(weights, dtype=torch.double)
|
| 129 |
+
|
| 130 |
+
def __len__(self) -> int:
|
| 131 |
+
# Размер датасета — количество записей в JSON после всех фильтров
|
| 132 |
+
return len(self.dataset)
|
| 133 |
+
|
| 134 |
+
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor, float, str, Tensor]:
|
| 135 |
+
"""
|
| 136 |
+
Возвращает:
|
| 137 |
+
video: Tensor, видео-клип (T, H, W, C) до применения transform
|
| 138 |
+
label: Tensor(1,), бинарная метка для классификации
|
| 139 |
+
target: Tensor(1,), регрессионная цель (log1p(score))
|
| 140 |
+
sample_weight: float, исходный вес записи из JSON
|
| 141 |
+
path: str, относительный путь к DICOM (как в JSON)
|
| 142 |
+
original_label: Tensor(1,), исходное значение score
|
| 143 |
+
"""
|
| 144 |
+
rec = self.dataset[idx]
|
| 145 |
+
|
| 146 |
+
# Относительный путь к DICOM, как он хранится в JSON
|
| 147 |
+
rel_path = rec["path"]
|
| 148 |
+
sample_weight = float(rec.get("weight", 1.0))
|
| 149 |
+
|
| 150 |
+
# Полный путь: директория JSON + относительный путь из JSON
|
| 151 |
+
full_path = (self.json_dir / rel_path).resolve()
|
| 152 |
+
|
| 153 |
+
if not full_path.exists():
|
| 154 |
+
raise FileNotFoundError(
|
| 155 |
+
f"DICOM not found: {full_path}\n"
|
| 156 |
+
f" json_dir={self.json_dir}\n"
|
| 157 |
+
f" rel_path='{rel_path}'"
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Загрузка DICOM и получение массива пикселей
|
| 161 |
+
video = pydicom.dcmread(str(full_path)).pixel_array
|
| 162 |
+
|
| 163 |
+
# Ожидается 3D-массив (T, H, W) или (H, W, T)
|
| 164 |
+
if video.ndim != 3:
|
| 165 |
+
raise ValueError(f"Expected 3D video array, got shape={video.shape} for {rel_path}")
|
| 166 |
+
|
| 167 |
+
# Если временная ось оказалась последней (H, W, T) — перенесём её в первую (T, H, W)
|
| 168 |
+
if video.shape[0] > 128 and video.shape[-1] <= 128:
|
| 169 |
+
video = np.moveaxis(video, -1, 0)
|
| 170 |
+
|
| 171 |
+
# Нормализация uint16 → uint8 с масштабированием на [0, 255]
|
| 172 |
+
if video.dtype == np.uint16:
|
| 173 |
+
vmax = int(np.max(video))
|
| 174 |
+
if vmax <= 0:
|
| 175 |
+
raise ValueError(f"Invalid vmax={vmax} for {rel_path}")
|
| 176 |
+
video = (video.astype(np.float32) * (255.0 / vmax)).clip(0, 255).astype(np.uint8)
|
| 177 |
+
else:
|
| 178 |
+
video = video.astype(np.uint8)
|
| 179 |
+
|
| 180 |
+
# Получение численного score из записи
|
| 181 |
+
score = float(rec.get(self.label, 0.0))
|
| 182 |
+
|
| 183 |
+
# Порог для бинарной классификации (отдельно для каждой артерии)
|
| 184 |
+
bin_thresholds = {
|
| 185 |
+
0: 15, # левая
|
| 186 |
+
1: 5, # правая
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
# Бинарная метка: 1, если score выше порога, иначе 0
|
| 190 |
+
label = torch.tensor(
|
| 191 |
+
[1.0 if score > bin_thresholds[self.artery_bin] else 0.0],
|
| 192 |
+
dtype=torch.float32,
|
| 193 |
+
)
|
| 194 |
+
# Регрессионная цель — логарифм score с единицей
|
| 195 |
+
target = torch.tensor([np.log1p(score)], dtype=torch.float32)
|
| 196 |
+
# Исходное значение score
|
| 197 |
+
original_label = torch.tensor([score], dtype=torch.float32)
|
| 198 |
+
|
| 199 |
+
# При необходимости "дублируем" видео по времени до нужной длины клипа
|
| 200 |
+
while video.shape[0] < self.length:
|
| 201 |
+
video = np.concatenate([video, video], axis=0)
|
| 202 |
+
|
| 203 |
+
# Случайная или фиксированная вырезка окна по времени
|
| 204 |
+
t = int(video.shape[0])
|
| 205 |
+
if self.train:
|
| 206 |
+
# В обучении берём случайный отрезок длины self.length
|
| 207 |
+
begin = torch.randint(low=0, high=t - self.length + 1, size=(1,)).item()
|
| 208 |
+
video = video[begin: begin + self.length]
|
| 209 |
+
else:
|
| 210 |
+
# На валидации можно взять первые self.length кадров
|
| 211 |
+
video = video[:self.length]
|
| 212 |
+
|
| 213 |
+
# Преобразуем (T, H, W) в (T, H, W, C), где C=3 (дублируем градации серого по каналам)
|
| 214 |
+
video = torch.from_numpy(np.stack([video, video, video], axis=-1))
|
| 215 |
+
|
| 216 |
+
# Применяем цепочку трансформаций, если задана
|
| 217 |
+
if self.transform is not None:
|
| 218 |
+
video = self.transform(video)
|
| 219 |
+
|
| 220 |
+
# Возвращаем видео, метки и относительный путь
|
| 221 |
+
return video, label, target, sample_weight, str(rel_path), original_label
|
backbone/pl_model.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Optional
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn, optim
|
| 6 |
+
import lightning.pytorch as pl
|
| 7 |
+
import torchvision.models.video as tvmv
|
| 8 |
+
import sklearn.metrics as skm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SyntaxLightningModule(pl.LightningModule):
|
| 12 |
+
"""
|
| 13 |
+
LightningModule для обучения 3D-backbone на SYNTAX score.
|
| 14 |
+
|
| 15 |
+
Архитектура:
|
| 16 |
+
- backbone: ResNet3D (r3d_18) из torchvision
|
| 17 |
+
- выходной полносвязный слой: два нейрона
|
| 18 |
+
[0] — логит для бинарной классификации (значимое поражение)
|
| 19 |
+
[1] — регрессионный выход для SYNTAX score (log1p)
|
| 20 |
+
|
| 21 |
+
Режимы обучения:
|
| 22 |
+
- pretrain (weight_path is None):
|
| 23 |
+
замораживается весь backbone, обучается только финальный слой (fc)
|
| 24 |
+
- finetune (weight_path задан):
|
| 25 |
+
загружаются веса из чекпоинта и дообучается вся сеть целиком.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
num_classes: int,
|
| 31 |
+
lr: float,
|
| 32 |
+
weight_decay: float = 0.0,
|
| 33 |
+
max_epochs: Optional[int] = None,
|
| 34 |
+
weight_path: Optional[str] = None,
|
| 35 |
+
sigma_a: float = 0.0,
|
| 36 |
+
sigma_b: float = 1.0,
|
| 37 |
+
**kwargs,
|
| 38 |
+
):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.save_hyperparameters()
|
| 41 |
+
|
| 42 |
+
self.num_classes = int(num_classes)
|
| 43 |
+
self.lr = float(lr)
|
| 44 |
+
self.weight_decay = float(weight_decay)
|
| 45 |
+
self.max_epochs = max_epochs
|
| 46 |
+
self.weight_path = weight_path
|
| 47 |
+
|
| 48 |
+
self.sigma_a = float(sigma_a)
|
| 49 |
+
self.sigma_b = float(sigma_b)
|
| 50 |
+
|
| 51 |
+
# Инициализация 3D-ResNet-18 с предобученными весами
|
| 52 |
+
self.model = tvmv.r3d_18(weights=tvmv.R3D_18_Weights.DEFAULT)
|
| 53 |
+
|
| 54 |
+
# Замена финального слоя fc на слой с num_classes выходами
|
| 55 |
+
in_features = self.model.fc.in_features
|
| 56 |
+
self.model.fc = nn.Linear(in_features=in_features, out_features=self.num_classes, bias=True)
|
| 57 |
+
|
| 58 |
+
# При наличии пути к весам загружаем backbone
|
| 59 |
+
if self.weight_path is not None:
|
| 60 |
+
self._load_backbone_weights(self.weight_path)
|
| 61 |
+
|
| 62 |
+
# Лоссы
|
| 63 |
+
self.loss_clf = nn.BCEWithLogitsLoss(reduction="none")
|
| 64 |
+
self.loss_reg = nn.MSELoss(reduction="none")
|
| 65 |
+
|
| 66 |
+
# Буферы для валидации
|
| 67 |
+
self._y_true = []
|
| 68 |
+
self._y_prob = []
|
| 69 |
+
self._y_pred = []
|
| 70 |
+
self._t_true = []
|
| 71 |
+
self._t_pred = []
|
| 72 |
+
|
| 73 |
+
def _load_backbone_weights(self, weight_path: str) -> None:
|
| 74 |
+
"""
|
| 75 |
+
Загружает веса backbone из:
|
| 76 |
+
- Lightning чекпоинта (dict с ключом 'state_dict')
|
| 77 |
+
- или "голого" state_dict (.pt/.pth), сохранённого через model.state_dict().
|
| 78 |
+
|
| 79 |
+
Логирует источник и статистику по ключам.
|
| 80 |
+
"""
|
| 81 |
+
obj = torch.load(weight_path, map_location="cpu", weights_only=False)
|
| 82 |
+
|
| 83 |
+
if isinstance(obj, dict) and "state_dict" in obj:
|
| 84 |
+
state_dict = obj["state_dict"]
|
| 85 |
+
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
|
| 86 |
+
src_type = "lightning_checkpoint"
|
| 87 |
+
else:
|
| 88 |
+
state_dict = obj
|
| 89 |
+
src_type = "raw_state_dict"
|
| 90 |
+
|
| 91 |
+
incompatible = self.model.load_state_dict(state_dict, strict=False)
|
| 92 |
+
|
| 93 |
+
loaded_keys = [k for k in state_dict.keys() if k not in incompatible.missing_keys]
|
| 94 |
+
print(
|
| 95 |
+
f"[Backbone] Loaded weights from '{weight_path}' "
|
| 96 |
+
f"(type={src_type}): {len(loaded_keys)} params, "
|
| 97 |
+
f"missing={len(incompatible.missing_keys)}, "
|
| 98 |
+
f"unexpected={len(incompatible.unexpected_keys)}"
|
| 99 |
+
)
|
| 100 |
+
if incompatible.missing_keys:
|
| 101 |
+
print(f"[Backbone] Missing keys example: {incompatible.missing_keys[:5]}")
|
| 102 |
+
if incompatible.unexpected_keys:
|
| 103 |
+
print(f"[Backbone] Unexpected keys example: {incompatible.unexpected_keys[:5]}")
|
| 104 |
+
|
| 105 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 106 |
+
"""
|
| 107 |
+
Вход:
|
| 108 |
+
x: (B, C, T, H, W)
|
| 109 |
+
|
| 110 |
+
Выход:
|
| 111 |
+
y_hat: (B, 2) — [clf_logit, reg_output]
|
| 112 |
+
"""
|
| 113 |
+
return self.model(x)
|
| 114 |
+
|
| 115 |
+
def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor:
|
| 116 |
+
"""
|
| 117 |
+
Один шаг обучения backbone.
|
| 118 |
+
"""
|
| 119 |
+
x, y, target, sample_weight, path, original_label = batch
|
| 120 |
+
|
| 121 |
+
y_hat = self(x)
|
| 122 |
+
yp_clf = y_hat[:, 0:1]
|
| 123 |
+
yp_reg = y_hat[:, 1:2]
|
| 124 |
+
|
| 125 |
+
weights_clf = torch.where(y > 0, 1.0, 0.45).to(y.dtype)
|
| 126 |
+
clf_loss = (self.loss_clf(yp_clf, y) * weights_clf).mean()
|
| 127 |
+
|
| 128 |
+
reg_loss_raw = self.loss_reg(yp_reg, target)
|
| 129 |
+
sigma = self.sigma_a * target + self.sigma_b
|
| 130 |
+
reg_loss = (reg_loss_raw / (sigma ** 2)).mean()
|
| 131 |
+
|
| 132 |
+
loss = clf_loss + 0.5 * reg_loss
|
| 133 |
+
|
| 134 |
+
y_prob = torch.sigmoid(yp_clf).detach()
|
| 135 |
+
y_pred = (y_prob > 0.5).int().cpu().numpy()
|
| 136 |
+
y_true = y.detach().int().cpu().numpy()
|
| 137 |
+
|
| 138 |
+
self.log("train_clf_loss", clf_loss, prog_bar=True, sync_dist=True)
|
| 139 |
+
self.log("train_reg_loss", reg_loss, prog_bar=True, sync_dist=True)
|
| 140 |
+
self.log("train_loss", loss, prog_bar=True, sync_dist=True)
|
| 141 |
+
|
| 142 |
+
self.log("train_f1", skm.f1_score(y_true, y_pred, zero_division=0),
|
| 143 |
+
prog_bar=True, sync_dist=True)
|
| 144 |
+
self.log("train_acc", skm.accuracy_score(y_true, y_pred),
|
| 145 |
+
prog_bar=True, sync_dist=True)
|
| 146 |
+
|
| 147 |
+
return loss
|
| 148 |
+
|
| 149 |
+
def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor:
|
| 150 |
+
"""
|
| 151 |
+
Один шаг валидации backbone.
|
| 152 |
+
"""
|
| 153 |
+
x, y, target, sample_weight, path, original_label = batch
|
| 154 |
+
|
| 155 |
+
y_hat = self(x)
|
| 156 |
+
yp_clf = y_hat[:, 0:1]
|
| 157 |
+
yp_reg = y_hat[:, 1:2]
|
| 158 |
+
|
| 159 |
+
clf_loss = self.loss_clf(yp_clf, y).mean()
|
| 160 |
+
|
| 161 |
+
reg_loss_raw = self.loss_reg(yp_reg, target)
|
| 162 |
+
sigma = self.sigma_a * target + self.sigma_b
|
| 163 |
+
reg_loss = (reg_loss_raw / (sigma ** 2)).mean()
|
| 164 |
+
|
| 165 |
+
loss = clf_loss + 0.5 * reg_loss
|
| 166 |
+
|
| 167 |
+
y_prob = torch.sigmoid(yp_clf).float()
|
| 168 |
+
self._y_true.append(float(y[..., 0].float().cpu()))
|
| 169 |
+
self._y_prob.append(float(y_prob[..., 0].cpu()))
|
| 170 |
+
self._y_pred.append(int((y_prob[..., 0] > 0.5).cpu()))
|
| 171 |
+
|
| 172 |
+
self._t_true.append(float(target[..., 0].float().cpu()))
|
| 173 |
+
self._t_pred.append(float(yp_reg[..., 0].cpu()))
|
| 174 |
+
|
| 175 |
+
self.log("val_loss", loss, prog_bar=True, sync_dist=True)
|
| 176 |
+
self.log("val_clf_loss", clf_loss, prog_bar=False, sync_dist=True)
|
| 177 |
+
self.log("val_reg_loss", reg_loss, prog_bar=False, sync_dist=True)
|
| 178 |
+
|
| 179 |
+
return loss
|
| 180 |
+
|
| 181 |
+
def on_validation_epoch_end(self) -> None:
|
| 182 |
+
"""
|
| 183 |
+
Расчёт и логирование метрик по окончании валидации.
|
| 184 |
+
"""
|
| 185 |
+
if len(self._t_true) > 0:
|
| 186 |
+
rmse = skm.root_mean_squared_error(self._t_true, self._t_pred)
|
| 187 |
+
mae = skm.mean_absolute_error(self._t_true, self._t_pred)
|
| 188 |
+
self.log("val_rmse", rmse, prog_bar=True, sync_dist=True)
|
| 189 |
+
self.log("val_reg_mae", mae, prog_bar=True, sync_dist=True)
|
| 190 |
+
|
| 191 |
+
if len(set(self._y_true)) > 1:
|
| 192 |
+
auc = skm.roc_auc_score(self._y_true, self._y_prob)
|
| 193 |
+
f1 = skm.f1_score(self._y_true, self._y_pred, zero_division=0)
|
| 194 |
+
acc = skm.accuracy_score(self._y_true, self._y_pred)
|
| 195 |
+
self.log("val_auc", auc, prog_bar=True, sync_dist=True)
|
| 196 |
+
self.log("val_f1", f1, prog_bar=True, sync_dist=True)
|
| 197 |
+
self.log("val_acc", acc, prog_bar=True, sync_dist=True)
|
| 198 |
+
|
| 199 |
+
self._y_true.clear()
|
| 200 |
+
self._y_prob.clear()
|
| 201 |
+
self._y_pred.clear()
|
| 202 |
+
self._t_true.clear()
|
| 203 |
+
self._t_pred.clear()
|
| 204 |
+
|
| 205 |
+
def on_train_epoch_end(self) -> None:
|
| 206 |
+
"""
|
| 207 |
+
Логирование текущего learning rate.
|
| 208 |
+
"""
|
| 209 |
+
opt = self.optimizers()
|
| 210 |
+
self.log(
|
| 211 |
+
"lr",
|
| 212 |
+
opt.optimizer.param_groups[0]["lr"],
|
| 213 |
+
on_step=False,
|
| 214 |
+
on_epoch=True,
|
| 215 |
+
sync_dist=True,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
def configure_optimizers(self):
|
| 219 |
+
"""
|
| 220 |
+
Настройка оптимизатора и OneCycleLR.
|
| 221 |
+
|
| 222 |
+
Если weight_path is None:
|
| 223 |
+
обучается только self.model.fc (pretrain).
|
| 224 |
+
Иначе:
|
| 225 |
+
обучается вся модель (full finetune).
|
| 226 |
+
"""
|
| 227 |
+
if self.weight_path is None:
|
| 228 |
+
for p in self.parameters():
|
| 229 |
+
p.requires_grad = False
|
| 230 |
+
for p in self.model.fc.parameters():
|
| 231 |
+
p.requires_grad = True
|
| 232 |
+
params = self.model.fc.parameters()
|
| 233 |
+
else:
|
| 234 |
+
for p in self.parameters():
|
| 235 |
+
p.requires_grad = True
|
| 236 |
+
params = self.parameters()
|
| 237 |
+
|
| 238 |
+
optimizer = optim.AdamW(params, lr=self.lr, weight_decay=self.weight_decay)
|
| 239 |
+
|
| 240 |
+
if self.max_epochs is not None and getattr(self, "trainer", None) is not None:
|
| 241 |
+
total_steps = self.trainer.estimated_stepping_batches
|
| 242 |
+
scheduler = optim.lr_scheduler.OneCycleLR(
|
| 243 |
+
optimizer=optimizer,
|
| 244 |
+
max_lr=self.lr,
|
| 245 |
+
total_steps=total_steps,
|
| 246 |
+
)
|
| 247 |
+
return {
|
| 248 |
+
"optimizer": optimizer,
|
| 249 |
+
"lr_scheduler": {
|
| 250 |
+
"scheduler": scheduler,
|
| 251 |
+
"interval": "step",
|
| 252 |
+
},
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
return optimizer
|
| 256 |
+
|
| 257 |
+
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
|
| 258 |
+
"""
|
| 259 |
+
Шаг инференса backbone.
|
| 260 |
+
"""
|
| 261 |
+
x, y, target, sample_weight, path, original_label = batch
|
| 262 |
+
y_hat = self(x)
|
| 263 |
+
yp_clf = y_hat[:, 0:1]
|
| 264 |
+
yp_reg = y_hat[:, 1:2]
|
| 265 |
+
y_prob = torch.sigmoid(yp_clf)
|
| 266 |
+
|
| 267 |
+
return {
|
| 268 |
+
"y": y,
|
| 269 |
+
"y_pred": (y_prob > 0.5).int(),
|
| 270 |
+
"y_prob": y_prob,
|
| 271 |
+
"y_reg": yp_reg,
|
| 272 |
+
"target": target,
|
| 273 |
+
"original_label": original_label,
|
| 274 |
+
"path": path,
|
| 275 |
+
}
|
backbone/pl_train.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import warnings
|
| 3 |
+
|
| 4 |
+
import click
|
| 5 |
+
import lightning.pytorch as pl
|
| 6 |
+
import torch
|
| 7 |
+
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
|
| 8 |
+
from lightning.pytorch.loggers import TensorBoardLogger
|
| 9 |
+
from pytorchvideo.transforms import Normalize, Permute, RandAugment
|
| 10 |
+
from torch.utils.data import DataLoader, WeightedRandomSampler
|
| 11 |
+
from torchvision.transforms import transforms as T
|
| 12 |
+
from torchvision.transforms._transforms_video import ToTensorVideo
|
| 13 |
+
from torchvision.transforms import InterpolationMode
|
| 14 |
+
|
| 15 |
+
from backbone.dataset import SyntaxDataset
|
| 16 |
+
from backbone.pl_model import SyntaxLightningModule
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Отключаем предупреждение Lightning о device id при DDP-инициализации
|
| 20 |
+
warnings.filterwarnings("ignore", message="No device id is provided via `init_process_group`")
|
| 21 |
+
|
| 22 |
+
# Устанавливаем точность матричных умножений (оптимизация производительности)
|
| 23 |
+
torch.set_float32_matmul_precision("medium")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_transforms(video_size, imagenet_mean, imagenet_std, train: bool = True):
|
| 27 |
+
"""
|
| 28 |
+
Создаёт пайплайн аугментаций/преобразований для видео.
|
| 29 |
+
|
| 30 |
+
Входные данные:
|
| 31 |
+
- видео в формате Tensor (T, H, W, C), dtype uint8
|
| 32 |
+
|
| 33 |
+
Результат:
|
| 34 |
+
- Tensor (C, T, H, W), нормализованный, готовый к подаче в 3D-ResNet.
|
| 35 |
+
"""
|
| 36 |
+
interpolation_choices = [InterpolationMode.BILINEAR, InterpolationMode.BICUBIC]
|
| 37 |
+
|
| 38 |
+
if train:
|
| 39 |
+
return T.Compose([
|
| 40 |
+
# Переводит (T, H, W, C) → (C, T, H, W), значения в [0,1]
|
| 41 |
+
ToTensorVideo(),
|
| 42 |
+
# Меняем порядок осей: (C, T, H, W) → (T, C, H, W) для RandAugment
|
| 43 |
+
Permute(dims=[1, 0, 2, 3]),
|
| 44 |
+
# Случайные аугментации по времени/пространству
|
| 45 |
+
RandAugment(magnitude=10, num_layers=2),
|
| 46 |
+
# Случайное горизонтальное отражение
|
| 47 |
+
T.RandomHorizontalFlip(),
|
| 48 |
+
# Возвращаемся к формату (C, T, H, W)
|
| 49 |
+
Permute(dims=[1, 0, 2, 3]),
|
| 50 |
+
# Случайный выбор интерполяции для изменения размера
|
| 51 |
+
T.RandomChoice([
|
| 52 |
+
T.Resize(size=video_size, interpolation=interp, antialias=True)
|
| 53 |
+
for interp in interpolation_choices
|
| 54 |
+
]),
|
| 55 |
+
# Нормализация по статистикам ImageNet
|
| 56 |
+
Normalize(mean=imagenet_mean, std=imagenet_std),
|
| 57 |
+
])
|
| 58 |
+
else:
|
| 59 |
+
# Для валидации/инференса используем только приведение к тензору и resize
|
| 60 |
+
return T.Compose([
|
| 61 |
+
ToTensorVideo(),
|
| 62 |
+
T.Resize(size=video_size, interpolation=InterpolationMode.BICUBIC, antialias=True),
|
| 63 |
+
Normalize(mean=imagenet_mean, std=imagenet_std),
|
| 64 |
+
])
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def make_dataloader(dataset, batch_size: int, num_workers: int, use_weighted_sampler: bool):
|
| 68 |
+
"""
|
| 69 |
+
Создаёт DataLoader с опциональным WeightedRandomSampler.
|
| 70 |
+
|
| 71 |
+
Если use_weighted_sampler=True:
|
| 72 |
+
- семплирование идёт с учётом весов, возвращаемых dataset.get_sample_weights()
|
| 73 |
+
- shuffle выключается, так как порядок определяется сэмплером
|
| 74 |
+
"""
|
| 75 |
+
if use_weighted_sampler:
|
| 76 |
+
sample_weights = dataset.get_sample_weights().cpu()
|
| 77 |
+
sampler = WeightedRandomSampler(sample_weights, num_samples=len(dataset), replacement=True)
|
| 78 |
+
shuffle = False
|
| 79 |
+
else:
|
| 80 |
+
sampler = None
|
| 81 |
+
shuffle = True
|
| 82 |
+
|
| 83 |
+
return DataLoader(
|
| 84 |
+
dataset,
|
| 85 |
+
batch_size=batch_size,
|
| 86 |
+
num_workers=num_workers,
|
| 87 |
+
sampler=sampler,
|
| 88 |
+
shuffle=shuffle,
|
| 89 |
+
drop_last=True,
|
| 90 |
+
pin_memory=True,
|
| 91 |
+
persistent_workers=(num_workers > 0),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def make_model(num_classes: int, lr: float, weight_decay: float, max_epochs: int, weight_path: str = None):
|
| 96 |
+
"""
|
| 97 |
+
Конструктор LightningModule для backbone.
|
| 98 |
+
|
| 99 |
+
num_classes:
|
| 100 |
+
количество выходных нейронов (обычно 2: классификация + регрессия).
|
| 101 |
+
"""
|
| 102 |
+
return SyntaxLightningModule(
|
| 103 |
+
num_classes=num_classes,
|
| 104 |
+
lr=lr,
|
| 105 |
+
weight_decay=weight_decay,
|
| 106 |
+
max_epochs=max_epochs,
|
| 107 |
+
weight_path=weight_path,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def make_callbacks(phase: str):
|
| 112 |
+
"""
|
| 113 |
+
Создаёт список callback'ов для Trainer:
|
| 114 |
+
- мониторинг learning rate
|
| 115 |
+
- сохранение чекпоинтов по метрике val_rmse
|
| 116 |
+
"""
|
| 117 |
+
lr_monitor = LearningRateMonitor(logging_interval="epoch")
|
| 118 |
+
|
| 119 |
+
checkpoint = ModelCheckpoint(
|
| 120 |
+
monitor="val_rmse",
|
| 121 |
+
save_top_k=1 if phase == "pre" else 3,
|
| 122 |
+
mode="min",
|
| 123 |
+
filename="model-{epoch:02d}-{val_rmse:.3f}",
|
| 124 |
+
save_last=True,
|
| 125 |
+
)
|
| 126 |
+
return [lr_monitor, checkpoint]
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def make_trainer(max_epochs: int, logdir: str, logger_name: str, devices: list[int], precision: str):
|
| 130 |
+
"""
|
| 131 |
+
Создаёт объект Trainer с заданными параметрами:
|
| 132 |
+
- logdir: путь к директории для логов TensorBoard
|
| 133 |
+
- logger_name: имя поддиректории для текущего эксперимента
|
| 134 |
+
- devices: количество GPU-устройств
|
| 135 |
+
- precision: режим числовой точности (например, "bf16-mixed")
|
| 136 |
+
"""
|
| 137 |
+
logger = TensorBoardLogger(save_dir=logdir, name=logger_name)
|
| 138 |
+
|
| 139 |
+
# Если устройств больше одного — используем DDP, иначе оставляем стратегию по умолчанию
|
| 140 |
+
strategy = "ddp_find_unused_parameters_true" if len(devices) > 1 else "auto"
|
| 141 |
+
|
| 142 |
+
return pl.Trainer(
|
| 143 |
+
max_epochs=max_epochs,
|
| 144 |
+
accelerator="gpu" if torch.cuda.is_available() else "cpu",
|
| 145 |
+
devices=devices,
|
| 146 |
+
strategy=strategy,
|
| 147 |
+
precision=precision,
|
| 148 |
+
callbacks=[],
|
| 149 |
+
log_every_n_steps=10,
|
| 150 |
+
logger=logger,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@click.command()
|
| 155 |
+
@click.option(
|
| 156 |
+
"-r",
|
| 157 |
+
"--dataset-root",
|
| 158 |
+
type=click.Path(exists=True),
|
| 159 |
+
default=".",
|
| 160 |
+
show_default=True,
|
| 161 |
+
help="Корень датасета (JSON и DICOM-пути считаются относительно него).",
|
| 162 |
+
)
|
| 163 |
+
@click.option("--fold", type=int, default=4, show_default=True, help="Номер фолда.")
|
| 164 |
+
@click.option(
|
| 165 |
+
"-a",
|
| 166 |
+
"--artery",
|
| 167 |
+
type=str,
|
| 168 |
+
default="right",
|
| 169 |
+
show_default=True,
|
| 170 |
+
help="Название артерии: left или right.",
|
| 171 |
+
)
|
| 172 |
+
@click.option(
|
| 173 |
+
"-nc",
|
| 174 |
+
"--num-classes",
|
| 175 |
+
type=int,
|
| 176 |
+
default=2,
|
| 177 |
+
show_default=True,
|
| 178 |
+
help="Число выходных нейронов (обычно 2: clf + reg).",
|
| 179 |
+
)
|
| 180 |
+
@click.option("-b", "--batch-size", type=int, default=50, show_default=True, help="Размер batch.")
|
| 181 |
+
@click.option("-f", "--frames-per-clip", type=int, default=32, show_default=True, help="Число кадров в клипе.")
|
| 182 |
+
@click.option(
|
| 183 |
+
"-v",
|
| 184 |
+
"--video-size",
|
| 185 |
+
type=click.Tuple([int, int]),
|
| 186 |
+
default=(256, 256),
|
| 187 |
+
show_default=True,
|
| 188 |
+
help="Размер кадра (H, W).",
|
| 189 |
+
)
|
| 190 |
+
@click.option("--max-epochs", type=int, default=10, show_default=True, help="Число эпох для full train.")
|
| 191 |
+
@click.option("--num-workers", type=int, default=8, show_default=True, help="Число DataLoader workers.")
|
| 192 |
+
@click.option(
|
| 193 |
+
"--devices",
|
| 194 |
+
type=list[int],
|
| 195 |
+
multiple=True,
|
| 196 |
+
default=[0],
|
| 197 |
+
show_default=True,
|
| 198 |
+
help="Список GPU id",
|
| 199 |
+
)
|
| 200 |
+
@click.option("--precision", type=str, default="bf16-mixed", show_default=True, help="Режим точности.")
|
| 201 |
+
@click.option(
|
| 202 |
+
"--logdir",
|
| 203 |
+
type=click.Path(),
|
| 204 |
+
default="./logs/backbone",
|
| 205 |
+
show_default=True,
|
| 206 |
+
help="Каталог для логов и чекпоинтов backbone.",
|
| 207 |
+
)
|
| 208 |
+
@click.option(
|
| 209 |
+
"--use-weighted-sampler",
|
| 210 |
+
is_flag=True,
|
| 211 |
+
default=False,
|
| 212 |
+
show_default=True,
|
| 213 |
+
help="Использовать ли WeightedRandomSampler по score-интервалам.",
|
| 214 |
+
)
|
| 215 |
+
@click.option("--seed", type=int, default=42, show_default=True, help="Сид для воспроизводимости.")
|
| 216 |
+
def main(
|
| 217 |
+
dataset_root,
|
| 218 |
+
fold,
|
| 219 |
+
artery,
|
| 220 |
+
num_classes,
|
| 221 |
+
batch_size,
|
| 222 |
+
frames_per_clip,
|
| 223 |
+
video_size,
|
| 224 |
+
max_epochs,
|
| 225 |
+
num_workers,
|
| 226 |
+
devices,
|
| 227 |
+
precision,
|
| 228 |
+
logdir,
|
| 229 |
+
use_weighted_sampler,
|
| 230 |
+
seed,
|
| 231 |
+
):
|
| 232 |
+
"""
|
| 233 |
+
Точка входа для обучения backbone-модели.
|
| 234 |
+
|
| 235 |
+
Последовательность:
|
| 236 |
+
1) pretrain: обучение только финального слоя fc
|
| 237 |
+
2) full train: дообучение всей модели с началом из последнего чекпоинта pretrain.
|
| 238 |
+
"""
|
| 239 |
+
# Фиксация сида во всех поддерживаемых библиотеках
|
| 240 |
+
pl.seed_everything(seed)
|
| 241 |
+
|
| 242 |
+
artery = artery.lower()
|
| 243 |
+
artery_bin = {"left": 0, "right": 1}.get(artery)
|
| 244 |
+
if artery_bin is None:
|
| 245 |
+
raise ValueError(f"Unknown artery '{artery}', expected 'left' or 'right'")
|
| 246 |
+
|
| 247 |
+
# Статистики ImageNet для нормализации входа
|
| 248 |
+
imagenet_mean = [0.485, 0.456, 0.406]
|
| 249 |
+
imagenet_std = [0.229, 0.224, 0.225]
|
| 250 |
+
|
| 251 |
+
# Пути к JSON-метаданным фолдов относительно dataset_root
|
| 252 |
+
train_meta = f"folds/step2_fold{fold:02d}_train.json"
|
| 253 |
+
eval_meta = f"folds/step2_fold{fold:02d}_eval.json"
|
| 254 |
+
|
| 255 |
+
# Инициализация тренировочного датасета
|
| 256 |
+
train_set = SyntaxDataset(
|
| 257 |
+
root=dataset_root,
|
| 258 |
+
meta=train_meta,
|
| 259 |
+
train=True,
|
| 260 |
+
length=frames_per_clip,
|
| 261 |
+
label=f"syntax_{artery}",
|
| 262 |
+
artery_bin=artery_bin,
|
| 263 |
+
validation=False,
|
| 264 |
+
transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=True),
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# Инициализация валидационного датасета
|
| 268 |
+
val_set = SyntaxDataset(
|
| 269 |
+
root=dataset_root,
|
| 270 |
+
meta=eval_meta,
|
| 271 |
+
train=False,
|
| 272 |
+
length=frames_per_clip,
|
| 273 |
+
label=f"syntax_{artery}",
|
| 274 |
+
artery_bin=artery_bin,
|
| 275 |
+
validation=True,
|
| 276 |
+
transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=False),
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# DataLoader'ы: для pretrain можно брать увеличенный batch
|
| 280 |
+
train_loader_pre = make_dataloader(train_set, batch_size * 2, num_workers, use_weighted_sampler)
|
| 281 |
+
train_loader_post = make_dataloader(train_set, batch_size, num_workers, use_weighted_sampler)
|
| 282 |
+
val_loader = make_dataloader(val_set, 1, num_workers, use_weighted_sampler=False)
|
| 283 |
+
|
| 284 |
+
# Получаем форму видео (C, T, H, W) из одного batch для информации
|
| 285 |
+
x, *_ = next(iter(train_loader_pre))
|
| 286 |
+
video_shape = x.shape[1:]
|
| 287 |
+
print(f"Backbone input video shape: {video_shape}")
|
| 288 |
+
|
| 289 |
+
# Callback'и для pretrain и full train
|
| 290 |
+
callbacks_pre = make_callbacks(phase="pre")
|
| 291 |
+
callbacks_full = make_callbacks(phase="full")
|
| 292 |
+
|
| 293 |
+
# ------------------- Pretrain (fc only) -------------------
|
| 294 |
+
num_pre_epochs = 10
|
| 295 |
+
|
| 296 |
+
model_pre = make_model(
|
| 297 |
+
num_classes=num_classes,
|
| 298 |
+
lr=3e-4,
|
| 299 |
+
weight_decay=0.01,
|
| 300 |
+
max_epochs=num_pre_epochs,
|
| 301 |
+
weight_path=None,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
trainer_pre = make_trainer(
|
| 305 |
+
max_epochs=num_pre_epochs,
|
| 306 |
+
logdir=logdir,
|
| 307 |
+
logger_name=f"{artery}BinSyntax_R3D_pre_fold{fold:02d}",
|
| 308 |
+
devices=devices,
|
| 309 |
+
precision=precision,
|
| 310 |
+
)
|
| 311 |
+
trainer_pre.callbacks.extend(callbacks_pre)
|
| 312 |
+
trainer_pre.fit(model_pre, train_loader_pre, val_loader)
|
| 313 |
+
|
| 314 |
+
# ------------------- Full train (finetune) -------------------
|
| 315 |
+
model_full = make_model(
|
| 316 |
+
num_classes=num_classes,
|
| 317 |
+
lr=1e-4,
|
| 318 |
+
weight_decay=0.01,
|
| 319 |
+
max_epochs=max_epochs,
|
| 320 |
+
weight_path=trainer_pre.checkpoint_callback.last_model_path,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
trainer_full = make_trainer(
|
| 324 |
+
max_epochs=max_epochs,
|
| 325 |
+
logdir=logdir,
|
| 326 |
+
logger_name=f"{artery}BinSyntax_R3D_full_fold{fold:02d}",
|
| 327 |
+
devices=devices,
|
| 328 |
+
precision=precision,
|
| 329 |
+
)
|
| 330 |
+
trainer_full.callbacks.extend(callbacks_full)
|
| 331 |
+
trainer_full.fit(model_full, train_loader_post, val_loader)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
if __name__ == "__main__":
|
| 335 |
+
main()
|
backbone_weights/leftBinSyntax_R3D_full_fold00.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:299db032713b8b8247fa42e1bfcf993d6c0ea162a2d47211cd4ef83e8e7083ac
|
| 3 |
+
size 398135752
|
backbone_weights/leftBinSyntax_R3D_full_fold01.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fef1d7fdc6d2fc3e8b64b807c21a3d15af832bad8dfd1cbe59c1122c5f020a62
|
| 3 |
+
size 398135752
|
backbone_weights/leftBinSyntax_R3D_full_fold02.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:81b7f04e2d6735c505fe50101c2f7dedacb1503681c14f28c6fd8007bb7a4255
|
| 3 |
+
size 398135752
|
backbone_weights/leftBinSyntax_R3D_full_fold03.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7decf5e0fec66d734901a8cb0c47a09b0d7ec33a341f245caa7aeee0011162fc
|
| 3 |
+
size 398135752
|
backbone_weights/leftBinSyntax_R3D_full_fold04.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e85aa8b198f29e9375d93a37e11eb130b73b9d8a03f68a60c7aa7fd16d899f0f
|
| 3 |
+
size 398135752
|
backbone_weights/rightBinSyntax_R3D_full_fold00.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f6285307728b4c15a5f9ded2c311981d9c71f81bfbab4c11663fa73af57ffd35
|
| 3 |
+
size 398135752
|
backbone_weights/rightBinSyntax_R3D_full_fold01.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b77467117b220d4245d7c0bc00b9072719c0c563672a34a643d986986be32866
|
| 3 |
+
size 398135752
|
backbone_weights/rightBinSyntax_R3D_full_fold02.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3d0d1fb3aac64364e2e99b0c25d17bf1650b15b9da71cbba4ea6dd3ce2cb2bb0
|
| 3 |
+
size 398135752
|
backbone_weights/rightBinSyntax_R3D_full_fold03.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2c3387e5bb99b615d7c8d5a67b41d53f4cc4d872e82f7b391be4da3321490fdb
|
| 3 |
+
size 398135752
|
backbone_weights/rightBinSyntax_R3D_full_fold04.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7e3cf35b8d0450f23bb105d6367efc8ab32bc73a0bc4170e0fc07fbe54e593f9
|
| 3 |
+
size 398135752
|
full_model/__init__.py
ADDED
|
File without changes
|
full_model/rnn_dataset.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Any, Callable, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pydicom
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from torch.utils.data import Dataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
DTYPE = torch.float16
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SyntaxDataset(Dataset):
|
| 16 |
+
"""
|
| 17 |
+
Dataset для head‑модели (RNN/LSTM) поверх backbone.
|
| 18 |
+
|
| 19 |
+
Структура JSON:
|
| 20 |
+
[
|
| 21 |
+
{
|
| 22 |
+
"study_uid": "...",
|
| 23 |
+
"syntax_left": 12.5,
|
| 24 |
+
"syntax_right": 8.2,
|
| 25 |
+
"videos_left": [
|
| 26 |
+
{"path": "../data/anon_data/.../IM-0001-0001.dcm"},
|
| 27 |
+
...
|
| 28 |
+
],
|
| 29 |
+
"videos_right": [
|
| 30 |
+
{"path": "../data/anon_data/.../IM-0002-0001.dcm"},
|
| 31 |
+
...
|
| 32 |
+
],
|
| 33 |
+
},
|
| 34 |
+
...
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
ВАЖНО: поля "videos_{artery}[i]['path']" в JSON — пути к DICOM
|
| 38 |
+
относительно директории этого JSON (папка rnn_folds/).
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
root: str,
|
| 44 |
+
meta: str,
|
| 45 |
+
train: bool,
|
| 46 |
+
length: int,
|
| 47 |
+
label: str,
|
| 48 |
+
artery: str,
|
| 49 |
+
inference: bool = False,
|
| 50 |
+
validation: bool = False,
|
| 51 |
+
transform: Optional[Callable] = None,
|
| 52 |
+
) -> None:
|
| 53 |
+
# Корень датасета, если meta задан относительно него
|
| 54 |
+
self.root = Path(root).resolve()
|
| 55 |
+
# Режим: обучение/валидация
|
| 56 |
+
self.train = train
|
| 57 |
+
# Длина одного клипа (по времени)
|
| 58 |
+
self.length = int(length)
|
| 59 |
+
# Имя поля с численным score ("syntax_left"/"syntax_right" или пусто)
|
| 60 |
+
self.label = label
|
| 61 |
+
# Артерия: "left" или "right"
|
| 62 |
+
self.artery = artery.lower()
|
| 63 |
+
# Режим инференса: возвращать все клипы без случайного выбора
|
| 64 |
+
self.inference = inference
|
| 65 |
+
# Флаг: использовать только записи с положительным score
|
| 66 |
+
self.validation = validation
|
| 67 |
+
# Трансформации для каждого видео‑клипа
|
| 68 |
+
self.transform = transform
|
| 69 |
+
|
| 70 |
+
# Полный путь к JSON с метаданными
|
| 71 |
+
meta_path = Path(meta)
|
| 72 |
+
if not meta_path.is_absolute():
|
| 73 |
+
meta_path = self.root / meta_path
|
| 74 |
+
meta_path = meta_path.resolve()
|
| 75 |
+
|
| 76 |
+
# База для путей к DICOM — директория JSON
|
| 77 |
+
self.base_dir = meta_path.parent
|
| 78 |
+
|
| 79 |
+
print(f"RNN Dataset: root={self.root}, meta={meta_path}, base_dir={self.base_dir}")
|
| 80 |
+
|
| 81 |
+
# Загрузка JSON
|
| 82 |
+
with open(meta_path, "r", encoding="utf-8") as f:
|
| 83 |
+
dataset = json.load(f)
|
| 84 |
+
|
| 85 |
+
# Убираем записи без видео по указанной артерии в режиме обучения/валидации
|
| 86 |
+
if not self.inference:
|
| 87 |
+
dataset = [rec for rec in dataset if len(rec.get(f"videos_{self.artery}", [])) > 0]
|
| 88 |
+
|
| 89 |
+
# Для валидации при необходимости фильтруем только положительные score
|
| 90 |
+
if validation and self.label:
|
| 91 |
+
dataset = [rec for rec in dataset if float(rec.get(self.label, 0.0)) > 0]
|
| 92 |
+
|
| 93 |
+
self.dataset = dataset
|
| 94 |
+
print(f"RNN Dataset loaded: {len(self.dataset)} samples after filtering")
|
| 95 |
+
|
| 96 |
+
# Коды артерий для порогов
|
| 97 |
+
artery_bin = {"left": 0, "right": 1}.get(self.artery)
|
| 98 |
+
if artery_bin is None:
|
| 99 |
+
raise ValueError(f"Unknown artery '{artery}', expected 'left' or 'right'")
|
| 100 |
+
self.artery_bin = artery_bin
|
| 101 |
+
|
| 102 |
+
def __len__(self) -> int:
|
| 103 |
+
# Размер датасета = количество записей (исследований)
|
| 104 |
+
return len(self.dataset)
|
| 105 |
+
|
| 106 |
+
def get_sample_weights(self) -> Tensor:
|
| 107 |
+
"""
|
| 108 |
+
Возвращает веса примеров для WeightedRandomSampler по интервалам score.
|
| 109 |
+
|
| 110 |
+
Для каждой артерии задаются свои пороги, далее считается обратная частота
|
| 111 |
+
попадания в интервал.
|
| 112 |
+
"""
|
| 113 |
+
bin_thresholds = {
|
| 114 |
+
0: [0, 5, 10, 15], # левая
|
| 115 |
+
1: [0, 2, 5, 8], # правая
|
| 116 |
+
}
|
| 117 |
+
thr0, thr1, thr2, thr3 = bin_thresholds[self.artery_bin]
|
| 118 |
+
|
| 119 |
+
def in_bin(score: float) -> int:
|
| 120 |
+
if score == thr0:
|
| 121 |
+
return 0
|
| 122 |
+
if thr0 < score <= thr1:
|
| 123 |
+
return 1
|
| 124 |
+
if thr1 < score <= thr2:
|
| 125 |
+
return 2
|
| 126 |
+
if thr2 < score <= thr3:
|
| 127 |
+
return 3
|
| 128 |
+
return 4
|
| 129 |
+
|
| 130 |
+
scores = [float(rec.get(self.label, 0.0)) for rec in self.dataset]
|
| 131 |
+
bins = [in_bin(s) for s in scores]
|
| 132 |
+
|
| 133 |
+
counts = np.bincount(np.array(bins, dtype=np.int64), minlength=5)
|
| 134 |
+
total = int(counts.sum())
|
| 135 |
+
|
| 136 |
+
weights_by_bin = np.array(
|
| 137 |
+
[(total / counts[b]) if counts[b] > 0 else 0.0 for b in range(5)],
|
| 138 |
+
dtype=np.float64,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
weights = np.array([weights_by_bin[b] for b in bins], dtype=np.float64)
|
| 142 |
+
print(
|
| 143 |
+
"RNN sample weights counts:",
|
| 144 |
+
int(counts[0]),
|
| 145 |
+
int(counts[1]),
|
| 146 |
+
int(counts[2]),
|
| 147 |
+
int(counts[3]),
|
| 148 |
+
int(counts[4]),
|
| 149 |
+
)
|
| 150 |
+
return torch.as_tensor(weights, dtype=DTYPE)
|
| 151 |
+
|
| 152 |
+
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor, Any]:
|
| 153 |
+
"""
|
| 154 |
+
Возвращает:
|
| 155 |
+
clips: Tensor, стек клипов (N_clips, C, T, H, W) после transform
|
| 156 |
+
label: Tensor(1,), бинарная метка (0/1)
|
| 157 |
+
target: Tensor(1,), регрессионная цель (log1p(score))
|
| 158 |
+
suid: идентификатор исследования (study_uid)
|
| 159 |
+
"""
|
| 160 |
+
rec = self.dataset[idx]
|
| 161 |
+
suid = rec["study_uid"]
|
| 162 |
+
|
| 163 |
+
# Формируем метки, если имя поля задано
|
| 164 |
+
if self.label:
|
| 165 |
+
bin_thresholds = {
|
| 166 |
+
0: 15, # левая
|
| 167 |
+
1: 5, # правая
|
| 168 |
+
}
|
| 169 |
+
score = float(rec.get(self.label, 0.0))
|
| 170 |
+
label = torch.tensor(
|
| 171 |
+
[1.0 if score > bin_thresholds[self.artery_bin] else 0.0],
|
| 172 |
+
dtype=DTYPE,
|
| 173 |
+
)
|
| 174 |
+
target = torch.tensor([np.log1p(score)], dtype=DTYPE)
|
| 175 |
+
else:
|
| 176 |
+
# В режиме чистого инференса на неизвестных данных можно не задавать label/target
|
| 177 |
+
label = torch.tensor([0.0], dtype=DTYPE)
|
| 178 |
+
target = torch.tensor([0.0], dtype=DTYPE)
|
| 179 |
+
|
| 180 |
+
videos_list = rec.get(f"videos_{self.artery}", [])
|
| 181 |
+
nv = len(videos_list)
|
| 182 |
+
|
| 183 |
+
# Выбор индексов клипов
|
| 184 |
+
if self.inference:
|
| 185 |
+
# В инференсе возвращаем все доступные клипы
|
| 186 |
+
if nv == 0:
|
| 187 |
+
# Пустая последовательность в крайнем случае
|
| 188 |
+
return torch.zeros(0), label, target, suid
|
| 189 |
+
seq_indices = range(nv)
|
| 190 |
+
else:
|
| 191 |
+
if nv == 0:
|
| 192 |
+
raise ValueError(f"No videos for artery={self.artery} in record {suid}")
|
| 193 |
+
# Случайный набор индексов клипов (например, 4 штуки)
|
| 194 |
+
seq_indices = torch.randint(low=0, high=nv, size=(4,))
|
| 195 |
+
|
| 196 |
+
clips = []
|
| 197 |
+
for vi in seq_indices:
|
| 198 |
+
vi_idx = int(vi)
|
| 199 |
+
video_rec = videos_list[vi_idx]
|
| 200 |
+
rel_path = video_rec["path"]
|
| 201 |
+
|
| 202 |
+
# Полный путь к DICOM: директория JSON + относительный путь
|
| 203 |
+
full_path = (self.base_dir / rel_path).resolve()
|
| 204 |
+
if not full_path.exists():
|
| 205 |
+
raise FileNotFoundError(
|
| 206 |
+
f"DICOM not found: {full_path}\n"
|
| 207 |
+
f" base_dir={self.base_dir}\n"
|
| 208 |
+
f" rel_path='{rel_path}'\n"
|
| 209 |
+
f" study={suid}"
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# Загрузка DICOM
|
| 213 |
+
video = pydicom.dcmread(str(full_path)).pixel_array
|
| 214 |
+
|
| 215 |
+
# Ожидаем 3D‑массив по времени
|
| 216 |
+
if video.ndim != 3:
|
| 217 |
+
raise ValueError(f"Expected 3D video, got {video.shape} in {full_path}")
|
| 218 |
+
|
| 219 |
+
# Если время в последней оси, переносим её в первую
|
| 220 |
+
if video.shape[0] > 128 and video.shape[-1] <= 128:
|
| 221 |
+
video = np.moveaxis(video, -1, 0)
|
| 222 |
+
|
| 223 |
+
# Приведение к uint8
|
| 224 |
+
if video.dtype == np.uint16:
|
| 225 |
+
vmax = int(np.max(video))
|
| 226 |
+
if vmax <= 0:
|
| 227 |
+
raise ValueError(f"Invalid vmax={vmax} in {full_path}")
|
| 228 |
+
video = (video.astype(np.float32) * (255.0 / vmax)).clip(0, 255).astype(np.uint8)
|
| 229 |
+
else:
|
| 230 |
+
video = video.astype(np.uint8)
|
| 231 |
+
|
| 232 |
+
# Дублируем по времени до нужной длины
|
| 233 |
+
while video.shape[0] < self.length:
|
| 234 |
+
video = np.concatenate([video, video], axis=0)
|
| 235 |
+
|
| 236 |
+
t = int(video.shape[0])
|
| 237 |
+
if self.train:
|
| 238 |
+
# В обучении берём случайное окно
|
| 239 |
+
begin = torch.randint(low=0, high=t - self.length + 1, size=(1,)).item()
|
| 240 |
+
else:
|
| 241 |
+
# В валидации можно взять центральное окно
|
| 242 |
+
begin = (t - self.length) // 2
|
| 243 |
+
video = video[begin: begin + self.length]
|
| 244 |
+
|
| 245 |
+
# (T, H, W) → (T, H, W, C) с C=3
|
| 246 |
+
video = torch.from_numpy(np.stack([video, video, video], axis=-1))
|
| 247 |
+
|
| 248 |
+
# Применяем трансформации
|
| 249 |
+
if self.transform is not None:
|
| 250 |
+
video = self.transform(video) # ожидается (C, T, H, W)
|
| 251 |
+
|
| 252 |
+
clips.append(video)
|
| 253 |
+
|
| 254 |
+
# Стек клипов: (N_clips, C, T, H, W)
|
| 255 |
+
clips = torch.stack(clips, dim=0) if clips else torch.zeros(0, dtype=DTYPE)
|
| 256 |
+
|
| 257 |
+
return clips, label, target, suid
|
full_model/rnn_model.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, Optional, Tuple
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch import nn, optim
|
| 5 |
+
import lightning.pytorch as pl
|
| 6 |
+
import torchvision.models.video as tvmv
|
| 7 |
+
import sklearn.metrics as skm
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
The head of model
|
| 12 |
+
"""
|
| 13 |
+
class SyntaxLightningModule(pl.LightningModule):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
num_classes,
|
| 17 |
+
lr: float,
|
| 18 |
+
variant: str, # mean, lstm_mean, lstm_last, gru_mean, gru_last, bert_mean, bert_cls, bert_cls2
|
| 19 |
+
weight_decay: float = 0,
|
| 20 |
+
max_epochs: int = None,
|
| 21 |
+
weight_path: str = None, # веса backbone (r3d_18) (ckpt или pt)
|
| 22 |
+
save_path: str = None, # путь для сохранения по лучшему auc (как было)
|
| 23 |
+
pl_weight_path: str = None, # полный Lightning‑чекпоинт всей модели
|
| 24 |
+
pt_weights_format: bool = False, # .pt формат
|
| 25 |
+
sigma_a: float = 0,
|
| 26 |
+
sigma_b: float = 1,
|
| 27 |
+
**kwargs,
|
| 28 |
+
):
|
| 29 |
+
self.save_hyperparameters()
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.num_classes = num_classes
|
| 32 |
+
self.save_path = save_path
|
| 33 |
+
self.weight_path = weight_path
|
| 34 |
+
self.variant = variant
|
| 35 |
+
self.sigma_a = sigma_a
|
| 36 |
+
self.sigma_b = sigma_b
|
| 37 |
+
|
| 38 |
+
# Video ResNet (backbone)
|
| 39 |
+
self.model = tvmv.r3d_18(weights=tvmv.R3D_18_Weights.DEFAULT)
|
| 40 |
+
|
| 41 |
+
self.lr = lr
|
| 42 |
+
self.loss_clf = nn.BCEWithLogitsLoss(reduction="none")
|
| 43 |
+
self.loss_reg = nn.MSELoss(reduction="none")
|
| 44 |
+
|
| 45 |
+
# Финальный слой backbone по умолчанию: 2 выхода (clf + reg)
|
| 46 |
+
in_features = self.model.fc.in_features
|
| 47 |
+
self.model.fc = nn.Linear(in_features=in_features, out_features=2, bias=True)
|
| 48 |
+
|
| 49 |
+
# Загрузка pretrain‑весов backbone (r3d_18), как раньше через weight_path,
|
| 50 |
+
# но теперь поддерживаются и .ckpt, и .pt/.pth
|
| 51 |
+
if weight_path is not None:
|
| 52 |
+
print("Load model weights (backbone)")
|
| 53 |
+
self.load_weights_backbone(weight_path, self.model)
|
| 54 |
+
|
| 55 |
+
# Выбор типа head
|
| 56 |
+
if self.variant != "mean_out":
|
| 57 |
+
self.model.fc = nn.Identity()
|
| 58 |
+
|
| 59 |
+
if self.variant == "mean_out":
|
| 60 |
+
# Только self.model с fc=Linear(…, 2)
|
| 61 |
+
pass
|
| 62 |
+
elif self.variant in ("gru_mean", "gru_last"):
|
| 63 |
+
self.rnn = nn.GRU(in_features, in_features // 4, batch_first=True)
|
| 64 |
+
self.dropout = nn.Dropout(0.2)
|
| 65 |
+
self.fc = nn.Linear(in_features=in_features // 4, out_features=num_classes, bias=True)
|
| 66 |
+
elif self.variant in ("lstm_mean", "lstm_last"):
|
| 67 |
+
self.lstm = nn.LSTM(
|
| 68 |
+
input_size=in_features,
|
| 69 |
+
hidden_size=in_features // 4,
|
| 70 |
+
proj_size=num_classes,
|
| 71 |
+
batch_first=True,
|
| 72 |
+
)
|
| 73 |
+
elif self.variant == "mean":
|
| 74 |
+
self.fc = nn.Linear(in_features=in_features, out_features=num_classes, bias=True)
|
| 75 |
+
elif self.variant in ("bert_mean", "bert_cls", "bert_cls2"):
|
| 76 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 77 |
+
d_model=in_features,
|
| 78 |
+
nhead=4,
|
| 79 |
+
batch_first=True,
|
| 80 |
+
dim_feedforward=in_features // 4,
|
| 81 |
+
)
|
| 82 |
+
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
|
| 83 |
+
self.dropout = nn.Dropout(0.2)
|
| 84 |
+
self.fc = nn.Linear(in_features=in_features, out_features=num_classes, bias=True)
|
| 85 |
+
if self.variant == "bert_cls2":
|
| 86 |
+
self.cls = nn.Parameter(torch.randn(1, 1, in_features))
|
| 87 |
+
else:
|
| 88 |
+
raise ValueError(f"Unknown model variant {self.variant}")
|
| 89 |
+
|
| 90 |
+
# Загрузка полного Lightning‑чекпоинта (backbone + head), как было раньше
|
| 91 |
+
if pl_weight_path is not None:
|
| 92 |
+
print(f"Load LightningModule weights from {pl_weight_path}")
|
| 93 |
+
|
| 94 |
+
if pt_weights_format:
|
| 95 |
+
pl_state_dict = torch.load(pl_weight_path, weights_only=False)
|
| 96 |
+
else:
|
| 97 |
+
pl_state_dict = torch.load(pl_weight_path, weights_only=False)["state_dict"]
|
| 98 |
+
|
| 99 |
+
# Загружаем backbone
|
| 100 |
+
self.load_weights(pl_state_dict, self.model, "model")
|
| 101 |
+
|
| 102 |
+
# Загружаем head в зависимости от варианта
|
| 103 |
+
if self.variant == "mean_out":
|
| 104 |
+
pass # только self.model
|
| 105 |
+
elif self.variant in ("gru_mean", "gru_last"):
|
| 106 |
+
self.load_weights(pl_state_dict, self.rnn, "rnn")
|
| 107 |
+
self.load_weights(pl_state_dict, self.fc, "fc")
|
| 108 |
+
elif self.variant in ("lstm_mean", "lstm_last"):
|
| 109 |
+
self.load_weights(pl_state_dict, self.lstm, "lstm")
|
| 110 |
+
elif self.variant == "mean":
|
| 111 |
+
self.load_weights(pl_state_dict, self.fc, "fc")
|
| 112 |
+
elif self.variant in ("bert_mean", "bert_cls", "bert_cls2"):
|
| 113 |
+
self.load_weights(pl_state_dict, self.encoder, "encoder")
|
| 114 |
+
self.load_weights(pl_state_dict, self.fc, "fc")
|
| 115 |
+
if self.variant == "bert_cls2":
|
| 116 |
+
old_shape = self.cls.shape
|
| 117 |
+
self.cls = nn.Parameter(pl_state_dict["cls"])
|
| 118 |
+
assert old_shape == self.cls.shape
|
| 119 |
+
else:
|
| 120 |
+
raise ValueError(f"Unknown model variant {self.variant}")
|
| 121 |
+
|
| 122 |
+
self.max_epochs = max_epochs
|
| 123 |
+
self.weight_decay = weight_decay
|
| 124 |
+
|
| 125 |
+
self.y_val = []
|
| 126 |
+
self.p_val = []
|
| 127 |
+
self.r_val = []
|
| 128 |
+
self.ty_val = []
|
| 129 |
+
self.tp_val = []
|
| 130 |
+
|
| 131 |
+
def load_weights_backbone(self, weight_path: str, model: nn.Module) -> None:
|
| 132 |
+
"""
|
| 133 |
+
Универсальная загрузка весов backbone (r3d_18):
|
| 134 |
+
|
| 135 |
+
- если файл — Lightning‑чекпоинт (dict с 'state_dict'),
|
| 136 |
+
берём state_dict['state_dict'] и убираем префикс 'model.'.
|
| 137 |
+
- если файл — "голый" state_dict (.pt/.pth), сохранённый через model.state_dict(),
|
| 138 |
+
загружаем его напрямую.
|
| 139 |
+
|
| 140 |
+
При этом перед загрузкой убираем из state_dict все ключи с несовпадающим размером
|
| 141 |
+
(например, fc.weight/fc.bias при разном числе выходов).
|
| 142 |
+
"""
|
| 143 |
+
obj = torch.load(weight_path, weights_only=False, map_location="cpu")
|
| 144 |
+
|
| 145 |
+
if isinstance(obj, dict) and "state_dict" in obj:
|
| 146 |
+
raw_state = obj["state_dict"]
|
| 147 |
+
state_dict = {k.replace("model.", ""): v for k, v in raw_state.items()}
|
| 148 |
+
src_type = "lightning_checkpoint"
|
| 149 |
+
else:
|
| 150 |
+
state_dict = obj
|
| 151 |
+
src_type = "raw_state_dict"
|
| 152 |
+
|
| 153 |
+
current_state = model.state_dict()
|
| 154 |
+
filtered_state = {}
|
| 155 |
+
mismatched_keys = []
|
| 156 |
+
|
| 157 |
+
# Оставляем только те веса, у которых совпадает размер тензора
|
| 158 |
+
for k, v in state_dict.items():
|
| 159 |
+
if k in current_state and current_state[k].shape == v.shape:
|
| 160 |
+
filtered_state[k] = v
|
| 161 |
+
else:
|
| 162 |
+
# либо вообще нет такого ключа, либо размер не совпадает
|
| 163 |
+
mismatched_keys.append(k)
|
| 164 |
+
|
| 165 |
+
# Загружаем только совместимые веса
|
| 166 |
+
incompatible = model.load_state_dict(filtered_state, strict=False)
|
| 167 |
+
|
| 168 |
+
loaded_keys = [k for k in filtered_state.keys() if k not in incompatible.missing_keys]
|
| 169 |
+
print(
|
| 170 |
+
f"[Backbone] Loaded weights from '{weight_path}' "
|
| 171 |
+
f"(type={src_type}): {len(loaded_keys)} params, "
|
| 172 |
+
f"missing={len(incompatible.missing_keys)}, "
|
| 173 |
+
f"unexpected={len(incompatible.unexpected_keys)}, "
|
| 174 |
+
f"skipped_mismatched={len(mismatched_keys)}"
|
| 175 |
+
)
|
| 176 |
+
if mismatched_keys:
|
| 177 |
+
print(f"[Backbone] Size‑mismatched keys (skipped), example: {mismatched_keys[:5]}")
|
| 178 |
+
if incompatible.missing_keys:
|
| 179 |
+
print(f"[Backbone] Missing keys after filtering, example: {incompatible.missing_keys[:5]}")
|
| 180 |
+
if incompatible.unexpected_keys:
|
| 181 |
+
print(f"[Backbone] Unexpected keys after filtering, example: {incompatible.unexpected_keys[:5]}")
|
| 182 |
+
|
| 183 |
+
def load_weights(self, state_dict, module, prefix: str):
|
| 184 |
+
"""Фильтруем и грузим только те веса, которые относятся к конкретному модулю."""
|
| 185 |
+
module_state = {
|
| 186 |
+
k.replace(f"{prefix}.", ""): v
|
| 187 |
+
for k, v in state_dict.items()
|
| 188 |
+
if k.startswith(prefix)
|
| 189 |
+
}
|
| 190 |
+
missing, unexpected = module.load_state_dict(module_state, strict=False)
|
| 191 |
+
if missing:
|
| 192 |
+
print(f"Missing keys for {prefix}: {missing}")
|
| 193 |
+
if unexpected:
|
| 194 |
+
print(f"Unexpected keys for {prefix}: {unexpected}")
|
| 195 |
+
|
| 196 |
+
def forward(self, x):
|
| 197 |
+
# x: (batch, seq, C, T, H, W)
|
| 198 |
+
batch_seq_shape = x.shape[0:2]
|
| 199 |
+
x = torch.flatten(x, start_dim=0, end_dim=1) # (batch*seq, C, T, H, W)
|
| 200 |
+
x = self.model(x)
|
| 201 |
+
x = torch.unflatten(x, 0, batch_seq_shape) # (batch, seq, feat)
|
| 202 |
+
|
| 203 |
+
if self.variant == "mean_out":
|
| 204 |
+
x = torch.mean(x, dim=1)
|
| 205 |
+
elif self.variant in ("gru_mean", "gru_last"):
|
| 206 |
+
_all_outs_, [_last_out_] = self.rnn(x)
|
| 207 |
+
if self.variant == "gru_mean":
|
| 208 |
+
x = torch.mean(_all_outs_, dim=1)
|
| 209 |
+
else:
|
| 210 |
+
x = _last_out_
|
| 211 |
+
x = self.dropout(x)
|
| 212 |
+
x = self.fc(x)
|
| 213 |
+
elif self.variant in ("lstm_mean", "lstm_last"):
|
| 214 |
+
_all_outs_, (_last_out_, _last_state_) = self.lstm(x)
|
| 215 |
+
if self.variant == "lstm_mean":
|
| 216 |
+
x = torch.mean(_all_outs_, dim=1)
|
| 217 |
+
else:
|
| 218 |
+
x = _last_out_
|
| 219 |
+
elif self.variant == "mean":
|
| 220 |
+
x = torch.mean(x, dim=1)
|
| 221 |
+
x = self.fc(x)
|
| 222 |
+
elif self.variant in ("bert_mean", "bert_cls", "bert_cls2"):
|
| 223 |
+
if self.variant == "bert_cls":
|
| 224 |
+
x = F.pad(x, (0, 0, 1, 0), "constant", 0)
|
| 225 |
+
elif self.variant == "bert_cls2":
|
| 226 |
+
bs = x.size(0)
|
| 227 |
+
x = torch.cat([self.cls.expand(bs, -1, -1), x], dim=1)
|
| 228 |
+
x = self.encoder(x)
|
| 229 |
+
if self.variant == "bert_mean":
|
| 230 |
+
x = torch.mean(x, dim=1)
|
| 231 |
+
else:
|
| 232 |
+
x = x[:, 0, :]
|
| 233 |
+
x = self.dropout(x)
|
| 234 |
+
x = self.fc(x)
|
| 235 |
+
else:
|
| 236 |
+
raise ValueError(f"Unknown model variant {self.variant}")
|
| 237 |
+
|
| 238 |
+
return x
|
| 239 |
+
|
| 240 |
+
def training_step(self, batch, batch_idx):
|
| 241 |
+
x, y, target, path = batch
|
| 242 |
+
y_hat = self(x)
|
| 243 |
+
yp_clf = y_hat[:, 0:1]
|
| 244 |
+
yp_reg = y_hat[:, 1:]
|
| 245 |
+
|
| 246 |
+
weights_clf = torch.where(y > 0, 1.0, 0.2)
|
| 247 |
+
clf_loss = self.loss_clf(yp_clf, y)
|
| 248 |
+
clf_loss = (clf_loss * weights_clf).mean()
|
| 249 |
+
|
| 250 |
+
reg_loss_raw = self.loss_reg(yp_reg, target)
|
| 251 |
+
sigma = self.sigma_a * target + self.sigma_b
|
| 252 |
+
reg_loss = (reg_loss_raw / (sigma ** 2)).mean()
|
| 253 |
+
|
| 254 |
+
loss = clf_loss + 0.5 * reg_loss
|
| 255 |
+
|
| 256 |
+
y_pred = torch.sigmoid(yp_clf)
|
| 257 |
+
y_bin = torch.round(y.cpu().detach()).int()
|
| 258 |
+
y_pred_bin = torch.round(y_pred.cpu().detach()).int()
|
| 259 |
+
|
| 260 |
+
self.log("train_clf_loss", clf_loss, prog_bar=True, sync_dist=True)
|
| 261 |
+
self.log("train_val_loss", reg_loss, prog_bar=True, sync_dist=True)
|
| 262 |
+
self.log("train_full_loss", loss, prog_bar=True, sync_dist=True)
|
| 263 |
+
self.log("train_f1", skm.f1_score(y_bin, y_pred_bin, zero_division=0),
|
| 264 |
+
prog_bar=True, sync_dist=True)
|
| 265 |
+
self.log("train_acc", skm.accuracy_score(y_bin, y_pred_bin),
|
| 266 |
+
prog_bar=True, sync_dist=True)
|
| 267 |
+
|
| 268 |
+
return loss
|
| 269 |
+
|
| 270 |
+
def validation_step(self, batch, batch_idx):
|
| 271 |
+
x, y, target, path = batch
|
| 272 |
+
y_hat = self(x)
|
| 273 |
+
yp_clf = y_hat[:, 0:1]
|
| 274 |
+
yp_reg = y_hat[:, 1:]
|
| 275 |
+
|
| 276 |
+
loss = self.loss_clf(yp_clf, y)
|
| 277 |
+
reg_loss_raw = self.loss_reg(yp_reg, target)
|
| 278 |
+
loss = loss.mean()
|
| 279 |
+
|
| 280 |
+
y_pred = torch.sigmoid(yp_clf)
|
| 281 |
+
|
| 282 |
+
self.y_val.append(int(y[..., 0].cpu()))
|
| 283 |
+
self.p_val.append(float(y_pred[..., 0].cpu()))
|
| 284 |
+
self.r_val.append(round(float(y_pred[..., 0].cpu())))
|
| 285 |
+
|
| 286 |
+
self.ty_val.append(float(target[..., 0].cpu()))
|
| 287 |
+
self.tp_val.append(float(yp_reg[..., 0].cpu()))
|
| 288 |
+
|
| 289 |
+
clf_loss = self.loss_clf(yp_clf, y)
|
| 290 |
+
reg_loss_raw = self.loss_reg(yp_reg, target)
|
| 291 |
+
sigma = self.sigma_a * target + self.sigma_b
|
| 292 |
+
reg_loss = (reg_loss_raw / (sigma ** 2)).mean()
|
| 293 |
+
|
| 294 |
+
loss = clf_loss + 0.5 * reg_loss
|
| 295 |
+
|
| 296 |
+
return loss
|
| 297 |
+
|
| 298 |
+
def on_validation_epoch_end(self):
|
| 299 |
+
try:
|
| 300 |
+
auc = skm.roc_auc_score(self.y_val, self.p_val)
|
| 301 |
+
f1 = skm.f1_score(self.y_val, self.r_val, zero_division=0)
|
| 302 |
+
acc = skm.accuracy_score(self.y_val, self.r_val)
|
| 303 |
+
mae = skm.mean_absolute_error(self.y_val, self.r_val)
|
| 304 |
+
self.log("val_auc", auc, prog_bar=True, sync_dist=True)
|
| 305 |
+
self.log("val_f1", f1, prog_bar=True, sync_dist=True)
|
| 306 |
+
self.log("val_acc", acc, prog_bar=True, sync_dist=True)
|
| 307 |
+
self.log("val_mae", mae, prog_bar=True, sync_dist=True)
|
| 308 |
+
|
| 309 |
+
rmse = skm.root_mean_squared_error(self.ty_val, self.tp_val)
|
| 310 |
+
self.log("val_rmse", rmse, prog_bar=True, sync_dist=True)
|
| 311 |
+
|
| 312 |
+
except ValueError as err:
|
| 313 |
+
print(err)
|
| 314 |
+
print("Y_VAL", self.y_val)
|
| 315 |
+
print("P_VAL", self.p_val)
|
| 316 |
+
self.y_val.clear()
|
| 317 |
+
self.p_val.clear()
|
| 318 |
+
self.r_val.clear()
|
| 319 |
+
self.ty_val.clear()
|
| 320 |
+
self.tp_val.clear()
|
| 321 |
+
|
| 322 |
+
def on_train_epoch_end(self) -> None:
|
| 323 |
+
self.log(
|
| 324 |
+
"lr",
|
| 325 |
+
self.optimizers().optimizer.param_groups[0]["lr"],
|
| 326 |
+
on_step=False,
|
| 327 |
+
on_epoch=True,
|
| 328 |
+
sync_dist=True,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
def configure_optimizers(self):
|
| 332 |
+
# Сначала определяем, какие модули тренируем
|
| 333 |
+
if self.weight_path: # pretrain without video backbone
|
| 334 |
+
if self.variant == "mean_out":
|
| 335 |
+
trainable_modules = [self.model.fc]
|
| 336 |
+
elif self.variant in ("gru_mean", "gru_last"):
|
| 337 |
+
trainable_modules = [self.rnn, self.fc]
|
| 338 |
+
elif self.variant in ("lstm_mean", "lstm_last"):
|
| 339 |
+
trainable_modules = [self.lstm]
|
| 340 |
+
elif self.variant == "mean":
|
| 341 |
+
trainable_modules = [self.fc]
|
| 342 |
+
elif self.variant in ("bert_mean", "bert_cls", "bert_cls2"):
|
| 343 |
+
trainable_modules = [self.encoder, self.fc]
|
| 344 |
+
if self.variant == "bert_cls2":
|
| 345 |
+
trainable_modules.append(self.cls)
|
| 346 |
+
else:
|
| 347 |
+
trainable_modules = []
|
| 348 |
+
|
| 349 |
+
for param in self.parameters():
|
| 350 |
+
param.requires_grad = False
|
| 351 |
+
|
| 352 |
+
for m in trainable_modules:
|
| 353 |
+
for p in m.parameters():
|
| 354 |
+
p.requires_grad = True
|
| 355 |
+
|
| 356 |
+
params = [p for m in trainable_modules for p in m.parameters()]
|
| 357 |
+
else:
|
| 358 |
+
for param in self.parameters():
|
| 359 |
+
param.requires_grad = True
|
| 360 |
+
params = self.parameters()
|
| 361 |
+
|
| 362 |
+
optimizer = optim.Adam(params, lr=self.lr, weight_decay=self.weight_decay)
|
| 363 |
+
|
| 364 |
+
if self.max_epochs is not None:
|
| 365 |
+
lr_scheduler = optim.lr_scheduler.OneCycleLR(
|
| 366 |
+
optimizer=optimizer, max_lr=self.lr, total_steps=self.max_epochs
|
| 367 |
+
)
|
| 368 |
+
return [optimizer], [lr_scheduler]
|
| 369 |
+
else:
|
| 370 |
+
return optimizer
|
| 371 |
+
|
| 372 |
+
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
|
| 373 |
+
"""Инференс"""
|
| 374 |
+
x, y, target, path = batch
|
| 375 |
+
y_hat = self(x)
|
| 376 |
+
yp_clf = y_hat[:, 0:1]
|
| 377 |
+
yp_reg = y_hat[:, 1:]
|
| 378 |
+
y_pred = torch.sigmoid(yp_clf)
|
| 379 |
+
|
| 380 |
+
return {
|
| 381 |
+
"y": y,
|
| 382 |
+
"y_pred": torch.round(y_pred),
|
| 383 |
+
"y_prob": y_pred,
|
| 384 |
+
"y_reg": yp_reg,
|
| 385 |
+
"target": target,
|
| 386 |
+
}
|
full_model/rnn_train.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# full_model/rnn_train.py
|
| 2 |
+
import os
|
| 3 |
+
import glob
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
import click
|
| 7 |
+
import lightning.pytorch as pl
|
| 8 |
+
import torch
|
| 9 |
+
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
|
| 10 |
+
from lightning.pytorch.loggers import TensorBoardLogger
|
| 11 |
+
from pytorchvideo.transforms import Normalize, Permute, RandAugment
|
| 12 |
+
from torch.utils.data import DataLoader, WeightedRandomSampler
|
| 13 |
+
from torchvision.transforms import transforms as T
|
| 14 |
+
from torchvision.transforms._transforms_video import ToTensorVideo
|
| 15 |
+
from torchvision.transforms import InterpolationMode
|
| 16 |
+
|
| 17 |
+
from full_model.rnn_dataset import SyntaxDataset
|
| 18 |
+
from full_model.rnn_model import SyntaxLightningModule
|
| 19 |
+
|
| 20 |
+
torch.set_float32_matmul_precision("medium")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_transforms(video_size, imagenet_mean, imagenet_std, train: bool = True):
|
| 24 |
+
"""Аугментации/преобразования для клипов."""
|
| 25 |
+
interpolation_choices = [InterpolationMode.BILINEAR, InterpolationMode.BICUBIC]
|
| 26 |
+
|
| 27 |
+
if train:
|
| 28 |
+
return T.Compose([
|
| 29 |
+
ToTensorVideo(),
|
| 30 |
+
Permute(dims=[1, 0, 2, 3]), # (C, T, H, W) -> (T, C, H, W)
|
| 31 |
+
RandAugment(magnitude=10, num_layers=2),
|
| 32 |
+
T.RandomHorizontalFlip(),
|
| 33 |
+
Permute(dims=[1, 0, 2, 3]), # обратно (C, T, H, W)
|
| 34 |
+
T.RandomChoice([
|
| 35 |
+
T.Resize(size=video_size, interpolation=interp, antialias=True)
|
| 36 |
+
for interp in interpolation_choices
|
| 37 |
+
]),
|
| 38 |
+
Normalize(mean=imagenet_mean, std=imagenet_std),
|
| 39 |
+
])
|
| 40 |
+
else:
|
| 41 |
+
return T.Compose([
|
| 42 |
+
ToTensorVideo(),
|
| 43 |
+
T.Resize(size=video_size, interpolation=InterpolationMode.BICUBIC, antialias=True),
|
| 44 |
+
Normalize(mean=imagenet_mean, std=imagenet_std),
|
| 45 |
+
])
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def make_dataloader(dataset, batch_size: int, num_workers: int, use_weighted_sampler: bool):
|
| 49 |
+
"""DataLoader c опциональным WeightedRandomSampler по score."""
|
| 50 |
+
if use_weighted_sampler:
|
| 51 |
+
sample_weights = dataset.get_sample_weights().cpu()
|
| 52 |
+
sampler = WeightedRandomSampler(sample_weights, num_samples=len(dataset), replacement=True)
|
| 53 |
+
shuffle = False
|
| 54 |
+
else:
|
| 55 |
+
sampler = None
|
| 56 |
+
shuffle = False
|
| 57 |
+
|
| 58 |
+
return DataLoader(
|
| 59 |
+
dataset,
|
| 60 |
+
batch_size=batch_size,
|
| 61 |
+
num_workers=num_workers,
|
| 62 |
+
sampler=sampler,
|
| 63 |
+
shuffle=shuffle,
|
| 64 |
+
drop_last=True,
|
| 65 |
+
pin_memory=True,
|
| 66 |
+
persistent_workers=(num_workers > 0),
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def make_model(
|
| 71 |
+
num_classes: int,
|
| 72 |
+
lr: float,
|
| 73 |
+
variant: str,
|
| 74 |
+
weight_decay: float,
|
| 75 |
+
max_epochs: int,
|
| 76 |
+
weight_path: str | None = None,
|
| 77 |
+
pl_weight_path: str | None = None,
|
| 78 |
+
pt_weights_format: bool = False,
|
| 79 |
+
) -> SyntaxLightningModule:
|
| 80 |
+
"""
|
| 81 |
+
Создание head‑модели.
|
| 82 |
+
|
| 83 |
+
weight_path — pretrain для backbone (r3d_18), .pt или .ckpt.
|
| 84 |
+
pl_weight_path — полный чекпоинт head‑модели (Lightning .ckpt или raw .pt).
|
| 85 |
+
pt_weights_format=True → pl_weight_path — raw state_dict (.pt).
|
| 86 |
+
pt_weights_format=False → pl_weight_path — Lightning .ckpt с 'state_dict'.
|
| 87 |
+
"""
|
| 88 |
+
return SyntaxLightningModule(
|
| 89 |
+
num_classes=num_classes,
|
| 90 |
+
lr=lr,
|
| 91 |
+
variant=variant,
|
| 92 |
+
weight_decay=weight_decay,
|
| 93 |
+
max_epochs=max_epochs,
|
| 94 |
+
weight_path=weight_path,
|
| 95 |
+
pl_weight_path=pl_weight_path,
|
| 96 |
+
yulie_model=pt_weights_format, # параметр модели
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def make_callbacks(phase: str):
|
| 101 |
+
"""Callback'и: LR‑монитор + ModelCheckpoint по val_rmse."""
|
| 102 |
+
lr_monitor = LearningRateMonitor(logging_interval="epoch")
|
| 103 |
+
|
| 104 |
+
if phase == "pre":
|
| 105 |
+
checkpoint = ModelCheckpoint(
|
| 106 |
+
monitor="val_rmse",
|
| 107 |
+
save_top_k=1,
|
| 108 |
+
mode="min",
|
| 109 |
+
filename="rnn_model-{epoch:02d}-{val_rmse:.3f}",
|
| 110 |
+
save_last=True,
|
| 111 |
+
)
|
| 112 |
+
elif phase == "full":
|
| 113 |
+
checkpoint = ModelCheckpoint(
|
| 114 |
+
monitor="val_rmse",
|
| 115 |
+
save_top_k=3,
|
| 116 |
+
mode="min",
|
| 117 |
+
filename="rnn_model-{epoch:02d}-{val_rmse:.3f}",
|
| 118 |
+
save_last=True,
|
| 119 |
+
)
|
| 120 |
+
else:
|
| 121 |
+
raise ValueError(f"Unknown phase '{phase}', expected 'pre' or 'full'")
|
| 122 |
+
|
| 123 |
+
return [lr_monitor, checkpoint]
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def make_trainer(max_epochs: int, logdir: str, logger_name: str, devices: list[int], precision: str, callbacks):
|
| 127 |
+
"""Создание Trainer с TensorBoard‑логгером."""
|
| 128 |
+
logger = TensorBoardLogger(save_dir=logdir, name=logger_name)
|
| 129 |
+
strategy = "ddp_find_unused_parameters_true" if len(devices) > 1 else "auto"
|
| 130 |
+
|
| 131 |
+
trainer = pl.Trainer(
|
| 132 |
+
max_epochs=max_epochs,
|
| 133 |
+
accelerator="gpu" if torch.cuda.is_available() else "cpu",
|
| 134 |
+
devices=devices,
|
| 135 |
+
strategy=strategy,
|
| 136 |
+
precision=precision,
|
| 137 |
+
callbacks=callbacks,
|
| 138 |
+
log_every_n_steps=10,
|
| 139 |
+
logger=logger,
|
| 140 |
+
)
|
| 141 |
+
return trainer
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def find_backbone_ckpt_lightning(backbone_logdir: str, artery: str, fold: int, phase: str = "full") -> str:
|
| 145 |
+
"""
|
| 146 |
+
Ищет Lightning‑чекпоинт backbone в каталоге логов.
|
| 147 |
+
|
| 148 |
+
Ожидаемая структура:
|
| 149 |
+
backbone_logdir/
|
| 150 |
+
{artery}BinSyntax_R3D_{phase}_foldXX/version_*/checkpoints/*.ckpt
|
| 151 |
+
"""
|
| 152 |
+
logger_name = f"{artery}BinSyntax_R3D_{phase}_fold{fold:02d}"
|
| 153 |
+
pattern = os.path.join(backbone_logdir, logger_name, "version_*/checkpoints", "*.ckpt")
|
| 154 |
+
ckpts = glob.glob(pattern)
|
| 155 |
+
if not ckpts:
|
| 156 |
+
raise FileNotFoundError(
|
| 157 |
+
f"No backbone Lightning checkpoints found for\n"
|
| 158 |
+
f" artery={artery}, fold={fold}, phase={phase}\n"
|
| 159 |
+
f" in '{backbone_logdir}' (pattern: {pattern})"
|
| 160 |
+
)
|
| 161 |
+
best = max(ckpts, key=os.path.getctime)
|
| 162 |
+
print(f"[Backbone] Using Lightning checkpoint: {best}")
|
| 163 |
+
return best
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def build_backbone_pt_path(backbone_pt_dir: str, artery: str, fold: int) -> str:
|
| 167 |
+
"""
|
| 168 |
+
Строит путь к .pt‑файлу backbone по соглашению:
|
| 169 |
+
rightBinSyntax_R3D_full_fold00.pt
|
| 170 |
+
leftBinSyntax_R3D_full_fold00.pt
|
| 171 |
+
...
|
| 172 |
+
"""
|
| 173 |
+
fname = f"{artery}BinSyntax_R3D_full_fold{fold:02d}.pt"
|
| 174 |
+
path = os.path.join(backbone_pt_dir, fname)
|
| 175 |
+
if not os.path.exists(path):
|
| 176 |
+
raise FileNotFoundError(
|
| 177 |
+
f"Backbone .pt not found for artery={artery}, fold={fold} in '{backbone_pt_dir}'\n"
|
| 178 |
+
f"Expected file: {fname}"
|
| 179 |
+
)
|
| 180 |
+
print(f"[Backbone] Using .pt file: {path}")
|
| 181 |
+
return path
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
@click.command()
|
| 185 |
+
@click.option(
|
| 186 |
+
"-r",
|
| 187 |
+
"--dataset-root",
|
| 188 |
+
type=click.Path(exists=True),
|
| 189 |
+
default=".",
|
| 190 |
+
show_default=True,
|
| 191 |
+
help="Корень датасета (JSON и DICOM‑пути считаются относительно него).",
|
| 192 |
+
)
|
| 193 |
+
@click.option("--fold", type=int, default=4, show_default=True, help="Fold number.")
|
| 194 |
+
@click.option(
|
| 195 |
+
"-a",
|
| 196 |
+
"--artery",
|
| 197 |
+
type=str,
|
| 198 |
+
default="right",
|
| 199 |
+
show_default=True,
|
| 200 |
+
help="Артерия: left или right.",
|
| 201 |
+
)
|
| 202 |
+
@click.option(
|
| 203 |
+
"--variant",
|
| 204 |
+
type=str,
|
| 205 |
+
default="lstm_mean",
|
| 206 |
+
show_default=True,
|
| 207 |
+
help="Вариант head‑модели: mean_out, mean, lstm_mean, lstm_last, gru_mean, gru_last, bert_mean, bert_cls, bert_cls2.",
|
| 208 |
+
)
|
| 209 |
+
@click.option("-nc", "--num-classes", type=int, default=2, show_default=True,
|
| 210 |
+
help="Число выходов head‑модели (clf + reg).")
|
| 211 |
+
@click.option("-b", "--batch-size", type=int, default=8, show_default=True, help="Batch size.")
|
| 212 |
+
@click.option("-f", "--frames-per-clip", type=int, default=32, show_default=True,
|
| 213 |
+
help="Количество кадров в клипе.")
|
| 214 |
+
@click.option(
|
| 215 |
+
"-v",
|
| 216 |
+
"--video-size",
|
| 217 |
+
type=click.Tuple([int, int]),
|
| 218 |
+
default=(256, 256),
|
| 219 |
+
show_default=True,
|
| 220 |
+
help="Размер кадра (H, W).",
|
| 221 |
+
)
|
| 222 |
+
@click.option("--max-epochs", type=int, default=10, show_default=True, help="Число эпох full train.")
|
| 223 |
+
@click.option("--num-workers", type=int, default=16, show_default=True, help="DataLoader workers.")
|
| 224 |
+
@click.option(
|
| 225 |
+
"--devices",
|
| 226 |
+
type=list[int],
|
| 227 |
+
multiple=True,
|
| 228 |
+
default=[0],
|
| 229 |
+
show_default=True,
|
| 230 |
+
help="Список GPU id",
|
| 231 |
+
)
|
| 232 |
+
@click.option("--precision", type=str, default="bf16-mixed", show_default=True, help="Режим числовой точности.")
|
| 233 |
+
@click.option(
|
| 234 |
+
"--logdir",
|
| 235 |
+
type=click.Path(),
|
| 236 |
+
default="./logs/rnn",
|
| 237 |
+
show_default=True,
|
| 238 |
+
help="Каталог для логов и чекпоинтов head‑модели.",
|
| 239 |
+
)
|
| 240 |
+
@click.option(
|
| 241 |
+
"--backbone-logdir",
|
| 242 |
+
type=click.Path(exists=True),
|
| 243 |
+
default=None,
|
| 244 |
+
help="Каталог с логами backbone (Lightning .ckpt).",
|
| 245 |
+
)
|
| 246 |
+
@click.option(
|
| 247 |
+
"--backbone-pt-dir",
|
| 248 |
+
type=click.Path(exists=True),
|
| 249 |
+
default="backbone_weights",
|
| 250 |
+
show_default=True,
|
| 251 |
+
help="Каталог с .pt‑файлами backbone (rightBinSyntax_R3D_full_foldXX.pt, leftBinSyntax_R3D_full_foldXX.pt).",
|
| 252 |
+
)
|
| 253 |
+
@click.option(
|
| 254 |
+
"--backbone-from-pt",
|
| 255 |
+
is_flag=True,
|
| 256 |
+
default=True,
|
| 257 |
+
show_default=True,
|
| 258 |
+
help="Если включено — backbone берётся из .pt в backbone-pt-dir, иначе из Lightning‑логов backbone-logdir.",
|
| 259 |
+
)
|
| 260 |
+
@click.option(
|
| 261 |
+
"--rnn-folds-dir",
|
| 262 |
+
type=click.Path(),
|
| 263 |
+
default="rnn_folds",
|
| 264 |
+
show_default=True,
|
| 265 |
+
help="Каталог с rnn_folds (относительно dataset_root).",
|
| 266 |
+
)
|
| 267 |
+
@click.option(
|
| 268 |
+
"--use-weighted-sampler",
|
| 269 |
+
is_flag=True,
|
| 270 |
+
default=False,
|
| 271 |
+
show_default=True,
|
| 272 |
+
help="Использовать ли WeightedRandomSampler по score.",
|
| 273 |
+
)
|
| 274 |
+
@click.option(
|
| 275 |
+
"--pt-weights-format",
|
| 276 |
+
is_flag=True,
|
| 277 |
+
default=False,
|
| 278 |
+
show_default=True,
|
| 279 |
+
help="Формат pl_weight_path для full‑трейна: True → .pt (raw state_dict), False → Lightning .ckpt.",
|
| 280 |
+
)
|
| 281 |
+
@click.option("--seed", type=int, default=42, show_default=True, help="Random seed.")
|
| 282 |
+
def main(
|
| 283 |
+
dataset_root: str,
|
| 284 |
+
fold: int,
|
| 285 |
+
artery: str,
|
| 286 |
+
variant: str,
|
| 287 |
+
num_classes: int,
|
| 288 |
+
batch_size: int,
|
| 289 |
+
frames_per_clip: int,
|
| 290 |
+
video_size: Any,
|
| 291 |
+
max_epochs: int,
|
| 292 |
+
num_workers: int,
|
| 293 |
+
devices: int,
|
| 294 |
+
precision: str,
|
| 295 |
+
logdir: str,
|
| 296 |
+
backbone_logdir: str | None,
|
| 297 |
+
backbone_pt_dir: str | None,
|
| 298 |
+
backbone_from_pt: bool,
|
| 299 |
+
rnn_folds_dir: str,
|
| 300 |
+
use_weighted_sampler: bool,
|
| 301 |
+
pt_weights_format: bool,
|
| 302 |
+
seed: int,
|
| 303 |
+
):
|
| 304 |
+
"""Обучение RNN‑head поверх backbone."""
|
| 305 |
+
VARIANTS = "mean_out mean lstm_mean lstm_last gru_mean gru_last bert_mean bert_cls bert_cls2".split()
|
| 306 |
+
if variant not in VARIANTS:
|
| 307 |
+
raise ValueError(f"Unknown variant '{variant}', expected one of: {VARIANTS}")
|
| 308 |
+
|
| 309 |
+
artery = artery.lower()
|
| 310 |
+
if artery not in ("left", "right"):
|
| 311 |
+
raise ValueError(f"Unknown artery '{artery}', expected 'left' or 'right'")
|
| 312 |
+
|
| 313 |
+
pl.seed_everything(seed)
|
| 314 |
+
|
| 315 |
+
imagenet_mean = [0.485, 0.456, 0.406]
|
| 316 |
+
imagenet_std = [0.229, 0.224, 0.225]
|
| 317 |
+
|
| 318 |
+
train_meta = os.path.join(rnn_folds_dir, f"rnn_fold{fold:02d}_train.json")
|
| 319 |
+
eval_meta = os.path.join(rnn_folds_dir, f"rnn_fold{fold:02d}_eval.json")
|
| 320 |
+
|
| 321 |
+
train_set = SyntaxDataset(
|
| 322 |
+
root=dataset_root,
|
| 323 |
+
meta=train_meta,
|
| 324 |
+
train=True,
|
| 325 |
+
length=frames_per_clip,
|
| 326 |
+
label=f"syntax_{artery}",
|
| 327 |
+
artery=artery,
|
| 328 |
+
inference=False,
|
| 329 |
+
validation=True,
|
| 330 |
+
transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=True),
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
val_set = SyntaxDataset(
|
| 334 |
+
root=dataset_root,
|
| 335 |
+
meta=eval_meta,
|
| 336 |
+
train=False,
|
| 337 |
+
length=frames_per_clip,
|
| 338 |
+
label=f"syntax_{artery}",
|
| 339 |
+
artery=artery,
|
| 340 |
+
inference=False,
|
| 341 |
+
validation=True,
|
| 342 |
+
transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=False),
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
train_loader_pre = make_dataloader(train_set, batch_size * 2, num_workers, use_weighted_sampler)
|
| 346 |
+
train_loader_post = make_dataloader(train_set, batch_size, num_workers, use_weighted_sampler)
|
| 347 |
+
val_loader = make_dataloader(val_set, 1, num_workers, use_weighted_sampler=False)
|
| 348 |
+
|
| 349 |
+
x, *_ = next(iter(train_loader_pre))
|
| 350 |
+
video_shape = x.shape[2:]
|
| 351 |
+
print(f"RNN head input per clip: {video_shape}")
|
| 352 |
+
|
| 353 |
+
# Выбор источника backbone
|
| 354 |
+
if backbone_from_pt:
|
| 355 |
+
if backbone_pt_dir is None:
|
| 356 |
+
raise ValueError("backbone-from-pt=True, но backbone-pt-dir не указан.")
|
| 357 |
+
backbone_weight_path = build_backbone_pt_path(backbone_pt_dir, artery=artery, fold=fold)
|
| 358 |
+
else:
|
| 359 |
+
if backbone_logdir is None:
|
| 360 |
+
raise ValueError("backbone-from-pt=False, но backbone-logdir не указан.")
|
| 361 |
+
backbone_weight_path = find_backbone_ckpt_lightning(
|
| 362 |
+
backbone_logdir=backbone_logdir,
|
| 363 |
+
artery=artery,
|
| 364 |
+
fold=fold,
|
| 365 |
+
phase="full",
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
# Pretrain head (замороженный backbone)
|
| 369 |
+
callbacks_pre = make_callbacks(phase="pre")
|
| 370 |
+
|
| 371 |
+
model_pre = make_model(
|
| 372 |
+
num_classes=num_classes,
|
| 373 |
+
lr=1e-4,
|
| 374 |
+
variant=variant,
|
| 375 |
+
weight_decay=0.01,
|
| 376 |
+
max_epochs=max_epochs,
|
| 377 |
+
weight_path=backbone_weight_path,
|
| 378 |
+
pl_weight_path=None,
|
| 379 |
+
pt_weights_format=False,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
trainer_pre = make_trainer(
|
| 383 |
+
max_epochs=max_epochs,
|
| 384 |
+
logdir=logdir,
|
| 385 |
+
logger_name=f"{artery}BinSyntax_R3D_fold{fold:02d}_{variant}_pre",
|
| 386 |
+
devices=devices,
|
| 387 |
+
precision=precision,
|
| 388 |
+
callbacks=callbacks_pre,
|
| 389 |
+
)
|
| 390 |
+
trainer_pre.fit(model_pre, train_dataloaders=train_loader_pre, val_dataloaders=val_loader)
|
| 391 |
+
|
| 392 |
+
# Full train head
|
| 393 |
+
callbacks_full = make_callbacks(phase="full")
|
| 394 |
+
|
| 395 |
+
model_full = make_model(
|
| 396 |
+
num_classes=num_classes,
|
| 397 |
+
lr=2e-5,
|
| 398 |
+
variant=variant,
|
| 399 |
+
weight_decay=0.01,
|
| 400 |
+
max_epochs=max_epochs,
|
| 401 |
+
weight_path=None,
|
| 402 |
+
pl_weight_path=trainer_pre.checkpoint_callback.best_model_path,
|
| 403 |
+
pt_weights_format=pt_weights_format,
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
trainer_full = make_trainer(
|
| 407 |
+
max_epochs=max_epochs,
|
| 408 |
+
logdir=logdir,
|
| 409 |
+
logger_name=f"{artery}BinSyntax_R3D_fold{fold:02d}_{variant}_post",
|
| 410 |
+
devices=devices,
|
| 411 |
+
precision=precision,
|
| 412 |
+
callbacks=callbacks_full,
|
| 413 |
+
)
|
| 414 |
+
trainer_full.fit(model_full, train_dataloaders=train_loader_post, val_dataloaders=val_loader)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
if __name__ == "__main__":
|
| 418 |
+
main()
|
full_model_weights/LeftBinSyntax_R3D_fold00_lstm_mean_post_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a1e9b585f99e863620185bc88f724488f0ac09e8cf25aa8ddc9a120fd893a99b
|
| 3 |
+
size 133809489
|
full_model_weights/LeftBinSyntax_R3D_fold01_lstm_mean_post_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6e8132d382b352afd22bb3fe3e1d4d7e3a9781f8a927018db9bc5085d0e8d109
|
| 3 |
+
size 133809489
|
full_model_weights/LeftBinSyntax_R3D_fold02_lstm_mean_post_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c18065b289591280a8051af4aa4124ffa03a42df69c876f125be5abaf7912486
|
| 3 |
+
size 133809489
|
full_model_weights/LeftBinSyntax_R3D_fold03_lstm_mean_post_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cf88400ea1a47d5a376ef9b366d0126036ddd02e3c1ba9071522fbac42b941bc
|
| 3 |
+
size 133809489
|
full_model_weights/LeftBinSyntax_R3D_fold04_lstm_mean_post_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eea9012bb3d975ee9cb2d5b651e401c47801400d3a2cc6da04b3c1f761be1793
|
| 3 |
+
size 133809489
|
full_model_weights/RightBinSyntax_R3D_fold00_lstm_mean_post_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1f47a7bd428011d8570bfcca7b36ff8cc074d07303e0e6477a0b9e398d72bbe2
|
| 3 |
+
size 133809614
|
full_model_weights/RightBinSyntax_R3D_fold01_lstm_mean_post_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a3f0857954b5af689d6582fb95bad39f7e9e7575bc49187cc36dbd19211847dc
|
| 3 |
+
size 133809614
|
full_model_weights/RightBinSyntax_R3D_fold02_lstm_mean_post_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:59b3b3719e67aed73e8d087f1e748da852f7c30a390610d4b4f676d60fbc3f89
|
| 3 |
+
size 133809614
|
full_model_weights/RightBinSyntax_R3D_fold03_lstm_mean_post_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:15adb5ce5a9f3473dfcead2ac841c3c3e83ef9ad46dea9d4fa7f373e3b0df7c5
|
| 3 |
+
size 133809614
|
full_model_weights/RightBinSyntax_R3D_fold04_lstm_mean_post_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c2312b11b6c865bf8f751a421557b1a9813a33f71906978518588d0a4d193f3e
|
| 3 |
+
size 133809614
|
inference/__init__.py
ADDED
|
File without changes
|
inference/metrics_visualization.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SYNTAX predictions visualization:
|
| 3 |
+
- points (SYNTAX ground truth vs model predictions) for multiple datasets;
|
| 4 |
+
- risk zones (low / high risk);
|
| 5 |
+
- ±σ and ±2σ bands around the diagonal;
|
| 6 |
+
- logistic trends for each dataset.
|
| 7 |
+
|
| 8 |
+
The script is independent of PyTorch/Lightning and is used at inference time.
|
| 9 |
+
Output is saved to the `visualizations/` folder inside the project.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import numpy as np
|
| 14 |
+
import plotly.graph_objects as go
|
| 15 |
+
from scipy.optimize import curve_fit # type: ignore
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ================= GLOBAL STYLE CONSTANTS =================
|
| 19 |
+
|
| 20 |
+
DATA_MIN = 0.0
|
| 21 |
+
DATA_MAX = 60.0
|
| 22 |
+
PADDING = 0.5
|
| 23 |
+
|
| 24 |
+
SIGMA_SLOPE = 0.15
|
| 25 |
+
SIGMA_BASE = 1.4
|
| 26 |
+
SIGMA_POINTS = 400
|
| 27 |
+
TREND_POINTS = 500
|
| 28 |
+
|
| 29 |
+
PLOT_WIDTH = 980
|
| 30 |
+
PLOT_HEIGHT = 980
|
| 31 |
+
|
| 32 |
+
# Fonts
|
| 33 |
+
FONT_FAMILY = "Inter, Roboto, Helvetica Neue, Arial, sans-serif"
|
| 34 |
+
BASE_FONT_SIZE = 20
|
| 35 |
+
TITLE_FONT_SIZE = 26
|
| 36 |
+
AXIS_TITLE_FONT_SIZE = 32
|
| 37 |
+
AXIS_TICK_FONT_SIZE = 30
|
| 38 |
+
LEGEND_FONT_SIZE = 20
|
| 39 |
+
|
| 40 |
+
# Markers / lines
|
| 41 |
+
MARKER_SIZE = 15
|
| 42 |
+
MARKER_LINE_WIDTH = 1.5
|
| 43 |
+
LINE_WIDTH = 3
|
| 44 |
+
TREND_LINE_WIDTH = 3.5
|
| 45 |
+
|
| 46 |
+
# Colors
|
| 47 |
+
PLOT_BG_COLOR = "rgba(235,238,245,1)"
|
| 48 |
+
PAPER_BG_COLOR = "white"
|
| 49 |
+
LEGEND_BG_COLOR = "rgba(255,255,255,0.45)"
|
| 50 |
+
GRID_COLOR = "rgba(100,116,139,0.18)"
|
| 51 |
+
|
| 52 |
+
# Layout
|
| 53 |
+
MARGIN_LEFT = 100
|
| 54 |
+
MARGIN_RIGHT = 15
|
| 55 |
+
MARGIN_TOP = 0
|
| 56 |
+
MARGIN_BOTTOM = 100
|
| 57 |
+
|
| 58 |
+
LEGEND_X = 0.008
|
| 59 |
+
LEGEND_Y = 0.985
|
| 60 |
+
|
| 61 |
+
COLORS = ["#1E88E5", "#8E24AA", "#A0D137", "#EA1D1D", "#06EE0D", "#FB8C00"]
|
| 62 |
+
SYMBOLS = ["circle", "x", "square", "diamond", "triangle-up", "star"]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _logistic_time(t, R0, Rmax, t50, k):
|
| 66 |
+
"""Logistic function over SYNTAX score."""
|
| 67 |
+
t = np.asarray(t, dtype=float)
|
| 68 |
+
t_safe = np.where(t <= 0, 1e-3, t)
|
| 69 |
+
return R0 + (Rmax - R0) / (1.0 + (t50 / t_safe) ** k)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _fit_logistic(x, y, domain, n=TREND_POINTS):
|
| 73 |
+
"""
|
| 74 |
+
Fit logistic curve.
|
| 75 |
+
Returns X, Y or (None, None) if the fit fails.
|
| 76 |
+
"""
|
| 77 |
+
x = np.asarray(x, dtype=float)
|
| 78 |
+
y = np.asarray(y, dtype=float)
|
| 79 |
+
m = np.isfinite(x) & np.isfinite(y)
|
| 80 |
+
if m.sum() < 4:
|
| 81 |
+
return None, None
|
| 82 |
+
|
| 83 |
+
x_m, y_m = x[m], y[m]
|
| 84 |
+
x_min = max(float(np.min(x_m)), float(domain[0]))
|
| 85 |
+
x_max = min(float(np.max(x_m)), float(domain[1]))
|
| 86 |
+
if not np.isfinite(x_min) or not np.isfinite(x_max) or x_max <= x_min:
|
| 87 |
+
return None, None
|
| 88 |
+
|
| 89 |
+
x_pos = x_m[x_m > 0]
|
| 90 |
+
if x_pos.size == 0:
|
| 91 |
+
return None, None
|
| 92 |
+
|
| 93 |
+
R0_init = float(np.percentile(y_m, 10))
|
| 94 |
+
Rmax_init = float(np.percentile(y_m, 90))
|
| 95 |
+
t50_init = float(np.median(x_pos))
|
| 96 |
+
k_init = 1.0
|
| 97 |
+
|
| 98 |
+
lower = [-10.0, 0.0, 1e-3, 0.01]
|
| 99 |
+
upper = [60.0, 80.0, 60.0, 10.0]
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
popt, _ = curve_fit(
|
| 103 |
+
_logistic_time,
|
| 104 |
+
x_m,
|
| 105 |
+
y_m,
|
| 106 |
+
p0=[R0_init, Rmax_init, t50_init, k_init],
|
| 107 |
+
bounds=(lower, upper),
|
| 108 |
+
maxfev=20000,
|
| 109 |
+
)
|
| 110 |
+
except Exception:
|
| 111 |
+
return None, None
|
| 112 |
+
|
| 113 |
+
X = np.linspace(x_min, x_max, n)
|
| 114 |
+
Y = _logistic_time(X, *popt)
|
| 115 |
+
return X, Y
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def visualize_final_syntax_plotly_multi(
|
| 119 |
+
datasets,
|
| 120 |
+
r2_values, # Pearson per dataset
|
| 121 |
+
gt_row,
|
| 122 |
+
postfix=None,
|
| 123 |
+
threshold: float = 22.0,
|
| 124 |
+
recall_values=None,
|
| 125 |
+
backbone: bool = False,
|
| 126 |
+
show_title: bool = False,
|
| 127 |
+
):
|
| 128 |
+
"""
|
| 129 |
+
Unified SYNTAX visualization: points, risk zones and logistic trends.
|
| 130 |
+
|
| 131 |
+
Parameters
|
| 132 |
+
----------
|
| 133 |
+
datasets : dict[str, tuple[list[float], list[float]]]
|
| 134 |
+
{dataset_name: (syntax_true_list, syntax_pred_list)}.
|
| 135 |
+
r2_values : dict[str, float]
|
| 136 |
+
Pearson correlation per dataset.
|
| 137 |
+
gt_row : str
|
| 138 |
+
String for the plot title (e.g. "ENSEMBLE" or "BOTH").
|
| 139 |
+
postfix : str | None
|
| 140 |
+
Suffix for the saved file name.
|
| 141 |
+
threshold : float
|
| 142 |
+
SYNTAX threshold (typically 22.0) to separate risk zones.
|
| 143 |
+
recall_values : dict[str, float] | None
|
| 144 |
+
Mean recall per dataset (may be None).
|
| 145 |
+
backbone : bool
|
| 146 |
+
If True, saves into `visualizations/backbone`, else into `visualizations/`.
|
| 147 |
+
"""
|
| 148 |
+
fig = go.Figure()
|
| 149 |
+
|
| 150 |
+
line_min = DATA_MIN - PADDING
|
| 151 |
+
line_max = DATA_MAX + PADDING
|
| 152 |
+
domain = (line_min, line_max)
|
| 153 |
+
|
| 154 |
+
base_font = dict(
|
| 155 |
+
family=FONT_FAMILY,
|
| 156 |
+
size=BASE_FONT_SIZE,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# ---------- Risk zones and bands (legendrank=0) ----------
|
| 160 |
+
fig.add_trace(
|
| 161 |
+
go.Scatter(
|
| 162 |
+
x=[line_min, threshold, threshold, line_min],
|
| 163 |
+
y=[line_min, line_min, threshold, threshold],
|
| 164 |
+
fill="toself",
|
| 165 |
+
fillcolor="rgba(255, 82, 82, 0.12)",
|
| 166 |
+
line=dict(color="rgba(0,0,0,0)"),
|
| 167 |
+
name="Low-risk zone",
|
| 168 |
+
legendgroup="zones",
|
| 169 |
+
legendgrouptitle_text="Thresholds & lines",
|
| 170 |
+
showlegend=True,
|
| 171 |
+
hoverinfo="skip",
|
| 172 |
+
legendrank=0,
|
| 173 |
+
)
|
| 174 |
+
)
|
| 175 |
+
fig.add_trace(
|
| 176 |
+
go.Scatter(
|
| 177 |
+
x=[threshold, line_max, line_max, threshold],
|
| 178 |
+
y=[threshold, threshold, line_max, line_max],
|
| 179 |
+
fill="toself",
|
| 180 |
+
fillcolor="rgba(76, 175, 80, 0.14)",
|
| 181 |
+
line=dict(color="rgba(0,0,0,0)"),
|
| 182 |
+
name="High-risk zone",
|
| 183 |
+
legendgroup="zones",
|
| 184 |
+
showlegend=True,
|
| 185 |
+
hoverinfo="skip",
|
| 186 |
+
legendrank=0,
|
| 187 |
+
)
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
fig.add_trace(
|
| 191 |
+
go.Scatter(
|
| 192 |
+
x=[threshold, threshold, None, line_min, line_max],
|
| 193 |
+
y=[line_min, line_max, None, threshold, threshold],
|
| 194 |
+
mode="lines",
|
| 195 |
+
name=f"SYNTAX = {threshold}",
|
| 196 |
+
legendgroup="zones",
|
| 197 |
+
showlegend=True,
|
| 198 |
+
line=dict(color="rgba(46,125,50,0.85)", width=LINE_WIDTH, dash="dash"),
|
| 199 |
+
legendrank=0,
|
| 200 |
+
hoverinfo="skip",
|
| 201 |
+
)
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
x_vals = np.linspace(line_min, line_max, SIGMA_POINTS)
|
| 205 |
+
sigma_upper = x_vals + SIGMA_BASE + SIGMA_SLOPE * x_vals
|
| 206 |
+
sigma_lower = x_vals - SIGMA_BASE - SIGMA_SLOPE * x_vals
|
| 207 |
+
two_sigma_upper = x_vals + 2 * SIGMA_BASE + 2 * SIGMA_SLOPE * x_vals
|
| 208 |
+
two_sigma_lower = x_vals - 2 * SIGMA_BASE - 2 * SIGMA_SLOPE * x_vals
|
| 209 |
+
|
| 210 |
+
fig.add_trace(
|
| 211 |
+
go.Scatter(
|
| 212 |
+
x=np.concatenate([x_vals, x_vals[::-1]]),
|
| 213 |
+
y=np.concatenate([two_sigma_lower, two_sigma_upper[::-1]]),
|
| 214 |
+
fill="toself",
|
| 215 |
+
fillcolor="rgba(255,193,7,0.18)",
|
| 216 |
+
line=dict(color="rgba(0,0,0,0)"),
|
| 217 |
+
name="± 2σ",
|
| 218 |
+
legendgroup="zones",
|
| 219 |
+
showlegend=True,
|
| 220 |
+
hoverinfo="skip",
|
| 221 |
+
legendrank=0,
|
| 222 |
+
)
|
| 223 |
+
)
|
| 224 |
+
fig.add_trace(
|
| 225 |
+
go.Scatter(
|
| 226 |
+
x=np.concatenate([x_vals, x_vals[::-1]]),
|
| 227 |
+
y=np.concatenate([sigma_lower, sigma_upper[::-1]]),
|
| 228 |
+
fill="toself",
|
| 229 |
+
fillcolor="rgba(255,152,0,0.30)",
|
| 230 |
+
line=dict(color="rgba(0,0,0,0)"),
|
| 231 |
+
name="± σ",
|
| 232 |
+
legendgroup="zones",
|
| 233 |
+
showlegend=True,
|
| 234 |
+
hoverinfo="skip",
|
| 235 |
+
legendrank=0,
|
| 236 |
+
)
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
fig.add_trace(
|
| 240 |
+
go.Scatter(
|
| 241 |
+
x=[line_min, line_max],
|
| 242 |
+
y=[line_min, line_max],
|
| 243 |
+
mode="lines",
|
| 244 |
+
name="Perfect prediction",
|
| 245 |
+
legendgroup="zones",
|
| 246 |
+
showlegend=True,
|
| 247 |
+
line=dict(color="rgba(30,30,30,0.85)", width=LINE_WIDTH),
|
| 248 |
+
legendrank=0,
|
| 249 |
+
)
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
# ---------- Datasets (legendrank=20) ----------
|
| 253 |
+
first_dataset = True
|
| 254 |
+
for i, (label, (syntax_true, syntax_pred)) in enumerate(datasets.items()):
|
| 255 |
+
x = np.array(syntax_true, dtype=float)
|
| 256 |
+
y = np.array(syntax_pred, dtype=float)
|
| 257 |
+
if x.size == 0 or y.size == 0:
|
| 258 |
+
continue
|
| 259 |
+
|
| 260 |
+
pearson = r2_values.get(label, None)
|
| 261 |
+
recall = recall_values.get(label, None) if recall_values else None
|
| 262 |
+
hover_lines = [f"<b>{label}</b>"]
|
| 263 |
+
if pearson is not None:
|
| 264 |
+
hover_lines.append(f"Pearson = {pearson:.3f}")
|
| 265 |
+
if recall is not None:
|
| 266 |
+
hover_lines.append(f"Mean recall = {recall:.3f}")
|
| 267 |
+
hovertemplate = (
|
| 268 |
+
"<br>".join(hover_lines)
|
| 269 |
+
+ "<br>Ground truth: %{x:.3f}<br>Prediction: %{y:.3f}<extra></extra>"
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
fig.add_trace(
|
| 273 |
+
go.Scatter(
|
| 274 |
+
x=x,
|
| 275 |
+
y=y,
|
| 276 |
+
mode="markers",
|
| 277 |
+
name=label,
|
| 278 |
+
legendgroup="datasets",
|
| 279 |
+
legendgrouptitle_text=("Datasets" if first_dataset else None),
|
| 280 |
+
showlegend=True,
|
| 281 |
+
marker=dict(
|
| 282 |
+
color=COLORS[i % len(COLORS)],
|
| 283 |
+
size=MARKER_SIZE,
|
| 284 |
+
opacity=0.96,
|
| 285 |
+
symbol=SYMBOLS[i % len(SYMBOLS)],
|
| 286 |
+
line=dict(
|
| 287 |
+
width=MARKER_LINE_WIDTH,
|
| 288 |
+
color="rgba(255,255,255,0.95)",
|
| 289 |
+
),
|
| 290 |
+
),
|
| 291 |
+
hovertemplate=hovertemplate,
|
| 292 |
+
legendrank=20,
|
| 293 |
+
)
|
| 294 |
+
)
|
| 295 |
+
first_dataset = False
|
| 296 |
+
|
| 297 |
+
# ---------- Logistic trends (legendrank=30) ----------
|
| 298 |
+
first_trend = True
|
| 299 |
+
for i, (label, (syntax_true, syntax_pred)) in enumerate(datasets.items()):
|
| 300 |
+
x = np.array(syntax_true, dtype=float)
|
| 301 |
+
y = np.array(syntax_pred, dtype=float)
|
| 302 |
+
if x.size == 0 or y.size == 0:
|
| 303 |
+
continue
|
| 304 |
+
|
| 305 |
+
Xc, Yc = _fit_logistic(x, y, domain=domain)
|
| 306 |
+
if Xc is not None:
|
| 307 |
+
fig.add_trace(
|
| 308 |
+
go.Scatter(
|
| 309 |
+
x=Xc,
|
| 310 |
+
y=Yc,
|
| 311 |
+
mode="lines",
|
| 312 |
+
name=label,
|
| 313 |
+
legendgroup="trends",
|
| 314 |
+
legendgrouptitle_text=("Logistic trends" if first_trend else None),
|
| 315 |
+
showlegend=True,
|
| 316 |
+
line=dict(
|
| 317 |
+
color=COLORS[i % len(COLORS)],
|
| 318 |
+
width=TREND_LINE_WIDTH,
|
| 319 |
+
),
|
| 320 |
+
hoverinfo="skip",
|
| 321 |
+
legendrank=30,
|
| 322 |
+
)
|
| 323 |
+
)
|
| 324 |
+
first_trend = False
|
| 325 |
+
|
| 326 |
+
# ---------- Layout ----------
|
| 327 |
+
# title_text формируем как раньше, но применяем только если show_title=True
|
| 328 |
+
title_text = f"SYNTAX predictions ({gt_row})"
|
| 329 |
+
if postfix:
|
| 330 |
+
title_text += f" {postfix}"
|
| 331 |
+
|
| 332 |
+
layout_kwargs = dict(
|
| 333 |
+
font=dict(
|
| 334 |
+
family=FONT_FAMILY,
|
| 335 |
+
size=BASE_FONT_SIZE,
|
| 336 |
+
),
|
| 337 |
+
width=PLOT_WIDTH,
|
| 338 |
+
height=PLOT_HEIGHT,
|
| 339 |
+
plot_bgcolor=PLOT_BG_COLOR,
|
| 340 |
+
paper_bgcolor=PAPER_BG_COLOR,
|
| 341 |
+
legend=dict(
|
| 342 |
+
x=LEGEND_X,
|
| 343 |
+
y=LEGEND_Y,
|
| 344 |
+
bgcolor=LEGEND_BG_COLOR, # полупрозрачный белый фон [web:143]
|
| 345 |
+
bordercolor="rgba(203,213,225,0.7)", # тоже чуть прозрачная рамка (по желанию) [web:145]
|
| 346 |
+
borderwidth=1,
|
| 347 |
+
font=dict(size=LEGEND_FONT_SIZE, family=FONT_FAMILY),
|
| 348 |
+
tracegroupgap=8,
|
| 349 |
+
itemclick="toggle",
|
| 350 |
+
itemdoubleclick="toggleothers",
|
| 351 |
+
groupclick="toggleitem",
|
| 352 |
+
),
|
| 353 |
+
xaxis=dict(
|
| 354 |
+
title=dict(
|
| 355 |
+
text="SYNTAX ground truth",
|
| 356 |
+
font=dict(
|
| 357 |
+
size=AXIS_TITLE_FONT_SIZE,
|
| 358 |
+
family=FONT_FAMILY,
|
| 359 |
+
color="rgba(15,23,42,1)",
|
| 360 |
+
),
|
| 361 |
+
),
|
| 362 |
+
showgrid=True,
|
| 363 |
+
gridcolor=GRID_COLOR,
|
| 364 |
+
gridwidth=1,
|
| 365 |
+
zeroline=False,
|
| 366 |
+
tickfont=dict(
|
| 367 |
+
size=AXIS_TICK_FONT_SIZE,
|
| 368 |
+
family=FONT_FAMILY,
|
| 369 |
+
),
|
| 370 |
+
range=[line_min, line_max],
|
| 371 |
+
constrain="domain",
|
| 372 |
+
),
|
| 373 |
+
yaxis=dict(
|
| 374 |
+
title=dict(
|
| 375 |
+
text="SYNTAX predictions",
|
| 376 |
+
font=dict(
|
| 377 |
+
size=AXIS_TITLE_FONT_SIZE,
|
| 378 |
+
family=FONT_FAMILY,
|
| 379 |
+
color="rgba(15,23,42,1)",
|
| 380 |
+
),
|
| 381 |
+
),
|
| 382 |
+
showgrid=True,
|
| 383 |
+
gridcolor=GRID_COLOR,
|
| 384 |
+
gridwidth=1,
|
| 385 |
+
zeroline=False,
|
| 386 |
+
tickfont=dict(
|
| 387 |
+
size=AXIS_TICK_FONT_SIZE,
|
| 388 |
+
family=FONT_FAMILY,
|
| 389 |
+
),
|
| 390 |
+
range=[line_min, line_max],
|
| 391 |
+
scaleanchor="x",
|
| 392 |
+
scaleratio=1,
|
| 393 |
+
constrain="domain",
|
| 394 |
+
),
|
| 395 |
+
margin=dict(
|
| 396 |
+
l=MARGIN_LEFT,
|
| 397 |
+
r=MARGIN_RIGHT,
|
| 398 |
+
t=MARGIN_TOP,
|
| 399 |
+
b=MARGIN_BOTTOM,
|
| 400 |
+
),
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
if show_title:
|
| 404 |
+
layout_kwargs["title"] = dict(
|
| 405 |
+
text=title_text,
|
| 406 |
+
x=0.5,
|
| 407 |
+
xanchor="center",
|
| 408 |
+
font=dict(
|
| 409 |
+
size=TITLE_FONT_SIZE,
|
| 410 |
+
family=FONT_FAMILY,
|
| 411 |
+
color="rgba(15,23,42,1)",
|
| 412 |
+
),
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
fig.update_layout(**layout_kwargs)
|
| 416 |
+
|
| 417 |
+
# ---------- Saving ----------
|
| 418 |
+
save_dir = "visualizations"
|
| 419 |
+
if backbone:
|
| 420 |
+
save_dir = os.path.join(save_dir, "backbone")
|
| 421 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 422 |
+
|
| 423 |
+
postfix_html = f"{postfix}" if postfix else "syntax"
|
| 424 |
+
save_path_html = os.path.join(save_dir, f"{postfix_html}.html")
|
| 425 |
+
fig.write_html(save_path_html, include_mathjax="cdn")
|
| 426 |
+
print(f"Saved visualization with logistic trends: {save_path_html}")
|
inference/rnn_apply.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# inference/rnn_apply.py
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import tqdm
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
import click
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
import lightning.pytorch as pl
|
| 10 |
+
import sklearn.metrics as skm
|
| 11 |
+
|
| 12 |
+
from torch.utils.data import DataLoader
|
| 13 |
+
from torchvision.transforms import transforms as T
|
| 14 |
+
from torchvision.transforms._transforms_video import ToTensorVideo
|
| 15 |
+
from pytorchvideo.transforms import Normalize
|
| 16 |
+
|
| 17 |
+
from full_model.rnn_dataset import SyntaxDataset
|
| 18 |
+
from full_model.rnn_model import SyntaxLightningModule
|
| 19 |
+
from inference.metrics_visualization import visualize_final_syntax_plotly_multi
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
| 23 |
+
print(f"DEVICE: {DEVICE}")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def safe_sample_std(values):
|
| 27 |
+
"""Sample std (ddof=1). Если значение одно/пусто — 0.0."""
|
| 28 |
+
arr = np.array(values, dtype=float)
|
| 29 |
+
if arr.size <= 1:
|
| 30 |
+
return 0.0
|
| 31 |
+
return float(arr.std(ddof=1))
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def compute_metrics(y_true, y_pred, thr=22.0):
|
| 35 |
+
"""Pearson и Mean_Recall."""
|
| 36 |
+
y_true_arr = np.array(y_true, dtype=float)
|
| 37 |
+
y_pred_arr = np.array(y_pred, dtype=float)
|
| 38 |
+
|
| 39 |
+
pearson = float(np.corrcoef(y_true_arr, y_pred_arr)[0, 1]) if len(y_true_arr) > 1 else 0.0
|
| 40 |
+
|
| 41 |
+
y_true_bin = (y_true_arr >= thr).astype(int)
|
| 42 |
+
y_pred_bin = (y_pred_arr >= thr).astype(int)
|
| 43 |
+
unique_classes = np.unique(np.concatenate([y_true_bin, y_pred_bin]))
|
| 44 |
+
mean_recall = float(
|
| 45 |
+
np.mean(skm.recall_score(y_true_bin, y_pred_bin, average=None, labels=[0, 1]))
|
| 46 |
+
) if len(unique_classes) > 1 else 0.0
|
| 47 |
+
|
| 48 |
+
return pearson, mean_recall
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@click.command()
|
| 52 |
+
@click.option("-d", "--dataset-paths", multiple=True,
|
| 53 |
+
help="JSON с метаданными датасетов (относительно dataset_root).")
|
| 54 |
+
@click.option("-n", "--dataset-names", multiple=True,
|
| 55 |
+
help="Имена датасетов для метрик/графиков.")
|
| 56 |
+
@click.option("-p", "--postfixes", multiple=True,
|
| 57 |
+
help="Суффиксы для файлов предсказаний.")
|
| 58 |
+
@click.option(
|
| 59 |
+
"-r",
|
| 60 |
+
"--dataset-root",
|
| 61 |
+
type=click.Path(exists=True),
|
| 62 |
+
default=".",
|
| 63 |
+
show_default=True,
|
| 64 |
+
help="Корень датасета (где лежат JSON и DICOM).",
|
| 65 |
+
)
|
| 66 |
+
@click.option(
|
| 67 |
+
"--model-dir",
|
| 68 |
+
type=click.Path(exists=True),
|
| 69 |
+
default="full_model_weights",
|
| 70 |
+
show_default=True,
|
| 71 |
+
help="Каталог с .pt/.ckpt весами full‑моделей (RNN‑head + backbone).",
|
| 72 |
+
)
|
| 73 |
+
@click.option("-v", "--video-size", type=click.Tuple([int, int]), default=(256, 256),
|
| 74 |
+
show_default=True, help="Размер видео (H, W).")
|
| 75 |
+
@click.option("--frames-per-clip", type=int, default=32,
|
| 76 |
+
show_default=True, help="Количество кадров в клипе.")
|
| 77 |
+
@click.option("--num-workers", type=int, default=8,
|
| 78 |
+
show_default=True, help="Число DataLoader workers.")
|
| 79 |
+
@click.option("--seed", type=int, default=42,
|
| 80 |
+
show_default=True, help="Random seed.")
|
| 81 |
+
@click.option(
|
| 82 |
+
"--pt-weights-format",
|
| 83 |
+
is_flag=True,
|
| 84 |
+
default=True,
|
| 85 |
+
show_default=True,
|
| 86 |
+
help="Формат весов full‑моделей: True → .pt (raw state_dict), False → Lightning .ckpt.",
|
| 87 |
+
)
|
| 88 |
+
@click.option("--use-scaling", is_flag=True, default=False,
|
| 89 |
+
show_default=True, help="Применить a*x+b scaling из JSON.")
|
| 90 |
+
@click.option("--scaling-file",
|
| 91 |
+
help="JSON с коэффициентами scaling (относительно dataset_root).")
|
| 92 |
+
@click.option(
|
| 93 |
+
"--variant",
|
| 94 |
+
type=str,
|
| 95 |
+
default="lstm_mean",
|
| 96 |
+
show_default=True,
|
| 97 |
+
help="Вариант head‑модели: mean, lstm_mean, lstm_last, gru_mean, gru_last, bert_mean, bert_cls, bert_cls2.",
|
| 98 |
+
)
|
| 99 |
+
@click.option("-e", "--ensemble-name",
|
| 100 |
+
help="Имя ансамбля в metrics.json.")
|
| 101 |
+
@click.option("-m", "--metrics-file",
|
| 102 |
+
help="JSON с метриками экспериментов.")
|
| 103 |
+
def main(dataset_paths, dataset_names, postfixes, dataset_root, model_dir, video_size,
|
| 104 |
+
frames_per_clip, num_workers, seed, pt_weights_format, use_scaling,
|
| 105 |
+
scaling_file, variant, ensemble_name, metrics_file):
|
| 106 |
+
|
| 107 |
+
pl.seed_everything(seed)
|
| 108 |
+
postfix_plotly = "Ensemble"
|
| 109 |
+
|
| 110 |
+
# Пути к моделям берутся из model_dir по шаблону
|
| 111 |
+
model_paths = {
|
| 112 |
+
"left": [
|
| 113 |
+
os.path.join(model_dir, f"LeftBinSyntax_R3D_fold{fold:02d}_{variant}_post_best.pt")
|
| 114 |
+
for fold in range(5)
|
| 115 |
+
],
|
| 116 |
+
"right": [
|
| 117 |
+
os.path.join(model_dir, f"RightBinSyntax_R3D_fold{fold:02d}_{variant}_post_best.pt")
|
| 118 |
+
for fold in range(5)
|
| 119 |
+
],
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
scaling_params_dict = {}
|
| 123 |
+
if use_scaling:
|
| 124 |
+
postfix_plotly += "_scaled"
|
| 125 |
+
ensemble_name += "_scaled"
|
| 126 |
+
scaling_path = os.path.join(dataset_root, scaling_file)
|
| 127 |
+
if os.path.exists(scaling_path):
|
| 128 |
+
with open(scaling_path, "r") as f:
|
| 129 |
+
scaling_params_dict = json.load(f)
|
| 130 |
+
print(f"Loaded scaling from {scaling_path}")
|
| 131 |
+
else:
|
| 132 |
+
print(f"⚠️ Scaling file not found: {scaling_path}")
|
| 133 |
+
|
| 134 |
+
ensemble_results = {
|
| 135 |
+
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
| 136 |
+
"use_scaling": use_scaling,
|
| 137 |
+
"pt_weights_format": pt_weights_format,
|
| 138 |
+
"variant": variant,
|
| 139 |
+
"datasets": {},
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
all_datasets, all_pearson, all_recalls = {}, {}, {}
|
| 143 |
+
|
| 144 |
+
# вспомогательная функция для получения (a, b)
|
| 145 |
+
def get_ab(i: int):
|
| 146 |
+
params = scaling_params_dict.get(f"fold{i}", (1.0, 0.0))
|
| 147 |
+
if isinstance(params, dict):
|
| 148 |
+
return params.get("a", 1.0), params.get("b", 0.0)
|
| 149 |
+
return params[0], params[1]
|
| 150 |
+
|
| 151 |
+
for dataset_path, dataset_name, postfix in zip(dataset_paths, dataset_names, postfixes):
|
| 152 |
+
abs_dataset_path = os.path.join(dataset_root, dataset_path)
|
| 153 |
+
results_file = os.path.join("results", f"{postfix}.json")
|
| 154 |
+
|
| 155 |
+
if os.path.exists(results_file):
|
| 156 |
+
print(f"[{postfix}] Loading from {results_file}")
|
| 157 |
+
with open(results_file, "r") as f:
|
| 158 |
+
data = json.load(f)
|
| 159 |
+
syntax_true = data["syntax_true"]
|
| 160 |
+
left_preds_all = data["left_preds"]
|
| 161 |
+
right_preds_all = data["right_preds"]
|
| 162 |
+
else:
|
| 163 |
+
print(f"[{postfix}] Computing predictions...")
|
| 164 |
+
left_preds_all, left_sids = run_artery(
|
| 165 |
+
abs_dataset_path, "left", model_paths["left"],
|
| 166 |
+
video_size, frames_per_clip, num_workers,
|
| 167 |
+
variant=variant, pt_weights_format=pt_weights_format,
|
| 168 |
+
)
|
| 169 |
+
right_preds_all, right_sids = run_artery(
|
| 170 |
+
abs_dataset_path, "right", model_paths["right"],
|
| 171 |
+
video_size, frames_per_clip, num_workers,
|
| 172 |
+
variant=variant, pt_weights_format=pt_weights_format,
|
| 173 |
+
)
|
| 174 |
+
assert left_sids == right_sids
|
| 175 |
+
|
| 176 |
+
with open(abs_dataset_path, "r") as f:
|
| 177 |
+
dataset = json.load(f)
|
| 178 |
+
syntax_true = [rec.get("mean_syntax", rec.get("syntax")) for rec in dataset]
|
| 179 |
+
|
| 180 |
+
os.makedirs(os.path.dirname(results_file), exist_ok=True)
|
| 181 |
+
save_data = {
|
| 182 |
+
"syntax_true": syntax_true,
|
| 183 |
+
"left_preds": left_preds_all,
|
| 184 |
+
"right_preds": right_preds_all,
|
| 185 |
+
}
|
| 186 |
+
with open(results_file, "w") as f:
|
| 187 |
+
json.dump(save_data, f)
|
| 188 |
+
print(f"[{postfix}] Saved to {results_file}")
|
| 189 |
+
|
| 190 |
+
# -------- ансамбль с/без scaling --------
|
| 191 |
+
if use_scaling:
|
| 192 |
+
syntax_pred = []
|
| 193 |
+
for l_list, r_list in zip(left_preds_all, right_preds_all):
|
| 194 |
+
scaled_folds = []
|
| 195 |
+
for i, (l_val, r_val) in enumerate(zip(l_list, r_list)):
|
| 196 |
+
s = l_val + r_val
|
| 197 |
+
a, b = get_ab(i)
|
| 198 |
+
scaled_folds.append(a * s + b)
|
| 199 |
+
syntax_pred.append(max(0.0, float(np.mean(scaled_folds))))
|
| 200 |
+
else:
|
| 201 |
+
syntax_pred = [
|
| 202 |
+
max(0.0, float(np.mean([l + r for l, r in zip(l_list, r_list)])))
|
| 203 |
+
for l_list, r_list in zip(left_preds_all, right_preds_all)
|
| 204 |
+
]
|
| 205 |
+
|
| 206 |
+
pearson, mean_recall = compute_metrics(syntax_true, syntax_pred)
|
| 207 |
+
print(f"[{postfix}] ENSEMBLE: Pearson={pearson:.4f}, Recall={mean_recall:.4f}")
|
| 208 |
+
|
| 209 |
+
# -------- per-fold метрики --------
|
| 210 |
+
n_folds = len(left_preds_all[0]) if left_preds_all else 0
|
| 211 |
+
fold_metrics = {metric: [] for metric in ["Pearson", "Mean_Recall"]}
|
| 212 |
+
|
| 213 |
+
for k in range(n_folds):
|
| 214 |
+
pred_k = []
|
| 215 |
+
for l_list, r_list in zip(left_preds_all, right_preds_all):
|
| 216 |
+
s = l_list[k] + r_list[k]
|
| 217 |
+
if use_scaling:
|
| 218 |
+
a, b = get_ab(k)
|
| 219 |
+
s = a * s + b
|
| 220 |
+
pred_k.append(max(0.0, float(s)))
|
| 221 |
+
|
| 222 |
+
fold_pearson, fold_recall = compute_metrics(syntax_true, pred_k)
|
| 223 |
+
for metric, value in zip(
|
| 224 |
+
fold_metrics.keys(),
|
| 225 |
+
[fold_pearson, fold_recall],
|
| 226 |
+
):
|
| 227 |
+
fold_metrics[metric].append(value)
|
| 228 |
+
|
| 229 |
+
fold_summary = {
|
| 230 |
+
k: {"mean": float(np.mean(v)), "std": safe_sample_std(v), "values": v}
|
| 231 |
+
for k, v in fold_metrics.items()
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
all_datasets[dataset_name] = (syntax_true, syntax_pred)
|
| 235 |
+
all_pearson[dataset_name] = pearson
|
| 236 |
+
all_recalls[dataset_name] = mean_recall
|
| 237 |
+
|
| 238 |
+
ensemble_results["datasets"][dataset_name] = {
|
| 239 |
+
"Pearson": round(pearson, 4),
|
| 240 |
+
"Mean_Recall": round(mean_recall, 4),
|
| 241 |
+
"N_samples": len(syntax_true),
|
| 242 |
+
**{f"{k}_mean": round(v["mean"], 4) for k, v in fold_summary.items()},
|
| 243 |
+
**{f"{k}_std": round(v["std"], 4) for k, v in fold_summary.items()},
|
| 244 |
+
**{f"{k}_folds": [round(x, 4) for x in v["values"]] for k, v in fold_summary.items()},
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
metrics_path = metrics_file
|
| 248 |
+
full_history = {}
|
| 249 |
+
if os.path.exists(metrics_path):
|
| 250 |
+
try:
|
| 251 |
+
with open(metrics_path, "r") as f:
|
| 252 |
+
full_history = json.load(f)
|
| 253 |
+
except json.JSONDecodeError:
|
| 254 |
+
print("⚠️ Metrics file corrupted. Creating new.")
|
| 255 |
+
|
| 256 |
+
full_history[ensemble_name] = ensemble_results
|
| 257 |
+
with open(metrics_path, "w") as f:
|
| 258 |
+
json.dump(full_history, f, indent=4)
|
| 259 |
+
print(f"✅ Metrics saved: {metrics_path}")
|
| 260 |
+
|
| 261 |
+
visualize_final_syntax_plotly_multi(
|
| 262 |
+
datasets=all_datasets,
|
| 263 |
+
r2_values=all_pearson, # здесь теперь Pearson
|
| 264 |
+
gt_row="ENSEMBLE",
|
| 265 |
+
postfix=postfix_plotly,
|
| 266 |
+
recall_values=all_recalls,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def run_artery(dataset_path, artery, model_paths, video_size, frames_per_clip,
|
| 271 |
+
num_workers, variant: str, pt_weights_format: bool):
|
| 272 |
+
"""Инференс для одной артерии (5 фолдов)."""
|
| 273 |
+
imagenet_mean = [0.485, 0.456, 0.406]
|
| 274 |
+
imagenet_std = [0.229, 0.224, 0.225]
|
| 275 |
+
test_transform = T.Compose([
|
| 276 |
+
ToTensorVideo(),
|
| 277 |
+
T.Resize(size=video_size, antialias=True),
|
| 278 |
+
Normalize(mean=imagenet_mean, std=imagenet_std),
|
| 279 |
+
])
|
| 280 |
+
|
| 281 |
+
val_set = SyntaxDataset(
|
| 282 |
+
root=os.path.dirname(dataset_path),
|
| 283 |
+
meta=dataset_path,
|
| 284 |
+
train=False,
|
| 285 |
+
length=frames_per_clip,
|
| 286 |
+
label="",
|
| 287 |
+
artery=artery,
|
| 288 |
+
inference=True,
|
| 289 |
+
transform=test_transform,
|
| 290 |
+
)
|
| 291 |
+
val_loader = DataLoader(
|
| 292 |
+
val_set,
|
| 293 |
+
batch_size=1,
|
| 294 |
+
num_workers=num_workers,
|
| 295 |
+
shuffle=False,
|
| 296 |
+
pin_memory=True,
|
| 297 |
+
)
|
| 298 |
+
print(f"{artery} artery: {len(val_loader)} samples")
|
| 299 |
+
|
| 300 |
+
models = []
|
| 301 |
+
for path in model_paths:
|
| 302 |
+
if not os.path.exists(path):
|
| 303 |
+
print(f"⚠️ Model not found: {path}")
|
| 304 |
+
continue
|
| 305 |
+
|
| 306 |
+
model = SyntaxLightningModule(
|
| 307 |
+
num_classes=2,
|
| 308 |
+
lr=1e-5,
|
| 309 |
+
variant=variant,
|
| 310 |
+
weight_decay=0.001,
|
| 311 |
+
max_epochs=1,
|
| 312 |
+
weight_path=None,
|
| 313 |
+
pl_weight_path=path,
|
| 314 |
+
pt_weights_format=pt_weights_format,
|
| 315 |
+
)
|
| 316 |
+
model.to(DEVICE)
|
| 317 |
+
model.eval()
|
| 318 |
+
models.append(model)
|
| 319 |
+
|
| 320 |
+
if not models:
|
| 321 |
+
raise RuntimeError(f"No models loaded for {artery}")
|
| 322 |
+
|
| 323 |
+
preds_all, sids = [], []
|
| 324 |
+
with torch.no_grad():
|
| 325 |
+
for x, [y], [t], [sid] in tqdm.tqdm(val_loader, desc=f"{artery} infer"):
|
| 326 |
+
if len(x.shape) == 1:
|
| 327 |
+
val_syntax_list = [0.0] * len(models)
|
| 328 |
+
else:
|
| 329 |
+
x = x.to(DEVICE)
|
| 330 |
+
val_syntax_list = []
|
| 331 |
+
for model in models:
|
| 332 |
+
y_hat = model(x)
|
| 333 |
+
yp_reg = y_hat[:, 1:]
|
| 334 |
+
val_log = yp_reg.squeeze(-1)
|
| 335 |
+
val = float(torch.exp(val_log).cpu()) - 1.0
|
| 336 |
+
val_syntax_list.append(val)
|
| 337 |
+
preds_all.append(val_syntax_list)
|
| 338 |
+
sids.append(sid)
|
| 339 |
+
|
| 340 |
+
return preds_all, sids
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
if __name__ == "__main__":
|
| 344 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core ML/DL
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
torchvision>=0.15.0
|
| 4 |
+
lightning>=2.1.0
|
| 5 |
+
pytorchvideo @ git+https://github.com/facebookresearch/pytorchvideo@main#egg=pytorchvideo
|
| 6 |
+
|
| 7 |
+
# Data processing
|
| 8 |
+
numpy>=1.24.0
|
| 9 |
+
scikit-learn>=1.3.0
|
| 10 |
+
tqdm>=4.65.0
|
| 11 |
+
pydicom>=2.4.0
|
| 12 |
+
python-gdcm>=3.0.10
|
| 13 |
+
|
| 14 |
+
# Visualization
|
| 15 |
+
plotly>=5.17.0
|
| 16 |
+
scipy>=1.11.0
|
| 17 |
+
tensorboard>=2.9
|
| 18 |
+
tensorboardX>=2.6
|
| 19 |
+
|
| 20 |
+
# CLI & Utilities
|
| 21 |
+
click>=8.1.0
|
scaling_coeffs/scaling_coeffs.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fold0": {
|
| 3 |
+
"a": 1.65,
|
| 4 |
+
"b": 0.4,
|
| 5 |
+
"mean_recall": 0.715961
|
| 6 |
+
},
|
| 7 |
+
"fold1": {
|
| 8 |
+
"a": 1.29,
|
| 9 |
+
"b": 0.38,
|
| 10 |
+
"mean_recall": 0.767792
|
| 11 |
+
},
|
| 12 |
+
"fold2": {
|
| 13 |
+
"a": 1.28,
|
| 14 |
+
"b": 0.365,
|
| 15 |
+
"mean_recall": 0.800703
|
| 16 |
+
},
|
| 17 |
+
"fold3": {
|
| 18 |
+
"a": 1.11,
|
| 19 |
+
"b": 0.42,
|
| 20 |
+
"mean_recall": 0.761545
|
| 21 |
+
},
|
| 22 |
+
"fold4": {
|
| 23 |
+
"a": 1.61,
|
| 24 |
+
"b": 0.385,
|
| 25 |
+
"mean_recall": 0.736111
|
| 26 |
+
}
|
| 27 |
+
}
|