# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # References: # GLIDE: https://github.com/openai/glide-text2im # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py # -------------------------------------------------------- import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math from timm.models.vision_transformer import PatchEmbed, Attention, Mlp def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) ################################################################################# # Embedding Layers for Timesteps and Class Labels # ################################################################################# class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. """ def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) t_emb = self.mlp(t_freq) return t_emb class LabelEmbedder(nn.Module): """ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. """ def __init__(self, num_classes, hidden_size, dropout_prob): super().__init__() use_cfg_embedding = dropout_prob > 0 self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) self.num_classes = num_classes self.dropout_prob = dropout_prob def token_drop(self, labels, force_drop_ids=None): """ Drops labels to enable classifier-free guidance. """ if force_drop_ids is None: drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob else: drop_ids = force_drop_ids == 1 labels = torch.where(drop_ids, self.num_classes, labels) return labels def forward(self, labels, train, force_drop_ids=None): use_dropout = self.dropout_prob > 0 if (train and use_dropout) or (force_drop_ids is not None): labels = self.token_drop(labels, force_drop_ids) embeddings = self.embedding_table(labels) return embeddings ################################################################################# # Core DiT Model # ################################################################################# class DiTBlock(nn.Module): """ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. """ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): super().__init__() self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) def forward(self, x, c): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) return x class CrossAttentionDiTBlock(nn.Module): """ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning and cross-attention for text/CLIP embeddings. """ def __init__(self, hidden_size, num_heads, context_dim=None, mlp_ratio=4.0, **block_kwargs): super().__init__() # Self-attention self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) # Cross-attention self.norm_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.cross_attn = CrossAttention( query_dim=hidden_size, context_dim=context_dim if context_dim is not None else hidden_size, heads=num_heads, dim_head=hidden_size // num_heads, dropout=0.0 ) # MLP self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) # AdaLN modulation (9 parameters: 3 for self-attn, 3 for cross-attn, 3 for mlp) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 9 * hidden_size, bias=True) ) def forward(self, x, c, context=None): """ x: (N, T, D) - spatial tokens c: (N, D) - timestep + class conditioning context: (N, L, context_dim) - CLIP text embeddings """ shift_msa, scale_msa, gate_msa, shift_ca, scale_ca, gate_ca, shift_mlp, scale_mlp, gate_mlp = \ self.adaLN_modulation(c).chunk(9, dim=1) # Self-attention x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) # Cross-attention if context is not None: x = x + gate_ca.unsqueeze(1) * self.cross_attn( modulate(self.norm_context(x), shift_ca, scale_ca), context=context ) # MLP x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) return x class FinalLayer(nn.Module): """ The final layer of DiT. """ def __init__(self, hidden_size, patch_size, out_channels): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) ) def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x class DiT(nn.Module): """ Diffusion model with a Transformer backbone. """ def __init__( self, input_size=32, patch_size=2, in_channels=4, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, num_classes=1000, learn_sigma=False, ): super().__init__() self.learn_sigma = learn_sigma self.in_channels = in_channels self.out_channels = in_channels * 2 if learn_sigma else in_channels self.patch_size = patch_size self.num_heads = num_heads self.hidden_size = hidden_size self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) self.t_embedder = TimestepEmbedder(hidden_size) self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) num_patches = self.x_embedder.num_patches # Will use fixed sin-cos embedding: self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) self.blocks = nn.ModuleList([ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) ]) self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) self.initialize_weights() def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize (and freeze) pos_embed by sin-cos embedding: pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): w = self.x_embedder.proj.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) nn.init.constant_(self.x_embedder.proj.bias, 0) # Initialize label embedding table: nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) # Initialize timestep embedding MLP: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) # Zero-out adaLN modulation layers in DiT blocks: for block in self.blocks: nn.init.constant_(block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.adaLN_modulation[-1].bias, 0) # Zero-out output layers: nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) def unpatchify(self, x): """ x: (N, T, patch_size**2 * C) imgs: (N, H, W, C) """ c = self.out_channels p = self.x_embedder.patch_size[0] h = w = int(x.shape[1] ** 0.5) assert h * w == x.shape[1] x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) x = torch.einsum('nhwpqc->nchpwq', x) imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) return imgs def forward(self, x, t, y): """ Forward pass of DiT. x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) t: (N,) tensor of diffusion timesteps y: (N,) tensor of class labels """ x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 t = self.t_embedder(t) # (N, D) y = self.y_embedder(y, self.training) # (N, D) c = t + y # (N, D) for block in self.blocks: x = block(x, c) # (N, T, D) x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) x = self.unpatchify(x) # (N, out_channels, H, W) return x def forward_with_cfg(self, x, t, y, cfg_scale): """ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. """ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb half = x[: len(x) // 2] combined = torch.cat([half, half], dim=0) model_out = self.forward(combined, t, y) # For exact reproducibility reasons, we apply classifier-free guidance on only # three channels by default. The standard approach to cfg applies it to all channels. # This can be done by uncommenting the following line and commenting-out the line following that. # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] eps, rest = model_out[:, :3], model_out[:, 3:] cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) return torch.cat([eps, rest], dim=1) ################################################################################# # Sine/Cosine Positional Embedding Functions # ################################################################################# # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token and extra_tokens > 0: pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2. omega = 1. / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb ################################################################################# # Stable Diffusion U-Net # ################################################################################# def timestep_embedding(timesteps, dim, max_period=10000): """ Create sinusoidal timestep embeddings. """ half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding class Upsample(nn.Module): """A 2D upsampling layer with an optional convolution.""" def __init__(self, channels, use_conv, dims=2): super().__init__() self.channels = channels self.use_conv = use_conv self.dims = dims if use_conv: self.conv = nn.Conv2d(channels, channels, 3, padding=1, padding_mode='reflect') def forward(self, x): assert x.shape[1] == self.channels if self.dims == 2: x = F.interpolate(x, scale_factor=2, mode="nearest") if self.use_conv: x = self.conv(x) return x class Downsample(nn.Module): """A 2D downsampling layer with an optional convolution.""" def __init__(self, channels, use_conv, dims=2): super().__init__() self.channels = channels self.use_conv = use_conv self.dims = dims stride = 2 if dims != 3 else (1, 2, 2) if use_conv: self.op = nn.Conv2d(channels, channels, 3, stride=stride, padding=1, padding_mode='reflect') else: self.op = nn.AvgPool2d(kernel_size=stride, stride=stride) def forward(self, x): assert x.shape[1] == self.channels return self.op(x) class ResBlock(nn.Module): """ A residual block that can optionally change the number of channels. """ def __init__( self, channels, emb_channels, dropout=0.0, out_channels=None, use_conv=False, dims=2, use_checkpoint=False, up=False, down=False, ): super().__init__() self.channels = channels self.emb_channels = emb_channels self.dropout = dropout self.out_channels = out_channels or channels self.use_conv = use_conv self.use_checkpoint = use_checkpoint self.in_layers = nn.Sequential( nn.GroupNorm(32, channels), nn.SiLU(), nn.Conv2d(channels, self.out_channels, 3, padding=1, padding_mode='reflect'), ) self.updown = up or down if up: self.h_upd = Upsample(channels, False, dims) self.x_upd = Upsample(channels, False, dims) elif down: self.h_upd = Downsample(channels, False, dims) self.x_upd = Downsample(channels, False, dims) else: self.h_upd = self.x_upd = nn.Identity() self.emb_layers = nn.Sequential( nn.SiLU(), nn.Linear(emb_channels, self.out_channels), ) self.out_layers = nn.Sequential( nn.GroupNorm(32, self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1, padding_mode='reflect'), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = nn.Conv2d(channels, self.out_channels, 3, padding=1, padding_mode='reflect') else: self.skip_connection = nn.Conv2d(channels, self.out_channels, 1) def forward(self, x, emb): """ Apply the block to a Tensor, conditioned on a timestep embedding. """ if self.updown: in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] h = in_rest(x) h = self.h_upd(h) x = self.x_upd(x) h = in_conv(h) else: h = self.in_layers(x) if emb is not None: emb_out = self.emb_layers(emb).type(h.dtype) while len(emb_out.shape) < len(h.shape): emb_out = emb_out[..., None] h = h + emb_out h = self.out_layers(h) return self.skip_connection(x) + h class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. """ def __init__(self, channels, num_head_channels=1): super().__init__() self.channels = channels self.num_heads = channels // num_head_channels self.norm = nn.GroupNorm(32, channels) self.qkv = nn.Conv1d(channels, channels * 3, 1) self.attention = QKVAttention(self.num_heads) self.proj_out = nn.Conv1d(channels, channels, 1) def forward(self, x): b, c, *spatial = x.shape x = x.reshape(b, c, -1) qkv = self.qkv(self.norm(x)) h = self.attention(qkv) h = self.proj_out(h) return (x + h).reshape(b, c, *spatial) class QKVAttention(nn.Module): """ A module which performs QKV attention. """ def __init__(self, n_heads): super().__init__() self.n_heads = n_heads def forward(self, qkv): """ Apply QKV attention. """ bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) q, k, v = qkv.chunk(3, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = torch.einsum( "bct,bcs->bts", (q * scale).view(bs * self.n_heads, ch, length), (k * scale).view(bs * self.n_heads, ch, length), ) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) return a.reshape(bs, -1, length) class CrossAttention(nn.Module): """ Cross attention block for text conditioning (used in SD 1.5). """ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads context_dim = context_dim if context_dim is not None else query_dim self.scale = dim_head ** -0.5 self.heads = heads self.query_dim = query_dim self.context_dim = context_dim self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) def forward(self, x, context=None): """ x: (B, N, D) where N is spatial dimension context: (B, M, context_dim) text embeddings or None """ h = self.heads # If context is None and context_dim != query_dim, skip cross-attention # (return identity) because we can't use x as context due to dimension mismatch if context is None: if self.context_dim != self.query_dim: # Skip cross-attention when context is None and dimensions don't match return x context = x q = self.to_q(x) k = self.to_k(context) v = self.to_v(context) q, k, v = map(lambda t: t.view(t.shape[0], t.shape[1], h, -1).transpose(1, 2), (q, k, v)) sim = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale attn = sim.softmax(dim=-1) out = torch.einsum('bhij,bhjd->bhid', attn, v) out = out.transpose(1, 2).reshape(x.shape[0], x.shape[1], -1) return self.to_out(out) class SpatialTransformer(nn.Module): """ Spatial Transformer block used in Stable Diffusion 1.5. Contains self-attention and cross-attention for text conditioning. """ def __init__( self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None, use_checkpoint=False, ): super().__init__() self.in_channels = in_channels inner_dim = n_heads * d_head self.norm = nn.GroupNorm(32, in_channels, eps=1e-6, affine=True) self.proj_in = nn.Linear(in_channels, inner_dim) self.transformer_blocks = nn.ModuleList([ BasicTransformerBlock( inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim ) for d in range(depth) ]) self.proj_out = nn.Linear(inner_dim, in_channels) def forward(self, x, context=None): """ x: (B, C, H, W) spatial feature map context: (B, seq_len, context_dim) text embeddings or None """ b, c, h, w = x.shape x_in = x x = self.norm(x) x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) # (B, H*W, C) x = self.proj_in(x) for block in self.transformer_blocks: x = block(x, context=context) x = self.proj_out(x) x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) # (B, C, H, W) return x + x_in class BasicTransformerBlock(nn.Module): """ Basic transformer block with self-attention and cross-attention. """ def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None): super().__init__() self.attn1 = CrossAttention( query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout ) # Self-attention self.ff = FeedForward(dim, dropout=dropout) self.attn2 = CrossAttention( query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout ) # Cross-attention self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) def forward(self, x, context=None): # Self-attention x = x + self.attn1(self.norm1(x)) # Cross-attention x = x + self.attn2(self.norm2(x), context=context) # Feed-forward x = x + self.ff(self.norm3(x)) return x class FeedForward(nn.Module): """ Feed-forward network in transformer block. """ def __init__(self, dim, dropout=0.0, mult=4.0): super().__init__() inner_dim = int(dim * mult) self.net = nn.Sequential( nn.Linear(dim, inner_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(inner_dim, dim), nn.Dropout(dropout), ) def forward(self, x): return self.net(x) class UNetModel(nn.Module): """ U-Net model for diffusion (compatible with Stable Diffusion 1.5 and 2.0). Complete implementation with Spatial Transformer support for cross-attention. - SD1.5: context_dim=768 (CLIP ViT-L/14) - SD2.0: context_dim=1024 (OpenCLIP ViT-H/14) """ def __init__( self, in_channels=4, out_channels=None, model_channels=320, attention_resolutions=(4, 2, 1), num_res_blocks=2, channel_mult=(1, 2, 4, 4), num_head_channels=1, # SD 1.5 uses 1 use_spatial_transformer=True, # Enable for SD 1.5 use_linear_in_transformer=False, transformer_depth=1, context_dim=768, # CLIP text embedding dimension num_classes=1000, learn_sigma=False, class_dropout_prob=0.1, use_pos_conditioning=False, # Enable position conditioning via condition_net ): super().__init__() if out_channels is None: out_channels = in_channels * 2 if learn_sigma else in_channels self.in_channels = in_channels self.out_channels = out_channels self.model_channels = model_channels self.num_classes = num_classes self.learn_sigma = learn_sigma self.class_dropout_prob = class_dropout_prob self.use_spatial_transformer = use_spatial_transformer self.context_dim = context_dim self.use_pos_conditioning = use_pos_conditioning # Position condition network (optional): 将 pos_info (2维) 转换为 context_dim 维的特征 if use_pos_conditioning: pos_dim = 2 condition_net_dim = self.context_dim if self.context_dim is not None else 768 self.condition_net = nn.Sequential( nn.Linear(pos_dim, condition_net_dim), nn.SiLU(), nn.Linear(condition_net_dim, condition_net_dim), nn.SiLU(), nn.Linear(condition_net_dim, condition_net_dim), ) else: self.condition_net = None # Time embedding time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( nn.Linear(model_channels, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim), ) # Class embedding self.label_embedder = LabelEmbedder(num_classes, time_embed_dim, class_dropout_prob) # Input block self.input_blocks = nn.ModuleList([ nn.Sequential(nn.Conv2d(in_channels, model_channels, 3, padding=1, padding_mode='reflect')) ]) input_block_chans = [model_channels] ch = model_channels ds = 1 # Calculate number of heads for spatial transformer num_heads = model_channels // num_head_channels for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): layers = [ ResBlock( ch, time_embed_dim, dropout=0.0, out_channels=int(mult * model_channels), dims=2, use_checkpoint=False, ) ] ch = int(mult * model_channels) # Add attention layer if ds in attention_resolutions: if use_spatial_transformer and context_dim is not None: # Use SpatialTransformer for SD 1.5 layers.append( SpatialTransformer( ch, num_heads, num_head_channels, depth=transformer_depth, dropout=0.0, context_dim=context_dim ) ) else: # Fallback to simple attention block layers.append(AttentionBlock(ch, num_head_channels=num_head_channels)) self.input_blocks.append(nn.ModuleList(layers)) input_block_chans.append(ch) if level != len(channel_mult) - 1: self.input_blocks.append(Downsample(ch, use_conv=True, dims=2)) input_block_chans.append(ch) ds *= 2 # Middle block if use_spatial_transformer and context_dim is not None: self.middle_block = nn.ModuleList([ ResBlock(ch, time_embed_dim, dropout=0.0, dims=2, use_checkpoint=False), SpatialTransformer( ch, num_heads, num_head_channels, depth=transformer_depth, dropout=0.0, context_dim=context_dim ), ResBlock(ch, time_embed_dim, dropout=0.0, dims=2, use_checkpoint=False), ]) else: self.middle_block = nn.ModuleList([ ResBlock(ch, time_embed_dim, dropout=0.0, dims=2, use_checkpoint=False), AttentionBlock(ch, num_head_channels=num_head_channels), ResBlock(ch, time_embed_dim, dropout=0.0, dims=2, use_checkpoint=False), ]) # Output blocks self.output_blocks = nn.ModuleList([]) for level, mult in list(enumerate(channel_mult))[::-1]: for i in range(num_res_blocks + 1): ich = input_block_chans.pop() layers = [ ResBlock( ch + ich, time_embed_dim, dropout=0.0, out_channels=int(model_channels * mult), dims=2, use_checkpoint=False, ) ] ch = int(model_channels * mult) # Add attention layer if ds in attention_resolutions: if use_spatial_transformer and context_dim is not None: layers.append( SpatialTransformer( ch, num_heads, num_head_channels, depth=transformer_depth, dropout=0.0, context_dim=context_dim ) ) else: layers.append(AttentionBlock(ch, num_head_channels=num_head_channels)) if level and i == num_res_blocks: layers.append(Upsample(ch, use_conv=True, dims=2)) ds //= 2 self.output_blocks.append(nn.ModuleList(layers)) # Output layer self.out = nn.Sequential( nn.GroupNorm(32, ch), nn.SiLU(), nn.Conv2d(ch, out_channels, 3, padding=1, padding_mode='reflect'), ) # Initialize time embedding nn.init.zeros_(self.time_embed[0].weight) nn.init.zeros_(self.time_embed[2].weight) def forward(self, x, t, y, condition=None, cross_attn_context=None, pos_context=None): t_emb = timestep_embedding(t, self.model_channels) t_emb = self.time_embed(t_emb) y_emb = self.label_embedder(y, self.training) emb = t_emb + y_emb # Handle position conditioning: convert pos_context to cross_attn_context if needed # If both CLIP embedding and pos_context are provided, concatenate them if pos_context is not None and self.condition_net is not None: pos_embed = self.condition_net(pos_context) # (N, context_dim) pos_embed = pos_embed.unsqueeze(1) # (N, 1, context_dim) if cross_attn_context is not None: # Concatenate CLIP embedding and pos embedding: (N, L, context_dim) + (N, 1, context_dim) -> (N, L+1, context_dim) cross_attn_context = torch.cat([cross_attn_context, pos_embed], dim=1) else: cross_attn_context = pos_embed hs = [] h = x.type(self.input_blocks[0][0].weight.dtype) for module in self.input_blocks: # Unify ModuleList and single module handling layers = module if isinstance(module, (nn.ModuleList, list)) else [module] for layer in layers: if isinstance(layer, ResBlock): h = layer(h, emb) elif isinstance(layer, SpatialTransformer): h = layer(h, context=cross_attn_context) else: h = layer(h) hs.append(h) # Middle block h = self.middle_block[0](h, emb) h = self.middle_block[1](h, context=cross_attn_context) if isinstance(self.middle_block[1], SpatialTransformer) else self.middle_block[1](h) h = self.middle_block[2](h, emb) # Output blocks for module in self.output_blocks: h = torch.cat([h, hs.pop()], dim=1) # Unify ModuleList and single module handling layers = module if isinstance(module, (nn.ModuleList, list)) else [module] for layer in layers: if isinstance(layer, ResBlock): h = layer(h, emb) elif isinstance(layer, SpatialTransformer): h = layer(h, context=cross_attn_context) else: h = layer(h) # Output h = h.type(x.dtype) return self.out(h) # ConditionalUNet is now merged into UNetModel # Use UNetModel with use_pos_conditioning=True instead # Keeping this alias for backward compatibility class ConditionalUNet(UNetModel): def __init__(self, *args, **kwargs): kwargs['use_pos_conditioning'] = True super().__init__(*args, **kwargs) def zero_module(module): for p in module.parameters(): p.detach().zero_() return module class ControlNetUNet(nn.Module): def __init__( self, in_channels=4, out_channels=None, model_channels=320, attention_resolutions=(4, 2, 1), num_res_blocks=2, channel_mult=(1, 2, 4, 4), num_head_channels=1, use_spatial_transformer=True, use_linear_in_transformer=False, transformer_depth=1, context_dim=768, num_classes=1000, learn_sigma=False, class_dropout_prob=0.1, condition_channels=4, ): super().__init__() if out_channels is None: out_channels = in_channels * 2 if learn_sigma else in_channels self.in_channels = in_channels self.out_channels = out_channels self.model_channels = model_channels self.num_classes = num_classes self.learn_sigma = learn_sigma self.class_dropout_prob = class_dropout_prob self.use_spatial_transformer = use_spatial_transformer self.context_dim = context_dim self.condition_channels = condition_channels # Position condition network pos_dim = 2 self.condition_net = nn.Sequential( nn.Linear(pos_dim, context_dim), nn.SiLU(), nn.Linear(context_dim, context_dim), nn.SiLU(), nn.Linear(context_dim, context_dim), ) # Time embedding time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( nn.Linear(model_channels, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim), ) # Class embedding self.label_embedder = LabelEmbedder(num_classes, time_embed_dim, class_dropout_prob) # ============================================================================ # ControlNet (trainable copy of encoder) # ============================================================================ self.controlnet_input_blocks = nn.ModuleList([ nn.Sequential(nn.Conv2d(condition_channels, model_channels, 3, padding=1, padding_mode='reflect')) ]) controlnet_block_chans = [model_channels] ch = model_channels ds = 1 num_heads = model_channels // num_head_channels for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): layers = [ ResBlock( ch, time_embed_dim, dropout=0.0, out_channels=int(mult * model_channels), dims=2, use_checkpoint=False, ) ] ch = int(mult * model_channels) if ds in attention_resolutions: if use_spatial_transformer and context_dim is not None: layers.append( SpatialTransformer( ch, num_heads, num_head_channels, depth=transformer_depth, dropout=0.0, context_dim=context_dim ) ) else: layers.append(AttentionBlock(ch, num_head_channels=num_head_channels)) self.controlnet_input_blocks.append(nn.ModuleList(layers)) controlnet_block_chans.append(ch) if level != len(channel_mult) - 1: self.controlnet_input_blocks.append(Downsample(ch, use_conv=True, dims=2)) controlnet_block_chans.append(ch) ds *= 2 # ControlNet Middle block if use_spatial_transformer and context_dim is not None: self.controlnet_middle_block = nn.ModuleList([ ResBlock(ch, time_embed_dim, dropout=0.0, dims=2, use_checkpoint=False), SpatialTransformer( ch, num_heads, num_head_channels, depth=transformer_depth, dropout=0.0, context_dim=context_dim ), ResBlock(ch, time_embed_dim, dropout=0.0, dims=2, use_checkpoint=False), ]) else: self.controlnet_middle_block = nn.ModuleList([ ResBlock(ch, time_embed_dim, dropout=0.0, dims=2, use_checkpoint=False), AttentionBlock(ch, num_head_channels=num_head_channels), ResBlock(ch, time_embed_dim, dropout=0.0, dims=2, use_checkpoint=False), ]) # Zero convolutions for ControlNet outputs self.controlnet_zero_convs = nn.ModuleList([]) for block_ch in controlnet_block_chans: self.controlnet_zero_convs.append(zero_module(nn.Conv2d(block_ch, block_ch, 1))) # Middle block zero conv self.controlnet_middle_zero_conv = zero_module(nn.Conv2d(ch, ch, 1)) # ============================================================================ # UNet # ============================================================================ self.input_blocks = nn.ModuleList([ nn.Sequential(nn.Conv2d(in_channels, model_channels, 3, padding=1, padding_mode='reflect')) ]) input_block_chans = [model_channels] ch = model_channels ds = 1 for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): layers = [ ResBlock( ch, time_embed_dim, dropout=0.0, out_channels=int(mult * model_channels), dims=2, use_checkpoint=False, ) ] ch = int(mult * model_channels) if ds in attention_resolutions: if use_spatial_transformer and context_dim is not None: layers.append( SpatialTransformer( ch, num_heads, num_head_channels, depth=transformer_depth, dropout=0.0, context_dim=context_dim ) ) else: layers.append(AttentionBlock(ch, num_head_channels=num_head_channels)) self.input_blocks.append(nn.ModuleList(layers)) input_block_chans.append(ch) if level != len(channel_mult) - 1: self.input_blocks.append(Downsample(ch, use_conv=True, dims=2)) input_block_chans.append(ch) ds *= 2 # Middle block if use_spatial_transformer and context_dim is not None: self.middle_block = nn.ModuleList([ ResBlock(ch, time_embed_dim, dropout=0.0, dims=2, use_checkpoint=False), SpatialTransformer( ch, num_heads, num_head_channels, depth=transformer_depth, dropout=0.0, context_dim=context_dim ), ResBlock(ch, time_embed_dim, dropout=0.0, dims=2, use_checkpoint=False), ]) else: self.middle_block = nn.ModuleList([ ResBlock(ch, time_embed_dim, dropout=0.0, dims=2, use_checkpoint=False), AttentionBlock(ch, num_head_channels=num_head_channels), ResBlock(ch, time_embed_dim, dropout=0.0, dims=2, use_checkpoint=False), ]) # Output blocks self.output_blocks = nn.ModuleList([]) for level, mult in list(enumerate(channel_mult))[::-1]: for i in range(num_res_blocks + 1): ich = input_block_chans.pop() layers = [ ResBlock( ch + ich, time_embed_dim, dropout=0.0, out_channels=int(model_channels * mult), dims=2, use_checkpoint=False, ) ] ch = int(model_channels * mult) if ds in attention_resolutions: if use_spatial_transformer and context_dim is not None: layers.append( SpatialTransformer( ch, num_heads, num_head_channels, depth=transformer_depth, dropout=0.0, context_dim=context_dim ) ) else: layers.append(AttentionBlock(ch, num_head_channels=num_head_channels)) if level and i == num_res_blocks: layers.append(Upsample(ch, use_conv=True, dims=2)) ds //= 2 self.output_blocks.append(nn.ModuleList(layers)) # Output layer self.out = nn.Sequential( nn.GroupNorm(32, ch), nn.SiLU(), nn.Conv2d(ch, out_channels, 3, padding=1, padding_mode='reflect'), ) # Initialize time embedding nn.init.zeros_(self.time_embed[0].weight) nn.init.zeros_(self.time_embed[2].weight) def forward(self, x, t, y, condition=None, pos_context=None): """ x: (N, C, H, W) - noisy latent t: (N,) - timesteps y: (N,) - class labels condition: (N, condition_channels, H, W) - condition image latent pos_context: (N, 4) - position info """ # Time and class embedding t_emb = timestep_embedding(t, self.model_channels) t_emb = self.time_embed(t_emb) y_emb = self.label_embedder(y, self.training) emb = t_emb + y_emb # Position context for cross-attention cross_attn_context = None if pos_context is not None: cross_attn_context = self.condition_net(pos_context) cross_attn_context = cross_attn_context.unsqueeze(1) # ============================================================================ # ControlNet forward pass # ============================================================================ controlnet_outputs = [] if condition is not None: h_ctrl = condition.type(self.controlnet_input_blocks[0][0].weight.dtype) ctrl_idx = 0 for module in self.controlnet_input_blocks: # Unify ModuleList and single module handling layers = module if isinstance(module, (nn.ModuleList, list)) else [module] for layer in layers: if isinstance(layer, ResBlock): h_ctrl = layer(h_ctrl, emb) elif isinstance(layer, SpatialTransformer): h_ctrl = layer(h_ctrl, context=cross_attn_context) else: h_ctrl = layer(h_ctrl) # Apply zero convolution and store controlnet_outputs.append(self.controlnet_zero_convs[ctrl_idx](h_ctrl)) ctrl_idx += 1 # ControlNet middle block h_ctrl = self.controlnet_middle_block[0](h_ctrl, emb) h_ctrl = self.controlnet_middle_block[1](h_ctrl, context=cross_attn_context) if isinstance(self.controlnet_middle_block[1], SpatialTransformer) else self.controlnet_middle_block[1](h_ctrl) h_ctrl = self.controlnet_middle_block[2](h_ctrl, emb) controlnet_middle_output = self.controlnet_middle_zero_conv(h_ctrl) else: controlnet_outputs = [0] * len(self.controlnet_zero_convs) controlnet_middle_output = 0 # ============================================================================ # UNet forward pass (add ControlNet output) # ============================================================================ hs = [] h = x.type(self.input_blocks[0][0].weight.dtype) ctrl_idx = 0 for module in self.input_blocks: # Unify ModuleList and single module handling layers = module if isinstance(module, (nn.ModuleList, list)) else [module] for layer in layers: if isinstance(layer, ResBlock): h = layer(h, emb) elif isinstance(layer, SpatialTransformer): h = layer(h, context=cross_attn_context) else: h = layer(h) # Add ControlNet output h = h + controlnet_outputs[ctrl_idx] ctrl_idx += 1 hs.append(h) # Middle block (add ControlNet middle output) h = self.middle_block[0](h, emb) h = self.middle_block[1](h, context=cross_attn_context) if isinstance(self.middle_block[1], SpatialTransformer) else self.middle_block[1](h) h = self.middle_block[2](h, emb) h = h + controlnet_middle_output # Output blocks (decoder) for module in self.output_blocks: h = torch.cat([h, hs.pop()], dim=1) # Unify ModuleList and single module handling layers = module if isinstance(module, (nn.ModuleList, list)) else [module] for layer in layers: if isinstance(layer, ResBlock): h = layer(h, emb) elif isinstance(layer, SpatialTransformer): h = layer(h, context=cross_attn_context) else: h = layer(h) h = h.type(x.dtype) return self.out(h) ################################################################################# # DiT Configs # ################################################################################# def DiT_XL_2(**kwargs): return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) def DiT_XL_4(**kwargs): return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) def DiT_XL_8(**kwargs): return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) def DiT_L_2(**kwargs): return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) def DiT_L_4(**kwargs): return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) def DiT_L_8(**kwargs): return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) def DiT_B_2(**kwargs): return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) def DiT_B_4(**kwargs): return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) def DiT_B_8(**kwargs): return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) def DiT_S_2(**kwargs): return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) def DiT_S_4(**kwargs): return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) def DiT_S_8(**kwargs): return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) class ConditionalDiT(DiT): """ Unified conditional DiT that supports multiple conditioning types: - condition: (N, C, H, W) latent images - added via PatchEmbed - pos_context: (N, 4) position info - added to timestep embedding - clip_embed: (N, L, context_dim) CLIP embeddings - via cross-attention """ def __init__(self, *args, context_dim=None, out_channels=None, **kwargs): # Extract depth before calling super().__init__ to know how many blocks to create depth = kwargs.get('depth', 28) mlp_ratio = kwargs.get('mlp_ratio', 4.0) super().__init__(*args, **kwargs) # Override out_channels if explicitly provided (for CLIPDiT with concat input) if out_channels is not None: self.out_channels = out_channels # Recreate final layer with correct output channels self.final_layer = FinalLayer(self.hidden_size, self.patch_size, self.out_channels) # Zero-out the newly created final layer (critical for training stability) nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) # Condition embedding (for latent images) self.condition_embedder = PatchEmbed( self.x_embedder.img_size, self.x_embedder.patch_size, self.in_channels, self.hidden_size, bias=True ) # Position context network pos_dim = 2 self.position_net = nn.Sequential( nn.Linear(pos_dim, self.hidden_size), nn.SiLU(), nn.Linear(self.hidden_size, self.hidden_size), nn.SiLU(), nn.Linear(self.hidden_size, self.hidden_size), ) # CLIP cross-attention support self.context_dim = context_dim if context_dim is not None: # Replace DiT blocks with CrossAttentionDiT blocks for CLIP conditioning self.blocks = nn.ModuleList([ CrossAttentionDiTBlock(self.hidden_size, self.num_heads, context_dim=context_dim, mlp_ratio=mlp_ratio) for _ in range(depth) ]) # Re-initialize the new blocks for block in self.blocks: nn.init.constant_(block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.adaLN_modulation[-1].bias, 0) def forward(self, x, t, y, condition=None, pos_context=None, clip_embed=None): """ Forward pass of unified Conditional DiT. x: (N, C, H, W) tensor of spatial inputs (target images) t: (N,) tensor of diffusion timesteps y: (N,) tensor of class labels condition: (N, C, H, W) tensor of condition images (domain1 latents) - optional pos_context: (N, 4) tensor of position info - optional clip_embed: (N, L, context_dim) tensor of CLIP embeddings - optional """ x = self.x_embedder(x) + self.pos_embed # Add condition embedding if provided if condition is not None: condition_embed = self.condition_embedder(condition) x = x + condition_embed t = self.t_embedder(t) y = self.y_embedder(y, self.training) c = t + y # Add position context if provided if pos_context is not None: pos_embed = self.position_net(pos_context) # (N, hidden_size) c = c + pos_embed # Transformer blocks (with cross-attention if clip_embed is provided) for block in self.blocks: if isinstance(block, CrossAttentionDiTBlock) and clip_embed is not None: x = block(x, c, context=clip_embed) else: x = block(x, c) x = self.final_layer(x, c) x = self.unpatchify(x) return x # CLIPConditionedDiT is now merged into ConditionalDiT # Use ConditionalDiT with context_dim parameter for CLIP conditioning # Keeping this alias for backward compatibility class CLIPConditionedDiT(ConditionalDiT): def __init__(self, *args, context_dim=1024, **kwargs): super().__init__(*args, context_dim=context_dim, **kwargs) def forward(self, x, t, y, clip_embed=None, **kwargs): """Backward compatibility: map clip_embed to the unified interface""" return super().forward(x, t, y, clip_embed=clip_embed, **kwargs) def forward_with_cfg(self, x, t, y, clip_embed, cfg_scale): """ Forward pass with classifier-free guidance. """ half = x[: len(x) // 2] combined = torch.cat([half, half], dim=0) model_out = self.forward(combined, t, y, clip_embed=clip_embed) eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) return torch.cat([eps, rest], dim=1) ################################################################################# # SD2.0 UNet Configs # ################################################################################# def UNet_SD2_B(**kwargs): return UNetModel( model_channels=320, num_res_blocks=2, channel_mult=(1, 2, 4, 4), attention_resolutions=(4, 2, 1), num_head_channels=64, context_dim=1024, # SD2.0 使用 OpenCLIP **kwargs ) def UNet_SD2_L(**kwargs): return UNetModel( model_channels=512, num_res_blocks=3, channel_mult=(1, 1, 2, 2, 4, 4), attention_resolutions=(8, 4, 2, 1), num_head_channels=64, context_dim=1024, **kwargs ) def UNet_SD2_XL(**kwargs): return UNetModel( model_channels=1024, num_res_blocks=4, channel_mult=(1, 2, 4, 4), attention_resolutions=(4, 2, 1), num_head_channels=64, context_dim=1024, **kwargs ) def ConditionalUNet_SD2_B(**kwargs): default_kwargs = { 'out_channels': 4, 'learn_sigma': False, 'context_dim': 1024, # SD2.0 'use_pos_conditioning': True, # Enable position conditioning } default_kwargs.update(kwargs) return UNetModel( model_channels=320, num_res_blocks=2, channel_mult=(1, 2, 4, 4), attention_resolutions=(4, 2, 1), num_head_channels=64, **default_kwargs ) def ConditionalUNet_SD2_L(**kwargs): default_kwargs = { 'out_channels': 4, 'learn_sigma': False, 'context_dim': 1024, 'use_pos_conditioning': True, # Enable position conditioning } default_kwargs.update(kwargs) return UNetModel( model_channels=512, num_res_blocks=3, channel_mult=(1, 1, 2, 2, 4, 4), attention_resolutions=(8, 4, 2, 1), num_head_channels=64, **default_kwargs ) def ControlNetUNet_SD2_B(**kwargs): """ControlNetUNet Base configuration for Stable Diffusion 2.0""" default_kwargs = { 'out_channels': 4, 'learn_sigma': False, 'condition_channels': 4, 'context_dim': 1024, # SD2.0 } default_kwargs.update(kwargs) return ControlNetUNet( model_channels=320, num_res_blocks=2, channel_mult=(1, 2, 4, 4), attention_resolutions=(4, 2, 1), num_head_channels=64, **default_kwargs ) def ControlNetUNet_SD2_L(**kwargs): """ControlNetUNet Large configuration for Stable Diffusion 2.0""" default_kwargs = { 'out_channels': 4, 'learn_sigma': False, 'condition_channels': 4, 'context_dim': 1024, } default_kwargs.update(kwargs) return ControlNetUNet( model_channels=512, num_res_blocks=3, channel_mult=(1, 1, 2, 2, 4, 4), attention_resolutions=(8, 4, 2, 1), num_head_channels=64, **default_kwargs ) ################################################################################# # KL-f8 VAE Implementation # ################################################################################# class Encoder(nn.Module): """ VAE Encoder for KL-f8 (8x downsampling) """ def __init__( self, in_channels=3, ch=128, ch_mult=(1, 2, 4, 4), num_res_blocks=2, attn_resolutions=[], dropout=0.0, resamp_with_conv=True, resolution=256, z_channels=4, double_z=True, use_mid_attn=False, ): super().__init__() self.ch = ch self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.use_mid_attn = use_mid_attn # Downsampling self.conv_in = nn.Conv2d(in_channels, ch, 3, stride=1, padding=1) curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append( ResBlock( block_in, block_out, dropout=dropout, out_channels=block_out, dims=2, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(AttentionBlock(block_in)) down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, resamp_with_conv, dims=2) curr_res = curr_res // 2 self.down.append(down) # Middle self.mid = nn.Module() self.mid.block_1 = ResBlock( block_in, block_in, dropout=dropout, out_channels=block_in, dims=2, ) # optional attention layer (default not used, save memory) if self.use_mid_attn: self.mid.attn_1 = AttentionBlock(block_in) else: self.mid.attn_1 = nn.Identity() self.mid.block_2 = ResBlock( block_in, block_in, dropout=dropout, out_channels=block_in, dims=2, ) # End self.norm_out = nn.GroupNorm(32, block_in) self.conv_out = nn.Conv2d( block_in, 2 * z_channels if double_z else z_channels, 3, stride=1, padding=1, padding_mode='reflect' ) def forward(self, x): # Downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], emb=None) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # Middle h = hs[-1] h = self.mid.block_1(h, emb=None) h = self.mid.attn_1(h) h = self.mid.block_2(h, emb=None) # End h = self.norm_out(h) h = F.silu(h) h = self.conv_out(h) return h class Decoder(nn.Module): """ VAE Decoder for KL-f8 (8x upsampling) """ def __init__( self, out_ch=3, ch=128, ch_mult=(1, 2, 4, 4), num_res_blocks=2, attn_resolutions=[], dropout=0.0, resamp_with_conv=True, resolution=256, z_channels=4, use_mid_attn=False, # 新增:控制middle层attention ): super().__init__() self.ch = ch self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.use_mid_attn = use_mid_attn # Compute in_ch_mult and block_in at lowest resolution block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) # z to block_in self.conv_in = nn.Conv2d(z_channels, block_in, 3, stride=1, padding=1) # Middle self.mid = nn.Module() self.mid.block_1 = ResBlock( block_in, block_in, dropout=dropout, out_channels=block_in, dims=2, ) # optional attention layer (default not used, save memory) if self.use_mid_attn: self.mid.attn_1 = AttentionBlock(block_in) else: self.mid.attn_1 = nn.Identity() self.mid.block_2 = ResBlock( block_in, block_in, dropout=dropout, out_channels=block_in, dims=2, ) # Upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): block.append( ResBlock( block_in, block_out, dropout=dropout, out_channels=block_out, dims=2, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(AttentionBlock(block_in)) up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv, dims=2) curr_res = curr_res * 2 self.up.insert(0, up) # End self.norm_out = nn.GroupNorm(32, block_in) self.conv_out = nn.Conv2d(block_in, out_ch, 3, stride=1, padding=1, padding_mode='reflect') def forward(self, z): # z to block_in h = self.conv_in(z) # Middle h = self.mid.block_1(h, emb=None) h = self.mid.attn_1(h) h = self.mid.block_2(h, emb=None) # Upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h, emb=None) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: h = self.up[i_level].upsample(h) # End h = self.norm_out(h) h = F.silu(h) h = self.conv_out(h) return h class ConditionalEncoder(nn.Module): def __init__( self, in_channels=3, ch=128, ch_mult=(1, 2, 4, 4, 8), num_res_blocks=2, attn_resolutions=[], dropout=0.0, resamp_with_conv=True, resolution=256, z_channels=4, double_z=True, context_dim=768, ): super().__init__() self.ch = ch self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.context_dim = context_dim # Downsampling self.conv_in = nn.Conv2d(in_channels, ch, 3, stride=1, padding=1, padding_mode='reflect') curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append( ResBlock( block_in, block_out, dropout=dropout, out_channels=block_out, dims=2, ) ) block_in = block_out # add cross-attention in the lowest resolution layer (i_level == num_resolutions - 1) if curr_res in attn_resolutions: attn.append(AttentionBlock(block_in)) elif i_level == self.num_resolutions - 1 and i_block == self.num_res_blocks - 1: # add SpatialTransformer (cross-attention) in the last layer num_heads = max(1, block_in // 64) attn.append( SpatialTransformer( in_channels=block_in, n_heads=num_heads, d_head=64, depth=1, dropout=dropout, context_dim=context_dim, ) ) down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, resamp_with_conv, dims=2) curr_res = curr_res // 2 self.down.append(down) # Middle - replace self-attention with SpatialTransformer self.mid = nn.Module() self.mid.block_1 = ResBlock( block_in, block_in, dropout=dropout, out_channels=block_in, dims=2, ) # replace self-attention with cross-attention num_heads = max(1, block_in // 64) self.mid.attn_1 = SpatialTransformer( in_channels=block_in, n_heads=num_heads, d_head=64, depth=1, dropout=dropout, context_dim=context_dim, ) self.mid.block_2 = ResBlock( block_in, block_in, dropout=dropout, out_channels=block_in, dims=2, ) # End self.norm_out = nn.GroupNorm(32, block_in) self.conv_out = nn.Conv2d( block_in, 2 * z_channels if double_z else z_channels, 3, stride=1, padding=1, padding_mode='reflect' ) def forward(self, x, context=None): if context is not None and len(context.shape) == 2: context = context.unsqueeze(1) # (B, context_dim) -> (B, 1, context_dim) # Downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], emb=None) if i_block < len(self.down[i_level].attn): attn_module = self.down[i_level].attn[i_block] if isinstance(attn_module, SpatialTransformer): h = attn_module(h, context=context) else: h = attn_module(h) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # Middle h = hs[-1] h = self.mid.block_1(h, emb=None) h = self.mid.attn_1(h, context=context) # cross-attention h = self.mid.block_2(h, emb=None) # End h = self.norm_out(h) h = F.silu(h) h = self.conv_out(h) return h class ConditionalDecoder(nn.Module): def __init__( self, out_ch=3, ch=128, ch_mult=(1, 2, 4, 4, 8), num_res_blocks=2, attn_resolutions=[], dropout=0.0, resamp_with_conv=True, resolution=256, z_channels=4, context_dim=768, ): super().__init__() self.ch = ch self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.context_dim = context_dim # Compute in_ch_mult and block_in at lowest resolution block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) # z to block_in self.conv_in = nn.Conv2d(z_channels, block_in, 3, stride=1, padding=1, padding_mode='reflect') # Middle - replace self-attention with SpatialTransformer self.mid = nn.Module() self.mid.block_1 = ResBlock( block_in, block_in, dropout=dropout, out_channels=block_in, dims=2, ) # replace self-attention with cross-attention num_heads = max(1, block_in // 64) self.mid.attn_1 = SpatialTransformer( in_channels=block_in, n_heads=num_heads, d_head=64, depth=1, dropout=dropout, context_dim=context_dim, ) self.mid.block_2 = ResBlock( block_in, block_in, dropout=dropout, out_channels=block_in, dims=2, ) # Upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): block.append( ResBlock( block_in, block_out, dropout=dropout, out_channels=block_out, dims=2, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(AttentionBlock(block_in)) elif i_level == self.num_resolutions - 1 and i_block == 0: num_heads = max(1, block_in // 64) attn.append( SpatialTransformer( in_channels=block_in, n_heads=num_heads, d_head=64, depth=1, dropout=dropout, context_dim=context_dim, ) ) up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv, dims=2) curr_res = curr_res * 2 self.up.insert(0, up) self.norm_out = nn.GroupNorm(32, block_in) self.conv_out = nn.Conv2d(block_in, out_ch, 3, stride=1, padding=1, padding_mode='reflect') def forward(self, z, context=None): if context is not None and len(context.shape) == 2: context = context.unsqueeze(1) h = self.conv_in(z) h = self.mid.block_1(h, emb=None) h = self.mid.attn_1(h, context=context) h = self.mid.block_2(h, emb=None) for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h, emb=None) if i_block < len(self.up[i_level].attn): attn_module = self.up[i_level].attn[i_block] if isinstance(attn_module, SpatialTransformer): h = attn_module(h, context=context) else: h = attn_module(h) if i_level != 0: h = self.up[i_level].upsample(h) h = self.norm_out(h) h = F.silu(h) h = self.conv_out(h) return h class DiagonalGaussianDistribution: def __init__(self, parameters, deterministic=False): self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: self.var = self.std = torch.zeros_like(self.mean) def sample(self): x = self.mean + self.std * torch.randn_like(self.mean) return x def kl(self, other=None): if self.deterministic: return torch.Tensor([0.0]) else: if other is None: # KL divergence with standard normal return 0.5 * torch.sum( torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3], ) else: # KL divergence with another Gaussian return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=[1, 2, 3], ) def nll(self, sample, dims=[1, 2, 3]): if self.deterministic: return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum( logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims, ) def mode(self): return self.mean class AutoencoderKL(nn.Module): """ KL-regularized Autoencoder for VAE Supports variable downsampling factors (f8, f16, etc.) """ def __init__( self, embed_dim=4, in_channels=3, out_ch=3, ch=128, ch_mult=(1, 2, 4, 4), num_res_blocks=2, attn_resolutions=[], dropout=0.0, resolution=256, z_channels=4, double_z=True, use_mid_attn=False, ): super().__init__() self.embed_dim = embed_dim self.encoder = Encoder( in_channels=in_channels, ch=ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks, attn_resolutions=attn_resolutions, dropout=dropout, resolution=resolution, z_channels=z_channels, double_z=double_z, use_mid_attn=use_mid_attn, ) self.decoder = Decoder( out_ch=out_ch, ch=ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks, attn_resolutions=attn_resolutions, dropout=dropout, resolution=resolution, z_channels=z_channels, use_mid_attn=use_mid_attn, ) # Latent space projection self.quant_conv = nn.Conv2d(2 * z_channels if double_z else z_channels, 2 * embed_dim, 1) self.post_quant_conv = nn.Conv2d(embed_dim, z_channels, 1) def encode(self, x): h = self.encoder(x) moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) return posterior def decode(self, z): z = self.post_quant_conv(z) dec = self.decoder(z) return dec def forward(self, x, sample_posterior=True): posterior = self.encode(x) if sample_posterior: z = posterior.sample() else: z = posterior.mode() dec = self.decode(z) return dec, posterior class ConditionalAutoencoderKL(nn.Module): def __init__( self, embed_dim=4, in_channels=3, out_ch=3, ch=128, ch_mult=(1, 2, 4, 4, 8), num_res_blocks=2, attn_resolutions=[], dropout=0.0, resolution=256, z_channels=4, double_z=True, context_dim=256, pos_dim=2, ): super().__init__() self.embed_dim = embed_dim self.context_dim = context_dim self.pos_dim = pos_dim self.condition_net = nn.Sequential( nn.Linear(pos_dim, context_dim), nn.SiLU(), nn.Linear(context_dim, context_dim), nn.SiLU(), nn.Linear(context_dim, context_dim), ) # Conditional Encoder with cross-attention self.encoder = ConditionalEncoder( in_channels=in_channels, ch=ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks, attn_resolutions=attn_resolutions, dropout=dropout, resolution=resolution, z_channels=z_channels, double_z=double_z, context_dim=context_dim, ) # Conditional Decoder with cross-attention self.decoder = ConditionalDecoder( out_ch=out_ch, ch=ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks, attn_resolutions=attn_resolutions, dropout=dropout, resolution=resolution, z_channels=z_channels, context_dim=context_dim, ) # Latent space projection self.quant_conv = nn.Conv2d(2 * z_channels if double_z else z_channels, 2 * embed_dim, 1) self.post_quant_conv = nn.Conv2d(embed_dim, z_channels, 1) def encode(self, x, pos_context=None): """ Args: x: (B, C, H, W) 输入图像 pos_context: (B, 4) 位置信息 [x_c_norm, y_c_norm, x_max_norm, y_max_norm] """ context = None if pos_context is not None: context = self.condition_net(pos_context) # (B, context_dim) h = self.encoder(x, context=context) moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) return posterior def decode(self, z, pos_context=None): context = None if pos_context is not None: context = self.condition_net(pos_context) # (B, context_dim) z = self.post_quant_conv(z) dec = self.decoder(z, context=context) return dec def forward(self, x, sample_posterior=True, pos_context=None): posterior = self.encode(x, pos_context=pos_context) if sample_posterior: z = posterior.sample() else: z = posterior.mode() dec = self.decode(z, pos_context=pos_context) return dec, posterior ################################################################################# # VAE Configs # ################################################################################# def AutoencoderKL_f8(**kwargs): defaults = { 'embed_dim': 4, 'z_channels': 4, 'ch':64, 'ch_mult': (1, 2, 4, 4), 'num_res_blocks': 1, 'attn_resolutions': [], 'dropout': 0.0, 'double_z': True, 'use_mid_attn': False, # } defaults.update(kwargs) return AutoencoderKL(**defaults) def AutoencoderKL_f16(**kwargs): defaults = { 'embed_dim': 4, 'z_channels': 4, 'ch': 128, 'ch_mult': (1, 2, 4, 4, 8), 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0, 'double_z': True, 'use_mid_attn': False, } defaults.update(kwargs) return AutoencoderKL(**defaults) def AutoencoderKL_f4(**kwargs): defaults = { 'embed_dim': 4, 'z_channels': 4, 'ch': 128, 'ch_mult': (1, 2, 4), 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0, 'double_z': True, } defaults.update(kwargs) return AutoencoderKL(**defaults) def AutoencoderKL_f16_condition(**kwargs): defaults = { 'embed_dim': 4, 'z_channels': 4, 'ch': 128, 'ch_mult': (1, 2, 4, 4, 8), 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0, 'double_z': True, 'context_dim': 256, 'pos_dim': 2, } defaults.update(kwargs) return ConditionalAutoencoderKL(**defaults) DiT_models = { # DiT 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8, 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8, 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8, 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8, # Unified conditional DiT model (supports condition, pos_context, and/or clip_embed) 'ConditionalDiT-XL/2': lambda **kwargs: ConditionalDiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, learn_sigma=False, **kwargs), 'ConditionalDiT-XL/4': lambda **kwargs: ConditionalDiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, learn_sigma=False, **kwargs), 'ConditionalDiT-XL/8': lambda **kwargs: ConditionalDiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, learn_sigma=False, **kwargs), 'ConditionalDiT-L/2': lambda **kwargs: ConditionalDiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, learn_sigma=False, **kwargs), 'ConditionalDiT-L/4': lambda **kwargs: ConditionalDiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, learn_sigma=False, **kwargs), 'ConditionalDiT-L/8': lambda **kwargs: ConditionalDiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, learn_sigma=False, **kwargs), 'ConditionalDiT-B/2': lambda **kwargs: ConditionalDiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, learn_sigma=False, **kwargs), 'ConditionalDiT-B/4': lambda **kwargs: ConditionalDiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, learn_sigma=False, **kwargs), 'ConditionalDiT-B/8': lambda **kwargs: ConditionalDiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, learn_sigma=False, **kwargs), 'ConditionalDiT-S/2': lambda **kwargs: ConditionalDiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, learn_sigma=False, **kwargs), 'ConditionalDiT-S/4': lambda **kwargs: ConditionalDiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, learn_sigma=False, **kwargs), 'ConditionalDiT-S/8': lambda **kwargs: ConditionalDiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, learn_sigma=False, **kwargs), # CLIP-conditioned DiT (uses ConditionalDiT with context_dim for cross-attention) # context_dim=1024 corresponds to OpenCLIP ViT-H/14 (used in SD2.0), but architecture is DiT, not UNet 'CLIPDiT-XL/2': lambda **kwargs: ConditionalDiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, context_dim=1024, learn_sigma=False, **kwargs), 'CLIPDiT-XL/4': lambda **kwargs: ConditionalDiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, context_dim=1024, learn_sigma=False, **kwargs), 'CLIPDiT-XL/8': lambda **kwargs: ConditionalDiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, context_dim=1024, learn_sigma=False, **kwargs), 'CLIPDiT-L/2': lambda **kwargs: ConditionalDiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, context_dim=1024, learn_sigma=False, **kwargs), 'CLIPDiT-L/4': lambda **kwargs: ConditionalDiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, context_dim=1024, learn_sigma=False, **kwargs), 'CLIPDiT-L/8': lambda **kwargs: ConditionalDiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, context_dim=1024, learn_sigma=False, **kwargs), 'CLIPDiT-B/2': lambda **kwargs: ConditionalDiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, context_dim=1024, learn_sigma=False, **kwargs), 'CLIPDiT-B/4': lambda **kwargs: ConditionalDiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, context_dim=1024, learn_sigma=False, **kwargs), 'CLIPDiT-B/8': lambda **kwargs: ConditionalDiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, context_dim=1024, learn_sigma=False, **kwargs), 'CLIPDiT-S/2': lambda **kwargs: ConditionalDiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, context_dim=1024, learn_sigma=False, **kwargs), 'CLIPDiT-S/4': lambda **kwargs: ConditionalDiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, context_dim=1024, learn_sigma=False, **kwargs), 'CLIPDiT-S/8': lambda **kwargs: ConditionalDiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, context_dim=1024, learn_sigma=False, **kwargs), # UNet for SD2.0 (context_dim=1024) 'UNet-SD2-B': UNet_SD2_B, 'UNet-SD2-L': UNet_SD2_L, 'UNet-SD2-XL': UNet_SD2_XL, # conditional UNet for SD2.0 'ConditionalUNet-SD2-B': ConditionalUNet_SD2_B, 'ConditionalUNet-SD2-L': ConditionalUNet_SD2_L, # ControlNet UNet for SD2.0 'ControlNetUNet-SD2-B': ControlNetUNet_SD2_B, 'ControlNetUNet-SD2-L': ControlNetUNet_SD2_L, # VAE 'VAE-KL-f4': AutoencoderKL_f4, 'VAE-KL-f8': AutoencoderKL_f8, 'VAE-KL-f16': AutoencoderKL_f16, # Conditional VAE 'VAE-KL-f16-condition': AutoencoderKL_f16_condition, }