MesserMMP's picture
Add model code and full model weights
f621d73
from typing import Any, Optional
import numpy as np
import torch
from torch import nn, optim
import lightning.pytorch as pl
import torchvision.models.video as tvmv
import sklearn.metrics as skm
class SyntaxLightningModule(pl.LightningModule):
"""
LightningModule для обучения 3D-backbone на SYNTAX score.
Архитектура:
- backbone: ResNet3D (r3d_18) из torchvision
- выходной полносвязный слой: два нейрона
[0] — логит для бинарной классификации (значимое поражение)
[1] — регрессионный выход для SYNTAX score (log1p)
Режимы обучения:
- pretrain (weight_path is None):
замораживается весь backbone, обучается только финальный слой (fc)
- finetune (weight_path задан):
загружаются веса из чекпоинта и дообучается вся сеть целиком.
"""
def __init__(
self,
num_classes: int,
lr: float,
weight_decay: float = 0.0,
max_epochs: Optional[int] = None,
weight_path: Optional[str] = None,
sigma_a: float = 0.0,
sigma_b: float = 1.0,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
self.num_classes = int(num_classes)
self.lr = float(lr)
self.weight_decay = float(weight_decay)
self.max_epochs = max_epochs
self.weight_path = weight_path
self.sigma_a = float(sigma_a)
self.sigma_b = float(sigma_b)
# Инициализация 3D-ResNet-18 с предобученными весами
self.model = tvmv.r3d_18(weights=tvmv.R3D_18_Weights.DEFAULT)
# Замена финального слоя fc на слой с num_classes выходами
in_features = self.model.fc.in_features
self.model.fc = nn.Linear(in_features=in_features, out_features=self.num_classes, bias=True)
# При наличии пути к весам загружаем backbone
if self.weight_path is not None:
self._load_backbone_weights(self.weight_path)
# Лоссы
self.loss_clf = nn.BCEWithLogitsLoss(reduction="none")
self.loss_reg = nn.MSELoss(reduction="none")
# Буферы для валидации
self._y_true = []
self._y_prob = []
self._y_pred = []
self._t_true = []
self._t_pred = []
def _load_backbone_weights(self, weight_path: str) -> None:
"""
Загружает веса backbone из:
- Lightning чекпоинта (dict с ключом 'state_dict')
- или "голого" state_dict (.pt/.pth), сохранённого через model.state_dict().
Логирует источник и статистику по ключам.
"""
obj = torch.load(weight_path, map_location="cpu", weights_only=False)
if isinstance(obj, dict) and "state_dict" in obj:
state_dict = obj["state_dict"]
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
src_type = "lightning_checkpoint"
else:
state_dict = obj
src_type = "raw_state_dict"
incompatible = self.model.load_state_dict(state_dict, strict=False)
loaded_keys = [k for k in state_dict.keys() if k not in incompatible.missing_keys]
print(
f"[Backbone] Loaded weights from '{weight_path}' "
f"(type={src_type}): {len(loaded_keys)} params, "
f"missing={len(incompatible.missing_keys)}, "
f"unexpected={len(incompatible.unexpected_keys)}"
)
if incompatible.missing_keys:
print(f"[Backbone] Missing keys example: {incompatible.missing_keys[:5]}")
if incompatible.unexpected_keys:
print(f"[Backbone] Unexpected keys example: {incompatible.unexpected_keys[:5]}")
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Вход:
x: (B, C, T, H, W)
Выход:
y_hat: (B, 2) — [clf_logit, reg_output]
"""
return self.model(x)
def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor:
"""
Один шаг обучения backbone.
"""
x, y, target, sample_weight, path, original_label = batch
y_hat = self(x)
yp_clf = y_hat[:, 0:1]
yp_reg = y_hat[:, 1:2]
weights_clf = torch.where(y > 0, 1.0, 0.45).to(y.dtype)
clf_loss = (self.loss_clf(yp_clf, y) * weights_clf).mean()
reg_loss_raw = self.loss_reg(yp_reg, target)
sigma = self.sigma_a * target + self.sigma_b
reg_loss = (reg_loss_raw / (sigma ** 2)).mean()
loss = clf_loss + 0.5 * reg_loss
y_prob = torch.sigmoid(yp_clf).detach()
y_pred = (y_prob > 0.5).int().cpu().numpy()
y_true = y.detach().int().cpu().numpy()
self.log("train_clf_loss", clf_loss, prog_bar=True, sync_dist=True)
self.log("train_reg_loss", reg_loss, prog_bar=True, sync_dist=True)
self.log("train_loss", loss, prog_bar=True, sync_dist=True)
self.log("train_f1", skm.f1_score(y_true, y_pred, zero_division=0),
prog_bar=True, sync_dist=True)
self.log("train_acc", skm.accuracy_score(y_true, y_pred),
prog_bar=True, sync_dist=True)
return loss
def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor:
"""
Один шаг валидации backbone.
"""
x, y, target, sample_weight, path, original_label = batch
y_hat = self(x)
yp_clf = y_hat[:, 0:1]
yp_reg = y_hat[:, 1:2]
clf_loss = self.loss_clf(yp_clf, y).mean()
reg_loss_raw = self.loss_reg(yp_reg, target)
sigma = self.sigma_a * target + self.sigma_b
reg_loss = (reg_loss_raw / (sigma ** 2)).mean()
loss = clf_loss + 0.5 * reg_loss
y_prob = torch.sigmoid(yp_clf).float()
self._y_true.append(float(y[..., 0].float().cpu()))
self._y_prob.append(float(y_prob[..., 0].cpu()))
self._y_pred.append(int((y_prob[..., 0] > 0.5).cpu()))
self._t_true.append(float(target[..., 0].float().cpu()))
self._t_pred.append(float(yp_reg[..., 0].cpu()))
self.log("val_loss", loss, prog_bar=True, sync_dist=True)
self.log("val_clf_loss", clf_loss, prog_bar=False, sync_dist=True)
self.log("val_reg_loss", reg_loss, prog_bar=False, sync_dist=True)
return loss
def on_validation_epoch_end(self) -> None:
"""
Расчёт и логирование метрик по окончании валидации.
"""
if len(self._t_true) > 0:
rmse = skm.root_mean_squared_error(self._t_true, self._t_pred)
mae = skm.mean_absolute_error(self._t_true, self._t_pred)
self.log("val_rmse", rmse, prog_bar=True, sync_dist=True)
self.log("val_reg_mae", mae, prog_bar=True, sync_dist=True)
if len(set(self._y_true)) > 1:
auc = skm.roc_auc_score(self._y_true, self._y_prob)
f1 = skm.f1_score(self._y_true, self._y_pred, zero_division=0)
acc = skm.accuracy_score(self._y_true, self._y_pred)
self.log("val_auc", auc, prog_bar=True, sync_dist=True)
self.log("val_f1", f1, prog_bar=True, sync_dist=True)
self.log("val_acc", acc, prog_bar=True, sync_dist=True)
self._y_true.clear()
self._y_prob.clear()
self._y_pred.clear()
self._t_true.clear()
self._t_pred.clear()
def on_train_epoch_end(self) -> None:
"""
Логирование текущего learning rate.
"""
opt = self.optimizers()
self.log(
"lr",
opt.optimizer.param_groups[0]["lr"],
on_step=False,
on_epoch=True,
sync_dist=True,
)
def configure_optimizers(self):
"""
Настройка оптимизатора и OneCycleLR.
Если weight_path is None:
обучается только self.model.fc (pretrain).
Иначе:
обучается вся модель (full finetune).
"""
if self.weight_path is None:
for p in self.parameters():
p.requires_grad = False
for p in self.model.fc.parameters():
p.requires_grad = True
params = self.model.fc.parameters()
else:
for p in self.parameters():
p.requires_grad = True
params = self.parameters()
optimizer = optim.AdamW(params, lr=self.lr, weight_decay=self.weight_decay)
if self.max_epochs is not None and getattr(self, "trainer", None) is not None:
total_steps = self.trainer.estimated_stepping_batches
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=self.lr,
total_steps=total_steps,
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "step",
},
}
return optimizer
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
"""
Шаг инференса backbone.
"""
x, y, target, sample_weight, path, original_label = batch
y_hat = self(x)
yp_clf = y_hat[:, 0:1]
yp_reg = y_hat[:, 1:2]
y_prob = torch.sigmoid(yp_clf)
return {
"y": y,
"y_pred": (y_prob > 0.5).int(),
"y_prob": y_prob,
"y_reg": yp_reg,
"target": target,
"original_label": original_label,
"path": path,
}