# -------------------------------------------------------- # References: # SiT: https://github.com/willisma/SiT # Lightning-DiT: https://github.com/hustvl/LightningDiT # JiT: https://github.com/LTH14/JiT # REPA: https://github.com/sihyun-yu/REPA # RAE: https://github.com/bytetriper/RAE # -------------------------------------------------------- import torch import torch.nn as nn import math import torch.nn.functional as F from util.model_util import VisionRotaryEmbeddingFast, get_2d_sincos_pos_embed, RMSNorm def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) class BottleneckPatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, pca_dim=768, embed_dim=768, bias=True): super().__init__() img_size = (img_size, img_size) patch_size = (patch_size, patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj1 = nn.Conv2d(in_chans, pca_dim, kernel_size=patch_size, stride=patch_size, bias=False) self.proj2 = nn.Conv2d(pca_dim, embed_dim, kernel_size=1, stride=1, bias=bias) def forward(self, x): B, C, H, W = x.shape assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj2(self.proj1(x)).flatten(2).transpose(1, 2) return x class NoBottleneckPatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, pca_dim=768, embed_dim=768, bias=True): super().__init__() img_size = (img_size, img_size) patch_size = (patch_size, patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) def forward(self, x): B, C, H, W = x.shape assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2) return x class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. """ def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(t.dtype) t_emb = self.mlp(t_freq) return t_emb class LabelEmbedder(nn.Module): """ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. """ def __init__(self, num_classes, hidden_size): super().__init__() self.embedding_table = nn.Embedding(num_classes + 1, hidden_size) self.num_classes = num_classes def forward(self, labels): embeddings = self.embedding_table(labels) return embeddings def scaled_dot_product_attention(query, key, value, dropout_p=0.0): return F.scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False ) class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.q_norm = RMSNorm(head_dim) if qk_norm else nn.Identity() self.k_norm = RMSNorm(head_dim) if qk_norm else nn.Identity() self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x, rope): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = self.q_norm(q) k = self.k_norm(k) q = rope(q) k = rope(k) x = scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.) x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class SwiGLUFFN(nn.Module): def __init__( self, dim: int, hidden_dim: int, drop=0.0, bias=True ) -> None: super().__init__() hidden_dim = int(hidden_dim * 2 / 3) self.w12 = nn.Linear(dim, 2 * hidden_dim, bias=bias) self.w3 = nn.Linear(hidden_dim, dim, bias=bias) self.ffn_dropout = nn.Dropout(drop) def forward(self, x): x12 = self.w12(x) x1, x2 = x12.chunk(2, dim=-1) hidden = F.silu(x1) * x2 return self.w3(self.ffn_dropout(hidden)) class FinalLayer(nn.Module): """ The final layer of JiT. """ def __init__(self, hidden_size, out_channels, hidden_size_c=None): super().__init__() self.norm_final = RMSNorm(hidden_size) self.linear = nn.Linear(hidden_size, out_channels, bias=True) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size_c or hidden_size, 2 * hidden_size, bias=True) ) @torch.compile def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x class JiTBlock(nn.Module): def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0, hidden_size_c=None): super().__init__() self.norm1 = RMSNorm(hidden_size, eps=1e-6) self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=True, attn_drop=attn_drop, proj_drop=proj_drop) self.norm2 = RMSNorm(hidden_size, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) self.mlp = SwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size_c or hidden_size, 6 * hidden_size, bias=True) ) @torch.compile def forward(self, x, c, feat_rope=None): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), rope=feat_rope) x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) return x class JiTCoT(nn.Module): """ Just image Transformer. """ def __init__( self, input_size=256, patch_size=16, in_channels=3, hidden_size=1024, depth=24, num_heads=16, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0, num_classes=1000, bottleneck_dim=128, bottleneck_dim_dino=128, in_context_len=32, in_context_start=8, dino_in_channels=768, dh_depth=0, dh_hidden_size=1536, ): super().__init__() self.in_channels = in_channels self.out_channels_pixel = in_channels * patch_size * patch_size self.out_channels_dino = dino_in_channels self.out_channels = self.out_channels_pixel + self.out_channels_dino self.patch_size = patch_size self.num_heads = num_heads self.hidden_size = hidden_size self.input_size = input_size self.in_context_len = in_context_len self.in_context_start = in_context_start self.num_classes = num_classes # time and class embed self.t_pixel_embedder = TimestepEmbedder(hidden_size) self.t_dino_embedder = TimestepEmbedder(hidden_size) self.y_embedder = LabelEmbedder(num_classes, hidden_size) # linear embed self.pixel_embedder = BottleneckPatchEmbed(input_size, patch_size, in_channels, bottleneck_dim, hidden_size, bias=True) self.dino_embedder = BottleneckPatchEmbed(input_size//patch_size, 1, dino_in_channels, bottleneck_dim_dino, hidden_size, bias=True) # use fixed sin-cos embedding num_patches = self.pixel_embedder.num_patches self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) # in-context cls token if self.in_context_len > 0: self.in_context_posemb = nn.Parameter(torch.zeros(1, self.in_context_len, hidden_size), requires_grad=True) torch.nn.init.normal_(self.in_context_posemb, std=.02) # rope half_head_dim = hidden_size // num_heads // 2 hw_seq_len = input_size // patch_size self.feat_rope = VisionRotaryEmbeddingFast( dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=0 ) self.feat_rope_incontext = VisionRotaryEmbeddingFast( dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=self.in_context_len ) # transformer self.blocks = nn.ModuleList([ JiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_drop=attn_drop if (depth // 4 * 3 > i >= depth // 4) else 0.0, proj_drop=proj_drop if (depth // 4 * 3 > i >= depth // 4) else 0.0) for i in range(depth) ]) # linear predict self.dh_depth = dh_depth if self.dh_depth > 0: self.dh_pixel_proj = nn.Linear(hidden_size, dh_hidden_size) self.dh_dino_proj = nn.Linear(hidden_size, dh_hidden_size) self.dh_blocks_pixel = nn.ModuleList([ JiTBlock(dh_hidden_size, num_heads * dh_hidden_size // hidden_size , mlp_ratio=mlp_ratio, hidden_size_c=hidden_size) for _ in range(dh_depth) ]) self.dh_blocks_dino = nn.ModuleList([ JiTBlock(dh_hidden_size, num_heads * dh_hidden_size // hidden_size , mlp_ratio=mlp_ratio, hidden_size_c=hidden_size) for _ in range(dh_depth) ]) self.final_layer_pixel = FinalLayer(dh_hidden_size, self.out_channels_pixel, hidden_size_c=hidden_size) self.final_layer_dino = FinalLayer(dh_hidden_size, self.out_channels_dino, hidden_size_c=hidden_size) else: self.final_layer = FinalLayer(hidden_size, self.out_channels) self.initialize_weights() def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize (and freeze) pos_embed by sin-cos embedding: pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.pixel_embedder.num_patches ** 0.5)) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): w1 = self.pixel_embedder.proj1.weight.data nn.init.xavier_uniform_(w1.view([w1.shape[0], -1])) w2 = self.pixel_embedder.proj2.weight.data nn.init.xavier_uniform_(w2.view([w2.shape[0], -1])) nn.init.constant_(self.pixel_embedder.proj2.bias, 0) w1 = self.dino_embedder.proj1.weight.data nn.init.xavier_uniform_(w1.view([w1.shape[0], -1])) w2 = self.dino_embedder.proj2.weight.data nn.init.xavier_uniform_(w2.view([w2.shape[0], -1])) nn.init.constant_(self.dino_embedder.proj2.bias, 0) # Initialize label embedding table: nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) nn.init.normal_(self.t_pixel_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_pixel_embedder.mlp[2].weight, std=0.02) nn.init.normal_(self.t_dino_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_dino_embedder.mlp[2].weight, std=0.02) # Zero-out adaLN modulation layers: for block in self.blocks: nn.init.constant_(block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.adaLN_modulation[-1].bias, 0) if self.dh_depth > 0: for block in self.dh_blocks_pixel: nn.init.constant_(block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.adaLN_modulation[-1].bias, 0) for block in self.dh_blocks_dino: nn.init.constant_(block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.adaLN_modulation[-1].bias, 0) # w1 = self.pixel_embedder_dh.proj.weight.data # nn.init.xavier_uniform_(w1.view([w1.shape[0], -1])) # nn.init.constant_(self.pixel_embedder_dh.proj.bias, 0) # w1 = self.dino_embedder_dh.proj.weight.data # nn.init.xavier_uniform_(w1.view([w1.shape[0], -1])) # nn.init.constant_(self.dino_embedder_dh.proj.bias, 0) # Zero-out output layers: for final_layer in [self.final_layer] if self.dh_depth == 0 else [self.final_layer_pixel, self.final_layer_dino]: nn.init.constant_(final_layer.adaLN_modulation[-1].weight, 0) nn.init.constant_(final_layer.adaLN_modulation[-1].bias, 0) nn.init.constant_(final_layer.linear.weight, 0) nn.init.constant_(final_layer.linear.bias, 0) def unpatchify(self, x, p, c): """ x: (N, T, patch_size**2 * C) imgs: (N, H, W, C) """ h = w = int(x.shape[1] ** 0.5) assert h * w == x.shape[1] x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) x = torch.einsum('nhwpqc->nchpwq', x) imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) return imgs def forward(self, x, t, y, m=None): """ x: (N, C, H, W), (N D HP WP) t: (N,), (N,) y: (N,) m: Optional (N, 256), (N, 256) """ # class and time embeddings t_emb_pixel = self.t_pixel_embedder(t[0]) t_emb_dino = self.t_dino_embedder(t[1]) y_emb = self.y_embedder(y) c = t_emb_pixel + t_emb_dino + y_emb # Todo try tokens? # forward JiT x_pixel, x_dino = x if m is None: x = self.pixel_embedder(x_pixel) + self.dino_embedder(x_dino) else: x = self.pixel_embedder(x_pixel) * (~m[0][...,None]) + self.dino_embedder(x_dino) * (~m[1][...,None]) x += self.pos_embed for i, block in enumerate(self.blocks): # in-context if self.in_context_len > 0 and i == self.in_context_start: in_context_tokens = y_emb.unsqueeze(1).repeat(1, self.in_context_len, 1) in_context_tokens += self.in_context_posemb x = torch.cat([in_context_tokens, x], dim=1) x = block(x, c, self.feat_rope if i < self.in_context_start else self.feat_rope_incontext) x = x[:, self.in_context_len:] if self.dh_depth == 0: output_pixel, output_dino = self.final_layer(x, c).split([self.out_channels_pixel, self.out_channels_dino], dim=-1) else: dh_feats_pixel = self.dh_pixel_proj(x) dh_feats_dino = self.dh_dino_proj(x) for i, block in enumerate(self.dh_blocks_pixel): dh_feats_pixel = block(dh_feats_pixel, c, self.feat_rope) for i, block in enumerate(self.dh_blocks_dino): dh_feats_dino = block(dh_feats_dino, c, self.feat_rope) output_pixel = self.final_layer_pixel(dh_feats_pixel, c) output_dino = self.final_layer_dino(dh_feats_dino, c) output_pixel = self.unpatchify(output_pixel, self.patch_size, self.in_channels) output_dino = self.unpatchify(output_dino, 1, self.out_channels_dino) output = [output_pixel, output_dino] return output def JiTCoT_S_16(**kwargs): return JiTCoT(depth=12, hidden_size=384, num_heads=6, bottleneck_dim=128, in_context_len=32, in_context_start=4, patch_size=16, **kwargs) def JiTCoT_B_16(**kwargs): return JiTCoT(depth=12, hidden_size=768, num_heads=12, bottleneck_dim=128, in_context_len=32, in_context_start=4, patch_size=16, **kwargs) def JiTCoT_B_32(**kwargs): return JiTCoT(depth=12, hidden_size=768, num_heads=12, bottleneck_dim=128, in_context_len=32, in_context_start=4, patch_size=32, **kwargs) def JiTCoT_L_16(**kwargs): return JiTCoT(depth=24, hidden_size=1024, num_heads=16, bottleneck_dim=128, in_context_len=32, in_context_start=8, patch_size=16, **kwargs) def JiTCoT_LM_16(**kwargs): return JiTCoT(depth=20, hidden_size=1024, num_heads=16, bottleneck_dim=128, in_context_len=32, in_context_start=8, patch_size=16, **kwargs) def JiTCoT_L_32(**kwargs): return JiTCoT(depth=24, hidden_size=1024, num_heads=16, bottleneck_dim=128, in_context_len=32, in_context_start=8, patch_size=32, **kwargs) def JiTCoT_H_16(**kwargs): return JiTCoT(depth=32, hidden_size=1280, num_heads=16, bottleneck_dim=256, in_context_len=32, in_context_start=10, patch_size=16, **kwargs) def JiTCoT_H_32(**kwargs): return JiTCoT(depth=32, hidden_size=1280, num_heads=16, bottleneck_dim=256, in_context_len=32, in_context_start=10, patch_size=32, **kwargs) JiTCoT_models = { 'JiTCoT-S/16': JiTCoT_S_16, 'JiTCoT-B/16': JiTCoT_B_16, 'JiTCoT-B/32': JiTCoT_B_32, 'JiTCoT-L/16': JiTCoT_L_16, 'JiTCoT-LM/16': JiTCoT_LM_16, 'JiTCoT-L/32': JiTCoT_L_32, 'JiTCoT-H/16': JiTCoT_H_16, 'JiTCoT-H/32': JiTCoT_H_32, } # TODO halve heads