| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """3D transformer model for NOVA.""" |
|
|
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.models.modeling_utils import ModelMixin |
|
|
| from diffnext.models.diffusion_mlp import DiffusionMLP |
| from diffnext.models.embeddings import PosEmbed, VideoPosEmbed, RotaryEmbed3D |
| from diffnext.models.embeddings import MaskEmbed, MotionEmbed, TextEmbed, LabelEmbed |
| from diffnext.models.normalization import AdaLayerNorm |
| from diffnext.models.transformers.transformer_nova_base import Transformer3DModel |
| from diffnext.models.vision_transformer import VisionTransformer |
| from diffnext.utils.registry import Registry |
|
|
| VIDEO_ENCODERS = Registry("video_encoders") |
| IMAGE_ENCODERS = Registry("image_encoders") |
| IMAGE_DECODERS = Registry("image_decoders") |
|
|
|
|
| @VIDEO_ENCODERS.register("vit_d16w768", depth=16, embed_dim=768, num_heads=12) |
| @VIDEO_ENCODERS.register("vit_d16w1024", depth=16, embed_dim=1024, num_heads=16) |
| @VIDEO_ENCODERS.register("vit_d16w1536", depth=16, embed_dim=1536, num_heads=16) |
| def video_encoder(depth, embed_dim, num_heads, patch_size, image_size, image_dim): |
| return VisionTransformer(**locals()) |
|
|
|
|
| @IMAGE_ENCODERS.register("vit_d32w768", depth=32, embed_dim=768, num_heads=12) |
| @IMAGE_ENCODERS.register("vit_d32w1024", depth=32, embed_dim=1024, num_heads=16) |
| @IMAGE_ENCODERS.register("vit_d32w1536", depth=32, embed_dim=1536, num_heads=16) |
| def image_encoder(depth, embed_dim, num_heads, patch_size, image_size, image_dim): |
| return VisionTransformer(**locals()) |
|
|
|
|
| @IMAGE_DECODERS.register("mlp_d3w1280", depth=3, embed_dim=1280) |
| @IMAGE_DECODERS.register("mlp_d6w768", depth=6, embed_dim=768) |
| @IMAGE_DECODERS.register("mlp_d6w1024", depth=6, embed_dim=1024) |
| @IMAGE_DECODERS.register("mlp_d6w1536", depth=6, embed_dim=1536) |
| def image_decoder(depth, embed_dim, patch_size, image_dim, cond_dim): |
| return DiffusionMLP(**locals()) |
|
|
|
|
| class NOVATransformer3DModel(Transformer3DModel, ModelMixin, ConfigMixin): |
| """3D transformer model for NOVA.""" |
|
|
| @register_to_config |
| def __init__( |
| self, |
| image_dim=None, |
| image_size=None, |
| image_stride=None, |
| text_token_dim=None, |
| text_token_len=None, |
| image_base_size=None, |
| video_base_size=None, |
| video_mixer_rank=None, |
| rotary_pos_embed=False, |
| arch=("", "", ""), |
| ): |
| image_size = (image_size,) * 2 if isinstance(image_size, int) else image_size |
| image_size = tuple(v // image_stride for v in image_size) |
| image_args = {"image_dim": image_dim, "patch_size": 15 // image_stride + 1} |
| video_args = {**image_args, "patch_size": image_args["patch_size"] * 2} |
| video_encoder = VIDEO_ENCODERS.get(arch[0])(image_size=image_size, **video_args) |
| image_encoder = IMAGE_ENCODERS.get(arch[1])(image_size=image_size, **image_args) |
| image_decoder = IMAGE_DECODERS.get(arch[2])(cond_dim=image_encoder.embed_dim, **image_args) |
| if rotary_pos_embed: |
| video_pos_embed = RotaryEmbed3D(video_encoder.rope.dim, video_base_size[1:]) |
| image_pos_embed = RotaryEmbed3D(image_encoder.rope.dim, image_base_size) |
| else: |
| video_pos_embed = VideoPosEmbed(video_encoder.embed_dim, video_base_size) |
| image_encoder.pos_embed = PosEmbed(image_encoder.embed_dim, image_base_size) |
| image_pos_embed = image_pos_embed if rotary_pos_embed else None |
| if video_mixer_rank: |
| video_mixer_rank = max(video_mixer_rank, 0) |
| video_encoder.mixer = AdaLayerNorm(video_encoder.embed_dim, video_mixer_rank, eps=None) |
| if text_token_dim: |
| text_embed = TextEmbed(text_token_dim, image_encoder.embed_dim, text_token_len) |
| super(NOVATransformer3DModel, self).__init__( |
| video_encoder=video_encoder, |
| image_encoder=image_encoder, |
| image_decoder=image_decoder, |
| mask_embed=MaskEmbed(image_encoder.embed_dim), |
| text_embed=text_embed if text_token_dim else None, |
| label_embed=LabelEmbed(image_encoder.embed_dim) if not text_token_dim else None, |
| video_pos_embed=video_pos_embed, |
| image_pos_embed=image_pos_embed, |
| motion_embed=MotionEmbed(video_encoder.embed_dim) if video_base_size[0] > 1 else None, |
| ) |
|
|