# inference/rnn_apply.py import os import json import tqdm import torch import numpy as np import click from datetime import datetime import lightning.pytorch as pl import sklearn.metrics as skm from torch.utils.data import DataLoader from torchvision.transforms import transforms as T from torchvision.transforms._transforms_video import ToTensorVideo from pytorchvideo.transforms import Normalize from full_model.rnn_dataset import SyntaxDataset from full_model.rnn_model import SyntaxLightningModule from inference.metrics_visualization import visualize_final_syntax_plotly_multi DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") print(f"DEVICE: {DEVICE}") def safe_sample_std(values): """Sample std (ddof=1). Если значение одно/пусто — 0.0.""" arr = np.array(values, dtype=float) if arr.size <= 1: return 0.0 return float(arr.std(ddof=1)) def compute_metrics(y_true, y_pred, thr=22.0): """Pearson и Mean_Recall.""" y_true_arr = np.array(y_true, dtype=float) y_pred_arr = np.array(y_pred, dtype=float) pearson = float(np.corrcoef(y_true_arr, y_pred_arr)[0, 1]) if len(y_true_arr) > 1 else 0.0 y_true_bin = (y_true_arr >= thr).astype(int) y_pred_bin = (y_pred_arr >= thr).astype(int) unique_classes = np.unique(np.concatenate([y_true_bin, y_pred_bin])) mean_recall = float( np.mean(skm.recall_score(y_true_bin, y_pred_bin, average=None, labels=[0, 1])) ) if len(unique_classes) > 1 else 0.0 return pearson, mean_recall @click.command() @click.option("-d", "--dataset-paths", multiple=True, help="JSON с метаданными датасетов (относительно dataset_root).") @click.option("-n", "--dataset-names", multiple=True, help="Имена датасетов для метрик/графиков.") @click.option("-p", "--postfixes", multiple=True, help="Суффиксы для файлов предсказаний.") @click.option( "-r", "--dataset-root", type=click.Path(exists=True), default=".", show_default=True, help="Корень датасета (где лежат JSON и DICOM).", ) @click.option( "--model-dir", type=click.Path(exists=True), default="full_model_weights", show_default=True, help="Каталог с .pt/.ckpt весами full‑моделей (RNN‑head + backbone).", ) @click.option("-v", "--video-size", type=click.Tuple([int, int]), default=(256, 256), show_default=True, help="Размер видео (H, W).") @click.option("--frames-per-clip", type=int, default=32, show_default=True, help="Количество кадров в клипе.") @click.option("--num-workers", type=int, default=8, show_default=True, help="Число DataLoader workers.") @click.option("--seed", type=int, default=42, show_default=True, help="Random seed.") @click.option( "--pt-weights-format", is_flag=True, default=True, show_default=True, help="Формат весов full‑моделей: True → .pt (raw state_dict), False → Lightning .ckpt.", ) @click.option("--use-scaling", is_flag=True, default=False, show_default=True, help="Применить a*x+b scaling из JSON.") @click.option("--scaling-file", help="JSON с коэффициентами scaling (относительно dataset_root).") @click.option( "--variant", type=str, default="lstm_mean", show_default=True, help="Вариант head‑модели: mean, lstm_mean, lstm_last, gru_mean, gru_last, bert_mean, bert_cls, bert_cls2.", ) @click.option("-e", "--ensemble-name", help="Имя ансамбля в metrics.json.") @click.option("-m", "--metrics-file", help="JSON с метриками экспериментов.") def main(dataset_paths, dataset_names, postfixes, dataset_root, model_dir, video_size, frames_per_clip, num_workers, seed, pt_weights_format, use_scaling, scaling_file, variant, ensemble_name, metrics_file): pl.seed_everything(seed) postfix_plotly = "Ensemble" # Пути к моделям берутся из model_dir по шаблону model_paths = { "left": [ os.path.join(model_dir, f"LeftBinSyntax_R3D_fold{fold:02d}_{variant}_post_best.pt") for fold in range(5) ], "right": [ os.path.join(model_dir, f"RightBinSyntax_R3D_fold{fold:02d}_{variant}_post_best.pt") for fold in range(5) ], } scaling_params_dict = {} if use_scaling: postfix_plotly += "_scaled" ensemble_name += "_scaled" scaling_path = os.path.join(dataset_root, scaling_file) if os.path.exists(scaling_path): with open(scaling_path, "r") as f: scaling_params_dict = json.load(f) print(f"Loaded scaling from {scaling_path}") else: print(f"⚠️ Scaling file not found: {scaling_path}") ensemble_results = { "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "use_scaling": use_scaling, "pt_weights_format": pt_weights_format, "variant": variant, "datasets": {}, } all_datasets, all_pearson, all_recalls = {}, {}, {} # вспомогательная функция для получения (a, b) def get_ab(i: int): params = scaling_params_dict.get(f"fold{i}", (1.0, 0.0)) if isinstance(params, dict): return params.get("a", 1.0), params.get("b", 0.0) return params[0], params[1] for dataset_path, dataset_name, postfix in zip(dataset_paths, dataset_names, postfixes): abs_dataset_path = os.path.join(dataset_root, dataset_path) results_file = os.path.join("results", f"{postfix}.json") if os.path.exists(results_file): print(f"[{postfix}] Loading from {results_file}") with open(results_file, "r") as f: data = json.load(f) syntax_true = data["syntax_true"] left_preds_all = data["left_preds"] right_preds_all = data["right_preds"] else: print(f"[{postfix}] Computing predictions...") left_preds_all, left_sids = run_artery( abs_dataset_path, "left", model_paths["left"], video_size, frames_per_clip, num_workers, variant=variant, pt_weights_format=pt_weights_format, ) right_preds_all, right_sids = run_artery( abs_dataset_path, "right", model_paths["right"], video_size, frames_per_clip, num_workers, variant=variant, pt_weights_format=pt_weights_format, ) assert left_sids == right_sids with open(abs_dataset_path, "r") as f: dataset = json.load(f) syntax_true = [rec.get("mean_syntax", rec.get("syntax")) for rec in dataset] os.makedirs(os.path.dirname(results_file), exist_ok=True) save_data = { "syntax_true": syntax_true, "left_preds": left_preds_all, "right_preds": right_preds_all, } with open(results_file, "w") as f: json.dump(save_data, f) print(f"[{postfix}] Saved to {results_file}") # -------- ансамбль с/без scaling -------- if use_scaling: syntax_pred = [] for l_list, r_list in zip(left_preds_all, right_preds_all): scaled_folds = [] for i, (l_val, r_val) in enumerate(zip(l_list, r_list)): s = l_val + r_val a, b = get_ab(i) scaled_folds.append(a * s + b) syntax_pred.append(max(0.0, float(np.mean(scaled_folds)))) else: syntax_pred = [ max(0.0, float(np.mean([l + r for l, r in zip(l_list, r_list)]))) for l_list, r_list in zip(left_preds_all, right_preds_all) ] pearson, mean_recall = compute_metrics(syntax_true, syntax_pred) print(f"[{postfix}] ENSEMBLE: Pearson={pearson:.4f}, Recall={mean_recall:.4f}") # -------- per-fold метрики -------- n_folds = len(left_preds_all[0]) if left_preds_all else 0 fold_metrics = {metric: [] for metric in ["Pearson", "Mean_Recall"]} for k in range(n_folds): pred_k = [] for l_list, r_list in zip(left_preds_all, right_preds_all): s = l_list[k] + r_list[k] if use_scaling: a, b = get_ab(k) s = a * s + b pred_k.append(max(0.0, float(s))) fold_pearson, fold_recall = compute_metrics(syntax_true, pred_k) for metric, value in zip( fold_metrics.keys(), [fold_pearson, fold_recall], ): fold_metrics[metric].append(value) fold_summary = { k: {"mean": float(np.mean(v)), "std": safe_sample_std(v), "values": v} for k, v in fold_metrics.items() } all_datasets[dataset_name] = (syntax_true, syntax_pred) all_pearson[dataset_name] = pearson all_recalls[dataset_name] = mean_recall ensemble_results["datasets"][dataset_name] = { "Pearson": round(pearson, 4), "Mean_Recall": round(mean_recall, 4), "N_samples": len(syntax_true), **{f"{k}_mean": round(v["mean"], 4) for k, v in fold_summary.items()}, **{f"{k}_std": round(v["std"], 4) for k, v in fold_summary.items()}, **{f"{k}_folds": [round(x, 4) for x in v["values"]] for k, v in fold_summary.items()}, } metrics_path = metrics_file full_history = {} if os.path.exists(metrics_path): try: with open(metrics_path, "r") as f: full_history = json.load(f) except json.JSONDecodeError: print("⚠️ Metrics file corrupted. Creating new.") full_history[ensemble_name] = ensemble_results with open(metrics_path, "w") as f: json.dump(full_history, f, indent=4) print(f"✅ Metrics saved: {metrics_path}") visualize_final_syntax_plotly_multi( datasets=all_datasets, r2_values=all_pearson, # здесь теперь Pearson gt_row="ENSEMBLE", postfix=postfix_plotly, recall_values=all_recalls, ) def run_artery(dataset_path, artery, model_paths, video_size, frames_per_clip, num_workers, variant: str, pt_weights_format: bool): """Инференс для одной артерии (5 фолдов).""" imagenet_mean = [0.485, 0.456, 0.406] imagenet_std = [0.229, 0.224, 0.225] test_transform = T.Compose([ ToTensorVideo(), T.Resize(size=video_size, antialias=True), Normalize(mean=imagenet_mean, std=imagenet_std), ]) val_set = SyntaxDataset( root=os.path.dirname(dataset_path), meta=dataset_path, train=False, length=frames_per_clip, label="", artery=artery, inference=True, transform=test_transform, ) val_loader = DataLoader( val_set, batch_size=1, num_workers=num_workers, shuffle=False, pin_memory=True, ) print(f"{artery} artery: {len(val_loader)} samples") models = [] for path in model_paths: if not os.path.exists(path): print(f"⚠️ Model not found: {path}") continue model = SyntaxLightningModule( num_classes=2, lr=1e-5, variant=variant, weight_decay=0.001, max_epochs=1, weight_path=None, pl_weight_path=path, pt_weights_format=pt_weights_format, ) model.to(DEVICE) model.eval() models.append(model) if not models: raise RuntimeError(f"No models loaded for {artery}") preds_all, sids = [], [] with torch.no_grad(): for x, [y], [t], [sid] in tqdm.tqdm(val_loader, desc=f"{artery} infer"): if len(x.shape) == 1: val_syntax_list = [0.0] * len(models) else: x = x.to(DEVICE) val_syntax_list = [] for model in models: y_hat = model(x) yp_reg = y_hat[:, 1:] val_log = yp_reg.squeeze(-1) val = float(torch.exp(val_log).cpu()) - 1.0 val_syntax_list.append(val) preds_all.append(val_syntax_list) sids.append(sid) return preds_all, sids if __name__ == "__main__": main()