| | import os |
| | import json |
| | import pydicom |
| | import numpy as np |
| | import torch |
| |
|
| | from typing import Callable, Optional, Tuple |
| | from torch import Tensor |
| | from torch.utils.data import Dataset |
| |
|
| | |
| | |
| | DTYPE = torch.float16 |
| |
|
| |
|
| | class SyntaxDataset(Dataset): |
| | """ |
| | PyTorch Dataset для обучения видеобэкбона на задаче SYNTAX. |
| | |
| | Функциональность: |
| | - читает метаданные из JSON (относительный путь относительно root); |
| | - фильтрует по артерии (левая / правая); |
| | - опционально отфильтровывает только примеры с положительным SYNTAX |
| | (validation=True); |
| | - рассчитывает sample weights по бинам SYNTAX (для WeightedRandomSampler); |
| | - конвертирует DICOM-видео в тензор (T, H, W, 3) c uint8 [0–255]; |
| | - возвращает: |
| | video, label_bin, target_log, weight, rel_path, original_label. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | root: str, |
| | meta: str, |
| | train: bool, |
| | length: int, |
| | label: str, |
| | artery_bin: int, |
| | validation: bool = False, |
| | transform: Optional[Callable] = None, |
| | ) -> None: |
| | super().__init__() |
| | self.root = root |
| | self.train = train |
| | self.length = length |
| | self.label = label |
| | self.transform = transform |
| | self.validation = validation |
| |
|
| | |
| | meta_path = os.path.join(root, meta) |
| | with open(meta_path, "r") as f: |
| | dataset = json.load(f) |
| |
|
| | |
| | if artery_bin is not None: |
| | assert artery_bin in (0, 1), "artery_bin должен быть 0 (левая) или 1 (правая)" |
| | dataset = [rec for rec in dataset if rec["artery"] == artery_bin] |
| | self.artery_bin = artery_bin |
| | else: |
| | |
| | raise ValueError("artery_bin должен быть явно задан (0 или 1).") |
| |
|
| | |
| | if validation: |
| | dataset = [rec for rec in dataset if rec[self.label] > 0] |
| |
|
| | |
| | for rec in dataset: |
| | rec["weight"] = 1.0 |
| |
|
| | self.dataset = dataset |
| |
|
| | |
| | |
| | |
| | def get_sample_weights(self) -> Tensor: |
| | """ |
| | Считает веса для примеров по бинам SYNTAX. |
| | |
| | Для каждой артерии определён свой набор порогов, |
| | после чего каждый пример получает вес, обратный частоте своего бина. |
| | """ |
| | |
| | bin_thresholds = { |
| | 0: [0, 5, 10, 15], |
| | 1: [0, 2, 5, 8], |
| | } |
| |
|
| | thresholds = bin_thresholds[self.artery_bin] |
| | thr0, thr1, thr2, thr3 = thresholds |
| |
|
| | |
| | self.dataset_0 = [rec for rec in self.dataset if rec[self.label] == thr0] |
| | self.dataset_1 = [rec for rec in self.dataset if thr0 < rec[self.label] <= thr1] |
| | self.dataset_2 = [rec for rec in self.dataset if thr1 < rec[self.label] <= thr2] |
| | self.dataset_3 = [rec for rec in self.dataset if thr2 < rec[self.label] <= thr3] |
| | self.dataset_4 = [rec for rec in self.dataset if rec[self.label] > thr3] |
| |
|
| | total = ( |
| | len(self.dataset_0) |
| | + len(self.dataset_1) |
| | + len(self.dataset_2) |
| | + len(self.dataset_3) |
| | + len(self.dataset_4) |
| | ) |
| |
|
| | def safe_weight(count: int) -> float: |
| | |
| | return total / count if count > 0 else 0.0 |
| |
|
| | self.weights_0 = safe_weight(len(self.dataset_0)) |
| | self.weights_1 = safe_weight(len(self.dataset_1)) |
| | self.weights_2 = safe_weight(len(self.dataset_2)) |
| | self.weights_3 = safe_weight(len(self.dataset_3)) |
| | self.weights_4 = safe_weight(len(self.dataset_4)) |
| |
|
| | print( |
| | "Weights: ", |
| | self.weights_0, |
| | self.weights_1, |
| | self.weights_2, |
| | self.weights_3, |
| | self.weights_4, |
| | ) |
| | print( |
| | "Counts: ", |
| | len(self.dataset_0), |
| | len(self.dataset_1), |
| | len(self.dataset_2), |
| | len(self.dataset_3), |
| | len(self.dataset_4), |
| | ) |
| |
|
| | |
| | weights = [] |
| | for rec in self.dataset: |
| | syntax_score = rec[self.label] |
| | if syntax_score == thr0: |
| | weights.append(self.weights_0) |
| | elif thr0 < syntax_score <= thr1: |
| | weights.append(self.weights_1) |
| | elif thr1 < syntax_score <= thr2: |
| | weights.append(self.weights_2) |
| | elif thr2 < syntax_score <= thr3: |
| | weights.append(self.weights_3) |
| | else: |
| | weights.append(self.weights_4) |
| |
|
| | self.weights = torch.tensor(weights, dtype=DTYPE) |
| | return self.weights |
| |
|
| | |
| | def __len__(self) -> int: |
| | return len(self.dataset) |
| |
|
| | |
| | def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor, Tensor, str, Tensor]: |
| | """ |
| | Возвращает один пример: |
| | - video: Tensor (T, H, W, 3) → после transform обычно (C, T, H, W) |
| | - label: бинарный таргет по порогу для конкретной артерии |
| | - target: логарифмированный SYNTAX score (регрессия) |
| | - weight: вес примера (для самплера / лосса) |
| | - path: относительный путь к DICOM файлу |
| | - original_label: исходный SYNTAX score |
| | """ |
| | rec = self.dataset[idx] |
| |
|
| | |
| | path = rec["path"] |
| | weight = rec["weight"] |
| |
|
| | full_path = os.path.join(self.root, path) |
| | video = pydicom.dcmread(full_path).pixel_array |
| |
|
| | |
| | if video.dtype == np.uint16: |
| | vmax = np.max(video) |
| | assert vmax > 0 |
| | video = video.astype(np.float32) |
| | video = video * (255.0 / vmax) |
| | video = video.astype(np.uint8) |
| | assert video.dtype == np.uint8 |
| |
|
| | |
| | bin_thresholds = { |
| | 0: 15, |
| | 1: 5, |
| | } |
| |
|
| | syntax_value = rec[self.label] |
| | label = torch.tensor( |
| | [int(syntax_value > bin_thresholds[self.artery_bin])], |
| | dtype=DTYPE, |
| | ) |
| | target = torch.tensor([np.log(1.0 + syntax_value)], dtype=DTYPE) |
| | original_label = torch.tensor([syntax_value], dtype=DTYPE) |
| |
|
| | |
| | while len(video) < self.length: |
| | video = np.concatenate([video, video]) |
| | t = len(video) |
| |
|
| | if self.train: |
| | |
| | begin = torch.randint(low=0, high=t - self.length + 1, size=(1,)) |
| | end = begin + self.length |
| | video = video[begin:end, :, :] |
| | else: |
| | |
| | video = video |
| |
|
| | |
| | video = torch.tensor(np.stack([video, video, video], axis=-1)) |
| |
|
| | if self.transform is not None: |
| | video = self.transform(video) |
| |
|
| | sample_weight = torch.tensor([weight], dtype=DTYPE) |
| |
|
| | return video, label, target, sample_weight, path, original_label |
| |
|