syntax-model / backbone /pl_train.py
MesserMMP's picture
add backbone model
3ca1d12
import os
import json
import torch
import numpy as np
import click
import lightning.pytorch as pl
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.profilers import AdvancedProfiler, PyTorchProfiler
from pytorchvideo.transforms import Normalize, Permute, RandAugment
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision.transforms import transforms as T
from torchvision.transforms._transforms_video import ToTensorVideo
from torchvision.transforms import InterpolationMode
from dataset import SyntaxDataset
from pl_model import SyntaxLightningModule
import warnings
warnings.filterwarnings("ignore", message="No device id is provided via `init_process_group`")
torch.set_float32_matmul_precision("medium")
"""
Скрипт обучения backbone (3D-ResNet) для предсказания SYNTAX score.
Шаги:
1) предварительное обучение (pretrain) — обучается только последний слой;
2) полное дообучение (full) — fine-tuning всего backbone.
"""
# ------------------- Трансформации -------------------
def get_transforms(video_size, imagenet_mean, imagenet_std, train=True):
interpolation_choices = [
InterpolationMode.BILINEAR,
InterpolationMode.BICUBIC,
]
if train:
return T.Compose([
ToTensorVideo(), # (T, H, W, 3) -> (C, T, H, W)
Permute(dims=[1, 0, 2, 3]), # (C, T, H, W) -> (T, C, H, W)
RandAugment(magnitude=10, num_layers=2),
T.RandomHorizontalFlip(),
Permute(dims=[1, 0, 2, 3]), # обратно: (T, C, H, W) -> (C, T, H, W)
T.RandomChoice([
T.Resize(size=video_size, interpolation=interp, antialias=True)
for interp in interpolation_choices
]),
Normalize(mean=imagenet_mean, std=imagenet_std),
])
else:
return T.Compose([
ToTensorVideo(),
T.Resize(size=video_size, interpolation=InterpolationMode.BICUBIC, antialias=True),
Normalize(mean=imagenet_mean, std=imagenet_std),
])
# ------------------- DataLoader -------------------
def make_dataloader(dataset, batch_size, num_workers):
"""
Создаёт DataLoader; по умолчанию используем shuffle,
но можно легко переключиться на WeightedRandomSampler.
"""
sample_weights = dataset.get_sample_weights()
# sampler = WeightedRandomSampler(sample_weights, len(dataset), replacement=True)
return DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
# sampler=sampler,
shuffle=True,
drop_last=True,
pin_memory=True,
)
# ------------------- Модель -------------------
def make_model(num_classes, video_shape, lr, weight_decay, max_epochs, weight_path=None):
"""
Обёртка над SyntaxLightningModule для единообразного создания модели
на этапах pretrain и full fine-tuning.
"""
model = SyntaxLightningModule(
num_classes=num_classes,
lr=lr,
weight_decay=weight_decay,
max_epochs=max_epochs,
weight_path=weight_path,
)
return model
# ------------------- Callbacks -------------------
def make_callbacks(artery: str, fold: int, phase: str):
"""
Возвращает набор callback'ов:
- LearningRateMonitor
- ModelCheckpoint с сохранением по наилучшему val_mae.
"""
lr_monitor = LearningRateMonitor(logging_interval="epoch")
if phase == "pre":
checkpoint = ModelCheckpoint(
monitor="val_mae",
save_top_k=1,
mode="min",
filename="model" + "-{epoch:02d}-{val_rmse:.3f}",
save_last=True,
)
elif phase == "full":
checkpoint = ModelCheckpoint(
monitor="val_mae",
save_top_k=3,
mode="min",
filename="model" + "-{epoch:02d}-{val_rmse:.3f}",
save_last=True,
)
else:
raise ValueError(f"Unknown phase '{phase}', expected 'pre' or 'full'")
return [lr_monitor, checkpoint]
# ------------------- Trainer -------------------
def make_trainer(max_epochs, logger_name, callbacks):
"""
Создаёт Lightning Trainer c TensorBoardLogger.
Важно: пути к логам и устройствам можно адаптировать под свой кластер.
"""
logger = TensorBoardLogger(
save_dir="backbone_logs",
name=logger_name,
)
trainer = pl.Trainer(
max_epochs=max_epochs,
accelerator="gpu",
devices=1,
strategy="ddp_find_unused_parameters_true",
precision="bf16-mixed",
callbacks=callbacks,
log_every_n_steps=10,
logger=logger,
)
return trainer
@click.command()
@click.option(
"-r",
"--dataset-root",
type=click.Path(exists=True),
default=".",
required=True,
help="Путь к корню датасета (директория, внутри которой лежат JSON и DICOM).",
)
@click.option("--fold", type=int, default=0, required=True, help="Номер фолда (0–4).")
@click.option(
"-a",
"--artery",
type=str,
default="right",
required=True,
help="Название артерии: 'left' или 'right'.",
)
@click.option("-nc", "--num-classes", type=int, default=2, help="Число выходных каналов модели.")
@click.option("-b", "--batch-size", type=int, default=50, help="Размер batch.")
@click.option("-f", "--frames-per-clip", type=int, default=32, help="Количество кадров в клипе.")
@click.option(
"-v",
"--video-size",
type=click.Tuple([int, int]),
default=(256, 256),
help="Размер кадра (H, W).",
)
@click.option("--max-epochs", type=int, default=10, help="Число эпох на этапе full fine-tuning.")
@click.option("--num-workers", type=int, default=8, help="Число воркеров для DataLoader.")
@click.option(
"--fast-dev-run",
is_flag=True,
default=False,
show_default=True,
help="Режим быстрой проверки пайплайна (1–2 батча).",
)
@click.option("--seed", type=int, default=42, help="Сид для воспроизводимости.")
def main(
dataset_root,
fold,
artery,
num_classes,
batch_size,
frames_per_clip,
video_size,
max_epochs,
num_workers,
fast_dev_run,
seed,
):
pl.seed_everything(seed)
artery = artery.lower()
artery_bin = {"left": 0, "right": 1}.get(artery)
if artery_bin is None:
raise ValueError(f"Unknown artery '{artery}', expected 'left' or 'right'.")
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
# ------------------- Datasets -------------------
# Путь к JSON теперь относительный относительно dataset_root
train_meta = os.path.join("folds", f"step2_fold{fold:02d}_train.json")
val_meta = os.path.join("folds", f"step2_fold{fold:02d}_eval.json")
train_set = SyntaxDataset(
root=dataset_root,
meta=train_meta,
train=True,
length=frames_per_clip,
label=f"syntax_{artery}",
artery_bin=artery_bin,
validation=False,
transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=True),
)
val_set = SyntaxDataset(
root=dataset_root,
meta=val_meta,
train=False,
length=frames_per_clip,
label=f"syntax_{artery}",
artery_bin=artery_bin,
validation=True,
transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=False),
)
train_loader_pre = make_dataloader(train_set, batch_size * 2, num_workers)
train_loader_post = make_dataloader(train_set, batch_size, num_workers)
val_loader = make_dataloader(val_set, 1, num_workers)
# Получаем форму входного видео (C, T, H, W) из одного батча
x, *_ = next(iter(train_loader_pre))
video_shape = x.shape[1:]
# ------------------- Callbacks -------------------
callbacks_pre = make_callbacks(artery=artery, fold=fold, phase="pre")
callbacks_full = make_callbacks(artery=artery, fold=fold, phase="full")
# ------------------- Pretrain -------------------
num_pre_epochs = 10
model_pre = make_model(
num_classes=num_classes,
video_shape=video_shape,
lr=3e-4,
weight_decay=0.01,
max_epochs=num_pre_epochs,
)
trainer_pre = make_trainer(num_pre_epochs, f"{artery}BinSyntax_R3D_pre_fold{fold:02d}", callbacks_pre)
trainer_pre.fit(model_pre, train_loader_pre, val_loader, ckpt_path=None)
# ------------------- Full train -------------------
model_full = make_model(
num_classes=num_classes,
video_shape=video_shape,
lr=1e-4,
weight_decay=0.01,
max_epochs=max_epochs,
weight_path=trainer_pre.checkpoint_callback.last_model_path,
)
trainer_full = make_trainer(max_epochs, f"{artery}BinSyntax_R3D_full_fold{fold:02d}", callbacks_full)
trainer_full.fit(model_full, train_loader_post, val_loader, ckpt_path=None)
if __name__ == "__main__":
main()