PencilFolder / diffsynth /models /wan_video_dit_instancev.py
PencilHu's picture
Upload folder using huggingface_hub
1146a67 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple, Optional
from einops import rearrange
from .wan_video_camera_controller import SimpleAdapter
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
try:
from sageattention import sageattn
SAGE_ATTN_AVAILABLE = True
except ModuleNotFoundError:
SAGE_ATTN_AVAILABLE = False
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False):
if compatibility_mode:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
elif FLASH_ATTN_3_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v)
if isinstance(x,tuple):
x = x[0]
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif FLASH_ATTN_2_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn.flash_attn_func(q, k, v)
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif SAGE_ATTN_AVAILABLE:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = sageattn(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
else:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
return x
def scaled_dot_product_attention_with_mask(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
num_heads: int,
attn_mask: Optional[torch.Tensor],
):
"""Always uses PyTorch SDPA because FlashAttention variants may not support arbitrary masks.
Args:
q,k,v: (B, S, D)
attn_mask: float mask broadcastable to (B, num_heads, Sq, Sk) with 0 for allowed, -inf for disallowed
or bool mask broadcastable where False indicates disallowed.
"""
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
if attn_mask is not None:
# Make it broadcastable to (B, n, Sq, Sk)
if attn_mask.dtype == torch.bool:
mask = attn_mask
else:
mask = attn_mask
if attn_mask.dim() == 3:
mask = mask.unsqueeze(1)
elif attn_mask.dim() != 4:
raise ValueError(f"attn_mask must be 3D or 4D, got shape={attn_mask.shape}")
x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
else:
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
return x
class MaskedCrossAttention(nn.Module):
"""Cross-attention with explicit attention mask support (used by IMCA)."""
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
def forward(self, x_q: torch.Tensor, x_kv: torch.Tensor, attn_mask: Optional[torch.Tensor]):
q = self.norm_q(self.q(x_q))
k = self.norm_k(self.k(x_kv))
v = self.v(x_kv)
x = scaled_dot_product_attention_with_mask(q, k, v, num_heads=self.num_heads, attn_mask=attn_mask)
return self.o(x)
class SharedTimestepAdaptivePromptEnhancement(nn.Module):
"""STAPE: I = I + (m_t + alpha1) * CrossAttn(I, T)."""
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
super().__init__()
self.attn = CrossAttention(dim, num_heads, eps=eps, has_image_input=False)
# Dins-dimensional learnable residual gate (initialized to 0 for stability)
self.mt = nn.Parameter(torch.zeros(1, dim))
def forward(self, instance_tokens: torch.Tensor, caption_tokens: torch.Tensor, alpha1: torch.Tensor):
# instance_tokens: (B, F, Nins, D) caption_tokens: (B, Nctx, D) alpha1: (B, D)
B, F_, Nins, D = instance_tokens.shape
I = instance_tokens.reshape(B, F_ * Nins, D)
delta = self.attn(I, caption_tokens) # (B, F*Nins, D)
gate = (self.mt.to(dtype=I.dtype, device=I.device) + alpha1).unsqueeze(1) # (B, 1, D)
I = I + gate * delta
return I.reshape(B, F_, Nins, D)
class InstanceAwareMaskedCrossAttention(nn.Module):
"""IMCA: masked cross-attention from visual tokens to instance prompt tokens."""
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
super().__init__()
self.attn = MaskedCrossAttention(dim, num_heads, eps=eps)
def forward(self, visual_tokens: torch.Tensor, instance_tokens: torch.Tensor, attn_mask: torch.Tensor):
"""Args:
visual_tokens: (B, F*HW, D)
instance_tokens: (B, F, Nins, D)
attn_mask: (B, F, Nins, HW) bool OR float, where True/1 means instance-token attends this visual token.
Returns:
(B, F*HW, D)
"""
B, Nv, D = visual_tokens.shape
_, F_, Nins, _ = instance_tokens.shape
HW = Nv // F_
V = visual_tokens.reshape(B, F_, HW, D)
I = instance_tokens
# Convert mask to (B*F, HW, Nins) with 0 / -inf
M = attn_mask
if M.shape[-1] != HW:
raise ValueError(f"attn_mask last dim must be HW={HW}, got {M.shape[-1]}")
# (B,F,Nins,HW) -> (B,F,HW,Nins)
M = M.permute(0, 1, 3, 2).contiguous()
# 使用与 visual_tokens 相同的 dtype(通常是 bfloat16)
target_dtype = visual_tokens.dtype
if M.dtype == torch.bool:
sdpa_mask = torch.where(M, torch.zeros((), device=M.device, dtype=target_dtype),
torch.full((), float("-inf"), device=M.device, dtype=target_dtype))
else:
# assume already 0/-inf or similar
sdpa_mask = M.to(dtype=target_dtype)
# Merge batch and frame
V_bf = V.reshape(B * F_, HW, D)
I_bf = I.reshape(B * F_, Nins, D)
sdpa_mask_bf = sdpa_mask.reshape(B * F_, HW, Nins)
out = self.attn(V_bf, I_bf, sdpa_mask_bf)
return out.reshape(B, F_ * HW, D)
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return (x * (1 + scale) + shift)
def sinusoidal_embedding_1d(dim, position):
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x.to(position.dtype)
def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
# 3d rope precompute
f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
return f_freqs_cis, h_freqs_cis, w_freqs_cis
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
# 1d rope precompute
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
[: (dim // 2)].double() / dim))
freqs = torch.outer(torch.arange(end, device=freqs.device), freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2))
x_out = torch.view_as_real(x_out * freqs).flatten(2)
return x_out.to(x.dtype)
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def forward(self, x):
dtype = x.dtype
return self.norm(x.float()).to(dtype) * self.weight
class AttentionModule(nn.Module):
def __init__(self, num_heads):
super().__init__()
self.num_heads = num_heads
def forward(self, q, k, v):
x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)
return x
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads)
def forward(self, x, freqs):
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
v = self.v(x)
q = rope_apply(q, freqs, self.num_heads)
k = rope_apply(k, freqs, self.num_heads)
x = self.attn(q, k, v)
return self.o(x)
class CrossAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.has_image_input = has_image_input
if has_image_input:
self.k_img = nn.Linear(dim, dim)
self.v_img = nn.Linear(dim, dim)
self.norm_k_img = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads)
def forward(self, x: torch.Tensor, y: torch.Tensor):
if self.has_image_input:
img = y[:, :257]
ctx = y[:, 257:]
else:
ctx = y
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(ctx))
v = self.v(ctx)
x = self.attn(q, k, v)
if self.has_image_input:
k_img = self.norm_k_img(self.k_img(img))
v_img = self.v_img(img)
y = flash_attention(q, k_img, v_img, num_heads=self.num_heads)
x = x + y
return self.o(x)
class GateModule(nn.Module):
def __init__(self,):
super().__init__()
def forward(self, x, gate, residual):
return x + gate * residual
class DiTBlock(nn.Module):
def __init__(
self,
has_image_input: bool,
dim: int,
num_heads: int,
ffn_dim: int,
eps: float = 1e-6,
enable_instancev: bool = False,
stape: Optional[SharedTimestepAdaptivePromptEnhancement] = None,
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.ffn_dim = ffn_dim
self.enable_instancev = enable_instancev
self.stape = stape
self.self_attn = SelfAttention(dim, num_heads, eps)
self.cross_attn = CrossAttention(dim, num_heads, eps, has_image_input=has_image_input)
# IMCA is inserted between self-attention and cross-attention as a residual branch
if enable_instancev:
self.imca = InstanceAwareMaskedCrossAttention(dim, num_heads, eps=eps)
# zero-initialized gated parameter m_v (paper Eq. 4)
self.mv = nn.Parameter(torch.zeros(1))
self.norm_imca = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
else:
self.imca = None
self.mv = None
self.norm_imca = None
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.norm3 = nn.LayerNorm(dim, eps=eps)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim),
nn.GELU(approximate='tanh'),
nn.Linear(ffn_dim, dim),
)
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
self.gate = GateModule()
# Better initialization for IMCA: copy weights from the native cross-attention (paper discussion)
if enable_instancev and self.imca is not None:
self._init_imca_from_cross_attention()
def _init_imca_from_cross_attention(self):
# copy q,k,v,o and norms
try:
self.imca.attn.q.load_state_dict(self.cross_attn.q.state_dict())
self.imca.attn.k.load_state_dict(self.cross_attn.k.state_dict())
self.imca.attn.v.load_state_dict(self.cross_attn.v.state_dict())
self.imca.attn.o.load_state_dict(self.cross_attn.o.state_dict())
self.imca.attn.norm_q.load_state_dict(self.cross_attn.norm_q.state_dict())
self.imca.attn.norm_k.load_state_dict(self.cross_attn.norm_k.state_dict())
except Exception:
# if anything mismatches, skip silently (keeps compatibility)
pass
def forward(
self,
x: torch.Tensor,
context: torch.Tensor,
t_mod: torch.Tensor,
freqs: torch.Tensor,
instance_tokens: Optional[torch.Tensor] = None,
instance_attn_mask: Optional[torch.Tensor] = None,
empty_instance_tokens: Optional[torch.Tensor] = None,
saug_drop_prob: float = 0.0,
):
"""Args:
x: (B, F*H*W, D)
context: global caption tokens T after embedding (B, Nctx, D)
instance_tokens: I (B, F, Nins, D) after embedding
instance_attn_mask: M (B, F, Nins, H*W) bool/float
empty_instance_tokens: used for SAUG unconditional branch (same shape as instance_tokens)
"""
has_seq = len(t_mod.shape) == 4
chunk_dim = 2 if has_seq else 1
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod
).chunk(6, dim=chunk_dim)
if has_seq:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2),
shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2),
)
# 1) Self-attention (paper Eq. 3)
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
# 2) IMCA (paper Eq. 4) + STAPE (paper Eq. 6)
if self.enable_instancev and (self.imca is not None) and (instance_tokens is not None) and (instance_attn_mask is not None):
# SAUG training-time drop: keep spatial masks but empty the instance prompts with probability p
if isinstance(saug_drop_prob, torch.Tensor):
saug_p = float(saug_drop_prob.detach().cpu().item())
else:
saug_p = float(saug_drop_prob)
if self.training and saug_p > 0.0 and empty_instance_tokens is not None:
if torch.rand((), device=x.device) < saug_p:
instance_tokens_use = empty_instance_tokens
else:
instance_tokens_use = instance_tokens
else:
instance_tokens_use = instance_tokens
# STAPE is shared across blocks (paper)
if self.stape is not None:
# reuse one AdaLN modulation vector as alpha1 (paper)
alpha1 = gate_msa # (B, D)
instance_tokens_use = self.stape(instance_tokens_use, context, alpha1=alpha1)
# IMCA: masked cross-attn from visual tokens to instance tokens
imca_out = self.imca(self.norm_imca(x), instance_tokens_use, instance_attn_mask)
x = x + self.mv.to(dtype=x.dtype, device=x.device) * imca_out
# 3) Native cross-attention with global caption tokens (paper Eq. 5)
x = x + self.cross_attn(self.norm3(x), context)
# 4) FFN
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = self.gate(x, gate_mlp, self.ffn(input_x))
return x
class MLP(torch.nn.Module):
def __init__(self, in_dim, out_dim, has_pos_emb=False):
super().__init__()
self.proj = torch.nn.Sequential(
nn.LayerNorm(in_dim),
nn.Linear(in_dim, in_dim),
nn.GELU(),
nn.Linear(in_dim, out_dim),
nn.LayerNorm(out_dim)
)
self.has_pos_emb = has_pos_emb
if has_pos_emb:
self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
def forward(self, x):
if self.has_pos_emb:
x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
return self.proj(x)
class Head(nn.Module):
def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
super().__init__()
self.dim = dim
self.patch_size = patch_size
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
def forward(self, x, t_mod):
if len(t_mod.shape) == 3:
shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2)
x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2)))
else:
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + scale) + shift))
return x
class WanModel(torch.nn.Module):
def __init__(
self,
dim: int,
in_dim: int,
ffn_dim: int,
out_dim: int,
text_dim: int,
freq_dim: int,
eps: float,
patch_size: Tuple[int, int, int],
num_heads: int,
num_layers: int,
has_image_input: bool,
enable_instancev: bool = False,
has_image_pos_emb: bool = False,
has_ref_conv: bool = False,
add_control_adapter: bool = False,
in_dim_control_adapter: int = 24,
seperated_timestep: bool = False,
require_vae_embedding: bool = True,
require_clip_embedding: bool = True,
fuse_vae_embedding_in_latents: bool = False,
):
super().__init__()
self.dim = dim
self.in_dim = in_dim
self.freq_dim = freq_dim
self.has_image_input = has_image_input
self.patch_size = patch_size
self.seperated_timestep = seperated_timestep
self.require_vae_embedding = require_vae_embedding
self.require_clip_embedding = require_clip_embedding
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim),
nn.GELU(approximate='tanh'),
nn.Linear(dim, dim)
)
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim)
)
self.time_projection = nn.Sequential(
nn.SiLU(), nn.Linear(dim, dim * 6))
self.enable_instancev = enable_instancev
if enable_instancev:
# STAPE is shared across all DiT blocks (paper Section 4.2)
self.stape = SharedTimestepAdaptivePromptEnhancement(dim=dim, num_heads=num_heads, eps=eps)
else:
self.stape = None
self.blocks = nn.ModuleList([
DiTBlock(
has_image_input=has_image_input,
dim=dim,
num_heads=num_heads,
ffn_dim=ffn_dim,
eps=eps,
enable_instancev=enable_instancev,
stape=self.stape,
)
for _ in range(num_layers)
])
self.head = Head(dim, out_dim, patch_size, eps)
head_dim = dim // num_heads
self.freqs = precompute_freqs_cis_3d(head_dim)
if has_image_input:
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
if has_ref_conv:
self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2))
self.has_image_pos_emb = has_image_pos_emb
self.has_ref_conv = has_ref_conv
if add_control_adapter:
self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:])
else:
self.control_adapter = None
def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None):
x = self.patch_embedding(x)
if self.control_adapter is not None and control_camera_latents_input is not None:
y_camera = self.control_adapter(control_camera_latents_input)
x = [u + v for u, v in zip(x, y_camera)]
x = x[0].unsqueeze(0)
return x
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
return rearrange(
x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
f=grid_size[0], h=grid_size[1], w=grid_size[2],
x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
)
def forward(self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
instance_prompt_tokens: Optional[torch.Tensor] = None,
instance_attn_mask: Optional[torch.Tensor] = None,
empty_instance_prompt_tokens: Optional[torch.Tensor] = None,
saug_drop_prob: float = 0.0,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
):
t = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep).to(x.dtype))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
context = self.text_embedding(context)
# Instance prompt tokens (paper Section 4.1): encode each instance prompt with the same text embedding layer
if instance_prompt_tokens is not None:
instance_tokens = self.text_embedding(instance_prompt_tokens)
else:
instance_tokens = None
if empty_instance_prompt_tokens is not None:
empty_instance_tokens = self.text_embedding(empty_instance_prompt_tokens)
else:
empty_instance_tokens = None
# If SAUG unconditional tokens are not provided but InstanceV is enabled, default to zeros.
# (For best results, provide the pretrained <extra_id> tokens as described in the paper.)
if self.enable_instancev and (instance_tokens is not None) and (empty_instance_tokens is None):
empty_instance_tokens = torch.zeros_like(instance_tokens)
if self.has_image_input:
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
clip_embdding = self.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
x, (f, h, w) = self.patchify(x)
freqs = torch.cat([
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.blocks:
use_instancev_inputs = (self.enable_instancev and (instance_tokens is not None) and (instance_attn_mask is not None))
if self.training and use_gradient_checkpointing:
if use_instancev_inputs:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
instance_tokens, instance_attn_mask, empty_instance_tokens,
torch.tensor(float(saug_drop_prob), device=x.device, dtype=x.dtype),
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
instance_tokens, instance_attn_mask, empty_instance_tokens,
torch.tensor(float(saug_drop_prob), device=x.device, dtype=x.dtype),
use_reentrant=False,
)
else:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
if use_instancev_inputs:
x = block(x, context, t_mod, freqs, instance_tokens, instance_attn_mask, empty_instance_tokens, saug_drop_prob)
else:
x = block(x, context, t_mod, freqs)
x = self.head(x, t)
x = self.unpatchify(x, (f, h, w))
return x
def apply_saug(eps_cond: torch.Tensor, eps_uncond: torch.Tensor, w: float) -> torch.Tensor:
"""Spatially-Aware Unconditional Guidance (SAUG), paper Eq. (7):
eps_tilde = (1 + w) * eps_cond - w * eps_uncond
where eps_uncond is predicted with *empty instance prompts* but the same spatial masks.
"""
return (1.0 + w) * eps_cond - w * eps_uncond