MesserMMP's picture
Add model code and full model weights
f621d73
# inference/rnn_apply.py
import os
import json
import tqdm
import torch
import numpy as np
import click
from datetime import datetime
import lightning.pytorch as pl
import sklearn.metrics as skm
from torch.utils.data import DataLoader
from torchvision.transforms import transforms as T
from torchvision.transforms._transforms_video import ToTensorVideo
from pytorchvideo.transforms import Normalize
from full_model.rnn_dataset import SyntaxDataset
from full_model.rnn_model import SyntaxLightningModule
from inference.metrics_visualization import visualize_final_syntax_plotly_multi
DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(f"DEVICE: {DEVICE}")
def safe_sample_std(values):
"""Sample std (ddof=1). Если значение одно/пусто — 0.0."""
arr = np.array(values, dtype=float)
if arr.size <= 1:
return 0.0
return float(arr.std(ddof=1))
def compute_metrics(y_true, y_pred, thr=22.0):
"""Pearson и Mean_Recall."""
y_true_arr = np.array(y_true, dtype=float)
y_pred_arr = np.array(y_pred, dtype=float)
pearson = float(np.corrcoef(y_true_arr, y_pred_arr)[0, 1]) if len(y_true_arr) > 1 else 0.0
y_true_bin = (y_true_arr >= thr).astype(int)
y_pred_bin = (y_pred_arr >= thr).astype(int)
unique_classes = np.unique(np.concatenate([y_true_bin, y_pred_bin]))
mean_recall = float(
np.mean(skm.recall_score(y_true_bin, y_pred_bin, average=None, labels=[0, 1]))
) if len(unique_classes) > 1 else 0.0
return pearson, mean_recall
@click.command()
@click.option("-d", "--dataset-paths", multiple=True,
help="JSON с метаданными датасетов (относительно dataset_root).")
@click.option("-n", "--dataset-names", multiple=True,
help="Имена датасетов для метрик/графиков.")
@click.option("-p", "--postfixes", multiple=True,
help="Суффиксы для файлов предсказаний.")
@click.option(
"-r",
"--dataset-root",
type=click.Path(exists=True),
default=".",
show_default=True,
help="Корень датасета (где лежат JSON и DICOM).",
)
@click.option(
"--model-dir",
type=click.Path(exists=True),
default="full_model_weights",
show_default=True,
help="Каталог с .pt/.ckpt весами full‑моделей (RNN‑head + backbone).",
)
@click.option("-v", "--video-size", type=click.Tuple([int, int]), default=(256, 256),
show_default=True, help="Размер видео (H, W).")
@click.option("--frames-per-clip", type=int, default=32,
show_default=True, help="Количество кадров в клипе.")
@click.option("--num-workers", type=int, default=8,
show_default=True, help="Число DataLoader workers.")
@click.option("--seed", type=int, default=42,
show_default=True, help="Random seed.")
@click.option(
"--pt-weights-format",
is_flag=True,
default=True,
show_default=True,
help="Формат весов full‑моделей: True → .pt (raw state_dict), False → Lightning .ckpt.",
)
@click.option("--use-scaling", is_flag=True, default=False,
show_default=True, help="Применить a*x+b scaling из JSON.")
@click.option("--scaling-file",
help="JSON с коэффициентами scaling (относительно dataset_root).")
@click.option(
"--variant",
type=str,
default="lstm_mean",
show_default=True,
help="Вариант head‑модели: mean, lstm_mean, lstm_last, gru_mean, gru_last, bert_mean, bert_cls, bert_cls2.",
)
@click.option("-e", "--ensemble-name",
help="Имя ансамбля в metrics.json.")
@click.option("-m", "--metrics-file",
help="JSON с метриками экспериментов.")
def main(dataset_paths, dataset_names, postfixes, dataset_root, model_dir, video_size,
frames_per_clip, num_workers, seed, pt_weights_format, use_scaling,
scaling_file, variant, ensemble_name, metrics_file):
pl.seed_everything(seed)
postfix_plotly = "Ensemble"
# Пути к моделям берутся из model_dir по шаблону
model_paths = {
"left": [
os.path.join(model_dir, f"LeftBinSyntax_R3D_fold{fold:02d}_{variant}_post_best.pt")
for fold in range(5)
],
"right": [
os.path.join(model_dir, f"RightBinSyntax_R3D_fold{fold:02d}_{variant}_post_best.pt")
for fold in range(5)
],
}
scaling_params_dict = {}
if use_scaling:
postfix_plotly += "_scaled"
ensemble_name += "_scaled"
scaling_path = os.path.join(dataset_root, scaling_file)
if os.path.exists(scaling_path):
with open(scaling_path, "r") as f:
scaling_params_dict = json.load(f)
print(f"Loaded scaling from {scaling_path}")
else:
print(f"⚠️ Scaling file not found: {scaling_path}")
ensemble_results = {
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"use_scaling": use_scaling,
"pt_weights_format": pt_weights_format,
"variant": variant,
"datasets": {},
}
all_datasets, all_pearson, all_recalls = {}, {}, {}
# вспомогательная функция для получения (a, b)
def get_ab(i: int):
params = scaling_params_dict.get(f"fold{i}", (1.0, 0.0))
if isinstance(params, dict):
return params.get("a", 1.0), params.get("b", 0.0)
return params[0], params[1]
for dataset_path, dataset_name, postfix in zip(dataset_paths, dataset_names, postfixes):
abs_dataset_path = os.path.join(dataset_root, dataset_path)
results_file = os.path.join("results", f"{postfix}.json")
if os.path.exists(results_file):
print(f"[{postfix}] Loading from {results_file}")
with open(results_file, "r") as f:
data = json.load(f)
syntax_true = data["syntax_true"]
left_preds_all = data["left_preds"]
right_preds_all = data["right_preds"]
else:
print(f"[{postfix}] Computing predictions...")
left_preds_all, left_sids = run_artery(
abs_dataset_path, "left", model_paths["left"],
video_size, frames_per_clip, num_workers,
variant=variant, pt_weights_format=pt_weights_format,
)
right_preds_all, right_sids = run_artery(
abs_dataset_path, "right", model_paths["right"],
video_size, frames_per_clip, num_workers,
variant=variant, pt_weights_format=pt_weights_format,
)
assert left_sids == right_sids
with open(abs_dataset_path, "r") as f:
dataset = json.load(f)
syntax_true = [rec.get("mean_syntax", rec.get("syntax")) for rec in dataset]
os.makedirs(os.path.dirname(results_file), exist_ok=True)
save_data = {
"syntax_true": syntax_true,
"left_preds": left_preds_all,
"right_preds": right_preds_all,
}
with open(results_file, "w") as f:
json.dump(save_data, f)
print(f"[{postfix}] Saved to {results_file}")
# -------- ансамбль с/без scaling --------
if use_scaling:
syntax_pred = []
for l_list, r_list in zip(left_preds_all, right_preds_all):
scaled_folds = []
for i, (l_val, r_val) in enumerate(zip(l_list, r_list)):
s = l_val + r_val
a, b = get_ab(i)
scaled_folds.append(a * s + b)
syntax_pred.append(max(0.0, float(np.mean(scaled_folds))))
else:
syntax_pred = [
max(0.0, float(np.mean([l + r for l, r in zip(l_list, r_list)])))
for l_list, r_list in zip(left_preds_all, right_preds_all)
]
pearson, mean_recall = compute_metrics(syntax_true, syntax_pred)
print(f"[{postfix}] ENSEMBLE: Pearson={pearson:.4f}, Recall={mean_recall:.4f}")
# -------- per-fold метрики --------
n_folds = len(left_preds_all[0]) if left_preds_all else 0
fold_metrics = {metric: [] for metric in ["Pearson", "Mean_Recall"]}
for k in range(n_folds):
pred_k = []
for l_list, r_list in zip(left_preds_all, right_preds_all):
s = l_list[k] + r_list[k]
if use_scaling:
a, b = get_ab(k)
s = a * s + b
pred_k.append(max(0.0, float(s)))
fold_pearson, fold_recall = compute_metrics(syntax_true, pred_k)
for metric, value in zip(
fold_metrics.keys(),
[fold_pearson, fold_recall],
):
fold_metrics[metric].append(value)
fold_summary = {
k: {"mean": float(np.mean(v)), "std": safe_sample_std(v), "values": v}
for k, v in fold_metrics.items()
}
all_datasets[dataset_name] = (syntax_true, syntax_pred)
all_pearson[dataset_name] = pearson
all_recalls[dataset_name] = mean_recall
ensemble_results["datasets"][dataset_name] = {
"Pearson": round(pearson, 4),
"Mean_Recall": round(mean_recall, 4),
"N_samples": len(syntax_true),
**{f"{k}_mean": round(v["mean"], 4) for k, v in fold_summary.items()},
**{f"{k}_std": round(v["std"], 4) for k, v in fold_summary.items()},
**{f"{k}_folds": [round(x, 4) for x in v["values"]] for k, v in fold_summary.items()},
}
metrics_path = metrics_file
full_history = {}
if os.path.exists(metrics_path):
try:
with open(metrics_path, "r") as f:
full_history = json.load(f)
except json.JSONDecodeError:
print("⚠️ Metrics file corrupted. Creating new.")
full_history[ensemble_name] = ensemble_results
with open(metrics_path, "w") as f:
json.dump(full_history, f, indent=4)
print(f"✅ Metrics saved: {metrics_path}")
visualize_final_syntax_plotly_multi(
datasets=all_datasets,
r2_values=all_pearson, # здесь теперь Pearson
gt_row="ENSEMBLE",
postfix=postfix_plotly,
recall_values=all_recalls,
)
def run_artery(dataset_path, artery, model_paths, video_size, frames_per_clip,
num_workers, variant: str, pt_weights_format: bool):
"""Инференс для одной артерии (5 фолдов)."""
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
test_transform = T.Compose([
ToTensorVideo(),
T.Resize(size=video_size, antialias=True),
Normalize(mean=imagenet_mean, std=imagenet_std),
])
val_set = SyntaxDataset(
root=os.path.dirname(dataset_path),
meta=dataset_path,
train=False,
length=frames_per_clip,
label="",
artery=artery,
inference=True,
transform=test_transform,
)
val_loader = DataLoader(
val_set,
batch_size=1,
num_workers=num_workers,
shuffle=False,
pin_memory=True,
)
print(f"{artery} artery: {len(val_loader)} samples")
models = []
for path in model_paths:
if not os.path.exists(path):
print(f"⚠️ Model not found: {path}")
continue
model = SyntaxLightningModule(
num_classes=2,
lr=1e-5,
variant=variant,
weight_decay=0.001,
max_epochs=1,
weight_path=None,
pl_weight_path=path,
pt_weights_format=pt_weights_format,
)
model.to(DEVICE)
model.eval()
models.append(model)
if not models:
raise RuntimeError(f"No models loaded for {artery}")
preds_all, sids = [], []
with torch.no_grad():
for x, [y], [t], [sid] in tqdm.tqdm(val_loader, desc=f"{artery} infer"):
if len(x.shape) == 1:
val_syntax_list = [0.0] * len(models)
else:
x = x.to(DEVICE)
val_syntax_list = []
for model in models:
y_hat = model(x)
yp_reg = y_hat[:, 1:]
val_log = yp_reg.squeeze(-1)
val = float(torch.exp(val_log).cpu()) - 1.0
val_syntax_list.append(val)
preds_all.append(val_syntax_list)
sids.append(sid)
return preds_all, sids
if __name__ == "__main__":
main()