| from model.unet import ScaleAt |
| from model.latentnet import * |
| from diffusion.resample import UniformSampler |
| from diffusion.diffusion import space_timesteps |
| from typing import Tuple |
|
|
| from torch.utils.data import DataLoader |
|
|
| from config_base import BaseConfig |
| from dataset import * |
| from diffusion import * |
| from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule |
| from model import * |
| from choices import * |
| from multiprocessing import get_context |
| import os |
| from dataset_util import * |
| from torch.utils.data.distributed import DistributedSampler |
|
|
| data_paths = { |
| 'ffhqlmdb256': |
| os.path.expanduser('datasets/ffhq256.lmdb'), |
| |
| 'celeba': |
| os.path.expanduser('datasets/celeba'), |
| |
| 'celebalmdb': |
| os.path.expanduser('datasets/celeba.lmdb'), |
| 'celebahq': |
| os.path.expanduser('datasets/celebahq256.lmdb'), |
| 'horse256': |
| os.path.expanduser('datasets/horse256.lmdb'), |
| 'bedroom256': |
| os.path.expanduser('datasets/bedroom256.lmdb'), |
| 'celeba_anno': |
| os.path.expanduser('datasets/celeba_anno/list_attr_celeba.txt'), |
| 'celebahq_anno': |
| os.path.expanduser( |
| 'datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt'), |
| 'celeba_relight': |
| os.path.expanduser('datasets/celeba_hq_light/celeba_light.txt'), |
| } |
|
|
|
|
| @dataclass |
| class PretrainConfig(BaseConfig): |
| name: str |
| path: str |
|
|
|
|
| @dataclass |
| class TrainConfig(BaseConfig): |
| |
| seed: int = 0 |
| train_mode: TrainMode = TrainMode.diffusion |
| train_cond0_prob: float = 0 |
| train_pred_xstart_detach: bool = True |
| train_interpolate_prob: float = 0 |
| train_interpolate_img: bool = False |
| manipulate_mode: ManipulateMode = ManipulateMode.celebahq_all |
| manipulate_cls: str = None |
| manipulate_shots: int = None |
| manipulate_loss: ManipulateLossType = ManipulateLossType.bce |
| manipulate_znormalize: bool = False |
| manipulate_seed: int = 0 |
| accum_batches: int = 1 |
| autoenc_mid_attn: bool = True |
| batch_size: int = 16 |
| batch_size_eval: int = None |
| beatgans_gen_type: GenerativeType = GenerativeType.ddim |
| beatgans_loss_type: LossType = LossType.mse |
| beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps |
| beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large |
| beatgans_rescale_timesteps: bool = False |
| latent_infer_path: str = None |
| latent_znormalize: bool = False |
| latent_gen_type: GenerativeType = GenerativeType.ddim |
| latent_loss_type: LossType = LossType.mse |
| latent_model_mean_type: ModelMeanType = ModelMeanType.eps |
| latent_model_var_type: ModelVarType = ModelVarType.fixed_large |
| latent_rescale_timesteps: bool = False |
| latent_T_eval: int = 1_000 |
| latent_clip_sample: bool = False |
| latent_beta_scheduler: str = 'linear' |
| beta_scheduler: str = 'linear' |
| data_name: str = '' |
| data_val_name: str = None |
| diffusion_type: str = None |
| dropout: float = 0.1 |
| ema_decay: float = 0.9999 |
| eval_num_images: int = 5_000 |
| eval_every_samples: int = 200_000 |
| eval_ema_every_samples: int = 200_000 |
| fid_use_torch: bool = True |
| fp16: bool = False |
| grad_clip: float = 1 |
| img_size: int = 64 |
| lr: float = 0.0001 |
| optimizer: OptimizerType = OptimizerType.adam |
| weight_decay: float = 0 |
| model_conf: ModelConfig = None |
| model_name: ModelName = None |
| model_type: ModelType = None |
| net_attn: Tuple[int] = None |
| net_beatgans_attn_head: int = 1 |
| |
| net_beatgans_embed_channels: int = 512 |
| net_resblock_updown: bool = True |
| net_enc_use_time: bool = False |
| net_enc_pool: str = 'adaptivenonzero' |
| net_beatgans_gradient_checkpoint: bool = False |
| net_beatgans_resnet_two_cond: bool = False |
| net_beatgans_resnet_use_zero_module: bool = True |
| net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm |
| net_beatgans_resnet_cond_channels: int = None |
| net_ch_mult: Tuple[int] = None |
| net_ch: int = 64 |
| net_enc_attn: Tuple[int] = None |
| net_enc_k: int = None |
| |
| net_enc_num_res_blocks: int = 2 |
| net_enc_channel_mult: Tuple[int] = None |
| net_enc_grad_checkpoint: bool = False |
| net_autoenc_stochastic: bool = False |
| net_latent_activation: Activation = Activation.silu |
| net_latent_channel_mult: Tuple[int] = (1, 2, 4) |
| net_latent_condition_bias: float = 0 |
| net_latent_dropout: float = 0 |
| net_latent_layers: int = None |
| net_latent_net_last_act: Activation = Activation.none |
| net_latent_net_type: LatentNetType = LatentNetType.none |
| net_latent_num_hid_channels: int = 1024 |
| net_latent_num_time_layers: int = 2 |
| net_latent_skip_layers: Tuple[int] = None |
| net_latent_time_emb_channels: int = 64 |
| net_latent_use_norm: bool = False |
| net_latent_time_last_act: bool = False |
| net_num_res_blocks: int = 2 |
| |
| net_num_input_res_blocks: int = None |
| net_enc_num_cls: int = None |
| num_workers: int = 4 |
| parallel: bool = False |
| postfix: str = '' |
| sample_size: int = 64 |
| sample_every_samples: int = 20_000 |
| save_every_samples: int = 100_000 |
| style_ch: int = 512 |
| T_eval: int = 1_000 |
| T_sampler: str = 'uniform' |
| T: int = 1_000 |
| total_samples: int = 10_000_000 |
| warmup: int = 0 |
| pretrain: PretrainConfig = None |
| continue_from: PretrainConfig = None |
| eval_programs: Tuple[str] = None |
| |
| eval_path: str = None |
| base_dir: str = 'checkpoints' |
| use_cache_dataset: bool = False |
| data_cache_dir: str = os.path.expanduser('~/cache') |
| work_cache_dir: str = os.path.expanduser('~/mycache') |
| |
| name: str = '' |
|
|
| def __post_init__(self): |
| self.batch_size_eval = self.batch_size_eval or self.batch_size |
| self.data_val_name = self.data_val_name or self.data_name |
|
|
| def scale_up_gpus(self, num_gpus, num_nodes=1): |
| self.eval_ema_every_samples *= num_gpus * num_nodes |
| self.eval_every_samples *= num_gpus * num_nodes |
| self.sample_every_samples *= num_gpus * num_nodes |
| self.batch_size *= num_gpus * num_nodes |
| self.batch_size_eval *= num_gpus * num_nodes |
| return self |
|
|
| @property |
| def batch_size_effective(self): |
| return self.batch_size * self.accum_batches |
|
|
| @property |
| def fid_cache(self): |
| |
| |
| return f'{self.work_cache_dir}/eval_images/{self.data_name}_size{self.img_size}_{self.eval_num_images}' |
|
|
| @property |
| def data_path(self): |
| |
| path = data_paths[self.data_name] |
| if self.use_cache_dataset and path is not None: |
| path = use_cached_dataset_path( |
| path, f'{self.data_cache_dir}/{self.data_name}') |
| return path |
|
|
| @property |
| def logdir(self): |
| return f'{self.base_dir}/{self.name}' |
|
|
| @property |
| def generate_dir(self): |
| |
| |
| return f'{self.work_cache_dir}/gen_images/{self.name}' |
|
|
| def _make_diffusion_conf(self, T=None): |
| if self.diffusion_type == 'beatgans': |
| |
| |
| |
| if self.beatgans_gen_type == GenerativeType.ddpm: |
| section_counts = [T] |
| elif self.beatgans_gen_type == GenerativeType.ddim: |
| section_counts = f'ddim{T}' |
| else: |
| raise NotImplementedError() |
|
|
| return SpacedDiffusionBeatGansConfig( |
| gen_type=self.beatgans_gen_type, |
| model_type=self.model_type, |
| betas=get_named_beta_schedule(self.beta_scheduler, self.T), |
| model_mean_type=self.beatgans_model_mean_type, |
| model_var_type=self.beatgans_model_var_type, |
| loss_type=self.beatgans_loss_type, |
| rescale_timesteps=self.beatgans_rescale_timesteps, |
| use_timesteps=space_timesteps(num_timesteps=self.T, |
| section_counts=section_counts), |
| fp16=self.fp16, |
| ) |
| else: |
| raise NotImplementedError() |
|
|
| def _make_latent_diffusion_conf(self, T=None): |
| |
| |
| |
| if self.latent_gen_type == GenerativeType.ddpm: |
| section_counts = [T] |
| elif self.latent_gen_type == GenerativeType.ddim: |
| section_counts = f'ddim{T}' |
| else: |
| raise NotImplementedError() |
|
|
| return SpacedDiffusionBeatGansConfig( |
| train_pred_xstart_detach=self.train_pred_xstart_detach, |
| gen_type=self.latent_gen_type, |
| |
| model_type=ModelType.ddpm, |
| |
| betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T), |
| model_mean_type=self.latent_model_mean_type, |
| model_var_type=self.latent_model_var_type, |
| loss_type=self.latent_loss_type, |
| rescale_timesteps=self.latent_rescale_timesteps, |
| use_timesteps=space_timesteps(num_timesteps=self.T, |
| section_counts=section_counts), |
| fp16=self.fp16, |
| ) |
|
|
| @property |
| def model_out_channels(self): |
| return 3 |
|
|
| def make_T_sampler(self): |
| if self.T_sampler == 'uniform': |
| return UniformSampler(self.T) |
| else: |
| raise NotImplementedError() |
|
|
| def make_diffusion_conf(self): |
| return self._make_diffusion_conf(self.T) |
|
|
| def make_eval_diffusion_conf(self): |
| return self._make_diffusion_conf(T=self.T_eval) |
|
|
| def make_latent_diffusion_conf(self): |
| return self._make_latent_diffusion_conf(T=self.T) |
|
|
| def make_latent_eval_diffusion_conf(self): |
| |
| return self._make_latent_diffusion_conf(T=self.latent_T_eval) |
|
|
| def make_dataset(self, path=None, **kwargs): |
| if self.data_name == 'ffhqlmdb256': |
| return FFHQlmdb(path=path or self.data_path, |
| image_size=self.img_size, |
| **kwargs) |
| elif self.data_name == 'horse256': |
| return Horse_lmdb(path=path or self.data_path, |
| image_size=self.img_size, |
| **kwargs) |
| elif self.data_name == 'bedroom256': |
| return Horse_lmdb(path=path or self.data_path, |
| image_size=self.img_size, |
| **kwargs) |
| elif self.data_name == 'celebalmdb': |
| |
| return CelebAlmdb(path=path or self.data_path, |
| image_size=self.img_size, |
| original_resolution=None, |
| crop_d2c=True, |
| **kwargs) |
| else: |
| raise NotImplementedError() |
|
|
| def make_loader(self, |
| dataset, |
| shuffle: bool, |
| num_worker: bool = None, |
| drop_last: bool = True, |
| batch_size: int = None, |
| parallel: bool = False): |
| if parallel and distributed.is_initialized(): |
| |
| sampler = DistributedSampler(dataset, |
| shuffle=shuffle, |
| drop_last=True) |
| else: |
| sampler = None |
| return DataLoader( |
| dataset, |
| batch_size=batch_size or self.batch_size, |
| sampler=sampler, |
| |
| shuffle=False if sampler else shuffle, |
| num_workers=num_worker or self.num_workers, |
| pin_memory=True, |
| drop_last=drop_last, |
| multiprocessing_context=get_context('fork'), |
| ) |
|
|
| def make_model_conf(self): |
| if self.model_name == ModelName.beatgans_ddpm: |
| self.model_type = ModelType.ddpm |
| self.model_conf = BeatGANsUNetConfig( |
| attention_resolutions=self.net_attn, |
| channel_mult=self.net_ch_mult, |
| conv_resample=True, |
| dims=2, |
| dropout=self.dropout, |
| embed_channels=self.net_beatgans_embed_channels, |
| image_size=self.img_size, |
| in_channels=3, |
| model_channels=self.net_ch, |
| num_classes=None, |
| num_head_channels=-1, |
| num_heads_upsample=-1, |
| num_heads=self.net_beatgans_attn_head, |
| num_res_blocks=self.net_num_res_blocks, |
| num_input_res_blocks=self.net_num_input_res_blocks, |
| out_channels=self.model_out_channels, |
| resblock_updown=self.net_resblock_updown, |
| use_checkpoint=self.net_beatgans_gradient_checkpoint, |
| use_new_attention_order=False, |
| resnet_two_cond=self.net_beatgans_resnet_two_cond, |
| resnet_use_zero_module=self. |
| net_beatgans_resnet_use_zero_module, |
| ) |
| elif self.model_name in [ |
| ModelName.beatgans_autoenc, |
| ]: |
| cls = BeatGANsAutoencConfig |
| |
| if self.model_name == ModelName.beatgans_autoenc: |
| self.model_type = ModelType.autoencoder |
| else: |
| raise NotImplementedError() |
|
|
| if self.net_latent_net_type == LatentNetType.none: |
| latent_net_conf = None |
| elif self.net_latent_net_type == LatentNetType.skip: |
| latent_net_conf = MLPSkipNetConfig( |
| num_channels=self.style_ch, |
| skip_layers=self.net_latent_skip_layers, |
| num_hid_channels=self.net_latent_num_hid_channels, |
| num_layers=self.net_latent_layers, |
| num_time_emb_channels=self.net_latent_time_emb_channels, |
| activation=self.net_latent_activation, |
| use_norm=self.net_latent_use_norm, |
| condition_bias=self.net_latent_condition_bias, |
| dropout=self.net_latent_dropout, |
| last_act=self.net_latent_net_last_act, |
| num_time_layers=self.net_latent_num_time_layers, |
| time_last_act=self.net_latent_time_last_act, |
| ) |
| else: |
| raise NotImplementedError() |
|
|
| self.model_conf = cls( |
| attention_resolutions=self.net_attn, |
| channel_mult=self.net_ch_mult, |
| conv_resample=True, |
| dims=2, |
| dropout=self.dropout, |
| embed_channels=self.net_beatgans_embed_channels, |
| enc_out_channels=self.style_ch, |
| enc_pool=self.net_enc_pool, |
| enc_num_res_block=self.net_enc_num_res_blocks, |
| enc_channel_mult=self.net_enc_channel_mult, |
| enc_grad_checkpoint=self.net_enc_grad_checkpoint, |
| enc_attn_resolutions=self.net_enc_attn, |
| image_size=self.img_size, |
| in_channels=3, |
| model_channels=self.net_ch, |
| num_classes=None, |
| num_head_channels=-1, |
| num_heads_upsample=-1, |
| num_heads=self.net_beatgans_attn_head, |
| num_res_blocks=self.net_num_res_blocks, |
| num_input_res_blocks=self.net_num_input_res_blocks, |
| out_channels=self.model_out_channels, |
| resblock_updown=self.net_resblock_updown, |
| use_checkpoint=self.net_beatgans_gradient_checkpoint, |
| use_new_attention_order=False, |
| resnet_two_cond=self.net_beatgans_resnet_two_cond, |
| resnet_use_zero_module=self. |
| net_beatgans_resnet_use_zero_module, |
| latent_net_conf=latent_net_conf, |
| resnet_cond_channels=self.net_beatgans_resnet_cond_channels, |
| ) |
| else: |
| raise NotImplementedError(self.model_name) |
|
|
| return self.model_conf |
|
|