PencilFolder / diffsynth /models /wan_video_dit_instance.py
PencilHu's picture
Upload folder using huggingface_hub
1146a67 verified
"""
Wan Video DiT with instance-aware control (T5 semantics + bbox/mask).
This refactor keeps the original Wan DiT backbone while integrating:
- Instance tokens: `<class> is <state>` text (T5) + instance_id embedding.
- Mask-guided cross-attention: per-patch gating via bbox- or mask-projected hints.
- Backward compatibility: still accepts id-based class/state embeddings and pixel masks.
"""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .wan_video_camera_controller import SimpleAdapter
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
try:
from sageattention import sageattn
SAGE_ATTN_AVAILABLE = True
except ModuleNotFoundError:
SAGE_ATTN_AVAILABLE = False
# -----------------------------------------------------------------------------
# Common utils
# -----------------------------------------------------------------------------
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode: bool = False):
if compatibility_mode:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
elif FLASH_ATTN_3_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v)
if isinstance(x, tuple):
x = x[0]
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif FLASH_ATTN_2_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn.flash_attn_func(q, k, v)
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif SAGE_ATTN_AVAILABLE:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = sageattn(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
else:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
return x
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return (x * (1 + scale) + shift)
def sinusoidal_embedding_1d(dim, position):
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
10000, -torch.arange(dim // 2, dtype=torch.float64, device=position.device).div(dim // 2)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x.to(position.dtype)
def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
return f_freqs_cis, h_freqs_cis, w_freqs_cis
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].double() / dim))
freqs = torch.outer(torch.arange(end, device=freqs.device), freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2))
x_out = torch.view_as_real(x_out * freqs).flatten(2)
return x_out.to(x.dtype)
# -----------------------------------------------------------------------------
# Core blocks
# -----------------------------------------------------------------------------
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def forward(self, x):
dtype = x.dtype
return self.norm(x.float()).to(dtype) * self.weight
class AttentionModule(nn.Module):
def __init__(self, num_heads):
super().__init__()
self.num_heads = num_heads
def forward(self, q, k, v):
x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)
return x
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads)
def forward(self, x, freqs):
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
v = self.v(x)
q = rope_apply(q, freqs, self.num_heads)
k = rope_apply(k, freqs, self.num_heads)
x = self.attn(q, k, v)
return self.o(x)
class CrossAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.has_image_input = has_image_input
if has_image_input:
self.k_img = nn.Linear(dim, dim)
self.v_img = nn.Linear(dim, dim)
self.norm_k_img = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads)
def forward(self, x: torch.Tensor, y: torch.Tensor):
if self.has_image_input:
img = y[:, :257]
ctx = y[:, 257:]
else:
ctx = y
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(ctx))
v = self.v(ctx)
x = self.attn(q, k, v)
if self.has_image_input:
k_img = self.norm_k_img(self.k_img(img))
v_img = self.v_img(img)
y = flash_attention(q, k_img, v_img, num_heads=self.num_heads)
x = x + y
return self.o(x)
class GateModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, gate, residual):
return x + gate * residual
class MaskGuidedCrossAttention(nn.Module):
"""
每个 patch 只关注覆盖它的实例 token,使用 log-mask trick 保证 0 区域被强制屏蔽。
"""
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.to_q = nn.Linear(dim, dim, bias=False)
self.to_k = nn.Linear(dim, dim, bias=False)
self.to_v = nn.Linear(dim, dim, bias=False)
self.to_out = nn.Linear(dim, dim)
self.norm = nn.LayerNorm(dim, eps=eps)
self.gate = nn.Parameter(torch.zeros(1)) # zero-init for stability
def _attend(self, x: torch.Tensor, instance_tokens: torch.Tensor, instance_masks: torch.Tensor) -> torch.Tensor:
B, L, _ = x.shape
_, N, _ = instance_tokens.shape
if N == 0:
return x
if instance_masks.shape != (B, N, L):
raise ValueError(f"instance_masks shape mismatch, expect (B,N,L)=({B},{N},{L}), got {tuple(instance_masks.shape)}")
h = self.num_heads
q = rearrange(self.to_q(self.norm(x)), "b l (h d) -> b h l d", h=h)
k = rearrange(self.to_k(instance_tokens), "b n (h d) -> b h n d", h=h)
v = rearrange(self.to_v(instance_tokens), "b n (h d) -> b h n d", h=h)
sim = torch.einsum("b h l d, b h n d -> b h l n", q, k) * self.scale
mask_bias = instance_masks.transpose(1, 2).unsqueeze(1).to(dtype=sim.dtype)
sim = sim + torch.log(mask_bias.clamp(min=1e-6))
attn = sim.softmax(dim=-1)
out = torch.einsum("b h l n, b h n d -> b h l d", attn, v)
out = rearrange(out, "b h l d -> b l (h d)")
return x + self.gate * self.to_out(out)
def forward(self, x: torch.Tensor, instance_tokens: torch.Tensor, instance_masks: torch.Tensor) -> torch.Tensor:
"""
instance_tokens supports:
- (B, N, D): static tokens for the whole sequence
- (B, T, N, D): tokens per patch-time (sequence assumed laid out as T contiguous chunks)
- (B, L, N, D): tokens per token position (used for sequence parallel chunking)
"""
B, L, _ = x.shape
if instance_tokens.ndim == 3:
return self._attend(x, instance_tokens, instance_masks)
if instance_tokens.ndim != 4:
raise ValueError(f"instance_tokens must be 3D or 4D, got {tuple(instance_tokens.shape)}")
if instance_tokens.shape[1] == L:
# per-token instance tokens: (B, L, N, D)
_, _, N, _ = instance_tokens.shape
if instance_masks.shape != (B, N, L):
raise ValueError(f"instance_masks shape mismatch, expect (B,N,L)=({B},{N},{L}), got {tuple(instance_masks.shape)}")
h = self.num_heads
q = rearrange(self.to_q(self.norm(x)), "b l (h d) -> b h l d", h=h)
k = rearrange(self.to_k(instance_tokens), "b l n (h d) -> b h l n d", h=h)
v = rearrange(self.to_v(instance_tokens), "b l n (h d) -> b h l n d", h=h)
sim = torch.einsum("b h l d, b h l n d -> b h l n", q, k) * self.scale
mask_bias = instance_masks.transpose(1, 2).unsqueeze(1).to(dtype=sim.dtype)
sim = sim + torch.log(mask_bias.clamp(min=1e-6))
attn = sim.softmax(dim=-1)
out = torch.einsum("b h l n, b h l n d -> b h l d", attn, v)
out = rearrange(out, "b h l d -> b l (h d)")
return x + self.gate * self.to_out(out)
# per-time instance tokens: (B, T, N, D)
_, T, _, _ = instance_tokens.shape
if L % T != 0:
raise ValueError(f"Token length L={L} must be divisible by T={T} for per-time instance tokens.")
hw = L // T
chunks = []
for t in range(T):
s, e = t * hw, (t + 1) * hw
chunks.append(self._attend(x[:, s:e], instance_tokens[:, t], instance_masks[:, :, s:e]))
return torch.cat(chunks, dim=1)
class DiTBlock(nn.Module):
def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.ffn_dim = ffn_dim
self.self_attn = SelfAttention(dim, num_heads, eps)
self.cross_attn = CrossAttention(dim, num_heads, eps, has_image_input=has_image_input)
self.instance_cross_attn = MaskGuidedCrossAttention(dim, num_heads, eps)
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.norm3 = nn.LayerNorm(dim, eps=eps)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim),
nn.GELU(approximate='tanh'),
nn.Linear(ffn_dim, dim),
)
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim ** 0.5)
self.gate = GateModule()
def forward(self, x, context, t_mod, freqs, instance_tokens=None, instance_masks=None):
has_seq = len(t_mod.shape) == 4
chunk_dim = 2 if has_seq else 1
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod
).chunk(6, dim=chunk_dim)
if has_seq:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2),
shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2),
)
# Self-attention with AdaLN modulation
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
# Text / image cross-attention
x = x + self.cross_attn(self.norm3(x), context)
# Instance-guided cross-attention
if instance_tokens is not None and instance_masks is not None:
x = self.instance_cross_attn(x, instance_tokens, instance_masks)
# FFN with AdaLN modulation
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = self.gate(x, gate_mlp, self.ffn(input_x))
return x
class MLP(torch.nn.Module):
def __init__(self, in_dim, out_dim, has_pos_emb=False):
super().__init__()
self.proj = torch.nn.Sequential(
nn.LayerNorm(in_dim),
nn.Linear(in_dim, in_dim),
nn.GELU(),
nn.Linear(in_dim, out_dim),
nn.LayerNorm(out_dim)
)
self.has_pos_emb = has_pos_emb
if has_pos_emb:
self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
def forward(self, x):
if self.has_pos_emb:
x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
return self.proj(x)
class Head(nn.Module):
def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
super().__init__()
self.dim = dim
self.patch_size = patch_size
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim ** 0.5)
def forward(self, x, t_mod):
if len(t_mod.shape) == 3:
shift, scale = (
self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)
).chunk(2, dim=2)
x = self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2))
else:
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
x = self.head(self.norm(x) * (1 + scale) + shift)
return x
class InstanceFeatureExtractor(nn.Module):
"""
将 `instance_id` 与 (class/state 组合短语) 的文本语义融合为实例 token,并支持按时间(帧/patch-time)
的 state weights 做动态加权:
- 输入:`state_text_embeds_multi` 形状 (B, N, S, text_dim),其中每个 state 对应短语 `"<class> is <state>"`
- 输入:`state_weights` 形状 (B, N, F, S),F 为帧数(或任意时间长度),每帧对 S 个 state 的权重
- 输出:实例 token 形状 (B, T, N, D),T 为时间长度(可选下采样到 patch-time)
"""
def __init__(
self,
num_instance_ids: int = 1000,
embedding_dim: int = 1280,
hidden_dim: int = 1280,
text_dim: int = 4096,
):
super().__init__()
self.inst_id_emb = nn.Embedding(num_instance_ids, hidden_dim, padding_idx=0)
self.text_proj = nn.Sequential(
nn.Linear(int(text_dim), hidden_dim, bias=False),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim, bias=False),
nn.LayerNorm(hidden_dim),
)
self.fusion = nn.Sequential(
nn.Linear(hidden_dim * 2, embedding_dim),
nn.SiLU(),
nn.Linear(embedding_dim, embedding_dim),
nn.LayerNorm(embedding_dim),
)
@staticmethod
def _pool_time_to_patches(weights: torch.Tensor, num_time_patches: int) -> torch.Tensor:
"""
Average-pool per-frame weights (B,N,F,S) to per-patch-time weights (B,N,T,S),
where mapping uses pt = floor(t * T / F).
"""
if weights.ndim != 4:
raise ValueError(f"state_weights must be (B,N,F,S), got {tuple(weights.shape)}")
B, N, F, S = weights.shape
T = int(num_time_patches)
if T <= 0:
raise ValueError("num_time_patches must be > 0")
if F == T:
return weights
device = weights.device
idx = (torch.arange(F, device=device, dtype=torch.float32) * (T / max(float(F), 1.0))).floor().clamp(0, T - 1).long()
idx = idx.view(1, 1, F, 1).expand(B, N, F, S)
out = torch.zeros((B, N, T, S), device=device, dtype=weights.dtype)
out.scatter_add_(2, idx, weights)
cnt = torch.zeros((B, N, T, S), device=device, dtype=weights.dtype)
cnt.scatter_add_(2, idx, torch.ones_like(weights))
return out / cnt.clamp(min=1.0)
def forward(
self,
instance_ids: torch.Tensor,
state_text_embeds_multi: torch.Tensor,
state_weights: torch.Tensor,
num_time_patches: Optional[int] = None,
):
if state_text_embeds_multi is None:
raise ValueError("state_text_embeds_multi is required.")
if state_weights is None:
raise ValueError("state_weights is required.")
if state_text_embeds_multi.ndim != 4:
raise ValueError(f"state_text_embeds_multi must be (B,N,S,D), got {tuple(state_text_embeds_multi.shape)}")
if state_weights.ndim != 4:
raise ValueError(f"state_weights must be (B,N,F,S), got {tuple(state_weights.shape)}")
B, N, S, _ = state_text_embeds_multi.shape
if instance_ids.shape[:2] != (B, N):
raise ValueError(f"instance_ids must be (B,N)=({B},{N}), got {tuple(instance_ids.shape)}")
if state_weights.shape[0] != B or state_weights.shape[1] != N or state_weights.shape[-1] != S:
raise ValueError(f"state_weights must be (B,N,F,S)=({B},{N},F,{S}), got {tuple(state_weights.shape)}")
sem_multi = self.text_proj(state_text_embeds_multi) # (B,N,S,H)
weights = state_weights.to(dtype=sem_multi.dtype, device=sem_multi.device).clamp(min=0)
if num_time_patches is not None:
weights = self._pool_time_to_patches(weights, int(num_time_patches))
# (B,N,T,S,H) -> (B,N,T,H)
sem_multi = sem_multi.unsqueeze(2)
weights = weights.unsqueeze(-1)
denom = weights.sum(dim=3).clamp(min=1e-6)
sem_time = (sem_multi * weights).sum(dim=3) / denom # (B,N,T,H)
i_feat = self.inst_id_emb(instance_ids % self.inst_id_emb.num_embeddings).to(dtype=sem_time.dtype, device=sem_time.device) # (B,N,H)
i_time = i_feat.unsqueeze(2).expand(-1, -1, sem_time.shape[2], -1)
tokens = self.fusion(torch.cat([sem_time, i_time], dim=-1)) # (B,N,T,D)
return tokens.permute(0, 2, 1, 3).contiguous() # (B,T,N,D)
# -----------------------------------------------------------------------------
# Main model
# -----------------------------------------------------------------------------
class WanModel(torch.nn.Module):
def __init__(
self,
dim: int,
in_dim: int,
ffn_dim: int,
out_dim: int,
text_dim: int,
freq_dim: int,
eps: float,
patch_size: Tuple[int, int, int],
num_heads: int,
num_layers: int,
has_image_input: bool,
has_image_pos_emb: bool = False,
has_ref_conv: bool = False,
add_control_adapter: bool = False,
in_dim_control_adapter: int = 24,
seperated_timestep: bool = False,
require_vae_embedding: bool = True,
require_clip_embedding: bool = True,
fuse_vae_embedding_in_latents: bool = False,
# instance control
num_class_ids: int = 200,
num_state_ids: int = 100,
num_instance_ids: int = 1000,
state_feature_dim: int = 256,
instance_text_dim: Optional[int] = 4096,
):
super().__init__()
self.dim = dim
self.in_dim = in_dim
self.freq_dim = freq_dim
self.has_image_input = has_image_input
self.patch_size = patch_size
self.seperated_timestep = seperated_timestep
self.require_vae_embedding = require_vae_embedding
self.require_clip_embedding = require_clip_embedding
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim),
nn.GELU(approximate="tanh"),
nn.Linear(dim, dim),
)
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim),
)
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
self.blocks = nn.ModuleList([DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps) for _ in range(num_layers)])
self.head = Head(dim, out_dim, patch_size, eps)
head_dim = dim // num_heads
self.freqs = precompute_freqs_cis_3d(head_dim)
if has_image_input:
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb)
if has_ref_conv:
self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2))
self.has_image_pos_emb = has_image_pos_emb
self.has_ref_conv = has_ref_conv
if add_control_adapter:
self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:])
else:
self.control_adapter = None
instance_text_dim = int(text_dim) if instance_text_dim is None else int(instance_text_dim)
self.instance_encoder = InstanceFeatureExtractor(
num_instance_ids=num_instance_ids,
embedding_dim=dim,
hidden_dim=dim,
text_dim=instance_text_dim,
)
self.instance_text_dim = instance_text_dim
# ------------------------------ patch helpers ------------------------------ #
def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None):
"""
Returns:
tokens: (B, L, D)
grid_size: (F_p, H_p, W_p)
"""
x = self.patch_embedding(x) # (B, D, F_p, H_p, W_p)
if self.control_adapter is not None and control_camera_latents_input is not None:
y_camera = self.control_adapter(control_camera_latents_input)
if isinstance(y_camera, (list, tuple)):
x = x + y_camera[0]
else:
x = x + y_camera
grid_size = x.shape[2:]
x = rearrange(x, "b c f h w -> b (f h w) c").contiguous()
return x, grid_size
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
return rearrange(
x, "b (f h w) (x y z c) -> b c (f x) (h y) (w z)",
f=grid_size[0], h=grid_size[1], w=grid_size[2],
x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2],
)
# ------------------------------ masks ------------------------------ #
def process_masks(
self,
grid_size,
image_size: Tuple[int, int, int],
bboxes: torch.Tensor,
bbox_mask: Optional[torch.Tensor] = None,
):
"""
bbox-only path:
bboxes: (B, N, F, 4) or (B, N, 4), xyxy in pixel coords
bbox_mask: (B, N, F) or (B, N, 1), optional existence mask
Returns:
(B, N, L) flattened patch mask
"""
if bboxes is None:
raise ValueError("bboxes must be provided for instance control.")
return self._bboxes_to_masks(bboxes, bbox_mask, grid_size, image_size)
def _bboxes_to_masks(
self,
bboxes: torch.Tensor,
bbox_mask: Optional[torch.Tensor],
grid_size: Tuple[int, int, int],
image_size: Tuple[int, int, int],
):
f_p, h_p, w_p = grid_size
F_img, H_img, W_img = image_size
# Notes on coordinate space:
# - bboxes are interpreted in the same pixel space as (H_img, W_img)
# - projection to patch grid uses ratio (w_p / W_img) and (h_p / H_img)
# - time index is mapped by ratio (f_p / F_bbox)
if bboxes.ndim == 3: # (B, N, 4) -> broadcast to frames
bboxes = bboxes.unsqueeze(2).expand(-1, -1, f_p, -1)
if bboxes.ndim != 4 or bboxes.shape[-1] != 4:
raise ValueError(f"bboxes must be (B,N,F,4) or (B,N,4); got {tuple(bboxes.shape)}")
if bbox_mask is None:
bbox_mask = torch.ones(bboxes.shape[:3], device=bboxes.device, dtype=torch.bool)
else:
if bbox_mask.ndim == 3:
pass
elif bbox_mask.ndim == 2:
bbox_mask = bbox_mask.unsqueeze(-1).expand(-1, -1, bboxes.shape[2])
else:
raise ValueError(f"bbox_mask must be (B,N,F) or (B,N,1); got {tuple(bbox_mask.shape)}")
bbox_mask = bbox_mask.to(dtype=torch.bool, device=bboxes.device)
mask = bboxes.new_zeros((bboxes.shape[0], bboxes.shape[1], f_p, h_p, w_p), dtype=torch.float32)
f_bbox = int(bboxes.shape[2])
w_scale = (w_p / max(float(W_img), 1.0))
h_scale = (h_p / max(float(H_img), 1.0))
for b in range(bboxes.shape[0]):
for n in range(bboxes.shape[1]):
for t in range(f_bbox):
if not bbox_mask[b, n, t]:
continue
x0, y0, x1, y1 = bboxes[b, n, t]
x0 = max(0, min(float(x0), W_img))
x1 = max(0, min(float(x1), W_img))
y0 = max(0, min(float(y0), H_img))
y1 = max(0, min(float(y1), H_img))
if x1 <= x0 or y1 <= y0:
continue
px0 = int(math.floor(x0 * w_scale))
py0 = int(math.floor(y0 * h_scale))
px1 = int(math.ceil(x1 * w_scale))
py1 = int(math.ceil(y1 * h_scale))
px0 = max(0, min(px0, w_p - 1))
py0 = max(0, min(py0, h_p - 1))
px1 = max(px0 + 1, min(px1, w_p))
py1 = max(py0 + 1, min(py1, h_p))
pt = min(int(math.floor(t * f_p / max(f_bbox, 1))), f_p - 1)
mask[b, n, pt, py0:py1, px0:px1] = 1.0
mask_flat = rearrange(mask, "b n f h w -> b n (f h w)")
return mask_flat
# ------------------------------ forward ------------------------------ #
def forward(
self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
# instance inputs (bbox-based)
instance_ids: Optional[torch.Tensor] = None, # (B, N)
instance_state_text_embeds_multi: Optional[torch.Tensor] = None, # (B, N, S, text_dim)
instance_state_weights: Optional[torch.Tensor] = None, # (B, N, F, S) weights per frame
instance_bboxes: Optional[torch.Tensor] = None, # (B, N, F, 4)
**kwargs,
):
# Timestep embedding
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep).to(x.dtype))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
# Text embedding
context = self.text_embedding(context)
# Image conditioning
if self.has_image_input:
x = torch.cat([x, y], dim=1) # (B, Cx+Cy, F, H, W)
clip_embedding = self.img_emb(clip_feature)
context = torch.cat([clip_embedding, context], dim=1)
orig_frames, orig_h, orig_w = x.shape[2:]
x, (f, h, w) = self.patchify(x)
grid_size = (f, h, w)
# Instance control
inst_tokens = None
inst_mask_flat = None
has_instance = (
instance_ids is not None
and instance_bboxes is not None
and instance_state_text_embeds_multi is not None
and instance_state_weights is not None
and instance_ids.shape[1] > 0
)
if has_instance:
inst_tokens = self.instance_encoder(
instance_ids=instance_ids,
state_text_embeds_multi=instance_state_text_embeds_multi,
state_weights=instance_state_weights,
num_time_patches=f,
)
inst_mask_flat = self.process_masks(
grid_size,
image_size=(orig_frames, orig_h, orig_w),
bboxes=instance_bboxes,
)
# RoPE
freqs = torch.cat([
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
def create_custom_forward_with_instance(module):
def custom_forward(x, context, t_mod, freqs, instance_tokens, instance_masks):
return module(x, context, t_mod, freqs, instance_tokens=instance_tokens, instance_masks=instance_masks)
return custom_forward
for block in self.blocks:
use_instance = inst_tokens is not None and inst_mask_flat is not None
if self.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
if use_instance:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward_with_instance(block),
x, context, t_mod, freqs, inst_tokens, inst_mask_flat,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
if use_instance:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward_with_instance(block),
x, context, t_mod, freqs, inst_tokens, inst_mask_flat,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
if use_instance:
x = block(x, context, t_mod, freqs, instance_tokens=inst_tokens, instance_masks=inst_mask_flat)
else:
x = block(x, context, t_mod, freqs)
x = self.head(x, t)
x = self.unpatchify(x, (f, h, w))
return x