Spaces:
Running
Running
Update landmarkdiff/validation.py to v0.3.2
Browse files- landmarkdiff/validation.py +117 -43
landmarkdiff/validation.py
CHANGED
|
@@ -30,6 +30,7 @@ class ValidationCallback:
|
|
| 30 |
val_dataset=val_dataset,
|
| 31 |
output_dir=Path("checkpoints/val"),
|
| 32 |
num_samples=8,
|
|
|
|
| 33 |
)
|
| 34 |
|
| 35 |
# In training loop:
|
|
@@ -53,6 +54,7 @@ class ValidationCallback:
|
|
| 53 |
num_samples: int = 8,
|
| 54 |
num_inference_steps: int = 25,
|
| 55 |
guidance_scale: float = 7.5,
|
|
|
|
| 56 |
):
|
| 57 |
self.val_dataset = val_dataset
|
| 58 |
self.output_dir = Path(output_dir)
|
|
@@ -60,8 +62,49 @@ class ValidationCallback:
|
|
| 60 |
self.num_samples = min(num_samples, len(val_dataset))
|
| 61 |
self.num_inference_steps = num_inference_steps
|
| 62 |
self.guidance_scale = guidance_scale
|
|
|
|
| 63 |
self.history: list[dict] = []
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
@torch.no_grad()
|
| 66 |
def run(
|
| 67 |
self,
|
|
@@ -76,9 +119,9 @@ class ValidationCallback:
|
|
| 76 |
) -> dict:
|
| 77 |
"""Run validation: generate samples and compute metrics.
|
| 78 |
|
| 79 |
-
Returns dict with aggregate metrics.
|
| 80 |
"""
|
| 81 |
-
from diffusers import
|
| 82 |
|
| 83 |
t0 = time.time()
|
| 84 |
controlnet.eval()
|
|
@@ -86,62 +129,70 @@ class ValidationCallback:
|
|
| 86 |
step_dir = self.output_dir / f"step-{global_step}"
|
| 87 |
step_dir.mkdir(parents=True, exist_ok=True)
|
| 88 |
|
| 89 |
-
# Set up inference scheduler (
|
| 90 |
-
scheduler =
|
| 91 |
scheduler.set_timesteps(self.num_inference_steps, device=device)
|
| 92 |
|
| 93 |
ssim_scores = []
|
| 94 |
lpips_scores = []
|
| 95 |
generated_images = []
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
conditioning = sample["conditioning"].unsqueeze(0).to(device, dtype=weight_dtype)
|
| 100 |
target = sample["target"].unsqueeze(0).to(device, dtype=weight_dtype)
|
| 101 |
|
| 102 |
-
# Encode target for latent shape
|
| 103 |
-
latents = vae.encode(target * 2 - 1).latent_dist.sample()
|
| 104 |
-
latents = latents * vae.config.scaling_factor
|
| 105 |
|
| 106 |
# Start from noise
|
| 107 |
noise = torch.randn_like(latents)
|
| 108 |
sample_latents = noise * scheduler.init_noise_sigma
|
| 109 |
encoder_hidden_states = text_embeddings[:1]
|
| 110 |
|
| 111 |
-
# Denoising loop with
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
#
|
|
|
|
|
|
|
| 136 |
decoded = vae.decode(sample_latents.float() / vae.config.scaling_factor).sample
|
|
|
|
| 137 |
decoded = ((decoded + 1) / 2).clamp(0, 1)
|
| 138 |
|
| 139 |
# Convert to numpy for metrics
|
| 140 |
gen_np = (decoded[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
| 141 |
tgt_np = (target[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
| 142 |
-
cond_np = (conditioning[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(
|
| 143 |
-
np.uint8
|
| 144 |
-
)
|
| 145 |
|
| 146 |
# BGR for metrics (our metrics expect BGR)
|
| 147 |
gen_bgr = gen_np[:, :, ::-1].copy()
|
|
@@ -154,12 +205,19 @@ class ValidationCallback:
|
|
| 154 |
lpips_scores.append(lpips_val)
|
| 155 |
generated_images.append(gen_np)
|
| 156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
# Save comparison: conditioning | generated | target
|
|
|
|
| 158 |
comparison = np.hstack([cond_np, gen_np, tgt_np])
|
| 159 |
-
Image.fromarray(comparison).save(
|
|
|
|
|
|
|
| 160 |
|
| 161 |
# Aggregate metrics
|
| 162 |
-
metrics = {
|
| 163 |
"step": global_step,
|
| 164 |
"ssim_mean": float(np.nanmean(ssim_scores)),
|
| 165 |
"ssim_std": float(np.nanstd(ssim_scores)),
|
|
@@ -168,6 +226,16 @@ class ValidationCallback:
|
|
| 168 |
"time_seconds": round(time.time() - t0, 1),
|
| 169 |
}
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
self.history.append(metrics)
|
| 172 |
|
| 173 |
# Save metrics
|
|
@@ -182,7 +250,7 @@ class ValidationCallback:
|
|
| 182 |
if generated_images:
|
| 183 |
grid_rows = []
|
| 184 |
for i in range(0, len(generated_images), 4):
|
| 185 |
-
row_imgs = generated_images[i
|
| 186 |
while len(row_imgs) < 4:
|
| 187 |
row_imgs.append(np.zeros_like(generated_images[0]))
|
| 188 |
grid_rows.append(np.hstack(row_imgs))
|
|
@@ -191,12 +259,19 @@ class ValidationCallback:
|
|
| 191 |
|
| 192 |
controlnet.train()
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
print(
|
| 195 |
f" Validation @ step {global_step}: "
|
| 196 |
-
f"SSIM={metrics['ssim_mean']:.4f}
|
| 197 |
-
f"LPIPS={metrics['lpips_mean']:.4f}
|
| 198 |
f"({metrics['time_seconds']:.1f}s)"
|
| 199 |
)
|
|
|
|
|
|
|
| 200 |
|
| 201 |
return metrics
|
| 202 |
|
|
@@ -207,7 +282,6 @@ class ValidationCallback:
|
|
| 207 |
|
| 208 |
try:
|
| 209 |
import matplotlib
|
| 210 |
-
|
| 211 |
matplotlib.use("Agg")
|
| 212 |
import matplotlib.pyplot as plt
|
| 213 |
except ImportError:
|
|
|
|
| 30 |
val_dataset=val_dataset,
|
| 31 |
output_dir=Path("checkpoints/val"),
|
| 32 |
num_samples=8,
|
| 33 |
+
samples_per_procedure=2,
|
| 34 |
)
|
| 35 |
|
| 36 |
# In training loop:
|
|
|
|
| 54 |
num_samples: int = 8,
|
| 55 |
num_inference_steps: int = 25,
|
| 56 |
guidance_scale: float = 7.5,
|
| 57 |
+
samples_per_procedure: int = 2,
|
| 58 |
):
|
| 59 |
self.val_dataset = val_dataset
|
| 60 |
self.output_dir = Path(output_dir)
|
|
|
|
| 62 |
self.num_samples = min(num_samples, len(val_dataset))
|
| 63 |
self.num_inference_steps = num_inference_steps
|
| 64 |
self.guidance_scale = guidance_scale
|
| 65 |
+
self.samples_per_procedure = samples_per_procedure
|
| 66 |
self.history: list[dict] = []
|
| 67 |
|
| 68 |
+
# Pre-build per-procedure index map for stratified sampling
|
| 69 |
+
self._procedure_indices = self._build_procedure_map()
|
| 70 |
+
|
| 71 |
+
def _build_procedure_map(self) -> dict[str, list[int]]:
|
| 72 |
+
"""Build a mapping of procedure name to dataset indices."""
|
| 73 |
+
from collections import defaultdict
|
| 74 |
+
|
| 75 |
+
proc_indices: dict[str, list[int]] = defaultdict(list)
|
| 76 |
+
ds = self.val_dataset
|
| 77 |
+
|
| 78 |
+
if hasattr(ds, "_sample_procedures") and ds._sample_procedures:
|
| 79 |
+
for idx, pair_path in enumerate(ds.pairs):
|
| 80 |
+
prefix = pair_path.stem.replace("_input", "")
|
| 81 |
+
proc = ds._sample_procedures.get(prefix, "unknown")
|
| 82 |
+
proc_indices[proc].append(idx)
|
| 83 |
+
elif hasattr(ds, "get_procedure"):
|
| 84 |
+
for idx in range(len(ds)):
|
| 85 |
+
proc = ds.get_procedure(idx)
|
| 86 |
+
proc_indices[proc].append(idx)
|
| 87 |
+
|
| 88 |
+
# Drop "unknown" if we have labeled procedures
|
| 89 |
+
known = {k: v for k, v in proc_indices.items() if k != "unknown"}
|
| 90 |
+
return dict(known) if known else dict(proc_indices)
|
| 91 |
+
|
| 92 |
+
def _select_per_procedure_indices(self) -> list[tuple[int, str]]:
|
| 93 |
+
"""Select sample indices ensuring each procedure is represented.
|
| 94 |
+
|
| 95 |
+
Returns list of (dataset_index, procedure_name) tuples.
|
| 96 |
+
Falls back to first N sequential indices when no procedure metadata
|
| 97 |
+
is available.
|
| 98 |
+
"""
|
| 99 |
+
if not self._procedure_indices:
|
| 100 |
+
return [(i, "unknown") for i in range(self.num_samples)]
|
| 101 |
+
|
| 102 |
+
selected: list[tuple[int, str]] = []
|
| 103 |
+
for proc, indices in sorted(self._procedure_indices.items()):
|
| 104 |
+
for idx in indices[: self.samples_per_procedure]:
|
| 105 |
+
selected.append((idx, proc))
|
| 106 |
+
return selected
|
| 107 |
+
|
| 108 |
@torch.no_grad()
|
| 109 |
def run(
|
| 110 |
self,
|
|
|
|
| 119 |
) -> dict:
|
| 120 |
"""Run validation: generate samples and compute metrics.
|
| 121 |
|
| 122 |
+
Returns dict with aggregate and per-procedure metrics.
|
| 123 |
"""
|
| 124 |
+
from diffusers import DDIMScheduler
|
| 125 |
|
| 126 |
t0 = time.time()
|
| 127 |
controlnet.eval()
|
|
|
|
| 129 |
step_dir = self.output_dir / f"step-{global_step}"
|
| 130 |
step_dir.mkdir(parents=True, exist_ok=True)
|
| 131 |
|
| 132 |
+
# Set up inference scheduler (DDIM for robustness during validation)
|
| 133 |
+
scheduler = DDIMScheduler.from_config(noise_scheduler.config)
|
| 134 |
scheduler.set_timesteps(self.num_inference_steps, device=device)
|
| 135 |
|
| 136 |
ssim_scores = []
|
| 137 |
lpips_scores = []
|
| 138 |
generated_images = []
|
| 139 |
|
| 140 |
+
# Per-procedure metric accumulators
|
| 141 |
+
proc_ssim: dict[str, list[float]] = {}
|
| 142 |
+
proc_lpips: dict[str, list[float]] = {}
|
| 143 |
+
|
| 144 |
+
# Use per-procedure selection instead of sequential indices
|
| 145 |
+
per_proc = self._select_per_procedure_indices()
|
| 146 |
+
|
| 147 |
+
for sample_num, (idx, proc) in enumerate(per_proc):
|
| 148 |
+
sample = self.val_dataset[idx]
|
| 149 |
conditioning = sample["conditioning"].unsqueeze(0).to(device, dtype=weight_dtype)
|
| 150 |
target = sample["target"].unsqueeze(0).to(device, dtype=weight_dtype)
|
| 151 |
|
| 152 |
+
# Encode target for latent shape (VAE needs float32)
|
| 153 |
+
latents = vae.encode((target * 2 - 1).float()).latent_dist.sample()
|
| 154 |
+
latents = (latents * vae.config.scaling_factor).to(weight_dtype)
|
| 155 |
|
| 156 |
# Start from noise
|
| 157 |
noise = torch.randn_like(latents)
|
| 158 |
sample_latents = noise * scheduler.init_noise_sigma
|
| 159 |
encoder_hidden_states = text_embeddings[:1]
|
| 160 |
|
| 161 |
+
# Denoising loop with autocast to handle BF16/FP32 dtype
|
| 162 |
+
# mismatches in timestep embeddings
|
| 163 |
+
with torch.autocast("cuda", dtype=weight_dtype):
|
| 164 |
+
for t in scheduler.timesteps:
|
| 165 |
+
scaled = scheduler.scale_model_input(sample_latents, t)
|
| 166 |
+
|
| 167 |
+
# ControlNet
|
| 168 |
+
down_samples, mid_sample = controlnet(
|
| 169 |
+
scaled, t, encoder_hidden_states=encoder_hidden_states,
|
| 170 |
+
controlnet_cond=conditioning, return_dict=False,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# UNet with ControlNet residuals
|
| 174 |
+
noise_pred = unet(
|
| 175 |
+
scaled, t, encoder_hidden_states=encoder_hidden_states,
|
| 176 |
+
down_block_additional_residuals=down_samples,
|
| 177 |
+
mid_block_additional_residual=mid_sample,
|
| 178 |
+
).sample
|
| 179 |
+
|
| 180 |
+
sample_latents = scheduler.step(
|
| 181 |
+
noise_pred, t, sample_latents,
|
| 182 |
+
).prev_sample
|
| 183 |
+
|
| 184 |
+
# Decode -- cast VAE to float32 temporarily to avoid color banding
|
| 185 |
+
# and prevent dtype mismatch (latents float32 vs VAE weights bf16)
|
| 186 |
+
vae_dtype = next(vae.parameters()).dtype
|
| 187 |
+
vae.to(torch.float32)
|
| 188 |
decoded = vae.decode(sample_latents.float() / vae.config.scaling_factor).sample
|
| 189 |
+
vae.to(vae_dtype)
|
| 190 |
decoded = ((decoded + 1) / 2).clamp(0, 1)
|
| 191 |
|
| 192 |
# Convert to numpy for metrics
|
| 193 |
gen_np = (decoded[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
| 194 |
tgt_np = (target[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
| 195 |
+
cond_np = (conditioning[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
|
|
|
|
|
|
| 196 |
|
| 197 |
# BGR for metrics (our metrics expect BGR)
|
| 198 |
gen_bgr = gen_np[:, :, ::-1].copy()
|
|
|
|
| 205 |
lpips_scores.append(lpips_val)
|
| 206 |
generated_images.append(gen_np)
|
| 207 |
|
| 208 |
+
# Accumulate per-procedure metrics
|
| 209 |
+
proc_ssim.setdefault(proc, []).append(ssim_val)
|
| 210 |
+
proc_lpips.setdefault(proc, []).append(lpips_val)
|
| 211 |
+
|
| 212 |
# Save comparison: conditioning | generated | target
|
| 213 |
+
proc_tag = proc.replace(" ", "_")
|
| 214 |
comparison = np.hstack([cond_np, gen_np, tgt_np])
|
| 215 |
+
Image.fromarray(comparison).save(
|
| 216 |
+
step_dir / f"val_{sample_num:02d}_{proc_tag}.png"
|
| 217 |
+
)
|
| 218 |
|
| 219 |
# Aggregate metrics
|
| 220 |
+
metrics: dict = {
|
| 221 |
"step": global_step,
|
| 222 |
"ssim_mean": float(np.nanmean(ssim_scores)),
|
| 223 |
"ssim_std": float(np.nanstd(ssim_scores)),
|
|
|
|
| 226 |
"time_seconds": round(time.time() - t0, 1),
|
| 227 |
}
|
| 228 |
|
| 229 |
+
# Per-procedure breakdown
|
| 230 |
+
per_procedure: dict[str, dict] = {}
|
| 231 |
+
for proc in sorted(proc_ssim.keys()):
|
| 232 |
+
per_procedure[proc] = {
|
| 233 |
+
"ssim_mean": float(np.nanmean(proc_ssim[proc])),
|
| 234 |
+
"lpips_mean": float(np.nanmean(proc_lpips[proc])),
|
| 235 |
+
"n_samples": len(proc_ssim[proc]),
|
| 236 |
+
}
|
| 237 |
+
metrics["per_procedure"] = per_procedure
|
| 238 |
+
|
| 239 |
self.history.append(metrics)
|
| 240 |
|
| 241 |
# Save metrics
|
|
|
|
| 250 |
if generated_images:
|
| 251 |
grid_rows = []
|
| 252 |
for i in range(0, len(generated_images), 4):
|
| 253 |
+
row_imgs = generated_images[i:i + 4]
|
| 254 |
while len(row_imgs) < 4:
|
| 255 |
row_imgs.append(np.zeros_like(generated_images[0]))
|
| 256 |
grid_rows.append(np.hstack(row_imgs))
|
|
|
|
| 259 |
|
| 260 |
controlnet.train()
|
| 261 |
|
| 262 |
+
# Log summary with per-procedure breakdown
|
| 263 |
+
proc_summary = " | ".join(
|
| 264 |
+
f"{p}: SSIM={v['ssim_mean']:.3f}"
|
| 265 |
+
for p, v in sorted(per_procedure.items())
|
| 266 |
+
)
|
| 267 |
print(
|
| 268 |
f" Validation @ step {global_step}: "
|
| 269 |
+
f"SSIM={metrics['ssim_mean']:.4f}+/-{metrics['ssim_std']:.4f} "
|
| 270 |
+
f"LPIPS={metrics['lpips_mean']:.4f}+/-{metrics['lpips_std']:.4f} "
|
| 271 |
f"({metrics['time_seconds']:.1f}s)"
|
| 272 |
)
|
| 273 |
+
if proc_summary:
|
| 274 |
+
print(f" Per-procedure: {proc_summary}")
|
| 275 |
|
| 276 |
return metrics
|
| 277 |
|
|
|
|
| 282 |
|
| 283 |
try:
|
| 284 |
import matplotlib
|
|
|
|
| 285 |
matplotlib.use("Agg")
|
| 286 |
import matplotlib.pyplot as plt
|
| 287 |
except ImportError:
|