| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """AutoEncoder utilities.""" |
|
|
| from diffusers.models.modeling_outputs import BaseOutput |
| import torch |
|
|
|
|
| class DecoderOutput(BaseOutput): |
| """Output of decoding method.""" |
|
|
| sample: torch.Tensor |
|
|
|
|
| class IdentityDistribution(object): |
| """IdentityGaussianDistribution.""" |
|
|
| def __init__(self, z): |
| self.parameters = z |
|
|
| def sample(self, generator=None): |
| return self.parameters |
|
|
|
|
| class DiagonalGaussianDistribution(object): |
| """DiagonalGaussianDistribution.""" |
|
|
| def __init__(self, z): |
| self.parameters = z |
| self.device, self.dtype = z.device, z.dtype |
| if z.size(1) % 2: |
| z = torch.cat([z, z[:, -1:].expand((-1, z.shape[1] - 2) + (-1,) * (z.dim() - 2))], 1) |
| self.mean, self.logvar = z.float().chunk(2, dim=1) |
| self.logvar = self.logvar.clamp(-30.0, 20.0) |
| self.std, self.var = self.logvar.mul(0.5).exp_(), self.logvar.exp() |
|
|
| def sample(self, generator=None) -> torch.Tensor: |
| device, dtype = self.mean.device, self.mean.dtype |
| norm_dist = torch.randn(self.mean.shape, generator=generator, device=device, dtype=dtype) |
| return norm_dist.mul_(self.std).add_(self.mean).to(device=self.device, dtype=self.dtype) |
|
|
|
|
| class TilingMixin(object): |
| """Base class for input tiling. |
| |
| Shape hints: |
| |
| print(torch.Size((1, 256, 17, 480, 768)).numel() < 2147483647, "Supported") |
| print(torch.Size((1, 256, 17, 576, 1024)).numel() > 2147483647, "Unsupported") |
| |
| """ |
|
|
| def __init__(self, sample_min_t=17, latent_min_t=5, sample_ovr_t=1, latent_ovr_t=0): |
| self.sample_min_t, self.latent_min_t = sample_min_t, latent_min_t |
| self.sample_ovr_t, self.latent_ovr_t = sample_ovr_t, latent_ovr_t |
|
|
| def tiled_encoder(self, x) -> torch.Tensor: |
| if x.dim() == 4 or x.size(2) <= self.sample_min_t: |
| return self.encoder(x) |
| t = x.shape[2] |
| t_start = [i for i in range(0, t, self.sample_min_t - self.sample_ovr_t)] |
| t_slice = [slice(i, i + self.sample_min_t) for i in t_start] |
| t_tiles = [self.encoder(x[:, :, s]) for s in t_slice if s.stop <= t] |
| t_tiles = [x[:, :, self.latent_ovr_t :] if i else x for i, x in enumerate(t_tiles)] |
| return torch.cat(t_tiles, dim=2) |
|
|
| def tiled_decoder(self, x, **kwargs) -> torch.Tensor: |
| if x.dim() == 4 or x.size(2) <= self.latent_min_t: |
| return self.decoder(x, **kwargs) |
| t = x.shape[2] |
| t_start = [i for i in range(0, t, self.latent_min_t - self.latent_ovr_t)] |
| t_slice = [slice(i, i + self.latent_min_t) for i in t_start] |
| t_tiles = [self.decoder(x[:, :, s], **kwargs) for s in t_slice if s.stop <= t] |
| t_tiles = [x[:, :, self.sample_ovr_t :] if i else x for i, x in enumerate(t_tiles)] |
| return torch.cat(t_tiles, dim=2) |
|
|
|
|
| class HybridMixin(object): |
| """Base class for hybrid module.""" |
|
|
| def forward(self, x) -> torch.Tensor: |
| return self.forward_image(x) if x.dim() == 4 else self.forward_video(x) |
|
|