| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Before-denoise blocks for WorldEngine modular pipeline.""" |
|
|
| from typing import List, Optional, Union |
|
|
| import PIL.Image |
| import torch |
| from torch import nn, Tensor |
| from tensordict import TensorDict |
| from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE, BlockMask |
|
|
| from diffusers.configuration_utils import FrozenDict |
| from diffusers.image_processor import VaeImageProcessor |
| from diffusers.utils import logging |
| from diffusers.utils.torch_utils import randn_tensor |
| from diffusers.modular_pipelines import ( |
| ModularPipelineBlocks, |
| ModularPipeline, |
| PipelineState, |
| SequentialPipelineBlocks, |
| ) |
| from diffusers.modular_pipelines.modular_pipeline_utils import ( |
| ComponentSpec, |
| ConfigSpec, |
| InputParam, |
| OutputParam, |
| ) |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| def make_block_mask(T: int, L: int, written: torch.Tensor) -> BlockMask: |
| """ |
| Create a block mask for flex_attention. |
| |
| Args: |
| T: Q length for this frame |
| L: KV capacity == written.numel() |
| written: [L] bool, True where there is valid KV data |
| """ |
| BS = _DEFAULT_SPARSE_BLOCK_SIZE |
| KV_blocks = (L + BS - 1) // BS |
| Q_blocks = (T + BS - 1) // BS |
|
|
| |
| written_blocks = torch.nn.functional.pad(written, (0, KV_blocks * BS - L)).view( |
| KV_blocks, BS |
| ) |
|
|
| |
| block_any = written_blocks.any(-1) |
| block_all = written_blocks.all(-1) |
|
|
| |
| nonzero_bm = block_any[None, :].expand(Q_blocks, KV_blocks) |
| full_bm = block_all[None, :].expand_as(nonzero_bm) |
| partial_bm = nonzero_bm & ~full_bm |
|
|
| def dense_to_ordered(dense_mask: torch.Tensor): |
| |
| |
| num_blocks = dense_mask.sum(dim=-1, dtype=torch.int32) |
| indices = dense_mask.argsort(dim=-1, descending=True, stable=True).to( |
| torch.int32 |
| ) |
| return num_blocks[None, None].contiguous(), indices[None, None].contiguous() |
|
|
| |
| kv_num_blocks, kv_indices = dense_to_ordered(partial_bm) |
|
|
| |
| full_kv_num_blocks, full_kv_indices = dense_to_ordered(full_bm) |
|
|
| def mask_mod(b, h, q, kv): |
| return written[kv] |
|
|
| bm = BlockMask.from_kv_blocks( |
| kv_num_blocks, |
| kv_indices, |
| full_kv_num_blocks, |
| full_kv_indices, |
| BLOCK_SIZE=BS, |
| mask_mod=mask_mod, |
| seq_lengths=(T, L), |
| compute_q_blocks=False, |
| ) |
|
|
| return bm |
|
|
|
|
| class LayerKVCache(nn.Module): |
| """ |
| Ring-buffer KV cache with fixed capacity L (tokens) for history plus |
| one extra frame (tokens_per_frame) at the tail holding the current frame. |
| """ |
|
|
| def __init__( |
| self, B, H, L, Dh, dtype, tokens_per_frame: int, pinned_dilation: int = 1 |
| ): |
| super().__init__() |
| self.tpf = tokens_per_frame |
| self.L = L |
| |
| self.capacity = L + self.tpf |
| self.pinned_dilation = pinned_dilation |
| self.num_buckets = (L // self.tpf) // self.pinned_dilation |
| assert (L // self.tpf) % pinned_dilation == 0 and L % self.tpf == 0 |
|
|
| |
| self.kv = nn.Buffer( |
| torch.zeros(2, B, H, self.capacity, Dh, dtype=dtype), |
| persistent=False, |
| ) |
|
|
| |
| |
| written = torch.zeros(self.capacity, dtype=torch.bool) |
| written[L:] = True |
| self.written = nn.Buffer(written, persistent=False) |
|
|
| |
| |
| |
| self.frame_offsets = nn.Buffer( |
| torch.arange(self.tpf, dtype=torch.long), persistent=False |
| ) |
| self.current_idx = nn.Buffer(self.frame_offsets + L, persistent=False) |
|
|
| def reset(self): |
| self.kv.zero_() |
| self.written.zero_() |
| self.written[self.L :].fill_(True) |
|
|
| def upsert(self, kv: Tensor, pos_ids: TensorDict, is_frozen: bool): |
| """ |
| Args: |
| kv: [2, B, H, T, Dh] for a single frame (T = tokens_per_frame) |
| pos_ids: TensorDict with t_pos [B, T], all equal per frame (ignoring -1) |
| """ |
| T = self.tpf |
| t_pos = pos_ids["t_pos"] |
|
|
| if not torch.compiler.is_compiling(): |
| torch._check( |
| kv.size(3) == self.tpf, "KV cache expects exactly one frame per upsert" |
| ) |
| torch._check(t_pos.shape == (kv.size(1), T), "t_pos must be [B, T]") |
| torch._check(self.tpf <= self.L, "frame longer than KV ring capacity") |
| torch._check( |
| self.L % self.tpf == 0, |
| f"L ({self.L}) must be a multiple of tokens_per_frame ({self.tpf})", |
| ) |
| torch._check( |
| self.kv.size(3) == self.capacity, |
| "KV buffer has unexpected length (expected L + tokens_per_frame)", |
| ) |
| torch._check( |
| (t_pos >= 0).all().item(), |
| "t_pos must be non-negative during inference", |
| ) |
| torch._check( |
| ((t_pos == t_pos[:, :1]).all()).item(), |
| "t_pos must be constant within frame", |
| ) |
|
|
| frame_t = t_pos[0, 0] |
|
|
| |
| bucket = (frame_t + (self.pinned_dilation - 1)) // self.pinned_dilation |
| slot = bucket % self.num_buckets |
| base = slot * T |
|
|
| |
| ring_idx = self.frame_offsets + base |
|
|
| |
| |
| self.kv.index_copy_(3, self.current_idx, kv) |
|
|
| write_step = frame_t.remainder(self.pinned_dilation) == 0 |
| mask_written = self.written.clone() |
| mask_written[ring_idx] = mask_written[ring_idx] & ~write_step |
| bm = make_block_mask(T, self.capacity, mask_written) |
|
|
| |
| if not is_frozen: |
| |
| dst = torch.where(write_step, ring_idx, self.current_idx) |
| self.kv.index_copy_(3, dst, kv) |
| self.written[dst] = True |
|
|
| k, v = self.kv.unbind(0) |
| return k, v, bm |
|
|
|
|
| class StaticKVCache(nn.Module): |
| """Static KV cache with per-layer configuration for local/global attention.""" |
|
|
| def __init__(self, config, batch_size, dtype): |
| super().__init__() |
|
|
| self.tpf = config.tokens_per_frame |
|
|
| local_L = config.local_window * self.tpf |
| global_L = config.global_window * self.tpf |
|
|
| period = config.global_attn_period |
| off = getattr(config, "global_attn_offset", 0) % period |
| self.layers = nn.ModuleList( |
| [ |
| LayerKVCache( |
| batch_size, |
| getattr(config, "n_kv_heads", config.n_heads), |
| global_L if ((layer_idx - off) % period == 0) else local_L, |
| config.d_model // config.n_heads, |
| dtype, |
| self.tpf, |
| ( |
| config.global_pinned_dilation |
| if ((layer_idx - off) % period == 0) |
| else 1 |
| ), |
| ) |
| for layer_idx in range(config.n_layers) |
| ] |
| ) |
|
|
| self._is_frozen = True |
|
|
| def reset(self): |
| for layer in self.layers: |
| layer.reset() |
| self._is_frozen = True |
|
|
| def set_frozen(self, is_frozen: bool): |
| self._is_frozen = is_frozen |
|
|
| def upsert(self, k: Tensor, v: Tensor, pos_ids: TensorDict, layer: int): |
| kv = torch.stack([k, v], dim=0) |
| return self.layers[layer].upsert(kv, pos_ids, self._is_frozen) |
|
|
|
|
| class WorldEngineSetTimestepsStep(ModularPipelineBlocks): |
| """Sets up the scheduler sigmas for rectified flow denoising.""" |
|
|
| model_name = "world_engine" |
|
|
| @property |
| def description(self) -> str: |
| return "Sets up scheduler sigmas for rectified flow denoising" |
|
|
| @property |
| def expected_components(self) -> List[ComponentSpec]: |
| return [] |
|
|
| @property |
| def expected_configs(self) -> List[ConfigSpec]: |
| return [ConfigSpec("scheduler_sigmas", [1.0, 0.94921875, 0.83984375, 0.0])] |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam( |
| "scheduler_sigmas", |
| type_hint=List[float], |
| description="Custom scheduler sigmas (overrides config)", |
| ), |
| InputParam( |
| "frame_timestamp", |
| type_hint=torch.Tensor, |
| description="Current frame timestamp", |
| ), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam( |
| "scheduler_sigmas", |
| type_hint=torch.Tensor, |
| description="Tensor of scheduler sigmas for denoising", |
| ), |
| OutputParam( |
| "frame_timestamp", |
| type_hint=torch.Tensor, |
| description="Current frame timestamp", |
| ), |
| ] |
|
|
| @torch.no_grad() |
| def __call__( |
| self, components: ModularPipeline, state: PipelineState |
| ) -> PipelineState: |
| block_state = self.get_block_state(state) |
| device = components._execution_device |
| dtype = components.transformer.dtype |
|
|
| |
| sigmas = block_state.scheduler_sigmas |
| if sigmas is None: |
| sigmas = components.config.scheduler_sigmas |
| block_state.scheduler_sigmas = torch.tensor( |
| sigmas, device=device, dtype=dtype |
| ) |
|
|
| frame_ts = block_state.frame_timestamp |
| if frame_ts is None: |
| frame_ts = torch.tensor([[0]], dtype=torch.long, device=device) |
| elif isinstance(frame_ts, int): |
| frame_ts = torch.tensor([[frame_ts]], dtype=torch.long, device=device) |
|
|
| block_state.frame_timestamp = frame_ts |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|
|
|
| class WorldEngineSetupKVCacheStep(ModularPipelineBlocks): |
| """Initializes or reuses the KV cache for autoregressive generation.""" |
|
|
| model_name = "world_engine" |
|
|
| @property |
| def description(self) -> str: |
| return "Initializes or reuses KV cache for autoregressive frame generation" |
|
|
| @property |
| def expected_components(self) -> List[ComponentSpec]: |
| return [] |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam( |
| "kv_cache", |
| type_hint=Optional[StaticKVCache], |
| description="Existing KV cache (will be reused if provided)", |
| ), |
| InputParam( |
| "reset_cache", |
| type_hint=bool, |
| default=False, |
| description="If True, reset the KV cache even if one exists", |
| ), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam( |
| "kv_cache", |
| type_hint=StaticKVCache, |
| description="KV cache for transformer attention", |
| ), |
| ] |
|
|
| @torch.no_grad() |
| def __call__( |
| self, components: ModularPipeline, state: PipelineState |
| ) -> PipelineState: |
| block_state = self.get_block_state(state) |
| device = components._execution_device |
| dtype = components.transformer.dtype |
|
|
| |
| if block_state.kv_cache is None: |
| block_state.kv_cache = StaticKVCache( |
| components.transformer.config, |
| batch_size=1, |
| dtype=dtype, |
| ).to(device) |
| elif block_state.reset_cache: |
| block_state.kv_cache.reset() |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|
|
|
| class WorldEnginePrepareLatentsStep(ModularPipelineBlocks): |
| """Prepares latents for frame generation, optionally encoding an input image.""" |
|
|
| model_name = "world_engine" |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Prepares latents for frame generation. If an image is provided on the " |
| "first frame, encodes it and caches it as context. Always creates fresh " |
| "random noise for the actual denoising." |
| ) |
|
|
| @property |
| def expected_components(self) -> List[ComponentSpec]: |
| return [ |
| ComponentSpec( |
| "image_processor", |
| VaeImageProcessor, |
| config=FrozenDict( |
| { |
| "vae_scale_factor": 16, |
| "do_normalize": False, |
| "do_convert_rgb": False, |
| } |
| ), |
| default_creation_method="from_config", |
| ), |
| ] |
|
|
| @property |
| def expected_configs(self) -> List[ConfigSpec]: |
| return [ |
| ConfigSpec("channels", 16), |
| ConfigSpec("height", 16), |
| ConfigSpec("width", 16), |
| ConfigSpec("patch", [2, 2]), |
| ConfigSpec("vae_scale_factor", 16), |
| ] |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam( |
| "image", |
| type_hint=Union[PIL.Image.Image, torch.Tensor], |
| description="Input image (PIL Image or [H, W, 3] uint8 tensor), only used on first frame", |
| ), |
| InputParam( |
| "latents", |
| type_hint=torch.Tensor, |
| description="Latent tensor for denoising [1, 1, C, H, W]. Only used if use_random_latents=False.", |
| ), |
| InputParam( |
| "use_random_latents", |
| type_hint=bool, |
| default=True, |
| description="If True, always generate fresh random latents. If False, use provided latents.", |
| ), |
| InputParam( |
| "kv_cache", |
| description="KV cache to update", |
| ), |
| InputParam( |
| "frame_timestamp", |
| type_hint=torch.Tensor, |
| description="Current frame timestamp", |
| ), |
| InputParam( |
| "prompt_embeds", |
| type_hint=torch.Tensor, |
| description="Prompt embeddings for cache pass", |
| ), |
| InputParam( |
| "prompt_pad_mask", |
| type_hint=torch.Tensor, |
| description="Prompt padding mask", |
| ), |
| InputParam( |
| "button_tensor", |
| type_hint=torch.Tensor, |
| description="Button tensor for cache pass", |
| ), |
| InputParam( |
| "mouse_tensor", |
| type_hint=torch.Tensor, |
| description="Mouse tensor for cache pass", |
| ), |
| InputParam( |
| "scroll_tensor", |
| type_hint=torch.Tensor, |
| description="Scroll tensor for cache pass", |
| ), |
| InputParam( |
| "generator", |
| type_hint=torch.Generator, |
| default=None, |
| description="torch Generator for deterministic output", |
| ), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam( |
| "latents", |
| type_hint=torch.Tensor, |
| description="Latent tensor for denoising [1, 1, C, H, W]", |
| ), |
| ] |
|
|
| @staticmethod |
| def _cache_pass( |
| transformer, |
| x, |
| frame_timestamp, |
| prompt_emb, |
| prompt_pad_mask, |
| mouse, |
| button, |
| scroll, |
| kv_cache, |
| ): |
| """Cache pass to persist frame in KV cache.""" |
| kv_cache.set_frozen(False) |
| transformer( |
| x=x, |
| sigma=x.new_zeros((x.size(0), x.size(1))), |
| frame_timestamp=frame_timestamp, |
| prompt_emb=prompt_emb, |
| prompt_pad_mask=prompt_pad_mask, |
| mouse=mouse, |
| button=button, |
| scroll=scroll, |
| kv_cache=kv_cache, |
| ) |
|
|
| @torch.inference_mode() |
| def __call__( |
| self, components: ModularPipeline, state: PipelineState |
| ) -> PipelineState: |
| block_state = self.get_block_state(state) |
| device = components._execution_device |
| dtype = components.transformer.dtype |
|
|
| |
| channels = components.config.channels |
| height = components.config.height |
| width = components.config.width |
| patch = components.config.patch |
|
|
| pH, pW = patch if isinstance(patch, (list, tuple)) else (patch, patch) |
| shape = ( |
| 1, |
| 1, |
| channels, |
| components.config.vae_scale_factor * pH, |
| components.config.vae_scale_factor * pW, |
| ) |
|
|
| if block_state.image is not None: |
| image = block_state.image |
| |
| image = components.image_processor.preprocess( |
| image, |
| height=height, |
| width=width, |
| ) |
| |
| image = (image[0].permute(1, 2, 0) * 255).to(torch.uint8) |
|
|
| assert image.dtype == torch.uint8, ( |
| f"Expected uint8 image, got {image.dtype}" |
| ) |
|
|
| latents = components.vae.encode(image) |
| latents = latents.unsqueeze(1) |
|
|
| |
| self._cache_pass( |
| components.transformer, |
| latents, |
| block_state.frame_timestamp, |
| block_state.prompt_embeds, |
| block_state.prompt_pad_mask, |
| block_state.mouse_tensor, |
| block_state.button_tensor, |
| block_state.scroll_tensor, |
| block_state.kv_cache, |
| ) |
| block_state.frame_timestamp.add_(1) |
|
|
| |
| if block_state.use_random_latents or block_state.latents is None: |
| block_state.latents = torch.randn( |
| shape, device=device, dtype=torch.bfloat16 |
| ) |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|
|
|
| class WorldEngineBeforeDenoiseStep(SequentialPipelineBlocks): |
| """Sequential pipeline that prepares all inputs for denoising.""" |
|
|
| block_classes = [ |
| WorldEngineSetTimestepsStep, |
| WorldEngineSetupKVCacheStep, |
| WorldEnginePrepareLatentsStep, |
| ] |
| block_names = ["set_timesteps", "setup_kv_cache", "prepare_latents"] |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Before denoise step that prepares inputs for denoising:\n" |
| " - WorldEngineSetTimestepsStep: Set up scheduler sigmas\n" |
| " - WorldEngineSetupKVCacheStep: Initialize or reuse KV cache\n" |
| " - WorldEnginePrepareLatentsStep: Encode image (if first frame) and create noise" |
| ) |
|
|