| import os |
| import json |
| from typing import Callable, Optional, Tuple, Any |
| from pathlib import Path |
|
|
| import numpy as np |
| import pydicom |
| import torch |
| from torch import Tensor |
| from torch.utils.data import Dataset |
|
|
|
|
| class SyntaxDataset(Dataset): |
| """ |
| PyTorch Dataset для обучения 3D-backbone по DICOM-видео. |
| |
| Ожидается, что: |
| - meta (JSON) содержит список словарей с полями: |
| "path": относительный путь к DICOM-файлу от директории этого JSON |
| "artery": 0 (левая) или 1 (правая) |
| "<label>": численное значение SYNTAX score (например, "syntax_left") |
| - видео загружается из multi-frame DICOM (pydicom), как 3D-массив. |
| """ |
|
|
| def __init__( |
| self, |
| root: str, |
| meta: str, |
| train: bool, |
| length: int, |
| label: str, |
| artery_bin: int, |
| validation: bool = False, |
| transform: Optional[Callable] = None, |
| ) -> None: |
| |
| self.root = Path(root).resolve() |
| |
| self.train = train |
| |
| self.length = int(length) |
| |
| self.label = label |
| |
| self.transform = transform |
| |
| self.validation = validation |
|
|
| |
| if artery_bin not in (0, 1): |
| raise ValueError("artery_bin must be 0 (left) or 1 (right)") |
| self.artery_bin = artery_bin |
|
|
| |
| meta_path = meta if os.path.isabs(meta) else self.root / meta |
| meta_path = Path(meta_path).resolve() |
|
|
| |
| json_dir = meta_path.parent |
|
|
| print(f"Backbone dataset: root={self.root}, meta={meta_path}, json_dir={json_dir}") |
|
|
| |
| with open(meta_path, "r", encoding="utf-8") as f: |
| dataset = json.load(f) |
|
|
| |
| dataset = [rec for rec in dataset if rec.get("artery") == artery_bin] |
|
|
| |
| if validation: |
| dataset = [rec for rec in dataset if float(rec.get(self.label, 0.0)) > 0] |
|
|
| |
| self.json_dir = json_dir |
| self.dataset = dataset |
|
|
| |
| for rec in self.dataset: |
| rec.setdefault("weight", 1.0) |
|
|
| print(f"Backbone dataset loaded: {len(self.dataset)} samples after filtering") |
|
|
| def get_sample_weights(self) -> Tensor: |
| """ |
| Возвращает веса для выборки с помощью WeightedRandomSampler. |
| |
| Логика: |
| - делим записи на интервалы по severity (score) отдельно для левой и правой артерии; |
| - чем реже встречается интервал, тем больше вес его записей. |
| """ |
| |
| bin_thresholds = { |
| 0: [0, 5, 10, 15], |
| 1: [0, 2, 5, 8], |
| } |
| thr0, thr1, thr2, thr3 = bin_thresholds[self.artery_bin] |
|
|
| |
| def in_bin(score: float) -> int: |
| if score == thr0: |
| return 0 |
| if thr0 < score <= thr1: |
| return 1 |
| if thr1 < score <= thr2: |
| return 2 |
| if thr2 < score <= thr3: |
| return 3 |
| return 4 |
|
|
| |
| scores = [float(rec.get(self.label, 0.0)) for rec in self.dataset] |
| bins = [in_bin(s) for s in scores] |
|
|
| |
| counts = np.bincount(np.array(bins, dtype=np.int64), minlength=5) |
| total = int(counts.sum()) |
|
|
| |
| weights_by_bin = np.array( |
| [(total / counts[b]) if counts[b] > 0 else 0.0 for b in range(5)], |
| dtype=np.float64, |
| ) |
|
|
| |
| weights = np.array([weights_by_bin[b] for b in bins], dtype=np.float64) |
|
|
| return torch.as_tensor(weights, dtype=torch.double) |
|
|
| def __len__(self) -> int: |
| |
| return len(self.dataset) |
|
|
| def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor, float, str, Tensor]: |
| """ |
| Возвращает: |
| video: Tensor, видео-клип (T, H, W, C) до применения transform |
| label: Tensor(1,), бинарная метка для классификации |
| target: Tensor(1,), регрессионная цель (log1p(score)) |
| sample_weight: float, исходный вес записи из JSON |
| path: str, относительный путь к DICOM (как в JSON) |
| original_label: Tensor(1,), исходное значение score |
| """ |
| rec = self.dataset[idx] |
|
|
| |
| rel_path = rec["path"] |
| sample_weight = float(rec.get("weight", 1.0)) |
|
|
| |
| full_path = (self.json_dir / rel_path).resolve() |
|
|
| if not full_path.exists(): |
| raise FileNotFoundError( |
| f"DICOM not found: {full_path}\n" |
| f" json_dir={self.json_dir}\n" |
| f" rel_path='{rel_path}'" |
| ) |
|
|
| |
| video = pydicom.dcmread(str(full_path)).pixel_array |
|
|
| |
| if video.ndim != 3: |
| raise ValueError(f"Expected 3D video array, got shape={video.shape} for {rel_path}") |
|
|
| |
| if video.shape[0] > 128 and video.shape[-1] <= 128: |
| video = np.moveaxis(video, -1, 0) |
|
|
| |
| if video.dtype == np.uint16: |
| vmax = int(np.max(video)) |
| if vmax <= 0: |
| raise ValueError(f"Invalid vmax={vmax} for {rel_path}") |
| video = (video.astype(np.float32) * (255.0 / vmax)).clip(0, 255).astype(np.uint8) |
| else: |
| video = video.astype(np.uint8) |
|
|
| |
| score = float(rec.get(self.label, 0.0)) |
|
|
| |
| bin_thresholds = { |
| 0: 15, |
| 1: 5, |
| } |
|
|
| |
| label = torch.tensor( |
| [1.0 if score > bin_thresholds[self.artery_bin] else 0.0], |
| dtype=torch.float32, |
| ) |
| |
| target = torch.tensor([np.log1p(score)], dtype=torch.float32) |
| |
| original_label = torch.tensor([score], dtype=torch.float32) |
|
|
| |
| while video.shape[0] < self.length: |
| video = np.concatenate([video, video], axis=0) |
|
|
| |
| t = int(video.shape[0]) |
| if self.train: |
| |
| begin = torch.randint(low=0, high=t - self.length + 1, size=(1,)).item() |
| video = video[begin: begin + self.length] |
| else: |
| |
| video = video[:self.length] |
|
|
| |
| video = torch.from_numpy(np.stack([video, video, video], axis=-1)) |
|
|
| |
| if self.transform is not None: |
| video = self.transform(video) |
|
|
| |
| return video, label, target, sample_weight, str(rel_path), original_label |
|
|