| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from typing import Tuple, Optional |
| from einops import rearrange |
| from .wan_video_camera_controller import SimpleAdapter |
| try: |
| import flash_attn_interface |
| FLASH_ATTN_3_AVAILABLE = True |
| except ModuleNotFoundError: |
| FLASH_ATTN_3_AVAILABLE = False |
|
|
| try: |
| import flash_attn |
| FLASH_ATTN_2_AVAILABLE = True |
| except ModuleNotFoundError: |
| FLASH_ATTN_2_AVAILABLE = False |
|
|
| try: |
| from sageattention import sageattn |
| SAGE_ATTN_AVAILABLE = True |
| except ModuleNotFoundError: |
| SAGE_ATTN_AVAILABLE = False |
| |
| |
| def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False): |
| if compatibility_mode: |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) |
| x = F.scaled_dot_product_attention(q, k, v) |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) |
| elif FLASH_ATTN_3_AVAILABLE: |
| q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) |
| x = flash_attn_interface.flash_attn_func(q, k, v) |
| if isinstance(x,tuple): |
| x = x[0] |
| x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) |
| elif FLASH_ATTN_2_AVAILABLE: |
| q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) |
| x = flash_attn.flash_attn_func(q, k, v) |
| x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) |
| elif SAGE_ATTN_AVAILABLE: |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) |
| x = sageattn(q, k, v) |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) |
| else: |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) |
| x = F.scaled_dot_product_attention(q, k, v) |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) |
| return x |
|
|
| def scaled_dot_product_attention_with_mask( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| num_heads: int, |
| attn_mask: Optional[torch.Tensor], |
| ): |
| """Always uses PyTorch SDPA because FlashAttention variants may not support arbitrary masks. |
| |
| Args: |
| q,k,v: (B, S, D) |
| attn_mask: float mask broadcastable to (B, num_heads, Sq, Sk) with 0 for allowed, -inf for disallowed |
| or bool mask broadcastable where False indicates disallowed. |
| """ |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) |
| if attn_mask is not None: |
| |
| if attn_mask.dtype == torch.bool: |
| mask = attn_mask |
| else: |
| mask = attn_mask |
| if attn_mask.dim() == 3: |
| mask = mask.unsqueeze(1) |
| elif attn_mask.dim() != 4: |
| raise ValueError(f"attn_mask must be 3D or 4D, got shape={attn_mask.shape}") |
| x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) |
| else: |
| x = F.scaled_dot_product_attention(q, k, v) |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) |
| return x |
|
|
|
|
| class MaskedCrossAttention(nn.Module): |
| """Cross-attention with explicit attention mask support (used by IMCA).""" |
|
|
| def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
|
|
| self.q = nn.Linear(dim, dim) |
| self.k = nn.Linear(dim, dim) |
| self.v = nn.Linear(dim, dim) |
| self.o = nn.Linear(dim, dim) |
| self.norm_q = RMSNorm(dim, eps=eps) |
| self.norm_k = RMSNorm(dim, eps=eps) |
|
|
| def forward(self, x_q: torch.Tensor, x_kv: torch.Tensor, attn_mask: Optional[torch.Tensor]): |
| q = self.norm_q(self.q(x_q)) |
| k = self.norm_k(self.k(x_kv)) |
| v = self.v(x_kv) |
| x = scaled_dot_product_attention_with_mask(q, k, v, num_heads=self.num_heads, attn_mask=attn_mask) |
| return self.o(x) |
|
|
|
|
| class SharedTimestepAdaptivePromptEnhancement(nn.Module): |
| """STAPE: I = I + (m_t + alpha1) * CrossAttn(I, T).""" |
|
|
| def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): |
| super().__init__() |
| self.attn = CrossAttention(dim, num_heads, eps=eps, has_image_input=False) |
| |
| self.mt = nn.Parameter(torch.zeros(1, dim)) |
|
|
| def forward(self, instance_tokens: torch.Tensor, caption_tokens: torch.Tensor, alpha1: torch.Tensor): |
| |
| B, F_, Nins, D = instance_tokens.shape |
| I = instance_tokens.reshape(B, F_ * Nins, D) |
| delta = self.attn(I, caption_tokens) |
| gate = (self.mt.to(dtype=I.dtype, device=I.device) + alpha1).unsqueeze(1) |
| I = I + gate * delta |
| return I.reshape(B, F_, Nins, D) |
|
|
|
|
| class InstanceAwareMaskedCrossAttention(nn.Module): |
| """IMCA: masked cross-attention from visual tokens to instance prompt tokens.""" |
|
|
| def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): |
| super().__init__() |
| self.attn = MaskedCrossAttention(dim, num_heads, eps=eps) |
|
|
| def forward(self, visual_tokens: torch.Tensor, instance_tokens: torch.Tensor, attn_mask: torch.Tensor): |
| """Args: |
| visual_tokens: (B, F*HW, D) |
| instance_tokens: (B, F, Nins, D) |
| attn_mask: (B, F, Nins, HW) bool OR float, where True/1 means instance-token attends this visual token. |
| Returns: |
| (B, F*HW, D) |
| """ |
| B, Nv, D = visual_tokens.shape |
| _, F_, Nins, _ = instance_tokens.shape |
| HW = Nv // F_ |
| V = visual_tokens.reshape(B, F_, HW, D) |
| I = instance_tokens |
| |
| M = attn_mask |
| if M.shape[-1] != HW: |
| raise ValueError(f"attn_mask last dim must be HW={HW}, got {M.shape[-1]}") |
| |
| M = M.permute(0, 1, 3, 2).contiguous() |
| |
| target_dtype = visual_tokens.dtype |
| if M.dtype == torch.bool: |
| sdpa_mask = torch.where(M, torch.zeros((), device=M.device, dtype=target_dtype), |
| torch.full((), float("-inf"), device=M.device, dtype=target_dtype)) |
| else: |
| |
| sdpa_mask = M.to(dtype=target_dtype) |
| |
| V_bf = V.reshape(B * F_, HW, D) |
| I_bf = I.reshape(B * F_, Nins, D) |
| sdpa_mask_bf = sdpa_mask.reshape(B * F_, HW, Nins) |
| out = self.attn(V_bf, I_bf, sdpa_mask_bf) |
| return out.reshape(B, F_ * HW, D) |
|
|
|
|
|
|
|
|
| def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): |
| return (x * (1 + scale) + shift) |
|
|
|
|
| def sinusoidal_embedding_1d(dim, position): |
| sinusoid = torch.outer(position.type(torch.float64), torch.pow( |
| 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2))) |
| x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) |
| return x.to(position.dtype) |
|
|
|
|
| def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): |
| |
| f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta) |
| h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) |
| w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) |
| return f_freqs_cis, h_freqs_cis, w_freqs_cis |
|
|
|
|
| def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): |
| |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) |
| [: (dim // 2)].double() / dim)) |
| freqs = torch.outer(torch.arange(end, device=freqs.device), freqs) |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| return freqs_cis |
|
|
|
|
| def rope_apply(x, freqs, num_heads): |
| x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) |
| x_out = torch.view_as_complex(x.to(torch.float64).reshape( |
| x.shape[0], x.shape[1], x.shape[2], -1, 2)) |
| x_out = torch.view_as_real(x_out * freqs).flatten(2) |
| return x_out.to(x.dtype) |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-5): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def norm(self, x): |
| return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) |
|
|
| def forward(self, x): |
| dtype = x.dtype |
| return self.norm(x.float()).to(dtype) * self.weight |
|
|
|
|
| class AttentionModule(nn.Module): |
| def __init__(self, num_heads): |
| super().__init__() |
| self.num_heads = num_heads |
| |
| def forward(self, q, k, v): |
| x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads) |
| return x |
|
|
|
|
| class SelfAttention(nn.Module): |
| def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
|
|
| self.q = nn.Linear(dim, dim) |
| self.k = nn.Linear(dim, dim) |
| self.v = nn.Linear(dim, dim) |
| self.o = nn.Linear(dim, dim) |
| self.norm_q = RMSNorm(dim, eps=eps) |
| self.norm_k = RMSNorm(dim, eps=eps) |
| |
| self.attn = AttentionModule(self.num_heads) |
|
|
| def forward(self, x, freqs): |
| q = self.norm_q(self.q(x)) |
| k = self.norm_k(self.k(x)) |
| v = self.v(x) |
| q = rope_apply(q, freqs, self.num_heads) |
| k = rope_apply(k, freqs, self.num_heads) |
| x = self.attn(q, k, v) |
| return self.o(x) |
|
|
|
|
| class CrossAttention(nn.Module): |
| def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False): |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
|
|
| self.q = nn.Linear(dim, dim) |
| self.k = nn.Linear(dim, dim) |
| self.v = nn.Linear(dim, dim) |
| self.o = nn.Linear(dim, dim) |
| self.norm_q = RMSNorm(dim, eps=eps) |
| self.norm_k = RMSNorm(dim, eps=eps) |
| self.has_image_input = has_image_input |
| if has_image_input: |
| self.k_img = nn.Linear(dim, dim) |
| self.v_img = nn.Linear(dim, dim) |
| self.norm_k_img = RMSNorm(dim, eps=eps) |
| |
| self.attn = AttentionModule(self.num_heads) |
|
|
| def forward(self, x: torch.Tensor, y: torch.Tensor): |
| if self.has_image_input: |
| img = y[:, :257] |
| ctx = y[:, 257:] |
| else: |
| ctx = y |
| q = self.norm_q(self.q(x)) |
| k = self.norm_k(self.k(ctx)) |
| v = self.v(ctx) |
| x = self.attn(q, k, v) |
|
|
| if self.has_image_input: |
| k_img = self.norm_k_img(self.k_img(img)) |
| v_img = self.v_img(img) |
| y = flash_attention(q, k_img, v_img, num_heads=self.num_heads) |
| x = x + y |
| return self.o(x) |
|
|
|
|
| class GateModule(nn.Module): |
| def __init__(self,): |
| super().__init__() |
|
|
| def forward(self, x, gate, residual): |
| return x + gate * residual |
|
|
|
|
| class DiTBlock(nn.Module): |
| def __init__( |
| self, |
| has_image_input: bool, |
| dim: int, |
| num_heads: int, |
| ffn_dim: int, |
| eps: float = 1e-6, |
| enable_instancev: bool = False, |
| stape: Optional[SharedTimestepAdaptivePromptEnhancement] = None, |
| ): |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
| self.ffn_dim = ffn_dim |
| self.enable_instancev = enable_instancev |
| self.stape = stape |
|
|
| self.self_attn = SelfAttention(dim, num_heads, eps) |
| self.cross_attn = CrossAttention(dim, num_heads, eps, has_image_input=has_image_input) |
|
|
| |
| if enable_instancev: |
| self.imca = InstanceAwareMaskedCrossAttention(dim, num_heads, eps=eps) |
| |
| self.mv = nn.Parameter(torch.zeros(1)) |
| self.norm_imca = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) |
| else: |
| self.imca = None |
| self.mv = None |
| self.norm_imca = None |
|
|
| self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) |
| self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) |
| self.norm3 = nn.LayerNorm(dim, eps=eps) |
| self.ffn = nn.Sequential( |
| nn.Linear(dim, ffn_dim), |
| nn.GELU(approximate='tanh'), |
| nn.Linear(ffn_dim, dim), |
| ) |
| self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) |
| self.gate = GateModule() |
|
|
| |
| if enable_instancev and self.imca is not None: |
| self._init_imca_from_cross_attention() |
|
|
| def _init_imca_from_cross_attention(self): |
| |
| try: |
| self.imca.attn.q.load_state_dict(self.cross_attn.q.state_dict()) |
| self.imca.attn.k.load_state_dict(self.cross_attn.k.state_dict()) |
| self.imca.attn.v.load_state_dict(self.cross_attn.v.state_dict()) |
| self.imca.attn.o.load_state_dict(self.cross_attn.o.state_dict()) |
| self.imca.attn.norm_q.load_state_dict(self.cross_attn.norm_q.state_dict()) |
| self.imca.attn.norm_k.load_state_dict(self.cross_attn.norm_k.state_dict()) |
| except Exception: |
| |
| pass |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| context: torch.Tensor, |
| t_mod: torch.Tensor, |
| freqs: torch.Tensor, |
| instance_tokens: Optional[torch.Tensor] = None, |
| instance_attn_mask: Optional[torch.Tensor] = None, |
| empty_instance_tokens: Optional[torch.Tensor] = None, |
| saug_drop_prob: float = 0.0, |
| ): |
| """Args: |
| x: (B, F*H*W, D) |
| context: global caption tokens T after embedding (B, Nctx, D) |
| instance_tokens: I (B, F, Nins, D) after embedding |
| instance_attn_mask: M (B, F, Nins, H*W) bool/float |
| empty_instance_tokens: used for SAUG unconditional branch (same shape as instance_tokens) |
| """ |
| has_seq = len(t_mod.shape) == 4 |
| chunk_dim = 2 if has_seq else 1 |
|
|
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( |
| self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod |
| ).chunk(6, dim=chunk_dim) |
|
|
| if has_seq: |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( |
| shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2), |
| shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2), |
| ) |
|
|
| |
| input_x = modulate(self.norm1(x), shift_msa, scale_msa) |
| x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) |
|
|
| |
| if self.enable_instancev and (self.imca is not None) and (instance_tokens is not None) and (instance_attn_mask is not None): |
| |
| if isinstance(saug_drop_prob, torch.Tensor): |
| saug_p = float(saug_drop_prob.detach().cpu().item()) |
| else: |
| saug_p = float(saug_drop_prob) |
|
|
| if self.training and saug_p > 0.0 and empty_instance_tokens is not None: |
| if torch.rand((), device=x.device) < saug_p: |
| instance_tokens_use = empty_instance_tokens |
| else: |
| instance_tokens_use = instance_tokens |
| else: |
| instance_tokens_use = instance_tokens |
|
|
| |
| if self.stape is not None: |
| |
| alpha1 = gate_msa |
| instance_tokens_use = self.stape(instance_tokens_use, context, alpha1=alpha1) |
|
|
| |
| imca_out = self.imca(self.norm_imca(x), instance_tokens_use, instance_attn_mask) |
| x = x + self.mv.to(dtype=x.dtype, device=x.device) * imca_out |
|
|
| |
| x = x + self.cross_attn(self.norm3(x), context) |
|
|
| |
| input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) |
| x = self.gate(x, gate_mlp, self.ffn(input_x)) |
| return x |
|
|
|
|
| class MLP(torch.nn.Module): |
| def __init__(self, in_dim, out_dim, has_pos_emb=False): |
| super().__init__() |
| self.proj = torch.nn.Sequential( |
| nn.LayerNorm(in_dim), |
| nn.Linear(in_dim, in_dim), |
| nn.GELU(), |
| nn.Linear(in_dim, out_dim), |
| nn.LayerNorm(out_dim) |
| ) |
| self.has_pos_emb = has_pos_emb |
| if has_pos_emb: |
| self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280))) |
|
|
| def forward(self, x): |
| if self.has_pos_emb: |
| x = x + self.emb_pos.to(dtype=x.dtype, device=x.device) |
| return self.proj(x) |
|
|
|
|
| class Head(nn.Module): |
| def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float): |
| super().__init__() |
| self.dim = dim |
| self.patch_size = patch_size |
| self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) |
| self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) |
| self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) |
|
|
| def forward(self, x, t_mod): |
| if len(t_mod.shape) == 3: |
| shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2) |
| x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2))) |
| else: |
| shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) |
| x = (self.head(self.norm(x) * (1 + scale) + shift)) |
| return x |
|
|
|
|
| class WanModel(torch.nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| in_dim: int, |
| ffn_dim: int, |
| out_dim: int, |
| text_dim: int, |
| freq_dim: int, |
| eps: float, |
| patch_size: Tuple[int, int, int], |
| num_heads: int, |
| num_layers: int, |
| has_image_input: bool, |
| enable_instancev: bool = False, |
| has_image_pos_emb: bool = False, |
| has_ref_conv: bool = False, |
| add_control_adapter: bool = False, |
| in_dim_control_adapter: int = 24, |
| seperated_timestep: bool = False, |
| require_vae_embedding: bool = True, |
| require_clip_embedding: bool = True, |
| fuse_vae_embedding_in_latents: bool = False, |
| ): |
| super().__init__() |
| self.dim = dim |
| self.in_dim = in_dim |
| self.freq_dim = freq_dim |
| self.has_image_input = has_image_input |
| self.patch_size = patch_size |
| self.seperated_timestep = seperated_timestep |
| self.require_vae_embedding = require_vae_embedding |
| self.require_clip_embedding = require_clip_embedding |
| self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents |
|
|
| self.patch_embedding = nn.Conv3d( |
| in_dim, dim, kernel_size=patch_size, stride=patch_size) |
| self.text_embedding = nn.Sequential( |
| nn.Linear(text_dim, dim), |
| nn.GELU(approximate='tanh'), |
| nn.Linear(dim, dim) |
| ) |
| self.time_embedding = nn.Sequential( |
| nn.Linear(freq_dim, dim), |
| nn.SiLU(), |
| nn.Linear(dim, dim) |
| ) |
| self.time_projection = nn.Sequential( |
| nn.SiLU(), nn.Linear(dim, dim * 6)) |
| |
|
|
| self.enable_instancev = enable_instancev |
| if enable_instancev: |
| |
| self.stape = SharedTimestepAdaptivePromptEnhancement(dim=dim, num_heads=num_heads, eps=eps) |
| else: |
| self.stape = None |
|
|
| self.blocks = nn.ModuleList([ |
| DiTBlock( |
| has_image_input=has_image_input, |
| dim=dim, |
| num_heads=num_heads, |
| ffn_dim=ffn_dim, |
| eps=eps, |
| enable_instancev=enable_instancev, |
| stape=self.stape, |
| ) |
| for _ in range(num_layers) |
| ]) |
| self.head = Head(dim, out_dim, patch_size, eps) |
| head_dim = dim // num_heads |
| self.freqs = precompute_freqs_cis_3d(head_dim) |
|
|
| if has_image_input: |
| self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) |
| if has_ref_conv: |
| self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2)) |
| self.has_image_pos_emb = has_image_pos_emb |
| self.has_ref_conv = has_ref_conv |
| if add_control_adapter: |
| self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:]) |
| else: |
| self.control_adapter = None |
|
|
| def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None): |
| x = self.patch_embedding(x) |
| if self.control_adapter is not None and control_camera_latents_input is not None: |
| y_camera = self.control_adapter(control_camera_latents_input) |
| x = [u + v for u, v in zip(x, y_camera)] |
| x = x[0].unsqueeze(0) |
| return x |
|
|
| def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): |
| return rearrange( |
| x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)', |
| f=grid_size[0], h=grid_size[1], w=grid_size[2], |
| x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2] |
| ) |
|
|
| def forward(self, |
| x: torch.Tensor, |
| timestep: torch.Tensor, |
| context: torch.Tensor, |
| instance_prompt_tokens: Optional[torch.Tensor] = None, |
| instance_attn_mask: Optional[torch.Tensor] = None, |
| empty_instance_prompt_tokens: Optional[torch.Tensor] = None, |
| saug_drop_prob: float = 0.0, |
| clip_feature: Optional[torch.Tensor] = None, |
| y: Optional[torch.Tensor] = None, |
| use_gradient_checkpointing: bool = False, |
| use_gradient_checkpointing_offload: bool = False, |
| **kwargs, |
| ): |
| t = self.time_embedding( |
| sinusoidal_embedding_1d(self.freq_dim, timestep).to(x.dtype)) |
| t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) |
| context = self.text_embedding(context) |
|
|
| |
| if instance_prompt_tokens is not None: |
| instance_tokens = self.text_embedding(instance_prompt_tokens) |
| else: |
| instance_tokens = None |
|
|
| if empty_instance_prompt_tokens is not None: |
| empty_instance_tokens = self.text_embedding(empty_instance_prompt_tokens) |
| else: |
| empty_instance_tokens = None |
| |
|
|
| |
| |
| if self.enable_instancev and (instance_tokens is not None) and (empty_instance_tokens is None): |
| empty_instance_tokens = torch.zeros_like(instance_tokens) |
|
|
| if self.has_image_input: |
| x = torch.cat([x, y], dim=1) |
| clip_embdding = self.img_emb(clip_feature) |
| context = torch.cat([clip_embdding, context], dim=1) |
| |
| x, (f, h, w) = self.patchify(x) |
| |
| freqs = torch.cat([ |
| self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), |
| self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), |
| self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) |
| ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) |
| |
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs) |
| return custom_forward |
|
|
| |
| for block in self.blocks: |
| use_instancev_inputs = (self.enable_instancev and (instance_tokens is not None) and (instance_attn_mask is not None)) |
| if self.training and use_gradient_checkpointing: |
| if use_instancev_inputs: |
| if use_gradient_checkpointing_offload: |
| with torch.autograd.graph.save_on_cpu(): |
| x = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(block), |
| x, context, t_mod, freqs, |
| instance_tokens, instance_attn_mask, empty_instance_tokens, |
| torch.tensor(float(saug_drop_prob), device=x.device, dtype=x.dtype), |
| use_reentrant=False, |
| ) |
| else: |
| x = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(block), |
| x, context, t_mod, freqs, |
| instance_tokens, instance_attn_mask, empty_instance_tokens, |
| torch.tensor(float(saug_drop_prob), device=x.device, dtype=x.dtype), |
| use_reentrant=False, |
| ) |
| else: |
| if use_gradient_checkpointing_offload: |
| with torch.autograd.graph.save_on_cpu(): |
| x = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(block), |
| x, context, t_mod, freqs, |
| use_reentrant=False, |
| ) |
| else: |
| x = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(block), |
| x, context, t_mod, freqs, |
| use_reentrant=False, |
| ) |
| else: |
| if use_instancev_inputs: |
| x = block(x, context, t_mod, freqs, instance_tokens, instance_attn_mask, empty_instance_tokens, saug_drop_prob) |
| else: |
| x = block(x, context, t_mod, freqs) |
|
|
| x = self.head(x, t) |
| x = self.unpatchify(x, (f, h, w)) |
| return x |
|
|
|
|
| def apply_saug(eps_cond: torch.Tensor, eps_uncond: torch.Tensor, w: float) -> torch.Tensor: |
| """Spatially-Aware Unconditional Guidance (SAUG), paper Eq. (7): |
| eps_tilde = (1 + w) * eps_cond - w * eps_uncond |
| where eps_uncond is predicted with *empty instance prompts* but the same spatial masks. |
| """ |
| return (1.0 + w) * eps_cond - w * eps_uncond |
|
|
|
|