import os import json import torch import numpy as np import click import lightning.pytorch as pl from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor from lightning.pytorch.profilers import AdvancedProfiler, PyTorchProfiler from pytorchvideo.transforms import Normalize, Permute, RandAugment from torch.utils.data import DataLoader, WeightedRandomSampler from torchvision.transforms import transforms as T from torchvision.transforms._transforms_video import ToTensorVideo from torchvision.transforms import InterpolationMode from dataset import SyntaxDataset from pl_model import SyntaxLightningModule import warnings warnings.filterwarnings("ignore", message="No device id is provided via `init_process_group`") torch.set_float32_matmul_precision("medium") """ Скрипт обучения backbone (3D-ResNet) для предсказания SYNTAX score. Шаги: 1) предварительное обучение (pretrain) — обучается только последний слой; 2) полное дообучение (full) — fine-tuning всего backbone. """ # ------------------- Трансформации ------------------- def get_transforms(video_size, imagenet_mean, imagenet_std, train=True): interpolation_choices = [ InterpolationMode.BILINEAR, InterpolationMode.BICUBIC, ] if train: return T.Compose([ ToTensorVideo(), # (T, H, W, 3) -> (C, T, H, W) Permute(dims=[1, 0, 2, 3]), # (C, T, H, W) -> (T, C, H, W) RandAugment(magnitude=10, num_layers=2), T.RandomHorizontalFlip(), Permute(dims=[1, 0, 2, 3]), # обратно: (T, C, H, W) -> (C, T, H, W) T.RandomChoice([ T.Resize(size=video_size, interpolation=interp, antialias=True) for interp in interpolation_choices ]), Normalize(mean=imagenet_mean, std=imagenet_std), ]) else: return T.Compose([ ToTensorVideo(), T.Resize(size=video_size, interpolation=InterpolationMode.BICUBIC, antialias=True), Normalize(mean=imagenet_mean, std=imagenet_std), ]) # ------------------- DataLoader ------------------- def make_dataloader(dataset, batch_size, num_workers): """ Создаёт DataLoader; по умолчанию используем shuffle, но можно легко переключиться на WeightedRandomSampler. """ sample_weights = dataset.get_sample_weights() # sampler = WeightedRandomSampler(sample_weights, len(dataset), replacement=True) return DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, # sampler=sampler, shuffle=True, drop_last=True, pin_memory=True, ) # ------------------- Модель ------------------- def make_model(num_classes, video_shape, lr, weight_decay, max_epochs, weight_path=None): """ Обёртка над SyntaxLightningModule для единообразного создания модели на этапах pretrain и full fine-tuning. """ model = SyntaxLightningModule( num_classes=num_classes, lr=lr, weight_decay=weight_decay, max_epochs=max_epochs, weight_path=weight_path, ) return model # ------------------- Callbacks ------------------- def make_callbacks(artery: str, fold: int, phase: str): """ Возвращает набор callback'ов: - LearningRateMonitor - ModelCheckpoint с сохранением по наилучшему val_mae. """ lr_monitor = LearningRateMonitor(logging_interval="epoch") if phase == "pre": checkpoint = ModelCheckpoint( monitor="val_mae", save_top_k=1, mode="min", filename="model" + "-{epoch:02d}-{val_rmse:.3f}", save_last=True, ) elif phase == "full": checkpoint = ModelCheckpoint( monitor="val_mae", save_top_k=3, mode="min", filename="model" + "-{epoch:02d}-{val_rmse:.3f}", save_last=True, ) else: raise ValueError(f"Unknown phase '{phase}', expected 'pre' or 'full'") return [lr_monitor, checkpoint] # ------------------- Trainer ------------------- def make_trainer(max_epochs, logger_name, callbacks): """ Создаёт Lightning Trainer c TensorBoardLogger. Важно: пути к логам и устройствам можно адаптировать под свой кластер. """ logger = TensorBoardLogger( save_dir="backbone_logs", name=logger_name, ) trainer = pl.Trainer( max_epochs=max_epochs, accelerator="gpu", devices=1, strategy="ddp_find_unused_parameters_true", precision="bf16-mixed", callbacks=callbacks, log_every_n_steps=10, logger=logger, ) return trainer @click.command() @click.option( "-r", "--dataset-root", type=click.Path(exists=True), default=".", required=True, help="Путь к корню датасета (директория, внутри которой лежат JSON и DICOM).", ) @click.option("--fold", type=int, default=0, required=True, help="Номер фолда (0–4).") @click.option( "-a", "--artery", type=str, default="right", required=True, help="Название артерии: 'left' или 'right'.", ) @click.option("-nc", "--num-classes", type=int, default=2, help="Число выходных каналов модели.") @click.option("-b", "--batch-size", type=int, default=50, help="Размер batch.") @click.option("-f", "--frames-per-clip", type=int, default=32, help="Количество кадров в клипе.") @click.option( "-v", "--video-size", type=click.Tuple([int, int]), default=(256, 256), help="Размер кадра (H, W).", ) @click.option("--max-epochs", type=int, default=10, help="Число эпох на этапе full fine-tuning.") @click.option("--num-workers", type=int, default=8, help="Число воркеров для DataLoader.") @click.option( "--fast-dev-run", is_flag=True, default=False, show_default=True, help="Режим быстрой проверки пайплайна (1–2 батча).", ) @click.option("--seed", type=int, default=42, help="Сид для воспроизводимости.") def main( dataset_root, fold, artery, num_classes, batch_size, frames_per_clip, video_size, max_epochs, num_workers, fast_dev_run, seed, ): pl.seed_everything(seed) artery = artery.lower() artery_bin = {"left": 0, "right": 1}.get(artery) if artery_bin is None: raise ValueError(f"Unknown artery '{artery}', expected 'left' or 'right'.") imagenet_mean = [0.485, 0.456, 0.406] imagenet_std = [0.229, 0.224, 0.225] # ------------------- Datasets ------------------- # Путь к JSON теперь относительный относительно dataset_root train_meta = os.path.join("folds", f"step2_fold{fold:02d}_train.json") val_meta = os.path.join("folds", f"step2_fold{fold:02d}_eval.json") train_set = SyntaxDataset( root=dataset_root, meta=train_meta, train=True, length=frames_per_clip, label=f"syntax_{artery}", artery_bin=artery_bin, validation=False, transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=True), ) val_set = SyntaxDataset( root=dataset_root, meta=val_meta, train=False, length=frames_per_clip, label=f"syntax_{artery}", artery_bin=artery_bin, validation=True, transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=False), ) train_loader_pre = make_dataloader(train_set, batch_size * 2, num_workers) train_loader_post = make_dataloader(train_set, batch_size, num_workers) val_loader = make_dataloader(val_set, 1, num_workers) # Получаем форму входного видео (C, T, H, W) из одного батча x, *_ = next(iter(train_loader_pre)) video_shape = x.shape[1:] # ------------------- Callbacks ------------------- callbacks_pre = make_callbacks(artery=artery, fold=fold, phase="pre") callbacks_full = make_callbacks(artery=artery, fold=fold, phase="full") # ------------------- Pretrain ------------------- num_pre_epochs = 10 model_pre = make_model( num_classes=num_classes, video_shape=video_shape, lr=3e-4, weight_decay=0.01, max_epochs=num_pre_epochs, ) trainer_pre = make_trainer(num_pre_epochs, f"{artery}BinSyntax_R3D_pre_fold{fold:02d}", callbacks_pre) trainer_pre.fit(model_pre, train_loader_pre, val_loader, ckpt_path=None) # ------------------- Full train ------------------- model_full = make_model( num_classes=num_classes, video_shape=video_shape, lr=1e-4, weight_decay=0.01, max_epochs=max_epochs, weight_path=trainer_pre.checkpoint_callback.last_model_path, ) trainer_full = make_trainer(max_epochs, f"{artery}BinSyntax_R3D_full_fold{fold:02d}", callbacks_full) trainer_full.fit(model_full, train_loader_post, val_loader, ckpt_path=None) if __name__ == "__main__": main()