| | from typing import Any |
| | import torch |
| | from torch import nn, optim |
| | import lightning.pytorch as pl |
| | import torchvision.models.video as tvmv |
| | import sklearn.metrics as skm |
| | import numpy as np |
| |
|
| |
|
| | class SyntaxLightningModule(pl.LightningModule): |
| | """ |
| | LightningModule для обучения 3D-ResNet (r3d_18) как backbone |
| | в задаче предсказания SYNTAX score по видеоангиографии. |
| | |
| | Модель предсказывает: |
| | - yp_clf: вероятность поражения (syntax > порог) — бинарная классификация |
| | - yp_reg: логарифмированное значение SYNTAX — регрессия |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | num_classes: int, |
| | lr: float, |
| | weight_decay: float = 0.0, |
| | max_epochs: int = None, |
| | weight_path: str = None, |
| | sigma_a: float = 0.0, |
| | sigma_b: float = 1.0, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| | self.save_hyperparameters() |
| |
|
| | self.num_classes = num_classes |
| | self.lr = lr |
| | self.weight_decay = weight_decay |
| | self.max_epochs = max_epochs |
| | self.weight_path = weight_path |
| | self.sigma_a = sigma_a |
| | self.sigma_b = sigma_b |
| |
|
| | |
| | self.model = tvmv.r3d_18(weights=tvmv.R3D_18_Weights.DEFAULT) |
| |
|
| | |
| | |
| | in_features = self.model.fc.in_features |
| | self.model.fc = nn.Linear(in_features=in_features, out_features=num_classes, bias=True) |
| |
|
| | |
| | if self.weight_path is not None: |
| | ckpt = torch.load(self.weight_path, map_location="cpu", weights_only=False) |
| | state_dict = ckpt["state_dict"] |
| | |
| | new_state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} |
| | self.model.load_state_dict(new_state_dict, strict=False) |
| |
|
| | |
| | self.loss_clf = nn.BCEWithLogitsLoss(reduction="none") |
| | self.loss_reg = nn.MSELoss(reduction="none") |
| |
|
| | |
| | self.y_val = [] |
| | self.p_val = [] |
| | self.r_val = [] |
| | self.ty_val = [] |
| | self.tp_val = [] |
| |
|
| | |
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | return self.model(x) |
| |
|
| | |
| | def training_step(self, batch, batch_idx): |
| | """ |
| | Один шаг обучения: |
| | - бинарная классификация поражения (BCE с down-weight для нулей); |
| | - регрессия логарифмированного SYNTAX с учётом get_sigma(target). |
| | """ |
| | x, y, target, sample_weight, path, original_label = batch |
| |
|
| | y_hat = self(x) |
| | yp_clf = y_hat[:, 0:1] |
| | yp_reg = y_hat[:, 1:] |
| |
|
| | |
| | weights_clf = torch.where(y > 0, 1.0, 0.45) |
| | clf_loss = self.loss_clf(yp_clf, y) |
| | clf_loss = (clf_loss * weights_clf).mean() |
| |
|
| | |
| | reg_loss_raw = self.loss_reg(yp_reg, target) |
| | sigma = self.sigma_a * target + self.sigma_b |
| | reg_loss = (reg_loss_raw / (sigma ** 2)).mean() |
| |
|
| | loss = clf_loss + 0.5 * reg_loss |
| |
|
| | |
| | y_pred = torch.sigmoid(yp_clf) |
| | y_bin = torch.round(y.detach().cpu()).int() |
| | y_pred_bin = torch.round(y_pred.detach().cpu()).int() |
| |
|
| | self.log("train_clf_loss", clf_loss, prog_bar=True, sync_dist=True) |
| | self.log("train_val_loss", reg_loss, prog_bar=True, sync_dist=True) |
| | self.log("train_full_loss", loss, prog_bar=True, sync_dist=True) |
| | self.log( |
| | "train_f1", |
| | skm.f1_score(y_bin, y_pred_bin, zero_division=0), |
| | prog_bar=True, |
| | sync_dist=True, |
| | ) |
| | self.log( |
| | "train_acc", |
| | skm.accuracy_score(y_bin, y_pred_bin), |
| | prog_bar=True, |
| | sync_dist=True, |
| | ) |
| |
|
| | return loss |
| |
|
| | |
| | def validation_step(self, batch, batch_idx): |
| | """ |
| | Валидационный шаг: считаем тот же комбинированный лосс и |
| | аккумулируем предсказания для расчёта метрик на эпоху. |
| | """ |
| | x, y, target, sample_weight, path, original_label = batch |
| |
|
| | y_hat = self(x) |
| | yp_clf = y_hat[:, 0:1] |
| | yp_reg = y_hat[:, 1:] |
| |
|
| | |
| | clf_loss = self.loss_clf(yp_clf, y) |
| | reg_loss_raw = self.loss_reg(yp_reg, target) |
| | sigma = self.sigma_a * target + self.sigma_b |
| | reg_loss = (reg_loss_raw / (sigma ** 2)).mean() |
| | loss = clf_loss.mean() + 0.5 * reg_loss |
| |
|
| | |
| | y_pred = torch.sigmoid(yp_clf) |
| |
|
| | self.y_val.append(int(y[..., 0].cpu())) |
| | self.p_val.append(float(y_pred[..., 0].cpu())) |
| | self.r_val.append(round(float(y_pred[..., 0].cpu()))) |
| |
|
| | self.ty_val.append(float(target[..., 0].cpu())) |
| | self.tp_val.append(float(yp_reg[..., 0].cpu())) |
| |
|
| | return loss |
| |
|
| | |
| | def on_validation_epoch_end(self) -> None: |
| | """ |
| | Подсчёт валидационных метрик по всей эпохе и логирование в Logger. |
| | """ |
| | try: |
| | auc = skm.roc_auc_score(self.y_val, self.p_val) |
| | f1 = skm.f1_score(self.y_val, self.r_val, zero_division=0) |
| | acc = skm.accuracy_score(self.y_val, self.r_val) |
| | mae = skm.mean_absolute_error(self.y_val, self.r_val) |
| | rmse = skm.root_mean_squared_error(self.ty_val, self.tp_val) |
| |
|
| | self.log("val_auc", auc, prog_bar=True, sync_dist=True) |
| | self.log("val_f1", f1, prog_bar=True, sync_dist=True) |
| | self.log("val_acc", acc, prog_bar=True, sync_dist=True) |
| | self.log("val_mae", mae, prog_bar=True, sync_dist=True) |
| | self.log("val_rmse", rmse, prog_bar=True, sync_dist=True) |
| |
|
| | except ValueError as err: |
| | |
| | print(err) |
| | print("Y_VAL", self.y_val) |
| | print("P_VAL", self.p_val) |
| |
|
| | |
| | self.y_val.clear() |
| | self.p_val.clear() |
| | self.r_val.clear() |
| | self.ty_val.clear() |
| | self.tp_val.clear() |
| |
|
| | |
| | def on_train_epoch_end(self) -> None: |
| | """Логирование текущего learning rate.""" |
| | opt = self.optimizers() |
| | if hasattr(opt, "optimizer"): |
| | lr = opt.optimizer.param_groups[0]["lr"] |
| | else: |
| | lr = opt.param_groups[0]["lr"] |
| | self.log("lr", lr, on_step=False, on_epoch=True, sync_dist=True) |
| |
|
| | |
| | def configure_optimizers(self): |
| | """ |
| | - Если weight_path не задан → pretrain: обучаем только финальный fc-слой. |
| | - Если weight_path задан → full fine-tuning: обучаем весь backbone. |
| | """ |
| | if not self.weight_path: |
| | |
| | for param in self.parameters(): |
| | param.requires_grad = False |
| | for p in self.model.fc.parameters(): |
| | p.requires_grad = True |
| | params = list(self.model.fc.parameters()) |
| | else: |
| | |
| | for param in self.parameters(): |
| | param.requires_grad = True |
| | params = self.parameters() |
| |
|
| | optimizer = optim.AdamW(params, lr=self.lr, weight_decay=self.weight_decay) |
| |
|
| | if self.max_epochs is not None: |
| | scheduler = optim.lr_scheduler.OneCycleLR( |
| | optimizer=optimizer, |
| | max_lr=self.lr, |
| | total_steps=self.max_epochs, |
| | ) |
| | return [optimizer], [scheduler] |
| | else: |
| | return optimizer |
| |
|
| | |
| | def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: |
| | """ |
| | Инференс: возвращает словарь с бинарным предсказанием, вероятностями |
| | и регрессионным выходом. |
| | """ |
| | x, y, target, sample_weight, path, original_label = batch |
| | y_hat = self(x) |
| | yp_clf = y_hat[:, 0:1] |
| | yp_reg = y_hat[:, 1:] |
| | y_prob = torch.sigmoid(yp_clf) |
| | y_pred = torch.round(y_prob) |
| |
|
| | return { |
| | "y": y, |
| | "y_pred": y_pred, |
| | "y_prob": y_prob, |
| | "y_reg": yp_reg, |
| | "target": target, |
| | "original_label": original_label, |
| | } |
| |
|