MesserMMP commited on
Commit
f621d73
·
1 Parent(s): 927c73f

Add model code and full model weights

Browse files
Files changed (34) hide show
  1. .gitignore +144 -0
  2. backbone/__init__.py +0 -0
  3. backbone/dataset.py +221 -0
  4. backbone/pl_model.py +275 -0
  5. backbone/pl_train.py +335 -0
  6. backbone_weights/leftBinSyntax_R3D_full_fold00.pt +3 -0
  7. backbone_weights/leftBinSyntax_R3D_full_fold01.pt +3 -0
  8. backbone_weights/leftBinSyntax_R3D_full_fold02.pt +3 -0
  9. backbone_weights/leftBinSyntax_R3D_full_fold03.pt +3 -0
  10. backbone_weights/leftBinSyntax_R3D_full_fold04.pt +3 -0
  11. backbone_weights/rightBinSyntax_R3D_full_fold00.pt +3 -0
  12. backbone_weights/rightBinSyntax_R3D_full_fold01.pt +3 -0
  13. backbone_weights/rightBinSyntax_R3D_full_fold02.pt +3 -0
  14. backbone_weights/rightBinSyntax_R3D_full_fold03.pt +3 -0
  15. backbone_weights/rightBinSyntax_R3D_full_fold04.pt +3 -0
  16. full_model/__init__.py +0 -0
  17. full_model/rnn_dataset.py +257 -0
  18. full_model/rnn_model.py +386 -0
  19. full_model/rnn_train.py +418 -0
  20. full_model_weights/LeftBinSyntax_R3D_fold00_lstm_mean_post_best.pt +3 -0
  21. full_model_weights/LeftBinSyntax_R3D_fold01_lstm_mean_post_best.pt +3 -0
  22. full_model_weights/LeftBinSyntax_R3D_fold02_lstm_mean_post_best.pt +3 -0
  23. full_model_weights/LeftBinSyntax_R3D_fold03_lstm_mean_post_best.pt +3 -0
  24. full_model_weights/LeftBinSyntax_R3D_fold04_lstm_mean_post_best.pt +3 -0
  25. full_model_weights/RightBinSyntax_R3D_fold00_lstm_mean_post_best.pt +3 -0
  26. full_model_weights/RightBinSyntax_R3D_fold01_lstm_mean_post_best.pt +3 -0
  27. full_model_weights/RightBinSyntax_R3D_fold02_lstm_mean_post_best.pt +3 -0
  28. full_model_weights/RightBinSyntax_R3D_fold03_lstm_mean_post_best.pt +3 -0
  29. full_model_weights/RightBinSyntax_R3D_fold04_lstm_mean_post_best.pt +3 -0
  30. inference/__init__.py +0 -0
  31. inference/metrics_visualization.py +426 -0
  32. inference/rnn_apply.py +344 -0
  33. requirements.txt +21 -0
  34. scaling_coeffs/scaling_coeffs.json +27 -0
.gitignore ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ env/
8
+ venv/
9
+ ENV/
10
+
11
+ # PyInstaller
12
+ # Usually these files are written by a python script from a template
13
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
14
+ *.manifest
15
+ *.spec
16
+
17
+ # Installer logs
18
+ pip-log.txt
19
+ pip-delete-this-directory.txt
20
+
21
+ # Unit test / coverage reports
22
+ htmlcov/
23
+ .tox/
24
+ .nox/
25
+ .coverage
26
+ .coverage.*
27
+ .cache
28
+ nosetests.xml
29
+ coverage.xml
30
+ *.cover
31
+ *.py,cover
32
+ .hypothesis/
33
+ .pytest_cache/
34
+ cover/
35
+
36
+ # Translations
37
+ *.mo
38
+ *.pot
39
+
40
+ # Django stuff:
41
+ *.log
42
+ local_settings.py
43
+ db.sqlite3
44
+ db.sqlite3-journal
45
+
46
+ # Flask stuff:
47
+ instance/
48
+ .webassets-cache
49
+
50
+ # Scrapy stuff:
51
+ .scrapy
52
+
53
+ # Sphinx documentation
54
+ docs/_build/
55
+
56
+ # PyBuilder
57
+ target/
58
+
59
+ # Jupyter Notebook
60
+ .ipynb_checkpoints
61
+
62
+ # IPython
63
+ profile_default/
64
+ ipython_config.py
65
+
66
+ # pyenv
67
+ .python-version
68
+
69
+ # pipenv
70
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
71
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
72
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
73
+ # install all needed dependencies.
74
+ #Pipfile.lock
75
+
76
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
77
+ __pypackages__/
78
+
79
+ # Celery stuff
80
+ celerybeat-schedule
81
+ celerybeat.pid
82
+
83
+ # SageMath parsed files
84
+ *.sage.py
85
+
86
+ # Environments
87
+ .env
88
+ .venv
89
+ env/
90
+ venv/
91
+ ENV/
92
+ env.bak/
93
+ venv.bak/
94
+
95
+ # Spyder project settings
96
+ .spyderproject
97
+ .spyproject
98
+
99
+ # Rope project settings
100
+ .ropeproject
101
+
102
+ # mkdocs documentation
103
+ /site
104
+
105
+ # mypy
106
+ .mypy_cache/
107
+ .dmypy.json
108
+ dmypy.json
109
+
110
+ # Pyre type checker
111
+ .pyre/
112
+
113
+ # Cython debug symbols
114
+ cython_debug/
115
+
116
+ # PyCharm
117
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
118
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
119
+ # and can be added to your project. Official documentation:
120
+ # https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
121
+ .idea/
122
+
123
+ # VS Code
124
+ .vscode/
125
+
126
+ # Pyright
127
+ .pyright/
128
+
129
+ # Pyre type checker
130
+ .pyre/
131
+
132
+ __pycache__/
133
+ *.pyc
134
+ *.pyo
135
+ logs/
136
+ rnn_logs/
137
+ backbone_logs/
138
+ checkpoints/
139
+ lightning_logs/
140
+ wandb/
141
+ runs/
142
+ tensorboard/
143
+ results/
144
+ visualizations/
backbone/__init__.py ADDED
File without changes
backbone/dataset.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import Callable, Optional, Tuple, Any
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import pydicom
8
+ import torch
9
+ from torch import Tensor
10
+ from torch.utils.data import Dataset
11
+
12
+
13
+ class SyntaxDataset(Dataset):
14
+ """
15
+ PyTorch Dataset для обучения 3D-backbone по DICOM-видео.
16
+
17
+ Ожидается, что:
18
+ - meta (JSON) содержит список словарей с полями:
19
+ "path": относительный путь к DICOM-файлу от директории этого JSON
20
+ "artery": 0 (левая) или 1 (правая)
21
+ "<label>": численное значение SYNTAX score (например, "syntax_left")
22
+ - видео загружается из multi-frame DICOM (pydicom), как 3D-массив.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ root: str,
28
+ meta: str,
29
+ train: bool,
30
+ length: int,
31
+ label: str,
32
+ artery_bin: int,
33
+ validation: bool = False,
34
+ transform: Optional[Callable] = None,
35
+ ) -> None:
36
+ # Корень датасета (используется, если meta передан как относительный путь)
37
+ self.root = Path(root).resolve()
38
+ # Режим обучения/валидации
39
+ self.train = train
40
+ # Требуемая длина клипа (количество временных кадров)
41
+ self.length = int(length)
42
+ # Имя поля с меткой (например, "syntax_left")
43
+ self.label = label
44
+ # Функция аугментаций/преобразований
45
+ self.transform = transform
46
+ # Флаг: использовать только записи с положительным score
47
+ self.validation = validation
48
+
49
+ # Проверка валидности кода артерии: 0 — left, 1 — right
50
+ if artery_bin not in (0, 1):
51
+ raise ValueError("artery_bin must be 0 (left) or 1 (right)")
52
+ self.artery_bin = artery_bin
53
+
54
+ # Полный путь к JSON-файлу метаданных
55
+ meta_path = meta if os.path.isabs(meta) else self.root / meta
56
+ meta_path = Path(meta_path).resolve()
57
+
58
+ # Директория, в которой расположен JSON; пути в "path" задаются относительно неё
59
+ json_dir = meta_path.parent
60
+
61
+ print(f"Backbone dataset: root={self.root}, meta={meta_path}, json_dir={json_dir}")
62
+
63
+ # Загрузка списка записей из JSON
64
+ with open(meta_path, "r", encoding="utf-8") as f:
65
+ dataset = json.load(f)
66
+
67
+ # Фильтрация по артерии
68
+ dataset = [rec for rec in dataset if rec.get("artery") == artery_bin]
69
+
70
+ # Для валидации можно оставить только записи с положительным score
71
+ if validation:
72
+ dataset = [rec for rec in dataset if float(rec.get(self.label, 0.0)) > 0]
73
+
74
+ # Сохраняем базовую директорию JSON для последующей сборки путей DICOM
75
+ self.json_dir = json_dir
76
+ self.dataset = dataset
77
+
78
+ # Инициализация веса записи (если не задан) значением 1.0
79
+ for rec in self.dataset:
80
+ rec.setdefault("weight", 1.0)
81
+
82
+ print(f"Backbone dataset loaded: {len(self.dataset)} samples after filtering")
83
+
84
+ def get_sample_weights(self) -> Tensor:
85
+ """
86
+ Возвращает веса для выборки с помощью WeightedRandomSampler.
87
+
88
+ Логика:
89
+ - делим записи на интервалы по severity (score) отдельно для левой и правой артерии;
90
+ - чем реже встречается интервал, тем больше вес его записей.
91
+ """
92
+ # Пороговые значения score для разбиения на интервалы (по артерии)
93
+ bin_thresholds = {
94
+ 0: [0, 5, 10, 15], # левая
95
+ 1: [0, 2, 5, 8], # правая
96
+ }
97
+ thr0, thr1, thr2, thr3 = bin_thresholds[self.artery_bin]
98
+
99
+ # Функция определения номера бина по значению score
100
+ def in_bin(score: float) -> int:
101
+ if score == thr0:
102
+ return 0
103
+ if thr0 < score <= thr1:
104
+ return 1
105
+ if thr1 < score <= thr2:
106
+ return 2
107
+ if thr2 < score <= thr3:
108
+ return 3
109
+ return 4
110
+
111
+ # Список score'ов и соответствующих бинов
112
+ scores = [float(rec.get(self.label, 0.0)) for rec in self.dataset]
113
+ bins = [in_bin(s) for s in scores]
114
+
115
+ # Подсчёт количества элементов в каждом бине
116
+ counts = np.bincount(np.array(bins, dtype=np.int64), minlength=5)
117
+ total = int(counts.sum())
118
+
119
+ # Вес бина — обратная частота
120
+ weights_by_bin = np.array(
121
+ [(total / counts[b]) if counts[b] > 0 else 0.0 for b in range(5)],
122
+ dtype=np.float64,
123
+ )
124
+
125
+ # Веса для каждой записи по номеру её бина
126
+ weights = np.array([weights_by_bin[b] for b in bins], dtype=np.float64)
127
+
128
+ return torch.as_tensor(weights, dtype=torch.double)
129
+
130
+ def __len__(self) -> int:
131
+ # Размер датасета — количество записей в JSON после всех фильтров
132
+ return len(self.dataset)
133
+
134
+ def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor, float, str, Tensor]:
135
+ """
136
+ Возвращает:
137
+ video: Tensor, видео-клип (T, H, W, C) до применения transform
138
+ label: Tensor(1,), бинарная метка для классификации
139
+ target: Tensor(1,), регрессионная цель (log1p(score))
140
+ sample_weight: float, исходный вес записи из JSON
141
+ path: str, относительный путь к DICOM (как в JSON)
142
+ original_label: Tensor(1,), исходное значение score
143
+ """
144
+ rec = self.dataset[idx]
145
+
146
+ # Относительный путь к DICOM, как он хранится в JSON
147
+ rel_path = rec["path"]
148
+ sample_weight = float(rec.get("weight", 1.0))
149
+
150
+ # Полный путь: директория JSON + относительный путь из JSON
151
+ full_path = (self.json_dir / rel_path).resolve()
152
+
153
+ if not full_path.exists():
154
+ raise FileNotFoundError(
155
+ f"DICOM not found: {full_path}\n"
156
+ f" json_dir={self.json_dir}\n"
157
+ f" rel_path='{rel_path}'"
158
+ )
159
+
160
+ # Загрузка DICOM и получение массива пикселей
161
+ video = pydicom.dcmread(str(full_path)).pixel_array
162
+
163
+ # Ожидается 3D-массив (T, H, W) или (H, W, T)
164
+ if video.ndim != 3:
165
+ raise ValueError(f"Expected 3D video array, got shape={video.shape} for {rel_path}")
166
+
167
+ # Если временная ось оказалась последней (H, W, T) — перенесём её в первую (T, H, W)
168
+ if video.shape[0] > 128 and video.shape[-1] <= 128:
169
+ video = np.moveaxis(video, -1, 0)
170
+
171
+ # Нормализация uint16 → uint8 с масштабированием на [0, 255]
172
+ if video.dtype == np.uint16:
173
+ vmax = int(np.max(video))
174
+ if vmax <= 0:
175
+ raise ValueError(f"Invalid vmax={vmax} for {rel_path}")
176
+ video = (video.astype(np.float32) * (255.0 / vmax)).clip(0, 255).astype(np.uint8)
177
+ else:
178
+ video = video.astype(np.uint8)
179
+
180
+ # Получение численного score из записи
181
+ score = float(rec.get(self.label, 0.0))
182
+
183
+ # Порог для бинарной классификации (отдельно для каждой артерии)
184
+ bin_thresholds = {
185
+ 0: 15, # левая
186
+ 1: 5, # правая
187
+ }
188
+
189
+ # Бинарная метка: 1, если score выше порога, иначе 0
190
+ label = torch.tensor(
191
+ [1.0 if score > bin_thresholds[self.artery_bin] else 0.0],
192
+ dtype=torch.float32,
193
+ )
194
+ # Регрессионная цель — логарифм score с единицей
195
+ target = torch.tensor([np.log1p(score)], dtype=torch.float32)
196
+ # Исходное значение score
197
+ original_label = torch.tensor([score], dtype=torch.float32)
198
+
199
+ # При необходимости "дублируем" видео по времени до нужной длины клипа
200
+ while video.shape[0] < self.length:
201
+ video = np.concatenate([video, video], axis=0)
202
+
203
+ # Случайная или фиксированная вырезка окна по времени
204
+ t = int(video.shape[0])
205
+ if self.train:
206
+ # В обучении берём случайный отрезок длины self.length
207
+ begin = torch.randint(low=0, high=t - self.length + 1, size=(1,)).item()
208
+ video = video[begin: begin + self.length]
209
+ else:
210
+ # На валидации можно взять первые self.length кадров
211
+ video = video[:self.length]
212
+
213
+ # Преобразуем (T, H, W) в (T, H, W, C), где C=3 (дублируем градации серого по каналам)
214
+ video = torch.from_numpy(np.stack([video, video, video], axis=-1))
215
+
216
+ # Применяем цепочку трансформаций, если задана
217
+ if self.transform is not None:
218
+ video = self.transform(video)
219
+
220
+ # Возвращаем видео, метки и относительный путь
221
+ return video, label, target, sample_weight, str(rel_path), original_label
backbone/pl_model.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn, optim
6
+ import lightning.pytorch as pl
7
+ import torchvision.models.video as tvmv
8
+ import sklearn.metrics as skm
9
+
10
+
11
+ class SyntaxLightningModule(pl.LightningModule):
12
+ """
13
+ LightningModule для обучения 3D-backbone на SYNTAX score.
14
+
15
+ Архитектура:
16
+ - backbone: ResNet3D (r3d_18) из torchvision
17
+ - выходной полносвязный слой: два нейрона
18
+ [0] — логит для бинарной классификации (значимое поражение)
19
+ [1] — регрессионный выход для SYNTAX score (log1p)
20
+
21
+ Режимы обучения:
22
+ - pretrain (weight_path is None):
23
+ замораживается весь backbone, обучается только финальный слой (fc)
24
+ - finetune (weight_path задан):
25
+ загружаются веса из чекпоинта и дообучается вся сеть целиком.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ num_classes: int,
31
+ lr: float,
32
+ weight_decay: float = 0.0,
33
+ max_epochs: Optional[int] = None,
34
+ weight_path: Optional[str] = None,
35
+ sigma_a: float = 0.0,
36
+ sigma_b: float = 1.0,
37
+ **kwargs,
38
+ ):
39
+ super().__init__()
40
+ self.save_hyperparameters()
41
+
42
+ self.num_classes = int(num_classes)
43
+ self.lr = float(lr)
44
+ self.weight_decay = float(weight_decay)
45
+ self.max_epochs = max_epochs
46
+ self.weight_path = weight_path
47
+
48
+ self.sigma_a = float(sigma_a)
49
+ self.sigma_b = float(sigma_b)
50
+
51
+ # Инициализация 3D-ResNet-18 с предобученными весами
52
+ self.model = tvmv.r3d_18(weights=tvmv.R3D_18_Weights.DEFAULT)
53
+
54
+ # Замена финального слоя fc на слой с num_classes выходами
55
+ in_features = self.model.fc.in_features
56
+ self.model.fc = nn.Linear(in_features=in_features, out_features=self.num_classes, bias=True)
57
+
58
+ # При наличии пути к весам загружаем backbone
59
+ if self.weight_path is not None:
60
+ self._load_backbone_weights(self.weight_path)
61
+
62
+ # Лоссы
63
+ self.loss_clf = nn.BCEWithLogitsLoss(reduction="none")
64
+ self.loss_reg = nn.MSELoss(reduction="none")
65
+
66
+ # Буферы для валидации
67
+ self._y_true = []
68
+ self._y_prob = []
69
+ self._y_pred = []
70
+ self._t_true = []
71
+ self._t_pred = []
72
+
73
+ def _load_backbone_weights(self, weight_path: str) -> None:
74
+ """
75
+ Загружает веса backbone из:
76
+ - Lightning чекпоинта (dict с ключом 'state_dict')
77
+ - или "голого" state_dict (.pt/.pth), сохранённого через model.state_dict().
78
+
79
+ Логирует источник и статистику по ключам.
80
+ """
81
+ obj = torch.load(weight_path, map_location="cpu", weights_only=False)
82
+
83
+ if isinstance(obj, dict) and "state_dict" in obj:
84
+ state_dict = obj["state_dict"]
85
+ state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
86
+ src_type = "lightning_checkpoint"
87
+ else:
88
+ state_dict = obj
89
+ src_type = "raw_state_dict"
90
+
91
+ incompatible = self.model.load_state_dict(state_dict, strict=False)
92
+
93
+ loaded_keys = [k for k in state_dict.keys() if k not in incompatible.missing_keys]
94
+ print(
95
+ f"[Backbone] Loaded weights from '{weight_path}' "
96
+ f"(type={src_type}): {len(loaded_keys)} params, "
97
+ f"missing={len(incompatible.missing_keys)}, "
98
+ f"unexpected={len(incompatible.unexpected_keys)}"
99
+ )
100
+ if incompatible.missing_keys:
101
+ print(f"[Backbone] Missing keys example: {incompatible.missing_keys[:5]}")
102
+ if incompatible.unexpected_keys:
103
+ print(f"[Backbone] Unexpected keys example: {incompatible.unexpected_keys[:5]}")
104
+
105
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
106
+ """
107
+ Вход:
108
+ x: (B, C, T, H, W)
109
+
110
+ Выход:
111
+ y_hat: (B, 2) — [clf_logit, reg_output]
112
+ """
113
+ return self.model(x)
114
+
115
+ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor:
116
+ """
117
+ Один шаг обучения backbone.
118
+ """
119
+ x, y, target, sample_weight, path, original_label = batch
120
+
121
+ y_hat = self(x)
122
+ yp_clf = y_hat[:, 0:1]
123
+ yp_reg = y_hat[:, 1:2]
124
+
125
+ weights_clf = torch.where(y > 0, 1.0, 0.45).to(y.dtype)
126
+ clf_loss = (self.loss_clf(yp_clf, y) * weights_clf).mean()
127
+
128
+ reg_loss_raw = self.loss_reg(yp_reg, target)
129
+ sigma = self.sigma_a * target + self.sigma_b
130
+ reg_loss = (reg_loss_raw / (sigma ** 2)).mean()
131
+
132
+ loss = clf_loss + 0.5 * reg_loss
133
+
134
+ y_prob = torch.sigmoid(yp_clf).detach()
135
+ y_pred = (y_prob > 0.5).int().cpu().numpy()
136
+ y_true = y.detach().int().cpu().numpy()
137
+
138
+ self.log("train_clf_loss", clf_loss, prog_bar=True, sync_dist=True)
139
+ self.log("train_reg_loss", reg_loss, prog_bar=True, sync_dist=True)
140
+ self.log("train_loss", loss, prog_bar=True, sync_dist=True)
141
+
142
+ self.log("train_f1", skm.f1_score(y_true, y_pred, zero_division=0),
143
+ prog_bar=True, sync_dist=True)
144
+ self.log("train_acc", skm.accuracy_score(y_true, y_pred),
145
+ prog_bar=True, sync_dist=True)
146
+
147
+ return loss
148
+
149
+ def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor:
150
+ """
151
+ Один шаг валидации backbone.
152
+ """
153
+ x, y, target, sample_weight, path, original_label = batch
154
+
155
+ y_hat = self(x)
156
+ yp_clf = y_hat[:, 0:1]
157
+ yp_reg = y_hat[:, 1:2]
158
+
159
+ clf_loss = self.loss_clf(yp_clf, y).mean()
160
+
161
+ reg_loss_raw = self.loss_reg(yp_reg, target)
162
+ sigma = self.sigma_a * target + self.sigma_b
163
+ reg_loss = (reg_loss_raw / (sigma ** 2)).mean()
164
+
165
+ loss = clf_loss + 0.5 * reg_loss
166
+
167
+ y_prob = torch.sigmoid(yp_clf).float()
168
+ self._y_true.append(float(y[..., 0].float().cpu()))
169
+ self._y_prob.append(float(y_prob[..., 0].cpu()))
170
+ self._y_pred.append(int((y_prob[..., 0] > 0.5).cpu()))
171
+
172
+ self._t_true.append(float(target[..., 0].float().cpu()))
173
+ self._t_pred.append(float(yp_reg[..., 0].cpu()))
174
+
175
+ self.log("val_loss", loss, prog_bar=True, sync_dist=True)
176
+ self.log("val_clf_loss", clf_loss, prog_bar=False, sync_dist=True)
177
+ self.log("val_reg_loss", reg_loss, prog_bar=False, sync_dist=True)
178
+
179
+ return loss
180
+
181
+ def on_validation_epoch_end(self) -> None:
182
+ """
183
+ Расчёт и логирование метрик по окончании валидации.
184
+ """
185
+ if len(self._t_true) > 0:
186
+ rmse = skm.root_mean_squared_error(self._t_true, self._t_pred)
187
+ mae = skm.mean_absolute_error(self._t_true, self._t_pred)
188
+ self.log("val_rmse", rmse, prog_bar=True, sync_dist=True)
189
+ self.log("val_reg_mae", mae, prog_bar=True, sync_dist=True)
190
+
191
+ if len(set(self._y_true)) > 1:
192
+ auc = skm.roc_auc_score(self._y_true, self._y_prob)
193
+ f1 = skm.f1_score(self._y_true, self._y_pred, zero_division=0)
194
+ acc = skm.accuracy_score(self._y_true, self._y_pred)
195
+ self.log("val_auc", auc, prog_bar=True, sync_dist=True)
196
+ self.log("val_f1", f1, prog_bar=True, sync_dist=True)
197
+ self.log("val_acc", acc, prog_bar=True, sync_dist=True)
198
+
199
+ self._y_true.clear()
200
+ self._y_prob.clear()
201
+ self._y_pred.clear()
202
+ self._t_true.clear()
203
+ self._t_pred.clear()
204
+
205
+ def on_train_epoch_end(self) -> None:
206
+ """
207
+ Логирование текущего learning rate.
208
+ """
209
+ opt = self.optimizers()
210
+ self.log(
211
+ "lr",
212
+ opt.optimizer.param_groups[0]["lr"],
213
+ on_step=False,
214
+ on_epoch=True,
215
+ sync_dist=True,
216
+ )
217
+
218
+ def configure_optimizers(self):
219
+ """
220
+ Настройка оптимизатора и OneCycleLR.
221
+
222
+ Если weight_path is None:
223
+ обучается только self.model.fc (pretrain).
224
+ Иначе:
225
+ обучается вся модель (full finetune).
226
+ """
227
+ if self.weight_path is None:
228
+ for p in self.parameters():
229
+ p.requires_grad = False
230
+ for p in self.model.fc.parameters():
231
+ p.requires_grad = True
232
+ params = self.model.fc.parameters()
233
+ else:
234
+ for p in self.parameters():
235
+ p.requires_grad = True
236
+ params = self.parameters()
237
+
238
+ optimizer = optim.AdamW(params, lr=self.lr, weight_decay=self.weight_decay)
239
+
240
+ if self.max_epochs is not None and getattr(self, "trainer", None) is not None:
241
+ total_steps = self.trainer.estimated_stepping_batches
242
+ scheduler = optim.lr_scheduler.OneCycleLR(
243
+ optimizer=optimizer,
244
+ max_lr=self.lr,
245
+ total_steps=total_steps,
246
+ )
247
+ return {
248
+ "optimizer": optimizer,
249
+ "lr_scheduler": {
250
+ "scheduler": scheduler,
251
+ "interval": "step",
252
+ },
253
+ }
254
+
255
+ return optimizer
256
+
257
+ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
258
+ """
259
+ Шаг инференса backbone.
260
+ """
261
+ x, y, target, sample_weight, path, original_label = batch
262
+ y_hat = self(x)
263
+ yp_clf = y_hat[:, 0:1]
264
+ yp_reg = y_hat[:, 1:2]
265
+ y_prob = torch.sigmoid(yp_clf)
266
+
267
+ return {
268
+ "y": y,
269
+ "y_pred": (y_prob > 0.5).int(),
270
+ "y_prob": y_prob,
271
+ "y_reg": yp_reg,
272
+ "target": target,
273
+ "original_label": original_label,
274
+ "path": path,
275
+ }
backbone/pl_train.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+
4
+ import click
5
+ import lightning.pytorch as pl
6
+ import torch
7
+ from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
8
+ from lightning.pytorch.loggers import TensorBoardLogger
9
+ from pytorchvideo.transforms import Normalize, Permute, RandAugment
10
+ from torch.utils.data import DataLoader, WeightedRandomSampler
11
+ from torchvision.transforms import transforms as T
12
+ from torchvision.transforms._transforms_video import ToTensorVideo
13
+ from torchvision.transforms import InterpolationMode
14
+
15
+ from backbone.dataset import SyntaxDataset
16
+ from backbone.pl_model import SyntaxLightningModule
17
+
18
+
19
+ # Отключаем предупреждение Lightning о device id при DDP-инициализации
20
+ warnings.filterwarnings("ignore", message="No device id is provided via `init_process_group`")
21
+
22
+ # Устанавливаем точность матричных умножений (оптимизация производительности)
23
+ torch.set_float32_matmul_precision("medium")
24
+
25
+
26
+ def get_transforms(video_size, imagenet_mean, imagenet_std, train: bool = True):
27
+ """
28
+ Создаёт пайплайн аугментаций/преобразований для видео.
29
+
30
+ Входные данные:
31
+ - видео в формате Tensor (T, H, W, C), dtype uint8
32
+
33
+ Результат:
34
+ - Tensor (C, T, H, W), нормализованный, готовый к подаче в 3D-ResNet.
35
+ """
36
+ interpolation_choices = [InterpolationMode.BILINEAR, InterpolationMode.BICUBIC]
37
+
38
+ if train:
39
+ return T.Compose([
40
+ # Переводит (T, H, W, C) → (C, T, H, W), значения в [0,1]
41
+ ToTensorVideo(),
42
+ # Меняем порядок осей: (C, T, H, W) → (T, C, H, W) для RandAugment
43
+ Permute(dims=[1, 0, 2, 3]),
44
+ # Случайные аугментации по времени/пространству
45
+ RandAugment(magnitude=10, num_layers=2),
46
+ # Случайное горизонтальное отражение
47
+ T.RandomHorizontalFlip(),
48
+ # Возвращаемся к формату (C, T, H, W)
49
+ Permute(dims=[1, 0, 2, 3]),
50
+ # Случайный выбор интерполяции для изменения размера
51
+ T.RandomChoice([
52
+ T.Resize(size=video_size, interpolation=interp, antialias=True)
53
+ for interp in interpolation_choices
54
+ ]),
55
+ # Нормализация по статистикам ImageNet
56
+ Normalize(mean=imagenet_mean, std=imagenet_std),
57
+ ])
58
+ else:
59
+ # Для валидации/инференса используем только приведение к тензору и resize
60
+ return T.Compose([
61
+ ToTensorVideo(),
62
+ T.Resize(size=video_size, interpolation=InterpolationMode.BICUBIC, antialias=True),
63
+ Normalize(mean=imagenet_mean, std=imagenet_std),
64
+ ])
65
+
66
+
67
+ def make_dataloader(dataset, batch_size: int, num_workers: int, use_weighted_sampler: bool):
68
+ """
69
+ Создаёт DataLoader с опциональным WeightedRandomSampler.
70
+
71
+ Если use_weighted_sampler=True:
72
+ - семплирование идёт с учётом весов, возвращаемых dataset.get_sample_weights()
73
+ - shuffle выключается, так как порядок определяется сэмплером
74
+ """
75
+ if use_weighted_sampler:
76
+ sample_weights = dataset.get_sample_weights().cpu()
77
+ sampler = WeightedRandomSampler(sample_weights, num_samples=len(dataset), replacement=True)
78
+ shuffle = False
79
+ else:
80
+ sampler = None
81
+ shuffle = True
82
+
83
+ return DataLoader(
84
+ dataset,
85
+ batch_size=batch_size,
86
+ num_workers=num_workers,
87
+ sampler=sampler,
88
+ shuffle=shuffle,
89
+ drop_last=True,
90
+ pin_memory=True,
91
+ persistent_workers=(num_workers > 0),
92
+ )
93
+
94
+
95
+ def make_model(num_classes: int, lr: float, weight_decay: float, max_epochs: int, weight_path: str = None):
96
+ """
97
+ Конструктор LightningModule для backbone.
98
+
99
+ num_classes:
100
+ количество выходных нейронов (обычно 2: классификация + регрессия).
101
+ """
102
+ return SyntaxLightningModule(
103
+ num_classes=num_classes,
104
+ lr=lr,
105
+ weight_decay=weight_decay,
106
+ max_epochs=max_epochs,
107
+ weight_path=weight_path,
108
+ )
109
+
110
+
111
+ def make_callbacks(phase: str):
112
+ """
113
+ Создаёт список callback'ов для Trainer:
114
+ - мониторинг learning rate
115
+ - сохранение чекпоинтов по метрике val_rmse
116
+ """
117
+ lr_monitor = LearningRateMonitor(logging_interval="epoch")
118
+
119
+ checkpoint = ModelCheckpoint(
120
+ monitor="val_rmse",
121
+ save_top_k=1 if phase == "pre" else 3,
122
+ mode="min",
123
+ filename="model-{epoch:02d}-{val_rmse:.3f}",
124
+ save_last=True,
125
+ )
126
+ return [lr_monitor, checkpoint]
127
+
128
+
129
+ def make_trainer(max_epochs: int, logdir: str, logger_name: str, devices: list[int], precision: str):
130
+ """
131
+ Создаёт объект Trainer с заданными параметрами:
132
+ - logdir: путь к директории для логов TensorBoard
133
+ - logger_name: имя поддиректории для текущего эксперимента
134
+ - devices: количество GPU-устройств
135
+ - precision: режим числовой точности (например, "bf16-mixed")
136
+ """
137
+ logger = TensorBoardLogger(save_dir=logdir, name=logger_name)
138
+
139
+ # Если устройств больше одного — используем DDP, иначе оставляем стратегию по умолчанию
140
+ strategy = "ddp_find_unused_parameters_true" if len(devices) > 1 else "auto"
141
+
142
+ return pl.Trainer(
143
+ max_epochs=max_epochs,
144
+ accelerator="gpu" if torch.cuda.is_available() else "cpu",
145
+ devices=devices,
146
+ strategy=strategy,
147
+ precision=precision,
148
+ callbacks=[],
149
+ log_every_n_steps=10,
150
+ logger=logger,
151
+ )
152
+
153
+
154
+ @click.command()
155
+ @click.option(
156
+ "-r",
157
+ "--dataset-root",
158
+ type=click.Path(exists=True),
159
+ default=".",
160
+ show_default=True,
161
+ help="Корень датасета (JSON и DICOM-пути считаются относительно него).",
162
+ )
163
+ @click.option("--fold", type=int, default=4, show_default=True, help="Номер фолда.")
164
+ @click.option(
165
+ "-a",
166
+ "--artery",
167
+ type=str,
168
+ default="right",
169
+ show_default=True,
170
+ help="Название артерии: left или right.",
171
+ )
172
+ @click.option(
173
+ "-nc",
174
+ "--num-classes",
175
+ type=int,
176
+ default=2,
177
+ show_default=True,
178
+ help="Число выходных нейронов (обычно 2: clf + reg).",
179
+ )
180
+ @click.option("-b", "--batch-size", type=int, default=50, show_default=True, help="Размер batch.")
181
+ @click.option("-f", "--frames-per-clip", type=int, default=32, show_default=True, help="Число кадров в клипе.")
182
+ @click.option(
183
+ "-v",
184
+ "--video-size",
185
+ type=click.Tuple([int, int]),
186
+ default=(256, 256),
187
+ show_default=True,
188
+ help="Размер кадра (H, W).",
189
+ )
190
+ @click.option("--max-epochs", type=int, default=10, show_default=True, help="Число эпох для full train.")
191
+ @click.option("--num-workers", type=int, default=8, show_default=True, help="Число DataLoader workers.")
192
+ @click.option(
193
+ "--devices",
194
+ type=list[int],
195
+ multiple=True,
196
+ default=[0],
197
+ show_default=True,
198
+ help="Список GPU id",
199
+ )
200
+ @click.option("--precision", type=str, default="bf16-mixed", show_default=True, help="Режим точности.")
201
+ @click.option(
202
+ "--logdir",
203
+ type=click.Path(),
204
+ default="./logs/backbone",
205
+ show_default=True,
206
+ help="Каталог для логов и чекпоинтов backbone.",
207
+ )
208
+ @click.option(
209
+ "--use-weighted-sampler",
210
+ is_flag=True,
211
+ default=False,
212
+ show_default=True,
213
+ help="Использовать ли WeightedRandomSampler по score-интервалам.",
214
+ )
215
+ @click.option("--seed", type=int, default=42, show_default=True, help="Сид для воспроизводимости.")
216
+ def main(
217
+ dataset_root,
218
+ fold,
219
+ artery,
220
+ num_classes,
221
+ batch_size,
222
+ frames_per_clip,
223
+ video_size,
224
+ max_epochs,
225
+ num_workers,
226
+ devices,
227
+ precision,
228
+ logdir,
229
+ use_weighted_sampler,
230
+ seed,
231
+ ):
232
+ """
233
+ Точка входа для обучения backbone-модели.
234
+
235
+ Последовательность:
236
+ 1) pretrain: обучение только финального слоя fc
237
+ 2) full train: дообучение всей модели с началом из последнего чекпоинта pretrain.
238
+ """
239
+ # Фиксация сида во всех поддерживаемых библиотеках
240
+ pl.seed_everything(seed)
241
+
242
+ artery = artery.lower()
243
+ artery_bin = {"left": 0, "right": 1}.get(artery)
244
+ if artery_bin is None:
245
+ raise ValueError(f"Unknown artery '{artery}', expected 'left' or 'right'")
246
+
247
+ # Статистики ImageNet для нормализации входа
248
+ imagenet_mean = [0.485, 0.456, 0.406]
249
+ imagenet_std = [0.229, 0.224, 0.225]
250
+
251
+ # Пути к JSON-метаданным фолдов относительно dataset_root
252
+ train_meta = f"folds/step2_fold{fold:02d}_train.json"
253
+ eval_meta = f"folds/step2_fold{fold:02d}_eval.json"
254
+
255
+ # Инициализация тренировочного датасета
256
+ train_set = SyntaxDataset(
257
+ root=dataset_root,
258
+ meta=train_meta,
259
+ train=True,
260
+ length=frames_per_clip,
261
+ label=f"syntax_{artery}",
262
+ artery_bin=artery_bin,
263
+ validation=False,
264
+ transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=True),
265
+ )
266
+
267
+ # Инициализация валидационного датасета
268
+ val_set = SyntaxDataset(
269
+ root=dataset_root,
270
+ meta=eval_meta,
271
+ train=False,
272
+ length=frames_per_clip,
273
+ label=f"syntax_{artery}",
274
+ artery_bin=artery_bin,
275
+ validation=True,
276
+ transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=False),
277
+ )
278
+
279
+ # DataLoader'ы: для pretrain можно брать увеличенный batch
280
+ train_loader_pre = make_dataloader(train_set, batch_size * 2, num_workers, use_weighted_sampler)
281
+ train_loader_post = make_dataloader(train_set, batch_size, num_workers, use_weighted_sampler)
282
+ val_loader = make_dataloader(val_set, 1, num_workers, use_weighted_sampler=False)
283
+
284
+ # Получаем форму видео (C, T, H, W) из одного batch для информации
285
+ x, *_ = next(iter(train_loader_pre))
286
+ video_shape = x.shape[1:]
287
+ print(f"Backbone input video shape: {video_shape}")
288
+
289
+ # Callback'и для pretrain и full train
290
+ callbacks_pre = make_callbacks(phase="pre")
291
+ callbacks_full = make_callbacks(phase="full")
292
+
293
+ # ------------------- Pretrain (fc only) -------------------
294
+ num_pre_epochs = 10
295
+
296
+ model_pre = make_model(
297
+ num_classes=num_classes,
298
+ lr=3e-4,
299
+ weight_decay=0.01,
300
+ max_epochs=num_pre_epochs,
301
+ weight_path=None,
302
+ )
303
+
304
+ trainer_pre = make_trainer(
305
+ max_epochs=num_pre_epochs,
306
+ logdir=logdir,
307
+ logger_name=f"{artery}BinSyntax_R3D_pre_fold{fold:02d}",
308
+ devices=devices,
309
+ precision=precision,
310
+ )
311
+ trainer_pre.callbacks.extend(callbacks_pre)
312
+ trainer_pre.fit(model_pre, train_loader_pre, val_loader)
313
+
314
+ # ------------------- Full train (finetune) -------------------
315
+ model_full = make_model(
316
+ num_classes=num_classes,
317
+ lr=1e-4,
318
+ weight_decay=0.01,
319
+ max_epochs=max_epochs,
320
+ weight_path=trainer_pre.checkpoint_callback.last_model_path,
321
+ )
322
+
323
+ trainer_full = make_trainer(
324
+ max_epochs=max_epochs,
325
+ logdir=logdir,
326
+ logger_name=f"{artery}BinSyntax_R3D_full_fold{fold:02d}",
327
+ devices=devices,
328
+ precision=precision,
329
+ )
330
+ trainer_full.callbacks.extend(callbacks_full)
331
+ trainer_full.fit(model_full, train_loader_post, val_loader)
332
+
333
+
334
+ if __name__ == "__main__":
335
+ main()
backbone_weights/leftBinSyntax_R3D_full_fold00.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:299db032713b8b8247fa42e1bfcf993d6c0ea162a2d47211cd4ef83e8e7083ac
3
+ size 398135752
backbone_weights/leftBinSyntax_R3D_full_fold01.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fef1d7fdc6d2fc3e8b64b807c21a3d15af832bad8dfd1cbe59c1122c5f020a62
3
+ size 398135752
backbone_weights/leftBinSyntax_R3D_full_fold02.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81b7f04e2d6735c505fe50101c2f7dedacb1503681c14f28c6fd8007bb7a4255
3
+ size 398135752
backbone_weights/leftBinSyntax_R3D_full_fold03.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7decf5e0fec66d734901a8cb0c47a09b0d7ec33a341f245caa7aeee0011162fc
3
+ size 398135752
backbone_weights/leftBinSyntax_R3D_full_fold04.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e85aa8b198f29e9375d93a37e11eb130b73b9d8a03f68a60c7aa7fd16d899f0f
3
+ size 398135752
backbone_weights/rightBinSyntax_R3D_full_fold00.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6285307728b4c15a5f9ded2c311981d9c71f81bfbab4c11663fa73af57ffd35
3
+ size 398135752
backbone_weights/rightBinSyntax_R3D_full_fold01.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b77467117b220d4245d7c0bc00b9072719c0c563672a34a643d986986be32866
3
+ size 398135752
backbone_weights/rightBinSyntax_R3D_full_fold02.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d0d1fb3aac64364e2e99b0c25d17bf1650b15b9da71cbba4ea6dd3ce2cb2bb0
3
+ size 398135752
backbone_weights/rightBinSyntax_R3D_full_fold03.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c3387e5bb99b615d7c8d5a67b41d53f4cc4d872e82f7b391be4da3321490fdb
3
+ size 398135752
backbone_weights/rightBinSyntax_R3D_full_fold04.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e3cf35b8d0450f23bb105d6367efc8ab32bc73a0bc4170e0fc07fbe54e593f9
3
+ size 398135752
full_model/__init__.py ADDED
File without changes
full_model/rnn_dataset.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Any, Callable, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import pydicom
7
+ import torch
8
+ from torch import Tensor
9
+ from torch.utils.data import Dataset
10
+
11
+
12
+ DTYPE = torch.float16
13
+
14
+
15
+ class SyntaxDataset(Dataset):
16
+ """
17
+ Dataset для head‑модели (RNN/LSTM) поверх backbone.
18
+
19
+ Структура JSON:
20
+ [
21
+ {
22
+ "study_uid": "...",
23
+ "syntax_left": 12.5,
24
+ "syntax_right": 8.2,
25
+ "videos_left": [
26
+ {"path": "../data/anon_data/.../IM-0001-0001.dcm"},
27
+ ...
28
+ ],
29
+ "videos_right": [
30
+ {"path": "../data/anon_data/.../IM-0002-0001.dcm"},
31
+ ...
32
+ ],
33
+ },
34
+ ...
35
+ ]
36
+
37
+ ВАЖНО: поля "videos_{artery}[i]['path']" в JSON — пути к DICOM
38
+ относительно директории этого JSON (папка rnn_folds/).
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ root: str,
44
+ meta: str,
45
+ train: bool,
46
+ length: int,
47
+ label: str,
48
+ artery: str,
49
+ inference: bool = False,
50
+ validation: bool = False,
51
+ transform: Optional[Callable] = None,
52
+ ) -> None:
53
+ # Корень датасета, если meta задан относительно него
54
+ self.root = Path(root).resolve()
55
+ # Режим: обучение/валидация
56
+ self.train = train
57
+ # Длина одного клипа (по времени)
58
+ self.length = int(length)
59
+ # Имя поля с численным score ("syntax_left"/"syntax_right" или пусто)
60
+ self.label = label
61
+ # Артерия: "left" или "right"
62
+ self.artery = artery.lower()
63
+ # Режим инференса: возвращать все клипы без случайного выбора
64
+ self.inference = inference
65
+ # Флаг: использовать только записи с положительным score
66
+ self.validation = validation
67
+ # Трансформации для каждого видео‑клипа
68
+ self.transform = transform
69
+
70
+ # Полный путь к JSON с метаданными
71
+ meta_path = Path(meta)
72
+ if not meta_path.is_absolute():
73
+ meta_path = self.root / meta_path
74
+ meta_path = meta_path.resolve()
75
+
76
+ # База для путей к DICOM — директория JSON
77
+ self.base_dir = meta_path.parent
78
+
79
+ print(f"RNN Dataset: root={self.root}, meta={meta_path}, base_dir={self.base_dir}")
80
+
81
+ # Загрузка JSON
82
+ with open(meta_path, "r", encoding="utf-8") as f:
83
+ dataset = json.load(f)
84
+
85
+ # Убираем записи без видео по указанной артерии в режиме обучения/валидации
86
+ if not self.inference:
87
+ dataset = [rec for rec in dataset if len(rec.get(f"videos_{self.artery}", [])) > 0]
88
+
89
+ # Для валидации при необходимости фильтруем только положительные score
90
+ if validation and self.label:
91
+ dataset = [rec for rec in dataset if float(rec.get(self.label, 0.0)) > 0]
92
+
93
+ self.dataset = dataset
94
+ print(f"RNN Dataset loaded: {len(self.dataset)} samples after filtering")
95
+
96
+ # Коды артерий для порогов
97
+ artery_bin = {"left": 0, "right": 1}.get(self.artery)
98
+ if artery_bin is None:
99
+ raise ValueError(f"Unknown artery '{artery}', expected 'left' or 'right'")
100
+ self.artery_bin = artery_bin
101
+
102
+ def __len__(self) -> int:
103
+ # Размер датасета = количество записей (исследований)
104
+ return len(self.dataset)
105
+
106
+ def get_sample_weights(self) -> Tensor:
107
+ """
108
+ Возвращает веса примеров для WeightedRandomSampler по интервалам score.
109
+
110
+ Для каждой артерии задаются свои пороги, далее считается обратная частота
111
+ попадания в интервал.
112
+ """
113
+ bin_thresholds = {
114
+ 0: [0, 5, 10, 15], # левая
115
+ 1: [0, 2, 5, 8], # правая
116
+ }
117
+ thr0, thr1, thr2, thr3 = bin_thresholds[self.artery_bin]
118
+
119
+ def in_bin(score: float) -> int:
120
+ if score == thr0:
121
+ return 0
122
+ if thr0 < score <= thr1:
123
+ return 1
124
+ if thr1 < score <= thr2:
125
+ return 2
126
+ if thr2 < score <= thr3:
127
+ return 3
128
+ return 4
129
+
130
+ scores = [float(rec.get(self.label, 0.0)) for rec in self.dataset]
131
+ bins = [in_bin(s) for s in scores]
132
+
133
+ counts = np.bincount(np.array(bins, dtype=np.int64), minlength=5)
134
+ total = int(counts.sum())
135
+
136
+ weights_by_bin = np.array(
137
+ [(total / counts[b]) if counts[b] > 0 else 0.0 for b in range(5)],
138
+ dtype=np.float64,
139
+ )
140
+
141
+ weights = np.array([weights_by_bin[b] for b in bins], dtype=np.float64)
142
+ print(
143
+ "RNN sample weights counts:",
144
+ int(counts[0]),
145
+ int(counts[1]),
146
+ int(counts[2]),
147
+ int(counts[3]),
148
+ int(counts[4]),
149
+ )
150
+ return torch.as_tensor(weights, dtype=DTYPE)
151
+
152
+ def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor, Any]:
153
+ """
154
+ Возвращает:
155
+ clips: Tensor, стек клипов (N_clips, C, T, H, W) после transform
156
+ label: Tensor(1,), бинарная метка (0/1)
157
+ target: Tensor(1,), регрессионная цель (log1p(score))
158
+ suid: идентификатор исследования (study_uid)
159
+ """
160
+ rec = self.dataset[idx]
161
+ suid = rec["study_uid"]
162
+
163
+ # Формируем метки, если имя поля задано
164
+ if self.label:
165
+ bin_thresholds = {
166
+ 0: 15, # левая
167
+ 1: 5, # правая
168
+ }
169
+ score = float(rec.get(self.label, 0.0))
170
+ label = torch.tensor(
171
+ [1.0 if score > bin_thresholds[self.artery_bin] else 0.0],
172
+ dtype=DTYPE,
173
+ )
174
+ target = torch.tensor([np.log1p(score)], dtype=DTYPE)
175
+ else:
176
+ # В режиме чистого инференса на неизвестных данных можно не задавать label/target
177
+ label = torch.tensor([0.0], dtype=DTYPE)
178
+ target = torch.tensor([0.0], dtype=DTYPE)
179
+
180
+ videos_list = rec.get(f"videos_{self.artery}", [])
181
+ nv = len(videos_list)
182
+
183
+ # Выбор индексов клипов
184
+ if self.inference:
185
+ # В инференсе возвращаем все доступные клипы
186
+ if nv == 0:
187
+ # Пустая последовательность в крайнем случае
188
+ return torch.zeros(0), label, target, suid
189
+ seq_indices = range(nv)
190
+ else:
191
+ if nv == 0:
192
+ raise ValueError(f"No videos for artery={self.artery} in record {suid}")
193
+ # Случайный набор индексов клипов (например, 4 штуки)
194
+ seq_indices = torch.randint(low=0, high=nv, size=(4,))
195
+
196
+ clips = []
197
+ for vi in seq_indices:
198
+ vi_idx = int(vi)
199
+ video_rec = videos_list[vi_idx]
200
+ rel_path = video_rec["path"]
201
+
202
+ # Полный путь к DICOM: директория JSON + относительный путь
203
+ full_path = (self.base_dir / rel_path).resolve()
204
+ if not full_path.exists():
205
+ raise FileNotFoundError(
206
+ f"DICOM not found: {full_path}\n"
207
+ f" base_dir={self.base_dir}\n"
208
+ f" rel_path='{rel_path}'\n"
209
+ f" study={suid}"
210
+ )
211
+
212
+ # Загрузка DICOM
213
+ video = pydicom.dcmread(str(full_path)).pixel_array
214
+
215
+ # Ожидаем 3D‑массив по времени
216
+ if video.ndim != 3:
217
+ raise ValueError(f"Expected 3D video, got {video.shape} in {full_path}")
218
+
219
+ # Если время в последней оси, переносим её в первую
220
+ if video.shape[0] > 128 and video.shape[-1] <= 128:
221
+ video = np.moveaxis(video, -1, 0)
222
+
223
+ # Приведение к uint8
224
+ if video.dtype == np.uint16:
225
+ vmax = int(np.max(video))
226
+ if vmax <= 0:
227
+ raise ValueError(f"Invalid vmax={vmax} in {full_path}")
228
+ video = (video.astype(np.float32) * (255.0 / vmax)).clip(0, 255).astype(np.uint8)
229
+ else:
230
+ video = video.astype(np.uint8)
231
+
232
+ # Дублируем по времени до нужной длины
233
+ while video.shape[0] < self.length:
234
+ video = np.concatenate([video, video], axis=0)
235
+
236
+ t = int(video.shape[0])
237
+ if self.train:
238
+ # В обучении берём случайное окно
239
+ begin = torch.randint(low=0, high=t - self.length + 1, size=(1,)).item()
240
+ else:
241
+ # В валидации можно взять центральное окно
242
+ begin = (t - self.length) // 2
243
+ video = video[begin: begin + self.length]
244
+
245
+ # (T, H, W) → (T, H, W, C) с C=3
246
+ video = torch.from_numpy(np.stack([video, video, video], axis=-1))
247
+
248
+ # Применяем трансформации
249
+ if self.transform is not None:
250
+ video = self.transform(video) # ожидается (C, T, H, W)
251
+
252
+ clips.append(video)
253
+
254
+ # Стек клипов: (N_clips, C, T, H, W)
255
+ clips = torch.stack(clips, dim=0) if clips else torch.zeros(0, dtype=DTYPE)
256
+
257
+ return clips, label, target, suid
full_model/rnn_model.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Optional, Tuple
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn, optim
5
+ import lightning.pytorch as pl
6
+ import torchvision.models.video as tvmv
7
+ import sklearn.metrics as skm
8
+
9
+
10
+ """
11
+ The head of model
12
+ """
13
+ class SyntaxLightningModule(pl.LightningModule):
14
+ def __init__(
15
+ self,
16
+ num_classes,
17
+ lr: float,
18
+ variant: str, # mean, lstm_mean, lstm_last, gru_mean, gru_last, bert_mean, bert_cls, bert_cls2
19
+ weight_decay: float = 0,
20
+ max_epochs: int = None,
21
+ weight_path: str = None, # веса backbone (r3d_18) (ckpt или pt)
22
+ save_path: str = None, # путь для сохранения по лучшему auc (как было)
23
+ pl_weight_path: str = None, # полный Lightning‑чекпоинт всей модели
24
+ pt_weights_format: bool = False, # .pt формат
25
+ sigma_a: float = 0,
26
+ sigma_b: float = 1,
27
+ **kwargs,
28
+ ):
29
+ self.save_hyperparameters()
30
+ super().__init__()
31
+ self.num_classes = num_classes
32
+ self.save_path = save_path
33
+ self.weight_path = weight_path
34
+ self.variant = variant
35
+ self.sigma_a = sigma_a
36
+ self.sigma_b = sigma_b
37
+
38
+ # Video ResNet (backbone)
39
+ self.model = tvmv.r3d_18(weights=tvmv.R3D_18_Weights.DEFAULT)
40
+
41
+ self.lr = lr
42
+ self.loss_clf = nn.BCEWithLogitsLoss(reduction="none")
43
+ self.loss_reg = nn.MSELoss(reduction="none")
44
+
45
+ # Финальный слой backbone по умолчанию: 2 выхода (clf + reg)
46
+ in_features = self.model.fc.in_features
47
+ self.model.fc = nn.Linear(in_features=in_features, out_features=2, bias=True)
48
+
49
+ # Загрузка pretrain‑весов backbone (r3d_18), как раньше через weight_path,
50
+ # но теперь поддерживаются и .ckpt, и .pt/.pth
51
+ if weight_path is not None:
52
+ print("Load model weights (backbone)")
53
+ self.load_weights_backbone(weight_path, self.model)
54
+
55
+ # Выбор типа head
56
+ if self.variant != "mean_out":
57
+ self.model.fc = nn.Identity()
58
+
59
+ if self.variant == "mean_out":
60
+ # Только self.model с fc=Linear(…, 2)
61
+ pass
62
+ elif self.variant in ("gru_mean", "gru_last"):
63
+ self.rnn = nn.GRU(in_features, in_features // 4, batch_first=True)
64
+ self.dropout = nn.Dropout(0.2)
65
+ self.fc = nn.Linear(in_features=in_features // 4, out_features=num_classes, bias=True)
66
+ elif self.variant in ("lstm_mean", "lstm_last"):
67
+ self.lstm = nn.LSTM(
68
+ input_size=in_features,
69
+ hidden_size=in_features // 4,
70
+ proj_size=num_classes,
71
+ batch_first=True,
72
+ )
73
+ elif self.variant == "mean":
74
+ self.fc = nn.Linear(in_features=in_features, out_features=num_classes, bias=True)
75
+ elif self.variant in ("bert_mean", "bert_cls", "bert_cls2"):
76
+ encoder_layer = nn.TransformerEncoderLayer(
77
+ d_model=in_features,
78
+ nhead=4,
79
+ batch_first=True,
80
+ dim_feedforward=in_features // 4,
81
+ )
82
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
83
+ self.dropout = nn.Dropout(0.2)
84
+ self.fc = nn.Linear(in_features=in_features, out_features=num_classes, bias=True)
85
+ if self.variant == "bert_cls2":
86
+ self.cls = nn.Parameter(torch.randn(1, 1, in_features))
87
+ else:
88
+ raise ValueError(f"Unknown model variant {self.variant}")
89
+
90
+ # Загрузка полного Lightning‑чекпоинта (backbone + head), как было раньше
91
+ if pl_weight_path is not None:
92
+ print(f"Load LightningModule weights from {pl_weight_path}")
93
+
94
+ if pt_weights_format:
95
+ pl_state_dict = torch.load(pl_weight_path, weights_only=False)
96
+ else:
97
+ pl_state_dict = torch.load(pl_weight_path, weights_only=False)["state_dict"]
98
+
99
+ # Загружаем backbone
100
+ self.load_weights(pl_state_dict, self.model, "model")
101
+
102
+ # Загружаем head в зависимости от варианта
103
+ if self.variant == "mean_out":
104
+ pass # только self.model
105
+ elif self.variant in ("gru_mean", "gru_last"):
106
+ self.load_weights(pl_state_dict, self.rnn, "rnn")
107
+ self.load_weights(pl_state_dict, self.fc, "fc")
108
+ elif self.variant in ("lstm_mean", "lstm_last"):
109
+ self.load_weights(pl_state_dict, self.lstm, "lstm")
110
+ elif self.variant == "mean":
111
+ self.load_weights(pl_state_dict, self.fc, "fc")
112
+ elif self.variant in ("bert_mean", "bert_cls", "bert_cls2"):
113
+ self.load_weights(pl_state_dict, self.encoder, "encoder")
114
+ self.load_weights(pl_state_dict, self.fc, "fc")
115
+ if self.variant == "bert_cls2":
116
+ old_shape = self.cls.shape
117
+ self.cls = nn.Parameter(pl_state_dict["cls"])
118
+ assert old_shape == self.cls.shape
119
+ else:
120
+ raise ValueError(f"Unknown model variant {self.variant}")
121
+
122
+ self.max_epochs = max_epochs
123
+ self.weight_decay = weight_decay
124
+
125
+ self.y_val = []
126
+ self.p_val = []
127
+ self.r_val = []
128
+ self.ty_val = []
129
+ self.tp_val = []
130
+
131
+ def load_weights_backbone(self, weight_path: str, model: nn.Module) -> None:
132
+ """
133
+ Универсальная загрузка весов backbone (r3d_18):
134
+
135
+ - если файл — Lightning‑чекпоинт (dict с 'state_dict'),
136
+ берём state_dict['state_dict'] и убираем префикс 'model.'.
137
+ - если файл — "голый" state_dict (.pt/.pth), сохранённый через model.state_dict(),
138
+ загружаем его напрямую.
139
+
140
+ При этом перед загрузкой убираем из state_dict все ключи с несовпадающим размером
141
+ (например, fc.weight/fc.bias при разном числе выходов).
142
+ """
143
+ obj = torch.load(weight_path, weights_only=False, map_location="cpu")
144
+
145
+ if isinstance(obj, dict) and "state_dict" in obj:
146
+ raw_state = obj["state_dict"]
147
+ state_dict = {k.replace("model.", ""): v for k, v in raw_state.items()}
148
+ src_type = "lightning_checkpoint"
149
+ else:
150
+ state_dict = obj
151
+ src_type = "raw_state_dict"
152
+
153
+ current_state = model.state_dict()
154
+ filtered_state = {}
155
+ mismatched_keys = []
156
+
157
+ # Оставляем только те веса, у которых совпадает размер тензора
158
+ for k, v in state_dict.items():
159
+ if k in current_state and current_state[k].shape == v.shape:
160
+ filtered_state[k] = v
161
+ else:
162
+ # либо вообще нет такого ключа, либо размер не совпадает
163
+ mismatched_keys.append(k)
164
+
165
+ # Загружаем только совместимые веса
166
+ incompatible = model.load_state_dict(filtered_state, strict=False)
167
+
168
+ loaded_keys = [k for k in filtered_state.keys() if k not in incompatible.missing_keys]
169
+ print(
170
+ f"[Backbone] Loaded weights from '{weight_path}' "
171
+ f"(type={src_type}): {len(loaded_keys)} params, "
172
+ f"missing={len(incompatible.missing_keys)}, "
173
+ f"unexpected={len(incompatible.unexpected_keys)}, "
174
+ f"skipped_mismatched={len(mismatched_keys)}"
175
+ )
176
+ if mismatched_keys:
177
+ print(f"[Backbone] Size‑mismatched keys (skipped), example: {mismatched_keys[:5]}")
178
+ if incompatible.missing_keys:
179
+ print(f"[Backbone] Missing keys after filtering, example: {incompatible.missing_keys[:5]}")
180
+ if incompatible.unexpected_keys:
181
+ print(f"[Backbone] Unexpected keys after filtering, example: {incompatible.unexpected_keys[:5]}")
182
+
183
+ def load_weights(self, state_dict, module, prefix: str):
184
+ """Фильтруем и грузим только те веса, которые относятся к конкретному модулю."""
185
+ module_state = {
186
+ k.replace(f"{prefix}.", ""): v
187
+ for k, v in state_dict.items()
188
+ if k.startswith(prefix)
189
+ }
190
+ missing, unexpected = module.load_state_dict(module_state, strict=False)
191
+ if missing:
192
+ print(f"Missing keys for {prefix}: {missing}")
193
+ if unexpected:
194
+ print(f"Unexpected keys for {prefix}: {unexpected}")
195
+
196
+ def forward(self, x):
197
+ # x: (batch, seq, C, T, H, W)
198
+ batch_seq_shape = x.shape[0:2]
199
+ x = torch.flatten(x, start_dim=0, end_dim=1) # (batch*seq, C, T, H, W)
200
+ x = self.model(x)
201
+ x = torch.unflatten(x, 0, batch_seq_shape) # (batch, seq, feat)
202
+
203
+ if self.variant == "mean_out":
204
+ x = torch.mean(x, dim=1)
205
+ elif self.variant in ("gru_mean", "gru_last"):
206
+ _all_outs_, [_last_out_] = self.rnn(x)
207
+ if self.variant == "gru_mean":
208
+ x = torch.mean(_all_outs_, dim=1)
209
+ else:
210
+ x = _last_out_
211
+ x = self.dropout(x)
212
+ x = self.fc(x)
213
+ elif self.variant in ("lstm_mean", "lstm_last"):
214
+ _all_outs_, (_last_out_, _last_state_) = self.lstm(x)
215
+ if self.variant == "lstm_mean":
216
+ x = torch.mean(_all_outs_, dim=1)
217
+ else:
218
+ x = _last_out_
219
+ elif self.variant == "mean":
220
+ x = torch.mean(x, dim=1)
221
+ x = self.fc(x)
222
+ elif self.variant in ("bert_mean", "bert_cls", "bert_cls2"):
223
+ if self.variant == "bert_cls":
224
+ x = F.pad(x, (0, 0, 1, 0), "constant", 0)
225
+ elif self.variant == "bert_cls2":
226
+ bs = x.size(0)
227
+ x = torch.cat([self.cls.expand(bs, -1, -1), x], dim=1)
228
+ x = self.encoder(x)
229
+ if self.variant == "bert_mean":
230
+ x = torch.mean(x, dim=1)
231
+ else:
232
+ x = x[:, 0, :]
233
+ x = self.dropout(x)
234
+ x = self.fc(x)
235
+ else:
236
+ raise ValueError(f"Unknown model variant {self.variant}")
237
+
238
+ return x
239
+
240
+ def training_step(self, batch, batch_idx):
241
+ x, y, target, path = batch
242
+ y_hat = self(x)
243
+ yp_clf = y_hat[:, 0:1]
244
+ yp_reg = y_hat[:, 1:]
245
+
246
+ weights_clf = torch.where(y > 0, 1.0, 0.2)
247
+ clf_loss = self.loss_clf(yp_clf, y)
248
+ clf_loss = (clf_loss * weights_clf).mean()
249
+
250
+ reg_loss_raw = self.loss_reg(yp_reg, target)
251
+ sigma = self.sigma_a * target + self.sigma_b
252
+ reg_loss = (reg_loss_raw / (sigma ** 2)).mean()
253
+
254
+ loss = clf_loss + 0.5 * reg_loss
255
+
256
+ y_pred = torch.sigmoid(yp_clf)
257
+ y_bin = torch.round(y.cpu().detach()).int()
258
+ y_pred_bin = torch.round(y_pred.cpu().detach()).int()
259
+
260
+ self.log("train_clf_loss", clf_loss, prog_bar=True, sync_dist=True)
261
+ self.log("train_val_loss", reg_loss, prog_bar=True, sync_dist=True)
262
+ self.log("train_full_loss", loss, prog_bar=True, sync_dist=True)
263
+ self.log("train_f1", skm.f1_score(y_bin, y_pred_bin, zero_division=0),
264
+ prog_bar=True, sync_dist=True)
265
+ self.log("train_acc", skm.accuracy_score(y_bin, y_pred_bin),
266
+ prog_bar=True, sync_dist=True)
267
+
268
+ return loss
269
+
270
+ def validation_step(self, batch, batch_idx):
271
+ x, y, target, path = batch
272
+ y_hat = self(x)
273
+ yp_clf = y_hat[:, 0:1]
274
+ yp_reg = y_hat[:, 1:]
275
+
276
+ loss = self.loss_clf(yp_clf, y)
277
+ reg_loss_raw = self.loss_reg(yp_reg, target)
278
+ loss = loss.mean()
279
+
280
+ y_pred = torch.sigmoid(yp_clf)
281
+
282
+ self.y_val.append(int(y[..., 0].cpu()))
283
+ self.p_val.append(float(y_pred[..., 0].cpu()))
284
+ self.r_val.append(round(float(y_pred[..., 0].cpu())))
285
+
286
+ self.ty_val.append(float(target[..., 0].cpu()))
287
+ self.tp_val.append(float(yp_reg[..., 0].cpu()))
288
+
289
+ clf_loss = self.loss_clf(yp_clf, y)
290
+ reg_loss_raw = self.loss_reg(yp_reg, target)
291
+ sigma = self.sigma_a * target + self.sigma_b
292
+ reg_loss = (reg_loss_raw / (sigma ** 2)).mean()
293
+
294
+ loss = clf_loss + 0.5 * reg_loss
295
+
296
+ return loss
297
+
298
+ def on_validation_epoch_end(self):
299
+ try:
300
+ auc = skm.roc_auc_score(self.y_val, self.p_val)
301
+ f1 = skm.f1_score(self.y_val, self.r_val, zero_division=0)
302
+ acc = skm.accuracy_score(self.y_val, self.r_val)
303
+ mae = skm.mean_absolute_error(self.y_val, self.r_val)
304
+ self.log("val_auc", auc, prog_bar=True, sync_dist=True)
305
+ self.log("val_f1", f1, prog_bar=True, sync_dist=True)
306
+ self.log("val_acc", acc, prog_bar=True, sync_dist=True)
307
+ self.log("val_mae", mae, prog_bar=True, sync_dist=True)
308
+
309
+ rmse = skm.root_mean_squared_error(self.ty_val, self.tp_val)
310
+ self.log("val_rmse", rmse, prog_bar=True, sync_dist=True)
311
+
312
+ except ValueError as err:
313
+ print(err)
314
+ print("Y_VAL", self.y_val)
315
+ print("P_VAL", self.p_val)
316
+ self.y_val.clear()
317
+ self.p_val.clear()
318
+ self.r_val.clear()
319
+ self.ty_val.clear()
320
+ self.tp_val.clear()
321
+
322
+ def on_train_epoch_end(self) -> None:
323
+ self.log(
324
+ "lr",
325
+ self.optimizers().optimizer.param_groups[0]["lr"],
326
+ on_step=False,
327
+ on_epoch=True,
328
+ sync_dist=True,
329
+ )
330
+
331
+ def configure_optimizers(self):
332
+ # Сначала определяем, какие модули тренируем
333
+ if self.weight_path: # pretrain without video backbone
334
+ if self.variant == "mean_out":
335
+ trainable_modules = [self.model.fc]
336
+ elif self.variant in ("gru_mean", "gru_last"):
337
+ trainable_modules = [self.rnn, self.fc]
338
+ elif self.variant in ("lstm_mean", "lstm_last"):
339
+ trainable_modules = [self.lstm]
340
+ elif self.variant == "mean":
341
+ trainable_modules = [self.fc]
342
+ elif self.variant in ("bert_mean", "bert_cls", "bert_cls2"):
343
+ trainable_modules = [self.encoder, self.fc]
344
+ if self.variant == "bert_cls2":
345
+ trainable_modules.append(self.cls)
346
+ else:
347
+ trainable_modules = []
348
+
349
+ for param in self.parameters():
350
+ param.requires_grad = False
351
+
352
+ for m in trainable_modules:
353
+ for p in m.parameters():
354
+ p.requires_grad = True
355
+
356
+ params = [p for m in trainable_modules for p in m.parameters()]
357
+ else:
358
+ for param in self.parameters():
359
+ param.requires_grad = True
360
+ params = self.parameters()
361
+
362
+ optimizer = optim.Adam(params, lr=self.lr, weight_decay=self.weight_decay)
363
+
364
+ if self.max_epochs is not None:
365
+ lr_scheduler = optim.lr_scheduler.OneCycleLR(
366
+ optimizer=optimizer, max_lr=self.lr, total_steps=self.max_epochs
367
+ )
368
+ return [optimizer], [lr_scheduler]
369
+ else:
370
+ return optimizer
371
+
372
+ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
373
+ """Инференс"""
374
+ x, y, target, path = batch
375
+ y_hat = self(x)
376
+ yp_clf = y_hat[:, 0:1]
377
+ yp_reg = y_hat[:, 1:]
378
+ y_pred = torch.sigmoid(yp_clf)
379
+
380
+ return {
381
+ "y": y,
382
+ "y_pred": torch.round(y_pred),
383
+ "y_prob": y_pred,
384
+ "y_reg": yp_reg,
385
+ "target": target,
386
+ }
full_model/rnn_train.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # full_model/rnn_train.py
2
+ import os
3
+ import glob
4
+ from typing import Any
5
+
6
+ import click
7
+ import lightning.pytorch as pl
8
+ import torch
9
+ from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
10
+ from lightning.pytorch.loggers import TensorBoardLogger
11
+ from pytorchvideo.transforms import Normalize, Permute, RandAugment
12
+ from torch.utils.data import DataLoader, WeightedRandomSampler
13
+ from torchvision.transforms import transforms as T
14
+ from torchvision.transforms._transforms_video import ToTensorVideo
15
+ from torchvision.transforms import InterpolationMode
16
+
17
+ from full_model.rnn_dataset import SyntaxDataset
18
+ from full_model.rnn_model import SyntaxLightningModule
19
+
20
+ torch.set_float32_matmul_precision("medium")
21
+
22
+
23
+ def get_transforms(video_size, imagenet_mean, imagenet_std, train: bool = True):
24
+ """Аугментации/преобразования для клипов."""
25
+ interpolation_choices = [InterpolationMode.BILINEAR, InterpolationMode.BICUBIC]
26
+
27
+ if train:
28
+ return T.Compose([
29
+ ToTensorVideo(),
30
+ Permute(dims=[1, 0, 2, 3]), # (C, T, H, W) -> (T, C, H, W)
31
+ RandAugment(magnitude=10, num_layers=2),
32
+ T.RandomHorizontalFlip(),
33
+ Permute(dims=[1, 0, 2, 3]), # обратно (C, T, H, W)
34
+ T.RandomChoice([
35
+ T.Resize(size=video_size, interpolation=interp, antialias=True)
36
+ for interp in interpolation_choices
37
+ ]),
38
+ Normalize(mean=imagenet_mean, std=imagenet_std),
39
+ ])
40
+ else:
41
+ return T.Compose([
42
+ ToTensorVideo(),
43
+ T.Resize(size=video_size, interpolation=InterpolationMode.BICUBIC, antialias=True),
44
+ Normalize(mean=imagenet_mean, std=imagenet_std),
45
+ ])
46
+
47
+
48
+ def make_dataloader(dataset, batch_size: int, num_workers: int, use_weighted_sampler: bool):
49
+ """DataLoader c опциональным WeightedRandomSampler по score."""
50
+ if use_weighted_sampler:
51
+ sample_weights = dataset.get_sample_weights().cpu()
52
+ sampler = WeightedRandomSampler(sample_weights, num_samples=len(dataset), replacement=True)
53
+ shuffle = False
54
+ else:
55
+ sampler = None
56
+ shuffle = False
57
+
58
+ return DataLoader(
59
+ dataset,
60
+ batch_size=batch_size,
61
+ num_workers=num_workers,
62
+ sampler=sampler,
63
+ shuffle=shuffle,
64
+ drop_last=True,
65
+ pin_memory=True,
66
+ persistent_workers=(num_workers > 0),
67
+ )
68
+
69
+
70
+ def make_model(
71
+ num_classes: int,
72
+ lr: float,
73
+ variant: str,
74
+ weight_decay: float,
75
+ max_epochs: int,
76
+ weight_path: str | None = None,
77
+ pl_weight_path: str | None = None,
78
+ pt_weights_format: bool = False,
79
+ ) -> SyntaxLightningModule:
80
+ """
81
+ Создание head‑модели.
82
+
83
+ weight_path — pretrain для backbone (r3d_18), .pt или .ckpt.
84
+ pl_weight_path — полный чекпоинт head‑модели (Lightning .ckpt или raw .pt).
85
+ pt_weights_format=True → pl_weight_path — raw state_dict (.pt).
86
+ pt_weights_format=False → pl_weight_path — Lightning .ckpt с 'state_dict'.
87
+ """
88
+ return SyntaxLightningModule(
89
+ num_classes=num_classes,
90
+ lr=lr,
91
+ variant=variant,
92
+ weight_decay=weight_decay,
93
+ max_epochs=max_epochs,
94
+ weight_path=weight_path,
95
+ pl_weight_path=pl_weight_path,
96
+ yulie_model=pt_weights_format, # параметр модели
97
+ )
98
+
99
+
100
+ def make_callbacks(phase: str):
101
+ """Callback'и: LR‑монитор + ModelCheckpoint по val_rmse."""
102
+ lr_monitor = LearningRateMonitor(logging_interval="epoch")
103
+
104
+ if phase == "pre":
105
+ checkpoint = ModelCheckpoint(
106
+ monitor="val_rmse",
107
+ save_top_k=1,
108
+ mode="min",
109
+ filename="rnn_model-{epoch:02d}-{val_rmse:.3f}",
110
+ save_last=True,
111
+ )
112
+ elif phase == "full":
113
+ checkpoint = ModelCheckpoint(
114
+ monitor="val_rmse",
115
+ save_top_k=3,
116
+ mode="min",
117
+ filename="rnn_model-{epoch:02d}-{val_rmse:.3f}",
118
+ save_last=True,
119
+ )
120
+ else:
121
+ raise ValueError(f"Unknown phase '{phase}', expected 'pre' or 'full'")
122
+
123
+ return [lr_monitor, checkpoint]
124
+
125
+
126
+ def make_trainer(max_epochs: int, logdir: str, logger_name: str, devices: list[int], precision: str, callbacks):
127
+ """Создание Trainer с TensorBoard‑логгером."""
128
+ logger = TensorBoardLogger(save_dir=logdir, name=logger_name)
129
+ strategy = "ddp_find_unused_parameters_true" if len(devices) > 1 else "auto"
130
+
131
+ trainer = pl.Trainer(
132
+ max_epochs=max_epochs,
133
+ accelerator="gpu" if torch.cuda.is_available() else "cpu",
134
+ devices=devices,
135
+ strategy=strategy,
136
+ precision=precision,
137
+ callbacks=callbacks,
138
+ log_every_n_steps=10,
139
+ logger=logger,
140
+ )
141
+ return trainer
142
+
143
+
144
+ def find_backbone_ckpt_lightning(backbone_logdir: str, artery: str, fold: int, phase: str = "full") -> str:
145
+ """
146
+ Ищет Lightning‑чекпоинт backbone в каталоге логов.
147
+
148
+ Ожидаемая структура:
149
+ backbone_logdir/
150
+ {artery}BinSyntax_R3D_{phase}_foldXX/version_*/checkpoints/*.ckpt
151
+ """
152
+ logger_name = f"{artery}BinSyntax_R3D_{phase}_fold{fold:02d}"
153
+ pattern = os.path.join(backbone_logdir, logger_name, "version_*/checkpoints", "*.ckpt")
154
+ ckpts = glob.glob(pattern)
155
+ if not ckpts:
156
+ raise FileNotFoundError(
157
+ f"No backbone Lightning checkpoints found for\n"
158
+ f" artery={artery}, fold={fold}, phase={phase}\n"
159
+ f" in '{backbone_logdir}' (pattern: {pattern})"
160
+ )
161
+ best = max(ckpts, key=os.path.getctime)
162
+ print(f"[Backbone] Using Lightning checkpoint: {best}")
163
+ return best
164
+
165
+
166
+ def build_backbone_pt_path(backbone_pt_dir: str, artery: str, fold: int) -> str:
167
+ """
168
+ Строит путь к .pt‑файлу backbone по соглашению:
169
+ rightBinSyntax_R3D_full_fold00.pt
170
+ leftBinSyntax_R3D_full_fold00.pt
171
+ ...
172
+ """
173
+ fname = f"{artery}BinSyntax_R3D_full_fold{fold:02d}.pt"
174
+ path = os.path.join(backbone_pt_dir, fname)
175
+ if not os.path.exists(path):
176
+ raise FileNotFoundError(
177
+ f"Backbone .pt not found for artery={artery}, fold={fold} in '{backbone_pt_dir}'\n"
178
+ f"Expected file: {fname}"
179
+ )
180
+ print(f"[Backbone] Using .pt file: {path}")
181
+ return path
182
+
183
+
184
+ @click.command()
185
+ @click.option(
186
+ "-r",
187
+ "--dataset-root",
188
+ type=click.Path(exists=True),
189
+ default=".",
190
+ show_default=True,
191
+ help="Корень датасета (JSON и DICOM‑пути считаются относительно него).",
192
+ )
193
+ @click.option("--fold", type=int, default=4, show_default=True, help="Fold number.")
194
+ @click.option(
195
+ "-a",
196
+ "--artery",
197
+ type=str,
198
+ default="right",
199
+ show_default=True,
200
+ help="Артерия: left или right.",
201
+ )
202
+ @click.option(
203
+ "--variant",
204
+ type=str,
205
+ default="lstm_mean",
206
+ show_default=True,
207
+ help="Вариант head‑модели: mean_out, mean, lstm_mean, lstm_last, gru_mean, gru_last, bert_mean, bert_cls, bert_cls2.",
208
+ )
209
+ @click.option("-nc", "--num-classes", type=int, default=2, show_default=True,
210
+ help="Число выходов head‑модели (clf + reg).")
211
+ @click.option("-b", "--batch-size", type=int, default=8, show_default=True, help="Batch size.")
212
+ @click.option("-f", "--frames-per-clip", type=int, default=32, show_default=True,
213
+ help="Количество кадров в клипе.")
214
+ @click.option(
215
+ "-v",
216
+ "--video-size",
217
+ type=click.Tuple([int, int]),
218
+ default=(256, 256),
219
+ show_default=True,
220
+ help="Размер кадра (H, W).",
221
+ )
222
+ @click.option("--max-epochs", type=int, default=10, show_default=True, help="Число эпох full train.")
223
+ @click.option("--num-workers", type=int, default=16, show_default=True, help="DataLoader workers.")
224
+ @click.option(
225
+ "--devices",
226
+ type=list[int],
227
+ multiple=True,
228
+ default=[0],
229
+ show_default=True,
230
+ help="Список GPU id",
231
+ )
232
+ @click.option("--precision", type=str, default="bf16-mixed", show_default=True, help="Режим числовой точности.")
233
+ @click.option(
234
+ "--logdir",
235
+ type=click.Path(),
236
+ default="./logs/rnn",
237
+ show_default=True,
238
+ help="Каталог для логов и чекпоинтов head‑модели.",
239
+ )
240
+ @click.option(
241
+ "--backbone-logdir",
242
+ type=click.Path(exists=True),
243
+ default=None,
244
+ help="Каталог с логами backbone (Lightning .ckpt).",
245
+ )
246
+ @click.option(
247
+ "--backbone-pt-dir",
248
+ type=click.Path(exists=True),
249
+ default="backbone_weights",
250
+ show_default=True,
251
+ help="Каталог с .pt‑файлами backbone (rightBinSyntax_R3D_full_foldXX.pt, leftBinSyntax_R3D_full_foldXX.pt).",
252
+ )
253
+ @click.option(
254
+ "--backbone-from-pt",
255
+ is_flag=True,
256
+ default=True,
257
+ show_default=True,
258
+ help="Если включено — backbone берётся из .pt в backbone-pt-dir, иначе из Lightning‑логов backbone-logdir.",
259
+ )
260
+ @click.option(
261
+ "--rnn-folds-dir",
262
+ type=click.Path(),
263
+ default="rnn_folds",
264
+ show_default=True,
265
+ help="Каталог с rnn_folds (относительно dataset_root).",
266
+ )
267
+ @click.option(
268
+ "--use-weighted-sampler",
269
+ is_flag=True,
270
+ default=False,
271
+ show_default=True,
272
+ help="Использовать ли WeightedRandomSampler по score.",
273
+ )
274
+ @click.option(
275
+ "--pt-weights-format",
276
+ is_flag=True,
277
+ default=False,
278
+ show_default=True,
279
+ help="Формат pl_weight_path для full‑трейна: True → .pt (raw state_dict), False → Lightning .ckpt.",
280
+ )
281
+ @click.option("--seed", type=int, default=42, show_default=True, help="Random seed.")
282
+ def main(
283
+ dataset_root: str,
284
+ fold: int,
285
+ artery: str,
286
+ variant: str,
287
+ num_classes: int,
288
+ batch_size: int,
289
+ frames_per_clip: int,
290
+ video_size: Any,
291
+ max_epochs: int,
292
+ num_workers: int,
293
+ devices: int,
294
+ precision: str,
295
+ logdir: str,
296
+ backbone_logdir: str | None,
297
+ backbone_pt_dir: str | None,
298
+ backbone_from_pt: bool,
299
+ rnn_folds_dir: str,
300
+ use_weighted_sampler: bool,
301
+ pt_weights_format: bool,
302
+ seed: int,
303
+ ):
304
+ """Обучение RNN‑head поверх backbone."""
305
+ VARIANTS = "mean_out mean lstm_mean lstm_last gru_mean gru_last bert_mean bert_cls bert_cls2".split()
306
+ if variant not in VARIANTS:
307
+ raise ValueError(f"Unknown variant '{variant}', expected one of: {VARIANTS}")
308
+
309
+ artery = artery.lower()
310
+ if artery not in ("left", "right"):
311
+ raise ValueError(f"Unknown artery '{artery}', expected 'left' or 'right'")
312
+
313
+ pl.seed_everything(seed)
314
+
315
+ imagenet_mean = [0.485, 0.456, 0.406]
316
+ imagenet_std = [0.229, 0.224, 0.225]
317
+
318
+ train_meta = os.path.join(rnn_folds_dir, f"rnn_fold{fold:02d}_train.json")
319
+ eval_meta = os.path.join(rnn_folds_dir, f"rnn_fold{fold:02d}_eval.json")
320
+
321
+ train_set = SyntaxDataset(
322
+ root=dataset_root,
323
+ meta=train_meta,
324
+ train=True,
325
+ length=frames_per_clip,
326
+ label=f"syntax_{artery}",
327
+ artery=artery,
328
+ inference=False,
329
+ validation=True,
330
+ transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=True),
331
+ )
332
+
333
+ val_set = SyntaxDataset(
334
+ root=dataset_root,
335
+ meta=eval_meta,
336
+ train=False,
337
+ length=frames_per_clip,
338
+ label=f"syntax_{artery}",
339
+ artery=artery,
340
+ inference=False,
341
+ validation=True,
342
+ transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=False),
343
+ )
344
+
345
+ train_loader_pre = make_dataloader(train_set, batch_size * 2, num_workers, use_weighted_sampler)
346
+ train_loader_post = make_dataloader(train_set, batch_size, num_workers, use_weighted_sampler)
347
+ val_loader = make_dataloader(val_set, 1, num_workers, use_weighted_sampler=False)
348
+
349
+ x, *_ = next(iter(train_loader_pre))
350
+ video_shape = x.shape[2:]
351
+ print(f"RNN head input per clip: {video_shape}")
352
+
353
+ # Выбор источника backbone
354
+ if backbone_from_pt:
355
+ if backbone_pt_dir is None:
356
+ raise ValueError("backbone-from-pt=True, но backbone-pt-dir не указан.")
357
+ backbone_weight_path = build_backbone_pt_path(backbone_pt_dir, artery=artery, fold=fold)
358
+ else:
359
+ if backbone_logdir is None:
360
+ raise ValueError("backbone-from-pt=False, но backbone-logdir не указан.")
361
+ backbone_weight_path = find_backbone_ckpt_lightning(
362
+ backbone_logdir=backbone_logdir,
363
+ artery=artery,
364
+ fold=fold,
365
+ phase="full",
366
+ )
367
+
368
+ # Pretrain head (замороженный backbone)
369
+ callbacks_pre = make_callbacks(phase="pre")
370
+
371
+ model_pre = make_model(
372
+ num_classes=num_classes,
373
+ lr=1e-4,
374
+ variant=variant,
375
+ weight_decay=0.01,
376
+ max_epochs=max_epochs,
377
+ weight_path=backbone_weight_path,
378
+ pl_weight_path=None,
379
+ pt_weights_format=False,
380
+ )
381
+
382
+ trainer_pre = make_trainer(
383
+ max_epochs=max_epochs,
384
+ logdir=logdir,
385
+ logger_name=f"{artery}BinSyntax_R3D_fold{fold:02d}_{variant}_pre",
386
+ devices=devices,
387
+ precision=precision,
388
+ callbacks=callbacks_pre,
389
+ )
390
+ trainer_pre.fit(model_pre, train_dataloaders=train_loader_pre, val_dataloaders=val_loader)
391
+
392
+ # Full train head
393
+ callbacks_full = make_callbacks(phase="full")
394
+
395
+ model_full = make_model(
396
+ num_classes=num_classes,
397
+ lr=2e-5,
398
+ variant=variant,
399
+ weight_decay=0.01,
400
+ max_epochs=max_epochs,
401
+ weight_path=None,
402
+ pl_weight_path=trainer_pre.checkpoint_callback.best_model_path,
403
+ pt_weights_format=pt_weights_format,
404
+ )
405
+
406
+ trainer_full = make_trainer(
407
+ max_epochs=max_epochs,
408
+ logdir=logdir,
409
+ logger_name=f"{artery}BinSyntax_R3D_fold{fold:02d}_{variant}_post",
410
+ devices=devices,
411
+ precision=precision,
412
+ callbacks=callbacks_full,
413
+ )
414
+ trainer_full.fit(model_full, train_dataloaders=train_loader_post, val_dataloaders=val_loader)
415
+
416
+
417
+ if __name__ == "__main__":
418
+ main()
full_model_weights/LeftBinSyntax_R3D_fold00_lstm_mean_post_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1e9b585f99e863620185bc88f724488f0ac09e8cf25aa8ddc9a120fd893a99b
3
+ size 133809489
full_model_weights/LeftBinSyntax_R3D_fold01_lstm_mean_post_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e8132d382b352afd22bb3fe3e1d4d7e3a9781f8a927018db9bc5085d0e8d109
3
+ size 133809489
full_model_weights/LeftBinSyntax_R3D_fold02_lstm_mean_post_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c18065b289591280a8051af4aa4124ffa03a42df69c876f125be5abaf7912486
3
+ size 133809489
full_model_weights/LeftBinSyntax_R3D_fold03_lstm_mean_post_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf88400ea1a47d5a376ef9b366d0126036ddd02e3c1ba9071522fbac42b941bc
3
+ size 133809489
full_model_weights/LeftBinSyntax_R3D_fold04_lstm_mean_post_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eea9012bb3d975ee9cb2d5b651e401c47801400d3a2cc6da04b3c1f761be1793
3
+ size 133809489
full_model_weights/RightBinSyntax_R3D_fold00_lstm_mean_post_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f47a7bd428011d8570bfcca7b36ff8cc074d07303e0e6477a0b9e398d72bbe2
3
+ size 133809614
full_model_weights/RightBinSyntax_R3D_fold01_lstm_mean_post_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3f0857954b5af689d6582fb95bad39f7e9e7575bc49187cc36dbd19211847dc
3
+ size 133809614
full_model_weights/RightBinSyntax_R3D_fold02_lstm_mean_post_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59b3b3719e67aed73e8d087f1e748da852f7c30a390610d4b4f676d60fbc3f89
3
+ size 133809614
full_model_weights/RightBinSyntax_R3D_fold03_lstm_mean_post_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15adb5ce5a9f3473dfcead2ac841c3c3e83ef9ad46dea9d4fa7f373e3b0df7c5
3
+ size 133809614
full_model_weights/RightBinSyntax_R3D_fold04_lstm_mean_post_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2312b11b6c865bf8f751a421557b1a9813a33f71906978518588d0a4d193f3e
3
+ size 133809614
inference/__init__.py ADDED
File without changes
inference/metrics_visualization.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SYNTAX predictions visualization:
3
+ - points (SYNTAX ground truth vs model predictions) for multiple datasets;
4
+ - risk zones (low / high risk);
5
+ - ±σ and ±2σ bands around the diagonal;
6
+ - logistic trends for each dataset.
7
+
8
+ The script is independent of PyTorch/Lightning and is used at inference time.
9
+ Output is saved to the `visualizations/` folder inside the project.
10
+ """
11
+
12
+ import os
13
+ import numpy as np
14
+ import plotly.graph_objects as go
15
+ from scipy.optimize import curve_fit # type: ignore
16
+
17
+
18
+ # ================= GLOBAL STYLE CONSTANTS =================
19
+
20
+ DATA_MIN = 0.0
21
+ DATA_MAX = 60.0
22
+ PADDING = 0.5
23
+
24
+ SIGMA_SLOPE = 0.15
25
+ SIGMA_BASE = 1.4
26
+ SIGMA_POINTS = 400
27
+ TREND_POINTS = 500
28
+
29
+ PLOT_WIDTH = 980
30
+ PLOT_HEIGHT = 980
31
+
32
+ # Fonts
33
+ FONT_FAMILY = "Inter, Roboto, Helvetica Neue, Arial, sans-serif"
34
+ BASE_FONT_SIZE = 20
35
+ TITLE_FONT_SIZE = 26
36
+ AXIS_TITLE_FONT_SIZE = 32
37
+ AXIS_TICK_FONT_SIZE = 30
38
+ LEGEND_FONT_SIZE = 20
39
+
40
+ # Markers / lines
41
+ MARKER_SIZE = 15
42
+ MARKER_LINE_WIDTH = 1.5
43
+ LINE_WIDTH = 3
44
+ TREND_LINE_WIDTH = 3.5
45
+
46
+ # Colors
47
+ PLOT_BG_COLOR = "rgba(235,238,245,1)"
48
+ PAPER_BG_COLOR = "white"
49
+ LEGEND_BG_COLOR = "rgba(255,255,255,0.45)"
50
+ GRID_COLOR = "rgba(100,116,139,0.18)"
51
+
52
+ # Layout
53
+ MARGIN_LEFT = 100
54
+ MARGIN_RIGHT = 15
55
+ MARGIN_TOP = 0
56
+ MARGIN_BOTTOM = 100
57
+
58
+ LEGEND_X = 0.008
59
+ LEGEND_Y = 0.985
60
+
61
+ COLORS = ["#1E88E5", "#8E24AA", "#A0D137", "#EA1D1D", "#06EE0D", "#FB8C00"]
62
+ SYMBOLS = ["circle", "x", "square", "diamond", "triangle-up", "star"]
63
+
64
+
65
+ def _logistic_time(t, R0, Rmax, t50, k):
66
+ """Logistic function over SYNTAX score."""
67
+ t = np.asarray(t, dtype=float)
68
+ t_safe = np.where(t <= 0, 1e-3, t)
69
+ return R0 + (Rmax - R0) / (1.0 + (t50 / t_safe) ** k)
70
+
71
+
72
+ def _fit_logistic(x, y, domain, n=TREND_POINTS):
73
+ """
74
+ Fit logistic curve.
75
+ Returns X, Y or (None, None) if the fit fails.
76
+ """
77
+ x = np.asarray(x, dtype=float)
78
+ y = np.asarray(y, dtype=float)
79
+ m = np.isfinite(x) & np.isfinite(y)
80
+ if m.sum() < 4:
81
+ return None, None
82
+
83
+ x_m, y_m = x[m], y[m]
84
+ x_min = max(float(np.min(x_m)), float(domain[0]))
85
+ x_max = min(float(np.max(x_m)), float(domain[1]))
86
+ if not np.isfinite(x_min) or not np.isfinite(x_max) or x_max <= x_min:
87
+ return None, None
88
+
89
+ x_pos = x_m[x_m > 0]
90
+ if x_pos.size == 0:
91
+ return None, None
92
+
93
+ R0_init = float(np.percentile(y_m, 10))
94
+ Rmax_init = float(np.percentile(y_m, 90))
95
+ t50_init = float(np.median(x_pos))
96
+ k_init = 1.0
97
+
98
+ lower = [-10.0, 0.0, 1e-3, 0.01]
99
+ upper = [60.0, 80.0, 60.0, 10.0]
100
+
101
+ try:
102
+ popt, _ = curve_fit(
103
+ _logistic_time,
104
+ x_m,
105
+ y_m,
106
+ p0=[R0_init, Rmax_init, t50_init, k_init],
107
+ bounds=(lower, upper),
108
+ maxfev=20000,
109
+ )
110
+ except Exception:
111
+ return None, None
112
+
113
+ X = np.linspace(x_min, x_max, n)
114
+ Y = _logistic_time(X, *popt)
115
+ return X, Y
116
+
117
+
118
+ def visualize_final_syntax_plotly_multi(
119
+ datasets,
120
+ r2_values, # Pearson per dataset
121
+ gt_row,
122
+ postfix=None,
123
+ threshold: float = 22.0,
124
+ recall_values=None,
125
+ backbone: bool = False,
126
+ show_title: bool = False,
127
+ ):
128
+ """
129
+ Unified SYNTAX visualization: points, risk zones and logistic trends.
130
+
131
+ Parameters
132
+ ----------
133
+ datasets : dict[str, tuple[list[float], list[float]]]
134
+ {dataset_name: (syntax_true_list, syntax_pred_list)}.
135
+ r2_values : dict[str, float]
136
+ Pearson correlation per dataset.
137
+ gt_row : str
138
+ String for the plot title (e.g. "ENSEMBLE" or "BOTH").
139
+ postfix : str | None
140
+ Suffix for the saved file name.
141
+ threshold : float
142
+ SYNTAX threshold (typically 22.0) to separate risk zones.
143
+ recall_values : dict[str, float] | None
144
+ Mean recall per dataset (may be None).
145
+ backbone : bool
146
+ If True, saves into `visualizations/backbone`, else into `visualizations/`.
147
+ """
148
+ fig = go.Figure()
149
+
150
+ line_min = DATA_MIN - PADDING
151
+ line_max = DATA_MAX + PADDING
152
+ domain = (line_min, line_max)
153
+
154
+ base_font = dict(
155
+ family=FONT_FAMILY,
156
+ size=BASE_FONT_SIZE,
157
+ )
158
+
159
+ # ---------- Risk zones and bands (legendrank=0) ----------
160
+ fig.add_trace(
161
+ go.Scatter(
162
+ x=[line_min, threshold, threshold, line_min],
163
+ y=[line_min, line_min, threshold, threshold],
164
+ fill="toself",
165
+ fillcolor="rgba(255, 82, 82, 0.12)",
166
+ line=dict(color="rgba(0,0,0,0)"),
167
+ name="Low-risk zone",
168
+ legendgroup="zones",
169
+ legendgrouptitle_text="Thresholds & lines",
170
+ showlegend=True,
171
+ hoverinfo="skip",
172
+ legendrank=0,
173
+ )
174
+ )
175
+ fig.add_trace(
176
+ go.Scatter(
177
+ x=[threshold, line_max, line_max, threshold],
178
+ y=[threshold, threshold, line_max, line_max],
179
+ fill="toself",
180
+ fillcolor="rgba(76, 175, 80, 0.14)",
181
+ line=dict(color="rgba(0,0,0,0)"),
182
+ name="High-risk zone",
183
+ legendgroup="zones",
184
+ showlegend=True,
185
+ hoverinfo="skip",
186
+ legendrank=0,
187
+ )
188
+ )
189
+
190
+ fig.add_trace(
191
+ go.Scatter(
192
+ x=[threshold, threshold, None, line_min, line_max],
193
+ y=[line_min, line_max, None, threshold, threshold],
194
+ mode="lines",
195
+ name=f"SYNTAX = {threshold}",
196
+ legendgroup="zones",
197
+ showlegend=True,
198
+ line=dict(color="rgba(46,125,50,0.85)", width=LINE_WIDTH, dash="dash"),
199
+ legendrank=0,
200
+ hoverinfo="skip",
201
+ )
202
+ )
203
+
204
+ x_vals = np.linspace(line_min, line_max, SIGMA_POINTS)
205
+ sigma_upper = x_vals + SIGMA_BASE + SIGMA_SLOPE * x_vals
206
+ sigma_lower = x_vals - SIGMA_BASE - SIGMA_SLOPE * x_vals
207
+ two_sigma_upper = x_vals + 2 * SIGMA_BASE + 2 * SIGMA_SLOPE * x_vals
208
+ two_sigma_lower = x_vals - 2 * SIGMA_BASE - 2 * SIGMA_SLOPE * x_vals
209
+
210
+ fig.add_trace(
211
+ go.Scatter(
212
+ x=np.concatenate([x_vals, x_vals[::-1]]),
213
+ y=np.concatenate([two_sigma_lower, two_sigma_upper[::-1]]),
214
+ fill="toself",
215
+ fillcolor="rgba(255,193,7,0.18)",
216
+ line=dict(color="rgba(0,0,0,0)"),
217
+ name="± 2σ",
218
+ legendgroup="zones",
219
+ showlegend=True,
220
+ hoverinfo="skip",
221
+ legendrank=0,
222
+ )
223
+ )
224
+ fig.add_trace(
225
+ go.Scatter(
226
+ x=np.concatenate([x_vals, x_vals[::-1]]),
227
+ y=np.concatenate([sigma_lower, sigma_upper[::-1]]),
228
+ fill="toself",
229
+ fillcolor="rgba(255,152,0,0.30)",
230
+ line=dict(color="rgba(0,0,0,0)"),
231
+ name="± σ",
232
+ legendgroup="zones",
233
+ showlegend=True,
234
+ hoverinfo="skip",
235
+ legendrank=0,
236
+ )
237
+ )
238
+
239
+ fig.add_trace(
240
+ go.Scatter(
241
+ x=[line_min, line_max],
242
+ y=[line_min, line_max],
243
+ mode="lines",
244
+ name="Perfect prediction",
245
+ legendgroup="zones",
246
+ showlegend=True,
247
+ line=dict(color="rgba(30,30,30,0.85)", width=LINE_WIDTH),
248
+ legendrank=0,
249
+ )
250
+ )
251
+
252
+ # ---------- Datasets (legendrank=20) ----------
253
+ first_dataset = True
254
+ for i, (label, (syntax_true, syntax_pred)) in enumerate(datasets.items()):
255
+ x = np.array(syntax_true, dtype=float)
256
+ y = np.array(syntax_pred, dtype=float)
257
+ if x.size == 0 or y.size == 0:
258
+ continue
259
+
260
+ pearson = r2_values.get(label, None)
261
+ recall = recall_values.get(label, None) if recall_values else None
262
+ hover_lines = [f"<b>{label}</b>"]
263
+ if pearson is not None:
264
+ hover_lines.append(f"Pearson = {pearson:.3f}")
265
+ if recall is not None:
266
+ hover_lines.append(f"Mean recall = {recall:.3f}")
267
+ hovertemplate = (
268
+ "<br>".join(hover_lines)
269
+ + "<br>Ground truth: %{x:.3f}<br>Prediction: %{y:.3f}<extra></extra>"
270
+ )
271
+
272
+ fig.add_trace(
273
+ go.Scatter(
274
+ x=x,
275
+ y=y,
276
+ mode="markers",
277
+ name=label,
278
+ legendgroup="datasets",
279
+ legendgrouptitle_text=("Datasets" if first_dataset else None),
280
+ showlegend=True,
281
+ marker=dict(
282
+ color=COLORS[i % len(COLORS)],
283
+ size=MARKER_SIZE,
284
+ opacity=0.96,
285
+ symbol=SYMBOLS[i % len(SYMBOLS)],
286
+ line=dict(
287
+ width=MARKER_LINE_WIDTH,
288
+ color="rgba(255,255,255,0.95)",
289
+ ),
290
+ ),
291
+ hovertemplate=hovertemplate,
292
+ legendrank=20,
293
+ )
294
+ )
295
+ first_dataset = False
296
+
297
+ # ---------- Logistic trends (legendrank=30) ----------
298
+ first_trend = True
299
+ for i, (label, (syntax_true, syntax_pred)) in enumerate(datasets.items()):
300
+ x = np.array(syntax_true, dtype=float)
301
+ y = np.array(syntax_pred, dtype=float)
302
+ if x.size == 0 or y.size == 0:
303
+ continue
304
+
305
+ Xc, Yc = _fit_logistic(x, y, domain=domain)
306
+ if Xc is not None:
307
+ fig.add_trace(
308
+ go.Scatter(
309
+ x=Xc,
310
+ y=Yc,
311
+ mode="lines",
312
+ name=label,
313
+ legendgroup="trends",
314
+ legendgrouptitle_text=("Logistic trends" if first_trend else None),
315
+ showlegend=True,
316
+ line=dict(
317
+ color=COLORS[i % len(COLORS)],
318
+ width=TREND_LINE_WIDTH,
319
+ ),
320
+ hoverinfo="skip",
321
+ legendrank=30,
322
+ )
323
+ )
324
+ first_trend = False
325
+
326
+ # ---------- Layout ----------
327
+ # title_text формируем как раньше, но применяем только если show_title=True
328
+ title_text = f"SYNTAX predictions ({gt_row})"
329
+ if postfix:
330
+ title_text += f" {postfix}"
331
+
332
+ layout_kwargs = dict(
333
+ font=dict(
334
+ family=FONT_FAMILY,
335
+ size=BASE_FONT_SIZE,
336
+ ),
337
+ width=PLOT_WIDTH,
338
+ height=PLOT_HEIGHT,
339
+ plot_bgcolor=PLOT_BG_COLOR,
340
+ paper_bgcolor=PAPER_BG_COLOR,
341
+ legend=dict(
342
+ x=LEGEND_X,
343
+ y=LEGEND_Y,
344
+ bgcolor=LEGEND_BG_COLOR, # полупрозрачный белый фон [web:143]
345
+ bordercolor="rgba(203,213,225,0.7)", # тоже чуть прозрачная рамка (по желанию) [web:145]
346
+ borderwidth=1,
347
+ font=dict(size=LEGEND_FONT_SIZE, family=FONT_FAMILY),
348
+ tracegroupgap=8,
349
+ itemclick="toggle",
350
+ itemdoubleclick="toggleothers",
351
+ groupclick="toggleitem",
352
+ ),
353
+ xaxis=dict(
354
+ title=dict(
355
+ text="SYNTAX ground truth",
356
+ font=dict(
357
+ size=AXIS_TITLE_FONT_SIZE,
358
+ family=FONT_FAMILY,
359
+ color="rgba(15,23,42,1)",
360
+ ),
361
+ ),
362
+ showgrid=True,
363
+ gridcolor=GRID_COLOR,
364
+ gridwidth=1,
365
+ zeroline=False,
366
+ tickfont=dict(
367
+ size=AXIS_TICK_FONT_SIZE,
368
+ family=FONT_FAMILY,
369
+ ),
370
+ range=[line_min, line_max],
371
+ constrain="domain",
372
+ ),
373
+ yaxis=dict(
374
+ title=dict(
375
+ text="SYNTAX predictions",
376
+ font=dict(
377
+ size=AXIS_TITLE_FONT_SIZE,
378
+ family=FONT_FAMILY,
379
+ color="rgba(15,23,42,1)",
380
+ ),
381
+ ),
382
+ showgrid=True,
383
+ gridcolor=GRID_COLOR,
384
+ gridwidth=1,
385
+ zeroline=False,
386
+ tickfont=dict(
387
+ size=AXIS_TICK_FONT_SIZE,
388
+ family=FONT_FAMILY,
389
+ ),
390
+ range=[line_min, line_max],
391
+ scaleanchor="x",
392
+ scaleratio=1,
393
+ constrain="domain",
394
+ ),
395
+ margin=dict(
396
+ l=MARGIN_LEFT,
397
+ r=MARGIN_RIGHT,
398
+ t=MARGIN_TOP,
399
+ b=MARGIN_BOTTOM,
400
+ ),
401
+ )
402
+
403
+ if show_title:
404
+ layout_kwargs["title"] = dict(
405
+ text=title_text,
406
+ x=0.5,
407
+ xanchor="center",
408
+ font=dict(
409
+ size=TITLE_FONT_SIZE,
410
+ family=FONT_FAMILY,
411
+ color="rgba(15,23,42,1)",
412
+ ),
413
+ )
414
+
415
+ fig.update_layout(**layout_kwargs)
416
+
417
+ # ---------- Saving ----------
418
+ save_dir = "visualizations"
419
+ if backbone:
420
+ save_dir = os.path.join(save_dir, "backbone")
421
+ os.makedirs(save_dir, exist_ok=True)
422
+
423
+ postfix_html = f"{postfix}" if postfix else "syntax"
424
+ save_path_html = os.path.join(save_dir, f"{postfix_html}.html")
425
+ fig.write_html(save_path_html, include_mathjax="cdn")
426
+ print(f"Saved visualization with logistic trends: {save_path_html}")
inference/rnn_apply.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference/rnn_apply.py
2
+ import os
3
+ import json
4
+ import tqdm
5
+ import torch
6
+ import numpy as np
7
+ import click
8
+ from datetime import datetime
9
+ import lightning.pytorch as pl
10
+ import sklearn.metrics as skm
11
+
12
+ from torch.utils.data import DataLoader
13
+ from torchvision.transforms import transforms as T
14
+ from torchvision.transforms._transforms_video import ToTensorVideo
15
+ from pytorchvideo.transforms import Normalize
16
+
17
+ from full_model.rnn_dataset import SyntaxDataset
18
+ from full_model.rnn_model import SyntaxLightningModule
19
+ from inference.metrics_visualization import visualize_final_syntax_plotly_multi
20
+
21
+
22
+ DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
23
+ print(f"DEVICE: {DEVICE}")
24
+
25
+
26
+ def safe_sample_std(values):
27
+ """Sample std (ddof=1). Если значение одно/пусто — 0.0."""
28
+ arr = np.array(values, dtype=float)
29
+ if arr.size <= 1:
30
+ return 0.0
31
+ return float(arr.std(ddof=1))
32
+
33
+
34
+ def compute_metrics(y_true, y_pred, thr=22.0):
35
+ """Pearson и Mean_Recall."""
36
+ y_true_arr = np.array(y_true, dtype=float)
37
+ y_pred_arr = np.array(y_pred, dtype=float)
38
+
39
+ pearson = float(np.corrcoef(y_true_arr, y_pred_arr)[0, 1]) if len(y_true_arr) > 1 else 0.0
40
+
41
+ y_true_bin = (y_true_arr >= thr).astype(int)
42
+ y_pred_bin = (y_pred_arr >= thr).astype(int)
43
+ unique_classes = np.unique(np.concatenate([y_true_bin, y_pred_bin]))
44
+ mean_recall = float(
45
+ np.mean(skm.recall_score(y_true_bin, y_pred_bin, average=None, labels=[0, 1]))
46
+ ) if len(unique_classes) > 1 else 0.0
47
+
48
+ return pearson, mean_recall
49
+
50
+
51
+ @click.command()
52
+ @click.option("-d", "--dataset-paths", multiple=True,
53
+ help="JSON с метаданными датасетов (относительно dataset_root).")
54
+ @click.option("-n", "--dataset-names", multiple=True,
55
+ help="Имена датасетов для метрик/графиков.")
56
+ @click.option("-p", "--postfixes", multiple=True,
57
+ help="Суффиксы для файлов предсказаний.")
58
+ @click.option(
59
+ "-r",
60
+ "--dataset-root",
61
+ type=click.Path(exists=True),
62
+ default=".",
63
+ show_default=True,
64
+ help="Корень датасета (где лежат JSON и DICOM).",
65
+ )
66
+ @click.option(
67
+ "--model-dir",
68
+ type=click.Path(exists=True),
69
+ default="full_model_weights",
70
+ show_default=True,
71
+ help="Каталог с .pt/.ckpt весами full‑моделей (RNN‑head + backbone).",
72
+ )
73
+ @click.option("-v", "--video-size", type=click.Tuple([int, int]), default=(256, 256),
74
+ show_default=True, help="Размер видео (H, W).")
75
+ @click.option("--frames-per-clip", type=int, default=32,
76
+ show_default=True, help="Количество кадров в клипе.")
77
+ @click.option("--num-workers", type=int, default=8,
78
+ show_default=True, help="Число DataLoader workers.")
79
+ @click.option("--seed", type=int, default=42,
80
+ show_default=True, help="Random seed.")
81
+ @click.option(
82
+ "--pt-weights-format",
83
+ is_flag=True,
84
+ default=True,
85
+ show_default=True,
86
+ help="Формат весов full‑моделей: True → .pt (raw state_dict), False → Lightning .ckpt.",
87
+ )
88
+ @click.option("--use-scaling", is_flag=True, default=False,
89
+ show_default=True, help="Применить a*x+b scaling из JSON.")
90
+ @click.option("--scaling-file",
91
+ help="JSON с коэффициентами scaling (относительно dataset_root).")
92
+ @click.option(
93
+ "--variant",
94
+ type=str,
95
+ default="lstm_mean",
96
+ show_default=True,
97
+ help="Вариант head‑модели: mean, lstm_mean, lstm_last, gru_mean, gru_last, bert_mean, bert_cls, bert_cls2.",
98
+ )
99
+ @click.option("-e", "--ensemble-name",
100
+ help="Имя ансамбля в metrics.json.")
101
+ @click.option("-m", "--metrics-file",
102
+ help="JSON с метриками экспериментов.")
103
+ def main(dataset_paths, dataset_names, postfixes, dataset_root, model_dir, video_size,
104
+ frames_per_clip, num_workers, seed, pt_weights_format, use_scaling,
105
+ scaling_file, variant, ensemble_name, metrics_file):
106
+
107
+ pl.seed_everything(seed)
108
+ postfix_plotly = "Ensemble"
109
+
110
+ # Пути к моделям берутся из model_dir по шаблону
111
+ model_paths = {
112
+ "left": [
113
+ os.path.join(model_dir, f"LeftBinSyntax_R3D_fold{fold:02d}_{variant}_post_best.pt")
114
+ for fold in range(5)
115
+ ],
116
+ "right": [
117
+ os.path.join(model_dir, f"RightBinSyntax_R3D_fold{fold:02d}_{variant}_post_best.pt")
118
+ for fold in range(5)
119
+ ],
120
+ }
121
+
122
+ scaling_params_dict = {}
123
+ if use_scaling:
124
+ postfix_plotly += "_scaled"
125
+ ensemble_name += "_scaled"
126
+ scaling_path = os.path.join(dataset_root, scaling_file)
127
+ if os.path.exists(scaling_path):
128
+ with open(scaling_path, "r") as f:
129
+ scaling_params_dict = json.load(f)
130
+ print(f"Loaded scaling from {scaling_path}")
131
+ else:
132
+ print(f"⚠️ Scaling file not found: {scaling_path}")
133
+
134
+ ensemble_results = {
135
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
136
+ "use_scaling": use_scaling,
137
+ "pt_weights_format": pt_weights_format,
138
+ "variant": variant,
139
+ "datasets": {},
140
+ }
141
+
142
+ all_datasets, all_pearson, all_recalls = {}, {}, {}
143
+
144
+ # вспомогательная функция для получения (a, b)
145
+ def get_ab(i: int):
146
+ params = scaling_params_dict.get(f"fold{i}", (1.0, 0.0))
147
+ if isinstance(params, dict):
148
+ return params.get("a", 1.0), params.get("b", 0.0)
149
+ return params[0], params[1]
150
+
151
+ for dataset_path, dataset_name, postfix in zip(dataset_paths, dataset_names, postfixes):
152
+ abs_dataset_path = os.path.join(dataset_root, dataset_path)
153
+ results_file = os.path.join("results", f"{postfix}.json")
154
+
155
+ if os.path.exists(results_file):
156
+ print(f"[{postfix}] Loading from {results_file}")
157
+ with open(results_file, "r") as f:
158
+ data = json.load(f)
159
+ syntax_true = data["syntax_true"]
160
+ left_preds_all = data["left_preds"]
161
+ right_preds_all = data["right_preds"]
162
+ else:
163
+ print(f"[{postfix}] Computing predictions...")
164
+ left_preds_all, left_sids = run_artery(
165
+ abs_dataset_path, "left", model_paths["left"],
166
+ video_size, frames_per_clip, num_workers,
167
+ variant=variant, pt_weights_format=pt_weights_format,
168
+ )
169
+ right_preds_all, right_sids = run_artery(
170
+ abs_dataset_path, "right", model_paths["right"],
171
+ video_size, frames_per_clip, num_workers,
172
+ variant=variant, pt_weights_format=pt_weights_format,
173
+ )
174
+ assert left_sids == right_sids
175
+
176
+ with open(abs_dataset_path, "r") as f:
177
+ dataset = json.load(f)
178
+ syntax_true = [rec.get("mean_syntax", rec.get("syntax")) for rec in dataset]
179
+
180
+ os.makedirs(os.path.dirname(results_file), exist_ok=True)
181
+ save_data = {
182
+ "syntax_true": syntax_true,
183
+ "left_preds": left_preds_all,
184
+ "right_preds": right_preds_all,
185
+ }
186
+ with open(results_file, "w") as f:
187
+ json.dump(save_data, f)
188
+ print(f"[{postfix}] Saved to {results_file}")
189
+
190
+ # -------- ансамбль с/без scaling --------
191
+ if use_scaling:
192
+ syntax_pred = []
193
+ for l_list, r_list in zip(left_preds_all, right_preds_all):
194
+ scaled_folds = []
195
+ for i, (l_val, r_val) in enumerate(zip(l_list, r_list)):
196
+ s = l_val + r_val
197
+ a, b = get_ab(i)
198
+ scaled_folds.append(a * s + b)
199
+ syntax_pred.append(max(0.0, float(np.mean(scaled_folds))))
200
+ else:
201
+ syntax_pred = [
202
+ max(0.0, float(np.mean([l + r for l, r in zip(l_list, r_list)])))
203
+ for l_list, r_list in zip(left_preds_all, right_preds_all)
204
+ ]
205
+
206
+ pearson, mean_recall = compute_metrics(syntax_true, syntax_pred)
207
+ print(f"[{postfix}] ENSEMBLE: Pearson={pearson:.4f}, Recall={mean_recall:.4f}")
208
+
209
+ # -------- per-fold метрики --------
210
+ n_folds = len(left_preds_all[0]) if left_preds_all else 0
211
+ fold_metrics = {metric: [] for metric in ["Pearson", "Mean_Recall"]}
212
+
213
+ for k in range(n_folds):
214
+ pred_k = []
215
+ for l_list, r_list in zip(left_preds_all, right_preds_all):
216
+ s = l_list[k] + r_list[k]
217
+ if use_scaling:
218
+ a, b = get_ab(k)
219
+ s = a * s + b
220
+ pred_k.append(max(0.0, float(s)))
221
+
222
+ fold_pearson, fold_recall = compute_metrics(syntax_true, pred_k)
223
+ for metric, value in zip(
224
+ fold_metrics.keys(),
225
+ [fold_pearson, fold_recall],
226
+ ):
227
+ fold_metrics[metric].append(value)
228
+
229
+ fold_summary = {
230
+ k: {"mean": float(np.mean(v)), "std": safe_sample_std(v), "values": v}
231
+ for k, v in fold_metrics.items()
232
+ }
233
+
234
+ all_datasets[dataset_name] = (syntax_true, syntax_pred)
235
+ all_pearson[dataset_name] = pearson
236
+ all_recalls[dataset_name] = mean_recall
237
+
238
+ ensemble_results["datasets"][dataset_name] = {
239
+ "Pearson": round(pearson, 4),
240
+ "Mean_Recall": round(mean_recall, 4),
241
+ "N_samples": len(syntax_true),
242
+ **{f"{k}_mean": round(v["mean"], 4) for k, v in fold_summary.items()},
243
+ **{f"{k}_std": round(v["std"], 4) for k, v in fold_summary.items()},
244
+ **{f"{k}_folds": [round(x, 4) for x in v["values"]] for k, v in fold_summary.items()},
245
+ }
246
+
247
+ metrics_path = metrics_file
248
+ full_history = {}
249
+ if os.path.exists(metrics_path):
250
+ try:
251
+ with open(metrics_path, "r") as f:
252
+ full_history = json.load(f)
253
+ except json.JSONDecodeError:
254
+ print("⚠️ Metrics file corrupted. Creating new.")
255
+
256
+ full_history[ensemble_name] = ensemble_results
257
+ with open(metrics_path, "w") as f:
258
+ json.dump(full_history, f, indent=4)
259
+ print(f"✅ Metrics saved: {metrics_path}")
260
+
261
+ visualize_final_syntax_plotly_multi(
262
+ datasets=all_datasets,
263
+ r2_values=all_pearson, # здесь теперь Pearson
264
+ gt_row="ENSEMBLE",
265
+ postfix=postfix_plotly,
266
+ recall_values=all_recalls,
267
+ )
268
+
269
+
270
+ def run_artery(dataset_path, artery, model_paths, video_size, frames_per_clip,
271
+ num_workers, variant: str, pt_weights_format: bool):
272
+ """Инференс для одной артерии (5 фолдов)."""
273
+ imagenet_mean = [0.485, 0.456, 0.406]
274
+ imagenet_std = [0.229, 0.224, 0.225]
275
+ test_transform = T.Compose([
276
+ ToTensorVideo(),
277
+ T.Resize(size=video_size, antialias=True),
278
+ Normalize(mean=imagenet_mean, std=imagenet_std),
279
+ ])
280
+
281
+ val_set = SyntaxDataset(
282
+ root=os.path.dirname(dataset_path),
283
+ meta=dataset_path,
284
+ train=False,
285
+ length=frames_per_clip,
286
+ label="",
287
+ artery=artery,
288
+ inference=True,
289
+ transform=test_transform,
290
+ )
291
+ val_loader = DataLoader(
292
+ val_set,
293
+ batch_size=1,
294
+ num_workers=num_workers,
295
+ shuffle=False,
296
+ pin_memory=True,
297
+ )
298
+ print(f"{artery} artery: {len(val_loader)} samples")
299
+
300
+ models = []
301
+ for path in model_paths:
302
+ if not os.path.exists(path):
303
+ print(f"⚠️ Model not found: {path}")
304
+ continue
305
+
306
+ model = SyntaxLightningModule(
307
+ num_classes=2,
308
+ lr=1e-5,
309
+ variant=variant,
310
+ weight_decay=0.001,
311
+ max_epochs=1,
312
+ weight_path=None,
313
+ pl_weight_path=path,
314
+ pt_weights_format=pt_weights_format,
315
+ )
316
+ model.to(DEVICE)
317
+ model.eval()
318
+ models.append(model)
319
+
320
+ if not models:
321
+ raise RuntimeError(f"No models loaded for {artery}")
322
+
323
+ preds_all, sids = [], []
324
+ with torch.no_grad():
325
+ for x, [y], [t], [sid] in tqdm.tqdm(val_loader, desc=f"{artery} infer"):
326
+ if len(x.shape) == 1:
327
+ val_syntax_list = [0.0] * len(models)
328
+ else:
329
+ x = x.to(DEVICE)
330
+ val_syntax_list = []
331
+ for model in models:
332
+ y_hat = model(x)
333
+ yp_reg = y_hat[:, 1:]
334
+ val_log = yp_reg.squeeze(-1)
335
+ val = float(torch.exp(val_log).cpu()) - 1.0
336
+ val_syntax_list.append(val)
337
+ preds_all.append(val_syntax_list)
338
+ sids.append(sid)
339
+
340
+ return preds_all, sids
341
+
342
+
343
+ if __name__ == "__main__":
344
+ main()
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core ML/DL
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ lightning>=2.1.0
5
+ pytorchvideo @ git+https://github.com/facebookresearch/pytorchvideo@main#egg=pytorchvideo
6
+
7
+ # Data processing
8
+ numpy>=1.24.0
9
+ scikit-learn>=1.3.0
10
+ tqdm>=4.65.0
11
+ pydicom>=2.4.0
12
+ python-gdcm>=3.0.10
13
+
14
+ # Visualization
15
+ plotly>=5.17.0
16
+ scipy>=1.11.0
17
+ tensorboard>=2.9
18
+ tensorboardX>=2.6
19
+
20
+ # CLI & Utilities
21
+ click>=8.1.0
scaling_coeffs/scaling_coeffs.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fold0": {
3
+ "a": 1.65,
4
+ "b": 0.4,
5
+ "mean_recall": 0.715961
6
+ },
7
+ "fold1": {
8
+ "a": 1.29,
9
+ "b": 0.38,
10
+ "mean_recall": 0.767792
11
+ },
12
+ "fold2": {
13
+ "a": 1.28,
14
+ "b": 0.365,
15
+ "mean_recall": 0.800703
16
+ },
17
+ "fold3": {
18
+ "a": 1.11,
19
+ "b": 0.42,
20
+ "mean_recall": 0.761545
21
+ },
22
+ "fold4": {
23
+ "a": 1.61,
24
+ "b": 0.385,
25
+ "mean_recall": 0.736111
26
+ }
27
+ }