| | import os |
| | import json |
| | import torch |
| | import numpy as np |
| | import click |
| | import lightning.pytorch as pl |
| | from lightning.pytorch.loggers import TensorBoardLogger |
| | from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor |
| | from lightning.pytorch.profilers import AdvancedProfiler, PyTorchProfiler |
| |
|
| | from pytorchvideo.transforms import Normalize, Permute, RandAugment |
| | from torch.utils.data import DataLoader, WeightedRandomSampler |
| | from torchvision.transforms import transforms as T |
| | from torchvision.transforms._transforms_video import ToTensorVideo |
| | from torchvision.transforms import InterpolationMode |
| |
|
| | from dataset import SyntaxDataset |
| | from pl_model import SyntaxLightningModule |
| |
|
| | import warnings |
| | warnings.filterwarnings("ignore", message="No device id is provided via `init_process_group`") |
| |
|
| | torch.set_float32_matmul_precision("medium") |
| |
|
| |
|
| | """ |
| | Скрипт обучения backbone (3D-ResNet) для предсказания SYNTAX score. |
| | |
| | Шаги: |
| | 1) предварительное обучение (pretrain) — обучается только последний слой; |
| | 2) полное дообучение (full) — fine-tuning всего backbone. |
| | """ |
| |
|
| |
|
| | |
| | def get_transforms(video_size, imagenet_mean, imagenet_std, train=True): |
| | interpolation_choices = [ |
| | InterpolationMode.BILINEAR, |
| | InterpolationMode.BICUBIC, |
| | ] |
| | if train: |
| | return T.Compose([ |
| | ToTensorVideo(), |
| | Permute(dims=[1, 0, 2, 3]), |
| | RandAugment(magnitude=10, num_layers=2), |
| | T.RandomHorizontalFlip(), |
| | Permute(dims=[1, 0, 2, 3]), |
| | T.RandomChoice([ |
| | T.Resize(size=video_size, interpolation=interp, antialias=True) |
| | for interp in interpolation_choices |
| | ]), |
| | Normalize(mean=imagenet_mean, std=imagenet_std), |
| | ]) |
| | else: |
| | return T.Compose([ |
| | ToTensorVideo(), |
| | T.Resize(size=video_size, interpolation=InterpolationMode.BICUBIC, antialias=True), |
| | Normalize(mean=imagenet_mean, std=imagenet_std), |
| | ]) |
| |
|
| |
|
| | |
| | def make_dataloader(dataset, batch_size, num_workers): |
| | """ |
| | Создаёт DataLoader; по умолчанию используем shuffle, |
| | но можно легко переключиться на WeightedRandomSampler. |
| | """ |
| | sample_weights = dataset.get_sample_weights() |
| | |
| | return DataLoader( |
| | dataset, |
| | batch_size=batch_size, |
| | num_workers=num_workers, |
| | |
| | shuffle=True, |
| | drop_last=True, |
| | pin_memory=True, |
| | ) |
| |
|
| |
|
| | |
| | def make_model(num_classes, video_shape, lr, weight_decay, max_epochs, weight_path=None): |
| | """ |
| | Обёртка над SyntaxLightningModule для единообразного создания модели |
| | на этапах pretrain и full fine-tuning. |
| | """ |
| | model = SyntaxLightningModule( |
| | num_classes=num_classes, |
| | lr=lr, |
| | weight_decay=weight_decay, |
| | max_epochs=max_epochs, |
| | weight_path=weight_path, |
| | ) |
| | return model |
| |
|
| |
|
| | |
| | def make_callbacks(artery: str, fold: int, phase: str): |
| | """ |
| | Возвращает набор callback'ов: |
| | - LearningRateMonitor |
| | - ModelCheckpoint с сохранением по наилучшему val_mae. |
| | """ |
| | lr_monitor = LearningRateMonitor(logging_interval="epoch") |
| |
|
| | if phase == "pre": |
| | checkpoint = ModelCheckpoint( |
| | monitor="val_mae", |
| | save_top_k=1, |
| | mode="min", |
| | filename="model" + "-{epoch:02d}-{val_rmse:.3f}", |
| | save_last=True, |
| | ) |
| | elif phase == "full": |
| | checkpoint = ModelCheckpoint( |
| | monitor="val_mae", |
| | save_top_k=3, |
| | mode="min", |
| | filename="model" + "-{epoch:02d}-{val_rmse:.3f}", |
| | save_last=True, |
| | ) |
| | else: |
| | raise ValueError(f"Unknown phase '{phase}', expected 'pre' or 'full'") |
| |
|
| | return [lr_monitor, checkpoint] |
| |
|
| |
|
| | |
| | def make_trainer(max_epochs, logger_name, callbacks): |
| | """ |
| | Создаёт Lightning Trainer c TensorBoardLogger. |
| | |
| | Важно: пути к логам и устройствам можно адаптировать под свой кластер. |
| | """ |
| | logger = TensorBoardLogger( |
| | save_dir="backbone_logs", |
| | name=logger_name, |
| | ) |
| | trainer = pl.Trainer( |
| | max_epochs=max_epochs, |
| | accelerator="gpu", |
| | devices=1, |
| | strategy="ddp_find_unused_parameters_true", |
| | precision="bf16-mixed", |
| | callbacks=callbacks, |
| | log_every_n_steps=10, |
| | logger=logger, |
| | ) |
| | return trainer |
| |
|
| |
|
| | @click.command() |
| | @click.option( |
| | "-r", |
| | "--dataset-root", |
| | type=click.Path(exists=True), |
| | default=".", |
| | required=True, |
| | help="Путь к корню датасета (директория, внутри которой лежат JSON и DICOM).", |
| | ) |
| | @click.option("--fold", type=int, default=0, required=True, help="Номер фолда (0–4).") |
| | @click.option( |
| | "-a", |
| | "--artery", |
| | type=str, |
| | default="right", |
| | required=True, |
| | help="Название артерии: 'left' или 'right'.", |
| | ) |
| | @click.option("-nc", "--num-classes", type=int, default=2, help="Число выходных каналов модели.") |
| | @click.option("-b", "--batch-size", type=int, default=50, help="Размер batch.") |
| | @click.option("-f", "--frames-per-clip", type=int, default=32, help="Количество кадров в клипе.") |
| | @click.option( |
| | "-v", |
| | "--video-size", |
| | type=click.Tuple([int, int]), |
| | default=(256, 256), |
| | help="Размер кадра (H, W).", |
| | ) |
| | @click.option("--max-epochs", type=int, default=10, help="Число эпох на этапе full fine-tuning.") |
| | @click.option("--num-workers", type=int, default=8, help="Число воркеров для DataLoader.") |
| | @click.option( |
| | "--fast-dev-run", |
| | is_flag=True, |
| | default=False, |
| | show_default=True, |
| | help="Режим быстрой проверки пайплайна (1–2 батча).", |
| | ) |
| | @click.option("--seed", type=int, default=42, help="Сид для воспроизводимости.") |
| | def main( |
| | dataset_root, |
| | fold, |
| | artery, |
| | num_classes, |
| | batch_size, |
| | frames_per_clip, |
| | video_size, |
| | max_epochs, |
| | num_workers, |
| | fast_dev_run, |
| | seed, |
| | ): |
| | pl.seed_everything(seed) |
| |
|
| | artery = artery.lower() |
| | artery_bin = {"left": 0, "right": 1}.get(artery) |
| | if artery_bin is None: |
| | raise ValueError(f"Unknown artery '{artery}', expected 'left' or 'right'.") |
| |
|
| | imagenet_mean = [0.485, 0.456, 0.406] |
| | imagenet_std = [0.229, 0.224, 0.225] |
| |
|
| | |
| | |
| | train_meta = os.path.join("folds", f"step2_fold{fold:02d}_train.json") |
| | val_meta = os.path.join("folds", f"step2_fold{fold:02d}_eval.json") |
| |
|
| | train_set = SyntaxDataset( |
| | root=dataset_root, |
| | meta=train_meta, |
| | train=True, |
| | length=frames_per_clip, |
| | label=f"syntax_{artery}", |
| | artery_bin=artery_bin, |
| | validation=False, |
| | transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=True), |
| | ) |
| |
|
| | val_set = SyntaxDataset( |
| | root=dataset_root, |
| | meta=val_meta, |
| | train=False, |
| | length=frames_per_clip, |
| | label=f"syntax_{artery}", |
| | artery_bin=artery_bin, |
| | validation=True, |
| | transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=False), |
| | ) |
| |
|
| | train_loader_pre = make_dataloader(train_set, batch_size * 2, num_workers) |
| | train_loader_post = make_dataloader(train_set, batch_size, num_workers) |
| | val_loader = make_dataloader(val_set, 1, num_workers) |
| |
|
| | |
| | x, *_ = next(iter(train_loader_pre)) |
| | video_shape = x.shape[1:] |
| |
|
| | |
| | callbacks_pre = make_callbacks(artery=artery, fold=fold, phase="pre") |
| | callbacks_full = make_callbacks(artery=artery, fold=fold, phase="full") |
| |
|
| | |
| | num_pre_epochs = 10 |
| | model_pre = make_model( |
| | num_classes=num_classes, |
| | video_shape=video_shape, |
| | lr=3e-4, |
| | weight_decay=0.01, |
| | max_epochs=num_pre_epochs, |
| | ) |
| | trainer_pre = make_trainer(num_pre_epochs, f"{artery}BinSyntax_R3D_pre_fold{fold:02d}", callbacks_pre) |
| | trainer_pre.fit(model_pre, train_loader_pre, val_loader, ckpt_path=None) |
| |
|
| | |
| | model_full = make_model( |
| | num_classes=num_classes, |
| | video_shape=video_shape, |
| | lr=1e-4, |
| | weight_decay=0.01, |
| | max_epochs=max_epochs, |
| | weight_path=trainer_pre.checkpoint_callback.last_model_path, |
| | ) |
| | trainer_full = make_trainer(max_epochs, f"{artery}BinSyntax_R3D_full_fold{fold:02d}", callbacks_full) |
| | trainer_full.fit(model_full, train_loader_post, val_loader, ckpt_path=None) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|