from collections import deque from dataclasses import dataclass import torch import numpy as np @dataclass class ChunkCum: cum: int image_grid_thw: tuple[int, int, int] | None = None video_grid_thw: tuple[int, int, int] | None = None def _visual_token_cums( sequence_idx: int, input_ids: torch.Tensor | np.ndarray, image_token_id: int, video_token_id: int, merge_size: int, focus_size: int, image_grid_thw: torch.Tensor | np.ndarray | None, video_grid_thw: torch.Tensor | np.ndarray | None, video_3d_pooling: bool = True, **kwargs, ) -> list[ChunkCum]: cums: deque[ChunkCum] = deque() video_idx = 0 frame_idx = 0 image_idx = 0 token_idx = 0 in_video = False cum = 0 sequence = input_ids[sequence_idx].tolist() while token_idx < len(sequence): token = sequence[token_idx] if token == image_token_id: assert image_grid_thw is not None, "image_grid_thw must be provided when image_token_id is used" _, h, w = image_grid_thw[image_idx].tolist() num_tokens = h * w // (merge_size ** 2) cums.append(ChunkCum( cum=num_tokens, image_grid_thw=(1, h, w), video_grid_thw=None ) ) token_idx += num_tokens image_idx += 1 elif token == video_token_id: assert video_grid_thw is not None, "video_grid_thw must be provided when video_token_id is used" t, h, w = video_grid_thw[video_idx].tolist() num_tokens = h * w // (merge_size ** 2) if video_3d_pooling: assert t % focus_size == 0, f"Number of frames {t} must be divisible by focus_size {focus_size}" cum += num_tokens if (frame_idx + 1) % focus_size == 0: cums.append(ChunkCum( cum=cum, image_grid_thw=None, video_grid_thw=(focus_size, h, w), )) cum = 0 in_video = False else: in_video = True else: # 2D pooling: each frame is independent (like an image) cums.append(ChunkCum( cum=num_tokens, image_grid_thw=None, video_grid_thw=(1, h, w), )) frame_idx += 1 if frame_idx == t: video_idx += 1 frame_idx = 0 token_idx += num_tokens else: if not in_video: cums.append(ChunkCum(cum=1, image_grid_thw=None, video_grid_thw=None)) else: cum += 1 token_idx += 1 return list(cums) def visual_token_cums( input_ids: torch.Tensor | np.ndarray, image_token_id: int, video_token_id: int, merge_size: int, focus_size: int, image_grid_thw: torch.Tensor | np.ndarray | None, video_grid_thw: torch.Tensor | np.ndarray | None, video_3d_pooling: bool = True, **kwargs, ) -> list[list[ChunkCum]]: return [ _visual_token_cums( sequence_idx=i, input_ids=input_ids, image_token_id=image_token_id, video_token_id=video_token_id, merge_size=merge_size, focus_size=focus_size, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, video_3d_pooling=video_3d_pooling, ) for i in range(input_ids.shape[0]) ] @dataclass class Chunk: start: int end: int image_grid_thws: list[tuple[int, int, int]] video_grid_thws: list[tuple[int, int, int]] # Precomputed indices for raster -> block-order reordering (computed once, reused across layers) image_fold_idx: torch.Tensor | None = None # (total_image_tokens,) long image_focus_idx: torch.Tensor | None = None # (total_image_regions,) long; None for 'mean' video_fold_idx: torch.Tensor | None = None video_focus_idx: torch.Tensor | None = None def _compute_image_fold_indices( image_grid_thws: list[tuple[int, int, int]], merge_size: int, focus_size: int, focus_pool: str, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: """Precompute raster->block-order gather indices for all images in a chunk.""" if not image_grid_thws: return None, None fold_parts: list[torch.Tensor] = [] focus_parts: list[torch.Tensor] = [] offset = 0 for t, h, w in image_grid_thws: assert t == 1 h = h // merge_size w = w // merge_size bh = h // focus_size bw = w // focus_size positions = torch.arange(h * w) # Raster -> block order: (bh, fs, bw, fs) -> (bh, bw, fs, fs) -> flat fold_idx = positions.view(bh, focus_size, bw, focus_size) \ .permute(0, 2, 1, 3).reshape(-1) + offset fold_parts.append(fold_idx) if focus_pool == 'start': fidx = positions.view(h, w)[::focus_size, ::focus_size].reshape(-1) + offset focus_parts.append(fidx) elif focus_pool == 'center': c = focus_size // 2 fidx = positions.view(h, w)[c::focus_size, c::focus_size].reshape(-1) + offset focus_parts.append(fidx) elif focus_pool == 'end': e = focus_size - 1 fidx = positions.view(h, w)[e::focus_size, e::focus_size].reshape(-1) + offset focus_parts.append(fidx) # 'mean': no direct gather; computed from block-ordered tensor in model offset += h * w fold_idx_cat = torch.cat(fold_parts) focus_idx_cat = torch.cat(focus_parts) if focus_parts else None return fold_idx_cat, focus_idx_cat def _compute_video_fold_indices( video_grid_thws: list[tuple[int, int, int]], merge_size: int, focus_size: int, focus_pool: str, video_3d_pooling: bool = True, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: """Precompute raster->block-order gather indices for all videos in a chunk.""" if not video_grid_thws: return None, None fold_parts: list[torch.Tensor] = [] focus_parts: list[torch.Tensor] = [] offset = 0 for t, h, w in video_grid_thws: h = h // merge_size w = w // merge_size bh = h // focus_size bw = w // focus_size if video_3d_pooling: bt = t // focus_size positions = torch.arange(t * h * w) # Raster -> block order: (bt, fs, bh, fs, bw, fs) -> (bt, bh, bw, fs, fs, fs) -> flat fold_idx = positions.view(bt, focus_size, bh, focus_size, bw, focus_size) \ .permute(0, 2, 4, 1, 3, 5).reshape(-1) + offset fold_parts.append(fold_idx) if focus_pool == 'start': fidx = positions.view(t, h, w)[::focus_size, ::focus_size, ::focus_size].reshape(-1) + offset focus_parts.append(fidx) elif focus_pool == 'center': c = focus_size // 2 fidx = positions.view(t, h, w)[c::focus_size, c::focus_size, c::focus_size].reshape(-1) + offset focus_parts.append(fidx) elif focus_pool == 'end': e = focus_size - 1 fidx = positions.view(t, h, w)[e::focus_size, e::focus_size, e::focus_size].reshape(-1) + offset focus_parts.append(fidx) else: # 2D pooling: treat each frame independently (like images) positions = torch.arange(t * h * w) # Raster -> block order per frame: (t, bh, fs, bw, fs) -> (t, bh, bw, fs, fs) -> flat fold_idx = positions.view(t, bh, focus_size, bw, focus_size) \ .permute(0, 1, 3, 2, 4).reshape(-1) + offset fold_parts.append(fold_idx) if focus_pool == 'start': fidx = positions.view(t, h, w)[:, ::focus_size, ::focus_size].reshape(-1) + offset focus_parts.append(fidx) elif focus_pool == 'center': c = focus_size // 2 fidx = positions.view(t, h, w)[:, c::focus_size, c::focus_size].reshape(-1) + offset focus_parts.append(fidx) elif focus_pool == 'end': e = focus_size - 1 fidx = positions.view(t, h, w)[:, e::focus_size, e::focus_size].reshape(-1) + offset focus_parts.append(fidx) offset += t * h * w fold_idx_cat = torch.cat(fold_parts) focus_idx_cat = torch.cat(focus_parts) if focus_parts else None return fold_idx_cat, focus_idx_cat def _merge_video_grid_thws( thws: list[tuple[int, int, int]], ) -> list[tuple[int, int, int]]: """Merge consecutive video grid_thws that share the same (h, w).""" if not thws: return thws merged: list[tuple[int, int, int]] = [] cur_t, cur_h, cur_w = thws[0] for t, h, w in thws[1:]: if h == cur_h and w == cur_w: cur_t += t else: merged.append((cur_t, cur_h, cur_w)) cur_t, cur_h, cur_w = t, h, w merged.append((cur_t, cur_h, cur_w)) return merged def chunk_tokens( max_chunk_size: int, input_ids: torch.Tensor | np.ndarray, image_token_id: int, video_token_id: int, merge_size: int, focus_size: int, image_grid_thw: torch.Tensor | np.ndarray | None, video_grid_thw: torch.Tensor | np.ndarray | None, focus_pool: str = 'mean', video_3d_pooling: bool = True, **kwargs, ) -> list[list[Chunk]]: cums = visual_token_cums( input_ids=input_ids, image_token_id=image_token_id, video_token_id=video_token_id, merge_size=merge_size, focus_size=focus_size, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, video_3d_pooling=video_3d_pooling, **kwargs, ) chunked_cums: list[list[Chunk]] = [] for sequence_cums in cums: chunks: list[Chunk] = [] current_chunk_start = 0 current_chunk_size = 0 current_image_grid_thws: list[tuple[int, int, int]] = [] current_video_grid_thws: list[tuple[int, int, int]] = [] def _make_chunk(start: int, end: int, img_thws: list[tuple[int, int, int]], vid_thws: list[tuple[int, int, int]]) -> Chunk: merged_vid = _merge_video_grid_thws(vid_thws) img_fold, img_focus = _compute_image_fold_indices(img_thws, merge_size, focus_size, focus_pool) vid_fold, vid_focus = _compute_video_fold_indices(merged_vid, merge_size, focus_size, focus_pool, video_3d_pooling) return Chunk( start=start, end=end, image_grid_thws=img_thws, video_grid_thws=merged_vid, image_fold_idx=img_fold, image_focus_idx=img_focus, video_fold_idx=vid_fold, video_focus_idx=vid_focus, ) for cum in sequence_cums: if current_chunk_size + cum.cum > max_chunk_size: chunks.append(_make_chunk( current_chunk_start, current_chunk_start + current_chunk_size, current_image_grid_thws, current_video_grid_thws, )) current_chunk_start += current_chunk_size current_chunk_size = 0 current_image_grid_thws = [] current_video_grid_thws = [] if cum.image_grid_thw is not None: current_image_grid_thws.append(cum.image_grid_thw) if cum.video_grid_thw is not None: current_video_grid_thws.append(cum.video_grid_thw) current_chunk_size += cum.cum if current_chunk_size > 0: chunks.append(_make_chunk( current_chunk_start, current_chunk_start + current_chunk_size, current_image_grid_thws, current_video_grid_thws, )) chunked_cums.append(chunks) num_chunks = max(len(chunks) for chunks in chunked_cums) for chunks in chunked_cums: while len(chunks) < num_chunks: chunks.append(Chunk( start=chunks[-1].end, end=chunks[-1].end, image_grid_thws=[], video_grid_thws=[], )) return chunked_cums