| import os |
| import json |
| import torch |
| import numpy as np |
| import click |
| import lightning.pytorch as pl |
| from lightning.pytorch.loggers import TensorBoardLogger |
| from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor |
|
|
| from pytorchvideo.transforms import Normalize, Permute, RandAugment |
| from torch.utils.data import DataLoader |
| from torchvision.transforms import transforms as T |
| from torchvision.transforms._transforms_video import ToTensorVideo |
| from torchvision.transforms import InterpolationMode |
|
|
| from rnn_dataset import SyntaxDataset |
| from rnn_model import SyntaxLightningModule |
|
|
| torch.set_float32_matmul_precision("medium") |
|
|
|
|
| """ |
| Обучение RNN-head поверх предобученного backbone для SYNTAX score. |
| |
| Этапы: |
| 1) pretrain — обучается только head (backbone заморожен); |
| 2) full — fine-tuning всей модели (backbone + head). |
| """ |
|
|
|
|
| def get_transforms(video_size, imagenet_mean, imagenet_std, train=True): |
| """Трансформации для видео (train с аугментациями, test без).""" |
| interpolation_choices = [InterpolationMode.BILINEAR, InterpolationMode.BICUBIC] |
| if train: |
| return T.Compose([ |
| ToTensorVideo(), |
| Permute(dims=[1, 0, 2, 3]), |
| RandAugment(magnitude=10, num_layers=2), |
| T.RandomHorizontalFlip(), |
| Permute(dims=[1, 0, 2, 3]), |
| T.RandomChoice([ |
| T.Resize(size=video_size, interpolation=interp, antialias=True) |
| for interp in interpolation_choices |
| ]), |
| Normalize(mean=imagenet_mean, std=imagenet_std), |
| ]) |
| return T.Compose([ |
| ToTensorVideo(), |
| T.Resize(size=video_size, interpolation=InterpolationMode.BICUBIC, antialias=True), |
| Normalize(mean=imagenet_mean, std=imagenet_std), |
| ]) |
|
|
|
|
| def make_dataloader(dataset, batch_size, num_workers): |
| """DataLoader с shuffle (sampler закомментирован).""" |
| |
| return DataLoader( |
| dataset, |
| batch_size=batch_size, |
| num_workers=num_workers, |
| shuffle=True if not dataset.inference else False, |
| drop_last=True, |
| pin_memory=True, |
| ) |
|
|
|
|
| def make_model(num_classes, video_shape, lr, variant, weight_decay, max_epochs, |
| weight_path=None, pl_weight_path=None, pt_weights_format=False): |
| """Создание SyntaxLightningModule.""" |
| return SyntaxLightningModule( |
| num_classes=num_classes, |
| lr=lr, |
| variant=variant, |
| weight_decay=weight_decay, |
| max_epochs=max_epochs, |
| weight_path=weight_path, |
| pl_weight_path=pl_weight_path, |
| pt_weights_format=pt_weights_format, |
| ) |
|
|
|
|
| def make_callbacks(artery: str, fold: int, phase: str): |
| """Callbacks: LR monitor + checkpoint по val_mae.""" |
| lr_monitor = LearningRateMonitor(logging_interval="epoch") |
| if phase == "pre": |
| checkpoint = ModelCheckpoint( |
| monitor="val_mae", |
| save_top_k=1, |
| mode="min", |
| filename="model-{epoch:02d}-{val_rmse:.3f}", |
| save_last=True, |
| ) |
| elif phase == "full": |
| checkpoint = ModelCheckpoint( |
| monitor="val_mae", |
| save_top_k=3, |
| mode="min", |
| filename="model-{epoch:02d}-{val_rmse:.3f}", |
| save_last=True, |
| ) |
| else: |
| raise ValueError(f"phase must be 'pre' or 'full'") |
| return [lr_monitor, checkpoint] |
|
|
|
|
| def make_trainer(max_epochs, logger_name, callbacks): |
| """Lightning Trainer с TensorBoard.""" |
| logger = TensorBoardLogger(save_dir="rnn_logs", name=logger_name) |
| trainer = pl.Trainer( |
| max_epochs=max_epochs, |
| accelerator="gpu", |
| devices=1, |
| strategy="ddp_find_unused_parameters_true", |
| precision="bf16-mixed", |
| callbacks=callbacks, |
| log_every_n_steps=10, |
| logger=logger, |
| ) |
| return trainer |
|
|
|
|
| @click.command() |
| @click.option( |
| "-r", "--dataset-root", type=click.Path(exists=True), required=True, |
| help="Корень датасета (где лежат folds/*.json и DICOM).", |
| ) |
| @click.option("--fold", type=int, default=0, help="Номер фолда (0-4).") |
| @click.option("-a", "--artery", type=str, default="right", help="'left' или 'right'.") |
| @click.option("--variant", type=str, default="lstm_mean", help="Тип head (lstm_mean и др.).") |
| @click.option("-nc", "--num-classes", type=int, default=2) |
| @click.option("-b", "--batch-size", type=int, default=8) |
| @click.option("-f", "--frames-per-clip", type=int, default=32) |
| @click.option("-v", "--video-size", type=click.Tuple([int, int]), default=(256, 256)) |
| @click.option("--max-epochs", type=int, default=10) |
| @click.option("--num-workers", type=int, default=8) |
| @click.option("--fast-dev-run", is_flag=True) |
| @click.option("--seed", type=int, default=42) |
| @click.option("--backbone-ckpt", type=str, help="Путь к backbone-чекпоинту для pretrain.") |
| def main( |
| dataset_root, fold, artery, variant, num_classes, batch_size, frames_per_clip, |
| video_size, max_epochs, num_workers, fast_dev_run, seed, backbone_ckpt, |
| ): |
| pl.seed_everything(seed) |
| artery = artery.lower() |
| artery_bin = {"left": 0, "right": 1}[artery] |
|
|
| print(f"Training {variant} head for {artery} artery, fold {fold}") |
|
|
| imagenet_mean = [0.485, 0.456, 0.406] |
| imagenet_std = [0.229, 0.224, 0.225] |
|
|
| |
| train_meta = os.path.join("rnn_folds", f"step2_rnn_fold{fold:02d}_train.json") |
| val_meta = os.path.join("rnn_folds", f"step2_rnn_fold{fold:02d}_eval.json") |
|
|
| train_set = SyntaxDataset( |
| root=dataset_root, |
| meta=train_meta, |
| train=True, |
| length=frames_per_clip, |
| label=f"syntax_{artery}", |
| artery=artery, |
| transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=True), |
| ) |
| val_set = SyntaxDataset( |
| root=dataset_root, |
| meta=val_meta, |
| train=False, |
| length=frames_per_clip, |
| label=f"syntax_{artery}", |
| artery=artery, |
| validation=True, |
| transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=False), |
| ) |
|
|
| |
| train_loader_pre = make_dataloader(train_set, batch_size * 2, num_workers) |
| train_loader_post = make_dataloader(train_set, batch_size, num_workers) |
| val_loader = make_dataloader(val_set, 1, num_workers) |
|
|
| |
| x, *_ = next(iter(train_loader_pre)) |
| video_shape = x.shape[1:] |
|
|
| |
| callbacks_pre = make_callbacks(artery, fold, "pre") |
| model_pre = make_model( |
| num_classes, video_shape, lr=1e-4, variant=variant, |
| weight_decay=0.01, max_epochs=max_epochs, weight_path=backbone_ckpt, |
| ) |
| trainer_pre = make_trainer(max_epochs, f"{artery}_{variant}_pre_fold{fold:02d}", callbacks_pre) |
| trainer_pre.fit(model_pre, train_loader_pre, val_loader) |
|
|
| |
| callbacks_full = make_callbacks(artery, fold, "full") |
| model_full = make_model( |
| num_classes, video_shape, lr=2e-5, variant=variant, |
| weight_decay=0.01, max_epochs=max_epochs, |
| pl_weight_path=trainer_pre.checkpoint_callback.best_model_path, |
| ) |
| trainer_full = make_trainer(max_epochs, f"{artery}_{variant}_full_fold{fold:02d}", callbacks_full) |
| trainer_full.fit(model_full, train_loader_post, val_loader) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|