import os import json import pydicom import numpy as np import torch from typing import Callable, Optional, Tuple from torch import Tensor from torch.utils.data import Dataset # Полуточность достаточно для хранения весов и таргетов, # а сами вычисления в модели идут в float32 / bf16. DTYPE = torch.float16 class SyntaxDataset(Dataset): """ PyTorch Dataset для обучения видеобэкбона на задаче SYNTAX. Функциональность: - читает метаданные из JSON (относительный путь относительно root); - фильтрует по артерии (левая / правая); - опционально отфильтровывает только примеры с положительным SYNTAX (validation=True); - рассчитывает sample weights по бинам SYNTAX (для WeightedRandomSampler); - конвертирует DICOM-видео в тензор (T, H, W, 3) c uint8 [0–255]; - возвращает: video, label_bin, target_log, weight, rel_path, original_label. """ def __init__( self, root: str, # корневая директория датасета meta: str, # относительный путь к JSON с метаданными train: bool, # режим: train / eval length: int, # длина клипа (кол-во кадров) label: str, # имя поля с SYNTAX score в JSON artery_bin: int, # 0 — левая, 1 — правая артерия validation: bool = False, # отбрасывать ли нулевые SYNTAX transform: Optional[Callable] = None, ) -> None: super().__init__() self.root = root self.train = train self.length = length self.label = label self.transform = transform self.validation = validation # meta теперь трактуется как ОТНОСИТЕЛЬНЫЙ путь от root meta_path = os.path.join(root, meta) with open(meta_path, "r") as f: dataset = json.load(f) # Фильтр по артерии (0 — левая, 1 — правая) if artery_bin is not None: assert artery_bin in (0, 1), "artery_bin должен быть 0 (левая) или 1 (правая)" dataset = [rec for rec in dataset if rec["artery"] == artery_bin] self.artery_bin = artery_bin else: # Для корректной работы get_sample_weights ожидаем известный artery_bin raise ValueError("artery_bin должен быть явно задан (0 или 1).") # Валидационный набор: берём только записи с положительным SYNTAX if validation: dataset = [rec for rec in dataset if rec[self.label] > 0] # Инициализируем веса с единиц for rec in dataset: rec["weight"] = 1.0 self.dataset = dataset # ------------------------------------------------------------------ # Веса для WeightedRandomSampler # ------------------------------------------------------------------ def get_sample_weights(self) -> Tensor: """ Считает веса для примеров по бинам SYNTAX. Для каждой артерии определён свой набор порогов, после чего каждый пример получает вес, обратный частоте своего бина. """ # Пороговые значения по артериям (подбирались эмпирически) bin_thresholds = { 0: [0, 5, 10, 15], # левая 1: [0, 2, 5, 8], # правая } thresholds = bin_thresholds[self.artery_bin] thr0, thr1, thr2, thr3 = thresholds # Бины по значениям SYNTAX self.dataset_0 = [rec for rec in self.dataset if rec[self.label] == thr0] self.dataset_1 = [rec for rec in self.dataset if thr0 < rec[self.label] <= thr1] self.dataset_2 = [rec for rec in self.dataset if thr1 < rec[self.label] <= thr2] self.dataset_3 = [rec for rec in self.dataset if thr2 < rec[self.label] <= thr3] self.dataset_4 = [rec for rec in self.dataset if rec[self.label] > thr3] total = ( len(self.dataset_0) + len(self.dataset_1) + len(self.dataset_2) + len(self.dataset_3) + len(self.dataset_4) ) def safe_weight(count: int) -> float: # Если в бине нет примеров, вес ставим 0.0 return total / count if count > 0 else 0.0 self.weights_0 = safe_weight(len(self.dataset_0)) self.weights_1 = safe_weight(len(self.dataset_1)) self.weights_2 = safe_weight(len(self.dataset_2)) self.weights_3 = safe_weight(len(self.dataset_3)) self.weights_4 = safe_weight(len(self.dataset_4)) print( "Weights: ", self.weights_0, self.weights_1, self.weights_2, self.weights_3, self.weights_4, ) print( "Counts: ", len(self.dataset_0), len(self.dataset_1), len(self.dataset_2), len(self.dataset_3), len(self.dataset_4), ) # Назначаем вес каждому примеру weights = [] for rec in self.dataset: syntax_score = rec[self.label] if syntax_score == thr0: weights.append(self.weights_0) elif thr0 < syntax_score <= thr1: weights.append(self.weights_1) elif thr1 < syntax_score <= thr2: weights.append(self.weights_2) elif thr2 < syntax_score <= thr3: weights.append(self.weights_3) else: weights.append(self.weights_4) self.weights = torch.tensor(weights, dtype=DTYPE) return self.weights # ------------------------------------------------------------------ def __len__(self) -> int: return len(self.dataset) # ------------------------------------------------------------------ def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor, Tensor, str, Tensor]: """ Возвращает один пример: - video: Tensor (T, H, W, 3) → после transform обычно (C, T, H, W) - label: бинарный таргет по порогу для конкретной артерии - target: логарифмированный SYNTAX score (регрессия) - weight: вес примера (для самплера / лосса) - path: относительный путь к DICOM файлу - original_label: исходный SYNTAX score """ rec = self.dataset[idx] # Относительный путь к DICOM из JSON (мы не храним абсолютные пути) path = rec["path"] weight = rec["weight"] full_path = os.path.join(self.root, path) video = pydicom.dcmread(full_path).pixel_array # (T, H, W) # Приводим 16-битный сигнал к диапазону [0, 255] uint8 if video.dtype == np.uint16: vmax = np.max(video) assert vmax > 0 video = video.astype(np.float32) video = video * (255.0 / vmax) video = video.astype(np.uint8) assert video.dtype == np.uint8 # Порог для бинарной классификации зависит от артерии bin_thresholds = { 0: 15, # левая 1: 5, # правая } syntax_value = rec[self.label] label = torch.tensor( [int(syntax_value > bin_thresholds[self.artery_bin])], dtype=DTYPE, ) target = torch.tensor([np.log(1.0 + syntax_value)], dtype=DTYPE) original_label = torch.tensor([syntax_value], dtype=DTYPE) # Дублируем видео по времени, пока не наберём нужную длину клипа while len(video) < self.length: video = np.concatenate([video, video]) t = len(video) if self.train: # Случайный подотрезок длины self.length begin = torch.randint(low=0, high=t - self.length + 1, size=(1,)) end = begin + self.length video = video[begin:end, :, :] else: # В валидации используем весь видеоряд (обрежется трансформами / моделью) video = video # Превращаем (T, H, W) → (T, H, W, 3) путём копирования каналов (grayscale→RGB) video = torch.tensor(np.stack([video, video, video], axis=-1)) if self.transform is not None: video = self.transform(video) sample_weight = torch.tensor([weight], dtype=DTYPE) return video, label, target, sample_weight, path, original_label