| |
| import os |
| import json |
| import tqdm |
| import torch |
| import numpy as np |
| import click |
| from datetime import datetime |
| import lightning.pytorch as pl |
| import sklearn.metrics as skm |
|
|
| from torch.utils.data import DataLoader |
| from torchvision.transforms import transforms as T |
| from torchvision.transforms._transforms_video import ToTensorVideo |
| from pytorchvideo.transforms import Normalize |
|
|
| from full_model.rnn_dataset import SyntaxDataset |
| from full_model.rnn_model import SyntaxLightningModule |
| from inference.metrics_visualization import visualize_final_syntax_plotly_multi |
|
|
|
|
| DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") |
| print(f"DEVICE: {DEVICE}") |
|
|
|
|
| def safe_sample_std(values): |
| """Sample std (ddof=1). Если значение одно/пусто — 0.0.""" |
| arr = np.array(values, dtype=float) |
| if arr.size <= 1: |
| return 0.0 |
| return float(arr.std(ddof=1)) |
|
|
|
|
| def compute_metrics(y_true, y_pred, thr=22.0): |
| """Pearson и Mean_Recall.""" |
| y_true_arr = np.array(y_true, dtype=float) |
| y_pred_arr = np.array(y_pred, dtype=float) |
|
|
| pearson = float(np.corrcoef(y_true_arr, y_pred_arr)[0, 1]) if len(y_true_arr) > 1 else 0.0 |
|
|
| y_true_bin = (y_true_arr >= thr).astype(int) |
| y_pred_bin = (y_pred_arr >= thr).astype(int) |
| unique_classes = np.unique(np.concatenate([y_true_bin, y_pred_bin])) |
| mean_recall = float( |
| np.mean(skm.recall_score(y_true_bin, y_pred_bin, average=None, labels=[0, 1])) |
| ) if len(unique_classes) > 1 else 0.0 |
|
|
| return pearson, mean_recall |
|
|
|
|
| @click.command() |
| @click.option("-d", "--dataset-paths", multiple=True, |
| help="JSON с метаданными датасетов (относительно dataset_root).") |
| @click.option("-n", "--dataset-names", multiple=True, |
| help="Имена датасетов для метрик/графиков.") |
| @click.option("-p", "--postfixes", multiple=True, |
| help="Суффиксы для файлов предсказаний.") |
| @click.option( |
| "-r", |
| "--dataset-root", |
| type=click.Path(exists=True), |
| default=".", |
| show_default=True, |
| help="Корень датасета (где лежат JSON и DICOM).", |
| ) |
| @click.option( |
| "--model-dir", |
| type=click.Path(exists=True), |
| default="full_model_weights", |
| show_default=True, |
| help="Каталог с .pt/.ckpt весами full‑моделей (RNN‑head + backbone).", |
| ) |
| @click.option("-v", "--video-size", type=click.Tuple([int, int]), default=(256, 256), |
| show_default=True, help="Размер видео (H, W).") |
| @click.option("--frames-per-clip", type=int, default=32, |
| show_default=True, help="Количество кадров в клипе.") |
| @click.option("--num-workers", type=int, default=8, |
| show_default=True, help="Число DataLoader workers.") |
| @click.option("--seed", type=int, default=42, |
| show_default=True, help="Random seed.") |
| @click.option( |
| "--pt-weights-format", |
| is_flag=True, |
| default=True, |
| show_default=True, |
| help="Формат весов full‑моделей: True → .pt (raw state_dict), False → Lightning .ckpt.", |
| ) |
| @click.option("--use-scaling", is_flag=True, default=False, |
| show_default=True, help="Применить a*x+b scaling из JSON.") |
| @click.option("--scaling-file", |
| help="JSON с коэффициентами scaling (относительно dataset_root).") |
| @click.option( |
| "--variant", |
| type=str, |
| default="lstm_mean", |
| show_default=True, |
| help="Вариант head‑модели: mean, lstm_mean, lstm_last, gru_mean, gru_last, bert_mean, bert_cls, bert_cls2.", |
| ) |
| @click.option("-e", "--ensemble-name", |
| help="Имя ансамбля в metrics.json.") |
| @click.option("-m", "--metrics-file", |
| help="JSON с метриками экспериментов.") |
| def main(dataset_paths, dataset_names, postfixes, dataset_root, model_dir, video_size, |
| frames_per_clip, num_workers, seed, pt_weights_format, use_scaling, |
| scaling_file, variant, ensemble_name, metrics_file): |
|
|
| pl.seed_everything(seed) |
| postfix_plotly = "Ensemble" |
|
|
| |
| model_paths = { |
| "left": [ |
| os.path.join(model_dir, f"LeftBinSyntax_R3D_fold{fold:02d}_{variant}_post_best.pt") |
| for fold in range(5) |
| ], |
| "right": [ |
| os.path.join(model_dir, f"RightBinSyntax_R3D_fold{fold:02d}_{variant}_post_best.pt") |
| for fold in range(5) |
| ], |
| } |
|
|
| scaling_params_dict = {} |
| if use_scaling: |
| postfix_plotly += "_scaled" |
| ensemble_name += "_scaled" |
| scaling_path = os.path.join(dataset_root, scaling_file) |
| if os.path.exists(scaling_path): |
| with open(scaling_path, "r") as f: |
| scaling_params_dict = json.load(f) |
| print(f"Loaded scaling from {scaling_path}") |
| else: |
| print(f"⚠️ Scaling file not found: {scaling_path}") |
|
|
| ensemble_results = { |
| "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| "use_scaling": use_scaling, |
| "pt_weights_format": pt_weights_format, |
| "variant": variant, |
| "datasets": {}, |
| } |
|
|
| all_datasets, all_pearson, all_recalls = {}, {}, {} |
|
|
| |
| def get_ab(i: int): |
| params = scaling_params_dict.get(f"fold{i}", (1.0, 0.0)) |
| if isinstance(params, dict): |
| return params.get("a", 1.0), params.get("b", 0.0) |
| return params[0], params[1] |
|
|
| for dataset_path, dataset_name, postfix in zip(dataset_paths, dataset_names, postfixes): |
| abs_dataset_path = os.path.join(dataset_root, dataset_path) |
| results_file = os.path.join("results", f"{postfix}.json") |
|
|
| if os.path.exists(results_file): |
| print(f"[{postfix}] Loading from {results_file}") |
| with open(results_file, "r") as f: |
| data = json.load(f) |
| syntax_true = data["syntax_true"] |
| left_preds_all = data["left_preds"] |
| right_preds_all = data["right_preds"] |
| else: |
| print(f"[{postfix}] Computing predictions...") |
| left_preds_all, left_sids = run_artery( |
| abs_dataset_path, "left", model_paths["left"], |
| video_size, frames_per_clip, num_workers, |
| variant=variant, pt_weights_format=pt_weights_format, |
| ) |
| right_preds_all, right_sids = run_artery( |
| abs_dataset_path, "right", model_paths["right"], |
| video_size, frames_per_clip, num_workers, |
| variant=variant, pt_weights_format=pt_weights_format, |
| ) |
| assert left_sids == right_sids |
|
|
| with open(abs_dataset_path, "r") as f: |
| dataset = json.load(f) |
| syntax_true = [rec.get("mean_syntax", rec.get("syntax")) for rec in dataset] |
|
|
| os.makedirs(os.path.dirname(results_file), exist_ok=True) |
| save_data = { |
| "syntax_true": syntax_true, |
| "left_preds": left_preds_all, |
| "right_preds": right_preds_all, |
| } |
| with open(results_file, "w") as f: |
| json.dump(save_data, f) |
| print(f"[{postfix}] Saved to {results_file}") |
|
|
| |
| if use_scaling: |
| syntax_pred = [] |
| for l_list, r_list in zip(left_preds_all, right_preds_all): |
| scaled_folds = [] |
| for i, (l_val, r_val) in enumerate(zip(l_list, r_list)): |
| s = l_val + r_val |
| a, b = get_ab(i) |
| scaled_folds.append(a * s + b) |
| syntax_pred.append(max(0.0, float(np.mean(scaled_folds)))) |
| else: |
| syntax_pred = [ |
| max(0.0, float(np.mean([l + r for l, r in zip(l_list, r_list)]))) |
| for l_list, r_list in zip(left_preds_all, right_preds_all) |
| ] |
|
|
| pearson, mean_recall = compute_metrics(syntax_true, syntax_pred) |
| print(f"[{postfix}] ENSEMBLE: Pearson={pearson:.4f}, Recall={mean_recall:.4f}") |
|
|
| |
| n_folds = len(left_preds_all[0]) if left_preds_all else 0 |
| fold_metrics = {metric: [] for metric in ["Pearson", "Mean_Recall"]} |
|
|
| for k in range(n_folds): |
| pred_k = [] |
| for l_list, r_list in zip(left_preds_all, right_preds_all): |
| s = l_list[k] + r_list[k] |
| if use_scaling: |
| a, b = get_ab(k) |
| s = a * s + b |
| pred_k.append(max(0.0, float(s))) |
|
|
| fold_pearson, fold_recall = compute_metrics(syntax_true, pred_k) |
| for metric, value in zip( |
| fold_metrics.keys(), |
| [fold_pearson, fold_recall], |
| ): |
| fold_metrics[metric].append(value) |
|
|
| fold_summary = { |
| k: {"mean": float(np.mean(v)), "std": safe_sample_std(v), "values": v} |
| for k, v in fold_metrics.items() |
| } |
|
|
| all_datasets[dataset_name] = (syntax_true, syntax_pred) |
| all_pearson[dataset_name] = pearson |
| all_recalls[dataset_name] = mean_recall |
|
|
| ensemble_results["datasets"][dataset_name] = { |
| "Pearson": round(pearson, 4), |
| "Mean_Recall": round(mean_recall, 4), |
| "N_samples": len(syntax_true), |
| **{f"{k}_mean": round(v["mean"], 4) for k, v in fold_summary.items()}, |
| **{f"{k}_std": round(v["std"], 4) for k, v in fold_summary.items()}, |
| **{f"{k}_folds": [round(x, 4) for x in v["values"]] for k, v in fold_summary.items()}, |
| } |
|
|
| metrics_path = metrics_file |
| full_history = {} |
| if os.path.exists(metrics_path): |
| try: |
| with open(metrics_path, "r") as f: |
| full_history = json.load(f) |
| except json.JSONDecodeError: |
| print("⚠️ Metrics file corrupted. Creating new.") |
|
|
| full_history[ensemble_name] = ensemble_results |
| with open(metrics_path, "w") as f: |
| json.dump(full_history, f, indent=4) |
| print(f"✅ Metrics saved: {metrics_path}") |
|
|
| visualize_final_syntax_plotly_multi( |
| datasets=all_datasets, |
| r2_values=all_pearson, |
| gt_row="ENSEMBLE", |
| postfix=postfix_plotly, |
| recall_values=all_recalls, |
| ) |
|
|
|
|
| def run_artery(dataset_path, artery, model_paths, video_size, frames_per_clip, |
| num_workers, variant: str, pt_weights_format: bool): |
| """Инференс для одной артерии (5 фолдов).""" |
| imagenet_mean = [0.485, 0.456, 0.406] |
| imagenet_std = [0.229, 0.224, 0.225] |
| test_transform = T.Compose([ |
| ToTensorVideo(), |
| T.Resize(size=video_size, antialias=True), |
| Normalize(mean=imagenet_mean, std=imagenet_std), |
| ]) |
|
|
| val_set = SyntaxDataset( |
| root=os.path.dirname(dataset_path), |
| meta=dataset_path, |
| train=False, |
| length=frames_per_clip, |
| label="", |
| artery=artery, |
| inference=True, |
| transform=test_transform, |
| ) |
| val_loader = DataLoader( |
| val_set, |
| batch_size=1, |
| num_workers=num_workers, |
| shuffle=False, |
| pin_memory=True, |
| ) |
| print(f"{artery} artery: {len(val_loader)} samples") |
|
|
| models = [] |
| for path in model_paths: |
| if not os.path.exists(path): |
| print(f"⚠️ Model not found: {path}") |
| continue |
|
|
| model = SyntaxLightningModule( |
| num_classes=2, |
| lr=1e-5, |
| variant=variant, |
| weight_decay=0.001, |
| max_epochs=1, |
| weight_path=None, |
| pl_weight_path=path, |
| pt_weights_format=pt_weights_format, |
| ) |
| model.to(DEVICE) |
| model.eval() |
| models.append(model) |
|
|
| if not models: |
| raise RuntimeError(f"No models loaded for {artery}") |
|
|
| preds_all, sids = [], [] |
| with torch.no_grad(): |
| for x, [y], [t], [sid] in tqdm.tqdm(val_loader, desc=f"{artery} infer"): |
| if len(x.shape) == 1: |
| val_syntax_list = [0.0] * len(models) |
| else: |
| x = x.to(DEVICE) |
| val_syntax_list = [] |
| for model in models: |
| y_hat = model(x) |
| yp_reg = y_hat[:, 1:] |
| val_log = yp_reg.squeeze(-1) |
| val = float(torch.exp(val_log).cpu()) - 1.0 |
| val_syntax_list.append(val) |
| preds_all.append(val_syntax_list) |
| sids.append(sid) |
|
|
| return preds_all, sids |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|