| """ |
| Evaluation and WandB visualization for diffusion models on The Well. |
| |
| Produces: |
| - Single-step comparison images: Condition | Ground Truth | Prediction |
| - Multi-step rollout videos: GT trajectory vs Predicted trajectory (side-by-side) |
| - Per-step MSE metrics for rollout quality analysis |
| """ |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| def _get_colormap(name="RdBu_r"): |
| """Return a colormap function (avoids repeated imports).""" |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.cm as cm |
| return cm.get_cmap(name) |
|
|
| _CMAP_CACHE = {} |
|
|
| def apply_colormap(field_01, cmap_name="RdBu_r"): |
| """[H, W] float in [0,1] → [H, W, 3] uint8 RGB.""" |
| if cmap_name not in _CMAP_CACHE: |
| _CMAP_CACHE[cmap_name] = _get_colormap(cmap_name) |
| rgba = _CMAP_CACHE[cmap_name](np.clip(field_01, 0, 1)) |
| return (rgba[:, :, :3] * 255).astype(np.uint8) |
|
|
|
|
| def normalize_for_vis(f, vmin=None, vmax=None): |
| """Percentile-robust normalization to [0, 1].""" |
| if vmin is None: |
| vmin = np.percentile(f, 2) |
| if vmax is None: |
| vmax = np.percentile(f, 98) |
| return np.clip((f - vmin) / max(vmax - vmin, 1e-8), 0, 1), vmin, vmax |
|
|
|
|
| |
| |
| |
|
|
| def _comparison_image(cond, gt, pred, cmap="RdBu_r"): |
| """Build a [H, W*3+4, 3] uint8 image: Cond | GT | Pred.""" |
| vals = np.concatenate([cond.flat, gt.flat, pred.flat]) |
| vmin, vmax = np.percentile(vals, 2), np.percentile(vals, 98) |
|
|
| def rgb(f): |
| n, _, _ = normalize_for_vis(f, vmin, vmax) |
| return apply_colormap(n, cmap) |
|
|
| H = cond.shape[0] |
| sep = np.full((H, 2, 3), 200, dtype=np.uint8) |
| return np.concatenate([rgb(cond), sep, rgb(gt), sep, rgb(pred)], axis=1) |
|
|
|
|
| @torch.no_grad() |
| def single_step_eval(model, val_loader, device, n_batches=4, ddim_steps=50): |
| """Compute val MSE and generate comparison images. |
| |
| Returns: |
| metrics: dict {'val/mse': float} |
| comparisons: list of (image_array, caption_string) |
| """ |
| from data_pipeline import prepare_batch |
|
|
| model.eval() |
| total_mse, n_samples = 0.0, 0 |
| first_data = None |
|
|
| for i, batch in enumerate(val_loader): |
| if i >= n_batches: |
| break |
| x_cond, x_target = prepare_batch(batch, device) |
| x_pred = model.sample_ddim(x_cond, shape=x_target.shape, steps=ddim_steps) |
|
|
| mse = F.mse_loss(x_pred, x_target).item() |
| total_mse += mse * x_target.shape[0] |
| n_samples += x_target.shape[0] |
|
|
| if i == 0: |
| first_data = (x_cond[:4].cpu(), x_target[:4].cpu(), x_pred[:4].cpu()) |
|
|
| avg_mse = total_mse / max(n_samples, 1) |
|
|
| comparisons = [] |
| if first_data is not None: |
| xc, xt, xp = first_data |
| n_ch = min(xc.shape[1], 4) |
| for b in range(xc.shape[0]): |
| for ch in range(n_ch): |
| img = _comparison_image( |
| xc[b, ch].numpy(), xt[b, ch].numpy(), xp[b, ch].numpy() |
| ) |
| comparisons.append((img, f"sample{b}_ch{ch}")) |
|
|
| model.train() |
| return {"val/mse": avg_mse}, comparisons |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def rollout_eval( |
| model, rollout_loader, device, |
| n_rollout=20, ddim_steps=50, channel=0, cmap="RdBu_r", |
| ): |
| """Autoregressive rollout with GT comparison video. |
| |
| Creates side-by-side video: Ground Truth | Prediction |
| and computes per-step MSE. |
| |
| Args: |
| model: GaussianDiffusion instance. |
| rollout_loader: DataLoader with n_steps_output >= n_rollout. |
| device: torch device. |
| n_rollout: autoregressive prediction steps. |
| ddim_steps: DDIM denoising steps per prediction. |
| channel: which field channel to visualize. |
| cmap: matplotlib colormap. |
| |
| Returns: |
| video: [T, 3, H, W_combined] uint8 for wandb.Video. |
| per_step_mse: list[float] of length n_rollout. |
| """ |
| model.eval() |
| batch = next(iter(rollout_loader)) |
|
|
| |
| inp = batch["input_fields"][:1] |
| out = batch["output_fields"][:1] |
|
|
| T_out = out.shape[1] |
| n_steps = min(n_rollout, T_out) |
| C = inp.shape[-1] |
|
|
| |
| x_cond = inp[:, 0].permute(0, 3, 1, 2).float().to(device) |
|
|
| |
| gt_frames = [out[:, t].permute(0, 3, 1, 2).float() for t in range(n_steps)] |
|
|
| |
| pred_frames = [] |
| per_step_mse = [] |
| cond = x_cond |
|
|
| for t in range(n_steps): |
| pred = model.sample_ddim(cond, shape=cond.shape, steps=ddim_steps, eta=0.0) |
| pred_cpu = pred.cpu() |
| pred_frames.append(pred_cpu) |
|
|
| mse_t = F.mse_loss(pred_cpu, gt_frames[t]).item() |
| per_step_mse.append(mse_t) |
|
|
| cond = pred |
| if (t + 1) % 5 == 0: |
| logger.info(f" rollout step {t+1}/{n_steps}, mse={mse_t:.6f}") |
|
|
| |
| ch = min(channel, C - 1) |
|
|
| |
| all_vals = [x_cond[0, ch].cpu().numpy().flat] |
| for t in range(n_steps): |
| all_vals.append(gt_frames[t][0, ch].numpy().flat) |
| all_vals.append(pred_frames[t][0, ch].numpy().flat) |
| all_vals = np.concatenate(list(all_vals)) |
| vmin, vmax = np.percentile(all_vals, 2), np.percentile(all_vals, 98) |
|
|
| def to_rgb(field_2d): |
| n, _, _ = normalize_for_vis(field_2d, vmin, vmax) |
| return apply_colormap(n, cmap) |
|
|
| H, W = x_cond.shape[2], x_cond.shape[3] |
| sep = np.full((H, 4, 3), 200, dtype=np.uint8) |
|
|
| |
| def _label_frame(gt_rgb, pred_rgb): |
| """Concatenate with separator.""" |
| return np.concatenate([gt_rgb, sep, pred_rgb], axis=1) |
|
|
| frames = [] |
|
|
| |
| init_rgb = to_rgb(x_cond[0, ch].cpu().numpy()) |
| frames.append(_label_frame(init_rgb, init_rgb).transpose(2, 0, 1)) |
|
|
| |
| for t in range(n_steps): |
| gt_rgb = to_rgb(gt_frames[t][0, ch].numpy()) |
| pred_rgb = to_rgb(pred_frames[t][0, ch].numpy()) |
| frames.append(_label_frame(gt_rgb, pred_rgb).transpose(2, 0, 1)) |
|
|
| video = np.stack(frames).astype(np.uint8) |
|
|
| model.train() |
| return video, per_step_mse |
|
|
|
|
| |
| |
| |
|
|
| def run_evaluation( |
| model, val_loader, rollout_loader, device, |
| global_step, wandb_run=None, |
| n_val_batches=4, n_rollout=20, ddim_steps=50, |
| ): |
| """Run full evaluation: single-step metrics + rollout video. |
| |
| Logs everything to WandB if wandb_run is provided. |
| |
| Returns: |
| dict of all metrics. |
| """ |
| logger.info("Running single-step evaluation...") |
| metrics, comparisons = single_step_eval( |
| model, val_loader, device, n_batches=n_val_batches, ddim_steps=ddim_steps |
| ) |
| logger.info(f" val/mse = {metrics['val/mse']:.6f}") |
|
|
| logger.info(f"Running {n_rollout}-step rollout evaluation...") |
| video, rollout_mse = rollout_eval( |
| model, rollout_loader, device, n_rollout=n_rollout, ddim_steps=ddim_steps |
| ) |
| logger.info(f" rollout MSE (step 1/last): {rollout_mse[0]:.6f} / {rollout_mse[-1]:.6f}") |
|
|
| |
| metrics["val/rollout_mse_mean"] = float(np.mean(rollout_mse)) |
| metrics["val/rollout_mse_final"] = rollout_mse[-1] |
| for t, m in enumerate(rollout_mse): |
| metrics[f"val/rollout_mse_step{t}"] = m |
|
|
| |
| if wandb_run is not None: |
| import wandb |
|
|
| wandb_run.log(metrics, step=global_step) |
|
|
| |
| for img, caption in comparisons[:8]: |
| wandb_run.log( |
| {f"eval/{caption}": wandb.Image(img, caption="Cond | GT | Pred")}, |
| step=global_step, |
| ) |
|
|
| |
| wandb_run.log( |
| {"eval/rollout_video": wandb.Video(video, fps=4, format="mp4", |
| caption="Left=GT Right=Prediction")}, |
| step=global_step, |
| ) |
|
|
| |
| table = wandb.Table(columns=["step", "mse"], data=[[t, m] for t, m in enumerate(rollout_mse)]) |
| wandb_run.log( |
| {"eval/rollout_mse_curve": wandb.plot.line( |
| table, "step", "mse", title="Rollout MSE vs Step" |
| )}, |
| step=global_step, |
| ) |
|
|
| return metrics |
|
|