| import torch | |
| import torch.nn as nn | |
| from diffusers import ModelMixin, ConfigMixin | |
| from diffusers.configuration_utils import register_to_config | |
| class ConditionalEmbedder(ModelMixin, ConfigMixin): | |
| """ | |
| Patchifies VAE-encoded conditions (source video or reference image) | |
| into the DiT hidden dimension space via a Conv3d layer. | |
| """ | |
| def __init__( | |
| self, | |
| in_dim: int = 48, | |
| dim: int = 3072, | |
| patch_size: list = [1, 2, 2], | |
| zero_init: bool = True, | |
| ref_pad_first: bool = False, | |
| ): | |
| super().__init__() | |
| kernel_size = tuple(patch_size) | |
| self.patch_embedding = nn.Conv3d( | |
| in_dim, dim, kernel_size=kernel_size, stride=kernel_size | |
| ) | |
| self.ref_pad_first = ref_pad_first | |
| if zero_init: | |
| nn.init.zeros_(self.patch_embedding.weight) | |
| nn.init.zeros_(self.patch_embedding.bias) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.patch_embedding(x) | |