World_Model / URSA /diffnext /models /autoencoders /modeling_utils.py
BryanW's picture
Add files using upload-large-folder tool
d403233 verified
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""AutoEncoder utilities."""
from diffusers.models.modeling_outputs import BaseOutput
import torch
class DecoderOutput(BaseOutput):
"""Output of decoding method."""
sample: torch.Tensor
class IdentityDistribution(object):
"""IdentityGaussianDistribution."""
def __init__(self, z):
self.parameters = z
def sample(self, generator=None):
return self.parameters
class DiagonalGaussianDistribution(object):
"""DiagonalGaussianDistribution."""
def __init__(self, z):
self.parameters = z
self.device, self.dtype = z.device, z.dtype
if z.size(1) % 2:
z = torch.cat([z, z[:, -1:].expand((-1, z.shape[1] - 2) + (-1,) * (z.dim() - 2))], 1)
self.mean, self.logvar = z.float().chunk(2, dim=1)
self.logvar = self.logvar.clamp(-30.0, 20.0)
self.std, self.var = self.logvar.mul(0.5).exp_(), self.logvar.exp()
def sample(self, generator=None) -> torch.Tensor:
device, dtype = self.mean.device, self.mean.dtype
norm_dist = torch.randn(self.mean.shape, generator=generator, device=device, dtype=dtype)
return norm_dist.mul_(self.std).add_(self.mean).to(device=self.device, dtype=self.dtype)
class TilingMixin(object):
"""Base class for input tiling.
Shape hints:
print(torch.Size((1, 256, 17, 480, 768)).numel() < 2147483647, "Supported")
print(torch.Size((1, 256, 17, 576, 1024)).numel() > 2147483647, "Unsupported")
"""
def __init__(self, sample_min_t=17, latent_min_t=5, sample_ovr_t=1, latent_ovr_t=0):
self.sample_min_t, self.latent_min_t = sample_min_t, latent_min_t
self.sample_ovr_t, self.latent_ovr_t = sample_ovr_t, latent_ovr_t
def tiled_encoder(self, x) -> torch.Tensor:
if x.dim() == 4 or x.size(2) <= self.sample_min_t:
return self.encoder(x)
t = x.shape[2]
t_start = [i for i in range(0, t, self.sample_min_t - self.sample_ovr_t)]
t_slice = [slice(i, i + self.sample_min_t) for i in t_start]
t_tiles = [self.encoder(x[:, :, s]) for s in t_slice if s.stop <= t]
t_tiles = [x[:, :, self.latent_ovr_t :] if i else x for i, x in enumerate(t_tiles)]
return torch.cat(t_tiles, dim=2)
def tiled_decoder(self, x, **kwargs) -> torch.Tensor:
if x.dim() == 4 or x.size(2) <= self.latent_min_t:
return self.decoder(x, **kwargs)
t = x.shape[2]
t_start = [i for i in range(0, t, self.latent_min_t - self.latent_ovr_t)]
t_slice = [slice(i, i + self.latent_min_t) for i in t_start]
t_tiles = [self.decoder(x[:, :, s], **kwargs) for s in t_slice if s.stop <= t]
t_tiles = [x[:, :, self.sample_ovr_t :] if i else x for i, x in enumerate(t_tiles)]
return torch.cat(t_tiles, dim=2)
class HybridMixin(object):
"""Base class for hybrid module."""
def forward(self, x) -> torch.Tensor:
return self.forward_image(x) if x.dim() == 4 else self.forward_video(x)