import torch import torch.nn as nn import torch.nn.functional as F import math from typing import Tuple, Optional from einops import rearrange from .wan_video_camera_controller import SimpleAdapter try: import flash_attn_interface FLASH_ATTN_3_AVAILABLE = True except ModuleNotFoundError: FLASH_ATTN_3_AVAILABLE = False try: import flash_attn FLASH_ATTN_2_AVAILABLE = True except ModuleNotFoundError: FLASH_ATTN_2_AVAILABLE = False try: from sageattention import sageattn SAGE_ATTN_AVAILABLE = True except ModuleNotFoundError: SAGE_ATTN_AVAILABLE = False def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False): if compatibility_mode: q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) x = F.scaled_dot_product_attention(q, k, v) x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) elif FLASH_ATTN_3_AVAILABLE: q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) x = flash_attn_interface.flash_attn_func(q, k, v) if isinstance(x,tuple): x = x[0] x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) elif FLASH_ATTN_2_AVAILABLE: q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) x = flash_attn.flash_attn_func(q, k, v) x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) elif SAGE_ATTN_AVAILABLE: q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) x = sageattn(q, k, v) x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) else: q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) x = F.scaled_dot_product_attention(q, k, v) x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) return x def scaled_dot_product_attention_with_mask( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, attn_mask: Optional[torch.Tensor], ): """Always uses PyTorch SDPA because FlashAttention variants may not support arbitrary masks. Args: q,k,v: (B, S, D) attn_mask: float mask broadcastable to (B, num_heads, Sq, Sk) with 0 for allowed, -inf for disallowed or bool mask broadcastable where False indicates disallowed. """ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) if attn_mask is not None: # Make it broadcastable to (B, n, Sq, Sk) if attn_mask.dtype == torch.bool: mask = attn_mask else: mask = attn_mask if attn_mask.dim() == 3: mask = mask.unsqueeze(1) elif attn_mask.dim() != 4: raise ValueError(f"attn_mask must be 3D or 4D, got shape={attn_mask.shape}") x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) else: x = F.scaled_dot_product_attention(q, k, v) x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) return x class MaskedCrossAttention(nn.Module): """Cross-attention with explicit attention mask support (used by IMCA).""" def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): super().__init__() self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads self.q = nn.Linear(dim, dim) self.k = nn.Linear(dim, dim) self.v = nn.Linear(dim, dim) self.o = nn.Linear(dim, dim) self.norm_q = RMSNorm(dim, eps=eps) self.norm_k = RMSNorm(dim, eps=eps) def forward(self, x_q: torch.Tensor, x_kv: torch.Tensor, attn_mask: Optional[torch.Tensor]): q = self.norm_q(self.q(x_q)) k = self.norm_k(self.k(x_kv)) v = self.v(x_kv) x = scaled_dot_product_attention_with_mask(q, k, v, num_heads=self.num_heads, attn_mask=attn_mask) return self.o(x) class SharedTimestepAdaptivePromptEnhancement(nn.Module): """STAPE: I = I + (m_t + alpha1) * CrossAttn(I, T).""" def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): super().__init__() self.attn = CrossAttention(dim, num_heads, eps=eps, has_image_input=False) # Dins-dimensional learnable residual gate (initialized to 0 for stability) self.mt = nn.Parameter(torch.zeros(1, dim)) def forward(self, instance_tokens: torch.Tensor, caption_tokens: torch.Tensor, alpha1: torch.Tensor): # instance_tokens: (B, F, Nins, D) caption_tokens: (B, Nctx, D) alpha1: (B, D) B, F_, Nins, D = instance_tokens.shape I = instance_tokens.reshape(B, F_ * Nins, D) delta = self.attn(I, caption_tokens) # (B, F*Nins, D) gate = (self.mt.to(dtype=I.dtype, device=I.device) + alpha1).unsqueeze(1) # (B, 1, D) I = I + gate * delta return I.reshape(B, F_, Nins, D) class InstanceAwareMaskedCrossAttention(nn.Module): """IMCA: masked cross-attention from visual tokens to instance prompt tokens.""" def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): super().__init__() self.attn = MaskedCrossAttention(dim, num_heads, eps=eps) def forward(self, visual_tokens: torch.Tensor, instance_tokens: torch.Tensor, attn_mask: torch.Tensor): """Args: visual_tokens: (B, F*HW, D) instance_tokens: (B, F, Nins, D) attn_mask: (B, F, Nins, HW) bool OR float, where True/1 means instance-token attends this visual token. Returns: (B, F*HW, D) """ B, Nv, D = visual_tokens.shape _, F_, Nins, _ = instance_tokens.shape HW = Nv // F_ V = visual_tokens.reshape(B, F_, HW, D) I = instance_tokens # Convert mask to (B*F, HW, Nins) with 0 / -inf M = attn_mask if M.shape[-1] != HW: raise ValueError(f"attn_mask last dim must be HW={HW}, got {M.shape[-1]}") # (B,F,Nins,HW) -> (B,F,HW,Nins) M = M.permute(0, 1, 3, 2).contiguous() # 使用与 visual_tokens 相同的 dtype(通常是 bfloat16) target_dtype = visual_tokens.dtype if M.dtype == torch.bool: sdpa_mask = torch.where(M, torch.zeros((), device=M.device, dtype=target_dtype), torch.full((), float("-inf"), device=M.device, dtype=target_dtype)) else: # assume already 0/-inf or similar sdpa_mask = M.to(dtype=target_dtype) # Merge batch and frame V_bf = V.reshape(B * F_, HW, D) I_bf = I.reshape(B * F_, Nins, D) sdpa_mask_bf = sdpa_mask.reshape(B * F_, HW, Nins) out = self.attn(V_bf, I_bf, sdpa_mask_bf) return out.reshape(B, F_ * HW, D) def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): return (x * (1 + scale) + shift) def sinusoidal_embedding_1d(dim, position): sinusoid = torch.outer(position.type(torch.float64), torch.pow( 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2))) x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) return x.to(position.dtype) def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): # 3d rope precompute f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta) h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) return f_freqs_cis, h_freqs_cis, w_freqs_cis def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): # 1d rope precompute freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) [: (dim // 2)].double() / dim)) freqs = torch.outer(torch.arange(end, device=freqs.device), freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return freqs_cis def rope_apply(x, freqs, num_heads): x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) x_out = torch.view_as_complex(x.to(torch.float64).reshape( x.shape[0], x.shape[1], x.shape[2], -1, 2)) x_out = torch.view_as_real(x_out * freqs).flatten(2) return x_out.to(x.dtype) class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def norm(self, x): return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) def forward(self, x): dtype = x.dtype return self.norm(x.float()).to(dtype) * self.weight class AttentionModule(nn.Module): def __init__(self, num_heads): super().__init__() self.num_heads = num_heads def forward(self, q, k, v): x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads) return x class SelfAttention(nn.Module): def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): super().__init__() self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads self.q = nn.Linear(dim, dim) self.k = nn.Linear(dim, dim) self.v = nn.Linear(dim, dim) self.o = nn.Linear(dim, dim) self.norm_q = RMSNorm(dim, eps=eps) self.norm_k = RMSNorm(dim, eps=eps) self.attn = AttentionModule(self.num_heads) def forward(self, x, freqs): q = self.norm_q(self.q(x)) k = self.norm_k(self.k(x)) v = self.v(x) q = rope_apply(q, freqs, self.num_heads) k = rope_apply(k, freqs, self.num_heads) x = self.attn(q, k, v) return self.o(x) class CrossAttention(nn.Module): def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False): super().__init__() self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads self.q = nn.Linear(dim, dim) self.k = nn.Linear(dim, dim) self.v = nn.Linear(dim, dim) self.o = nn.Linear(dim, dim) self.norm_q = RMSNorm(dim, eps=eps) self.norm_k = RMSNorm(dim, eps=eps) self.has_image_input = has_image_input if has_image_input: self.k_img = nn.Linear(dim, dim) self.v_img = nn.Linear(dim, dim) self.norm_k_img = RMSNorm(dim, eps=eps) self.attn = AttentionModule(self.num_heads) def forward(self, x: torch.Tensor, y: torch.Tensor): if self.has_image_input: img = y[:, :257] ctx = y[:, 257:] else: ctx = y q = self.norm_q(self.q(x)) k = self.norm_k(self.k(ctx)) v = self.v(ctx) x = self.attn(q, k, v) if self.has_image_input: k_img = self.norm_k_img(self.k_img(img)) v_img = self.v_img(img) y = flash_attention(q, k_img, v_img, num_heads=self.num_heads) x = x + y return self.o(x) class GateModule(nn.Module): def __init__(self,): super().__init__() def forward(self, x, gate, residual): return x + gate * residual class DiTBlock(nn.Module): def __init__( self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6, enable_instancev: bool = False, stape: Optional[SharedTimestepAdaptivePromptEnhancement] = None, ): super().__init__() self.dim = dim self.num_heads = num_heads self.ffn_dim = ffn_dim self.enable_instancev = enable_instancev self.stape = stape self.self_attn = SelfAttention(dim, num_heads, eps) self.cross_attn = CrossAttention(dim, num_heads, eps, has_image_input=has_image_input) # IMCA is inserted between self-attention and cross-attention as a residual branch if enable_instancev: self.imca = InstanceAwareMaskedCrossAttention(dim, num_heads, eps=eps) # zero-initialized gated parameter m_v (paper Eq. 4) self.mv = nn.Parameter(torch.zeros(1)) self.norm_imca = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) else: self.imca = None self.mv = None self.norm_imca = None self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) self.norm3 = nn.LayerNorm(dim, eps=eps) self.ffn = nn.Sequential( nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), nn.Linear(ffn_dim, dim), ) self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) self.gate = GateModule() # Better initialization for IMCA: copy weights from the native cross-attention (paper discussion) if enable_instancev and self.imca is not None: self._init_imca_from_cross_attention() def _init_imca_from_cross_attention(self): # copy q,k,v,o and norms try: self.imca.attn.q.load_state_dict(self.cross_attn.q.state_dict()) self.imca.attn.k.load_state_dict(self.cross_attn.k.state_dict()) self.imca.attn.v.load_state_dict(self.cross_attn.v.state_dict()) self.imca.attn.o.load_state_dict(self.cross_attn.o.state_dict()) self.imca.attn.norm_q.load_state_dict(self.cross_attn.norm_q.state_dict()) self.imca.attn.norm_k.load_state_dict(self.cross_attn.norm_k.state_dict()) except Exception: # if anything mismatches, skip silently (keeps compatibility) pass def forward( self, x: torch.Tensor, context: torch.Tensor, t_mod: torch.Tensor, freqs: torch.Tensor, instance_tokens: Optional[torch.Tensor] = None, instance_attn_mask: Optional[torch.Tensor] = None, empty_instance_tokens: Optional[torch.Tensor] = None, saug_drop_prob: float = 0.0, ): """Args: x: (B, F*H*W, D) context: global caption tokens T after embedding (B, Nctx, D) instance_tokens: I (B, F, Nins, D) after embedding instance_attn_mask: M (B, F, Nins, H*W) bool/float empty_instance_tokens: used for SAUG unconditional branch (same shape as instance_tokens) """ has_seq = len(t_mod.shape) == 4 chunk_dim = 2 if has_seq else 1 shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod ).chunk(6, dim=chunk_dim) if has_seq: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2), shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2), ) # 1) Self-attention (paper Eq. 3) input_x = modulate(self.norm1(x), shift_msa, scale_msa) x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) # 2) IMCA (paper Eq. 4) + STAPE (paper Eq. 6) if self.enable_instancev and (self.imca is not None) and (instance_tokens is not None) and (instance_attn_mask is not None): # SAUG training-time drop: keep spatial masks but empty the instance prompts with probability p if isinstance(saug_drop_prob, torch.Tensor): saug_p = float(saug_drop_prob.detach().cpu().item()) else: saug_p = float(saug_drop_prob) if self.training and saug_p > 0.0 and empty_instance_tokens is not None: if torch.rand((), device=x.device) < saug_p: instance_tokens_use = empty_instance_tokens else: instance_tokens_use = instance_tokens else: instance_tokens_use = instance_tokens # STAPE is shared across blocks (paper) if self.stape is not None: # reuse one AdaLN modulation vector as alpha1 (paper) alpha1 = gate_msa # (B, D) instance_tokens_use = self.stape(instance_tokens_use, context, alpha1=alpha1) # IMCA: masked cross-attn from visual tokens to instance tokens imca_out = self.imca(self.norm_imca(x), instance_tokens_use, instance_attn_mask) x = x + self.mv.to(dtype=x.dtype, device=x.device) * imca_out # 3) Native cross-attention with global caption tokens (paper Eq. 5) x = x + self.cross_attn(self.norm3(x), context) # 4) FFN input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) x = self.gate(x, gate_mlp, self.ffn(input_x)) return x class MLP(torch.nn.Module): def __init__(self, in_dim, out_dim, has_pos_emb=False): super().__init__() self.proj = torch.nn.Sequential( nn.LayerNorm(in_dim), nn.Linear(in_dim, in_dim), nn.GELU(), nn.Linear(in_dim, out_dim), nn.LayerNorm(out_dim) ) self.has_pos_emb = has_pos_emb if has_pos_emb: self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280))) def forward(self, x): if self.has_pos_emb: x = x + self.emb_pos.to(dtype=x.dtype, device=x.device) return self.proj(x) class Head(nn.Module): def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float): super().__init__() self.dim = dim self.patch_size = patch_size self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) def forward(self, x, t_mod): if len(t_mod.shape) == 3: shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2) x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2))) else: shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) x = (self.head(self.norm(x) * (1 + scale) + shift)) return x class WanModel(torch.nn.Module): def __init__( self, dim: int, in_dim: int, ffn_dim: int, out_dim: int, text_dim: int, freq_dim: int, eps: float, patch_size: Tuple[int, int, int], num_heads: int, num_layers: int, has_image_input: bool, enable_instancev: bool = False, has_image_pos_emb: bool = False, has_ref_conv: bool = False, add_control_adapter: bool = False, in_dim_control_adapter: int = 24, seperated_timestep: bool = False, require_vae_embedding: bool = True, require_clip_embedding: bool = True, fuse_vae_embedding_in_latents: bool = False, ): super().__init__() self.dim = dim self.in_dim = in_dim self.freq_dim = freq_dim self.has_image_input = has_image_input self.patch_size = patch_size self.seperated_timestep = seperated_timestep self.require_vae_embedding = require_vae_embedding self.require_clip_embedding = require_clip_embedding self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents self.patch_embedding = nn.Conv3d( in_dim, dim, kernel_size=patch_size, stride=patch_size) self.text_embedding = nn.Sequential( nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), nn.Linear(dim, dim) ) self.time_embedding = nn.Sequential( nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim) ) self.time_projection = nn.Sequential( nn.SiLU(), nn.Linear(dim, dim * 6)) self.enable_instancev = enable_instancev if enable_instancev: # STAPE is shared across all DiT blocks (paper Section 4.2) self.stape = SharedTimestepAdaptivePromptEnhancement(dim=dim, num_heads=num_heads, eps=eps) else: self.stape = None self.blocks = nn.ModuleList([ DiTBlock( has_image_input=has_image_input, dim=dim, num_heads=num_heads, ffn_dim=ffn_dim, eps=eps, enable_instancev=enable_instancev, stape=self.stape, ) for _ in range(num_layers) ]) self.head = Head(dim, out_dim, patch_size, eps) head_dim = dim // num_heads self.freqs = precompute_freqs_cis_3d(head_dim) if has_image_input: self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280 if has_ref_conv: self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2)) self.has_image_pos_emb = has_image_pos_emb self.has_ref_conv = has_ref_conv if add_control_adapter: self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:]) else: self.control_adapter = None def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None): x = self.patch_embedding(x) if self.control_adapter is not None and control_camera_latents_input is not None: y_camera = self.control_adapter(control_camera_latents_input) x = [u + v for u, v in zip(x, y_camera)] x = x[0].unsqueeze(0) return x def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): return rearrange( x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)', f=grid_size[0], h=grid_size[1], w=grid_size[2], x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2] ) def forward(self, x: torch.Tensor, timestep: torch.Tensor, context: torch.Tensor, instance_prompt_tokens: Optional[torch.Tensor] = None, instance_attn_mask: Optional[torch.Tensor] = None, empty_instance_prompt_tokens: Optional[torch.Tensor] = None, saug_drop_prob: float = 0.0, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, use_gradient_checkpointing: bool = False, use_gradient_checkpointing_offload: bool = False, **kwargs, ): t = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, timestep).to(x.dtype)) t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) context = self.text_embedding(context) # Instance prompt tokens (paper Section 4.1): encode each instance prompt with the same text embedding layer if instance_prompt_tokens is not None: instance_tokens = self.text_embedding(instance_prompt_tokens) else: instance_tokens = None if empty_instance_prompt_tokens is not None: empty_instance_tokens = self.text_embedding(empty_instance_prompt_tokens) else: empty_instance_tokens = None # If SAUG unconditional tokens are not provided but InstanceV is enabled, default to zeros. # (For best results, provide the pretrained tokens as described in the paper.) if self.enable_instancev and (instance_tokens is not None) and (empty_instance_tokens is None): empty_instance_tokens = torch.zeros_like(instance_tokens) if self.has_image_input: x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) clip_embdding = self.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) x, (f, h, w) = self.patchify(x) freqs = torch.cat([ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward for block in self.blocks: use_instancev_inputs = (self.enable_instancev and (instance_tokens is not None) and (instance_attn_mask is not None)) if self.training and use_gradient_checkpointing: if use_instancev_inputs: if use_gradient_checkpointing_offload: with torch.autograd.graph.save_on_cpu(): x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, context, t_mod, freqs, instance_tokens, instance_attn_mask, empty_instance_tokens, torch.tensor(float(saug_drop_prob), device=x.device, dtype=x.dtype), use_reentrant=False, ) else: x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, context, t_mod, freqs, instance_tokens, instance_attn_mask, empty_instance_tokens, torch.tensor(float(saug_drop_prob), device=x.device, dtype=x.dtype), use_reentrant=False, ) else: if use_gradient_checkpointing_offload: with torch.autograd.graph.save_on_cpu(): x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, context, t_mod, freqs, use_reentrant=False, ) else: x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, context, t_mod, freqs, use_reentrant=False, ) else: if use_instancev_inputs: x = block(x, context, t_mod, freqs, instance_tokens, instance_attn_mask, empty_instance_tokens, saug_drop_prob) else: x = block(x, context, t_mod, freqs) x = self.head(x, t) x = self.unpatchify(x, (f, h, w)) return x def apply_saug(eps_cond: torch.Tensor, eps_uncond: torch.Tensor, w: float) -> torch.Tensor: """Spatially-Aware Unconditional Guidance (SAUG), paper Eq. (7): eps_tilde = (1 + w) * eps_cond - w * eps_uncond where eps_uncond is predicted with *empty instance prompts* but the same spatial masks. """ return (1.0 + w) * eps_cond - w * eps_uncond