dreamlessx commited on
Commit
c1dadad
·
verified ·
1 Parent(s): ded6c17

Update landmarkdiff/validation.py to v0.3.2

Browse files
Files changed (1) hide show
  1. landmarkdiff/validation.py +117 -43
landmarkdiff/validation.py CHANGED
@@ -30,6 +30,7 @@ class ValidationCallback:
30
  val_dataset=val_dataset,
31
  output_dir=Path("checkpoints/val"),
32
  num_samples=8,
 
33
  )
34
 
35
  # In training loop:
@@ -53,6 +54,7 @@ class ValidationCallback:
53
  num_samples: int = 8,
54
  num_inference_steps: int = 25,
55
  guidance_scale: float = 7.5,
 
56
  ):
57
  self.val_dataset = val_dataset
58
  self.output_dir = Path(output_dir)
@@ -60,8 +62,49 @@ class ValidationCallback:
60
  self.num_samples = min(num_samples, len(val_dataset))
61
  self.num_inference_steps = num_inference_steps
62
  self.guidance_scale = guidance_scale
 
63
  self.history: list[dict] = []
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  @torch.no_grad()
66
  def run(
67
  self,
@@ -76,9 +119,9 @@ class ValidationCallback:
76
  ) -> dict:
77
  """Run validation: generate samples and compute metrics.
78
 
79
- Returns dict with aggregate metrics.
80
  """
81
- from diffusers import DPMSolverMultistepScheduler
82
 
83
  t0 = time.time()
84
  controlnet.eval()
@@ -86,62 +129,70 @@ class ValidationCallback:
86
  step_dir = self.output_dir / f"step-{global_step}"
87
  step_dir.mkdir(parents=True, exist_ok=True)
88
 
89
- # Set up inference scheduler (DPM++ 2M for quality)
90
- scheduler = DPMSolverMultistepScheduler.from_config(noise_scheduler.config)
91
  scheduler.set_timesteps(self.num_inference_steps, device=device)
92
 
93
  ssim_scores = []
94
  lpips_scores = []
95
  generated_images = []
96
 
97
- for i in range(self.num_samples):
98
- sample = self.val_dataset[i]
 
 
 
 
 
 
 
99
  conditioning = sample["conditioning"].unsqueeze(0).to(device, dtype=weight_dtype)
100
  target = sample["target"].unsqueeze(0).to(device, dtype=weight_dtype)
101
 
102
- # Encode target for latent shape
103
- latents = vae.encode(target * 2 - 1).latent_dist.sample()
104
- latents = latents * vae.config.scaling_factor
105
 
106
  # Start from noise
107
  noise = torch.randn_like(latents)
108
  sample_latents = noise * scheduler.init_noise_sigma
109
  encoder_hidden_states = text_embeddings[:1]
110
 
111
- # Denoising loop with classifier-free guidance
112
- for t in scheduler.timesteps:
113
- scaled = scheduler.scale_model_input(sample_latents, t)
114
-
115
- # ControlNet
116
- down_samples, mid_sample = controlnet(
117
- scaled,
118
- t,
119
- encoder_hidden_states=encoder_hidden_states,
120
- controlnet_cond=conditioning,
121
- return_dict=False,
122
- )
123
-
124
- # UNet with ControlNet residuals
125
- noise_pred = unet(
126
- scaled,
127
- t,
128
- encoder_hidden_states=encoder_hidden_states,
129
- down_block_additional_residuals=down_samples,
130
- mid_block_additional_residual=mid_sample,
131
- ).sample
132
-
133
- sample_latents = scheduler.step(noise_pred, t, sample_latents).prev_sample
134
-
135
- # Decode (use float32 for VAE to avoid color banding)
 
 
136
  decoded = vae.decode(sample_latents.float() / vae.config.scaling_factor).sample
 
137
  decoded = ((decoded + 1) / 2).clamp(0, 1)
138
 
139
  # Convert to numpy for metrics
140
  gen_np = (decoded[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
141
  tgt_np = (target[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
142
- cond_np = (conditioning[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(
143
- np.uint8
144
- )
145
 
146
  # BGR for metrics (our metrics expect BGR)
147
  gen_bgr = gen_np[:, :, ::-1].copy()
@@ -154,12 +205,19 @@ class ValidationCallback:
154
  lpips_scores.append(lpips_val)
155
  generated_images.append(gen_np)
156
 
 
 
 
 
157
  # Save comparison: conditioning | generated | target
 
158
  comparison = np.hstack([cond_np, gen_np, tgt_np])
159
- Image.fromarray(comparison).save(step_dir / f"val_{i:02d}.png")
 
 
160
 
161
  # Aggregate metrics
162
- metrics = {
163
  "step": global_step,
164
  "ssim_mean": float(np.nanmean(ssim_scores)),
165
  "ssim_std": float(np.nanstd(ssim_scores)),
@@ -168,6 +226,16 @@ class ValidationCallback:
168
  "time_seconds": round(time.time() - t0, 1),
169
  }
170
 
 
 
 
 
 
 
 
 
 
 
171
  self.history.append(metrics)
172
 
173
  # Save metrics
@@ -182,7 +250,7 @@ class ValidationCallback:
182
  if generated_images:
183
  grid_rows = []
184
  for i in range(0, len(generated_images), 4):
185
- row_imgs = generated_images[i : i + 4]
186
  while len(row_imgs) < 4:
187
  row_imgs.append(np.zeros_like(generated_images[0]))
188
  grid_rows.append(np.hstack(row_imgs))
@@ -191,12 +259,19 @@ class ValidationCallback:
191
 
192
  controlnet.train()
193
 
 
 
 
 
 
194
  print(
195
  f" Validation @ step {global_step}: "
196
- f"SSIM={metrics['ssim_mean']:.4f}±{metrics['ssim_std']:.4f} "
197
- f"LPIPS={metrics['lpips_mean']:.4f}±{metrics['lpips_std']:.4f} "
198
  f"({metrics['time_seconds']:.1f}s)"
199
  )
 
 
200
 
201
  return metrics
202
 
@@ -207,7 +282,6 @@ class ValidationCallback:
207
 
208
  try:
209
  import matplotlib
210
-
211
  matplotlib.use("Agg")
212
  import matplotlib.pyplot as plt
213
  except ImportError:
 
30
  val_dataset=val_dataset,
31
  output_dir=Path("checkpoints/val"),
32
  num_samples=8,
33
+ samples_per_procedure=2,
34
  )
35
 
36
  # In training loop:
 
54
  num_samples: int = 8,
55
  num_inference_steps: int = 25,
56
  guidance_scale: float = 7.5,
57
+ samples_per_procedure: int = 2,
58
  ):
59
  self.val_dataset = val_dataset
60
  self.output_dir = Path(output_dir)
 
62
  self.num_samples = min(num_samples, len(val_dataset))
63
  self.num_inference_steps = num_inference_steps
64
  self.guidance_scale = guidance_scale
65
+ self.samples_per_procedure = samples_per_procedure
66
  self.history: list[dict] = []
67
 
68
+ # Pre-build per-procedure index map for stratified sampling
69
+ self._procedure_indices = self._build_procedure_map()
70
+
71
+ def _build_procedure_map(self) -> dict[str, list[int]]:
72
+ """Build a mapping of procedure name to dataset indices."""
73
+ from collections import defaultdict
74
+
75
+ proc_indices: dict[str, list[int]] = defaultdict(list)
76
+ ds = self.val_dataset
77
+
78
+ if hasattr(ds, "_sample_procedures") and ds._sample_procedures:
79
+ for idx, pair_path in enumerate(ds.pairs):
80
+ prefix = pair_path.stem.replace("_input", "")
81
+ proc = ds._sample_procedures.get(prefix, "unknown")
82
+ proc_indices[proc].append(idx)
83
+ elif hasattr(ds, "get_procedure"):
84
+ for idx in range(len(ds)):
85
+ proc = ds.get_procedure(idx)
86
+ proc_indices[proc].append(idx)
87
+
88
+ # Drop "unknown" if we have labeled procedures
89
+ known = {k: v for k, v in proc_indices.items() if k != "unknown"}
90
+ return dict(known) if known else dict(proc_indices)
91
+
92
+ def _select_per_procedure_indices(self) -> list[tuple[int, str]]:
93
+ """Select sample indices ensuring each procedure is represented.
94
+
95
+ Returns list of (dataset_index, procedure_name) tuples.
96
+ Falls back to first N sequential indices when no procedure metadata
97
+ is available.
98
+ """
99
+ if not self._procedure_indices:
100
+ return [(i, "unknown") for i in range(self.num_samples)]
101
+
102
+ selected: list[tuple[int, str]] = []
103
+ for proc, indices in sorted(self._procedure_indices.items()):
104
+ for idx in indices[: self.samples_per_procedure]:
105
+ selected.append((idx, proc))
106
+ return selected
107
+
108
  @torch.no_grad()
109
  def run(
110
  self,
 
119
  ) -> dict:
120
  """Run validation: generate samples and compute metrics.
121
 
122
+ Returns dict with aggregate and per-procedure metrics.
123
  """
124
+ from diffusers import DDIMScheduler
125
 
126
  t0 = time.time()
127
  controlnet.eval()
 
129
  step_dir = self.output_dir / f"step-{global_step}"
130
  step_dir.mkdir(parents=True, exist_ok=True)
131
 
132
+ # Set up inference scheduler (DDIM for robustness during validation)
133
+ scheduler = DDIMScheduler.from_config(noise_scheduler.config)
134
  scheduler.set_timesteps(self.num_inference_steps, device=device)
135
 
136
  ssim_scores = []
137
  lpips_scores = []
138
  generated_images = []
139
 
140
+ # Per-procedure metric accumulators
141
+ proc_ssim: dict[str, list[float]] = {}
142
+ proc_lpips: dict[str, list[float]] = {}
143
+
144
+ # Use per-procedure selection instead of sequential indices
145
+ per_proc = self._select_per_procedure_indices()
146
+
147
+ for sample_num, (idx, proc) in enumerate(per_proc):
148
+ sample = self.val_dataset[idx]
149
  conditioning = sample["conditioning"].unsqueeze(0).to(device, dtype=weight_dtype)
150
  target = sample["target"].unsqueeze(0).to(device, dtype=weight_dtype)
151
 
152
+ # Encode target for latent shape (VAE needs float32)
153
+ latents = vae.encode((target * 2 - 1).float()).latent_dist.sample()
154
+ latents = (latents * vae.config.scaling_factor).to(weight_dtype)
155
 
156
  # Start from noise
157
  noise = torch.randn_like(latents)
158
  sample_latents = noise * scheduler.init_noise_sigma
159
  encoder_hidden_states = text_embeddings[:1]
160
 
161
+ # Denoising loop with autocast to handle BF16/FP32 dtype
162
+ # mismatches in timestep embeddings
163
+ with torch.autocast("cuda", dtype=weight_dtype):
164
+ for t in scheduler.timesteps:
165
+ scaled = scheduler.scale_model_input(sample_latents, t)
166
+
167
+ # ControlNet
168
+ down_samples, mid_sample = controlnet(
169
+ scaled, t, encoder_hidden_states=encoder_hidden_states,
170
+ controlnet_cond=conditioning, return_dict=False,
171
+ )
172
+
173
+ # UNet with ControlNet residuals
174
+ noise_pred = unet(
175
+ scaled, t, encoder_hidden_states=encoder_hidden_states,
176
+ down_block_additional_residuals=down_samples,
177
+ mid_block_additional_residual=mid_sample,
178
+ ).sample
179
+
180
+ sample_latents = scheduler.step(
181
+ noise_pred, t, sample_latents,
182
+ ).prev_sample
183
+
184
+ # Decode -- cast VAE to float32 temporarily to avoid color banding
185
+ # and prevent dtype mismatch (latents float32 vs VAE weights bf16)
186
+ vae_dtype = next(vae.parameters()).dtype
187
+ vae.to(torch.float32)
188
  decoded = vae.decode(sample_latents.float() / vae.config.scaling_factor).sample
189
+ vae.to(vae_dtype)
190
  decoded = ((decoded + 1) / 2).clamp(0, 1)
191
 
192
  # Convert to numpy for metrics
193
  gen_np = (decoded[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
194
  tgt_np = (target[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
195
+ cond_np = (conditioning[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
 
 
196
 
197
  # BGR for metrics (our metrics expect BGR)
198
  gen_bgr = gen_np[:, :, ::-1].copy()
 
205
  lpips_scores.append(lpips_val)
206
  generated_images.append(gen_np)
207
 
208
+ # Accumulate per-procedure metrics
209
+ proc_ssim.setdefault(proc, []).append(ssim_val)
210
+ proc_lpips.setdefault(proc, []).append(lpips_val)
211
+
212
  # Save comparison: conditioning | generated | target
213
+ proc_tag = proc.replace(" ", "_")
214
  comparison = np.hstack([cond_np, gen_np, tgt_np])
215
+ Image.fromarray(comparison).save(
216
+ step_dir / f"val_{sample_num:02d}_{proc_tag}.png"
217
+ )
218
 
219
  # Aggregate metrics
220
+ metrics: dict = {
221
  "step": global_step,
222
  "ssim_mean": float(np.nanmean(ssim_scores)),
223
  "ssim_std": float(np.nanstd(ssim_scores)),
 
226
  "time_seconds": round(time.time() - t0, 1),
227
  }
228
 
229
+ # Per-procedure breakdown
230
+ per_procedure: dict[str, dict] = {}
231
+ for proc in sorted(proc_ssim.keys()):
232
+ per_procedure[proc] = {
233
+ "ssim_mean": float(np.nanmean(proc_ssim[proc])),
234
+ "lpips_mean": float(np.nanmean(proc_lpips[proc])),
235
+ "n_samples": len(proc_ssim[proc]),
236
+ }
237
+ metrics["per_procedure"] = per_procedure
238
+
239
  self.history.append(metrics)
240
 
241
  # Save metrics
 
250
  if generated_images:
251
  grid_rows = []
252
  for i in range(0, len(generated_images), 4):
253
+ row_imgs = generated_images[i:i + 4]
254
  while len(row_imgs) < 4:
255
  row_imgs.append(np.zeros_like(generated_images[0]))
256
  grid_rows.append(np.hstack(row_imgs))
 
259
 
260
  controlnet.train()
261
 
262
+ # Log summary with per-procedure breakdown
263
+ proc_summary = " | ".join(
264
+ f"{p}: SSIM={v['ssim_mean']:.3f}"
265
+ for p, v in sorted(per_procedure.items())
266
+ )
267
  print(
268
  f" Validation @ step {global_step}: "
269
+ f"SSIM={metrics['ssim_mean']:.4f}+/-{metrics['ssim_std']:.4f} "
270
+ f"LPIPS={metrics['lpips_mean']:.4f}+/-{metrics['lpips_std']:.4f} "
271
  f"({metrics['time_seconds']:.1f}s)"
272
  )
273
+ if proc_summary:
274
+ print(f" Per-procedure: {proc_summary}")
275
 
276
  return metrics
277
 
 
282
 
283
  try:
284
  import matplotlib
 
285
  matplotlib.use("Agg")
286
  import matplotlib.pyplot as plt
287
  except ImportError: