Isaac-0.1 / modular_isaac.py
philippguevorguian's picture
5.0 structure fixes (#7)
0ae3c24
# coding=utf-8
# Copyright 2025 Perceptron, Inc and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
import math
import re
from collections import defaultdict
from collections.abc import Callable, Sequence
from typing import Any, Optional, Union
from transformers.utils.import_utils import (
is_torch_available,
is_torchdynamo_compiling,
is_vision_available,
)
if is_torch_available():
import torch
import torch.nn as nn
import torch.nn.functional as F
if is_vision_available():
from PIL.Image import Image
else:
Image = None
from perceptron.tensorstream.ops import (
compute_mrope_pos_tensor,
modality_mask,
reconstruct_tensor_stream_from_compact_dict,
tensor_stream_token_view,
)
from perceptron.tensorstream.ops import (
slice as ts_slice,
)
from perceptron.tensorstream.tensorstream import (
Event,
Stream,
TensorStream,
TextType,
VisionType,
create_stream,
group_streams,
)
from transformers.cache_utils import DynamicCache
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
from transformers.feature_extraction_utils import BatchFeature
from transformers.generation.utils import GenerationMixin
from transformers.processing_utils import ImagesKwargs
from transformers.image_transforms import group_images_by_shape, reorder_images
from transformers.image_utils import SizeDict
from transformers.image_processing_utils_fast import (
BaseImageProcessorFast,
)
from transformers.image_utils import (
ChannelDimension,
PILImageResampling,
)
from transformers.masking_utils import create_masks_for_generate, packed_sequence_mask_function
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_rope_utils import rope_config_validation
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.auto.modeling_auto import AutoModel
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM, Qwen3PreTrainedModel
from transformers.processing_utils import ProcessorMixin, Unpack
from transformers.utils.generic import TensorType
from transformers.utils.auto_docstring import auto_docstring
# Vision preprocessing constants
from transformers.utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN
from transformers.utils.constants import IMAGENET_STANDARD_STD as VISION_STD
from transformers.utils.generic import TransformersKwargs, can_return_tuple, check_model_inputs
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl as qwen2_5_vl_modeling
from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig
from transformers.models.siglip2.modeling_siglip2 import (
Siglip2Attention,
Siglip2Encoder,
Siglip2EncoderLayer,
)
class IsaacVisionConfig(Siglip2VisionConfig):
"""Vision configuration for Isaac with Pixel Shuffle support.
Extends Siglip2VisionConfig with additional fields for pixel shuffle.
Args:
pixel_shuffle_scale_factor (`int`, *optional*, defaults to 1):
Spatial factor applied before pixel shuffle reduces the resolution.
num_patches (`int`, *optional*, defaults to 256):
Maximum number of learnable positional embeddings to initialize.
"""
model_type = "isaac_vision"
base_config_key = "vision_config"
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
num_channels=3,
num_patches=256,
patch_size=16,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
pixel_shuffle_scale_factor=1,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.num_patches = num_patches
# Add our custom fields
self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor
# Ensure a sensible default attention backend
if getattr(self, "_attn_implementation", None) is None:
self._attn_implementation = "sdpa"
class IsaacImageProcessorKwargs(ImagesKwargs, total=False):
patch_size: Optional[int]
max_num_patches: Optional[int]
min_num_patches: Optional[int]
pixel_shuffle_scale: Optional[int]
@auto_docstring
class IsaacImageProcessorFast(BaseImageProcessorFast):
MAX_PIXELS = 60_000_000 # 60‑megapixel ceiling ≈ 8200 × 7300 px
r"""Fast torch-based image processor for Isaac vision inputs."""
resample = PILImageResampling.BILINEAR
model_input_names = ["patches", "token_grids"]
valid_kwargs = IsaacImageProcessorKwargs
unused_kwargs = ["size", "do_center_crop", "crop_size"]
do_resize = True
size: Optional[SizeDict] = None
default_to_square: Optional[bool] = None
do_center_crop = False
crop_size: Optional[SizeDict] = None
patch_size: Optional[int] = 16
max_num_patches: Optional[int] = 256
min_num_patches: Optional[int] = None
pixel_shuffle_scale: Optional[int] = 1
do_pad = False
pad_size: Optional[SizeDict] = None
do_rescale = True
rescale_factor = 1 / 255
do_normalize = True
image_mean = list(VISION_MEAN)
image_std = list(VISION_STD)
do_convert_rgb = True
return_tensors = None
data_format = ChannelDimension.FIRST
input_data_format = None
device = None
disable_grouping = False
size_divisor: Optional[int] = None
def __init__(
self,
**kwargs: Unpack[IsaacImageProcessorKwargs],
) -> None:
super().__init__(**kwargs)
pixel_shuffle_scale = 1 if self.pixel_shuffle_scale is None else int(self.pixel_shuffle_scale)
if pixel_shuffle_scale < 1:
raise ValueError("`pixel_shuffle_scale` must be >= 1")
self.pixel_shuffle_scale = pixel_shuffle_scale
def _validate_preprocess_kwargs(self, **kwargs):
# Allow callers to omit resize-related placeholders that BaseImageProcessorFast checks for.
kwargs.pop("do_resize", None)
kwargs.pop("size", None)
kwargs.pop("do_center_crop", None)
kwargs.pop("crop_size", None)
kwargs.pop("disable_grouping", None)
return super()._validate_preprocess_kwargs(**kwargs)
def resize(
self,
image: torch.Tensor,
size: SizeDict,
interpolation: Optional[Any] = None,
antialias: bool = True,
**kwargs,
) -> torch.Tensor:
if size.height is None or size.width is None:
raise ValueError("IsaacImageProcessorFast requires explicit `height` and `width` when resizing.")
resize_mode: Any = interpolation
if hasattr(resize_mode, "value"):
resize_mode = resize_mode.value
elif hasattr(resize_mode, "name"):
resize_mode = resize_mode.name.lower()
elif resize_mode is None:
resize_mode = "bilinear"
if isinstance(resize_mode, str):
mode_key = resize_mode.lower()
else:
mode_key = resize_mode
resize_kwargs: dict[str, Any] = {}
if mode_key in {"linear", "bilinear", "bicubic", "trilinear"}:
resize_kwargs["align_corners"] = False
return F.interpolate(
image,
size=(size.height, size.width),
mode=resize_mode,
**resize_kwargs,
)
def _preprocess(
self,
images: list[torch.Tensor],
do_resize: bool,
size: Optional[SizeDict],
interpolation: Optional[Any],
do_center_crop: bool,
crop_size: Optional[SizeDict],
do_rescale: Optional[bool],
rescale_factor: Optional[float],
do_normalize: Optional[bool],
image_mean: Optional[Union[float, Sequence[float]]],
image_std: Optional[Union[float, Sequence[float]]],
disable_grouping: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
do_pad: Optional[bool] = None,
pad_size: Optional[SizeDict] = None,
*,
patch_size: Optional[int] = None,
max_num_patches: Optional[int] = None,
min_num_patches: Optional[int] = None,
pixel_shuffle_scale: Optional[int] = None,
**kwargs,
) -> BatchFeature:
if do_center_crop:
raise ValueError("`do_center_crop` is not supported by IsaacImageProcessorFast.")
if do_pad:
raise ValueError("`do_pad` is not supported by IsaacImageProcessorFast.")
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
processed_patches_grouped: dict[tuple[int, ...], torch.Tensor] = {}
token_grids_grouped: dict[tuple[int, ...], torch.Tensor] = {}
virtual_dims_grouped: dict[tuple[int, ...], torch.Tensor] = {}
real_dims_grouped: dict[tuple[int, ...], torch.Tensor] = {}
for shape, stacked_images in grouped_images.items():
if stacked_images.ndim != 4:
raise ValueError("Expected batched channel-first image tensors.")
batch_size, channels, original_height, original_width = stacked_images.shape
if bool(self.do_convert_rgb) and channels == 1:
stacked_images = stacked_images.repeat(1, 3, 1, 1)
channels = 3
if original_height * original_width > self.MAX_PIXELS:
raise ValueError(f"Image (w={original_width}, h={original_height}) > MAX=`{self.MAX_PIXELS}`")
target_height, target_width = get_image_size_for_max_num_patches(
original_height,
original_width,
patch_size,
max_num_patches,
min_num_patches=min_num_patches,
pixel_shuffle_scale=pixel_shuffle_scale,
)
if do_resize:
resize_size = SizeDict(height=target_height, width=target_width)
image_batch = self.resize(
image=stacked_images,
size=resize_size,
interpolation=interpolation,
)
else:
if ((original_height % patch_size) != 0) or ((original_width % patch_size) != 0):
raise ValueError("Image dimensions must be divisible by patch_size when resize is disabled.")
image_batch = stacked_images
target_height, target_width = original_height, original_width
if do_rescale:
image_batch = self.rescale_and_normalize(
image_batch,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
)
nhwc_images = image_batch.permute(0, 2, 3, 1)
nhwc_images = _compute_residual_p_frames(nhwc_images, is_p_frame=[False] * batch_size)
patches = patchify_vision(nhwc_images, patch_size=patch_size)
_, height_tokens, width_tokens, _ = patches.shape
token_grid = (
torch.tensor(
[height_tokens, width_tokens],
dtype=torch.long,
device=patches.device,
)
.unsqueeze(0)
.repeat(batch_size, 1)
)
real_dim = (
torch.tensor(
[1, height_tokens, width_tokens],
dtype=torch.long,
device=patches.device,
)
.unsqueeze(0)
.repeat(batch_size, 1)
)
if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale):
raise ValueError(
"Spatial dimensions must be divisible by pixel_shuffle_scale when pixel shuffle is enabled."
)
virtual_height = height_tokens // pixel_shuffle_scale
virtual_width = width_tokens // pixel_shuffle_scale
virtual_dim = (
torch.tensor(
[1, virtual_height, virtual_width],
dtype=torch.long,
device=patches.device,
)
.unsqueeze(0)
.repeat(batch_size, 1)
)
processed_patches_grouped[shape] = patches
token_grids_grouped[shape] = token_grid
virtual_dims_grouped[shape] = virtual_dim
real_dims_grouped[shape] = real_dim
patches_slices = reorder_images(processed_patches_grouped, grouped_images_index)
token_grid_slices = reorder_images(token_grids_grouped, grouped_images_index)
virtual_dim_slices = reorder_images(virtual_dims_grouped, grouped_images_index)
real_dim_slices = reorder_images(real_dims_grouped, grouped_images_index)
patches_tensor = torch.stack(patches_slices, dim=0)
token_grids_tensor = torch.stack(token_grid_slices, dim=0)
virtual_dims_tensor = torch.stack(virtual_dim_slices, dim=0)
real_dims_tensor = torch.stack(real_dim_slices, dim=0)
return BatchFeature(
data={
"patches": patches_tensor,
"token_grids": token_grids_tensor,
"virtual_pixel_size": virtual_dims_tensor,
"real_pixel_size": real_dims_tensor,
},
tensor_type=return_tensors,
)
def document_mask_function_from_cu_seqlens(cu_seqlens: Optional[torch.Tensor]) -> Optional[Callable]:
"""Return a mask function that blocks cross-document attention from packed ``cu_seqlens``.
The returned callable matches the signature expected by ``masking_utils`` mask factories and
yields ``True`` only when query/key positions belong to the same packed segment.
"""
if cu_seqlens is None:
return None
if cu_seqlens.numel() < 2:
return None
seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long()
if seq_sizes.numel() == 0:
return None
total_tokens = int(seq_sizes.sum().item())
seg_ids = torch.repeat_interleave(torch.arange(seq_sizes.numel(), device=cu_seqlens.device), seq_sizes)
packed_sequence_mask = seg_ids.view(1, total_tokens)
return packed_sequence_mask_function(packed_sequence_mask)
def ensure_document_attention_mask(
attention_mask: Optional[torch.Tensor],
cu_seqlens: Optional[torch.Tensor],
total_tokens: int,
dtype: torch.dtype,
device: torch.device,
*,
return_mask_function: bool = False,
) -> Optional[Union[torch.Tensor, Callable]]:
"""Return the provided mask, a callable mask from ``cu_seqlens``, or ``None``.
``return_mask_function=True`` yields a callable suitable for ``masking_utils``; otherwise
``None`` is returned when no explicit ``attention_mask`` is provided. The legacy additive mask
has been removed in favor of the callable-based path.
"""
if attention_mask is not None:
return attention_mask
if cu_seqlens is None:
return None
if return_mask_function:
return document_mask_function_from_cu_seqlens(cu_seqlens)
return None
class IsaacVisionEmbeddings(nn.Module):
"""Adapter around SigLIP2 vision embeddings that consumes packed patch sequences."""
def __init__(self, config: IsaacVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Linear(
in_features=config.num_channels * self.patch_size * self.patch_size,
out_features=self.embed_dim,
)
self.num_patches = config.num_patches
self.position_embedding_size = int(self.num_patches**0.5)
self.position_embedding = nn.Parameter(
torch.empty(
self.position_embedding_size,
self.position_embedding_size,
self.embed_dim,
)
)
nn.init.normal_(self.position_embedding)
def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor) -> torch.Tensor:
packed_pixel_values, seq_lengths = self._pack_to_batch(seq_patches, spatial_shapes)
if packed_pixel_values is None:
return seq_patches.new_zeros((0, self.embed_dim))
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(packed_pixel_values.to(dtype=target_dtype))
positional_embeddings = self.position_embedding
resized_positional_embeddings = self.resize_positional_embeddings(
positional_embeddings, spatial_shapes, max_length=packed_pixel_values.shape[1]
)
embeddings = patch_embeds + resized_positional_embeddings
return self._unpack_from_batch(embeddings, seq_lengths)
@staticmethod
def resize_positional_embeddings(
positional_embeddings: torch.Tensor,
spatial_shapes: torch.LongTensor,
max_length: int,
) -> torch.Tensor:
"""
Resize positional embeddings to image-specific size and pad to a fixed size.
Args:
positional_embeddings (`torch.Tensor`):
Position embeddings of shape (height, width, embed_dim)
spatial_shapes (`torch.LongTensor`):
Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
max_length (`int`):
Maximum length of the positional embeddings to pad resized positional embeddings to
Returns:
`torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
"""
batch_size = spatial_shapes.shape[0]
embed_dim = positional_embeddings.shape[-1]
source_dtype = positional_embeddings.dtype
resulted_positional_embeddings = torch.empty(
(batch_size, max_length, embed_dim),
device=positional_embeddings.device,
dtype=source_dtype,
)
# (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
# Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU
if positional_embeddings.device.type == "cpu":
positional_embeddings = positional_embeddings.to(torch.float32)
for i in range(batch_size):
# (1, dim, height, width) -> (1, dim, target_height, target_width)
height, width = spatial_shapes[i]
resized_embeddings = F.interpolate(
positional_embeddings,
size=(height, width),
mode="bilinear",
align_corners=False,
antialias=True,
)
# (1, dim, target_height, target_width) -> (target_height * target_width, dim)
resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1)
# Cast to original dtype
resized_embeddings = resized_embeddings.to(source_dtype)
resulted_positional_embeddings[i, : height * width] = resized_embeddings
resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
return resulted_positional_embeddings
def _pack_to_batch(
self,
seq_patches: torch.Tensor,
spatial_shapes: torch.Tensor,
) -> tuple[Optional[torch.Tensor], torch.Tensor]:
if seq_patches.ndim != 2:
raise ValueError("`seq_patches` is expected to be 2D (total_patches, patch_dim).")
if spatial_shapes.ndim != 2 or spatial_shapes.size(-1) != 2:
raise ValueError("`spatial_shapes` must have shape (num_images, 2) with (height_tokens, width_tokens).")
seq_lengths = spatial_shapes.long().prod(dim=-1)
total_patches = int(seq_lengths.sum().item())
if total_patches != seq_patches.size(0):
raise ValueError(
"Mismatch between packed patches and spatial shapes: got "
f"{seq_patches.size(0)} patches but spatial shapes imply {total_patches}."
)
batch_size = spatial_shapes.size(0)
if batch_size == 0:
return None, seq_lengths
max_length = int(seq_lengths.max().item())
patch_dim = seq_patches.size(-1)
device = seq_patches.device
packed_pixel_values = seq_patches.new_zeros((batch_size, max_length, patch_dim), device=device)
start = 0
for batch_idx, length in enumerate(seq_lengths.tolist()):
if length == 0:
continue
end = start + length
packed_pixel_values[batch_idx, :length] = seq_patches[start:end]
start = end
return packed_pixel_values, seq_lengths
def _unpack_from_batch(self, embeddings: torch.Tensor, seq_lengths: torch.Tensor) -> torch.Tensor:
output_chunks: list[torch.Tensor] = []
for batch_idx, length in enumerate(seq_lengths.tolist()):
if length == 0:
continue
output_chunks.append(embeddings[batch_idx, :length])
if not output_chunks:
return embeddings.new_zeros((0, embeddings.size(-1)))
return torch.cat(output_chunks, dim=0)
class IsaacVisionAttention(Siglip2Attention):
"""Custom attention that supports variable-length sequences with flash attention."""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[torch.Tensor] = None,
output_attentions: bool = False,
is_causal: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
**kwargs,
):
# Ignore unused arguments for interface compatibility
_ = position_ids
_ = past_key_value
_ = is_causal
kwargs.pop("output_hidden_states", None)
kwargs.pop("return_dict", None)
batch_size, seq_length, embed_dim = hidden_states.shape
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
if not queries.is_contiguous():
queries = queries.contiguous()
if not keys.is_contiguous():
keys = keys.contiguous()
if not values.is_contiguous():
values = values.contiguous()
L = queries.size(0)
if max_seqlen is not None:
max_q = max_k = int(max_seqlen)
else:
max_q = max_k = self._max_from_cu(cu_seqlens, L)
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS["sdpa"]
if self.config._attn_implementation != "sdpa":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
dropout = 0.0 if not self.training else self.dropout
attention_kwargs: dict[str, Any] = {
"is_causal": False,
"scaling": self.scale,
"dropout": dropout,
}
if cu_seqlens is not None:
attention_kwargs["cu_seq_lens_q"] = cu_seqlens
attention_kwargs["cu_seq_lens_k"] = cu_seqlens
if max_seqlen is not None:
attention_kwargs["max_length_q"] = max_q
attention_kwargs["max_length_k"] = max_k
if output_attentions:
attention_kwargs["output_attentions"] = True
attn_output, attn_weights = attention_interface(
self,
queries,
keys,
values,
attention_mask,
**attention_kwargs,
)
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
# Align projection inputs with parameter dtype to avoid mixed-dtype matmul errors
out_proj_dtype = self.out_proj.weight.dtype
if attn_output.dtype != out_proj_dtype:
attn_output = attn_output.to(out_proj_dtype)
attn_output = self.out_proj(attn_output)
if attn_output.dtype != hidden_states.dtype:
attn_output = attn_output.to(hidden_states.dtype)
return attn_output, attn_weights
@staticmethod
def _max_from_cu(cu: Optional[torch.Tensor], fallback: int) -> int:
if cu is None or cu.numel() < 2:
return fallback
return int((cu[1:] - cu[:-1]).max().item())
class IsaacVisionEncoderLayer(Siglip2EncoderLayer):
"""Isaac vision encoder layer with variable-length attention."""
def __init__(self, config: IsaacVisionConfig):
super().__init__(config)
self.self_attn = IsaacVisionAttention(config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
output_attentions: bool = False,
**kwargs: Unpack[TransformersKwargs],
):
r"""
cu_seqlens (`torch.Tensor`, *optional*):
Prefix-sum tensor whose length equals the number of documents + 1. The difference between successive
entries gives each document's token count and enables block-diagonal attention masking for packed batches.
max_seqlen (`int`, *optional*):
Maximum document length referenced by `cu_seqlens`. Passed to FlashAttention so it can size temporary
buffers for packed variable-length attention.
"""
attention_mask = ensure_document_attention_mask(
attention_mask,
cu_seqlens,
hidden_states.size(1),
hidden_states.dtype,
hidden_states.device,
return_mask_function=False,
)
# Run attention directly so variable-length metadata reaches FlashAttention.
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
attn_outputs = self.self_attn(
hidden_states,
attention_mask=attention_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
output_attentions=output_attentions,
**kwargs,
)
if isinstance(attn_outputs, tuple):
attn_output, attn_weights = attn_outputs
else:
attn_output, attn_weights = attn_outputs, None
hidden_states = residual + attn_output
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
if output_attentions:
return hidden_states, attn_weights
return hidden_states
class IsaacVisionEncoder(Siglip2Encoder):
"""Encoder using Isaac encoder layers with variable-length attention support."""
def __init__(self, config: IsaacVisionConfig):
super().__init__(config)
self.layers = nn.ModuleList([IsaacVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
@can_return_tuple
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
):
attention_mask = ensure_document_attention_mask(
attention_mask,
cu_seqlens,
inputs_embeds.size(1),
inputs_embeds.dtype,
inputs_embeds.device,
return_mask_function=False,
)
hidden_states = inputs_embeds
kwargs.update(
{
"max_seqlen": max_seqlen,
"cu_seqlens": cu_seqlens,
"output_attentions": output_attentions,
"output_hidden_states": output_hidden_states,
"return_dict": return_dict,
}
)
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
attention_mask,
**kwargs,
)
return BaseModelOutput(last_hidden_state=hidden_states)
def create_pixel_shuffle_index_map(
seq_sizes: torch.Tensor,
token_grids: torch.Tensor,
scale_factor: int = 1,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""
Build a gather-index map that tells us, for every *output* token after
pixel-shuffle, which `scale_factor**2` *input* tokens are being merged.
Args
----
seq_sizes : (num_images,) - #patches in each image (row-major order)
token_grids : (num_images,2) - (height, width) for every image
scale_factor : spatial down-scale factor (≥2)
device : (optional) overrides `seq_sizes.device`
Returns
-------
gather_idx : (new_total_seq_len, scale_factor**2) int64 tensor.
gather_idx[i, j] is the *flat* index into the *original*
packed sequence for the j-th sub-patch that forms the
i-th output token.
"""
if device is None:
device = seq_sizes.device
scale_factor = int(scale_factor)
if scale_factor < 2:
raise ValueError("`scale_factor` must be ≥ 2")
# Safety: all spatial dims must be divisible by the scale factor
# Cannot run under torch compile fullgraph mode hence
if not is_torchdynamo_compiling():
if not ((token_grids[:, 0] % scale_factor == 0).all() and (token_grids[:, 1] % scale_factor == 0).all()):
raise AssertionError(
"Every (H,W) in `token_grids` must be divisible by "
f"scale_factor={scale_factor}, got {token_grids.tolist()}"
)
gather_chunks: list[torch.Tensor] = []
tok_offset = 0
for seq_len, (h, w) in zip(seq_sizes.tolist(), token_grids.tolist(), strict=False):
# Build the (H, W) grid of flat indices for this image
grid = torch.arange(seq_len, device=device, dtype=torch.int64) + tok_offset
grid = grid.view(h, w) # (H, W)
# -------- identical ordering to your fixed-res routine --------
# Step 1: split width into blocks of scale_factor
grid = grid.view(h, w // scale_factor, scale_factor) # (H, W/scale_factor, scale_factor)
# Step 2: now split height into blocks of scale_factor
grid = grid.view(h // scale_factor, scale_factor, w // scale_factor, scale_factor)
# (H/scale_factor, scale_factor, W/scale_factor, scale_factor)
# Step 3: final permutation to (H/scale_factor, W/scale_factor, scale_factor, scale_factor)
grid = grid.permute(0, 2, 1, 3).contiguous() # (H/scale_factor, W/scale_factor, scale_factor, scale_factor)
# Step 4: each (scale_factor, scale_factor) block forms one output token
gather_chunks.append(grid.reshape(-1, scale_factor * scale_factor))
# (H*W / scale_factor**2, scale_factor**2)
tok_offset += seq_len
# Concatenate over all images in the packed batch
gather_idx = torch.cat(gather_chunks, dim=0) # (Σ_i HᵢWᵢ/scale_factor**2, scale_factor**2)
return gather_idx
def pixel_shuffle_varlen(
x: torch.Tensor,
token_grids: torch.Tensor,
scale_factor: int = 1,
) -> torch.Tensor:
r"""Apply pixel shuffle to a packed vision sequence without unpacking per image.
Args:
x (`torch.Tensor`):
Concatenated vision embeddings. Accepts `(seq_len, hidden_size)` or `(1, seq_len, hidden_size)` shapes
produced by stacking image patches.
token_grids (`torch.Tensor`):
Integer tensor of shape `(num_images, 2)` whose rows give the `(height, width)` patch grid sizes
corresponding to each image segment inside `x`.
scale_factor (`int`, *optional*, defaults to 1):
Spatial down-sampling factor specific to pixel shuffle. Values greater than one merge `scale_factor**2` neighboring patches into a
single embedding channel-group.
Returns:
`torch.Tensor`: Pixel-shuffled embeddings with shape matching the input convention:
`(seq_len, hidden_size * scale_factor**2)` when the input was 2D, or `(1, seq_len, hidden_size * scale_factor**2)`
if the singleton batch dimension was present.
Raises:
ValueError: If more than one batch item is provided.
"""
keep_batch_dim = x.dim() == 3
if keep_batch_dim:
if x.size(0) != 1:
raise AssertionError("Packed sequence is expected to have batch_size == 1")
x_ = x.squeeze(0) # (seq, embed)
else:
x_ = x # (seq, embed)
embed_dim = x_.size(-1)
scale_factor = int(scale_factor)
# Calculate seq_sizes from token_grids
seq_sizes = torch.prod(token_grids, dim=-1)
# Build index map and gather in one go
gather_idx = create_pixel_shuffle_index_map(
seq_sizes=seq_sizes,
token_grids=token_grids,
scale_factor=scale_factor,
device=x_.device,
) # (new_seq, scale_factor**2)
# Gather → (new_seq, scale_factor**2, embed_dim)
gathered = x_[gather_idx] # fancy indexing keeps gradient
# Merge the scale_factor**2 group dimension into channels to finish the shuffle
out = gathered.reshape(gathered.size(0), embed_dim * scale_factor * scale_factor)
# Restore batch dimension if needed
if keep_batch_dim:
out = out.unsqueeze(0)
return out
class IsaacVisionTransformer(nn.Module):
_supports_sdpa = True
def __init__(self, config: IsaacVisionConfig):
super().__init__()
self.config = config
self.embeddings = IsaacVisionEmbeddings(config)
self.encoder = IsaacVisionEncoder(config)
self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor
def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]):
seq_patches, token_grids = packed_seq_patches
seq_sizes = torch.prod(token_grids, dim=-1)
# Get embeddings from packed sequence
hidden_states = self.embeddings(seq_patches, token_grids)
# Add a pseudo batch dimension for the encoder
hidden_states = hidden_states.unsqueeze(0)
# Generate cumulative sequence lengths for variable-length attention
cu_seqlens = torch.zeros(seq_sizes.size(0) + 1, dtype=torch.int32, device=hidden_states.device)
cu_seqlens[1:] = seq_sizes.cumsum(0)
max_seqlen = int(seq_sizes.max().item()) if seq_sizes.numel() > 0 else 0
# Pass through encoder with variable-length attention parameters
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
return_dict=True,
)
hidden_states = encoder_outputs.last_hidden_state
# Apply final layer normalization
hidden_states = self.post_layernorm(hidden_states)
hidden_states = pixel_shuffle_varlen(
x=hidden_states,
token_grids=token_grids,
scale_factor=self.pixel_shuffle_scale_factor,
)
# Remove the pseudo batch dimension we added earlier
hidden_states = hidden_states.squeeze(0)
# Return the full sequence of embeddings
return hidden_states
class IsaacMultiModalProjector(nn.Module):
def __init__(self, config: IsaacConfig):
super().__init__()
self.vision_hidden_size = config.vision_config.hidden_size * (
config.vision_config.pixel_shuffle_scale_factor**2
)
self.backbone_hidden_size = config.hidden_size
self.linear_1 = nn.Linear(self.vision_hidden_size, 4 * self.vision_hidden_size, bias=False)
self.silu = nn.SiLU()
self.linear_2 = nn.Linear(4 * self.vision_hidden_size, self.backbone_hidden_size, bias=False)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.silu(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class IsaacVisionEmbedding(nn.Module):
"""Vision embedding wrapper exposing tower and projector."""
_supports_sdpa = True
def __init__(self, config: IsaacConfig):
super().__init__()
vision_cfg = config.vision_config
self.vision_tower = IsaacVisionTransformer(vision_cfg)
self.multimodal_projector = IsaacMultiModalProjector(config)
def forward(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
hidden_states = self.vision_tower(vision_tokens)
return self.multimodal_projector(hidden_states)
def get_scaled_image_size(
scale: float,
original_size: int,
patch_size: int,
pixel_shuffle_scale: int,
) -> int:
scaled_size = scale * original_size
divisor = patch_size * pixel_shuffle_scale
scaled_size = math.ceil(scaled_size / divisor) * divisor
scaled_size = max(divisor, scaled_size)
return int(scaled_size)
def get_image_size_for_max_num_patches(
image_height: int,
image_width: int,
patch_size: int,
max_num_patches: int,
min_num_patches: Optional[int] = None,
eps: float = 1e-5,
pixel_shuffle_scale: int = 1,
) -> tuple[int, int]:
r"""Compute a target resolution whose patch grid satisfies patching parametrization.
Args:
image_height (`int`):
Height in pixels of the source image prior to any resizing.
image_width (`int`):
Width in pixels of the source image prior to any resizing.
patch_size (`int`):
Size of the square patch used by the vision encoder.
max_num_patches (`int`):
Upper bound on `(height / patch_size) * (width / patch_size)` after resizing.
min_num_patches (`int`, *optional*):
Lower bound on the number of patches. When provided the image will be scaled up if necessary.
eps (`float`, *optional*, defaults to 1e-5):
Convergence tolerance for the internal binary search to determing the target dimensions.
pixel_shuffle_scale (`int`, *optional*, defaults to 1):
Additional stride multiplier applied when pixel shuffle later reduces spatial resolution.
Returns:
`tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale`
and respect both the maximum and optional minimum patch-count constraints.
"""
# Ensure divisibility
divisor = patch_size * pixel_shuffle_scale
adjusted_height = math.ceil(image_height / divisor) * divisor
adjusted_height = max(divisor, adjusted_height)
adjusted_width = math.ceil(image_width / divisor) * divisor
adjusted_width = max(divisor, adjusted_width)
num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size)
if min_num_patches is not None and num_patches < min_num_patches:
# Scale up
scale_min, scale_max = 1.0, 100.0
while (scale_max - scale_min) >= eps:
scale = (scale_min + scale_max) / 2
target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale)
target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale)
num_patches = (target_height / patch_size) * (target_width / patch_size)
if num_patches >= min_num_patches:
scale_max = scale
else:
scale_min = scale
scale = scale_max
target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale)
target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale)
return target_height, target_width
elif num_patches <= max_num_patches:
return adjusted_height, adjusted_width
else:
# Scale down
scale_min, scale_max = eps / 10, 1.0
while (scale_max - scale_min) >= eps:
scale = (scale_min + scale_max) / 2
target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale)
target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale)
num_patches = (target_height / patch_size) * (target_width / patch_size)
if num_patches <= max_num_patches:
scale_min = scale
else:
scale_max = scale
scale = scale_min
target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale)
target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale)
return target_height, target_width
def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor:
r"""Convert normalized images into flattened ViT-style patches.
Args:
image (`torch.Tensor`):
Tensor of shape `(num_images, height, width, channels)`.
patch_size (`int`):
Edge length of the square patches
Returns:
`torch.Tensor`:
Patch tensor where each position stores the flattened pixels belonging to that patch.
Raises:
ValueError: If `height` or `width` is not divisible by `patch_size`.
"""
num_images, height, width, channels = image.shape
if height % patch_size or width % patch_size:
raise ValueError(f"Dimensions of images {image.shape} are not divisible by patch_size={patch_size}.")
patches = image.reshape(num_images, height // patch_size, patch_size, width // patch_size, patch_size, channels)
patches = patches.permute(0, 1, 3, 2, 4, 5)
patches = patches.reshape(
num_images, height // patch_size, width // patch_size, channels * patch_size * patch_size
)
return patches
class IsaacConfig(PretrainedConfig):
"""Configuration class for Isaac multimodal model.
This configuration corresponds to checkpoints such as
[Perceptron/isaac-base](https://huggingface.co/Perceptron/isaac-base).
"""
model_type = "isaac"
sub_configs = {"vision_config": IsaacVisionConfig, "text_config": Qwen3Config}
image_processor_type = "IsaacImageProcessor"
def __init__(
self,
vision_config: Optional[IsaacVisionConfig] = None,
text_config: Optional[Union[Qwen3Config, dict]] = None,
vision_rescale_factor: float = 1 / 255,
max_sequence_length: int = 16384,
vision_token: str = "<image>",
**kwargs,
):
self._rope_parameters: Optional[dict[str, Any]] = None
attn_implementation = kwargs.get("attn_implementation")
if isinstance(text_config, dict):
self.text_config = self.sub_configs["text_config"](**text_config)
elif text_config is None:
self.text_config = self.sub_configs["text_config"]()
super().__init__(**kwargs)
if self._rope_scaling is None:
self._rope_scaling = getattr(self.text_config, "rope_scaling", None)
else:
self.text_config.rope_scaling = self._rope_scaling
# Keep rope parameters alias in sync with upstream expectations
self._rope_parameters = self._rope_scaling
# Mirror frequently accessed Qwen3 attributes at the composite config level for BC.
self.vocab_size = self.text_config.vocab_size
self.hidden_size = self.text_config.hidden_size
self.num_hidden_layers = self.text_config.num_hidden_layers
self.num_attention_heads = self.text_config.num_attention_heads
self.head_dim = self.text_config.head_dim
self.hidden_act = self.text_config.hidden_act
self.use_cache = self.text_config.use_cache
self.rope_theta = self.text_config.rope_parameters["rope_theta"]
# Validate rotary parameters now that they have been mirrored locally.
rope_config_validation(self)
self.layer_types = getattr(self.text_config, "layer_types", None)
layer_type_validation(self.layer_types, self.num_hidden_layers)
# Handle vision config - either dict or IsaacVisionConfig instance
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif isinstance(vision_config, IsaacVisionConfig):
self.vision_config = vision_config
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
# Propagate user-requested attention backend to the vision sub-config when provided.
if attn_implementation is not None:
if isinstance(attn_implementation, dict):
vision_attn = attn_implementation.get("vision_config", attn_implementation.get("", None))
else:
vision_attn = attn_implementation
if vision_attn is not None:
self.vision_config._attn_implementation = vision_attn
# Vision normalization parameters
self.vision_rescale_factor = float(vision_rescale_factor)
# Processing parameters
self.max_sequence_length = max_sequence_length
self.vision_token = vision_token
@property
def rope_scaling(self):
if hasattr(self, "text_config") and self.text_config is not None:
return getattr(self.text_config, "rope_scaling", None)
return self._rope_scaling
@rope_scaling.setter
def rope_scaling(self, value):
self._rope_scaling = value
if hasattr(self, "text_config") and self.text_config is not None:
self.text_config.rope_scaling = value
@property
def rope_parameters(self) -> dict[str, Any] | None:
"""Alias introduced upstream for rope scaling dictionaries."""
value = self._rope_parameters
if value is None:
value = self.rope_scaling
if value is None:
return {"rope_type": "default"}
return value
@rope_parameters.setter
def rope_parameters(self, value: dict[str, Any] | None) -> None:
self._rope_parameters = value
self.rope_scaling = value
def to_dict(self):
output = super().to_dict()
# Ensure nested configs round-trip through dict serialization
if hasattr(self, "text_config") and self.text_config is not None:
output["text_config"] = self.text_config.to_dict()
if hasattr(self, "vision_config") and self.vision_config is not None:
output["vision_config"] = self.vision_config.to_dict()
return output
# ============================================================================
# Processor Components
# ============================================================================
def create_text_event(tokenizer: AutoTokenizer, text: str, time: float = 0.0) -> Event:
r"""Wrap a text into an `Event` compatible with the multimodal TensorStream.
Args:
tokenizer (`AutoTokenizer`):
Tokenizer used to convert text into model vocabulary ids.
text (`str`):
Plain-text fragment to encode.
time (`float`, *optional*, defaults to 0.0):
Timeline coordinate associated with the event. Both start and end times use the same value because text
segments are instantaneous in the scheduler.
Returns:
`Event`: Event carrying a `(num_tokens, 1)` tensor of token ids with matching
metadata so that downstream processors can compute modality-specific embeddings.
"""
tokens = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt").squeeze(0)
# Calculate dimensions for the event
num_tokens = len(tokens)
dims_virtual = [num_tokens, 1] # [sequence_length, 1]
dims_real = dims_virtual.copy()
# Ensure tokens has the right shape for tensor_stream_token_view
# It expects a 2D tensor where sum(dim=-1) gives the token IDs
if tokens.dim() == 1:
tokens = tokens.unsqueeze(-1)
return Event(
data=tokens,
type=TextType.text,
time=(time, time),
dims_virtual=dims_virtual,
dims_real=dims_real,
idx_range=(0, num_tokens),
)
# ============================================================================
# Processor
# ============================================================================
class IsaacProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
image_processor_class = ("IsaacImageProcessorFast",)
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
def __init__(
self,
image_processor,
tokenizer,
*,
vision_token: str = "<image>",
max_sequence_length: int = 16384,
rescale_factor: Optional[float] = None,
config: Optional[Union[IsaacConfig, dict]] = None,
) -> None:
if tokenizer is None:
raise ValueError("`tokenizer` must be provided to initialize IsaacProcessor.")
if isinstance(config, dict):
config = IsaacConfig(**config)
if config is not None:
max_sequence_length = config.max_sequence_length
vision_token = config.vision_token
rescale_factor = config.vision_rescale_factor
resolved_rescale_factor = float(rescale_factor) if rescale_factor is not None else float(1 / 255)
if config is not None:
config.vision_rescale_factor = resolved_rescale_factor
self.image_processor = image_processor
super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor
self.config = config
# Mirror tokenizer chat template so ProcessorMixin.apply_chat_template works.
self.chat_template = getattr(self.tokenizer, "chat_template", None)
self.vision_token = vision_token
self.max_sequence_length = max_sequence_length
def build_event_stream_simple(
self,
text: str,
images: Optional[list[Image]] = None,
) -> Stream:
events = []
# Process text and images
# Find all occurrences of vision token
pattern = re.escape(self.vision_token)
parts = re.split(f"({pattern})", text) # Keep the delimiter in the result
image_idx = 0
for current_time, part in enumerate(parts):
if part == self.vision_token:
# Replace vision token with image event
if images is None or image_idx >= len(images):
raise ValueError("Encountered vision token without a corresponding image.")
features = self.image_processor(
images=images[image_idx],
return_tensors=TensorType.PYTORCH,
)
patches = features["patches"][0] # (H_tokens, W_tokens, embed)
virtual_dims = features["virtual_pixel_size"][0].tolist()
real_dims = features["real_pixel_size"][0].tolist()
vision_event = Event(
data=patches.reshape(-1, patches.shape[-1]),
type=VisionType.image,
time=(current_time, current_time),
dims_virtual=virtual_dims,
dims_real=real_dims,
idx_range=(0, math.prod(virtual_dims)),
)
events.append(vision_event)
image_idx += 1
elif part: # Non-empty text part
# tokens = self.text_processor.tokenize(part, add_special_tokens=False)
text_event = create_text_event(self.tokenizer, part, time=current_time)
events.append(text_event)
# Create stream without scheduling (events already in order)
return create_stream(events, priority=[TextType.text, VisionType.image], schedule=True)
def __call__(
self,
text: Union[str, list[str]],
images: Optional[Union[Image, list[Image]]] = None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
**kwargs,
) -> BatchFeature:
"""
Process text and images into TensorStream format.
Args:
text: Input text or list of texts with vision tokens
images: PIL image or list of images (optional)
return_tensors: Format for output tensors
Returns:
BatchFeature with input_ids and tensor_stream
"""
# Normalize inputs to lists
if isinstance(text, str):
texts = [text]
else:
texts = text
if images is not None:
if isinstance(images, Image):
images_list = [images]
else:
images_list = images
else:
images_list = None
if len(texts) != 1:
raise ValueError("IsaacProcessor currently supports batch_size=1")
if images_list is not None:
# Count vision tokens in text to validate image count
vision_token_count = texts[0].count(self.vision_token)
if vision_token_count != len(images_list):
raise ValueError(
f"Number of {self.vision_token} tokens in text ({vision_token_count}) "
f"must match number of images ({len(images_list)})"
)
# Build event stream
stream = self.build_event_stream_simple(
text=texts[0],
images=images_list,
)
# Create TensorStream
tensor_stream = TensorStream([stream])
# Slice to max length if needed
_, T = tensor_stream.shape
if T > self.max_sequence_length:
tensor_stream = ts_slice(tensor_stream, start=T - self.max_sequence_length, end=T)
# Get token view
tokens = tensor_stream_token_view(tensor_stream)
if return_tensors in (TensorType.PYTORCH, "pt"):
input_ids = torch.as_tensor(tokens, dtype=torch.long)
else:
input_ids = tokens
data = {
"input_ids": input_ids,
"tensor_stream": tensor_stream,
}
return BatchFeature(data=data)
# ============================================================================
# Model
# ============================================================================
def compute_position_ids_input_ids(input_ids: torch.Tensor) -> torch.Tensor:
r"""Create 3D positional indices for token input.
Args:
input_ids (`torch.Tensor`):
Tensor of shape `(batch_size, seq_len)` containing token ids.
Returns:
`torch.Tensor`: Positional indices with shape `(batch_size, seq_len, 3)` where each channel duplicates the
1D position so it can be consumed by the 3-axis MRoPE rotary embedding.
"""
batch_size, seq_length = input_ids.shape
position_ids = torch.arange(seq_length, device=input_ids.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
position_ids = position_ids.unsqueeze(2).expand(-1, -1, 3) # Add 3D for MRoPE
return position_ids
class IsaacRotaryEmbedding(nn.Module):
EXTRA_ROPE_KEYS = {"mrope_section", "mrope_interleaved"}
def __init__(self, config: IsaacConfig, device=None):
super().__init__()
rope_source_cfg = config.get_text_config() if hasattr(config, "get_text_config") else config
rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {}
sanitized_scaling = {k: v for k, v in rope_scaling.items() if k not in self.EXTRA_ROPE_KEYS}
config_for_rope = copy.copy(rope_source_cfg)
config_for_rope.rope_scaling = sanitized_scaling if sanitized_scaling else None
init_device = device if device is not None and getattr(device, "type", None) != "meta" else None
self._qwen_rotary = qwen2_5_vl_modeling.Qwen2_5_VLRotaryEmbedding(config_for_rope, device=init_device)
rotary_half_dim = self._qwen_rotary.inv_freq.shape[0]
self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), rotary_half_dim)
self.hidden_size = getattr(rope_source_cfg, "hidden_size", None) or config.hidden_size
@staticmethod
def _resolve_mrope_section(section: Optional[list[int]], rotary_half_dim: int) -> list[int]:
if section is None:
weights = (2, 1, 1)
base = [rotary_half_dim * w // sum(weights) for w in weights]
base[0] += rotary_half_dim - sum(base)
return base
section = [int(v) for v in section]
if len(section) != 3:
raise ValueError("`mrope_section` must contain exactly three elements (temporal, height, width)")
if sum(section) != rotary_half_dim:
raise ValueError(
f"`mrope_section` must sum to the rotary half-dimension ({rotary_half_dim}). Received {section}."
)
return section
def _combine_axes(self, tensor: torch.Tensor) -> torch.Tensor:
split_sections = tuple(self.mrope_section * 2)
chunks = tensor.split(split_sections, dim=-1)
return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1)
@property
def inv_freq(self) -> torch.Tensor:
return self._qwen_rotary.inv_freq
def forward(
self,
position_ids: torch.Tensor,
modality_tensor: torch.Tensor,
hidden_states: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if position_ids.ndim != 3 or position_ids.size(-1) != 3:
raise ValueError("`position_ids` must have shape (batch, seq_len, 3) for MRoPE")
if modality_tensor.shape != position_ids.shape[:2]:
raise ValueError("`modality_tensor` must align with the first two dims of `position_ids`")
if hidden_states is None:
batch, seq_len, _ = position_ids.shape
hidden_states = torch.zeros(
batch,
seq_len,
self.hidden_size,
dtype=torch.float32,
device=position_ids.device,
)
with torch.no_grad():
pos = position_ids.clone()
image_value = VisionType.image.value if VisionType is not None else 1
not_spatial = modality_tensor != image_value
if not_spatial.any():
data_1d = pos[not_spatial][..., 0].unsqueeze(-1)
pos[not_spatial] = data_1d.expand(-1, pos.shape[-1])
pos_axes = pos.permute(2, 0, 1).contiguous()
cos_axes, sin_axes = self._qwen_rotary(hidden_states, pos_axes)
cos_axes = cos_axes.to(hidden_states.dtype)
sin_axes = sin_axes.to(hidden_states.dtype)
cos_combined = self._combine_axes(cos_axes)
sin_combined = self._combine_axes(sin_axes)
return cos_combined, sin_combined
class IsaacModel(Qwen3PreTrainedModel):
supports_gradient_checkpointing = True
_can_compile_fullgraph = False
_supports_flex_attn = False
# Expose tied-weights mapping even if empty for base model tests.
all_tied_weights_keys: dict[str, str] = {}
def __init__(self, config: IsaacConfig):
Qwen3PreTrainedModel.__init__(self, config)
text_cfg_source = config.text_config
text_cfg = copy.deepcopy(text_cfg_source)
self.text_model = AutoModel.from_config(text_cfg)
# Ensure downstream callers observe the composed config
self.text_model.config = config
self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device)
if config.vision_config is None:
raise ValueError("IsaacConfig should always have vision_config")
self.vision_tower = IsaacVisionTransformer(config.vision_config)
self.multimodal_projector = IsaacMultiModalProjector(config)
# Dispatch table for TensorStream balanced embedding (text + vision)
self.embed_fns = {
TextType: self.embed_text_tokens,
VisionType: self.embed_vision,
}
# Keep track of config attributes that downstream utilities may query directly on the model.
self.max_sequence_length = config.max_sequence_length
self.vision_rescale_factor = config.vision_rescale_factor
self.vision_token = config.vision_token
# Initialize weights and parallel plans (including tp_plan from the text model)
self.post_init()
# Respect config-specified gradient checkpointing
if getattr(config, "gradient_checkpointing", False):
self.gradient_checkpointing_enable()
def get_input_embeddings(self) -> nn.Module:
return self.text_model.get_input_embeddings()
def set_input_embeddings(self, value: nn.Module) -> None:
self.text_model.set_input_embeddings(value)
vocab_size = getattr(value, "num_embeddings", None)
if vocab_size is not None:
self.config.vocab_size = vocab_size
if hasattr(self.config, "text_config"):
self.config.text_config.vocab_size = vocab_size
self.text_model.config.vocab_size = vocab_size
@property
def embed_tokens(self) -> nn.Module:
return self.text_model.embed_tokens
@embed_tokens.setter
def embed_tokens(self, value: nn.Module) -> None:
self.text_model.embed_tokens = value
@property
def layers(self) -> nn.ModuleList:
return self.text_model.layers
@property
def norm(self) -> nn.Module:
return self.text_model.norm
@property
def vision_model(self) -> nn.Module:
return self.vision_embedding.vision_tower
def embed_text_tokens(self, token_ids: torch.Tensor) -> torch.Tensor:
"""Embed text tokens, squeezing singleton dimensions."""
# Text events are shaped as (..., 1); squeeze the singleton index dim
h = self.text_model.embed_tokens(token_ids)
if h.dim() >= 2 and h.size(-2) == 1:
h = h[..., 0, :]
return h
def embed_vision(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
"""Embed vision tokens using the vision encoder."""
# vision tokens is (seq_patches, token_grids)
return self.multimodal_projector(self.vision_tower(vision_tokens))
def embed_stream(self, tensor_stream: TensorStream) -> torch.Tensor:
"""
Embed each modality stream independently, preserving the original TensorStream
structure.
"""
flat_stream = tensor_stream.flat_stream()
per_modality_stream = group_streams(flat_stream, group_fn=lambda ev: ev.type, schedule=False)
per_modality_compact_stream = {k: v.compact() for k, v in per_modality_stream.items()}
# Collect per-event grids for vision tokens (H, W like dims sans time)
token_grids = defaultdict(list)
for stream in tensor_stream.streams:
for event in stream:
token_grids[event.type].append(event.dims(virtual=False))
embedded_compact = {}
for stream_type, modality_payload_tensor in per_modality_compact_stream.items():
if stream_type.modality == VisionType:
# Build a (N_events, 2) grid tensor with spatial dims only
grids = token_grids.get(stream_type, [])
if len(grids) == 0:
input_tensor = modality_payload_tensor
else:
token_grids_tensor = torch.tensor(grids, dtype=torch.long, device=tensor_stream.device)[:, 1:]
input_tensor = (modality_payload_tensor, token_grids_tensor)
embedded_compact[stream_type] = self.embed_fns[stream_type.modality](input_tensor)
else:
embedded_compact[stream_type] = self.embed_fns[stream_type.modality](modality_payload_tensor)
# Reconstruct a TensorStream with embedded payloads and compact
embedded_ts = reconstruct_tensor_stream_from_compact_dict(tensor_stream, embedded_compact)
h = embedded_ts.compact() # (B, T, D)
return h
@auto_docstring
@check_model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
tensor_stream: Optional[TensorStream] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
modality_tensor: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple | BaseModelOutputWithPast:
"""
Forward pass with MRoPE position embeddings.
Computes position embeddings once and passes them through all layers.
Args:
tensor_stream (`TensorStream`, *optional*):
Packed multimodal stream of text and vision events to embed directly. Mutually exclusive with
`input_ids` and `inputs_embeds`. When provided, the method derives `position_ids` and `modality_tensor`
if they are not supplied.
modality_tensor (`torch.LongTensor`, *optional*):
Modality identifiers aligned with the embedded sequence, shaped `(batch_size, seq_len)` and containing
values from `TextType`/`VisionType`. Automatically built from `tensor_stream` or `input_ids` when
omitted.
"""
text_value = TextType.text.value if TextType is not None else 0
# Get inputs
if tensor_stream is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both tensor_stream and inputs_embeds")
if tensor_stream is None and input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
# Resolve the input source (TensorStream takes precedence over token ids).
if tensor_stream is not None:
inputs_embeds = self.embed_stream(tensor_stream)
elif input_ids is not None:
inputs_embeds = self.text_model.embed_tokens(input_ids)
elif inputs_embeds is None:
raise ValueError("You have to specify either tensor_stream, input_ids or inputs_embeds")
batch_size, seq_len = inputs_embeds.shape[:2]
# Ensure cache exists when requested
if use_cache and past_key_values is None:
cache_config = self.config.get_text_config() if hasattr(self.config, "get_text_config") else self.config
past_key_values = DynamicCache(config=cache_config)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device)
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), device=inputs_embeds.device, dtype=torch.long)
# Normalize modality tensor
if modality_tensor is None:
if tensor_stream is not None:
modality_tensor = modality_mask(tensor_stream)
else:
modality_tensor = torch.full(
(batch_size, seq_len), text_value, device=inputs_embeds.device, dtype=torch.long
)
else:
modality_tensor = modality_tensor.to(dtype=torch.long)
if modality_tensor.shape[1] != seq_len:
if modality_tensor.shape[1] > seq_len:
modality_tensor = modality_tensor[:, :seq_len]
else:
pad = modality_tensor[:, -1:].expand(-1, seq_len - modality_tensor.shape[1])
modality_tensor = torch.cat([modality_tensor, pad], dim=1)
# Normalize position ids
if position_ids is None:
if tensor_stream is not None:
position_ids = compute_mrope_pos_tensor(tensor_stream) # (B,L,3)
else:
position_ids = cache_position.view(1, -1).expand(modality_tensor.shape[0], -1)
# Expand 2D position ids (from generic padding tests or decode cache positions) to 3D MRoPE coords
if position_ids.ndim == 2:
position_ids = position_ids.to(device=inputs_embeds.device)
position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3)
# Align lengths so rotary embedding sees matching shapes
if position_ids.shape[1] != seq_len:
start_positions = position_ids[:, :1, 0]
position_ids = torch.arange(seq_len, device=inputs_embeds.device).view(1, -1)
position_ids = position_ids + start_positions
position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3)
# Compute MRoPE position embeddings if we have custom rotary_emb
cos, sin = self.rotary_emb(
position_ids,
modality_tensor,
hidden_states=inputs_embeds,
)
cos = cos.to(inputs_embeds.dtype)
sin = sin.to(inputs_embeds.dtype)
# Flash attention expects 1D position_ids; keep 3D only for rotary phases
decoder_position_ids = position_ids[..., 0] if position_ids.ndim == 3 else position_ids
# Prepare attention mask
if not isinstance(attention_mask, dict):
attention_mask = create_masks_for_generate(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=decoder_position_ids,
)
is_attention_mask_dict = isinstance(attention_mask, dict)
# Initialize hidden states
hidden_states = inputs_embeds
all_attentions = [] if output_attentions else None
for decoder_layer in self.text_model.layers:
layer_attention_mask = (
attention_mask[decoder_layer.attention_type] if is_attention_mask_dict else attention_mask
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=layer_attention_mask,
position_ids=decoder_position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=(cos, sin),
output_attentions=output_attentions,
**kwargs,
)
layer_outputs_is_tuple = isinstance(layer_outputs, tuple)
hidden_states = layer_outputs[0] if layer_outputs_is_tuple else layer_outputs
if output_attentions and layer_outputs_is_tuple:
all_attentions.append(layer_outputs[1])
# Final layer norm
hidden_states = self.text_model.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=(hidden_states,),
attentions=tuple(all_attentions) if output_attentions else None,
)
class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin):
"""Isaac multimodal model for conditional generation."""
config_class = IsaacConfig
_can_compile_fullgraph = False
_tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"}
all_tied_weights_keys: dict[str, str] = {"lm_head.weight": "model.text_model.embed_tokens.weight"}
def __init__(self, config: IsaacConfig):
config.pad_token_id = 0
super().__init__(config)
self.model = IsaacModel(config) # Use our custom model
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Tracks rotary position offsets computed during a full forward pass so decode steps can reuse them.
self.rope_deltas = None
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
tensor_stream: Optional[TensorStream] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple | CausalLMOutputWithPast:
r"""
Forward pass for conditional generation supporting both standard inputs and TensorStream.
tensor_stream (`TensorStream`, *optional*):
Packed multimodal stream (text, vision, audio tokens) that already encodes spatial metadata. When provided,
the model derives embeddings, modality masks, and 3D rotary coordinates directly from the stream instead of
`input_ids`.
"""
# Don't compute embeddings here - let the model handle it
if tensor_stream is not None:
input_ids = None
if input_ids is None and inputs_embeds is None and tensor_stream is None:
raise ValueError("Either input_ids, inputs_embeds, or tensor_stream must be provided.")
text_value = TextType.text.value if TextType is not None else 0
if tensor_stream is None:
if input_ids is not None:
batch_size, seq_len = input_ids.shape
input_device = input_ids.device
else:
batch_size, seq_len = inputs_embeds.shape[:2]
input_device = inputs_embeds.device
# Build position ids (MRoPE) if needed and tensor_stream is available
# During decode we reuse `self.rope_deltas` computed on the initial forward pass; `rope_delta` captures how far
# cached rotary phases have progressed so we can advance `position_ids` without rebuilding the TensorStream.
if position_ids is None:
if tensor_stream is not None:
position_ids, self.rope_deltas = self.get_rope_index(input_ids, tensor_stream, attention_mask)
elif input_ids is None:
dummy_ids = torch.zeros((batch_size, seq_len), device=input_device, dtype=torch.long)
position_ids = compute_position_ids_input_ids(dummy_ids)
else:
position_ids = compute_position_ids_input_ids(input_ids)
rope_delta = 0
if cache_position is not None and self.rope_deltas is not None:
# Combine the incremental decode step (`cache_position`) with cached offsets so hidden states continue
# rotating in lockstep across generation steps.
rope_delta = (cache_position[0] + self.rope_deltas).to(input_ids.device)
if not isinstance(rope_delta, int): # otherwise `deltas` is an int `0`
rope_delta = rope_delta.repeat_interleave(batch_size // rope_delta.shape[0], dim=0)
position_ids = position_ids.add(rope_delta)
if attention_mask is None and tensor_stream is None:
attention_mask = torch.ones((batch_size, seq_len), device=input_device, dtype=torch.long)
if tensor_stream is not None:
modality_tensor = modality_mask(tensor_stream)
else:
modality_tensor = torch.full(
(batch_size, seq_len), text_value, device=position_ids.device, dtype=torch.long
)
outputs = self.model(
input_ids=input_ids,
tensor_stream=tensor_stream,
attention_mask=attention_mask,
position_ids=position_ids,
modality_tensor=modality_tensor,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions if output_attentions else None,
)
def set_input_embeddings(self, value: nn.Module) -> None:
self.model.set_input_embeddings(value)
vocab_size = getattr(value, "num_embeddings", None)
if vocab_size is not None:
self.config.vocab_size = vocab_size
self.model.config.vocab_size = vocab_size
if hasattr(self.model, "text_model"):
self.model.text_model.config.vocab_size = vocab_size
if self.lm_head.weight.shape[0] != vocab_size:
self.lm_head = nn.Linear(self.config.hidden_size, vocab_size, bias=False)
if hasattr(self.model, "embed_tokens"):
self.lm_head.weight = self.model.text_model.embed_tokens.weight
def get_rope_index(
self,
input_ids: Optional[torch.Tensor],
tensor_stream: Optional[TensorStream],
attention_mask: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute MRoPE position ids from a TensorStream (or 1D fallback).
Returns (position_ids, rope_deltas). position_ids is (B,L,3) for MRoPE.
rope_deltas is (B,1) used to advance positions in decode.
"""
# tensor_stream present: compute 3D coords
if tensor_stream is None and input_ids is None:
raise ValueError("`tensor_stream` or `input_ids` must be provided to compute rope indices")
if tensor_stream is not None:
pos_3d = compute_mrope_pos_tensor(tensor_stream) # (B,L,3)
else:
pos_3d = compute_position_ids_input_ids(input_ids)
B, L, _ = pos_3d.shape
# Max position per batch across the 3 planes and sequence dimension: (B,)
m_per_batch = pos_3d.amax(dim=(1, 2))
# Sequence lengths per batch: (B,)
if attention_mask is None:
seq_lens = torch.full_like(m_per_batch, L)
else:
seq_lens = attention_mask.eq(1).sum(dim=-1).to(dtype=m_per_batch.dtype, device=m_per_batch.device)
rope_deltas = (m_per_batch + 1 - seq_lens).to(dtype=pos_3d.dtype).unsqueeze(1)
return pos_3d, rope_deltas
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[list[torch.FloatTensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
tensor_stream: Optional[TensorStream] = None,
cache_position: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
use_cache: bool = True,
**kwargs,
) -> dict[str, Any]:
"""
Prepare inputs for generation, handling TensorStream inputs properly.
"""
if cache_position is None:
seq_length = None
device = None
if input_ids is not None:
seq_length = input_ids.shape[1]
device = input_ids.device
elif inputs_embeds is not None:
seq_length = inputs_embeds.shape[1]
device = inputs_embeds.device
elif tensor_stream is not None:
_, seq_length = tensor_stream.shape
device = tensor_stream.device
if seq_length is not None:
# prepare_inputs_for_generation may be invoked outside `generate`, so synthesize the
# same cache positions that GenerationMixin would have created during prefill.
cache_position = torch.arange(seq_length, dtype=torch.long, device=device)
# Call parent preparation
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
use_cache=use_cache,
**kwargs,
)
cache_position = model_inputs.get("cache_position", cache_position)
# Handle TensorStream only for the prefill step
first_step = cache_position is None or cache_position[0] == 0
if tensor_stream is not None and first_step:
model_inputs["tensor_stream"] = tensor_stream
# Let forward rebuild MRoPE coordinates from the TensorStream
model_inputs["position_ids"] = None
else:
model_inputs["tensor_stream"] = None
# TensorStream decode path: preserve rotary offsets from prefill
if tensor_stream is not None and not first_step and self.rope_deltas is not None:
model_inputs["position_ids"] = None
return model_inputs
# For decode steps, synthesize position_ids that continue from the cache offsets
if model_inputs.get("position_ids") is None and cache_position is not None and not first_step:
batch_size = 1
if model_inputs.get("input_ids") is not None:
batch_size = model_inputs["input_ids"].shape[0]
elif model_inputs.get("inputs_embeds") is not None:
batch_size = model_inputs["inputs_embeds"].shape[0]
pos_ids = cache_position.view(1, -1).expand(batch_size, -1)
pos_ids = pos_ids.unsqueeze(-1).expand(-1, -1, 3)
model_inputs["position_ids"] = pos_ids
return model_inputs
@classmethod
def can_generate(cls) -> bool:
return True
def _compute_residual_p_frames(frames: torch.Tensor, is_p_frame: list[bool]) -> torch.Tensor:
"""Compute residuals for P-frames to stay in sync with the training pipeline."""
if not any(is_p_frame):
return frames
frame_indices = torch.arange(len(is_p_frame), device=frames.device)
i_frame_mask = torch.tensor([not flag for flag in is_p_frame], device=frames.device)
last_i_indices = torch.cummax((i_frame_mask * (1 + frame_indices)), dim=0).values.long() - 1
p_indices = frame_indices[torch.tensor(is_p_frame, device=frames.device)]
frames[p_indices] = frames[p_indices] - frames[last_i_indices[p_indices]]
return frames
__all__ = [
"IsaacConfig",
"IsaacModel",
"IsaacPreTrainedModel", # noqa: F822
"IsaacForConditionalGeneration",
"IsaacImageProcessorFast",
"IsaacProcessor",
]