| | from typing import Dict, Optional, Tuple, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from diffusers import AutoencoderKL |
| | from diffusers.configuration_utils import register_to_config |
| | from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution |
| | from diffusers.models.modeling_outputs import AutoencoderKLOutput |
| | from diffusers.utils.accelerate_utils import apply_forward_hook |
| |
|
| |
|
| | class AutoencoderKLNextStep(AutoencoderKL): |
| | @register_to_config |
| | def __init__( |
| | self, |
| | in_channels: int = 3, |
| | out_channels: int = 3, |
| | down_block_types: Tuple[str] = ("DownEncoderBlock2D",), |
| | up_block_types: Tuple[str] = ("UpDecoderBlock2D",), |
| | block_out_channels: Tuple[int] = (64,), |
| | layers_per_block: int = 1, |
| | act_fn: str = "silu", |
| | latent_channels: int = 4, |
| | norm_num_groups: int = 32, |
| | sample_size: int = 32, |
| | scaling_factor: float = 0.18215, |
| | shift_factor: Optional[float] = None, |
| | latents_mean: Optional[Tuple[float]] = None, |
| | latents_std: Optional[Tuple[float]] = None, |
| | force_upcast: bool = True, |
| | use_quant_conv: bool = True, |
| | use_post_quant_conv: bool = True, |
| | mid_block_add_attention: bool = True, |
| | deterministic: bool = False, |
| | normalize_latents: bool = False, |
| | patch_size: Optional[int] = None, |
| | ): |
| | super().__init__( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | down_block_types=down_block_types, |
| | up_block_types=up_block_types, |
| | block_out_channels=block_out_channels, |
| | layers_per_block=layers_per_block, |
| | act_fn=act_fn, |
| | latent_channels=latent_channels, |
| | norm_num_groups=norm_num_groups, |
| | sample_size=sample_size, |
| | scaling_factor=scaling_factor, |
| | shift_factor=shift_factor, |
| | latents_mean=latents_mean, |
| | latents_std=latents_std, |
| | force_upcast=force_upcast, |
| | use_quant_conv=use_quant_conv, |
| | use_post_quant_conv=use_post_quant_conv, |
| | mid_block_add_attention=mid_block_add_attention, |
| | ) |
| | self.deterministic = deterministic |
| | self.normalize_latents = normalize_latents |
| | self.patch_size = patch_size |
| |
|
| | def patchify(self, x: torch.Tensor) -> torch.Tensor: |
| | b, c, h, w = x.shape |
| | p = self.patch_size |
| | h_, w_ = h // p, w // p |
| |
|
| | x = x.reshape(b, c, h_, p, w_, p) |
| | x = torch.einsum("bchpwq->bcpqhw", x) |
| | x = x.reshape(b, c * p ** 2, h_, w_) |
| | return x |
| |
|
| | def unpatchify(self, x: torch.Tensor) -> torch.Tensor: |
| | b, _, h_, w_ = x.shape |
| | p = self.patch_size |
| | c = x.shape[1] // (p ** 2) |
| |
|
| | x = x.reshape(b, c, p, p, h_, w_) |
| | x = torch.einsum("bcpqhw->bchpwq", x) |
| | x = x.reshape(b, c, h_ * p, w_ * p) |
| | return x |
| |
|
| | @apply_forward_hook |
| | def encode( |
| | self, x: torch.Tensor, return_dict: bool = True |
| | ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: |
| | if self.use_slicing and x.shape[0] > 1: |
| | encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] |
| | h = torch.cat(encoded_slices) |
| | else: |
| | h = self._encode(x) |
| |
|
| | mean, logvar = torch.chunk(h, 2, dim=1) |
| | if self.patch_size is not None: |
| | mean = self.patchify(mean) |
| | if self.normalize_latents: |
| | mean = mean.permute(0, 2, 3, 1) |
| | mean = F.layer_norm(mean, mean.shape[-1:], eps=1e-6) |
| | mean = mean.permute(0, 3, 1, 2) |
| | if self.patch_size is not None: |
| | mean = self.unpatchify(mean) |
| | h = torch.cat([mean, logvar], dim=1).contiguous() |
| | posterior = DiagonalGaussianDistribution(h, deterministic=self.deterministic) |
| |
|
| | if not return_dict: |
| | return (posterior,) |
| |
|
| | return AutoencoderKLOutput(latent_dist=posterior) |
| |
|
| | def forward( |
| | self, |
| | sample: torch.Tensor, |
| | sample_posterior: bool = False, |
| | return_dict: bool = True, |
| | generator: Optional[torch.Generator] = None, |
| | noise_strength: float = 0.0, |
| | ) -> Union[DecoderOutput, torch.Tensor]: |
| | x = sample |
| | posterior = self.encode(x).latent_dist |
| | if sample_posterior: |
| | z = posterior.sample(generator=generator) |
| | else: |
| | z = posterior.mode() |
| | if noise_strength > 0.0: |
| | p = torch.distributions.Uniform(0, noise_strength) |
| | z = z + p.sample((z.shape[0],)).reshape(-1, 1, 1, 1).to(z.device) * randn_tensor( |
| | z.shape, device=z.device, dtype=z.dtype |
| | ) |
| | dec = self.decode(z).sample |
| |
|
| | if not return_dict: |
| | return (dec,) |
| |
|
| | return DecoderOutput(sample=dec) |
| |
|