"""CASA layers""" import bisect from dataclasses import dataclass from itertools import accumulate from typing import TYPE_CHECKING, Callable, Literal, Sequence, TypedDict, overload from typing import cast as type_cast import torch from transformers.configuration_utils import PretrainedConfig from .utils import StreamingModule, StreamingState, delta_w_factory if TYPE_CHECKING: from transformers.configuration_utils import PretrainedConfig try: from flash_attn import flash_attn_varlen_func except ImportError: flash_attn_varlen_func = None # type: ignore WindowsComputeKwargs = TypedDict( "WindowsComputeKwargs", { "num_post_image_tokens": int, "num_pre_image_tokens": int, }, total=False, ) def __split_n_merge__( x: torch.Tensor, sample_lengths: list[int], padding_side: Literal["left", "right"] = "right", pad_value: int | float | bool = 0, ) -> torch.Tensor: max_sample_length = max(sample_lengths) pad_tuple = tuple(0 for _ in range((x.ndim - 1) * 2)) return torch.stack( [ torch.nn.functional.pad( _x, pad_tuple + (0, max_sample_length - _x.shape[0]) if padding_side == "right" else pad_tuple + (max_sample_length - _x.shape[0], 0), value=pad_value, ) for _x in torch.split(x, sample_lengths, dim=0) ], dim=0, ) @overload def insert_image_tokens( inputs_embeds: torch.Tensor, image_embeds: torch.Tensor | Sequence[torch.Tensor], image_embeds_insertion_points: list[torch.Tensor], recover_batch_dim: Literal[True], attention_mask: torch.Tensor | None = None, padding_side: Literal["left", "right"] = "right", keep_only_attended: bool = False, pad_output: int | float | bool = 0.0, ) -> tuple[ torch.Tensor, None, torch.Tensor | None, torch.Tensor, ]: ... @overload def insert_image_tokens( inputs_embeds: torch.Tensor, image_embeds: torch.Tensor | Sequence[torch.Tensor], image_embeds_insertion_points: list[torch.Tensor], recover_batch_dim: Literal[False], attention_mask: torch.Tensor | None = None, padding_side: Literal["left", "right"] = "right", keep_only_attended: bool = False, pad_output: int | float | bool = 0.0, ) -> tuple[ torch.Tensor, list[int], torch.Tensor | None, torch.Tensor, ]: ... def insert_image_tokens( inputs_embeds: torch.Tensor, image_embeds: torch.Tensor | Sequence[torch.Tensor], image_embeds_insertion_points: list[torch.Tensor], recover_batch_dim: bool = True, attention_mask: torch.Tensor | None = None, padding_side: Literal["left", "right"] = "right", keep_only_attended: bool = False, pad_output: int | float | bool = 0.0, ) -> tuple[ torch.Tensor | torch.Tensor, list[int] | None, torch.Tensor | torch.Tensor | None, torch.Tensor | torch.Tensor, ]: """ Insert image embeddings into text embeddings Args: inputs_embeds (torch.Tensor): (B, S, D) input token embeddings. image_embeds (torch.Tensor | list[torch.Tensor]): (N_images, Nt, D) | List[(Nt, D)] image token embeddings. image_embeds_insertion_points (list[torch.Tensor]): Insertion indices. attention_mask (torch.Tensor, optional): (B, S) attention mask. padding_side (Literal["left", "right"]): Padding scheme. Controls behavior for padded images. return_indices (bool): Whether to return gather indices or the fused sequence directly. keep_only_attended: This is only applicable when recover_batch_dim is False; whether to remove any non-attended tokens in the whole array. In this case, the attention mask returned is **still the original one**, so we can remember which indices have been removed Returns: output (torch.Tensor): (B, S + Ni * Nt) gather indices or (B, S + Ni * Nt, D) fused sequence image_embeds (torch.Tensor): (B, Ni * Nt) image embeds, padded and batch if input was a list attention_mask (torch.Tensor): Same shape, 1 for real tokens, 0 for image and text padding. image_tokens_mask (torch.Tensor): (B, S + Ni * Nt, 1), marks image token positions. """ if isinstance(image_embeds, list) and len(image_embeds) == 0: batch_size, text_seq_length, token_dim = inputs_embeds.shape if recover_batch_dim: return ( inputs_embeds, None, attention_mask, torch.zeros((batch_size, text_seq_length, 1), dtype=torch.bool), ) else: flattened_seq_length = inputs_embeds.shape[0] * inputs_embeds.shape[1] return ( torch.reshape(inputs_embeds, (flattened_seq_length, inputs_embeds.shape[2])), [text_seq_length] * inputs_embeds.shape[0], attention_mask.flatten() if attention_mask is not None else None, torch.zeros((flattened_seq_length, 1), dtype=torch.bool), ) # Sanity checks if isinstance(image_embeds, torch.Tensor): assert inputs_embeds.shape[-1] == image_embeds.shape[-1] else: assert all(inputs_embeds.shape[-1] == _x.shape[-1] for _x in image_embeds) batch_size, text_seq_length, token_dim = inputs_embeds.shape image_seq_length = [x.shape[0] for x in image_embeds] # Flatten insertion points insertion_offset = [] counter, offset_from_text, offset_from_image = 0, 0, 0 for sample in image_embeds_insertion_points: for pt in sample: insertion_offset.append(pt + offset_from_image + offset_from_text) offset_from_image += image_seq_length[counter] counter += 1 offset_from_text += text_seq_length image_insert_positions = [ x for idx, pt in enumerate(insertion_offset) for x in range(pt, pt + image_seq_length[idx]) ] # Flatten image embeds if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) else: image_embeds = type_cast(torch.Tensor, image_embeds) image_embeds = torch.reshape(image_embeds, (-1, token_dim)) # Flatten text embeds across batch dim (B x S, D) inputs_embeds = torch.reshape(inputs_embeds, (-1, token_dim)) flattened_seq_length = inputs_embeds.shape[0] + sum(image_seq_length) text_insert_positions = sorted( set(range(flattened_seq_length)).difference(set(image_insert_positions)) ) # Scatter image embeds in the flattened dict # scatter text related stuff output = torch.empty( (flattened_seq_length, token_dim), device=inputs_embeds.device, dtype=inputs_embeds.dtype, ) txt_positions_tensor = torch.Tensor(text_insert_positions).to( dtype=torch.long, device=inputs_embeds.device ) output.scatter_(0, txt_positions_tensor[:, None].expand(-1, token_dim), inputs_embeds) attention_mask_new: torch.Tensor | None = None if attention_mask is not None: attention_mask_new = torch.ones( (flattened_seq_length,), dtype=torch.bool, device=inputs_embeds.device ) attention_mask_new.scatter_( 0, txt_positions_tensor, attention_mask.flatten().to(torch.bool) ) # scatter image related stuff image_tokens_mask = torch.zeros( (flattened_seq_length,), dtype=torch.bool, device=inputs_embeds.device ) img_positions_tensor = torch.Tensor(image_insert_positions).to( device=inputs_embeds.device, dtype=torch.long ) output.scatter_(0, img_positions_tensor[:, None].expand(-1, token_dim), image_embeds) image_tokens_mask.scatter_(0, img_positions_tensor, True) # Compute expected sample length, taking into account the real batch # i.e. recover the batch dimension of image embeddings sample_lengths = [] counter = 0 for sample_idx, pts in enumerate(image_embeds_insertion_points): num_image_tokens = 0 for _ in pts: num_image_tokens += image_seq_length[counter] counter += 1 if keep_only_attended and attention_mask is not None: attended_seq_length = torch.sum(attention_mask[sample_idx]).cpu().item() sample_lengths.append(attended_seq_length + num_image_tokens) else: sample_lengths.append(text_seq_length + num_image_tokens) # For CASA attention, we can keep stuff flatten ad return # the sample_lengths for the blockwise attention if not recover_batch_dim: if keep_only_attended and attention_mask_new is not None: output = output[attention_mask_new] image_tokens_mask = image_tokens_mask[attention_mask_new] return output, sample_lengths, attention_mask_new, image_tokens_mask[..., None] # Otherwise, time to (pad) and reshape # Easy case: everything has the same length if all(x == sample_lengths[0] for x in sample_lengths): output = torch.reshape(output, (batch_size, sample_lengths[0], token_dim)) image_tokens_mask = torch.reshape(image_tokens_mask, (batch_size, sample_lengths[0], 1)) if attention_mask_new is not None: attention_mask_new = torch.reshape(attention_mask_new, (batch_size, sample_lengths[0])) # if there is any size mismatch we break into a # list and pad again else: # split and merge output = __split_n_merge__(output, sample_lengths, padding_side, pad_value=pad_output) # note that the extra padding tokens are also marked as image tokens to be removed later image_tokens_mask = __split_n_merge__( image_tokens_mask, sample_lengths, padding_side, True )[:, :, None] if attention_mask_new is not None: attention_mask_new = __split_n_merge__( attention_mask_new, sample_lengths, padding_side, 0 ) # Return return output, sample_lengths, attention_mask_new, image_tokens_mask def get_sample_lengths_from_insertion_points( image_embeds_insertion_points: list[torch.Tensor], image_embeds: torch.Tensor | list[torch.Tensor] | None, total_seq_len: int | None = None, attention_mask: torch.Tensor | None = None, **kwargs: WindowsComputeKwargs, ) -> tuple[list[tuple[int, bool]], list[int]]: """Compute sample lengths as if each image insertion point defines a new document (ex document ID) """ num_post_image_tokens = type_cast(int, kwargs.get("num_post_image_tokens", 0)) num_pre_image_tokens = type_cast(int, kwargs.get("num_pre_image_tokens", 0)) squashed_samples_lengths = type_cast( list[list[int]] | None, kwargs.get("squashed_samples_lengths", None) ) if squashed_samples_lengths is not None: assert len(squashed_samples_lengths) == len(image_embeds_insertion_points) def __insert_next_sample__( batch_idx: int, insrt_pt: int, last_insrt_pt: int, end_of_batch_sample: bool = False ) -> None: nonlocal attention_mask nonlocal text_sample_lengths, full_sample_lengths nonlocal cum_samples_lengths, current_image_offset nonlocal last_image_idx, current_image_idx, current_length # Add the sample between [last_insrt_pt, insrt_pt] with breaks in # between any squashed samples we find on the way start_pt = bisect.bisect_left(cum_samples_lengths, last_insrt_pt) added_sample = False for end_of_sample in cum_samples_lengths[start_pt:]: # we will break the loop at the end when end_of_sample = insrt_pt end_of_sample = min(end_of_sample, insrt_pt) # Add between [last_insrt_pt, end_of_sample] current_length = end_of_sample - last_insrt_pt if attention_mask is not None: current_length -= int( torch.sum(~attention_mask[batch_idx, last_insrt_pt:end_of_sample]).item() ) if current_length > 0: added_sample = True text_sample_lengths.append( (current_length, end_of_batch_sample and insrt_pt == end_of_sample) ) # add image tokens to current_length if current_image_idx > 0 and image_embeds is not None: images_in_sample = [ img_idx for img_idx in range(last_image_idx, current_image_idx) if img_idx < len(image_embeds_insertion_points[batch_idx]) and last_insrt_pt <= image_embeds_insertion_points[batch_idx][img_idx] < end_of_sample ] if len(images_in_sample) > 0: num_image_tokens = sum( _x.shape[0] for _x in image_embeds[ current_image_offset + images_in_sample[0] : current_image_offset + images_in_sample[-1] + 1 ] ) current_length += num_image_tokens full_sample_lengths.append(current_length) # prepare for next loop last_insrt_pt = end_of_sample if end_of_sample == insrt_pt: break # End of loop: Catching weird use case where we may end up on a span # full of padding tokens which will not get added due to current_length > 0 if end_of_batch_sample: assert added_sample, "Weird edge case. Don't do that, thank you" text_sample_lengths[-1] = (text_sample_lengths[-1][0], True) # End of loop: Catching weird use case where we may end up on a span # full of padding tokens which will not get added due to current_length > 0 if end_of_batch_sample: assert added_sample, "Weird edge case. Don't do that, thank you" text_sample_lengths[-1] = (text_sample_lengths[-1][0], True) current_image_offset = 0 text_sample_lengths, full_sample_lengths = [], [] cum_samples_lengths: list[int] = [] current_length, last_insrt_pt, last_image_idx, current_image_idx = 0, 0, 0, 0 for batch_idx, pts in enumerate(image_embeds_insertion_points): if squashed_samples_lengths is not None: cum_samples_lengths = list(accumulate(squashed_samples_lengths[batch_idx])) else: assert total_seq_len is not None cum_samples_lengths = [total_seq_len] for current_image_idx, insrt_pt in enumerate(pts.cpu().tolist()): # check if the images are consecutive in which way we want # them to belong to the same window if current_image_idx >= 1 and insrt_pt == ( image_embeds_insertion_points[batch_idx][current_image_idx - 1] + num_pre_image_tokens + num_post_image_tokens ): continue # Otherwise, we found a new sample # not very important but for completeness: the insertion points come *after* # the pre-image tokens per design but for the document-id mask it is more consistent to # have them correspond to the same image insrt_pt -= num_pre_image_tokens # Update text and full sample lengths if insrt_pt > last_insrt_pt: __insert_next_sample__( batch_idx, insrt_pt, last_insrt_pt, end_of_batch_sample=False ) last_image_idx = current_image_idx last_insrt_pt = insrt_pt # End of batch: add sample in progress and reset current_image_idx += 1 if cum_samples_lengths[-1] > last_insrt_pt: __insert_next_sample__( batch_idx, cum_samples_lengths[-1], last_insrt_pt, end_of_batch_sample=True ) current_length, last_insrt_pt, last_image_idx, current_image_idx = 0, 0, 0, 0 current_image_offset += len(pts) # Sanity checks that the is_eob are correctly place assert sum(_x[1] for _x in text_sample_lengths) == len(image_embeds_insertion_points), ( f"Number of eob markers ({sum(_x[1] for _x in text_sample_lengths)}) differs" f" from original batch size ({len(image_embeds_insertion_points)})" ) return text_sample_lengths, full_sample_lengths class CASAAttentionHandler: def __init__( self, inputs_embeds: torch.Tensor, image_embeds: torch.Tensor | list[torch.Tensor], image_embeds_insertion_points: list[torch.Tensor], attention_mask: torch.Tensor | None = None, rope_fn: Callable | None = None, windows: Literal["batch", "squashed", "images", "turn_based"] = "images", use_asymetric_q_kv: bool = True, casa_windows_info: None | dict = None, ): """Initialize the structure holding the query buffer for CASA attention layers (ie the **flattened** text+image inserted tokens). Note that this structure is shared across all casa layers, and it gets updated with the current hidden states at every layer; this is merely a buffer to keep scatter_ operations in-plae as much as possible In this module, the embeddings related values (image_tokens_mask, text_sample_lengths etc) are stored under the assumption of a tensor which is *flatened* and *witout padding tokens* Only the attention mask is kept as-is (text-only, batched, padded) to be able to recover original shapes when needed """ super().__init__() assert windows == "images" # for inference code release # Note 1: Unless overriden, text/full_sample_lengths are defined such that one # document = one sample in the batch if attention_mask is None: text_sample_lengths = [(_x.shape[0], True) for _x in inputs_embeds] else: text_sample_lengths = [(int(torch.sum(_x).item()), True) for _x in attention_mask] ( full_inputs_embeds, full_sample_lengths, # Full attention mask is only needed at inference to # flatten the KV-Cache and remove padding tokens _, self.image_tokens_mask, ) = insert_image_tokens( inputs_embeds=inputs_embeds, image_embeds=image_embeds, image_embeds_insertion_points=image_embeds_insertion_points, attention_mask=attention_mask, recover_batch_dim=False, keep_only_attended=attention_mask is not None, ) assert self.image_tokens_mask.ndim == 2 self.image_embeds = image_embeds self.image_embeds_insertion_points = image_embeds_insertion_points self.attention_mask = None if attention_mask is None else attention_mask.bool() self.use_asymetric_qkv = use_asymetric_q_kv # At inference, we have to use asymetric QKV for efficiency if self.attention_mask is not None: self.use_asymetric_qkv = True # Build CASA windows assert casa_windows_info is not None text_sample_lengths, full_sample_lengths = get_sample_lengths_from_insertion_points( image_embeds_insertion_points=image_embeds_insertion_points, image_embeds=image_embeds, total_seq_len=inputs_embeds.shape[1], attention_mask=self.attention_mask, **casa_windows_info, # pyright: ignore ) # Sanity checks on the sample lengths self.text_sample_lengths = [(int(s), eob) for s, eob in text_sample_lengths if s > 0] self.full_sample_lengths = [int(s) for s in full_sample_lengths if s > 0] assert len(self.text_sample_lengths) == len(self.full_sample_lengths), ( f"Sanity check failed; text sample lengths {len(self.text_sample_lengths)}" f" != full sample lengths {len(self.full_sample_lengths)}" ) if self.attention_mask is None: num_unpadded_text_tokens = inputs_embeds.shape[0] * inputs_embeds.shape[1] else: num_unpadded_text_tokens = int( torch.sum(type_cast(torch.Tensor, attention_mask)).item() ) assert sum(_x[0] for _x in self.text_sample_lengths) == num_unpadded_text_tokens, ( f"Sanity check failed; sample lengths {sum(self.full_sample_lengths)} != {full_inputs_embeds.shape[0]}" ) assert sum(self.full_sample_lengths) == full_inputs_embeds.shape[0], ( f"Sanity check failed; sample lengths {sum(self.full_sample_lengths)} != {full_inputs_embeds.shape[0]}" ) # Finally we can compute cu_seqlen based on sample lengths self.max_seqlen_q = max(self.text_sample_lengths)[0] self.cu_seqlens_q = self.get_cu_seqlens( [x[0] for x in self.text_sample_lengths], device=inputs_embeds.device ) self.max_seqlen_kv = max(self.full_sample_lengths) self.cu_seqlens_kv = self.get_cu_seqlens( self.full_sample_lengths, device=inputs_embeds.device ) # For inference: We save the length of the current document # to trim the KV cache appropriately self.current_doc_lengths = self.full_sample_lengths # Precompute position embeddings self.position_embeds = None self.rope_fn = rope_fn if self.rope_fn is not None: self.position_embeds = self.compute_position_embeddings( self.rope_fn, full_sample_lengths, dummy_for_dtype_and_device=full_inputs_embeds ) @property def batch_lengths(self) -> list[int]: """Return a (batch_size,) list of integers containing the number of (non-padded) text tokens for each sample in the batch""" bls = [0] for ln, eob in self.text_sample_lengths: bls[-1] += ln if eob: bls.append(0) return bls[:-1] @property def full_batch_lengths(self) -> list[int]: """Same as batch_lengths for text+image tokens""" bls = [0] for (_, eob), ln in zip(self.text_sample_lengths, self.full_sample_lengths): bls[-1] += ln if eob: bls.append(0) return bls[:-1] def get_cu_seqlens( self, sample_lengths: list[int], device: torch.device | None ) -> torch.Tensor: """Update cu_seqlengths according to the given sample_lengths""" return torch.Tensor(list(accumulate(sample_lengths, initial=0))).to( dtype=torch.int32, device=device ) def compute_position_embeddings( self, rope_fn: Callable, sample_lengths: list[int], dummy_for_dtype_and_device: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Compute info required for position embeddings. Can be override e.g. for Qwen""" # option 1: Standard range # position_ids = torch.arange(0, full_inputs_embeds.shape[0]) # option 2: Follows document boundary position_ids = torch.cat([torch.arange(0, lg) for lg in sample_lengths], dim=0) return rope_fn( dummy_for_dtype_and_device, position_ids.to(dummy_for_dtype_and_device.device)[None, ...], ) def get_position_embedding( self, key: Literal["q", "kv"], num_queries: int = 0, ) -> tuple[torch.Tensor, torch.Tensor] | None: if self.position_embeds is None: return None cos, sin = self.position_embeds bls = self.full_batch_lengths # For Q, we only want the text-only posembeds if key == "q" and self.use_asymetric_qkv: bls = self.batch_lengths cos, sin = cos[:, ~self.image_tokens_mask[:, 0]], sin[:, ~self.image_tokens_mask[:, 0]] elif key not in {"q", "kv"}: raise ValueError(f"Unknow for position embedding {key}") # Easy case: training or first step at inference: we use all the posembeds if num_queries == 0: return cos, sin # If num queries is given, we need to trim for *every sample in the batch* cos = [x[:, -num_queries:] for x in torch.split(cos, bls, dim=1)] sin = [x[:, -num_queries:] for x in torch.split(sin, bls, dim=1)] return torch.cat(cos, dim=1), torch.cat(sin, dim=1) def get_full_embeds( self, hidden_states: torch.Tensor, norm_fn: Callable | None ) -> torch.Tensor: """Update attended hidden states in the current query buffer :param hidden_states: (b, s, d) Tensor input to the CASA attention layer" """ assert self.image_embeds is not None return insert_image_tokens( inputs_embeds=hidden_states, image_embeds=self.image_embeds if norm_fn is None else norm_fn(self.image_embeds) if isinstance(self.image_embeds, torch.Tensor) else [norm_fn(_x) for _x in self.image_embeds], image_embeds_insertion_points=self.image_embeds_insertion_points, attention_mask=self.attention_mask, recover_batch_dim=False, keep_only_attended=self.attention_mask is not None, )[0][None, :, :] def recover_text_embeds( self, hidden_states_out: torch.Tensor, hidden_states_in: torch.Tensor, update_image_embeddings: bool = False, ) -> torch.Tensor: """Returns text embeddings from the query buffer, including non-attended tokens at inference""" if update_image_embeddings and not self.use_asymetric_qkv: raise NotImplementedError("Implement image embeddings updates for asymetric QKV") # Remove image tokens in the symetric case if not self.use_asymetric_qkv: hidden_states_out = hidden_states_out[~self.image_tokens_mask[:, 0]] # if there's not attention mask, we are in the right padded case # (keep_only_attended = False) we can directly return the query # outputs (which don't contain the image) if self.attention_mask is None: return hidden_states_out # Otherwise, we need to "scatter" back only the text-attended tokens to the original # hidden states, which contain the paddings num_queries = hidden_states_in.shape[1] # Case 1: the padded hidden_states_in is larger than hidden_states_out # we rebatch+pad hidden_state_out before doing the scattering if hidden_states_out.shape[0] != hidden_states_in.shape[0] * hidden_states_in.shape[1]: s = torch.split(hidden_states_out, self.batch_lengths, dim=0) assert max(_s.shape[0] for _s in s) <= num_queries # sanity check s = [ torch.nn.functional.pad(_s, (0, 0, num_queries - _s.shape[0], 0), value=0) for _s in s ] return torch.where( self.attention_mask[:, -num_queries:, None], torch.stack(s), hidden_states_in, ) # If both have the smae shape, it means hidden_states_in contained no padding # so we can directly return hidden states out return hidden_states_out def extend(self, num_tokens: int, offset: int = 0): """Extend all necessary values of the Handler for infenrece Note: this implementation curently assumes a single conversation at a time (otherwise image tokens mask would have to change) and that tokens added are attended to""" # image embeds is inserted in the first step and stored in the KV cache self.image_embeds = None # Update attention mask (non-flattened) (assumes all new tokens are attended to) if self.attention_mask is not None: self.attention_mask = torch.nn.functional.pad( self.attention_mask, (0, num_tokens), value=1 ) # Update image token mask (assumes only one image/conversation # is started at once so that we always extend by zero) # Note that the mask is stored flattened to avoid padding so we have to # do something a bit ugly and inefficient here imtokmask = torch.split(self.image_tokens_mask, self.full_batch_lengths, dim=0) imtokmask = [torch.nn.functional.pad(x, (0, 0, 0, num_tokens), value=0) for x in imtokmask] self.image_tokens_mask = torch.cat(imtokmask, dim=0) # Recompute cumulative document lengths after assigning the new # number of tokens to each sample in the batch for idx, (ln, is_eob) in enumerate(self.text_sample_lengths): if is_eob: self.text_sample_lengths[idx] = (num_tokens + ln, is_eob) self.full_sample_lengths[idx] += num_tokens # Recompute cu sequlen # First step: Technically this never occurs, but we keep it for completeness if offset == 0: self.max_seqlen_q = max(self.text_sample_lengths)[0] self.cu_seqlens_q = self.get_cu_seqlens( [x[0] for x in self.text_sample_lengths], device=self.cu_seqlens_q.device ) self.max_seqlen_kv = max(self.full_sample_lengths) self.cu_seqlens_kv = self.get_cu_seqlens( self.full_sample_lengths, device=self.cu_seqlens_kv.device ) # Step > 0: the annoying part is since flashattn_varlen does not accept # 0-len documents, we need to remove documents from the KV Cache when they're past # their windows. In our current setting, this means we only want to keep the latest # documents else: self.max_seqlen_q = num_tokens self.cu_seqlens_q = self.get_cu_seqlens( [num_tokens for (_, eob) in self.text_sample_lengths if eob], device=self.cu_seqlens_q.device, ) final_doc_lengths = [ ln for (_, eob), ln in zip(self.text_sample_lengths, self.full_sample_lengths) if eob ] self.current_doc_lengths = final_doc_lengths self.max_seqlen_kv = max(self.current_doc_lengths) self.cu_seqlens_kv = self.get_cu_seqlens( final_doc_lengths, device=self.cu_seqlens_kv.device, ) # Update position embeddings if self.rope_fn is not None and self.position_embeds is not None: self.position_embeds = self.compute_position_embeddings( self.rope_fn, self.full_sample_lengths, dummy_for_dtype_and_device=self.position_embeds[0], ) @dataclass class CASAAttentionStreamingState(StreamingState): """Streaming State for CASA Atention module. Keep the hidden""" k: torch.Tensor = None # pyright: ignore[reportAssignmentType] v: torch.Tensor = None # pyright: ignore[reportAssignmentType] recover_batched_trims: list[int] = None # pyright: ignore[reportAssignmentType] casa_handler: CASAAttentionHandler = None # pyright: ignore[reportAssignmentType] def maybe_get_casa_handler( self, casa_handler: CASAAttentionHandler | None, is_first_casa_layer: bool = False, num_queries: int = -1, ) -> CASAAttentionHandler | None: # Set given Casa Handler the first time we reach this if self.casa_handler is None: self.casa_handler = casa_handler # pyright: ignore # subsequent calls: we need to extend shape to accomodate new tokens # however because CASA handler is shared across layers, we only need to do it once if self.casa_handler is not None and self.offset > 0 and is_first_casa_layer: # since CasaHandler is shared, we only use its extend step once self.casa_handler.extend(num_queries, offset=self.offset) return self.casa_handler def __recover_batched_kv__(self, states: torch.Tensor) -> torch.Tensor: """Recover batched key/value states with left padding""" s = torch.split(states, self.casa_handler.full_batch_lengths, dim=1) mlen = max(_s.shape[1] for _s in s) # Remember the added padding so that we can re-flatten KV later if self.recover_batched_trims is None: self.recover_batched_trims = [mlen - _s.shape[1] for _s in s] s = [torch.nn.functional.pad(_s, (0, 0, 0, 0, mlen - _s.shape[1], 0), value=0) for _s in s] return torch.cat(s, dim=0) def __get_flattened_kv__( self, k: torch.Tensor | None = None, v: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor]: """ Flattened and remove padding to act with flash_attn_func """ k = self.k if k is None else k v = self.v if v is None else v assert k is not None and v is not None # Since every batch at least contributes one document, # we can use this to check whether we are in streaming mode with dropped docs. # If so, we should trim the kv cache accordingly if len(self.casa_handler.current_doc_lengths) == len(k): k = torch.cat( [ _k[self.recover_batched_trims[idx] :][-doc_len:] for idx, _k, doc_len in zip( range(len(k)), k, self.casa_handler.current_doc_lengths ) ] ) v = torch.cat( [ _v[self.recover_batched_trims[idx] :][-doc_len:] for idx, _v, doc_len in zip( range(len(k)), v, self.casa_handler.current_doc_lengths ) ] ) return k[None, ...], v[None, ...] k = torch.cat([_k[self.recover_batched_trims[idx] :] for idx, _k in enumerate(k)]) v = torch.cat([_v[self.recover_batched_trims[idx] :] for idx, _v in enumerate(v)]) return k[None, ...], v[None, ...] def extend_kv( self, key_states: torch.Tensor, value_states: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """ Extend KV Cache while keep """ assert self.casa_handler is not None if self.k is None and self.v is None: # Init with batch-padded key and value states self.k = self.__recover_batched_kv__(key_states) self.v = self.__recover_batched_kv__(value_states) return self.__get_flattened_kv__() if self.k is not None and self.v is not None: # this is during generation; normally there is no padding at this stage # so we can directly reshape the flattened key states rshp = (self.k.shape[0], -1, self.k.shape[2], self.k.shape[3]) self.k = torch.cat([self.k, key_states.reshape(rshp)], dim=1) self.v = torch.cat([self.v, value_states.reshape(rshp)], dim=1) return self.__get_flattened_kv__() raise ValueError("Impossible configuration (k and v updates are desynchronized )") class CASAAttention(StreamingModule[CASAAttentionStreamingState]): def __init__( self, config: "PretrainedConfig", layer_idx: int | None, self_attn: torch.nn.Module | None = None, input_layernorm_fn: Callable[[torch.Tensor], torch.Tensor] | None = None, ): super().__init__(CASAAttentionStreamingState) self.head_dim = config.head_dim self.config = config self.is_first_casa_layer = layer_idx == (min(config.xa_layers) if config.xa_layers else 0) self.use_delta_w = config.casa_delta_w self.q_proj_casa = self.init_from_config_proj("q", config) self.k_proj_casa = self.init_from_config_proj("k", config) self.v_proj_casa = self.init_from_config_proj("v", config) self.o_proj_casa = self.init_from_config_proj("o", config) # Delta_w self.override_q_proj: Callable[[torch.Tensor], torch.Tensor] | None = None self.override_k_proj: Callable[[torch.Tensor], torch.Tensor] | None = None self.override_v_proj: Callable[[torch.Tensor], torch.Tensor] | None = None self.override_o_proj: Callable[[torch.Tensor], torch.Tensor] | None = None if config.casa_delta_w: assert self_attn is not None self.set_delta_w(self_attn) # Layer norm self.norm_fn: Callable | None = None if config.xa_norm_on_images: assert input_layernorm_fn is not None self.norm_fn = input_layernorm_fn def init_from_mha(self, self_attn: torch.nn.Module): assert self_attn is not None with torch.no_grad(): assert hasattr(self_attn, "q_proj") for key in ["q", "k", "v", "o"]: src = type_cast(torch.nn.Linear, getattr(self_attn, f"{key}_proj")) tgt = type_cast(torch.nn.Linear, getattr(self, f"{key}_proj_casa")) tgt.weight.copy_(src.weight) if tgt.bias is not None and src.bias is not None: tgt.bias.copy_(src.bias) def set_delta_w(self, self_attn: torch.nn.Module): """Delta w setup""" self.override_q_proj = delta_w_factory( self.q_proj_casa, type_cast(torch.nn.Linear, self_attn.q_proj) ) self.override_k_proj = delta_w_factory( self.k_proj_casa, type_cast(torch.nn.Linear, self_attn.k_proj) ) self.override_v_proj = delta_w_factory( self.v_proj_casa, type_cast(torch.nn.Linear, self_attn.v_proj) ) self.override_o_proj = delta_w_factory( self.o_proj_casa, type_cast(torch.nn.Linear, self_attn.o_proj) ) with torch.no_grad(): torch.nn.init.zeros_(self.q_proj_casa.weight) torch.nn.init.zeros_(self.k_proj_casa.weight) torch.nn.init.zeros_(self.v_proj_casa.weight) torch.nn.init.zeros_(self.o_proj_casa.weight) if self.q_proj_casa.bias is not None: torch.nn.init.zeros_(self.q_proj_casa.bias) if self.k_proj_casa.bias is not None: torch.nn.init.zeros_(self.k_proj_casa.bias) if self.v_proj_casa.bias is not None: torch.nn.init.zeros_(self.v_proj_casa.bias) if self.o_proj_casa.bias is not None: torch.nn.init.zeros_(self.o_proj_casa.bias) def init_from_config_proj( self, key: Literal["q", "o", "k", "v"], config: PretrainedConfig ) -> torch.nn.Linear: """Initialize the Linear proj in this module""" raise NotImplementedError("Abastract class.") def apply_position_embeddings( self, key: Literal["q", "kv"], x: torch.Tensor, # (batch, seq_len, num_heads, head_dim) casa_handler: CASAAttentionHandler | None, num_queries: int = 0, unsqueeze_dim: int = 1, ) -> torch.Tensor: # (batch, seq_len, num_heads, head_dim) """Apply position embeddings to query and key states""" raise NotImplementedError("Abastract class.") def forward( self, hidden_states: torch.Tensor, casa_handler: CASAAttentionHandler | None, ) -> torch.Tensor | None: """Generic forward for CASA uses for instance in `helium1_attention`""" og_dtype = hidden_states.dtype if self.is_streaming: casa_handler = self.streaming_state.maybe_get_casa_handler( casa_handler, is_first_casa_layer=self.is_first_casa_layer, num_queries=hidden_states.shape[1], ) # Case of text-only samples at training (or inference when no handler was cached) # in this case we just skip CASA so we return None (no casa_update) if casa_handler is None: return None if self.is_streaming: assert casa_handler.use_asymetric_qkv, ( "You should set `use_asymetric_qkv` to True during inference" ) og_shape = hidden_states.shape # Build Q inputs if casa_handler.use_asymetric_qkv: q_inputs = hidden_states.flatten(0, 1)[None, ...] if casa_handler.attention_mask is not None: q_inputs = q_inputs[:, casa_handler.attention_mask[:, -og_shape[1] :].flatten()] else: q_inputs = casa_handler.get_full_embeds(hidden_states, norm_fn=self.norm_fn) # Case 1: Training or first inference step if not self.is_streaming or self.streaming_state.offset == 0: kv_inputs = casa_handler.get_full_embeds(hidden_states, norm_fn=self.norm_fn) else: # during streaming, the KV cache including image embeddings # will be inserted later so for now we only update the incoming queries kv_inputs = q_inputs # Compute QKV for the blockwise attention bs, total_seq_len = kv_inputs.shape[:2] hidden_shape_q = (bs, q_inputs.shape[1], -1, self.head_dim) hidden_shape_kv = (bs, total_seq_len, -1, self.head_dim) if self.override_q_proj is None: query_states = self.q_proj_casa(q_inputs).view(*hidden_shape_q) else: query_states = self.override_q_proj(q_inputs).view(*hidden_shape_q) if self.override_k_proj is None: key_states = self.k_proj_casa(kv_inputs).view(*hidden_shape_kv) else: key_states = self.override_k_proj(kv_inputs).view(*hidden_shape_kv) if self.override_v_proj is None: value_states = self.v_proj_casa(kv_inputs).view(*hidden_shape_kv) else: value_states = self.override_v_proj(kv_inputs).view(*hidden_shape_kv) # Apply position embedding at the right offset num_queries = 0 if self.streaming and self.streaming_state.offset > 0: num_queries = og_shape[1] query_states = self.apply_position_embeddings( "q", query_states, num_queries=num_queries, casa_handler=casa_handler ) key_states = self.apply_position_embeddings( "kv", key_states, num_queries=num_queries, casa_handler=casa_handler ) assert flash_attn_varlen_func is not None, ( "flash_attention is not installed but required for block-wise attention" ) # Flashattention has different efficient implem for streaming # In that case, the KV cache has to be batched and has been extended # to accomodate the shape of ne the new updates if self.is_streaming: key_states, value_states = self.streaming_state.extend_kv( key_states=key_states, value_states=value_states ) if casa_handler.use_asymetric_qkv: cu_seqlens_q = casa_handler.cu_seqlens_q max_seqlen_q = casa_handler.max_seqlen_q else: cu_seqlens_q = casa_handler.cu_seqlens_kv max_seqlen_q = casa_handler.max_seqlen_kv assert cu_seqlens_q[-1] == query_states.shape[1], ( f"{cu_seqlens_q[-1]} != {query_states.shape[1]}" ) assert casa_handler.cu_seqlens_kv[-1] == key_states.shape[1], ( f"{casa_handler.cu_seqlens_kv[-1]} != {key_states.shape[1]}" ) # for quer attn_output: torch.Tensor = flash_attn_varlen_func( query_states[0].to(torch.bfloat16), key_states[0].to(torch.bfloat16), value_states[0].to(torch.bfloat16), cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=casa_handler.cu_seqlens_kv, max_seqlen_q=max_seqlen_q, max_seqlen_k=casa_handler.max_seqlen_kv, dropout_p=0.0, # softmax_scale=None, # defaults to 1/sqrt(d) causal=True, ).to(og_dtype) attn_output = attn_output.reshape(hidden_shape_q[1], -1).contiguous() if self.override_o_proj is None: attn_output = self.o_proj_casa(attn_output) else: attn_output = self.override_o_proj(attn_output) attn_output = casa_handler.recover_text_embeds( attn_output, hidden_states, update_image_embeddings=self.config.xa_update_image_embeds ) attn_output = attn_output.reshape(og_shape) if self.is_streaming: self.streaming_state.offset += attn_output.shape[1] return attn_output