| from typing import Any, Optional |
|
|
| import numpy as np |
| import torch |
| from torch import nn, optim |
| import lightning.pytorch as pl |
| import torchvision.models.video as tvmv |
| import sklearn.metrics as skm |
|
|
|
|
| class SyntaxLightningModule(pl.LightningModule): |
| """ |
| LightningModule для обучения 3D-backbone на SYNTAX score. |
| |
| Архитектура: |
| - backbone: ResNet3D (r3d_18) из torchvision |
| - выходной полносвязный слой: два нейрона |
| [0] — логит для бинарной классификации (значимое поражение) |
| [1] — регрессионный выход для SYNTAX score (log1p) |
| |
| Режимы обучения: |
| - pretrain (weight_path is None): |
| замораживается весь backbone, обучается только финальный слой (fc) |
| - finetune (weight_path задан): |
| загружаются веса из чекпоинта и дообучается вся сеть целиком. |
| """ |
|
|
| def __init__( |
| self, |
| num_classes: int, |
| lr: float, |
| weight_decay: float = 0.0, |
| max_epochs: Optional[int] = None, |
| weight_path: Optional[str] = None, |
| sigma_a: float = 0.0, |
| sigma_b: float = 1.0, |
| **kwargs, |
| ): |
| super().__init__() |
| self.save_hyperparameters() |
|
|
| self.num_classes = int(num_classes) |
| self.lr = float(lr) |
| self.weight_decay = float(weight_decay) |
| self.max_epochs = max_epochs |
| self.weight_path = weight_path |
|
|
| self.sigma_a = float(sigma_a) |
| self.sigma_b = float(sigma_b) |
|
|
| |
| self.model = tvmv.r3d_18(weights=tvmv.R3D_18_Weights.DEFAULT) |
|
|
| |
| in_features = self.model.fc.in_features |
| self.model.fc = nn.Linear(in_features=in_features, out_features=self.num_classes, bias=True) |
|
|
| |
| if self.weight_path is not None: |
| self._load_backbone_weights(self.weight_path) |
|
|
| |
| self.loss_clf = nn.BCEWithLogitsLoss(reduction="none") |
| self.loss_reg = nn.MSELoss(reduction="none") |
|
|
| |
| self._y_true = [] |
| self._y_prob = [] |
| self._y_pred = [] |
| self._t_true = [] |
| self._t_pred = [] |
|
|
| def _load_backbone_weights(self, weight_path: str) -> None: |
| """ |
| Загружает веса backbone из: |
| - Lightning чекпоинта (dict с ключом 'state_dict') |
| - или "голого" state_dict (.pt/.pth), сохранённого через model.state_dict(). |
| |
| Логирует источник и статистику по ключам. |
| """ |
| obj = torch.load(weight_path, map_location="cpu", weights_only=False) |
|
|
| if isinstance(obj, dict) and "state_dict" in obj: |
| state_dict = obj["state_dict"] |
| state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} |
| src_type = "lightning_checkpoint" |
| else: |
| state_dict = obj |
| src_type = "raw_state_dict" |
|
|
| incompatible = self.model.load_state_dict(state_dict, strict=False) |
|
|
| loaded_keys = [k for k in state_dict.keys() if k not in incompatible.missing_keys] |
| print( |
| f"[Backbone] Loaded weights from '{weight_path}' " |
| f"(type={src_type}): {len(loaded_keys)} params, " |
| f"missing={len(incompatible.missing_keys)}, " |
| f"unexpected={len(incompatible.unexpected_keys)}" |
| ) |
| if incompatible.missing_keys: |
| print(f"[Backbone] Missing keys example: {incompatible.missing_keys[:5]}") |
| if incompatible.unexpected_keys: |
| print(f"[Backbone] Unexpected keys example: {incompatible.unexpected_keys[:5]}") |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Вход: |
| x: (B, C, T, H, W) |
| |
| Выход: |
| y_hat: (B, 2) — [clf_logit, reg_output] |
| """ |
| return self.model(x) |
|
|
| def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: |
| """ |
| Один шаг обучения backbone. |
| """ |
| x, y, target, sample_weight, path, original_label = batch |
|
|
| y_hat = self(x) |
| yp_clf = y_hat[:, 0:1] |
| yp_reg = y_hat[:, 1:2] |
|
|
| weights_clf = torch.where(y > 0, 1.0, 0.45).to(y.dtype) |
| clf_loss = (self.loss_clf(yp_clf, y) * weights_clf).mean() |
|
|
| reg_loss_raw = self.loss_reg(yp_reg, target) |
| sigma = self.sigma_a * target + self.sigma_b |
| reg_loss = (reg_loss_raw / (sigma ** 2)).mean() |
|
|
| loss = clf_loss + 0.5 * reg_loss |
|
|
| y_prob = torch.sigmoid(yp_clf).detach() |
| y_pred = (y_prob > 0.5).int().cpu().numpy() |
| y_true = y.detach().int().cpu().numpy() |
|
|
| self.log("train_clf_loss", clf_loss, prog_bar=True, sync_dist=True) |
| self.log("train_reg_loss", reg_loss, prog_bar=True, sync_dist=True) |
| self.log("train_loss", loss, prog_bar=True, sync_dist=True) |
|
|
| self.log("train_f1", skm.f1_score(y_true, y_pred, zero_division=0), |
| prog_bar=True, sync_dist=True) |
| self.log("train_acc", skm.accuracy_score(y_true, y_pred), |
| prog_bar=True, sync_dist=True) |
|
|
| return loss |
|
|
| def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor: |
| """ |
| Один шаг валидации backbone. |
| """ |
| x, y, target, sample_weight, path, original_label = batch |
|
|
| y_hat = self(x) |
| yp_clf = y_hat[:, 0:1] |
| yp_reg = y_hat[:, 1:2] |
|
|
| clf_loss = self.loss_clf(yp_clf, y).mean() |
|
|
| reg_loss_raw = self.loss_reg(yp_reg, target) |
| sigma = self.sigma_a * target + self.sigma_b |
| reg_loss = (reg_loss_raw / (sigma ** 2)).mean() |
|
|
| loss = clf_loss + 0.5 * reg_loss |
|
|
| y_prob = torch.sigmoid(yp_clf).float() |
| self._y_true.append(float(y[..., 0].float().cpu())) |
| self._y_prob.append(float(y_prob[..., 0].cpu())) |
| self._y_pred.append(int((y_prob[..., 0] > 0.5).cpu())) |
|
|
| self._t_true.append(float(target[..., 0].float().cpu())) |
| self._t_pred.append(float(yp_reg[..., 0].cpu())) |
|
|
| self.log("val_loss", loss, prog_bar=True, sync_dist=True) |
| self.log("val_clf_loss", clf_loss, prog_bar=False, sync_dist=True) |
| self.log("val_reg_loss", reg_loss, prog_bar=False, sync_dist=True) |
|
|
| return loss |
|
|
| def on_validation_epoch_end(self) -> None: |
| """ |
| Расчёт и логирование метрик по окончании валидации. |
| """ |
| if len(self._t_true) > 0: |
| rmse = skm.root_mean_squared_error(self._t_true, self._t_pred) |
| mae = skm.mean_absolute_error(self._t_true, self._t_pred) |
| self.log("val_rmse", rmse, prog_bar=True, sync_dist=True) |
| self.log("val_reg_mae", mae, prog_bar=True, sync_dist=True) |
|
|
| if len(set(self._y_true)) > 1: |
| auc = skm.roc_auc_score(self._y_true, self._y_prob) |
| f1 = skm.f1_score(self._y_true, self._y_pred, zero_division=0) |
| acc = skm.accuracy_score(self._y_true, self._y_pred) |
| self.log("val_auc", auc, prog_bar=True, sync_dist=True) |
| self.log("val_f1", f1, prog_bar=True, sync_dist=True) |
| self.log("val_acc", acc, prog_bar=True, sync_dist=True) |
|
|
| self._y_true.clear() |
| self._y_prob.clear() |
| self._y_pred.clear() |
| self._t_true.clear() |
| self._t_pred.clear() |
|
|
| def on_train_epoch_end(self) -> None: |
| """ |
| Логирование текущего learning rate. |
| """ |
| opt = self.optimizers() |
| self.log( |
| "lr", |
| opt.optimizer.param_groups[0]["lr"], |
| on_step=False, |
| on_epoch=True, |
| sync_dist=True, |
| ) |
|
|
| def configure_optimizers(self): |
| """ |
| Настройка оптимизатора и OneCycleLR. |
| |
| Если weight_path is None: |
| обучается только self.model.fc (pretrain). |
| Иначе: |
| обучается вся модель (full finetune). |
| """ |
| if self.weight_path is None: |
| for p in self.parameters(): |
| p.requires_grad = False |
| for p in self.model.fc.parameters(): |
| p.requires_grad = True |
| params = self.model.fc.parameters() |
| else: |
| for p in self.parameters(): |
| p.requires_grad = True |
| params = self.parameters() |
|
|
| optimizer = optim.AdamW(params, lr=self.lr, weight_decay=self.weight_decay) |
|
|
| if self.max_epochs is not None and getattr(self, "trainer", None) is not None: |
| total_steps = self.trainer.estimated_stepping_batches |
| scheduler = optim.lr_scheduler.OneCycleLR( |
| optimizer=optimizer, |
| max_lr=self.lr, |
| total_steps=total_steps, |
| ) |
| return { |
| "optimizer": optimizer, |
| "lr_scheduler": { |
| "scheduler": scheduler, |
| "interval": "step", |
| }, |
| } |
|
|
| return optimizer |
|
|
| def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: |
| """ |
| Шаг инференса backbone. |
| """ |
| x, y, target, sample_weight, path, original_label = batch |
| y_hat = self(x) |
| yp_clf = y_hat[:, 0:1] |
| yp_reg = y_hat[:, 1:2] |
| y_prob = torch.sigmoid(yp_clf) |
|
|
| return { |
| "y": y, |
| "y_pred": (y_prob > 0.5).int(), |
| "y_prob": y_prob, |
| "y_reg": yp_reg, |
| "target": target, |
| "original_label": original_label, |
| "path": path, |
| } |
|
|