| """Training utils for VibeToken.""" |
| import json |
| import os |
| import time |
| import math |
| from pathlib import Path |
| import pprint |
| import glob |
| from collections import defaultdict |
| import random |
| import gc |
|
|
| from data import SimpleImageDataset, PretoeknizedDataSetJSONL, PretokenizedWebDataset |
| import torch |
| from torch.utils.data import DataLoader |
| from omegaconf import OmegaConf |
| from torch.optim import AdamW |
| from utils.lr_schedulers import get_scheduler |
| from modeling.modules import EMAModel, ReconstructionLoss_Single_Stage |
| from modeling.vibetoken_model import VibeTokenModel, PretrainedTokenizer |
| from evaluator import VQGANEvaluator |
|
|
| from utils.viz_utils import make_viz_from_samples |
| from torchinfo import summary |
| import accelerate |
|
|
| def get_config(): |
| """Reads configs from a yaml file and terminal.""" |
| cli_conf = OmegaConf.from_cli() |
|
|
| yaml_conf = OmegaConf.load(cli_conf.config) |
| conf = OmegaConf.merge(yaml_conf, cli_conf) |
|
|
| return conf |
|
|
|
|
| class AverageMeter(object): |
| """Computes and stores the average and current value. |
| |
| This class is borrowed from |
| https://github.com/pytorch/examples/blob/main/imagenet/main.py#L423 |
| """ |
|
|
| def __init__(self): |
| self.reset() |
|
|
| def reset(self): |
| self.val = 0 |
| self.avg = 0 |
| self.sum = 0 |
| self.count = 0 |
|
|
| def update(self, val, n=1): |
| self.val = val |
| self.sum += val * n |
| self.count += n |
| self.avg = self.sum / self.count |
|
|
|
|
| def create_pretrained_tokenizer(config, accelerator=None): |
| if config.model.vq_model.finetune_decoder: |
| pretrianed_tokenizer = None |
| else: |
| pretrianed_tokenizer = PretrainedTokenizer(config.model.vq_model.pretrained_tokenizer_weight) |
| if accelerator is not None: |
| pretrianed_tokenizer.to(accelerator.device) |
| return pretrianed_tokenizer |
|
|
|
|
| def create_model_and_loss_module(config, logger, accelerator, |
| model_type="vibetoken"): |
| """Creates model and loss module.""" |
| logger.info("Creating model and loss module.") |
| if model_type == "vibetoken": |
| if config.model.sub_model_type == "vibetoken": |
| model_cls = VibeTokenModel |
| loss_cls = ReconstructionLoss_Single_Stage |
| else: |
| raise ValueError(f"Unsupported sub_model_type {config.model.sub_model_type}") |
| else: |
| raise ValueError(f"Unsupported model_type {model_type}") |
| model = model_cls(config) |
|
|
| if config.experiment.get("init_weight", ""): |
| model_weight = torch.load(config.experiment.init_weight, map_location="cpu") |
| if config.model.vq_model.finetune_decoder: |
| pretrained_tokenizer_weight = torch.load( |
| config.model.vq_model.pretrained_tokenizer_weight, map_location="cpu" |
| ) |
| pretrained_tokenizer_weight = {"pixel_" + k:v for k,v in pretrained_tokenizer_weight.items() if not "encoder." in k} |
| model_weight.update(pretrained_tokenizer_weight) |
| |
| msg = model.load_state_dict(model_weight, strict=False) |
| logger.info(f"loading weight from {config.experiment.init_weight}, msg: {msg}") |
|
|
| |
| ema_model = None |
| if config.training.use_ema: |
| ema_model = EMAModel(model.parameters(), decay=0.999, |
| model_cls=model_cls, config=config) |
| def load_model_hook(models, input_dir): |
| load_model = EMAModel.from_pretrained(os.path.join(input_dir, "ema_model"), |
| model_cls=model_cls, config=config) |
| ema_model.load_state_dict(load_model.state_dict()) |
| ema_model.to(accelerator.device) |
| del load_model |
|
|
| def save_model_hook(models, weights, output_dir): |
| if accelerator.is_main_process: |
| ema_model.save_pretrained(os.path.join(output_dir, "ema_model")) |
|
|
| accelerator.register_load_state_pre_hook(load_model_hook) |
| accelerator.register_save_state_pre_hook(save_model_hook) |
|
|
| loss_module = loss_cls(config=config) if loss_cls is not None else None |
|
|
| if accelerator.is_main_process: |
| if model_type in ["vibetoken"]: |
| logger.info("VibeToken model summary not implemented yet.") |
| else: |
| raise NotImplementedError |
|
|
| return model, ema_model, loss_module |
|
|
|
|
| def create_optimizer(config, logger, model, loss_module, |
| model_type="vibetoken", need_discrminator=True): |
| """Creates optimizer for model and discriminator.""" |
| logger.info("Creating optimizers.") |
| optimizer_config = config.optimizer.params |
| learning_rate = optimizer_config.learning_rate |
|
|
| optimizer_type = config.optimizer.name |
| if optimizer_type == "adamw": |
| optimizer_cls = AdamW |
| else: |
| raise ValueError(f"Optimizer {optimizer_type} not supported") |
|
|
| exclude = (lambda n, p: p.ndim < 2 or "ln" in n or "bias" in n or 'latent_tokens' in n |
| or 'mask_token' in n or 'embedding' in n or 'norm' in n or 'gamma' in n or 'embed' in n) |
| include = lambda n, p: not exclude(n, p) |
| named_parameters = list(model.named_parameters()) |
| gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad] |
| rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] |
| optimizer = optimizer_cls( |
| [ |
| {"params": gain_or_bias_params, "weight_decay": 0.}, |
| {"params": rest_params, "weight_decay": optimizer_config.weight_decay}, |
| ], |
| lr=learning_rate, |
| betas=(optimizer_config.beta1, optimizer_config.beta2) |
| ) |
|
|
| if (config.model.vq_model.finetune_decoder or model_type == "vibetoken") and need_discrminator: |
| discriminator_learning_rate = optimizer_config.discriminator_learning_rate |
| discriminator_named_parameters = list(loss_module.named_parameters()) |
| discriminator_gain_or_bias_params = [p for n, p in discriminator_named_parameters if exclude(n, p) and p.requires_grad] |
| discriminator_rest_params = [p for n, p in discriminator_named_parameters if include(n, p) and p.requires_grad] |
|
|
| discriminator_optimizer = optimizer_cls( |
| [ |
| {"params": discriminator_gain_or_bias_params, "weight_decay": 0.}, |
| {"params": discriminator_rest_params, "weight_decay": optimizer_config.weight_decay}, |
| ], |
| lr=discriminator_learning_rate, |
| betas=(optimizer_config.beta1, optimizer_config.beta2) |
| ) |
| else: |
| discriminator_optimizer = None |
|
|
| assert discriminator_optimizer is not None, "Discriminator optimizer is None with condition values: {config.model.vq_model.finetune_decoder} {model_type} {need_discrminator}" |
|
|
| return optimizer, discriminator_optimizer |
|
|
|
|
| def create_lr_scheduler(config, logger, accelerator, optimizer, discriminator_optimizer=None): |
| """Creates learning rate scheduler for model and discriminator.""" |
| logger.info("Creating lr_schedulers.") |
| lr_scheduler = get_scheduler( |
| config.lr_scheduler.scheduler, |
| optimizer=optimizer, |
| num_training_steps=config.training.max_train_steps * accelerator.num_processes, |
| num_warmup_steps=config.lr_scheduler.params.warmup_steps * accelerator.num_processes, |
| base_lr=config.lr_scheduler.params.learning_rate, |
| end_lr=config.lr_scheduler.params.end_lr, |
| ) |
| if discriminator_optimizer is not None: |
| discriminator_lr_scheduler = get_scheduler( |
| config.lr_scheduler.scheduler, |
| optimizer=discriminator_optimizer, |
| num_training_steps=config.training.max_train_steps * accelerator.num_processes - config.losses.discriminator_start, |
| num_warmup_steps=config.lr_scheduler.params.warmup_steps * accelerator.num_processes, |
| base_lr=config.lr_scheduler.params.learning_rate, |
| end_lr=config.lr_scheduler.params.end_lr, |
| ) |
| else: |
| discriminator_lr_scheduler = None |
| return lr_scheduler, discriminator_lr_scheduler |
|
|
|
|
| def create_dataloader(config, logger, accelerator): |
| """Creates data loader for training and testing.""" |
| logger.info("Creating dataloaders.") |
| total_batch_size_without_accum = config.training.per_gpu_batch_size * accelerator.num_processes |
| total_batch_size = ( |
| config.training.per_gpu_batch_size * accelerator.num_processes * config.training.gradient_accumulation_steps |
| ) |
| preproc_config = config.dataset.preprocessing |
| dataset_config = config.dataset.params |
|
|
| if dataset_config.get("pretokenization", "") and dataset_config.get("dataset_with_text_label", False) is True: |
| dataset = PretokenizedWebDataset( |
| train_shards_path=dataset_config.train_shards_path_or_url, |
| eval_shards_path=dataset_config.eval_shards_path_or_url, |
| num_train_examples=config.experiment.max_train_examples, |
| per_gpu_batch_size=config.training.per_gpu_batch_size, |
| global_batch_size=total_batch_size_without_accum, |
| num_workers_per_gpu=dataset_config.num_workers_per_gpu, |
| resize_shorter_edge=preproc_config.resize_shorter_edge, |
| crop_size=preproc_config.crop_size, |
| random_crop=preproc_config.random_crop, |
| random_flip=preproc_config.random_flip, |
| normalize_mean=preproc_config.normalize_mean, |
| normalize_std=preproc_config.normalize_std, |
| process_recap=preproc_config.get("preproc_recap", True), |
| use_recap_prob=preproc_config.get("use_recap_prob", 0.95) |
| ) |
| train_dataloader, eval_dataloader = dataset.train_dataloader, dataset.eval_dataloader |
| elif dataset_config.get("pretokenization", "") and dataset_config.get("dataset_with_text_label", False) is False: |
| dataset = SimpleImageDataset( |
| train_shards_path=dataset_config.train_shards_path_or_url, |
| eval_shards_path=dataset_config.eval_shards_path_or_url, |
| num_train_examples=config.experiment.max_train_examples, |
| per_gpu_batch_size=config.training.per_gpu_batch_size, |
| global_batch_size=total_batch_size_without_accum, |
| num_workers_per_gpu=dataset_config.num_workers_per_gpu, |
| resize_shorter_edge=preproc_config.resize_shorter_edge, |
| crop_size=preproc_config.crop_size, |
| random_crop=preproc_config.random_crop, |
| random_flip=preproc_config.random_flip, |
| dataset_with_class_label=dataset_config.get("dataset_with_class_label", True), |
| dataset_with_text_label=dataset_config.get("dataset_with_text_label", False), |
| res_ratio_filtering=preproc_config.get("res_ratio_filtering", False), |
| min_tokens=preproc_config.min_tokens, |
| max_tokens=preproc_config.max_tokens, |
| ) |
| train_dataloader, eval_dataloader = dataset.train_dataloader, dataset.eval_dataloader |
| else: |
| if dataset_config.get("pretokenization", ""): |
| train_dataloader = DataLoader( |
| PretoeknizedDataSetJSONL(dataset_config.pretokenization), |
| batch_size=config.training.per_gpu_batch_size, |
| shuffle=True, drop_last=True, pin_memory=True) |
| train_dataloader.num_batches = math.ceil( |
| config.experiment.max_train_examples / total_batch_size_without_accum) |
| |
| return train_dataloader, eval_dataloader |
|
|
|
|
| class LazyVQGANEvaluator: |
| """A lazy-loading wrapper for VQGANEvaluator that delays inception model initialization.""" |
| |
| def __init__(self, device, enable_rfid=True, enable_inception_score=True, |
| enable_codebook_usage_measure=False, enable_codebook_entropy_measure=False, |
| num_codebook_entries=1024, accelerator=None): |
| self._device = device |
| self._enable_rfid = enable_rfid |
| self._enable_inception_score = enable_inception_score |
| self._enable_codebook_usage_measure = enable_codebook_usage_measure |
| self._enable_codebook_entropy_measure = enable_codebook_entropy_measure |
| self._num_codebook_entries = num_codebook_entries |
| self._accelerator = accelerator |
| self._evaluator = None |
| self._initialized = False |
| |
| def _ensure_initialized(self): |
| """Initialize the real evaluator only when needed.""" |
| if not self._initialized: |
| if self._accelerator and self._accelerator.num_processes > 1: |
| if self._accelerator.is_main_process: |
| try: |
| from evaluator.inception import get_inception_model |
| _ = get_inception_model() |
| except Exception as e: |
| print(f"Warning: Failed to pre-load inception model: {e}") |
| |
| if self._accelerator: |
| self._accelerator.wait_for_everyone() |
| |
| try: |
| self._evaluator = VQGANEvaluator( |
| device=self._device, |
| enable_rfid=self._enable_rfid, |
| enable_inception_score=self._enable_inception_score, |
| enable_codebook_usage_measure=self._enable_codebook_usage_measure, |
| enable_codebook_entropy_measure=self._enable_codebook_entropy_measure, |
| num_codebook_entries=self._num_codebook_entries |
| ) |
| self._initialized = True |
| except Exception as e: |
| print(f"Warning: Failed to create VQGANEvaluator, using dummy: {e}") |
| class DummyEvaluator: |
| def reset_metrics(self): pass |
| def update(self, real_images, fake_images, codebook_indices=None): pass |
| def result(self): |
| return {"InceptionScore": 0.0, "rFID": 0.0, "CodebookUsage": 0.0, "CodebookEntropy": 0.0} |
| self._evaluator = DummyEvaluator() |
| self._initialized = True |
| |
| def reset_metrics(self): |
| self._ensure_initialized() |
| return self._evaluator.reset_metrics() |
| |
| def update(self, real_images, fake_images, codebook_indices=None): |
| self._ensure_initialized() |
| return self._evaluator.update(real_images, fake_images, codebook_indices) |
| |
| def result(self): |
| self._ensure_initialized() |
| return self._evaluator.result() |
|
|
|
|
| def create_evaluator(config, logger, accelerator): |
| """Creates evaluator.""" |
| logger.info("Creating evaluator.") |
| |
| if config.model.vq_model.get("quantize_mode", "vq") in ["vq", "softvq", "mvq"]: |
| evaluator = LazyVQGANEvaluator( |
| device=accelerator.device, |
| enable_rfid=True, |
| enable_inception_score=True, |
| enable_codebook_usage_measure=True, |
| enable_codebook_entropy_measure=True, |
| num_codebook_entries=config.model.vq_model.codebook_size, |
| accelerator=accelerator |
| ) |
| elif config.model.vq_model.get("quantize_mode", "vq") == "vae": |
| evaluator = LazyVQGANEvaluator( |
| device=accelerator.device, |
| enable_rfid=True, |
| enable_inception_score=True, |
| enable_codebook_usage_measure=False, |
| enable_codebook_entropy_measure=False, |
| accelerator=accelerator |
| ) |
| else: |
| raise NotImplementedError |
| |
| logger.info("Lazy evaluator creation completed.") |
| return evaluator |
|
|
|
|
| def auto_resume(config, logger, accelerator, ema_model, |
| num_update_steps_per_epoch, strict=True): |
| """Auto resuming the training.""" |
| global_step = 0 |
| first_epoch = 0 |
| if config.experiment.resume: |
| accelerator.wait_for_everyone() |
| if accelerator.is_main_process: |
| local_ckpt_list = list(glob.glob(os.path.join( |
| config.experiment.output_dir, "checkpoint*"))) |
| logger.info(f"All globbed checkpoints are: {local_ckpt_list}") |
| else: |
| local_ckpt_list = [] |
| |
| if accelerator.num_processes > 1: |
| checkpoint_count = torch.tensor(len(local_ckpt_list), device=accelerator.device) |
| accelerate.utils.broadcast(checkpoint_count, 0) |
| |
| if checkpoint_count > 0: |
| if accelerator.is_main_process: |
| if len(local_ckpt_list) > 1: |
| fn = lambda x: int(x.split('/')[-1].split('-')[-1]) |
| checkpoint_paths = sorted(local_ckpt_list, key=fn, reverse=True) |
| else: |
| checkpoint_paths = local_ckpt_list |
| latest_checkpoint = checkpoint_paths[0] |
| else: |
| latest_checkpoint = "" |
| |
| if accelerator.is_main_process: |
| checkpoint_path_tensor = torch.tensor([ord(c) for c in latest_checkpoint], device=accelerator.device, dtype=torch.long) |
| path_length = torch.tensor(len(latest_checkpoint), device=accelerator.device) |
| else: |
| path_length = torch.tensor(0, device=accelerator.device) |
| |
| accelerate.utils.broadcast(path_length, 0) |
| |
| if not accelerator.is_main_process: |
| checkpoint_path_tensor = torch.zeros(path_length.item(), device=accelerator.device, dtype=torch.long) |
| |
| accelerate.utils.broadcast(checkpoint_path_tensor, 0) |
| |
| if not accelerator.is_main_process: |
| latest_checkpoint = ''.join([chr(c.item()) for c in checkpoint_path_tensor]) |
| |
| global_step = load_checkpoint( |
| Path(latest_checkpoint), |
| accelerator, |
| logger=logger, |
| strict=strict |
| ) |
| if config.training.use_ema: |
| ema_model.set_step(global_step) |
| first_epoch = global_step // num_update_steps_per_epoch |
| else: |
| logger.info("Training from scratch.") |
| else: |
| if len(local_ckpt_list) >= 1: |
| if len(local_ckpt_list) > 1: |
| fn = lambda x: int(x.split('/')[-1].split('-')[-1]) |
| checkpoint_paths = sorted(local_ckpt_list, key=fn, reverse=True) |
| else: |
| checkpoint_paths = local_ckpt_list |
| global_step = load_checkpoint( |
| Path(checkpoint_paths[0]), |
| accelerator, |
| logger=logger, |
| strict=strict |
| ) |
| if config.training.use_ema: |
| ema_model.set_step(global_step) |
| first_epoch = global_step // num_update_steps_per_epoch |
| else: |
| logger.info("Training from scratch.") |
| |
| accelerator.wait_for_everyone() |
| return global_step, first_epoch |
|
|
|
|
| def train_one_epoch(config, logger, accelerator, |
| model, ema_model, loss_module, |
| optimizer, discriminator_optimizer, |
| lr_scheduler, discriminator_lr_scheduler, |
| train_dataloader, eval_dataloader, |
| evaluator, |
| global_step, |
| model_type="vibetoken", |
| clip_tokenizer=None, |
| clip_encoder=None, |
| pretrained_tokenizer=None): |
| """One epoch training.""" |
| batch_time_meter = AverageMeter() |
| data_time_meter = AverageMeter() |
| end = time.time() |
|
|
| model.train() |
|
|
| autoencoder_logs = defaultdict(float) |
| discriminator_logs = defaultdict(float) |
| for i, batch in enumerate(train_dataloader): |
| model.train() |
| if "image" in batch: |
| images = batch["image"].to( |
| accelerator.device, memory_format=torch.contiguous_format, non_blocking=True |
| ) |
| if config.training.get("variable_resolution", False): |
| any2any = config.training.variable_resolution.get("any2any", True) |
|
|
| dims = config.training.variable_resolution.dim |
| ratios = config.training.variable_resolution.ratio |
| assert len(dims) == len(ratios), "dims and ratios must have the same length" |
| input_res = tuple(random.choices(dims, weights=ratios, k=1)[0]) |
| |
| if any2any: |
| output_res = tuple(random.choices(dims, weights=ratios, k=1)[0]) |
| else: |
| output_res = input_res |
| |
| images = torch.nn.functional.interpolate(images, size=output_res, mode="bilinear", align_corners=False) |
| input_images = torch.nn.functional.interpolate(images, size=input_res, mode="bilinear", align_corners=False) |
| else: |
| input_images = images |
| output_res = (None, None) |
|
|
| fnames = batch["__key__"] |
| data_time_meter.update(time.time() - end) |
|
|
| if pretrained_tokenizer is not None: |
| pretrained_tokenizer.eval() |
| proxy_codes = pretrained_tokenizer.encode(images) |
| else: |
| proxy_codes = None |
|
|
| with accelerator.accumulate([model, loss_module]): |
| additional_args = {} |
| if config.model.get("train_with_attention", False): |
| additional_args["key_attention_mask"] = batch["attention_mask"].to( |
| accelerator.device, memory_format=torch.contiguous_format, non_blocking=True |
| ) |
| reconstructed_images, extra_results_dict = model(input_images, height=output_res[0], width=output_res[1], **additional_args) |
| autoencoder_loss, loss_dict = loss_module( |
| images, |
| reconstructed_images, |
| extra_results_dict, |
| global_step, |
| mode="generator", |
| ) |
|
|
| autoencoder_logs = {} |
| for k, v in loss_dict.items(): |
| if k in ["discriminator_factor", "d_weight"]: |
| if type(v) == torch.Tensor: |
| autoencoder_logs["train/" + k] = v.cpu().item() |
| else: |
| autoencoder_logs["train/" + k] = v |
| else: |
| gathered_tensor = accelerator.gather(v) |
| autoencoder_logs["train/" + k] = gathered_tensor.mean().item() |
| del gathered_tensor |
| |
| torch.cuda.empty_cache() |
| accelerator.backward(autoencoder_loss) |
|
|
| if config.training.max_grad_norm is not None and accelerator.sync_gradients: |
| accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) |
|
|
| optimizer.step() |
| lr_scheduler.step() |
|
|
| if ( |
| accelerator.sync_gradients |
| and (global_step + 1) % config.experiment.log_grad_norm_every == 0 |
| and accelerator.is_main_process |
| ): |
| log_grad_norm(model, accelerator, global_step + 1) |
|
|
| optimizer.zero_grad(set_to_none=True) |
|
|
| |
| discriminator_logs = defaultdict(float) |
| if (config.model.vq_model.finetune_decoder or model_type == "vibetoken") and accelerator.unwrap_model(loss_module).should_discriminator_be_trained(global_step): |
| discriminator_logs = defaultdict(float) |
| discriminator_loss, loss_dict_discriminator = loss_module( |
| images, |
| reconstructed_images, |
| extra_results_dict, |
| global_step=global_step, |
| mode="discriminator", |
| ) |
|
|
| for k, v in loss_dict_discriminator.items(): |
| if k in ["logits_real", "logits_fake"]: |
| if type(v) == torch.Tensor: |
| discriminator_logs["train/" + k] = v.cpu().item() |
| else: |
| discriminator_logs["train/" + k] = v |
| else: |
| gathered_tensor = accelerator.gather(v) |
| discriminator_logs["train/" + k] = gathered_tensor.mean().item() |
| del gathered_tensor |
|
|
| torch.cuda.empty_cache() |
| accelerator.backward(discriminator_loss) |
|
|
| if config.training.max_grad_norm is not None and accelerator.sync_gradients: |
| accelerator.clip_grad_norm_(loss_module.parameters(), config.training.max_grad_norm) |
|
|
| discriminator_optimizer.step() |
| discriminator_lr_scheduler.step() |
| |
| if ( |
| accelerator.sync_gradients |
| and (global_step + 1) % config.experiment.log_grad_norm_every == 0 |
| and accelerator.is_main_process |
| ): |
| log_grad_norm(loss_module, accelerator, global_step + 1) |
| |
| discriminator_optimizer.zero_grad(set_to_none=True) |
|
|
| if accelerator.sync_gradients: |
| if config.training.use_ema: |
| ema_model.step(model.parameters()) |
| batch_time_meter.update(time.time() - end) |
| end = time.time() |
|
|
| if (global_step + 1) % config.experiment.log_every == 0: |
| samples_per_second_per_gpu = ( |
| config.training.gradient_accumulation_steps * config.training.per_gpu_batch_size / batch_time_meter.val |
| ) |
|
|
| lr = lr_scheduler.get_last_lr()[0] |
| logger.info( |
| f"Data (t): {data_time_meter.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " |
| f"Batch (t): {batch_time_meter.val:0.4f} " |
| f"LR: {lr:0.6f} " |
| f"Step: {global_step + 1} " |
| f"Total Loss: {autoencoder_logs['train/total_loss']:0.4f} " |
| f"Recon Loss: {autoencoder_logs['train/reconstruction_loss']:0.4f} " |
| ) |
| logs = { |
| "lr": lr, |
| "lr/generator": lr, |
| "samples/sec/gpu": samples_per_second_per_gpu, |
| "time/data_time": data_time_meter.val, |
| "time/batch_time": batch_time_meter.val, |
| } |
| logs.update(autoencoder_logs) |
| logs.update(discriminator_logs) |
| accelerator.log(logs, step=global_step + 1) |
|
|
| del autoencoder_logs, discriminator_logs, logs |
| gc.collect() |
|
|
| batch_time_meter.reset() |
| data_time_meter.reset() |
|
|
| |
| if (global_step + 1) % config.experiment.save_every == 0: |
| save_path = save_checkpoint( |
| model, config.experiment.output_dir, accelerator, global_step + 1, logger=logger) |
| accelerator.wait_for_everyone() |
|
|
| |
| if (global_step + 1) % config.experiment.generate_every == 0: |
| if accelerator.is_main_process: |
| if config.training.get("use_ema", False): |
| ema_model.store(model.parameters()) |
| ema_model.copy_to(model.parameters()) |
|
|
| reconstruct_images( |
| model, |
| images[:config.training.num_generated_images], |
| fnames[:config.training.num_generated_images], |
| accelerator, |
| global_step + 1, |
| config.experiment.output_dir, |
| logger=logger, |
| config=config, |
| pretrained_tokenizer=pretrained_tokenizer |
| ) |
|
|
| if config.training.get("use_ema", False): |
| ema_model.restore(model.parameters()) |
| |
| accelerator.wait_for_everyone() |
|
|
|
|
| |
| if eval_dataloader is not None and (global_step + 1) % config.experiment.eval_every == 0: |
| logger.info(f"Computing metrics on the validation set.") |
| if config.training.get("use_ema", False): |
| ema_model.store(model.parameters()) |
| ema_model.copy_to(model.parameters()) |
| eval_scores = eval_reconstruction( |
| config, |
| model, |
| eval_dataloader, |
| accelerator, |
| evaluator, |
| pretrained_tokenizer=pretrained_tokenizer |
| ) |
| logger.info( |
| f"EMA EVALUATION " |
| f"Step: {global_step + 1} " |
| ) |
| logger.info(pprint.pformat(eval_scores)) |
| if accelerator.is_main_process: |
| eval_log = {f'ema_eval/'+k: v for k, v in eval_scores.items()} |
| accelerator.log(eval_log, step=global_step + 1) |
| if config.training.get("use_ema", False): |
| ema_model.restore(model.parameters()) |
| else: |
| eval_scores = eval_reconstruction( |
| config, |
| model, |
| eval_dataloader, |
| accelerator, |
| evaluator, |
| pretrained_tokenizer=pretrained_tokenizer |
| ) |
|
|
| logger.info( |
| f"Non-EMA EVALUATION " |
| f"Step: {global_step + 1} " |
| ) |
| logger.info(pprint.pformat(eval_scores)) |
| if accelerator.is_main_process: |
| eval_log = {f'eval/'+k: v for k, v in eval_scores.items()} |
| accelerator.log(eval_log, step=global_step + 1) |
|
|
| accelerator.wait_for_everyone() |
|
|
| global_step += 1 |
|
|
| if global_step >= config.training.max_train_steps: |
| accelerator.print( |
| f"Finishing training: Global step is >= Max train steps: {global_step} >= {config.training.max_train_steps}" |
| ) |
| break |
|
|
|
|
| return global_step |
|
|
|
|
| @torch.no_grad() |
| def eval_reconstruction( |
| config, |
| model, |
| eval_loader, |
| accelerator, |
| evaluator, |
| pretrained_tokenizer=None |
| ): |
| model.eval() |
| evaluator.reset_metrics() |
| local_model = accelerator.unwrap_model(model) |
|
|
| accelerator.wait_for_everyone() |
| |
| for batch in eval_loader: |
| images = batch["image"].to( |
| accelerator.device, memory_format=torch.contiguous_format, non_blocking=True |
| ) |
|
|
| original_images = torch.clone(images) |
| additional_args = {} |
| if config.model.get("eval_with_attention", False): |
| additional_args["key_attention_mask"] = batch["attention_mask"].to( |
| accelerator.device, memory_format=torch.contiguous_format, non_blocking=True |
| ) |
| reconstructed_images, model_dict = local_model(images, **additional_args) |
|
|
| if pretrained_tokenizer is not None: |
| reconstructed_images = pretrained_tokenizer.decode(reconstructed_images.argmax(1)) |
| reconstructed_images = torch.clamp(reconstructed_images, 0.0, 1.0) |
| reconstructed_images = torch.round(reconstructed_images * 255.0) / 255.0 |
| original_images = torch.clamp(original_images, 0.0, 1.0) |
| |
| if isinstance(model_dict, dict): |
| evaluator.update(original_images, reconstructed_images.squeeze(2), model_dict["min_encoding_indices"]) |
| else: |
| evaluator.update(original_images, reconstructed_images.squeeze(2), None) |
| |
| accelerator.wait_for_everyone() |
| |
| local_results = evaluator.result() |
| |
| if accelerator.num_processes > 1: |
| gathered_results = {} |
| for key, value in local_results.items(): |
| if isinstance(value, (int, float)): |
| value_tensor = torch.tensor(value, device=accelerator.device) |
| gathered_values = accelerator.gather(value_tensor) |
| gathered_results[key] = gathered_values.mean().item() |
| else: |
| gathered_results[key] = value |
| |
| accelerator.wait_for_everyone() |
| model.train() |
| return gathered_results |
| else: |
| model.train() |
| return local_results |
|
|
|
|
| @torch.no_grad() |
| def reconstruct_images(model, original_images, fnames, accelerator, |
| global_step, output_dir, logger, config=None, |
| pretrained_tokenizer=None): |
| logger.info("Reconstructing images...") |
| original_images = torch.clone(original_images) |
| _, _, height, width = original_images.shape |
| model.eval() |
| dtype = torch.float32 |
| if accelerator.mixed_precision == "fp16": |
| dtype = torch.float16 |
| elif accelerator.mixed_precision == "bf16": |
| dtype = torch.bfloat16 |
|
|
| with torch.autocast("cuda", dtype=dtype, enabled=accelerator.mixed_precision != "no"): |
| enc_tokens, encoder_dict = accelerator.unwrap_model(model).encode(original_images) |
| reconstructed_images = accelerator.unwrap_model(model).decode(enc_tokens, height=height, width=width) |
| if pretrained_tokenizer is not None: |
| reconstructed_images = pretrained_tokenizer.decode(reconstructed_images.argmax(1)) |
|
|
| images_for_saving, images_for_logging = make_viz_from_samples( |
| original_images, |
| reconstructed_images |
| ) |
| if config.training.enable_wandb: |
| accelerator.get_tracker("wandb").log_images( |
| {f"Train Reconstruction": images_for_saving}, |
| step=global_step |
| ) |
| else: |
| accelerator.get_tracker("tensorboard").log_images( |
| {"Train Reconstruction": images_for_logging}, step=global_step |
| ) |
| root = Path(output_dir) / "train_images" |
| os.makedirs(root, exist_ok=True) |
| for i,img in enumerate(images_for_saving): |
| filename = f"{global_step:08}_s-{i:03}-{fnames[i]}.png" |
| path = os.path.join(root, filename) |
| img.save(path) |
|
|
| model.train() |
|
|
|
|
| def save_checkpoint(model, output_dir, accelerator, global_step, logger) -> Path: |
| save_path = Path(output_dir) / f"checkpoint-{global_step}" |
|
|
| state_dict = accelerator.get_state_dict(model) |
| if accelerator.is_main_process: |
| unwrapped_model = accelerator.unwrap_model(model) |
| unwrapped_model.save_pretrained_weight( |
| save_path / "unwrapped_model", |
| save_function=accelerator.save, |
| state_dict=state_dict, |
| ) |
| json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) |
| logger.info(f"Saved state to {save_path}") |
|
|
| accelerator.save_state(save_path) |
| return save_path |
|
|
|
|
| def load_checkpoint(checkpoint_path: Path, accelerator, logger, strict=True): |
| logger.info(f"Load checkpoint from {checkpoint_path}") |
|
|
| accelerator.load_state(checkpoint_path, strict=strict) |
| |
| with open(checkpoint_path / "metadata.json", "r") as f: |
| global_step = int(json.load(f)["global_step"]) |
|
|
| logger.info(f"Resuming at global_step {global_step}") |
| return global_step |
|
|
|
|
| def log_grad_norm(model, accelerator, global_step): |
| for name, param in model.named_parameters(): |
| if param.grad is not None: |
| grads = param.grad.detach().data |
| grad_norm = (grads.norm(p=2) / grads.numel()).item() |
| accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) |
|
|