MesserMMP's picture
Add model code and full model weights
f621d73
import os
import json
from typing import Callable, Optional, Tuple, Any
from pathlib import Path
import numpy as np
import pydicom
import torch
from torch import Tensor
from torch.utils.data import Dataset
class SyntaxDataset(Dataset):
"""
PyTorch Dataset для обучения 3D-backbone по DICOM-видео.
Ожидается, что:
- meta (JSON) содержит список словарей с полями:
"path": относительный путь к DICOM-файлу от директории этого JSON
"artery": 0 (левая) или 1 (правая)
"<label>": численное значение SYNTAX score (например, "syntax_left")
- видео загружается из multi-frame DICOM (pydicom), как 3D-массив.
"""
def __init__(
self,
root: str,
meta: str,
train: bool,
length: int,
label: str,
artery_bin: int,
validation: bool = False,
transform: Optional[Callable] = None,
) -> None:
# Корень датасета (используется, если meta передан как относительный путь)
self.root = Path(root).resolve()
# Режим обучения/валидации
self.train = train
# Требуемая длина клипа (количество временных кадров)
self.length = int(length)
# Имя поля с меткой (например, "syntax_left")
self.label = label
# Функция аугментаций/преобразований
self.transform = transform
# Флаг: использовать только записи с положительным score
self.validation = validation
# Проверка валидности кода артерии: 0 — left, 1 — right
if artery_bin not in (0, 1):
raise ValueError("artery_bin must be 0 (left) or 1 (right)")
self.artery_bin = artery_bin
# Полный путь к JSON-файлу метаданных
meta_path = meta if os.path.isabs(meta) else self.root / meta
meta_path = Path(meta_path).resolve()
# Директория, в которой расположен JSON; пути в "path" задаются относительно неё
json_dir = meta_path.parent
print(f"Backbone dataset: root={self.root}, meta={meta_path}, json_dir={json_dir}")
# Загрузка списка записей из JSON
with open(meta_path, "r", encoding="utf-8") as f:
dataset = json.load(f)
# Фильтрация по артерии
dataset = [rec for rec in dataset if rec.get("artery") == artery_bin]
# Для валидации можно оставить только записи с положительным score
if validation:
dataset = [rec for rec in dataset if float(rec.get(self.label, 0.0)) > 0]
# Сохраняем базовую директорию JSON для последующей сборки путей DICOM
self.json_dir = json_dir
self.dataset = dataset
# Инициализация веса записи (если не задан) значением 1.0
for rec in self.dataset:
rec.setdefault("weight", 1.0)
print(f"Backbone dataset loaded: {len(self.dataset)} samples after filtering")
def get_sample_weights(self) -> Tensor:
"""
Возвращает веса для выборки с помощью WeightedRandomSampler.
Логика:
- делим записи на интервалы по severity (score) отдельно для левой и правой артерии;
- чем реже встречается интервал, тем больше вес его записей.
"""
# Пороговые значения score для разбиения на интервалы (по артерии)
bin_thresholds = {
0: [0, 5, 10, 15], # левая
1: [0, 2, 5, 8], # правая
}
thr0, thr1, thr2, thr3 = bin_thresholds[self.artery_bin]
# Функция определения номера бина по значению score
def in_bin(score: float) -> int:
if score == thr0:
return 0
if thr0 < score <= thr1:
return 1
if thr1 < score <= thr2:
return 2
if thr2 < score <= thr3:
return 3
return 4
# Список score'ов и соответствующих бинов
scores = [float(rec.get(self.label, 0.0)) for rec in self.dataset]
bins = [in_bin(s) for s in scores]
# Подсчёт количества элементов в каждом бине
counts = np.bincount(np.array(bins, dtype=np.int64), minlength=5)
total = int(counts.sum())
# Вес бина — обратная частота
weights_by_bin = np.array(
[(total / counts[b]) if counts[b] > 0 else 0.0 for b in range(5)],
dtype=np.float64,
)
# Веса для каждой записи по номеру её бина
weights = np.array([weights_by_bin[b] for b in bins], dtype=np.float64)
return torch.as_tensor(weights, dtype=torch.double)
def __len__(self) -> int:
# Размер датасета — количество записей в JSON после всех фильтров
return len(self.dataset)
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor, float, str, Tensor]:
"""
Возвращает:
video: Tensor, видео-клип (T, H, W, C) до применения transform
label: Tensor(1,), бинарная метка для классификации
target: Tensor(1,), регрессионная цель (log1p(score))
sample_weight: float, исходный вес записи из JSON
path: str, относительный путь к DICOM (как в JSON)
original_label: Tensor(1,), исходное значение score
"""
rec = self.dataset[idx]
# Относительный путь к DICOM, как он хранится в JSON
rel_path = rec["path"]
sample_weight = float(rec.get("weight", 1.0))
# Полный путь: директория JSON + относительный путь из JSON
full_path = (self.json_dir / rel_path).resolve()
if not full_path.exists():
raise FileNotFoundError(
f"DICOM not found: {full_path}\n"
f" json_dir={self.json_dir}\n"
f" rel_path='{rel_path}'"
)
# Загрузка DICOM и получение массива пикселей
video = pydicom.dcmread(str(full_path)).pixel_array
# Ожидается 3D-массив (T, H, W) или (H, W, T)
if video.ndim != 3:
raise ValueError(f"Expected 3D video array, got shape={video.shape} for {rel_path}")
# Если временная ось оказалась последней (H, W, T) — перенесём её в первую (T, H, W)
if video.shape[0] > 128 and video.shape[-1] <= 128:
video = np.moveaxis(video, -1, 0)
# Нормализация uint16 → uint8 с масштабированием на [0, 255]
if video.dtype == np.uint16:
vmax = int(np.max(video))
if vmax <= 0:
raise ValueError(f"Invalid vmax={vmax} for {rel_path}")
video = (video.astype(np.float32) * (255.0 / vmax)).clip(0, 255).astype(np.uint8)
else:
video = video.astype(np.uint8)
# Получение численного score из записи
score = float(rec.get(self.label, 0.0))
# Порог для бинарной классификации (отдельно для каждой артерии)
bin_thresholds = {
0: 15, # левая
1: 5, # правая
}
# Бинарная метка: 1, если score выше порога, иначе 0
label = torch.tensor(
[1.0 if score > bin_thresholds[self.artery_bin] else 0.0],
dtype=torch.float32,
)
# Регрессионная цель — логарифм score с единицей
target = torch.tensor([np.log1p(score)], dtype=torch.float32)
# Исходное значение score
original_label = torch.tensor([score], dtype=torch.float32)
# При необходимости "дублируем" видео по времени до нужной длины клипа
while video.shape[0] < self.length:
video = np.concatenate([video, video], axis=0)
# Случайная или фиксированная вырезка окна по времени
t = int(video.shape[0])
if self.train:
# В обучении берём случайный отрезок длины self.length
begin = torch.randint(low=0, high=t - self.length + 1, size=(1,)).item()
video = video[begin: begin + self.length]
else:
# На валидации можно взять первые self.length кадров
video = video[:self.length]
# Преобразуем (T, H, W) в (T, H, W, C), где C=3 (дублируем градации серого по каналам)
video = torch.from_numpy(np.stack([video, video, video], axis=-1))
# Применяем цепочку трансформаций, если задана
if self.transform is not None:
video = self.transform(video)
# Возвращаем видео, метки и относительный путь
return video, label, target, sample_weight, str(rel_path), original_label