syntax-model / backbone /dataset.py
MesserMMP's picture
add backbone model
3ca1d12
import os
import json
import pydicom
import numpy as np
import torch
from typing import Callable, Optional, Tuple
from torch import Tensor
from torch.utils.data import Dataset
# Полуточность достаточно для хранения весов и таргетов,
# а сами вычисления в модели идут в float32 / bf16.
DTYPE = torch.float16
class SyntaxDataset(Dataset):
"""
PyTorch Dataset для обучения видеобэкбона на задаче SYNTAX.
Функциональность:
- читает метаданные из JSON (относительный путь относительно root);
- фильтрует по артерии (левая / правая);
- опционально отфильтровывает только примеры с положительным SYNTAX
(validation=True);
- рассчитывает sample weights по бинам SYNTAX (для WeightedRandomSampler);
- конвертирует DICOM-видео в тензор (T, H, W, 3) c uint8 [0–255];
- возвращает:
video, label_bin, target_log, weight, rel_path, original_label.
"""
def __init__(
self,
root: str, # корневая директория датасета
meta: str, # относительный путь к JSON с метаданными
train: bool, # режим: train / eval
length: int, # длина клипа (кол-во кадров)
label: str, # имя поля с SYNTAX score в JSON
artery_bin: int, # 0 — левая, 1 — правая артерия
validation: bool = False, # отбрасывать ли нулевые SYNTAX
transform: Optional[Callable] = None,
) -> None:
super().__init__()
self.root = root
self.train = train
self.length = length
self.label = label
self.transform = transform
self.validation = validation
# meta теперь трактуется как ОТНОСИТЕЛЬНЫЙ путь от root
meta_path = os.path.join(root, meta)
with open(meta_path, "r") as f:
dataset = json.load(f)
# Фильтр по артерии (0 — левая, 1 — правая)
if artery_bin is not None:
assert artery_bin in (0, 1), "artery_bin должен быть 0 (левая) или 1 (правая)"
dataset = [rec for rec in dataset if rec["artery"] == artery_bin]
self.artery_bin = artery_bin
else:
# Для корректной работы get_sample_weights ожидаем известный artery_bin
raise ValueError("artery_bin должен быть явно задан (0 или 1).")
# Валидационный набор: берём только записи с положительным SYNTAX
if validation:
dataset = [rec for rec in dataset if rec[self.label] > 0]
# Инициализируем веса с единиц
for rec in dataset:
rec["weight"] = 1.0
self.dataset = dataset
# ------------------------------------------------------------------
# Веса для WeightedRandomSampler
# ------------------------------------------------------------------
def get_sample_weights(self) -> Tensor:
"""
Считает веса для примеров по бинам SYNTAX.
Для каждой артерии определён свой набор порогов,
после чего каждый пример получает вес, обратный частоте своего бина.
"""
# Пороговые значения по артериям (подбирались эмпирически)
bin_thresholds = {
0: [0, 5, 10, 15], # левая
1: [0, 2, 5, 8], # правая
}
thresholds = bin_thresholds[self.artery_bin]
thr0, thr1, thr2, thr3 = thresholds
# Бины по значениям SYNTAX
self.dataset_0 = [rec for rec in self.dataset if rec[self.label] == thr0]
self.dataset_1 = [rec for rec in self.dataset if thr0 < rec[self.label] <= thr1]
self.dataset_2 = [rec for rec in self.dataset if thr1 < rec[self.label] <= thr2]
self.dataset_3 = [rec for rec in self.dataset if thr2 < rec[self.label] <= thr3]
self.dataset_4 = [rec for rec in self.dataset if rec[self.label] > thr3]
total = (
len(self.dataset_0)
+ len(self.dataset_1)
+ len(self.dataset_2)
+ len(self.dataset_3)
+ len(self.dataset_4)
)
def safe_weight(count: int) -> float:
# Если в бине нет примеров, вес ставим 0.0
return total / count if count > 0 else 0.0
self.weights_0 = safe_weight(len(self.dataset_0))
self.weights_1 = safe_weight(len(self.dataset_1))
self.weights_2 = safe_weight(len(self.dataset_2))
self.weights_3 = safe_weight(len(self.dataset_3))
self.weights_4 = safe_weight(len(self.dataset_4))
print(
"Weights: ",
self.weights_0,
self.weights_1,
self.weights_2,
self.weights_3,
self.weights_4,
)
print(
"Counts: ",
len(self.dataset_0),
len(self.dataset_1),
len(self.dataset_2),
len(self.dataset_3),
len(self.dataset_4),
)
# Назначаем вес каждому примеру
weights = []
for rec in self.dataset:
syntax_score = rec[self.label]
if syntax_score == thr0:
weights.append(self.weights_0)
elif thr0 < syntax_score <= thr1:
weights.append(self.weights_1)
elif thr1 < syntax_score <= thr2:
weights.append(self.weights_2)
elif thr2 < syntax_score <= thr3:
weights.append(self.weights_3)
else:
weights.append(self.weights_4)
self.weights = torch.tensor(weights, dtype=DTYPE)
return self.weights
# ------------------------------------------------------------------
def __len__(self) -> int:
return len(self.dataset)
# ------------------------------------------------------------------
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor, Tensor, str, Tensor]:
"""
Возвращает один пример:
- video: Tensor (T, H, W, 3) → после transform обычно (C, T, H, W)
- label: бинарный таргет по порогу для конкретной артерии
- target: логарифмированный SYNTAX score (регрессия)
- weight: вес примера (для самплера / лосса)
- path: относительный путь к DICOM файлу
- original_label: исходный SYNTAX score
"""
rec = self.dataset[idx]
# Относительный путь к DICOM из JSON (мы не храним абсолютные пути)
path = rec["path"]
weight = rec["weight"]
full_path = os.path.join(self.root, path)
video = pydicom.dcmread(full_path).pixel_array # (T, H, W)
# Приводим 16-битный сигнал к диапазону [0, 255] uint8
if video.dtype == np.uint16:
vmax = np.max(video)
assert vmax > 0
video = video.astype(np.float32)
video = video * (255.0 / vmax)
video = video.astype(np.uint8)
assert video.dtype == np.uint8
# Порог для бинарной классификации зависит от артерии
bin_thresholds = {
0: 15, # левая
1: 5, # правая
}
syntax_value = rec[self.label]
label = torch.tensor(
[int(syntax_value > bin_thresholds[self.artery_bin])],
dtype=DTYPE,
)
target = torch.tensor([np.log(1.0 + syntax_value)], dtype=DTYPE)
original_label = torch.tensor([syntax_value], dtype=DTYPE)
# Дублируем видео по времени, пока не наберём нужную длину клипа
while len(video) < self.length:
video = np.concatenate([video, video])
t = len(video)
if self.train:
# Случайный подотрезок длины self.length
begin = torch.randint(low=0, high=t - self.length + 1, size=(1,))
end = begin + self.length
video = video[begin:end, :, :]
else:
# В валидации используем весь видеоряд (обрежется трансформами / моделью)
video = video
# Превращаем (T, H, W) → (T, H, W, 3) путём копирования каналов (grayscale→RGB)
video = torch.tensor(np.stack([video, video, video], axis=-1))
if self.transform is not None:
video = self.transform(video)
sample_weight = torch.tensor([weight], dtype=DTYPE)
return video, label, target, sample_weight, path, original_label