letitiaaa commited on
Commit
d2c56ce
·
verified ·
1 Parent(s): f0cfe9c

Upload train_vae.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_vae.py +729 -0
train_vae.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.distributed as dist
5
+ from torch.nn.parallel import DistributedDataParallel as DDP
6
+ from torch.utils.data import DataLoader, Dataset
7
+ from torch.utils.data.distributed import DistributedSampler
8
+ from torch.amp import autocast, GradScaler
9
+ from torch.utils.tensorboard import SummaryWriter
10
+ import numpy as np
11
+ from PIL import Image
12
+ from glob import glob
13
+ from time import time
14
+ import argparse
15
+ import logging
16
+ import os
17
+ import json
18
+
19
+ from models import AutoencoderKL, DiT_models
20
+ from custom_dataset import StyleTransferDataset, create_style_transfer_dataloader
21
+
22
+
23
+ #################################################################################
24
+ # VAE loss function #
25
+ #################################################################################
26
+
27
+ class SSIMLoss(nn.Module):
28
+ def __init__(self, window_size=11, size_average=True):
29
+ super(SSIMLoss, self).__init__()
30
+ self.window_size = window_size
31
+ self.size_average = size_average
32
+ self.channel = 1
33
+ self.window = self.create_window(window_size, self.channel)
34
+
35
+ def gaussian(self, window_size, sigma):
36
+ gauss = torch.Tensor([np.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
37
+ return gauss/gauss.sum()
38
+
39
+ def create_window(self, window_size, channel):
40
+ _1D_window = self.gaussian(window_size, 1.5).unsqueeze(1)
41
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
42
+ window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
43
+ return window
44
+
45
+ def _ssim(self, img1, img2, window, window_size, channel, size_average=True):
46
+ mu1 = F.conv2d(img1, window, padding=window_size//2, groups=channel)
47
+ mu2 = F.conv2d(img2, window, padding=window_size//2, groups=channel)
48
+
49
+ mu1_sq = mu1.pow(2)
50
+ mu2_sq = mu2.pow(2)
51
+ mu1_mu2 = mu1 * mu2
52
+
53
+ sigma1_sq = F.conv2d(img1*img1, window, padding=window_size//2, groups=channel) - mu1_sq
54
+ sigma2_sq = F.conv2d(img2*img2, window, padding=window_size//2, groups=channel) - mu2_sq
55
+ sigma12 = F.conv2d(img1*img2, window, padding=window_size//2, groups=channel) - mu1_mu2
56
+
57
+ C1 = 0.01**2
58
+ C2 = 0.03**2
59
+
60
+ ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
61
+
62
+ if size_average:
63
+ return ssim_map.mean()
64
+ else:
65
+ return ssim_map.mean(1).mean(1).mean(1)
66
+
67
+ def forward(self, img1, img2):
68
+ (_, channel, _, _) = img1.size()
69
+
70
+ if channel == self.channel and self.window.data.type() == img1.data.type():
71
+ window = self.window
72
+ else:
73
+ window = self.create_window(self.window_size, channel)
74
+
75
+ if img1.is_cuda:
76
+ window = window.cuda(img1.get_device())
77
+ window = window.type_as(img1)
78
+
79
+ self.window = window
80
+ self.channel = channel
81
+
82
+ ssim_val = self._ssim(img1, img2, window, self.window_size, channel, self.size_average)
83
+ return 1 - ssim_val
84
+
85
+
86
+ class VAELoss(nn.Module):
87
+ def __init__(
88
+ self,
89
+ kl_weight=1e-6,
90
+ l1_weight=1.0,
91
+ ssim_weight=1.0,
92
+ ):
93
+ super().__init__()
94
+ self.kl_weight = kl_weight
95
+ self.l1_weight = l1_weight
96
+ self.ssim_weight = ssim_weight
97
+
98
+ self.l1_loss = nn.L1Loss()
99
+ self.ssim_loss = SSIMLoss()
100
+
101
+ def forward(self, recon, target, posterior):
102
+
103
+ l1_loss = self.l1_loss(recon, target)
104
+
105
+ # Convert from [-1, 1] to [0, 1] for SSIM calculation
106
+ # SSIM constants (C1, C2) are designed for [0, 1] range
107
+ recon_01 = (recon + 1.0) / 2.0
108
+ target_01 = (target + 1.0) / 2.0
109
+ ssim_loss = self.ssim_loss(recon_01, target_01)
110
+
111
+ kl_loss = posterior.kl().mean()
112
+
113
+ total_loss = (
114
+ self.l1_weight * l1_loss +
115
+ self.ssim_weight * ssim_loss +
116
+ self.kl_weight * kl_loss
117
+ )
118
+
119
+ return {
120
+ 'total_loss': total_loss,
121
+ 'l1_loss': self.l1_weight * l1_loss ,
122
+ 'ssim_loss': self.ssim_weight * ssim_loss,
123
+ 'kl_loss': self.kl_weight * kl_loss,
124
+ }
125
+
126
+
127
+ #################################################################################
128
+ # Training Helper Functions #
129
+ #################################################################################
130
+
131
+ def create_logger(experiment_dir):
132
+ if experiment_dir is not None:
133
+ logging.basicConfig(
134
+ level=logging.INFO,
135
+ format='[\033[34m%(asctime)s\033[0m] %(message)s',
136
+ datefmt='%Y-%m-%d %H:%M:%S',
137
+ handlers=[
138
+ logging.StreamHandler(),
139
+ logging.FileHandler(f"{experiment_dir}/log.txt")
140
+ ]
141
+ )
142
+ logger = logging.getLogger(__name__)
143
+ else:
144
+ logger = logging.getLogger(__name__)
145
+ logger.addHandler(logging.NullHandler())
146
+ return logger
147
+
148
+
149
+ def cleanup():
150
+ if dist.is_initialized():
151
+ dist.destroy_process_group()
152
+
153
+
154
+ def get_lr_scheduler(optimizer, args, steps_per_epoch):
155
+ if args.lr_scheduler == 'none':
156
+ return None
157
+
158
+ total_steps = args.epochs * steps_per_epoch
159
+ warmup_steps = args.warmup_epochs * steps_per_epoch
160
+
161
+ if args.lr_scheduler == 'linear':
162
+ # Warmup + Linear Decay
163
+ def lr_lambda(current_step):
164
+ if current_step < warmup_steps:
165
+ return float(current_step) / float(max(1, warmup_steps))
166
+ else:
167
+ progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
168
+ return max(0.0, 1.0 - progress)
169
+
170
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
171
+
172
+ elif args.lr_scheduler == 'cosine':
173
+ # Warmup + Cosine Decay
174
+ def lr_lambda(current_step):
175
+ if current_step < warmup_steps:
176
+ return float(current_step) / float(max(1, warmup_steps))
177
+ else:
178
+ progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
179
+ return max(0.0, 0.5 * (1.0 + np.cos(np.pi * progress)))
180
+
181
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
182
+
183
+ elif args.lr_scheduler == 'constant':
184
+ # Warmup + Constant
185
+ def lr_lambda(current_step):
186
+ if current_step < warmup_steps:
187
+ return float(current_step) / float(max(1, warmup_steps))
188
+ else:
189
+ return 1.0
190
+
191
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
192
+
193
+ else:
194
+ raise ValueError(f"Unknown lr_scheduler: {args.lr_scheduler}")
195
+
196
+
197
+ @torch.no_grad()
198
+ def save_samples(vae, dataloader, device, save_dir, num_samples=8, patches_per_image=4, patch_size=512, is_conditional=False, use_fp16=False, vae_domain1=False, vae_domain2=False, multiscale=False, multiscale_levels=None):
199
+ vae.eval()
200
+ os.makedirs(save_dir, exist_ok=True)
201
+
202
+ if multiscale_levels is None:
203
+ multiscale_levels = [32, 64, 128, 256, 512]
204
+
205
+ saved_count = 0
206
+ for a_large_images, b_large_images, paths in dataloader:
207
+ a_large_images = a_large_images.to(device)
208
+ b_large_images = b_large_images.to(device)
209
+ # multiscale_sizes=None disables multiscale, otherwise enables it
210
+ # If multiscale_levels is provided, automatically enable multiscale
211
+ if multiscale:
212
+ multiscale_sizes_param = multiscale_levels if multiscale_levels is not None else [32, 64, 128, 256, 512]
213
+ elif multiscale_levels is not None:
214
+ # User provided multiscale_levels without multiscale flag, auto-enable
215
+ multiscale_sizes_param = multiscale_levels
216
+ else:
217
+ multiscale_sizes_param = None
218
+ if vae_domain1:
219
+ b_images, _, pos_info = StyleTransferDataset.crop_patches_from_large_images_with_pos(
220
+ a_large_images,
221
+ b_large_images,
222
+ patch_size=patch_size,
223
+ patches_per_image=patches_per_image,
224
+ width_norm=15000.0,
225
+ height_norm=20000.0,
226
+ multiscale_sizes=multiscale_sizes_param
227
+ )
228
+ elif vae_domain2:
229
+ _, b_images, pos_info = StyleTransferDataset.crop_patches_from_large_images_with_pos(
230
+ a_large_images,
231
+ b_large_images,
232
+ patch_size=patch_size,
233
+ patches_per_image=patches_per_image,
234
+ width_norm=15000.0,
235
+ height_norm=20000.0,
236
+ multiscale_sizes=multiscale_sizes_param
237
+ )
238
+
239
+ del a_large_images, b_large_images
240
+ torch.cuda.empty_cache()
241
+
242
+ batch = b_images[:num_samples - saved_count]
243
+ pos_batch = pos_info[:num_samples - saved_count] if pos_info is not None else None
244
+
245
+ # Use autocast for inference if fp16 is enabled
246
+ with autocast('cuda', enabled=use_fp16):
247
+ if is_conditional and pos_batch is not None:
248
+ recon, _ = vae(batch, sample_posterior=False, pos_context=pos_batch)
249
+ else:
250
+ recon, _ = vae(batch, sample_posterior=False)
251
+
252
+ if vae_domain2:
253
+ batch_ori = (batch[:, 0:1, :, :] + 1.0) / 2.0
254
+ recon_ori = (recon[:, 0:1, :, :] + 1.0) / 2.0
255
+
256
+ # Z-score denormalization for retar: (normalized * std + mean)
257
+ # batch_retar = (batch[:, 1:2, :, :] + 1.0) / 2.0 * 9
258
+ # recon_retar = (recon[:, 1:2, :, :] + 1.0) / 2.0 * 9
259
+ batch_retar = batch[:, 1:2, :, :] * 11.41 + 5.61
260
+ recon_retar = recon[:, 1:2, :, :] * 11.41 + 5.61
261
+
262
+ for i in range(batch.shape[0]):
263
+ orig_ori = (batch_ori[i, 0].cpu().numpy() * 255).astype(np.uint8)
264
+ recon_ori_img = (recon_ori[i, 0].cpu().numpy() * 255).astype(np.uint8)
265
+
266
+ # Clip retar values to [0, 90] range and scale to [0, 255] for display
267
+ orig_retar = (batch_retar[i, 0].cpu().numpy() * 255 / 9).astype(np.uint8)
268
+ recon_retar_img = (recon_retar[i, 0].cpu().numpy() * 255 / 9).astype(np.uint8)
269
+
270
+ orig_ori_pil = Image.fromarray(orig_ori)
271
+ recon_ori_pil = Image.fromarray(recon_ori_img)
272
+ orig_retar_pil = Image.fromarray(orig_retar)
273
+ recon_retar_pil = Image.fromarray(recon_retar_img)
274
+
275
+ combined = Image.new('L', (orig_ori_pil.width * 4, orig_ori_pil.height))
276
+ combined.paste(orig_ori_pil, (0, 0))
277
+ combined.paste(recon_ori_pil, (orig_ori_pil.width, 0))
278
+ combined.paste(orig_retar_pil, (orig_ori_pil.width * 2, 0))
279
+ combined.paste(recon_retar_pil, (orig_ori_pil.width * 3, 0))
280
+
281
+ elif vae_domain1:
282
+ batch_a = (batch[:, 0:1, :, :] + 1.0) / 2.0
283
+ recon_a = (recon[:, 0:1, :, :] + 1.0) / 2.0
284
+
285
+ for i in range(batch.shape[0]):
286
+ orig_a = (batch_a[i, 0].cpu().numpy() * 255).astype(np.uint8)
287
+ recon_a_img = (recon_a[i, 0].cpu().numpy() * 255).astype(np.uint8)
288
+
289
+ orig_a_pil = Image.fromarray(orig_a)
290
+ recon_a_pil = Image.fromarray(recon_a_img)
291
+
292
+ combined = Image.new('L', (orig_a_pil.width * 2, orig_a_pil.height))
293
+ combined.paste(orig_a_pil, (0, 0))
294
+ combined.paste(recon_a_pil, (orig_a_pil.width, 0))
295
+
296
+
297
+ combined.save(f"{save_dir}/sample_{saved_count}.png")
298
+ saved_count += 1
299
+ if saved_count >= num_samples:
300
+ break
301
+
302
+ vae.train()
303
+
304
+
305
+
306
+ def main(args):
307
+ assert torch.cuda.is_available(), "Training requires at least one GPU."
308
+
309
+ # Setup DDP
310
+ rank = int(os.environ.get("RANK", 0))
311
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
312
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
313
+
314
+ torch.cuda.set_device(local_rank)
315
+ device = local_rank
316
+
317
+ dist.init_process_group("nccl")
318
+
319
+ seed = args.global_seed * world_size + rank
320
+ torch.manual_seed(seed)
321
+ print(f"Starting rank={rank}, local_rank={local_rank}, seed={seed}, world_size={world_size}.")
322
+
323
+ # Setup experiment folder
324
+ is_master = (rank == 0)
325
+ if is_master:
326
+ os.makedirs(args.results_dir, exist_ok=True)
327
+ experiment_index = len(glob(f"{args.results_dir}/*"))
328
+ model_name = args.vae_model if args.vae_model else "VAE-Custom"
329
+ experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_name}"
330
+ checkpoint_dir = f"{experiment_dir}/checkpoints"
331
+ sample_dir = f"{experiment_dir}/samples"
332
+ tensorboard_dir = f"{experiment_dir}/tensorboard"
333
+ os.makedirs(checkpoint_dir, exist_ok=True)
334
+ os.makedirs(sample_dir, exist_ok=True)
335
+ os.makedirs(tensorboard_dir, exist_ok=True)
336
+ logger = create_logger(experiment_dir)
337
+ logger.info(f"Experiment directory created at {experiment_dir}")
338
+
339
+ writer = SummaryWriter(tensorboard_dir)
340
+ logger.info(f"TensorBoard logs will be saved to {tensorboard_dir}")
341
+ else:
342
+ logger = create_logger(None)
343
+ sample_dir = None
344
+ writer = None
345
+
346
+ if args.vae_model:
347
+ if is_master:
348
+ logger.info(f"Creating VAE model: {args.vae_model}")
349
+ if args.vae_model not in DiT_models:
350
+ raise ValueError(f"Unknown VAE model: {args.vae_model}. Available: {[k for k in DiT_models.keys() if k.startswith('VAE')]}")
351
+
352
+ vae_fn = DiT_models[args.vae_model]
353
+ vae = vae_fn(
354
+ in_channels=args.in_channels,
355
+ out_ch=args.out_channels,
356
+ resolution=args.image_size,
357
+ ).to(device)
358
+ if is_master:
359
+ logger.info(f"Using predefined VAE model: {args.vae_model}")
360
+
361
+ else:
362
+ if is_master:
363
+ logger.info("Creating VAE model with custom parameters")
364
+ vae = AutoencoderKL(
365
+ embed_dim=args.embed_dim,
366
+ in_channels=args.in_channels,
367
+ out_ch=args.out_channels,
368
+ ch=args.ch,
369
+ ch_mult=tuple(args.ch_mult),
370
+ num_res_blocks=args.num_res_blocks,
371
+ attn_resolutions=args.attn_resolutions,
372
+ dropout=args.dropout,
373
+ resolution=args.image_size,
374
+ z_channels=args.z_channels,
375
+ double_z=args.double_z,
376
+ use_mid_attn=False,
377
+ ).to(device)
378
+ if is_master:
379
+ logger.info(f"VAE middle attention: DISABLED (saves ~68GB memory)")
380
+
381
+ if is_master:
382
+ logger.info(f"VAE Parameters: {sum(p.numel() for p in vae.parameters()):,}")
383
+
384
+ is_conditional_vae = hasattr(vae, 'condition_net')
385
+ if is_master:
386
+ if is_conditional_vae:
387
+ logger.info("✓ Using Conditional VAE with position information")
388
+ else:
389
+ logger.info("Using standard VAE (no position conditioning)")
390
+
391
+
392
+ vae = DDP(vae, device_ids=[device], find_unused_parameters=True)
393
+ if is_master:
394
+ logger.info("Using find_unused_parameters=True to handle attention layers")
395
+
396
+ opt = torch.optim.AdamW(vae.parameters(), lr=args.learning_rate, weight_decay=0.0)
397
+ criterion = VAELoss(
398
+ kl_weight=args.kl_weight,
399
+ l1_weight=args.l1_weight,
400
+ ssim_weight=args.ssim_weight,
401
+ ).to(device)
402
+
403
+
404
+ scaler = GradScaler('cuda', enabled=args.fp16)
405
+ actual_batch_size = int(args.global_batch_size // dist.get_world_size())
406
+ world_size = dist.get_world_size()
407
+ loader = create_style_transfer_dataloader(
408
+ pairing_json_path=args.data_path,
409
+ batch_size=actual_batch_size,
410
+ patch_size=args.image_size,
411
+ patches_per_image=args.patches_per_image,
412
+ num_workers=args.num_workers,
413
+ shuffle=True,
414
+ drop_last=True,
415
+ device=device,
416
+ distributed=(world_size > 1),
417
+ rank=rank,
418
+ world_size=world_size
419
+ )
420
+
421
+ if is_master:
422
+ logger.info(f"Dataset contains {len(loader.dataset):,} large images")
423
+ logger.info(f"Global batch size: {args.global_batch_size}, Actual batch size: {actual_batch_size}")
424
+ logger.info(f"Patches per image: {args.patches_per_image}")
425
+
426
+ steps_per_epoch = len(loader)
427
+ scheduler = get_lr_scheduler(opt, args, steps_per_epoch)
428
+ if is_master:
429
+ if scheduler:
430
+ logger.info(f"Using LR scheduler: {args.lr_scheduler} with {args.warmup_epochs} warmup epochs")
431
+ logger.info(f"Total steps: {args.epochs * steps_per_epoch}, Warmup steps: {args.warmup_epochs * steps_per_epoch}")
432
+ else:
433
+ logger.info("No LR scheduler (constant learning rate)")
434
+
435
+ train_steps = 0
436
+ start_epoch = 0
437
+
438
+ if args.resume:
439
+ if is_master:
440
+ logger.info(f"Resuming from checkpoint: {args.resume}")
441
+ checkpoint = torch.load(args.resume, map_location=f"cuda:{device}", weights_only=False)
442
+ vae.module.load_state_dict(checkpoint["vae"], strict=False)
443
+ if is_master:
444
+ logger.info(f"Note: using strict=False to ignore unexpected keys (e.g., old attention weights)")
445
+ opt.load_state_dict(checkpoint["opt"])
446
+ train_steps = checkpoint.get("train_steps", 0)
447
+
448
+ if args.start_epoch is not None:
449
+ start_epoch = args.start_epoch
450
+ if is_master:
451
+ logger.info(f"Using manually specified start epoch: {start_epoch}")
452
+ else:
453
+ start_epoch = checkpoint.get("epoch", 0)
454
+ if is_master:
455
+ if "epoch" in checkpoint:
456
+ logger.info(f"Loaded epoch from checkpoint: {start_epoch}")
457
+ else:
458
+ logger.info(f"No epoch info in checkpoint, starting from epoch 0")
459
+
460
+ if scheduler and "scheduler" in checkpoint:
461
+ scheduler.load_state_dict(checkpoint["scheduler"])
462
+ if is_master:
463
+ logger.info(f"Resumed scheduler from step {train_steps}")
464
+
465
+ if is_master:
466
+ logger.info(f"Resumed from epoch {start_epoch}, step {train_steps}")
467
+ elif args.start_epoch is not None:
468
+ start_epoch = args.start_epoch
469
+ if is_master:
470
+ logger.info(f"Starting from manually specified epoch: {start_epoch} (without resume)")
471
+
472
+
473
+ vae.train()
474
+ running_loss = 0
475
+ running_l1_loss = 0
476
+ running_ssim_loss = 0
477
+ running_kl_loss = 0
478
+ log_steps = 0
479
+ start_time = time()
480
+
481
+ if is_master:
482
+ logger.info(f"Training for {args.epochs} epochs (from epoch {start_epoch} to {args.epochs})...")
483
+
484
+ for epoch in range(start_epoch, args.epochs):
485
+ if is_master:
486
+ logger.info(f"Beginning epoch {epoch}...")
487
+
488
+ for a_large_images, b_large_images, paths in loader:
489
+ # Use non_blocking transfer to overlap data loading with computation
490
+ a_large_images = a_large_images.to(device, non_blocking=True)
491
+ b_large_images = b_large_images.to(device, non_blocking=True)
492
+
493
+ # Crop patches on GPU (already on GPU from .to(device))
494
+ if args.multiscale:
495
+ multiscale_sizes = getattr(args, 'multiscale_levels', [32, 64, 128, 256, 512])
496
+ elif hasattr(args, 'multiscale_levels') and args.multiscale_levels is not None:
497
+ multiscale_sizes = args.multiscale_levels
498
+ else:
499
+ multiscale_sizes = None
500
+ if args.vae_domain2:
501
+ _, b_images, pos_info = StyleTransferDataset.crop_patches_from_large_images_with_pos(
502
+ a_large_images, b_large_images,
503
+ patch_size=args.image_size,
504
+ patches_per_image=args.patches_per_image,
505
+ width_norm=15000.0,
506
+ height_norm=20000.0,
507
+ multiscale_sizes=multiscale_sizes
508
+ )
509
+ elif args.vae_domain1:
510
+ b_images, _, pos_info = StyleTransferDataset.crop_patches_from_large_images_with_pos(
511
+ a_large_images,
512
+ b_large_images,
513
+ patch_size=args.image_size,
514
+ patches_per_image=args.patches_per_image,
515
+ width_norm=15000.0,
516
+ height_norm=20000.0,
517
+ multiscale_sizes=multiscale_sizes
518
+ )
519
+
520
+ images = b_images
521
+
522
+ with autocast('cuda', enabled=args.fp16):
523
+ if is_conditional_vae and pos_info is not None:
524
+ recon, posterior = vae(images, sample_posterior=True, pos_context=pos_info)
525
+ else:
526
+ recon, posterior = vae(images, sample_posterior=True)
527
+ losses = criterion(recon, images, posterior)
528
+ loss = losses['total_loss']
529
+
530
+ # Extract loss values without blocking
531
+ loss_val = loss.item()
532
+ l1_loss_val = losses['l1_loss'].item()
533
+ ssim_loss_val = losses['ssim_loss'].item()
534
+ kl_loss_val = losses['kl_loss'].item()
535
+
536
+ opt.zero_grad()
537
+ if args.fp16:
538
+ scaler.scale(loss).backward()
539
+ scaler.step(opt)
540
+ scaler.update()
541
+ else:
542
+ loss.backward()
543
+ opt.step()
544
+
545
+ # Accumulate losses locally (no sync needed)
546
+ running_loss += loss_val
547
+ running_l1_loss += l1_loss_val
548
+ running_ssim_loss += ssim_loss_val
549
+ running_kl_loss += kl_loss_val
550
+ log_steps += 1
551
+ train_steps += 1
552
+
553
+ if scheduler:
554
+ scheduler.step()
555
+
556
+ # Only sync and log periodically (reduces GPU-CPU synchronization overhead)
557
+ if train_steps % args.log_every == 0:
558
+ # Synchronize only when logging (not every step)
559
+ torch.cuda.synchronize()
560
+ end_time = time()
561
+ steps_per_sec = log_steps / (end_time - start_time)
562
+
563
+ # Compute local averages
564
+ local_avg_loss = running_loss / log_steps
565
+ local_avg_l1 = running_l1_loss / log_steps
566
+ local_avg_ssim = running_ssim_loss / log_steps
567
+ local_avg_kl = running_kl_loss / log_steps
568
+
569
+ avg_loss_tensor = torch.tensor(local_avg_loss, device=device)
570
+ avg_l1_tensor = torch.tensor(local_avg_l1, device=device)
571
+ avg_ssim_tensor = torch.tensor(local_avg_ssim, device=device)
572
+ avg_kl_tensor = torch.tensor(local_avg_kl, device=device)
573
+
574
+ dist.all_reduce(avg_loss_tensor, op=dist.ReduceOp.SUM)
575
+ dist.all_reduce(avg_l1_tensor, op=dist.ReduceOp.SUM)
576
+ dist.all_reduce(avg_ssim_tensor, op=dist.ReduceOp.SUM)
577
+ dist.all_reduce(avg_kl_tensor, op=dist.ReduceOp.SUM)
578
+
579
+ avg_loss = avg_loss_tensor.item() / dist.get_world_size()
580
+ avg_l1 = avg_l1_tensor.item() / dist.get_world_size()
581
+ avg_ssim = avg_ssim_tensor.item() / dist.get_world_size()
582
+ avg_kl = avg_kl_tensor.item() / dist.get_world_size()
583
+
584
+ if is_master:
585
+ if writer is not None:
586
+ writer.add_scalar('Loss/total', avg_loss, train_steps)
587
+ writer.add_scalar('Loss/l1', avg_l1, train_steps)
588
+ writer.add_scalar('Loss/ssim', avg_ssim, train_steps)
589
+ writer.add_scalar('Loss/kl', avg_kl, train_steps)
590
+ writer.add_scalar('Training/steps_per_sec', steps_per_sec, train_steps)
591
+ writer.add_scalar('Training/learning_rate', opt.param_groups[0]['lr'], train_steps)
592
+
593
+ current_lr = opt.param_groups[0]['lr']
594
+ logger.info(
595
+ f"(step={train_steps:07d}) "
596
+ f"Loss: {avg_loss:.4f} | "
597
+ f"L1: {avg_l1:.4f} | "
598
+ f"SSIM: {avg_ssim:.4f} | "
599
+ f"KL: {avg_kl:.6f} | "
600
+ f"LR: {current_lr:.2e} | "
601
+ f"Steps/Sec: {steps_per_sec:.2f}"
602
+ )
603
+
604
+ running_loss = 0
605
+ running_l1_loss = 0
606
+ running_ssim_loss = 0
607
+ running_kl_loss = 0
608
+ log_steps = 0
609
+ start_time = time()
610
+
611
+ if args.sample_every > 0 and train_steps % args.sample_every == 0 and train_steps > 0:
612
+ if is_master:
613
+ sample_subdir = f"{sample_dir}/step_{train_steps:07d}"
614
+ # Calculate multiscale_sizes using same logic as training loop
615
+ if args.multiscale:
616
+ sample_multiscale_levels = getattr(args, 'multiscale_levels', [32, 64, 128, 256, 512])
617
+ elif hasattr(args, 'multiscale_levels') and args.multiscale_levels is not None:
618
+ sample_multiscale_levels = args.multiscale_levels
619
+ else:
620
+ sample_multiscale_levels = None
621
+
622
+ save_samples(vae, loader, device, sample_subdir,
623
+ num_samples=args.vis_num_samples,
624
+ patches_per_image=args.patches_per_image,
625
+ patch_size=args.image_size,
626
+ is_conditional=is_conditional_vae,
627
+ use_fp16=args.fp16,
628
+ vae_domain1=args.vae_domain1,
629
+ vae_domain2=args.vae_domain2,
630
+ multiscale=(sample_multiscale_levels is not None),
631
+ multiscale_levels=sample_multiscale_levels)
632
+ logger.info(f"Saved {args.vis_num_samples} samples to {sample_subdir}")
633
+
634
+ if writer is not None:
635
+ sample_image_path = f"{sample_subdir}/sample_0.png"
636
+ if os.path.exists(sample_image_path):
637
+ sample_img = Image.open(sample_image_path)
638
+ sample_img_array = np.array(sample_img)
639
+ if len(sample_img_array.shape) == 2:
640
+ sample_img_tensor = torch.from_numpy(sample_img_array).unsqueeze(0).float() / 255.0
641
+ else:
642
+ sample_img_tensor = torch.from_numpy(sample_img_array).permute(2, 0, 1).float() / 255.0
643
+ writer.add_image('Samples/reconstruction', sample_img_tensor, train_steps)
644
+ dist.barrier()
645
+
646
+ if train_steps % args.ckpt_every == 0 and train_steps > 0:
647
+ if is_master:
648
+ checkpoint = {
649
+ "vae": vae.module.state_dict(),
650
+ "opt": opt.state_dict(),
651
+ "train_steps": train_steps,
652
+ "epoch": epoch,
653
+ "args": args
654
+ }
655
+ if scheduler:
656
+ checkpoint["scheduler"] = scheduler.state_dict()
657
+ checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
658
+ torch.save(checkpoint, checkpoint_path)
659
+ logger.info(f"Saved checkpoint to {checkpoint_path} (epoch {epoch}, step {train_steps})")
660
+ dist.barrier()
661
+
662
+ if is_master:
663
+ logger.info("Done!")
664
+
665
+ if is_master and writer is not None:
666
+ writer.close()
667
+ logger.info("TensorBoard writer closed.")
668
+
669
+ cleanup()
670
+
671
+
672
+ if __name__ == "__main__":
673
+ parser = argparse.ArgumentParser()
674
+
675
+ # data parameters
676
+ parser.add_argument("--data-path", type=str, required=True, help="Path to pairing JSON file")
677
+ parser.add_argument("--image-size", type=int, default=256, help="Patch size")
678
+ parser.add_argument("--patches-per-image", type=int, default=4, help="Number of patches to crop from each large image")
679
+ parser.add_argument("--multiscale", action="store_true", help="Enable multiscale training: randomly crop from specified sizes and resize to image-size")
680
+ parser.add_argument("--multiscale-levels", type=int, nargs='+', default=[32, 64, 128, 256, 512],
681
+ help="Multiscale crop sizes (default: 32 64 128 256 512). Example: --multiscale-levels 128 256 512")
682
+ parser.add_argument("--results-dir", type=str, default="results_vae")
683
+
684
+ # VAE model selection (either use predefined model or custom parameters)
685
+ parser.add_argument("--vae-model", type=str, default=None,
686
+ help="Predefined VAE model name (e.g., VAE-KL-f8, VAE-KL-f16). If specified, overrides custom architecture parameters.")
687
+
688
+ # VAE architecture parameters (when --vae-model is not specified)
689
+ parser.add_argument("--embed-dim", type=int, default=4, help="Latent embedding dimension")
690
+ parser.add_argument("--z-channels", type=int, default=4, help="Number of latent channels")
691
+ parser.add_argument("--in-channels", type=int, default=3, help="Number of input channels")
692
+ parser.add_argument("--out-channels", type=int, default=3, help="Number of output channels")
693
+ parser.add_argument("--ch", type=int, default=128, help="Base channel count")
694
+ parser.add_argument("--ch-mult", type=int, nargs="+", default=[1, 2, 4, 4], help="Channel multipliers")
695
+ parser.add_argument("--num-res-blocks", type=int, default=2, help="Number of residual blocks per level")
696
+ parser.add_argument("--attn-resolutions", type=int, nargs="*", default=[], help="Resolutions at which to apply attention")
697
+ parser.add_argument("--dropout", type=float, default=0.0, help="Dropout rate")
698
+ parser.add_argument("--double-z", action="store_true", default=True, help="Double z for mean and variance")
699
+
700
+ # loss parameters
701
+ parser.add_argument("--kl-weight", type=float, default=1e-6, help="Weight for KL divergence loss")
702
+ parser.add_argument("--l1-weight", type=float, default=1.0, help="Weight for L1 reconstruction loss")
703
+ parser.add_argument("--ssim-weight", type=float, default=1.0, help="Weight for SSIM reconstruction loss")
704
+
705
+ # training parameters
706
+ parser.add_argument("--epochs", type=int, default=100)
707
+ parser.add_argument("--global-batch-size", type=int, default=4)
708
+ parser.add_argument("--learning-rate", type=float, default=4.5e-6)
709
+ parser.add_argument("--global-seed", type=int, default=0)
710
+ parser.add_argument("--num-workers", type=int, default=4)
711
+ parser.add_argument("--log-every", type=int, default=100)
712
+ parser.add_argument("--ckpt-every", type=int, default=5000)
713
+ parser.add_argument("--sample-every", type=int, default=1000, help="Save reconstruction samples every N steps")
714
+ parser.add_argument("--vis_num-samples", type=int, default=8, help="Number of reconstruction samples to save")
715
+ parser.add_argument("--fp16", action="store_true", help="Use mixed precision training")
716
+ parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint to resume from")
717
+ parser.add_argument("--start-epoch", type=int, default=None, help="Manually specify starting epoch (useful for old checkpoints without epoch info)")
718
+ parser.add_argument("--vae-domain1", action="store_true", help="use domain1 for training")
719
+ parser.add_argument("--vae-domain2", action="store_true", help="use domain2 for training")
720
+
721
+ # scheduler parameters
722
+ parser.add_argument("--lr-scheduler", type=str, default="linear",
723
+ choices=["none", "linear", "cosine", "constant"],
724
+ help="Learning rate scheduler type")
725
+ parser.add_argument("--warmup-epochs", type=int, default=0,
726
+ help="Number of warmup epochs (linear warmup from 0 to initial lr)")
727
+
728
+ args = parser.parse_args()
729
+ main(args)