| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import annotations |
|
|
| import copy |
| import math |
| import re |
| from collections import defaultdict |
| from collections.abc import Callable, Sequence |
| from typing import Any, Optional, Union |
|
|
| from transformers.utils.import_utils import ( |
| is_torch_available, |
| is_torchdynamo_compiling, |
| is_vision_available, |
| ) |
|
|
|
|
| if is_torch_available(): |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| if is_vision_available(): |
| from PIL.Image import Image |
| else: |
| Image = None |
|
|
|
|
| from perceptron.tensorstream.ops import ( |
| compute_mrope_pos_tensor, |
| modality_mask, |
| reconstruct_tensor_stream_from_compact_dict, |
| tensor_stream_token_view, |
| ) |
| from perceptron.tensorstream.ops import ( |
| slice as ts_slice, |
| ) |
| from perceptron.tensorstream.tensorstream import ( |
| Event, |
| Stream, |
| TensorStream, |
| TextType, |
| VisionType, |
| create_stream, |
| group_streams, |
| ) |
|
|
| from transformers.cache_utils import DynamicCache |
| from transformers.configuration_utils import PretrainedConfig, layer_type_validation |
| from transformers.feature_extraction_utils import BatchFeature |
| from transformers.generation.utils import GenerationMixin |
| from transformers.processing_utils import ImagesKwargs |
| from transformers.image_transforms import group_images_by_shape, reorder_images |
| from transformers.image_utils import SizeDict |
| from transformers.image_processing_utils_fast import ( |
| BaseImageProcessorFast, |
| ) |
| from transformers.image_utils import ( |
| ChannelDimension, |
| PILImageResampling, |
| ) |
| from transformers.masking_utils import create_masks_for_generate, packed_sequence_mask_function |
| from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast |
| from transformers.modeling_rope_utils import rope_config_validation |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS |
| from transformers.models.auto.modeling_auto import AutoModel |
| from transformers.models.auto.tokenization_auto import AutoTokenizer |
| from transformers.models.qwen3.configuration_qwen3 import Qwen3Config |
| from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM, Qwen3PreTrainedModel |
| from transformers.processing_utils import ProcessorMixin, Unpack |
| from transformers.utils.generic import TensorType |
| from transformers.utils.auto_docstring import auto_docstring |
|
|
| |
| from transformers.utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN |
| from transformers.utils.constants import IMAGENET_STANDARD_STD as VISION_STD |
| from transformers.utils.generic import TransformersKwargs, can_return_tuple, check_model_inputs |
| from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl as qwen2_5_vl_modeling |
| from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig |
| from transformers.models.siglip2.modeling_siglip2 import ( |
| Siglip2Attention, |
| Siglip2Encoder, |
| Siglip2EncoderLayer, |
| ) |
|
|
|
|
| class IsaacVisionConfig(Siglip2VisionConfig): |
| """Vision configuration for Isaac with Pixel Shuffle support. |
| |
| Extends Siglip2VisionConfig with additional fields for pixel shuffle. |
| |
| Args: |
| pixel_shuffle_scale_factor (`int`, *optional*, defaults to 1): |
| Spatial factor applied before pixel shuffle reduces the resolution. |
| num_patches (`int`, *optional*, defaults to 256): |
| Maximum number of learnable positional embeddings to initialize. |
| """ |
|
|
| model_type = "isaac_vision" |
| base_config_key = "vision_config" |
|
|
| def __init__( |
| self, |
| hidden_size=768, |
| intermediate_size=3072, |
| num_hidden_layers=12, |
| num_attention_heads=12, |
| num_channels=3, |
| num_patches=256, |
| patch_size=16, |
| hidden_act="gelu_pytorch_tanh", |
| layer_norm_eps=1e-6, |
| attention_dropout=0.0, |
| pixel_shuffle_scale_factor=1, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
|
|
| self.hidden_size = hidden_size |
| self.intermediate_size = intermediate_size |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| self.num_channels = num_channels |
| self.patch_size = patch_size |
| self.attention_dropout = attention_dropout |
| self.layer_norm_eps = layer_norm_eps |
| self.hidden_act = hidden_act |
| self.num_patches = num_patches |
|
|
| |
| self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor |
|
|
| |
| if getattr(self, "_attn_implementation", None) is None: |
| self._attn_implementation = "sdpa" |
|
|
|
|
| class IsaacImageProcessorKwargs(ImagesKwargs, total=False): |
| patch_size: Optional[int] |
| max_num_patches: Optional[int] |
| min_num_patches: Optional[int] |
| pixel_shuffle_scale: Optional[int] |
|
|
|
|
| @auto_docstring |
| class IsaacImageProcessorFast(BaseImageProcessorFast): |
| MAX_PIXELS = 60_000_000 |
| r"""Fast torch-based image processor for Isaac vision inputs.""" |
|
|
| resample = PILImageResampling.BILINEAR |
| model_input_names = ["patches", "token_grids"] |
| valid_kwargs = IsaacImageProcessorKwargs |
| unused_kwargs = ["size", "do_center_crop", "crop_size"] |
|
|
| do_resize = True |
| size: Optional[SizeDict] = None |
| default_to_square: Optional[bool] = None |
| do_center_crop = False |
| crop_size: Optional[SizeDict] = None |
| patch_size: Optional[int] = 16 |
| max_num_patches: Optional[int] = 256 |
| min_num_patches: Optional[int] = None |
| pixel_shuffle_scale: Optional[int] = 1 |
| do_pad = False |
| pad_size: Optional[SizeDict] = None |
| do_rescale = True |
| rescale_factor = 1 / 255 |
| do_normalize = True |
| image_mean = list(VISION_MEAN) |
| image_std = list(VISION_STD) |
| do_convert_rgb = True |
| return_tensors = None |
| data_format = ChannelDimension.FIRST |
| input_data_format = None |
| device = None |
| disable_grouping = False |
| size_divisor: Optional[int] = None |
|
|
| def __init__( |
| self, |
| **kwargs: Unpack[IsaacImageProcessorKwargs], |
| ) -> None: |
| super().__init__(**kwargs) |
|
|
| pixel_shuffle_scale = 1 if self.pixel_shuffle_scale is None else int(self.pixel_shuffle_scale) |
| if pixel_shuffle_scale < 1: |
| raise ValueError("`pixel_shuffle_scale` must be >= 1") |
| self.pixel_shuffle_scale = pixel_shuffle_scale |
|
|
| def _validate_preprocess_kwargs(self, **kwargs): |
| |
| kwargs.pop("do_resize", None) |
| kwargs.pop("size", None) |
| kwargs.pop("do_center_crop", None) |
| kwargs.pop("crop_size", None) |
| kwargs.pop("disable_grouping", None) |
| return super()._validate_preprocess_kwargs(**kwargs) |
|
|
| def resize( |
| self, |
| image: torch.Tensor, |
| size: SizeDict, |
| interpolation: Optional[Any] = None, |
| antialias: bool = True, |
| **kwargs, |
| ) -> torch.Tensor: |
| if size.height is None or size.width is None: |
| raise ValueError("IsaacImageProcessorFast requires explicit `height` and `width` when resizing.") |
|
|
| resize_mode: Any = interpolation |
| if hasattr(resize_mode, "value"): |
| resize_mode = resize_mode.value |
| elif hasattr(resize_mode, "name"): |
| resize_mode = resize_mode.name.lower() |
| elif resize_mode is None: |
| resize_mode = "bilinear" |
|
|
| if isinstance(resize_mode, str): |
| mode_key = resize_mode.lower() |
| else: |
| mode_key = resize_mode |
|
|
| resize_kwargs: dict[str, Any] = {} |
| if mode_key in {"linear", "bilinear", "bicubic", "trilinear"}: |
| resize_kwargs["align_corners"] = False |
|
|
| return F.interpolate( |
| image, |
| size=(size.height, size.width), |
| mode=resize_mode, |
| **resize_kwargs, |
| ) |
|
|
| def _preprocess( |
| self, |
| images: list[torch.Tensor], |
| do_resize: bool, |
| size: Optional[SizeDict], |
| interpolation: Optional[Any], |
| do_center_crop: bool, |
| crop_size: Optional[SizeDict], |
| do_rescale: Optional[bool], |
| rescale_factor: Optional[float], |
| do_normalize: Optional[bool], |
| image_mean: Optional[Union[float, Sequence[float]]], |
| image_std: Optional[Union[float, Sequence[float]]], |
| disable_grouping: Optional[bool] = None, |
| return_tensors: Optional[Union[str, TensorType]] = None, |
| do_pad: Optional[bool] = None, |
| pad_size: Optional[SizeDict] = None, |
| *, |
| patch_size: Optional[int] = None, |
| max_num_patches: Optional[int] = None, |
| min_num_patches: Optional[int] = None, |
| pixel_shuffle_scale: Optional[int] = None, |
| **kwargs, |
| ) -> BatchFeature: |
| if do_center_crop: |
| raise ValueError("`do_center_crop` is not supported by IsaacImageProcessorFast.") |
| if do_pad: |
| raise ValueError("`do_pad` is not supported by IsaacImageProcessorFast.") |
|
|
| grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) |
| processed_patches_grouped: dict[tuple[int, ...], torch.Tensor] = {} |
| token_grids_grouped: dict[tuple[int, ...], torch.Tensor] = {} |
| virtual_dims_grouped: dict[tuple[int, ...], torch.Tensor] = {} |
| real_dims_grouped: dict[tuple[int, ...], torch.Tensor] = {} |
|
|
| for shape, stacked_images in grouped_images.items(): |
| if stacked_images.ndim != 4: |
| raise ValueError("Expected batched channel-first image tensors.") |
|
|
| batch_size, channels, original_height, original_width = stacked_images.shape |
|
|
| if bool(self.do_convert_rgb) and channels == 1: |
| stacked_images = stacked_images.repeat(1, 3, 1, 1) |
| channels = 3 |
|
|
| if original_height * original_width > self.MAX_PIXELS: |
| raise ValueError(f"Image (w={original_width}, h={original_height}) > MAX=`{self.MAX_PIXELS}`") |
|
|
| target_height, target_width = get_image_size_for_max_num_patches( |
| original_height, |
| original_width, |
| patch_size, |
| max_num_patches, |
| min_num_patches=min_num_patches, |
| pixel_shuffle_scale=pixel_shuffle_scale, |
| ) |
|
|
| if do_resize: |
| resize_size = SizeDict(height=target_height, width=target_width) |
| image_batch = self.resize( |
| image=stacked_images, |
| size=resize_size, |
| interpolation=interpolation, |
| ) |
| else: |
| if ((original_height % patch_size) != 0) or ((original_width % patch_size) != 0): |
| raise ValueError("Image dimensions must be divisible by patch_size when resize is disabled.") |
| image_batch = stacked_images |
| target_height, target_width = original_height, original_width |
|
|
| if do_rescale: |
| image_batch = self.rescale_and_normalize( |
| image_batch, |
| do_rescale=do_rescale, |
| rescale_factor=rescale_factor, |
| do_normalize=do_normalize, |
| image_mean=image_mean, |
| image_std=image_std, |
| ) |
|
|
| nhwc_images = image_batch.permute(0, 2, 3, 1) |
| nhwc_images = _compute_residual_p_frames(nhwc_images, is_p_frame=[False] * batch_size) |
|
|
| patches = patchify_vision(nhwc_images, patch_size=patch_size) |
| _, height_tokens, width_tokens, _ = patches.shape |
|
|
| token_grid = ( |
| torch.tensor( |
| [height_tokens, width_tokens], |
| dtype=torch.long, |
| device=patches.device, |
| ) |
| .unsqueeze(0) |
| .repeat(batch_size, 1) |
| ) |
|
|
| real_dim = ( |
| torch.tensor( |
| [1, height_tokens, width_tokens], |
| dtype=torch.long, |
| device=patches.device, |
| ) |
| .unsqueeze(0) |
| .repeat(batch_size, 1) |
| ) |
|
|
| if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale): |
| raise ValueError( |
| "Spatial dimensions must be divisible by pixel_shuffle_scale when pixel shuffle is enabled." |
| ) |
| virtual_height = height_tokens // pixel_shuffle_scale |
| virtual_width = width_tokens // pixel_shuffle_scale |
|
|
| virtual_dim = ( |
| torch.tensor( |
| [1, virtual_height, virtual_width], |
| dtype=torch.long, |
| device=patches.device, |
| ) |
| .unsqueeze(0) |
| .repeat(batch_size, 1) |
| ) |
|
|
| processed_patches_grouped[shape] = patches |
| token_grids_grouped[shape] = token_grid |
| virtual_dims_grouped[shape] = virtual_dim |
| real_dims_grouped[shape] = real_dim |
|
|
| patches_slices = reorder_images(processed_patches_grouped, grouped_images_index) |
| token_grid_slices = reorder_images(token_grids_grouped, grouped_images_index) |
| virtual_dim_slices = reorder_images(virtual_dims_grouped, grouped_images_index) |
| real_dim_slices = reorder_images(real_dims_grouped, grouped_images_index) |
|
|
| patches_tensor = torch.stack(patches_slices, dim=0) |
| token_grids_tensor = torch.stack(token_grid_slices, dim=0) |
| virtual_dims_tensor = torch.stack(virtual_dim_slices, dim=0) |
| real_dims_tensor = torch.stack(real_dim_slices, dim=0) |
|
|
| return BatchFeature( |
| data={ |
| "patches": patches_tensor, |
| "token_grids": token_grids_tensor, |
| "virtual_pixel_size": virtual_dims_tensor, |
| "real_pixel_size": real_dims_tensor, |
| }, |
| tensor_type=return_tensors, |
| ) |
|
|
|
|
| def document_mask_function_from_cu_seqlens(cu_seqlens: Optional[torch.Tensor]) -> Optional[Callable]: |
| """Return a mask function that blocks cross-document attention from packed ``cu_seqlens``. |
| |
| The returned callable matches the signature expected by ``masking_utils`` mask factories and |
| yields ``True`` only when query/key positions belong to the same packed segment. |
| """ |
|
|
| if cu_seqlens is None: |
| return None |
|
|
| if cu_seqlens.numel() < 2: |
| return None |
|
|
| seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long() |
| if seq_sizes.numel() == 0: |
| return None |
|
|
| total_tokens = int(seq_sizes.sum().item()) |
| seg_ids = torch.repeat_interleave(torch.arange(seq_sizes.numel(), device=cu_seqlens.device), seq_sizes) |
| packed_sequence_mask = seg_ids.view(1, total_tokens) |
| return packed_sequence_mask_function(packed_sequence_mask) |
|
|
|
|
| def ensure_document_attention_mask( |
| attention_mask: Optional[torch.Tensor], |
| cu_seqlens: Optional[torch.Tensor], |
| total_tokens: int, |
| dtype: torch.dtype, |
| device: torch.device, |
| *, |
| return_mask_function: bool = False, |
| ) -> Optional[Union[torch.Tensor, Callable]]: |
| """Return the provided mask, a callable mask from ``cu_seqlens``, or ``None``. |
| |
| ``return_mask_function=True`` yields a callable suitable for ``masking_utils``; otherwise |
| ``None`` is returned when no explicit ``attention_mask`` is provided. The legacy additive mask |
| has been removed in favor of the callable-based path. |
| """ |
|
|
| if attention_mask is not None: |
| return attention_mask |
|
|
| if cu_seqlens is None: |
| return None |
|
|
| if return_mask_function: |
| return document_mask_function_from_cu_seqlens(cu_seqlens) |
|
|
| return None |
|
|
|
|
| class IsaacVisionEmbeddings(nn.Module): |
| """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences.""" |
|
|
| def __init__(self, config: IsaacVisionConfig): |
| super().__init__() |
| self.config = config |
| self.embed_dim = config.hidden_size |
| self.patch_size = config.patch_size |
|
|
| self.patch_embedding = nn.Linear( |
| in_features=config.num_channels * self.patch_size * self.patch_size, |
| out_features=self.embed_dim, |
| ) |
|
|
| self.num_patches = config.num_patches |
| self.position_embedding_size = int(self.num_patches**0.5) |
| self.position_embedding = nn.Parameter( |
| torch.empty( |
| self.position_embedding_size, |
| self.position_embedding_size, |
| self.embed_dim, |
| ) |
| ) |
| nn.init.normal_(self.position_embedding) |
|
|
| def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor) -> torch.Tensor: |
| packed_pixel_values, seq_lengths = self._pack_to_batch(seq_patches, spatial_shapes) |
| if packed_pixel_values is None: |
| return seq_patches.new_zeros((0, self.embed_dim)) |
|
|
| target_dtype = self.patch_embedding.weight.dtype |
| patch_embeds = self.patch_embedding(packed_pixel_values.to(dtype=target_dtype)) |
|
|
| positional_embeddings = self.position_embedding |
| resized_positional_embeddings = self.resize_positional_embeddings( |
| positional_embeddings, spatial_shapes, max_length=packed_pixel_values.shape[1] |
| ) |
|
|
| embeddings = patch_embeds + resized_positional_embeddings |
| return self._unpack_from_batch(embeddings, seq_lengths) |
|
|
| @staticmethod |
| def resize_positional_embeddings( |
| positional_embeddings: torch.Tensor, |
| spatial_shapes: torch.LongTensor, |
| max_length: int, |
| ) -> torch.Tensor: |
| """ |
| Resize positional embeddings to image-specific size and pad to a fixed size. |
| |
| Args: |
| positional_embeddings (`torch.Tensor`): |
| Position embeddings of shape (height, width, embed_dim) |
| spatial_shapes (`torch.LongTensor`): |
| Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to |
| max_length (`int`): |
| Maximum length of the positional embeddings to pad resized positional embeddings to |
| |
| Returns: |
| `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim) |
| """ |
| batch_size = spatial_shapes.shape[0] |
| embed_dim = positional_embeddings.shape[-1] |
| source_dtype = positional_embeddings.dtype |
|
|
| resulted_positional_embeddings = torch.empty( |
| (batch_size, max_length, embed_dim), |
| device=positional_embeddings.device, |
| dtype=source_dtype, |
| ) |
|
|
| |
| positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0) |
|
|
| |
| if positional_embeddings.device.type == "cpu": |
| positional_embeddings = positional_embeddings.to(torch.float32) |
|
|
| for i in range(batch_size): |
| |
| height, width = spatial_shapes[i] |
| resized_embeddings = F.interpolate( |
| positional_embeddings, |
| size=(height, width), |
| mode="bilinear", |
| align_corners=False, |
| antialias=True, |
| ) |
|
|
| |
| resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1) |
|
|
| |
| resized_embeddings = resized_embeddings.to(source_dtype) |
|
|
| resulted_positional_embeddings[i, : height * width] = resized_embeddings |
| resulted_positional_embeddings[i, height * width :] = resized_embeddings[0] |
|
|
| return resulted_positional_embeddings |
|
|
| def _pack_to_batch( |
| self, |
| seq_patches: torch.Tensor, |
| spatial_shapes: torch.Tensor, |
| ) -> tuple[Optional[torch.Tensor], torch.Tensor]: |
| if seq_patches.ndim != 2: |
| raise ValueError("`seq_patches` is expected to be 2D (total_patches, patch_dim).") |
| if spatial_shapes.ndim != 2 or spatial_shapes.size(-1) != 2: |
| raise ValueError("`spatial_shapes` must have shape (num_images, 2) with (height_tokens, width_tokens).") |
|
|
| seq_lengths = spatial_shapes.long().prod(dim=-1) |
| total_patches = int(seq_lengths.sum().item()) |
| if total_patches != seq_patches.size(0): |
| raise ValueError( |
| "Mismatch between packed patches and spatial shapes: got " |
| f"{seq_patches.size(0)} patches but spatial shapes imply {total_patches}." |
| ) |
|
|
| batch_size = spatial_shapes.size(0) |
| if batch_size == 0: |
| return None, seq_lengths |
|
|
| max_length = int(seq_lengths.max().item()) |
| patch_dim = seq_patches.size(-1) |
| device = seq_patches.device |
|
|
| packed_pixel_values = seq_patches.new_zeros((batch_size, max_length, patch_dim), device=device) |
|
|
| start = 0 |
| for batch_idx, length in enumerate(seq_lengths.tolist()): |
| if length == 0: |
| continue |
| end = start + length |
| packed_pixel_values[batch_idx, :length] = seq_patches[start:end] |
| start = end |
|
|
| return packed_pixel_values, seq_lengths |
|
|
| def _unpack_from_batch(self, embeddings: torch.Tensor, seq_lengths: torch.Tensor) -> torch.Tensor: |
| output_chunks: list[torch.Tensor] = [] |
| for batch_idx, length in enumerate(seq_lengths.tolist()): |
| if length == 0: |
| continue |
| output_chunks.append(embeddings[batch_idx, :length]) |
|
|
| if not output_chunks: |
| return embeddings.new_zeros((0, embeddings.size(-1))) |
|
|
| return torch.cat(output_chunks, dim=0) |
|
|
|
|
| class IsaacVisionAttention(Siglip2Attention): |
| """Custom attention that supports variable-length sequences with flash attention.""" |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| past_key_value: Optional[torch.Tensor] = None, |
| output_attentions: bool = False, |
| is_causal: bool = False, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| **kwargs, |
| ): |
| |
| _ = position_ids |
| _ = past_key_value |
| _ = is_causal |
| kwargs.pop("output_hidden_states", None) |
| kwargs.pop("return_dict", None) |
|
|
| batch_size, seq_length, embed_dim = hidden_states.shape |
| queries = self.q_proj(hidden_states) |
| keys = self.k_proj(hidden_states) |
| values = self.v_proj(hidden_states) |
|
|
| queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) |
| keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) |
| values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
| if not queries.is_contiguous(): |
| queries = queries.contiguous() |
| if not keys.is_contiguous(): |
| keys = keys.contiguous() |
| if not values.is_contiguous(): |
| values = values.contiguous() |
|
|
| L = queries.size(0) |
| if max_seqlen is not None: |
| max_q = max_k = int(max_seqlen) |
| else: |
| max_q = max_k = self._max_from_cu(cu_seqlens, L) |
|
|
| attention_interface: Callable = ALL_ATTENTION_FUNCTIONS["sdpa"] |
| if self.config._attn_implementation != "sdpa": |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
| dropout = 0.0 if not self.training else self.dropout |
| attention_kwargs: dict[str, Any] = { |
| "is_causal": False, |
| "scaling": self.scale, |
| "dropout": dropout, |
| } |
| if cu_seqlens is not None: |
| attention_kwargs["cu_seq_lens_q"] = cu_seqlens |
| attention_kwargs["cu_seq_lens_k"] = cu_seqlens |
| if max_seqlen is not None: |
| attention_kwargs["max_length_q"] = max_q |
| attention_kwargs["max_length_k"] = max_k |
| if output_attentions: |
| attention_kwargs["output_attentions"] = True |
|
|
| attn_output, attn_weights = attention_interface( |
| self, |
| queries, |
| keys, |
| values, |
| attention_mask, |
| **attention_kwargs, |
| ) |
|
|
| attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() |
|
|
| |
| out_proj_dtype = self.out_proj.weight.dtype |
| if attn_output.dtype != out_proj_dtype: |
| attn_output = attn_output.to(out_proj_dtype) |
|
|
| attn_output = self.out_proj(attn_output) |
| if attn_output.dtype != hidden_states.dtype: |
| attn_output = attn_output.to(hidden_states.dtype) |
|
|
| return attn_output, attn_weights |
|
|
| @staticmethod |
| def _max_from_cu(cu: Optional[torch.Tensor], fallback: int) -> int: |
| if cu is None or cu.numel() < 2: |
| return fallback |
| return int((cu[1:] - cu[:-1]).max().item()) |
|
|
|
|
| class IsaacVisionEncoderLayer(Siglip2EncoderLayer): |
| """Isaac vision encoder layer with variable-length attention.""" |
|
|
| def __init__(self, config: IsaacVisionConfig): |
| super().__init__(config) |
| self.self_attn = IsaacVisionAttention(config) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| output_attentions: bool = False, |
| **kwargs: Unpack[TransformersKwargs], |
| ): |
| r""" |
| cu_seqlens (`torch.Tensor`, *optional*): |
| Prefix-sum tensor whose length equals the number of documents + 1. The difference between successive |
| entries gives each document's token count and enables block-diagonal attention masking for packed batches. |
| max_seqlen (`int`, *optional*): |
| Maximum document length referenced by `cu_seqlens`. Passed to FlashAttention so it can size temporary |
| buffers for packed variable-length attention. |
| """ |
| attention_mask = ensure_document_attention_mask( |
| attention_mask, |
| cu_seqlens, |
| hidden_states.size(1), |
| hidden_states.dtype, |
| hidden_states.device, |
| return_mask_function=False, |
| ) |
|
|
| |
| residual = hidden_states |
| hidden_states = self.layer_norm1(hidden_states) |
| attn_outputs = self.self_attn( |
| hidden_states, |
| attention_mask=attention_mask, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| output_attentions=output_attentions, |
| **kwargs, |
| ) |
| if isinstance(attn_outputs, tuple): |
| attn_output, attn_weights = attn_outputs |
| else: |
| attn_output, attn_weights = attn_outputs, None |
| hidden_states = residual + attn_output |
|
|
| residual = hidden_states |
| hidden_states = self.layer_norm2(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| if output_attentions: |
| return hidden_states, attn_weights |
| return hidden_states |
|
|
|
|
| class IsaacVisionEncoder(Siglip2Encoder): |
| """Encoder using Isaac encoder layers with variable-length attention support.""" |
|
|
| def __init__(self, config: IsaacVisionConfig): |
| super().__init__(config) |
| self.layers = nn.ModuleList([IsaacVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) |
|
|
| @can_return_tuple |
| def forward( |
| self, |
| inputs_embeds, |
| attention_mask: Optional[torch.Tensor] = None, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ): |
| attention_mask = ensure_document_attention_mask( |
| attention_mask, |
| cu_seqlens, |
| inputs_embeds.size(1), |
| inputs_embeds.dtype, |
| inputs_embeds.device, |
| return_mask_function=False, |
| ) |
|
|
| hidden_states = inputs_embeds |
| kwargs.update( |
| { |
| "max_seqlen": max_seqlen, |
| "cu_seqlens": cu_seqlens, |
| "output_attentions": output_attentions, |
| "output_hidden_states": output_hidden_states, |
| "return_dict": return_dict, |
| } |
| ) |
| for encoder_layer in self.layers: |
| hidden_states = encoder_layer( |
| hidden_states, |
| attention_mask, |
| **kwargs, |
| ) |
| return BaseModelOutput(last_hidden_state=hidden_states) |
|
|
|
|
| def create_pixel_shuffle_index_map( |
| seq_sizes: torch.Tensor, |
| token_grids: torch.Tensor, |
| scale_factor: int = 1, |
| device: Optional[torch.device] = None, |
| ) -> torch.Tensor: |
| """ |
| Build a gather-index map that tells us, for every *output* token after |
| pixel-shuffle, which `scale_factor**2` *input* tokens are being merged. |
| |
| Args |
| ---- |
| seq_sizes : (num_images,) - #patches in each image (row-major order) |
| token_grids : (num_images,2) - (height, width) for every image |
| scale_factor : spatial down-scale factor (≥2) |
| device : (optional) overrides `seq_sizes.device` |
| |
| Returns |
| ------- |
| gather_idx : (new_total_seq_len, scale_factor**2) int64 tensor. |
| gather_idx[i, j] is the *flat* index into the *original* |
| packed sequence for the j-th sub-patch that forms the |
| i-th output token. |
| """ |
| if device is None: |
| device = seq_sizes.device |
|
|
| scale_factor = int(scale_factor) |
| if scale_factor < 2: |
| raise ValueError("`scale_factor` must be ≥ 2") |
|
|
| |
| |
| if not is_torchdynamo_compiling(): |
| if not ((token_grids[:, 0] % scale_factor == 0).all() and (token_grids[:, 1] % scale_factor == 0).all()): |
| raise AssertionError( |
| "Every (H,W) in `token_grids` must be divisible by " |
| f"scale_factor={scale_factor}, got {token_grids.tolist()}" |
| ) |
|
|
| gather_chunks: list[torch.Tensor] = [] |
| tok_offset = 0 |
|
|
| for seq_len, (h, w) in zip(seq_sizes.tolist(), token_grids.tolist(), strict=False): |
| |
| grid = torch.arange(seq_len, device=device, dtype=torch.int64) + tok_offset |
| grid = grid.view(h, w) |
|
|
| |
| |
| grid = grid.view(h, w // scale_factor, scale_factor) |
| |
| grid = grid.view(h // scale_factor, scale_factor, w // scale_factor, scale_factor) |
| |
| |
| grid = grid.permute(0, 2, 1, 3).contiguous() |
| |
| gather_chunks.append(grid.reshape(-1, scale_factor * scale_factor)) |
| |
|
|
| tok_offset += seq_len |
|
|
| |
| gather_idx = torch.cat(gather_chunks, dim=0) |
| return gather_idx |
|
|
|
|
| def pixel_shuffle_varlen( |
| x: torch.Tensor, |
| token_grids: torch.Tensor, |
| scale_factor: int = 1, |
| ) -> torch.Tensor: |
| r"""Apply pixel shuffle to a packed vision sequence without unpacking per image. |
| |
| Args: |
| x (`torch.Tensor`): |
| Concatenated vision embeddings. Accepts `(seq_len, hidden_size)` or `(1, seq_len, hidden_size)` shapes |
| produced by stacking image patches. |
| token_grids (`torch.Tensor`): |
| Integer tensor of shape `(num_images, 2)` whose rows give the `(height, width)` patch grid sizes |
| corresponding to each image segment inside `x`. |
| scale_factor (`int`, *optional*, defaults to 1): |
| Spatial down-sampling factor specific to pixel shuffle. Values greater than one merge `scale_factor**2` neighboring patches into a |
| single embedding channel-group. |
| |
| Returns: |
| `torch.Tensor`: Pixel-shuffled embeddings with shape matching the input convention: |
| `(seq_len, hidden_size * scale_factor**2)` when the input was 2D, or `(1, seq_len, hidden_size * scale_factor**2)` |
| if the singleton batch dimension was present. |
| |
| Raises: |
| ValueError: If more than one batch item is provided. |
| """ |
| keep_batch_dim = x.dim() == 3 |
| if keep_batch_dim: |
| if x.size(0) != 1: |
| raise AssertionError("Packed sequence is expected to have batch_size == 1") |
| x_ = x.squeeze(0) |
| else: |
| x_ = x |
|
|
| embed_dim = x_.size(-1) |
| scale_factor = int(scale_factor) |
|
|
| |
| seq_sizes = torch.prod(token_grids, dim=-1) |
|
|
| |
| gather_idx = create_pixel_shuffle_index_map( |
| seq_sizes=seq_sizes, |
| token_grids=token_grids, |
| scale_factor=scale_factor, |
| device=x_.device, |
| ) |
|
|
| |
| gathered = x_[gather_idx] |
|
|
| |
| out = gathered.reshape(gathered.size(0), embed_dim * scale_factor * scale_factor) |
|
|
| |
| if keep_batch_dim: |
| out = out.unsqueeze(0) |
| return out |
|
|
|
|
| class IsaacVisionTransformer(nn.Module): |
| _supports_sdpa = True |
|
|
| def __init__(self, config: IsaacVisionConfig): |
| super().__init__() |
| self.config = config |
| self.embeddings = IsaacVisionEmbeddings(config) |
| self.encoder = IsaacVisionEncoder(config) |
| self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor |
|
|
| def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): |
| seq_patches, token_grids = packed_seq_patches |
| seq_sizes = torch.prod(token_grids, dim=-1) |
|
|
| |
| hidden_states = self.embeddings(seq_patches, token_grids) |
|
|
| |
| hidden_states = hidden_states.unsqueeze(0) |
|
|
| |
| cu_seqlens = torch.zeros(seq_sizes.size(0) + 1, dtype=torch.int32, device=hidden_states.device) |
| cu_seqlens[1:] = seq_sizes.cumsum(0) |
| max_seqlen = int(seq_sizes.max().item()) if seq_sizes.numel() > 0 else 0 |
|
|
| |
| encoder_outputs = self.encoder( |
| inputs_embeds=hidden_states, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| return_dict=True, |
| ) |
| hidden_states = encoder_outputs.last_hidden_state |
|
|
| |
| hidden_states = self.post_layernorm(hidden_states) |
|
|
| hidden_states = pixel_shuffle_varlen( |
| x=hidden_states, |
| token_grids=token_grids, |
| scale_factor=self.pixel_shuffle_scale_factor, |
| ) |
| |
| hidden_states = hidden_states.squeeze(0) |
|
|
| |
| return hidden_states |
|
|
| class IsaacMultiModalProjector(nn.Module): |
| def __init__(self, config: IsaacConfig): |
| super().__init__() |
| self.vision_hidden_size = config.vision_config.hidden_size * ( |
| config.vision_config.pixel_shuffle_scale_factor**2 |
| ) |
| self.backbone_hidden_size = config.hidden_size |
| self.linear_1 = nn.Linear(self.vision_hidden_size, 4 * self.vision_hidden_size, bias=False) |
| self.silu = nn.SiLU() |
| self.linear_2 = nn.Linear(4 * self.vision_hidden_size, self.backbone_hidden_size, bias=False) |
|
|
| def forward(self, image_features): |
| hidden_states = self.linear_1(image_features) |
| hidden_states = self.silu(hidden_states) |
| hidden_states = self.linear_2(hidden_states) |
| return hidden_states |
|
|
|
|
| class IsaacVisionEmbedding(nn.Module): |
| """Vision embedding wrapper exposing tower and projector.""" |
|
|
| _supports_sdpa = True |
|
|
| def __init__(self, config: IsaacConfig): |
| super().__init__() |
| vision_cfg = config.vision_config |
| self.vision_tower = IsaacVisionTransformer(vision_cfg) |
| self.multimodal_projector = IsaacMultiModalProjector(config) |
|
|
| def forward(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: |
| hidden_states = self.vision_tower(vision_tokens) |
| return self.multimodal_projector(hidden_states) |
|
|
|
|
| def get_scaled_image_size( |
| scale: float, |
| original_size: int, |
| patch_size: int, |
| pixel_shuffle_scale: int, |
| ) -> int: |
| scaled_size = scale * original_size |
| divisor = patch_size * pixel_shuffle_scale |
| scaled_size = math.ceil(scaled_size / divisor) * divisor |
| scaled_size = max(divisor, scaled_size) |
| return int(scaled_size) |
|
|
|
|
| def get_image_size_for_max_num_patches( |
| image_height: int, |
| image_width: int, |
| patch_size: int, |
| max_num_patches: int, |
| min_num_patches: Optional[int] = None, |
| eps: float = 1e-5, |
| pixel_shuffle_scale: int = 1, |
| ) -> tuple[int, int]: |
| r"""Compute a target resolution whose patch grid satisfies patching parametrization. |
| |
| Args: |
| image_height (`int`): |
| Height in pixels of the source image prior to any resizing. |
| image_width (`int`): |
| Width in pixels of the source image prior to any resizing. |
| patch_size (`int`): |
| Size of the square patch used by the vision encoder. |
| max_num_patches (`int`): |
| Upper bound on `(height / patch_size) * (width / patch_size)` after resizing. |
| min_num_patches (`int`, *optional*): |
| Lower bound on the number of patches. When provided the image will be scaled up if necessary. |
| eps (`float`, *optional*, defaults to 1e-5): |
| Convergence tolerance for the internal binary search to determing the target dimensions. |
| pixel_shuffle_scale (`int`, *optional*, defaults to 1): |
| Additional stride multiplier applied when pixel shuffle later reduces spatial resolution. |
| |
| Returns: |
| `tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale` |
| and respect both the maximum and optional minimum patch-count constraints. |
| """ |
|
|
| |
| divisor = patch_size * pixel_shuffle_scale |
| adjusted_height = math.ceil(image_height / divisor) * divisor |
| adjusted_height = max(divisor, adjusted_height) |
| adjusted_width = math.ceil(image_width / divisor) * divisor |
| adjusted_width = max(divisor, adjusted_width) |
|
|
| num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size) |
|
|
| if min_num_patches is not None and num_patches < min_num_patches: |
| |
| scale_min, scale_max = 1.0, 100.0 |
| while (scale_max - scale_min) >= eps: |
| scale = (scale_min + scale_max) / 2 |
| target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) |
| target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) |
| num_patches = (target_height / patch_size) * (target_width / patch_size) |
| if num_patches >= min_num_patches: |
| scale_max = scale |
| else: |
| scale_min = scale |
| scale = scale_max |
| target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) |
| target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) |
| return target_height, target_width |
| elif num_patches <= max_num_patches: |
| return adjusted_height, adjusted_width |
| else: |
| |
| scale_min, scale_max = eps / 10, 1.0 |
| while (scale_max - scale_min) >= eps: |
| scale = (scale_min + scale_max) / 2 |
| target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) |
| target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) |
| num_patches = (target_height / patch_size) * (target_width / patch_size) |
| if num_patches <= max_num_patches: |
| scale_min = scale |
| else: |
| scale_max = scale |
| scale = scale_min |
| target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) |
| target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) |
| return target_height, target_width |
|
|
|
|
| def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor: |
| r"""Convert normalized images into flattened ViT-style patches. |
| |
| Args: |
| image (`torch.Tensor`): |
| Tensor of shape `(num_images, height, width, channels)`. |
| patch_size (`int`): |
| Edge length of the square patches |
| |
| Returns: |
| `torch.Tensor`: |
| Patch tensor where each position stores the flattened pixels belonging to that patch. |
| |
| Raises: |
| ValueError: If `height` or `width` is not divisible by `patch_size`. |
| """ |
| num_images, height, width, channels = image.shape |
| if height % patch_size or width % patch_size: |
| raise ValueError(f"Dimensions of images {image.shape} are not divisible by patch_size={patch_size}.") |
| patches = image.reshape(num_images, height // patch_size, patch_size, width // patch_size, patch_size, channels) |
| patches = patches.permute(0, 1, 3, 2, 4, 5) |
| patches = patches.reshape( |
| num_images, height // patch_size, width // patch_size, channels * patch_size * patch_size |
| ) |
| return patches |
|
|
|
|
| class IsaacConfig(PretrainedConfig): |
| """Configuration class for Isaac multimodal model. |
| |
| This configuration corresponds to checkpoints such as |
| [Perceptron/isaac-base](https://huggingface.co/Perceptron/isaac-base). |
| """ |
|
|
| model_type = "isaac" |
| sub_configs = {"vision_config": IsaacVisionConfig, "text_config": Qwen3Config} |
| image_processor_type = "IsaacImageProcessor" |
|
|
| def __init__( |
| self, |
| vision_config: Optional[IsaacVisionConfig] = None, |
| text_config: Optional[Union[Qwen3Config, dict]] = None, |
| vision_rescale_factor: float = 1 / 255, |
| max_sequence_length: int = 16384, |
| vision_token: str = "<image>", |
| **kwargs, |
| ): |
| self._rope_parameters: Optional[dict[str, Any]] = None |
| attn_implementation = kwargs.get("attn_implementation") |
|
|
| if isinstance(text_config, dict): |
| self.text_config = self.sub_configs["text_config"](**text_config) |
| elif text_config is None: |
| self.text_config = self.sub_configs["text_config"]() |
|
|
| super().__init__(**kwargs) |
|
|
| if self._rope_scaling is None: |
| self._rope_scaling = getattr(self.text_config, "rope_scaling", None) |
| else: |
| self.text_config.rope_scaling = self._rope_scaling |
|
|
| |
| self._rope_parameters = self._rope_scaling |
|
|
| |
| self.vocab_size = self.text_config.vocab_size |
| self.hidden_size = self.text_config.hidden_size |
| self.num_hidden_layers = self.text_config.num_hidden_layers |
| self.num_attention_heads = self.text_config.num_attention_heads |
| self.head_dim = self.text_config.head_dim |
| self.hidden_act = self.text_config.hidden_act |
| self.use_cache = self.text_config.use_cache |
| self.rope_theta = self.text_config.rope_parameters["rope_theta"] |
|
|
| |
| rope_config_validation(self) |
|
|
| self.layer_types = getattr(self.text_config, "layer_types", None) |
| layer_type_validation(self.layer_types, self.num_hidden_layers) |
|
|
| |
| if isinstance(vision_config, dict): |
| self.vision_config = self.sub_configs["vision_config"](**vision_config) |
| elif isinstance(vision_config, IsaacVisionConfig): |
| self.vision_config = vision_config |
| elif vision_config is None: |
| self.vision_config = self.sub_configs["vision_config"]() |
|
|
| |
| if attn_implementation is not None: |
| if isinstance(attn_implementation, dict): |
| vision_attn = attn_implementation.get("vision_config", attn_implementation.get("", None)) |
| else: |
| vision_attn = attn_implementation |
| if vision_attn is not None: |
| self.vision_config._attn_implementation = vision_attn |
|
|
| |
| self.vision_rescale_factor = float(vision_rescale_factor) |
|
|
| |
| self.max_sequence_length = max_sequence_length |
| self.vision_token = vision_token |
|
|
| @property |
| def rope_scaling(self): |
| if hasattr(self, "text_config") and self.text_config is not None: |
| return getattr(self.text_config, "rope_scaling", None) |
| return self._rope_scaling |
|
|
| @rope_scaling.setter |
| def rope_scaling(self, value): |
| self._rope_scaling = value |
| if hasattr(self, "text_config") and self.text_config is not None: |
| self.text_config.rope_scaling = value |
|
|
| @property |
| def rope_parameters(self) -> dict[str, Any] | None: |
| """Alias introduced upstream for rope scaling dictionaries.""" |
| value = self._rope_parameters |
| if value is None: |
| value = self.rope_scaling |
| if value is None: |
| return {"rope_type": "default"} |
| return value |
|
|
| @rope_parameters.setter |
| def rope_parameters(self, value: dict[str, Any] | None) -> None: |
| self._rope_parameters = value |
| self.rope_scaling = value |
|
|
| def to_dict(self): |
| output = super().to_dict() |
| |
| if hasattr(self, "text_config") and self.text_config is not None: |
| output["text_config"] = self.text_config.to_dict() |
| if hasattr(self, "vision_config") and self.vision_config is not None: |
| output["vision_config"] = self.vision_config.to_dict() |
| return output |
|
|
|
|
| |
| |
| |
|
|
|
|
| def create_text_event(tokenizer: AutoTokenizer, text: str, time: float = 0.0) -> Event: |
| r"""Wrap a text into an `Event` compatible with the multimodal TensorStream. |
| |
| Args: |
| tokenizer (`AutoTokenizer`): |
| Tokenizer used to convert text into model vocabulary ids. |
| text (`str`): |
| Plain-text fragment to encode. |
| time (`float`, *optional*, defaults to 0.0): |
| Timeline coordinate associated with the event. Both start and end times use the same value because text |
| segments are instantaneous in the scheduler. |
| |
| Returns: |
| `Event`: Event carrying a `(num_tokens, 1)` tensor of token ids with matching |
| metadata so that downstream processors can compute modality-specific embeddings. |
| """ |
| tokens = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt").squeeze(0) |
|
|
| |
| num_tokens = len(tokens) |
| dims_virtual = [num_tokens, 1] |
| dims_real = dims_virtual.copy() |
|
|
| |
| |
| if tokens.dim() == 1: |
| tokens = tokens.unsqueeze(-1) |
|
|
| return Event( |
| data=tokens, |
| type=TextType.text, |
| time=(time, time), |
| dims_virtual=dims_virtual, |
| dims_real=dims_real, |
| idx_range=(0, num_tokens), |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class IsaacProcessor(ProcessorMixin): |
| attributes = ["image_processor", "tokenizer"] |
| image_processor_class = ("IsaacImageProcessorFast",) |
| tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") |
|
|
| def __init__( |
| self, |
| image_processor, |
| tokenizer, |
| *, |
| vision_token: str = "<image>", |
| max_sequence_length: int = 16384, |
| rescale_factor: Optional[float] = None, |
| config: Optional[Union[IsaacConfig, dict]] = None, |
| ) -> None: |
| if tokenizer is None: |
| raise ValueError("`tokenizer` must be provided to initialize IsaacProcessor.") |
|
|
| if isinstance(config, dict): |
| config = IsaacConfig(**config) |
|
|
| if config is not None: |
| max_sequence_length = config.max_sequence_length |
| vision_token = config.vision_token |
| rescale_factor = config.vision_rescale_factor |
|
|
| resolved_rescale_factor = float(rescale_factor) if rescale_factor is not None else float(1 / 255) |
|
|
| if config is not None: |
| config.vision_rescale_factor = resolved_rescale_factor |
|
|
| self.image_processor = image_processor |
|
|
| super().__init__(image_processor, tokenizer) |
| self.current_processor = self.image_processor |
| self.config = config |
|
|
| |
| self.chat_template = getattr(self.tokenizer, "chat_template", None) |
|
|
| self.vision_token = vision_token |
| self.max_sequence_length = max_sequence_length |
|
|
| def build_event_stream_simple( |
| self, |
| text: str, |
| images: Optional[list[Image]] = None, |
| ) -> Stream: |
| events = [] |
| |
| |
|
|
| pattern = re.escape(self.vision_token) |
| parts = re.split(f"({pattern})", text) |
|
|
| image_idx = 0 |
| for current_time, part in enumerate(parts): |
| if part == self.vision_token: |
| |
| if images is None or image_idx >= len(images): |
| raise ValueError("Encountered vision token without a corresponding image.") |
|
|
| features = self.image_processor( |
| images=images[image_idx], |
| return_tensors=TensorType.PYTORCH, |
| ) |
|
|
| patches = features["patches"][0] |
| virtual_dims = features["virtual_pixel_size"][0].tolist() |
| real_dims = features["real_pixel_size"][0].tolist() |
|
|
| vision_event = Event( |
| data=patches.reshape(-1, patches.shape[-1]), |
| type=VisionType.image, |
| time=(current_time, current_time), |
| dims_virtual=virtual_dims, |
| dims_real=real_dims, |
| idx_range=(0, math.prod(virtual_dims)), |
| ) |
| events.append(vision_event) |
| image_idx += 1 |
| elif part: |
| |
| text_event = create_text_event(self.tokenizer, part, time=current_time) |
| events.append(text_event) |
|
|
| |
| return create_stream(events, priority=[TextType.text, VisionType.image], schedule=True) |
|
|
| def __call__( |
| self, |
| text: Union[str, list[str]], |
| images: Optional[Union[Image, list[Image]]] = None, |
| return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, |
| **kwargs, |
| ) -> BatchFeature: |
| """ |
| Process text and images into TensorStream format. |
| Args: |
| text: Input text or list of texts with vision tokens |
| images: PIL image or list of images (optional) |
| return_tensors: Format for output tensors |
| |
| Returns: |
| BatchFeature with input_ids and tensor_stream |
| """ |
| |
| if isinstance(text, str): |
| texts = [text] |
| else: |
| texts = text |
|
|
| if images is not None: |
| if isinstance(images, Image): |
| images_list = [images] |
| else: |
| images_list = images |
| else: |
| images_list = None |
|
|
| if len(texts) != 1: |
| raise ValueError("IsaacProcessor currently supports batch_size=1") |
| if images_list is not None: |
| |
| vision_token_count = texts[0].count(self.vision_token) |
| if vision_token_count != len(images_list): |
| raise ValueError( |
| f"Number of {self.vision_token} tokens in text ({vision_token_count}) " |
| f"must match number of images ({len(images_list)})" |
| ) |
|
|
| |
| stream = self.build_event_stream_simple( |
| text=texts[0], |
| images=images_list, |
| ) |
|
|
| |
| tensor_stream = TensorStream([stream]) |
|
|
| |
| _, T = tensor_stream.shape |
| if T > self.max_sequence_length: |
| tensor_stream = ts_slice(tensor_stream, start=T - self.max_sequence_length, end=T) |
|
|
| |
| tokens = tensor_stream_token_view(tensor_stream) |
| if return_tensors in (TensorType.PYTORCH, "pt"): |
| input_ids = torch.as_tensor(tokens, dtype=torch.long) |
| else: |
| input_ids = tokens |
|
|
| data = { |
| "input_ids": input_ids, |
| "tensor_stream": tensor_stream, |
| } |
|
|
| return BatchFeature(data=data) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def compute_position_ids_input_ids(input_ids: torch.Tensor) -> torch.Tensor: |
| r"""Create 3D positional indices for token input. |
| |
| Args: |
| input_ids (`torch.Tensor`): |
| Tensor of shape `(batch_size, seq_len)` containing token ids. |
| |
| Returns: |
| `torch.Tensor`: Positional indices with shape `(batch_size, seq_len, 3)` where each channel duplicates the |
| 1D position so it can be consumed by the 3-axis MRoPE rotary embedding. |
| """ |
| batch_size, seq_length = input_ids.shape |
| position_ids = torch.arange(seq_length, device=input_ids.device) |
| position_ids = position_ids.view(1, -1).expand(batch_size, -1) |
| position_ids = position_ids.unsqueeze(2).expand(-1, -1, 3) |
| return position_ids |
|
|
|
|
| class IsaacRotaryEmbedding(nn.Module): |
| EXTRA_ROPE_KEYS = {"mrope_section", "mrope_interleaved"} |
|
|
| def __init__(self, config: IsaacConfig, device=None): |
| super().__init__() |
|
|
| rope_source_cfg = config.get_text_config() if hasattr(config, "get_text_config") else config |
| rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} |
|
|
| sanitized_scaling = {k: v for k, v in rope_scaling.items() if k not in self.EXTRA_ROPE_KEYS} |
| config_for_rope = copy.copy(rope_source_cfg) |
| config_for_rope.rope_scaling = sanitized_scaling if sanitized_scaling else None |
|
|
| init_device = device if device is not None and getattr(device, "type", None) != "meta" else None |
| self._qwen_rotary = qwen2_5_vl_modeling.Qwen2_5_VLRotaryEmbedding(config_for_rope, device=init_device) |
|
|
| rotary_half_dim = self._qwen_rotary.inv_freq.shape[0] |
| self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), rotary_half_dim) |
| self.hidden_size = getattr(rope_source_cfg, "hidden_size", None) or config.hidden_size |
|
|
| @staticmethod |
| def _resolve_mrope_section(section: Optional[list[int]], rotary_half_dim: int) -> list[int]: |
| if section is None: |
| weights = (2, 1, 1) |
| base = [rotary_half_dim * w // sum(weights) for w in weights] |
| base[0] += rotary_half_dim - sum(base) |
| return base |
|
|
| section = [int(v) for v in section] |
| if len(section) != 3: |
| raise ValueError("`mrope_section` must contain exactly three elements (temporal, height, width)") |
| if sum(section) != rotary_half_dim: |
| raise ValueError( |
| f"`mrope_section` must sum to the rotary half-dimension ({rotary_half_dim}). Received {section}." |
| ) |
| return section |
|
|
| def _combine_axes(self, tensor: torch.Tensor) -> torch.Tensor: |
| split_sections = tuple(self.mrope_section * 2) |
| chunks = tensor.split(split_sections, dim=-1) |
| return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) |
|
|
| @property |
| def inv_freq(self) -> torch.Tensor: |
| return self._qwen_rotary.inv_freq |
|
|
| def forward( |
| self, |
| position_ids: torch.Tensor, |
| modality_tensor: torch.Tensor, |
| hidden_states: Optional[torch.Tensor] = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| if position_ids.ndim != 3 or position_ids.size(-1) != 3: |
| raise ValueError("`position_ids` must have shape (batch, seq_len, 3) for MRoPE") |
| if modality_tensor.shape != position_ids.shape[:2]: |
| raise ValueError("`modality_tensor` must align with the first two dims of `position_ids`") |
|
|
| if hidden_states is None: |
| batch, seq_len, _ = position_ids.shape |
| hidden_states = torch.zeros( |
| batch, |
| seq_len, |
| self.hidden_size, |
| dtype=torch.float32, |
| device=position_ids.device, |
| ) |
|
|
| with torch.no_grad(): |
| pos = position_ids.clone() |
| image_value = VisionType.image.value if VisionType is not None else 1 |
| not_spatial = modality_tensor != image_value |
| if not_spatial.any(): |
| data_1d = pos[not_spatial][..., 0].unsqueeze(-1) |
| pos[not_spatial] = data_1d.expand(-1, pos.shape[-1]) |
|
|
| pos_axes = pos.permute(2, 0, 1).contiguous() |
|
|
| cos_axes, sin_axes = self._qwen_rotary(hidden_states, pos_axes) |
|
|
| cos_axes = cos_axes.to(hidden_states.dtype) |
| sin_axes = sin_axes.to(hidden_states.dtype) |
|
|
| cos_combined = self._combine_axes(cos_axes) |
| sin_combined = self._combine_axes(sin_axes) |
|
|
| return cos_combined, sin_combined |
|
|
|
|
| class IsaacModel(Qwen3PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _can_compile_fullgraph = False |
| _supports_flex_attn = False |
| |
| all_tied_weights_keys: dict[str, str] = {} |
|
|
| def __init__(self, config: IsaacConfig): |
| Qwen3PreTrainedModel.__init__(self, config) |
|
|
| text_cfg_source = config.text_config |
| text_cfg = copy.deepcopy(text_cfg_source) |
| self.text_model = AutoModel.from_config(text_cfg) |
| |
| self.text_model.config = config |
|
|
| self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) |
|
|
| if config.vision_config is None: |
| raise ValueError("IsaacConfig should always have vision_config") |
|
|
|
|
|
|
| self.vision_tower = IsaacVisionTransformer(config.vision_config) |
| self.multimodal_projector = IsaacMultiModalProjector(config) |
|
|
| |
| self.embed_fns = { |
| TextType: self.embed_text_tokens, |
| VisionType: self.embed_vision, |
| } |
|
|
| |
| self.max_sequence_length = config.max_sequence_length |
| self.vision_rescale_factor = config.vision_rescale_factor |
| self.vision_token = config.vision_token |
|
|
| |
| self.post_init() |
|
|
| |
| if getattr(config, "gradient_checkpointing", False): |
| self.gradient_checkpointing_enable() |
|
|
| def get_input_embeddings(self) -> nn.Module: |
| return self.text_model.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value: nn.Module) -> None: |
| self.text_model.set_input_embeddings(value) |
| vocab_size = getattr(value, "num_embeddings", None) |
| if vocab_size is not None: |
| self.config.vocab_size = vocab_size |
| if hasattr(self.config, "text_config"): |
| self.config.text_config.vocab_size = vocab_size |
| self.text_model.config.vocab_size = vocab_size |
|
|
| @property |
| def embed_tokens(self) -> nn.Module: |
| return self.text_model.embed_tokens |
|
|
| @embed_tokens.setter |
| def embed_tokens(self, value: nn.Module) -> None: |
| self.text_model.embed_tokens = value |
|
|
| @property |
| def layers(self) -> nn.ModuleList: |
| return self.text_model.layers |
|
|
| @property |
| def norm(self) -> nn.Module: |
| return self.text_model.norm |
|
|
| @property |
| def vision_model(self) -> nn.Module: |
| return self.vision_embedding.vision_tower |
|
|
| def embed_text_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: |
| """Embed text tokens, squeezing singleton dimensions.""" |
| |
| h = self.text_model.embed_tokens(token_ids) |
| if h.dim() >= 2 and h.size(-2) == 1: |
| h = h[..., 0, :] |
| return h |
|
|
| def embed_vision(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: |
| """Embed vision tokens using the vision encoder.""" |
| |
| return self.multimodal_projector(self.vision_tower(vision_tokens)) |
|
|
| def embed_stream(self, tensor_stream: TensorStream) -> torch.Tensor: |
| """ |
| Embed each modality stream independently, preserving the original TensorStream |
| structure. |
| """ |
| flat_stream = tensor_stream.flat_stream() |
| per_modality_stream = group_streams(flat_stream, group_fn=lambda ev: ev.type, schedule=False) |
| per_modality_compact_stream = {k: v.compact() for k, v in per_modality_stream.items()} |
|
|
| |
| token_grids = defaultdict(list) |
| for stream in tensor_stream.streams: |
| for event in stream: |
| token_grids[event.type].append(event.dims(virtual=False)) |
|
|
| embedded_compact = {} |
| for stream_type, modality_payload_tensor in per_modality_compact_stream.items(): |
| if stream_type.modality == VisionType: |
| |
| grids = token_grids.get(stream_type, []) |
| if len(grids) == 0: |
| input_tensor = modality_payload_tensor |
| else: |
| token_grids_tensor = torch.tensor(grids, dtype=torch.long, device=tensor_stream.device)[:, 1:] |
| input_tensor = (modality_payload_tensor, token_grids_tensor) |
| embedded_compact[stream_type] = self.embed_fns[stream_type.modality](input_tensor) |
| else: |
| embedded_compact[stream_type] = self.embed_fns[stream_type.modality](modality_payload_tensor) |
|
|
| |
| embedded_ts = reconstruct_tensor_stream_from_compact_dict(tensor_stream, embedded_compact) |
| h = embedded_ts.compact() |
| return h |
|
|
| @auto_docstring |
| @check_model_inputs |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| tensor_stream: Optional[TensorStream] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| modality_tensor: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[list[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs, |
| ) -> tuple | BaseModelOutputWithPast: |
| """ |
| Forward pass with MRoPE position embeddings. |
| |
| Computes position embeddings once and passes them through all layers. |
| |
| Args: |
| tensor_stream (`TensorStream`, *optional*): |
| Packed multimodal stream of text and vision events to embed directly. Mutually exclusive with |
| `input_ids` and `inputs_embeds`. When provided, the method derives `position_ids` and `modality_tensor` |
| if they are not supplied. |
| modality_tensor (`torch.LongTensor`, *optional*): |
| Modality identifiers aligned with the embedded sequence, shaped `(batch_size, seq_len)` and containing |
| values from `TextType`/`VisionType`. Automatically built from `tensor_stream` or `input_ids` when |
| omitted. |
| """ |
|
|
| text_value = TextType.text.value if TextType is not None else 0 |
|
|
| |
| if tensor_stream is not None and inputs_embeds is not None: |
| raise ValueError("You cannot specify both tensor_stream and inputs_embeds") |
| if tensor_stream is None and input_ids is not None and inputs_embeds is not None: |
| raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
|
|
| |
| if tensor_stream is not None: |
| inputs_embeds = self.embed_stream(tensor_stream) |
| elif input_ids is not None: |
| inputs_embeds = self.text_model.embed_tokens(input_ids) |
| elif inputs_embeds is None: |
| raise ValueError("You have to specify either tensor_stream, input_ids or inputs_embeds") |
|
|
| batch_size, seq_len = inputs_embeds.shape[:2] |
|
|
| |
| if use_cache and past_key_values is None: |
| cache_config = self.config.get_text_config() if hasattr(self.config, "get_text_config") else self.config |
| past_key_values = DynamicCache(config=cache_config) |
|
|
| if cache_position is None: |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device) |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones((batch_size, seq_len), device=inputs_embeds.device, dtype=torch.long) |
|
|
| |
| if modality_tensor is None: |
| if tensor_stream is not None: |
| modality_tensor = modality_mask(tensor_stream) |
| else: |
| modality_tensor = torch.full( |
| (batch_size, seq_len), text_value, device=inputs_embeds.device, dtype=torch.long |
| ) |
| else: |
| modality_tensor = modality_tensor.to(dtype=torch.long) |
|
|
| if modality_tensor.shape[1] != seq_len: |
| if modality_tensor.shape[1] > seq_len: |
| modality_tensor = modality_tensor[:, :seq_len] |
| else: |
| pad = modality_tensor[:, -1:].expand(-1, seq_len - modality_tensor.shape[1]) |
| modality_tensor = torch.cat([modality_tensor, pad], dim=1) |
|
|
| |
| if position_ids is None: |
| if tensor_stream is not None: |
| position_ids = compute_mrope_pos_tensor(tensor_stream) |
| else: |
| position_ids = cache_position.view(1, -1).expand(modality_tensor.shape[0], -1) |
|
|
| |
| if position_ids.ndim == 2: |
| position_ids = position_ids.to(device=inputs_embeds.device) |
| position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) |
|
|
| |
| if position_ids.shape[1] != seq_len: |
| start_positions = position_ids[:, :1, 0] |
| position_ids = torch.arange(seq_len, device=inputs_embeds.device).view(1, -1) |
| position_ids = position_ids + start_positions |
| position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) |
|
|
| |
| cos, sin = self.rotary_emb( |
| position_ids, |
| modality_tensor, |
| hidden_states=inputs_embeds, |
| ) |
| cos = cos.to(inputs_embeds.dtype) |
| sin = sin.to(inputs_embeds.dtype) |
|
|
| |
| decoder_position_ids = position_ids[..., 0] if position_ids.ndim == 3 else position_ids |
|
|
| |
| if not isinstance(attention_mask, dict): |
| attention_mask = create_masks_for_generate( |
| config=self.config, |
| input_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| cache_position=cache_position, |
| past_key_values=past_key_values, |
| position_ids=decoder_position_ids, |
| ) |
|
|
| is_attention_mask_dict = isinstance(attention_mask, dict) |
|
|
| |
| hidden_states = inputs_embeds |
| all_attentions = [] if output_attentions else None |
|
|
| for decoder_layer in self.text_model.layers: |
| layer_attention_mask = ( |
| attention_mask[decoder_layer.attention_type] if is_attention_mask_dict else attention_mask |
| ) |
| layer_outputs = decoder_layer( |
| hidden_states, |
| attention_mask=layer_attention_mask, |
| position_ids=decoder_position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=(cos, sin), |
| output_attentions=output_attentions, |
| **kwargs, |
| ) |
|
|
| layer_outputs_is_tuple = isinstance(layer_outputs, tuple) |
| hidden_states = layer_outputs[0] if layer_outputs_is_tuple else layer_outputs |
| if output_attentions and layer_outputs_is_tuple: |
| all_attentions.append(layer_outputs[1]) |
|
|
| |
| hidden_states = self.text_model.norm(hidden_states) |
|
|
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=past_key_values, |
| hidden_states=(hidden_states,), |
| attentions=tuple(all_attentions) if output_attentions else None, |
| ) |
|
|
|
|
| class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin): |
| """Isaac multimodal model for conditional generation.""" |
|
|
| config_class = IsaacConfig |
| _can_compile_fullgraph = False |
| _tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"} |
| all_tied_weights_keys: dict[str, str] = {"lm_head.weight": "model.text_model.embed_tokens.weight"} |
|
|
| def __init__(self, config: IsaacConfig): |
| config.pad_token_id = 0 |
| super().__init__(config) |
| self.model = IsaacModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| |
| self.rope_deltas = None |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| tensor_stream: Optional[TensorStream] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[list[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs, |
| ) -> tuple | CausalLMOutputWithPast: |
| r""" |
| Forward pass for conditional generation supporting both standard inputs and TensorStream. |
| |
| tensor_stream (`TensorStream`, *optional*): |
| Packed multimodal stream (text, vision, audio tokens) that already encodes spatial metadata. When provided, |
| the model derives embeddings, modality masks, and 3D rotary coordinates directly from the stream instead of |
| `input_ids`. |
| """ |
|
|
| |
| if tensor_stream is not None: |
| input_ids = None |
| if input_ids is None and inputs_embeds is None and tensor_stream is None: |
| raise ValueError("Either input_ids, inputs_embeds, or tensor_stream must be provided.") |
|
|
| text_value = TextType.text.value if TextType is not None else 0 |
|
|
| if tensor_stream is None: |
| if input_ids is not None: |
| batch_size, seq_len = input_ids.shape |
| input_device = input_ids.device |
| else: |
| batch_size, seq_len = inputs_embeds.shape[:2] |
| input_device = inputs_embeds.device |
|
|
| |
| |
| |
| if position_ids is None: |
| if tensor_stream is not None: |
| position_ids, self.rope_deltas = self.get_rope_index(input_ids, tensor_stream, attention_mask) |
| elif input_ids is None: |
| dummy_ids = torch.zeros((batch_size, seq_len), device=input_device, dtype=torch.long) |
| position_ids = compute_position_ids_input_ids(dummy_ids) |
| else: |
| position_ids = compute_position_ids_input_ids(input_ids) |
|
|
| rope_delta = 0 |
| if cache_position is not None and self.rope_deltas is not None: |
| |
| |
| rope_delta = (cache_position[0] + self.rope_deltas).to(input_ids.device) |
| if not isinstance(rope_delta, int): |
| rope_delta = rope_delta.repeat_interleave(batch_size // rope_delta.shape[0], dim=0) |
|
|
| position_ids = position_ids.add(rope_delta) |
|
|
| if attention_mask is None and tensor_stream is None: |
| attention_mask = torch.ones((batch_size, seq_len), device=input_device, dtype=torch.long) |
|
|
| if tensor_stream is not None: |
| modality_tensor = modality_mask(tensor_stream) |
| else: |
| modality_tensor = torch.full( |
| (batch_size, seq_len), text_value, device=position_ids.device, dtype=torch.long |
| ) |
|
|
| outputs = self.model( |
| input_ids=input_ids, |
| tensor_stream=tensor_stream, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| modality_tensor=modality_tensor, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs[0] |
| logits = self.lm_head(hidden_states) |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions if output_attentions else None, |
| ) |
|
|
| def set_input_embeddings(self, value: nn.Module) -> None: |
| self.model.set_input_embeddings(value) |
| vocab_size = getattr(value, "num_embeddings", None) |
| if vocab_size is not None: |
| self.config.vocab_size = vocab_size |
| self.model.config.vocab_size = vocab_size |
| if hasattr(self.model, "text_model"): |
| self.model.text_model.config.vocab_size = vocab_size |
| if self.lm_head.weight.shape[0] != vocab_size: |
| self.lm_head = nn.Linear(self.config.hidden_size, vocab_size, bias=False) |
| if hasattr(self.model, "embed_tokens"): |
| self.lm_head.weight = self.model.text_model.embed_tokens.weight |
|
|
| def get_rope_index( |
| self, |
| input_ids: Optional[torch.Tensor], |
| tensor_stream: Optional[TensorStream], |
| attention_mask: Optional[torch.Tensor], |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Compute MRoPE position ids from a TensorStream (or 1D fallback). |
| |
| Returns (position_ids, rope_deltas). position_ids is (B,L,3) for MRoPE. |
| rope_deltas is (B,1) used to advance positions in decode. |
| """ |
| |
| if tensor_stream is None and input_ids is None: |
| raise ValueError("`tensor_stream` or `input_ids` must be provided to compute rope indices") |
|
|
| if tensor_stream is not None: |
| pos_3d = compute_mrope_pos_tensor(tensor_stream) |
| else: |
| pos_3d = compute_position_ids_input_ids(input_ids) |
| B, L, _ = pos_3d.shape |
|
|
| |
| m_per_batch = pos_3d.amax(dim=(1, 2)) |
|
|
| |
| if attention_mask is None: |
| seq_lens = torch.full_like(m_per_batch, L) |
| else: |
| seq_lens = attention_mask.eq(1).sum(dim=-1).to(dtype=m_per_batch.dtype, device=m_per_batch.device) |
|
|
| rope_deltas = (m_per_batch + 1 - seq_lens).to(dtype=pos_3d.dtype).unsqueeze(1) |
| return pos_3d, rope_deltas |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids: torch.LongTensor, |
| past_key_values: Optional[list[torch.FloatTensor]] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| tensor_stream: Optional[TensorStream] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| use_cache: bool = True, |
| **kwargs, |
| ) -> dict[str, Any]: |
| """ |
| Prepare inputs for generation, handling TensorStream inputs properly. |
| """ |
| if cache_position is None: |
| seq_length = None |
| device = None |
| if input_ids is not None: |
| seq_length = input_ids.shape[1] |
| device = input_ids.device |
| elif inputs_embeds is not None: |
| seq_length = inputs_embeds.shape[1] |
| device = inputs_embeds.device |
| elif tensor_stream is not None: |
| _, seq_length = tensor_stream.shape |
| device = tensor_stream.device |
| if seq_length is not None: |
| |
| |
| cache_position = torch.arange(seq_length, dtype=torch.long, device=device) |
|
|
| |
| model_inputs = super().prepare_inputs_for_generation( |
| input_ids, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| cache_position=cache_position, |
| position_ids=position_ids, |
| use_cache=use_cache, |
| **kwargs, |
| ) |
|
|
| cache_position = model_inputs.get("cache_position", cache_position) |
|
|
| |
| first_step = cache_position is None or cache_position[0] == 0 |
| if tensor_stream is not None and first_step: |
| model_inputs["tensor_stream"] = tensor_stream |
| |
| model_inputs["position_ids"] = None |
| else: |
| model_inputs["tensor_stream"] = None |
|
|
| |
| if tensor_stream is not None and not first_step and self.rope_deltas is not None: |
| model_inputs["position_ids"] = None |
| return model_inputs |
|
|
| |
| if model_inputs.get("position_ids") is None and cache_position is not None and not first_step: |
| batch_size = 1 |
| if model_inputs.get("input_ids") is not None: |
| batch_size = model_inputs["input_ids"].shape[0] |
| elif model_inputs.get("inputs_embeds") is not None: |
| batch_size = model_inputs["inputs_embeds"].shape[0] |
| pos_ids = cache_position.view(1, -1).expand(batch_size, -1) |
| pos_ids = pos_ids.unsqueeze(-1).expand(-1, -1, 3) |
| model_inputs["position_ids"] = pos_ids |
|
|
| return model_inputs |
|
|
| @classmethod |
| def can_generate(cls) -> bool: |
| return True |
|
|
|
|
| def _compute_residual_p_frames(frames: torch.Tensor, is_p_frame: list[bool]) -> torch.Tensor: |
| """Compute residuals for P-frames to stay in sync with the training pipeline.""" |
| if not any(is_p_frame): |
| return frames |
|
|
| frame_indices = torch.arange(len(is_p_frame), device=frames.device) |
| i_frame_mask = torch.tensor([not flag for flag in is_p_frame], device=frames.device) |
| last_i_indices = torch.cummax((i_frame_mask * (1 + frame_indices)), dim=0).values.long() - 1 |
| p_indices = frame_indices[torch.tensor(is_p_frame, device=frames.device)] |
| frames[p_indices] = frames[p_indices] - frames[last_i_indices[p_indices]] |
| return frames |
|
|
|
|
| __all__ = [ |
| "IsaacConfig", |
| "IsaacModel", |
| "IsaacPreTrainedModel", |
| "IsaacForConditionalGeneration", |
| "IsaacImageProcessorFast", |
| "IsaacProcessor", |
| ] |
|
|