from typing import Any import torch from torch import nn, optim import lightning.pytorch as pl import torchvision.models.video as tvmv import sklearn.metrics as skm import numpy as np class SyntaxLightningModule(pl.LightningModule): """ LightningModule для обучения 3D-ResNet (r3d_18) как backbone в задаче предсказания SYNTAX score по видеоангиографии. Модель предсказывает: - yp_clf: вероятность поражения (syntax > порог) — бинарная классификация - yp_reg: логарифмированное значение SYNTAX — регрессия """ def __init__( self, num_classes: int, lr: float, weight_decay: float = 0.0, max_epochs: int = None, weight_path: str = None, sigma_a: float = 0.0, sigma_b: float = 1.0, **kwargs, ): super().__init__() self.save_hyperparameters() self.num_classes = num_classes self.lr = lr self.weight_decay = weight_decay self.max_epochs = max_epochs self.weight_path = weight_path self.sigma_a = sigma_a self.sigma_b = sigma_b # Базовый 3D-ResNet с ImageNet Kinetics-предобученными весами self.model = tvmv.r3d_18(weights=tvmv.R3D_18_Weights.DEFAULT) # Последний слой заменяем на Linear с num_classes выходами: # 1 канал для классификации, 1 для регрессии in_features = self.model.fc.in_features self.model.fc = nn.Linear(in_features=in_features, out_features=num_classes, bias=True) # Если передан путь к чекпоинту Lightning — загружаем backbone if self.weight_path is not None: ckpt = torch.load(self.weight_path, map_location="cpu", weights_only=False) state_dict = ckpt["state_dict"] # Чистим префикс "model." у ключей new_state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} self.model.load_state_dict(new_state_dict, strict=False) # Лоссы self.loss_clf = nn.BCEWithLogitsLoss(reduction="none") self.loss_reg = nn.MSELoss(reduction="none") # Буферы для валидационных метрик self.y_val = [] self.p_val = [] self.r_val = [] self.ty_val = [] self.tp_val = [] # ------------------------------------------------------------------ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x) # ------------------------------------------------------------------ def training_step(self, batch, batch_idx): """ Один шаг обучения: - бинарная классификация поражения (BCE с down-weight для нулей); - регрессия логарифмированного SYNTAX с учётом get_sigma(target). """ x, y, target, sample_weight, path, original_label = batch y_hat = self(x) yp_clf = y_hat[:, 0:1] # logits для классификации yp_reg = y_hat[:, 1:] # регрессия (лог SYNTAX) # BCE с меньшим весом для класса 0 (нет поражения) weights_clf = torch.where(y > 0, 1.0, 0.45) clf_loss = self.loss_clf(yp_clf, y) clf_loss = (clf_loss * weights_clf).mean() # Регрессионный лосс с «вариабельностью по красной линии» reg_loss_raw = self.loss_reg(yp_reg, target) sigma = self.sigma_a * target + self.sigma_b reg_loss = (reg_loss_raw / (sigma ** 2)).mean() loss = clf_loss + 0.5 * reg_loss # Метрики на бинарную задачу y_pred = torch.sigmoid(yp_clf) y_bin = torch.round(y.detach().cpu()).int() y_pred_bin = torch.round(y_pred.detach().cpu()).int() self.log("train_clf_loss", clf_loss, prog_bar=True, sync_dist=True) self.log("train_val_loss", reg_loss, prog_bar=True, sync_dist=True) self.log("train_full_loss", loss, prog_bar=True, sync_dist=True) self.log( "train_f1", skm.f1_score(y_bin, y_pred_bin, zero_division=0), prog_bar=True, sync_dist=True, ) self.log( "train_acc", skm.accuracy_score(y_bin, y_pred_bin), prog_bar=True, sync_dist=True, ) return loss # ------------------------------------------------------------------ def validation_step(self, batch, batch_idx): """ Валидационный шаг: считаем тот же комбинированный лосс и аккумулируем предсказания для расчёта метрик на эпоху. """ x, y, target, sample_weight, path, original_label = batch y_hat = self(x) yp_clf = y_hat[:, 0:1] yp_reg = y_hat[:, 1:] # Комбинированный лосс clf_loss = self.loss_clf(yp_clf, y) reg_loss_raw = self.loss_reg(yp_reg, target) sigma = self.sigma_a * target + self.sigma_b reg_loss = (reg_loss_raw / (sigma ** 2)).mean() loss = clf_loss.mean() + 0.5 * reg_loss # Для метрик y_pred = torch.sigmoid(yp_clf) self.y_val.append(int(y[..., 0].cpu())) self.p_val.append(float(y_pred[..., 0].cpu())) self.r_val.append(round(float(y_pred[..., 0].cpu()))) self.ty_val.append(float(target[..., 0].cpu())) self.tp_val.append(float(yp_reg[..., 0].cpu())) return loss # ------------------------------------------------------------------ def on_validation_epoch_end(self) -> None: """ Подсчёт валидационных метрик по всей эпохе и логирование в Logger. """ try: auc = skm.roc_auc_score(self.y_val, self.p_val) f1 = skm.f1_score(self.y_val, self.r_val, zero_division=0) acc = skm.accuracy_score(self.y_val, self.r_val) mae = skm.mean_absolute_error(self.y_val, self.r_val) rmse = skm.root_mean_squared_error(self.ty_val, self.tp_val) self.log("val_auc", auc, prog_bar=True, sync_dist=True) self.log("val_f1", f1, prog_bar=True, sync_dist=True) self.log("val_acc", acc, prog_bar=True, sync_dist=True) self.log("val_mae", mae, prog_bar=True, sync_dist=True) self.log("val_rmse", rmse, prog_bar=True, sync_dist=True) except ValueError as err: # Случаи, когда метрики нельзя посчитать (например, только один класс) print(err) print("Y_VAL", self.y_val) print("P_VAL", self.p_val) # Чистим буферы к следующей эпохе self.y_val.clear() self.p_val.clear() self.r_val.clear() self.ty_val.clear() self.tp_val.clear() # ------------------------------------------------------------------ def on_train_epoch_end(self) -> None: """Логирование текущего learning rate.""" opt = self.optimizers() if hasattr(opt, "optimizer"): lr = opt.optimizer.param_groups[0]["lr"] else: lr = opt.param_groups[0]["lr"] self.log("lr", lr, on_step=False, on_epoch=True, sync_dist=True) # ------------------------------------------------------------------ def configure_optimizers(self): """ - Если weight_path не задан → pretrain: обучаем только финальный fc-слой. - Если weight_path задан → full fine-tuning: обучаем весь backbone. """ if not self.weight_path: # Pretrain: замораживаем всё, кроме финального слоя for param in self.parameters(): param.requires_grad = False for p in self.model.fc.parameters(): p.requires_grad = True params = list(self.model.fc.parameters()) else: # Full fine-tune: обучаем все параметры модели for param in self.parameters(): param.requires_grad = True params = self.parameters() optimizer = optim.AdamW(params, lr=self.lr, weight_decay=self.weight_decay) if self.max_epochs is not None: scheduler = optim.lr_scheduler.OneCycleLR( optimizer=optimizer, max_lr=self.lr, total_steps=self.max_epochs, ) return [optimizer], [scheduler] else: return optimizer # ------------------------------------------------------------------ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: """ Инференс: возвращает словарь с бинарным предсказанием, вероятностями и регрессионным выходом. """ x, y, target, sample_weight, path, original_label = batch y_hat = self(x) yp_clf = y_hat[:, 0:1] yp_reg = y_hat[:, 1:] y_prob = torch.sigmoid(yp_clf) y_pred = torch.round(y_prob) return { "y": y, "y_pred": y_pred, "y_prob": y_prob, "y_reg": yp_reg, "target": target, "original_label": original_label, }