| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import contextlib |
| from typing import Any, Callable, Optional, Union, Tuple, TypedDict |
|
|
| import os |
| import torch |
| from torch import nn |
|
|
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| from ...integrations import use_kernel_forward_from_hub |
| try: |
| from ...masking_utils import create_causal_mask |
| except Exception: |
| def _compat_past_key_values_length(past_key_values) -> int: |
| if past_key_values is None: |
| return 0 |
| get_seq_length = getattr(past_key_values, "get_seq_length", None) |
| if callable(get_seq_length): |
| try: |
| return max(0, int(get_seq_length())) |
| except Exception: |
| pass |
| if isinstance(past_key_values, tuple) and len(past_key_values) > 0: |
| try: |
| first = past_key_values[0][0] |
| if isinstance(first, torch.Tensor) and first.dim() >= 3: |
| return max(0, int(first.shape[2])) |
| except Exception: |
| pass |
| return 0 |
|
|
| def create_causal_mask( |
| *, |
| config, |
| input_embeds, |
| attention_mask, |
| cache_position=None, |
| past_key_values=None, |
| position_ids=None, |
| ): |
| del position_ids |
| batch_size, query_len = input_embeds.shape[:2] |
| device = input_embeds.device |
| dtype = input_embeds.dtype |
| past_key_values_length = _compat_past_key_values_length(past_key_values) |
| if past_key_values_length == 0 and isinstance(cache_position, torch.Tensor) and cache_position.numel() > 0: |
| try: |
| past_key_values_length = max(0, int(cache_position.reshape(-1)[0].item())) |
| except Exception: |
| past_key_values_length = 0 |
| key_len = past_key_values_length + query_len |
| if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 2: |
| key_len = max(key_len, int(attention_mask.size(1))) |
|
|
| if isinstance(cache_position, torch.Tensor) and cache_position.numel() == query_len: |
| query_positions = cache_position.to(device=device, dtype=torch.long).view(1, query_len, 1) |
| if query_positions.dim() == 3 and query_positions.size(0) != batch_size: |
| query_positions = query_positions.expand(batch_size, -1, -1) |
| else: |
| query_positions = ( |
| torch.arange(query_len, device=device, dtype=torch.long).view(1, query_len, 1) + past_key_values_length |
| ).expand(batch_size, -1, -1) |
|
|
| key_positions = torch.arange(key_len, device=device, dtype=torch.long).view(1, 1, key_len) |
| allow = key_positions <= query_positions |
|
|
| if isinstance(attention_mask, torch.Tensor): |
| key_valid = attention_mask.to(device=device, dtype=torch.bool) |
| if key_valid.dim() == 1: |
| key_valid = key_valid.unsqueeze(0) |
| if key_valid.dim() == 2: |
| if int(key_valid.size(1)) < key_len: |
| pad = torch.ones( |
| (int(key_valid.size(0)), key_len - int(key_valid.size(1))), |
| device=device, |
| dtype=torch.bool, |
| ) |
| key_valid = torch.cat([pad, key_valid], dim=1) |
| elif int(key_valid.size(1)) > key_len: |
| key_valid = key_valid[:, -key_len:] |
| allow = allow & key_valid.unsqueeze(1) |
| query_valid = key_valid[:, -query_len:] |
| allow = allow & query_valid.unsqueeze(-1) |
|
|
| neg_inf = torch.tensor(torch.finfo(torch.float32).min, device=device) |
| zeros = torch.tensor(0.0, device=device) |
| mask = torch.where(allow, zeros, neg_inf).to(dtype) |
| return mask.view(batch_size, 1, query_len, key_len) |
| try: |
| from ...modeling_layers import ( |
| GenericForQuestionAnswering, |
| GenericForSequenceClassification, |
| GenericForTokenClassification, |
| GradientCheckpointingLayer, |
| ) |
| except Exception: |
| class GradientCheckpointingLayer(nn.Module): |
| pass |
|
|
| class GenericForSequenceClassification: |
| pass |
|
|
| class GenericForQuestionAnswering: |
| pass |
|
|
| class GenericForTokenClassification: |
| pass |
| from ...modeling_outputs import ( |
| BaseModelOutputWithPast, |
| CausalLMOutputWithPast, |
| ) |
| from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
| from ...processing_utils import Unpack |
| try: |
| from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging |
| except Exception: |
| from ...utils import can_return_tuple, logging |
|
|
| class TransformersKwargs(TypedDict, total=False): |
| pass |
|
|
| def auto_docstring(obj=None, *args, **kwargs): |
| del args, kwargs |
| if callable(obj): |
| return obj |
|
|
| def decorator(inner): |
| return inner |
|
|
| return decorator |
| from ...utils.deprecation import deprecate_kwarg |
| try: |
| from ...utils.generic import check_model_inputs |
| except Exception: |
| def check_model_inputs(fn): |
| return fn |
| from .configuration_llama import LlamaConfig |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| PREFILL_LOWER_REPLAY_USER_PREFIX_KEEP_LEN_KEY = "prefill_lower_replay_user_prefix_keep_len" |
| PREFILL_LOWER_REPLAY_USER_START_KEY = "prefill_lower_replay_user_start" |
| PREFILL_LOWER_REPLAY_USER_LEN_KEY = "prefill_lower_replay_user_len" |
|
|
|
|
| try: |
| from open_instruct.dataset_transformation import ( |
| ASSISTANT_HEADER_START_MASK_KEY, |
| ASSISTANT_HEADER_STARTS_KEY, |
| ) |
| except Exception: |
| ASSISTANT_HEADER_STARTS_KEY = "assistant_header_starts" |
| ASSISTANT_HEADER_START_MASK_KEY = "assistant_header_start_mask" |
|
|
| ASSISTANT_TURN_ENDS_KEY = "assistant_turn_ends" |
|
|
|
|
| try: |
| from open_instruct.llopa_adapter import ( |
| PREFILL_LOWER_SYSTEM_LEN_KEY, |
| build_prefill_lower_upper_indices, |
| normalize_system_prefill, |
| ) |
| except Exception: |
| PREFILL_LOWER_SYSTEM_LEN_KEY = "prefill_lower_system_len" |
|
|
| def normalize_system_prefill(system_prefill: str) -> str: |
| mode = (system_prefill or "full").strip().lower() |
| if mode not in {"full", "no_system", "no_bos_system"}: |
| return "full" |
| return mode |
|
|
| def build_prefill_lower_upper_indices( |
| *, |
| sequence_len: int, |
| split_start: int, |
| system_len: int, |
| system_prefill: str, |
| device, |
| ) -> torch.LongTensor: |
| total_len = max(int(sequence_len), 0) |
| split_start = min(max(int(split_start), 0), total_len) |
| system_len = min(max(int(system_len), 0), split_start) |
| mode = normalize_system_prefill(system_prefill) |
| if mode == "full": |
| prefix_keep = system_len |
| elif mode == "no_system": |
| prefix_keep = min(system_len, 1) |
| else: |
| prefix_keep = 0 |
| pieces = [] |
| if prefix_keep > 0: |
| pieces.append(torch.arange(0, prefix_keep, device=device, dtype=torch.long)) |
| if split_start < total_len: |
| pieces.append(torch.arange(split_start, total_len, device=device, dtype=torch.long)) |
| if not pieces: |
| return torch.empty((0,), device=device, dtype=torch.long) |
| if len(pieces) == 1: |
| return pieces[0] |
| return torch.cat(pieces, dim=0) |
|
|
|
|
| try: |
| from ...utils import is_flash_attn_2_available |
| except Exception: |
| def is_flash_attn_2_available(): |
| return False |
|
|
| try: |
| from flash_attn.flash_attn_interface import flash_attn_varlen_func as _flash_attn_varlen_func |
| except Exception: |
| _flash_attn_varlen_func = None |
|
|
| _WARNED_NO_FLASH_ATTN = False |
|
|
| def _resolve_attn_impl(config): |
| attn_impl = getattr(config, "_attn_implementation", "eager") |
| if attn_impl == "flash_attention_2" and not is_flash_attn_2_available(): |
| require_flash = str(os.environ.get("FLASH_ATTN_REQUIRED", "0")).lower() in ("1", "true", "yes") |
| if require_flash: |
| raise ValueError( |
| "flash_attention_2 requested but flash_attn is unavailable. " |
| "Install flash-attn or unset FLASH_ATTN_REQUIRED to allow sdpa fallback." |
| ) |
| global _WARNED_NO_FLASH_ATTN |
| if not _WARNED_NO_FLASH_ATTN: |
| logger.warning("flash_attention_2 requested but flash_attn is unavailable; falling back to sdpa.") |
| _WARNED_NO_FLASH_ATTN = True |
| attn_impl = "sdpa" |
| try: |
| setattr(config, "_attn_implementation", attn_impl) |
| except Exception: |
| pass |
| return attn_impl |
|
|
| os.environ.setdefault("DS_BUILD_AIO", "0") |
| _USE_DS_CHECKPOINT = os.environ.get("LLOPA_ENABLE_DS_CHECKPOINT", "").strip().lower() in { |
| "1", |
| "true", |
| "yes", |
| "on", |
| } |
| if _USE_DS_CHECKPOINT: |
| try: |
| from deepspeed.checkpointing import checkpoint as ds_checkpoint |
| except Exception: |
| ds_checkpoint = None |
| else: |
| ds_checkpoint = None |
|
|
| try: |
| from checkpoint_utils import checkpoint_with_cache, should_use_checkpoint |
| except Exception: |
| def checkpoint_with_cache(fn, *args, **kwargs): |
| return fn(*args, **kwargs) |
|
|
| def should_use_checkpoint(model) -> bool: |
| return False |
|
|
| def _safe_dynamic_cache(config): |
| try: |
| return DynamicCache(config=config) |
| except TypeError as exc: |
| if "max_cache_len" in str(exc) or "config" in str(exc): |
| return DynamicCache() |
| raise |
|
|
|
|
| def _tri_cache_is_empty(past_key_values) -> bool: |
| if past_key_values is None: |
| return True |
| if isinstance(past_key_values, tuple): |
| if len(past_key_values) == 0: |
| return True |
| try: |
| first = past_key_values[0][0] |
| return not isinstance(first, torch.Tensor) or int(first.shape[2]) == 0 |
| except Exception: |
| return False |
| get_seq_length = getattr(past_key_values, "get_seq_length", None) |
| if callable(get_seq_length): |
| try: |
| return int(get_seq_length()) == 0 |
| except Exception: |
| return False |
| return False |
|
|
|
|
| def _normalize_last_layer_module(last_layer_module: str) -> str: |
| mode = str(last_layer_module or "none").strip().lower() |
| aliases = { |
| "": "none", |
| "off": "none", |
| "disabled": "none", |
| "disable": "none", |
| "self-attention": "self", |
| "self_attention": "self", |
| "selfattn": "self", |
| "self_attn": "self", |
| "cross-attention": "cross", |
| "cross_attention": "cross", |
| "crossattn": "cross", |
| "cross_attn": "cross", |
| } |
| mode = aliases.get(mode, mode) |
| if mode not in {"none", "self", "cross"}: |
| mode = "none" |
| return mode |
|
|
|
|
| def _normalize_replay_module(replay_module: str) -> str: |
| return _normalize_last_layer_module(replay_module) |
|
|
|
|
| def _replay_module_enabled(config) -> bool: |
| replay_module = _normalize_replay_module( |
| getattr(config, "capsule_replay_module", getattr(config, "capsule_last_layer_module", "none")) |
| ) |
| return replay_module != "none" |
|
|
|
|
| def _normalize_replay_per_layers(replay_per_layers) -> int: |
| try: |
| value = int(replay_per_layers) |
| except Exception: |
| value = -1 |
| if value == -1: |
| return -1 |
| if value <= 0: |
| return -1 |
| return value |
|
|
|
|
| def _normalize_upper_attention_mode(mode: Any) -> str: |
| normalized = str(mode or "causal").strip().lower().replace("-", "_") |
| aliases = { |
| "": "causal", |
| "normal": "causal", |
| "default": "causal", |
| "full": "causal", |
| "causal_attention": "causal", |
| "self": "solo_v2", |
| "self_only": "solo_v2", |
| "solo": "solo_v2", |
| "solo2": "solo_v2", |
| "solo_attention_v2": "solo_v2", |
| } |
| normalized = aliases.get(normalized, normalized) |
| if normalized not in {"causal", "solo_v2"}: |
| raise ValueError("upper_attention_mode must be one of {'causal', 'solo_v2'}.") |
| return normalized |
|
|
|
|
| def _normalize_attention_gate_mode(mode: Any) -> str: |
| normalized = str(mode or "off").strip().lower() |
| aliases = { |
| "": "off", |
| "none": "off", |
| "disabled": "off", |
| "disable": "off", |
| "false": "off", |
| "0": "off", |
| "paper": "sdpa_sigmoid", |
| "sdpa_gate": "sdpa_sigmoid", |
| "sdpa-gate": "sdpa_sigmoid", |
| "sigmoid_after_sdpa": "sdpa_sigmoid", |
| "sdpa_elementwise_sigmoid": "sdpa_sigmoid", |
| } |
| normalized = aliases.get(normalized, normalized) |
| if normalized not in {"off", "sdpa_sigmoid"}: |
| normalized = "off" |
| return normalized |
|
|
|
|
| def _attention_gate_enabled(config) -> bool: |
| return _normalize_attention_gate_mode(getattr(config, "capsule_attention_gate_mode", "off")) == "sdpa_sigmoid" |
|
|
|
|
| def _init_sdpa_gate_linear(linear: nn.Linear) -> None: |
| with torch.no_grad(): |
| linear.weight.zero_() |
| if linear.bias is not None: |
| linear.bias.fill_(4.0) |
| setattr(linear, "_is_hf_initialized", True) |
|
|
|
|
| def _apply_sdpa_output_gate(attn_module: nn.Module, hidden_states: torch.Tensor, attn_output: torch.Tensor) -> torch.Tensor: |
| gate_proj = getattr(attn_module, "sdpa_gate_proj", None) |
| if gate_proj is None or not isinstance(hidden_states, torch.Tensor): |
| return attn_output |
| gate = torch.sigmoid(gate_proj(hidden_states)) |
| gate = gate.view( |
| hidden_states.size(0), |
| hidden_states.size(1), |
| int(getattr(attn_module.config, "num_attention_heads")), |
| int(getattr(attn_module, "head_dim")), |
| ) |
| return attn_output * gate.to(device=attn_output.device, dtype=attn_output.dtype) |
|
|
|
|
| def _copy_attention_module_state(target_attn: nn.Module, source_attn: nn.Module) -> None: |
| target_attn.load_state_dict(source_attn.state_dict(), strict=False) |
|
|
|
|
| def _seed_replay_cross_attn_from_self_pre_hook( |
| module: nn.Module, |
| state_dict: dict[str, torch.Tensor], |
| prefix: str, |
| local_metadata, |
| strict: bool, |
| missing_keys, |
| unexpected_keys, |
| error_msgs, |
| ) -> None: |
| replay_attn = getattr(module, "replay_cross_attn", None) |
| self_attn = getattr(module, "self_attn", None) |
| if replay_attn is None or self_attn is None: |
| return |
|
|
| replay_state = replay_attn.state_dict() |
| for key in replay_state.keys(): |
| replay_key = f"{prefix}replay_cross_attn.{key}" |
| if replay_key in state_dict: |
| continue |
| self_key = f"{prefix}self_attn.{key}" |
| if self_key not in state_dict: |
| continue |
| value = state_dict[self_key] |
| if isinstance(value, torch.Tensor): |
| state_dict[replay_key] = value.detach().clone() |
| else: |
| state_dict[replay_key] = value |
|
|
|
|
| def _tri_replay_layer_index_set( |
| *, |
| upper_layer_indices, |
| replay_per_layers, |
| ) -> set[int]: |
| upper = [int(li) for li in upper_layer_indices] |
| if not upper: |
| return set() |
| replay_per_layers = _normalize_replay_per_layers(replay_per_layers) |
| if replay_per_layers == -1: |
| return {int(upper[-1])} |
| return {int(upper[idx]) for idx in range(int(replay_per_layers) - 1, len(upper), int(replay_per_layers))} |
|
|
|
|
| @use_kernel_forward_from_hub("RMSNorm") |
| class LlamaRMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-6): |
| """ |
| LlamaRMSNorm is equivalent to T5LayerNorm |
| """ |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return self.weight * hidden_states.to(input_dtype) |
|
|
| def extra_repr(self): |
| return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" |
|
|
|
|
| class LlamaRotaryEmbedding(nn.Module): |
| inv_freq: torch.Tensor |
|
|
| def __init__(self, config: LlamaConfig, device=None): |
| super().__init__() |
| |
| if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): |
| self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) |
| else: |
| self.rope_type = "default" |
| self.max_seq_len_cached = config.max_position_embeddings |
| self.original_max_seq_len = config.max_position_embeddings |
|
|
| self.config = config |
| self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
| inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.original_inv_freq = self.inv_freq |
|
|
| @torch.no_grad() |
| @dynamic_rope_update |
| def forward(self, x, position_ids): |
| inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
| position_ids_expanded = position_ids[:, None, :].float() |
|
|
| device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
| with torch.autocast(device_type=device_type, enabled=False): |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() * self.attention_scaling |
| sin = emb.sin() * self.attention_scaling |
|
|
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
| def rotate_half(x): |
| """Rotates half the hidden dims of the input.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| Args: |
| q (`torch.Tensor`): The query tensor. |
| k (`torch.Tensor`): The key tensor. |
| cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| sin (`torch.Tensor`): The sine part of the rotary embedding. |
| position_ids (`torch.Tensor`, *optional*): |
| Deprecated and unused. |
| unsqueeze_dim (`int`, *optional*, defaults to 1): |
| The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| Returns: |
| `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
| """ |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| class LlamaMLP(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) |
| self.act_fn = ACT2FN[config.hidden_act] |
|
|
| def forward(self, x): |
| down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
| return down_proj |
|
|
|
|
| def _checkpoint_mlp(mlp, hidden_states): |
| if ds_checkpoint is not None: |
| return ds_checkpoint(mlp, hidden_states) |
| try: |
| return torch.utils.checkpoint.checkpoint(mlp, hidden_states, use_reentrant=False) |
| except TypeError: |
| return torch.utils.checkpoint.checkpoint(mlp, hidden_states) |
|
|
|
|
| def _run_decoder_mlp_only(layer, hidden_states: torch.Tensor, use_checkpoint: bool) -> torch.Tensor: |
| residual = hidden_states |
| hidden_states = layer.post_attention_layernorm(hidden_states) |
| if use_checkpoint: |
| hidden_states = _checkpoint_mlp(layer.mlp, hidden_states) |
| else: |
| hidden_states = layer.mlp(hidden_states) |
| return residual + hidden_states |
|
|
|
|
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| """ |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
| def eager_attention_forward( |
| module: nn.Module, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| scaling: float, |
| dropout: float = 0.0, |
| **kwargs: Unpack[TransformersKwargs], |
| ): |
| key_states = repeat_kv(key, module.num_key_value_groups) |
| value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
| attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
| if attention_mask is not None: |
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
| attn_weights = attn_weights + causal_mask |
|
|
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
| attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
| attn_output = torch.matmul(attn_weights, value_states) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
| return attn_output, attn_weights |
|
|
|
|
| class LlamaAttention(nn.Module): |
| """Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
| def __init__(self, config: LlamaConfig, layer_idx: int): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) |
| self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
| self.scaling = self.head_dim**-0.5 |
| self.attention_dropout = config.attention_dropout |
| self.is_causal = True |
|
|
| self.q_proj = nn.Linear( |
| config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.k_proj = nn.Linear( |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.v_proj = nn.Linear( |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| self.sdpa_gate_proj = None |
| if _attention_gate_enabled(config): |
| self.sdpa_gate_proj = nn.Linear( |
| config.hidden_size, |
| config.num_attention_heads * self.head_dim, |
| bias=True, |
| ) |
| _init_sdpa_gate_linear(self.sdpa_gate_proj) |
|
|
| @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor], |
| past_key_values: Optional[Cache] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| if hidden_states.size(1) == 0: |
| |
| return hidden_states, None |
| input_shape = hidden_states.shape[:-1] |
| hidden_shape = (*input_shape, -1, self.head_dim) |
| llopa_special_kv = kwargs.pop("llopa_special_kv", None) |
| llopa_special_span = kwargs.pop("llopa_special_span", None) |
| extra_prefix_kv = kwargs.pop("extra_prefix_kv", None) |
| extra_prefix_valid_mask = kwargs.pop("extra_prefix_valid_mask", None) |
| extra_prefix_query_mask = kwargs.pop("extra_prefix_query_mask", None) |
| extra_prefix_local_valid_mask = kwargs.pop("extra_prefix_local_valid_mask", None) |
|
|
| query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| cos, sin = position_embeddings |
| cos_u = cos.unsqueeze(1) |
| sin_u = sin.unsqueeze(1) |
| query_states = (query_states * cos_u) + (rotate_half(query_states) * sin_u) |
|
|
| key_states = None |
| value_states = None |
| if llopa_special_kv is not None and llopa_special_span is not None: |
| try: |
| sp_start, sp_end = int(llopa_special_span[0]), int(llopa_special_span[1]) |
| except Exception: |
| sp_start, sp_end = -1, -1 |
| seq_len = int(input_shape[1]) |
| if ( |
| 0 <= sp_start < sp_end <= seq_len |
| and isinstance(llopa_special_kv, (tuple, list)) |
| and len(llopa_special_kv) == 2 |
| ): |
| key_sp, value_sp = llopa_special_kv |
| key_sp = key_sp.to(device=hidden_states.device, dtype=query_states.dtype) |
| value_sp = value_sp.to(device=hidden_states.device, dtype=query_states.dtype) |
| left_len = sp_start |
| right_len = seq_len - sp_end |
| if left_len == 0 and right_len == 0: |
| key_states = key_sp |
| value_states = value_sp |
| else: |
| parts = [] |
| if left_len > 0: |
| parts.append(hidden_states[:, :left_len, :]) |
| if right_len > 0: |
| parts.append(hidden_states[:, sp_end:, :]) |
| non_special = torch.cat(parts, dim=1) if parts else hidden_states[:, :0, :] |
| if non_special.size(1) > 0: |
| non_hidden_shape = (*non_special.shape[:-1], -1, self.head_dim) |
| key_non = self.k_proj(non_special).view(non_hidden_shape).transpose(1, 2) |
| value_non = self.v_proj(non_special).view(non_hidden_shape).transpose(1, 2) |
| if left_len > 0 and right_len > 0: |
| cos_non = torch.cat([cos[:, :left_len, :], cos[:, sp_end:, :]], dim=1) |
| sin_non = torch.cat([sin[:, :left_len, :], sin[:, sp_end:, :]], dim=1) |
| elif left_len > 0: |
| cos_non = cos[:, :left_len, :] |
| sin_non = sin[:, :left_len, :] |
| else: |
| cos_non = cos[:, sp_end:, :] |
| sin_non = sin[:, sp_end:, :] |
| key_non = (key_non * cos_non.unsqueeze(1)) + (rotate_half(key_non) * sin_non.unsqueeze(1)) |
| key_parts = [] |
| value_parts = [] |
| if left_len > 0: |
| key_parts.append(key_non[:, :, :left_len, :]) |
| value_parts.append(value_non[:, :, :left_len, :]) |
| key_parts.append(key_sp) |
| value_parts.append(value_sp) |
| if right_len > 0: |
| key_parts.append(key_non[:, :, left_len:, :]) |
| value_parts.append(value_non[:, :, left_len:, :]) |
| key_states = torch.cat(key_parts, dim=2) |
| value_states = torch.cat(value_parts, dim=2) |
| else: |
| key_states = key_sp |
| value_states = value_sp |
|
|
| if key_states is None or value_states is None: |
| key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| key_states = (key_states * cos_u) + (rotate_half(key_states) * sin_u) |
|
|
| if past_key_values is not None: |
| |
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
| key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
| attn_impl = _resolve_attn_impl(self.config) |
| if isinstance(extra_prefix_kv, (tuple, list)) and len(extra_prefix_kv) == 2: |
| prefix_key_states, prefix_value_states = extra_prefix_kv |
| if isinstance(prefix_key_states, torch.Tensor) and isinstance(prefix_value_states, torch.Tensor): |
| prefix_key_states = prefix_key_states.to(device=hidden_states.device, dtype=query_states.dtype) |
| prefix_value_states = prefix_value_states.to(device=hidden_states.device, dtype=query_states.dtype) |
| if prefix_key_states.numel() > 0 and int(prefix_key_states.size(2)) > 0: |
| local_key_len = int(key_states.size(2)) |
| batch_size = int(hidden_states.size(0)) |
| query_len = int(hidden_states.size(1)) |
| prefix_len = int(prefix_key_states.size(2)) |
| key_states = torch.cat([prefix_key_states, key_states], dim=2) |
| value_states = torch.cat([prefix_value_states, value_states], dim=2) |
|
|
| if attn_impl == "flash_attention_2": |
| local_valid_source = extra_prefix_local_valid_mask |
| if local_valid_source is None and isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 2: |
| local_valid_source = attention_mask |
| query_prefix_enabled = True |
| if isinstance(extra_prefix_query_mask, torch.Tensor): |
| query_prefix_enabled = bool(extra_prefix_query_mask.to(device=hidden_states.device, dtype=torch.bool).all().item()) |
| if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 4: |
| attention_mask = _tri_build_self_prefix_attention_mask( |
| local_attention_mask=attention_mask, |
| batch_size=batch_size, |
| query_len=query_len, |
| local_key_len=local_key_len, |
| prefix_valid_mask=extra_prefix_valid_mask, |
| query_prefix_mask=extra_prefix_query_mask, |
| local_valid_mask=local_valid_source, |
| device=hidden_states.device, |
| dtype=hidden_states.dtype, |
| ) |
| attn_impl = "eager" |
| elif query_prefix_enabled: |
| prefix_valid = _tri_expand_valid_mask( |
| extra_prefix_valid_mask, |
| batch_size=batch_size, |
| total_len=prefix_len, |
| device=hidden_states.device, |
| ) |
| prefix_all_valid = bool(prefix_valid.all().item()) |
| if attention_mask is None and prefix_all_valid: |
| pass |
| else: |
| local_valid = _tri_expand_valid_mask( |
| local_valid_source, |
| batch_size=batch_size, |
| total_len=local_key_len, |
| device=hidden_states.device, |
| ) |
| attention_mask = torch.cat( |
| [prefix_valid.to(dtype=torch.long), local_valid.to(dtype=torch.long)], |
| dim=1, |
| ) |
| else: |
| attention_mask = _tri_build_self_prefix_attention_mask( |
| local_attention_mask=attention_mask if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 4 else None, |
| batch_size=batch_size, |
| query_len=query_len, |
| local_key_len=local_key_len, |
| prefix_valid_mask=extra_prefix_valid_mask, |
| query_prefix_mask=extra_prefix_query_mask, |
| local_valid_mask=local_valid_source, |
| device=hidden_states.device, |
| dtype=hidden_states.dtype, |
| ) |
| attn_impl = "eager" |
| else: |
| local_valid_source = extra_prefix_local_valid_mask |
| if local_valid_source is None and isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 2: |
| local_valid_source = attention_mask |
| attention_mask = _tri_build_self_prefix_attention_mask( |
| local_attention_mask=attention_mask if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 4 else None, |
| batch_size=batch_size, |
| query_len=query_len, |
| local_key_len=local_key_len, |
| prefix_valid_mask=extra_prefix_valid_mask, |
| query_prefix_mask=extra_prefix_query_mask, |
| local_valid_mask=local_valid_source, |
| device=hidden_states.device, |
| dtype=hidden_states.dtype, |
| ) |
|
|
| attention_interface: Callable = eager_attention_forward |
| if attn_impl != "eager": |
| attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl] |
| if attn_impl == "flash_attention_2" and attention_mask is None and "position_ids" in kwargs: |
| |
| |
| |
| |
| kwargs = dict(kwargs) |
| kwargs.pop("position_ids", None) |
|
|
| attn_output, attn_weights = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| scaling=self.scaling, |
| **kwargs, |
| ) |
|
|
| attn_output = _apply_sdpa_output_gate(self, hidden_states, attn_output) |
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| attn_output = self.o_proj(attn_output) |
| return attn_output, attn_weights |
|
|
|
|
| class LlamaDecoderLayer(GradientCheckpointingLayer): |
| def __init__(self, config: LlamaConfig, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.partial_gradient_checkpointing = False |
| self.layer_idx = layer_idx |
|
|
| self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) |
| self.replay_cross_attn = None |
| if _replay_module_enabled(config): |
| self.replay_cross_attn = LlamaAttention(config=config, layer_idx=layer_idx) |
|
|
| self.mlp = LlamaMLP(config) |
| self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.register_load_state_dict_pre_hook(_seed_replay_cross_attn_from_self_pre_hook) |
|
|
| def ensure_replay_cross_attn(self) -> Optional[LlamaAttention]: |
| replay_attn = getattr(self, "replay_cross_attn", None) |
| if replay_attn is not None: |
| return replay_attn |
| if self.training: |
| raise RuntimeError( |
| "Replay attention was requested at runtime, but replay parameters were not initialized. " |
| "Build the model with replay_module=cross before training." |
| ) |
| replay_attn = LlamaAttention(config=self.self_attn.config, layer_idx=self.layer_idx) |
| replay_attn.to( |
| device=self.self_attn.q_proj.weight.device, |
| dtype=self.self_attn.q_proj.weight.dtype, |
| ) |
| self.replay_cross_attn = replay_attn |
| self.sync_replay_cross_attn_from_self() |
| return replay_attn |
|
|
| def sync_replay_cross_attn_from_self(self) -> None: |
| replay_attn = getattr(self, "replay_cross_attn", None) |
| if replay_attn is None: |
| return |
| _copy_attention_module_state(replay_attn, self.self_attn) |
|
|
| @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> torch.Tensor: |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
| |
| hidden_states, _ = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| if self.training and self.partial_gradient_checkpointing: |
| hidden_states = _checkpoint_mlp(self.mlp, hidden_states) |
| else: |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
| return hidden_states |
|
|
|
|
| @auto_docstring |
| class LlamaPreTrainedModel(PreTrainedModel): |
| config: LlamaConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["LlamaDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| _supports_flash_attn = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
|
|
| _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": LlamaDecoderLayer, |
| "attentions": LlamaAttention, |
| } |
|
|
|
|
| @auto_docstring |
| class LlamaModel(LlamaPreTrainedModel): |
| def __init__(self, config: LlamaConfig): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| self.layers = nn.ModuleList( |
| [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| ) |
| self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.rotary_emb = LlamaRotaryEmbedding(config=config) |
| self.gradient_checkpointing = False |
| self.llopa_num_specials = int(getattr(config, "llopa_num_specials", 0) or 0) |
| self.llopa_specials = None |
| if self.llopa_num_specials > 0: |
| self.llopa_specials = nn.ParameterList( |
| [nn.Parameter(torch.empty(self.llopa_num_specials, config.hidden_size)) |
| for _ in range(config.num_hidden_layers)] |
| ) |
| init_range = float(getattr(config, "initializer_range", 0.02)) |
| for p in self.llopa_specials: |
| nn.init.normal_(p, mean=0.0, std=init_range) |
|
|
| |
| self.post_init() |
| for layer in self.layers: |
| if hasattr(layer, "sync_replay_cross_attn_from_self"): |
| layer.sync_replay_cross_attn_from_self() |
|
|
| def get_input_embeddings(self): |
| return self.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.embed_tokens = value |
|
|
| @check_model_inputs |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> BaseModelOutputWithPast: |
| skip_upper_attention_layers = int(kwargs.pop("skip_upper_attention_layers", 0) or 0) |
| solo_attention_layers = int(kwargs.pop("solo_attention_layers", 0) or 0) |
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
| if inputs_embeds is None: |
| inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) |
|
|
| if use_cache and past_key_values is None: |
| past_key_values = _safe_dynamic_cache(self.config) |
|
|
| if cache_position is None: |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| cache_position: torch.Tensor = torch.arange( |
| past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
| ) |
|
|
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
|
|
| causal_mask = create_causal_mask( |
| config=self.config, |
| input_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| cache_position=cache_position, |
| past_key_values=past_key_values, |
| position_ids=position_ids, |
| ) |
|
|
| hidden_states = inputs_embeds |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
| skip_from_layer = max(0, min(int(skip_upper_attention_layers), self.config.num_hidden_layers)) |
| solo_from_layer = max(0, min(int(solo_attention_layers), self.config.num_hidden_layers)) |
| solo_mask = None |
| if solo_from_layer > 0: |
| past_len = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| solo_mask = _build_self_only_mask(inputs_embeds.size(0), inputs_embeds.size(1), past_len, inputs_embeds.device, inputs_embeds.dtype) |
| use_checkpoint = should_use_checkpoint(self) |
| _llopa_log_memory( |
| self, |
| "after_vanilla_embed", |
| sequence_len=hidden_states.size(1), |
| ) |
|
|
| for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): |
| if skip_from_layer > 0 and layer_idx >= skip_from_layer: |
| hidden_states = _run_decoder_mlp_only(decoder_layer, hidden_states, use_checkpoint) |
| _llopa_log_memory( |
| self, |
| f"after_vanilla_upper_skip_layer_{layer_idx}", |
| sequence_len=hidden_states.size(1), |
| ) |
| else: |
| layer_mask = causal_mask |
| if solo_from_layer > 0 and layer_idx >= solo_from_layer: |
| layer_mask = solo_mask |
| hidden_states = decoder_layer( |
| hidden_states, |
| attention_mask=layer_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
| if skip_from_layer > 0 and layer_idx + 1 == skip_from_layer: |
| _llopa_log_memory( |
| self, |
| "after_vanilla_lower", |
| sequence_len=hidden_states.size(1), |
| ) |
|
|
| hidden_states = self.norm(hidden_states) |
| _llopa_log_memory( |
| self, |
| "after_vanilla_final", |
| sequence_len=hidden_states.size(1), |
| ) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=past_key_values, |
| ) |
|
|
|
|
| @auto_docstring |
| class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): |
| _tied_weights_keys = ["lm_head.weight"] |
| _tp_plan = {"lm_head": "colwise_rep"} |
| _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} |
|
|
| def __init__(self, config, **kwargs): |
| |
| init_dtype = kwargs.pop("dtype", None) |
| if init_dtype is None: |
| init_dtype = kwargs.pop("torch_dtype", None) |
| super().__init__(config, **kwargs) |
| self.model = LlamaModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
| if isinstance(init_dtype, str): |
| alias = { |
| "bf16": torch.bfloat16, |
| "bfloat16": torch.bfloat16, |
| "fp16": torch.float16, |
| "float16": torch.float16, |
| "fp32": torch.float32, |
| "float32": torch.float32, |
| } |
| init_dtype = alias.get(init_dtype.lower(), None) |
| if isinstance(init_dtype, torch.dtype): |
| self.to(dtype=init_dtype) |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def set_decoder(self, decoder): |
| self.model = decoder |
|
|
| def get_decoder(self): |
| return self.model |
|
|
| @can_return_tuple |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> CausalLMOutputWithPast: |
| r""" |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, LlamaForCausalLM |
| |
| >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") |
| >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") |
| |
| >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| ```""" |
| skip_upper_attention_layers = int(kwargs.pop("skip_upper_attention_layers", 0) or 0) |
| prefill_lower_layers = int(kwargs.pop("prefill_lower_layers", 0) or 0) |
| prefill_lower_attn = str(kwargs.pop("prefill_lower_attn", "causal") or "causal") |
| prefill_lower_no_upper_attn = bool(kwargs.pop("prefill_lower_no_upper_attn", False)) |
| prefill_lower_upper_attention_mode = _normalize_upper_attention_mode( |
| kwargs.pop( |
| "prefill_lower_upper_attention_mode", |
| getattr(self.config, "capsule_upper_attention_mode", "causal"), |
| ) |
| ) |
| prefill_lower_solo_attention = bool(kwargs.pop("prefill_lower_solo_attention", False)) |
| prefill_lower_solo_attention_v2 = bool(kwargs.pop("prefill_lower_solo_attention_v2", False)) |
| prefill_lower_solo_attention_v2_with_bos = bool(kwargs.pop("prefill_lower_solo_attention_v2_with_bos", False)) |
| prefill_lower_freeze_runtime = bool(kwargs.pop("prefill_lower_freeze_runtime", False)) |
| prefill_lower_split_start = kwargs.pop("prefill_lower_split_start", None) |
| prefill_lower_system_len = kwargs.pop(PREFILL_LOWER_SYSTEM_LEN_KEY, None) |
| prefill_lower_replay_user_prefix_keep_len = kwargs.pop(PREFILL_LOWER_REPLAY_USER_PREFIX_KEEP_LEN_KEY, None) |
| prefill_lower_replay_user_start = kwargs.pop(PREFILL_LOWER_REPLAY_USER_START_KEY, None) |
| prefill_lower_replay_user_len = kwargs.pop(PREFILL_LOWER_REPLAY_USER_LEN_KEY, None) |
| prefill_lower_system_prefill = str(kwargs.pop("prefill_lower_system_prefill", "no_bos_system") or "no_bos_system") |
| prefill_lower_replay_module = _normalize_replay_module( |
| kwargs.pop( |
| "prefill_lower_replay_module", |
| kwargs.pop( |
| "prefill_lower_last_layer_module", |
| getattr(self.config, "capsule_replay_module", getattr(self.config, "capsule_last_layer_module", "none")), |
| ), |
| ) |
| ) |
| prefill_lower_replay_per_layers = _normalize_replay_per_layers( |
| kwargs.pop( |
| "prefill_lower_replay_per_layers", |
| getattr(self.config, "capsule_replay_per_layers", -1), |
| ) |
| ) |
| prefill_lower_see_past_assistant = bool(kwargs.pop("prefill_lower_see_past_assistant", False)) |
| assistant_header_start = kwargs.pop("assistant_header_start", None) |
| assistant_header_starts = kwargs.pop(ASSISTANT_HEADER_STARTS_KEY, None) |
| assistant_turn_ends = kwargs.pop(ASSISTANT_TURN_ENDS_KEY, None) |
| assistant_header_start_mask = kwargs.pop(ASSISTANT_HEADER_START_MASK_KEY, None) |
| runtime_llopa_prefill = bool(kwargs.pop("runtime_llopa_prefill", False)) |
| runtime_llopa_layers = int(kwargs.pop("runtime_llopa_layers", 0) or 0) |
| runtime_llopa_attn = str(kwargs.pop("runtime_llopa_attn", "causal") or "causal") |
| runtime_llopa_no_upper_attn = bool(kwargs.pop("runtime_llopa_no_upper_attn", False)) |
| runtime_llopa_replay_module = _normalize_replay_module( |
| kwargs.pop( |
| "runtime_llopa_replay_module", |
| kwargs.pop("runtime_llopa_last_layer_module", prefill_lower_replay_module), |
| ) |
| ) |
| runtime_llopa_replay_per_layers = _normalize_replay_per_layers( |
| kwargs.pop( |
| "runtime_llopa_replay_per_layers", |
| prefill_lower_replay_per_layers, |
| ) |
| ) |
| llopa_v2_decode = bool(kwargs.pop("llopa_v2_decode", False)) |
| llopa_v2_decode_layers = int( |
| kwargs.pop( |
| "llopa_v2_decode_layers", |
| kwargs.pop("llopa_v2_lower_k", runtime_llopa_layers), |
| ) |
| or 0 |
| ) |
| llopa_v2_decode_no_upper_attn = bool( |
| kwargs.pop("llopa_v2_decode_no_upper_attn", runtime_llopa_no_upper_attn) |
| ) |
| llopa_v2_decode_replay_module = _normalize_replay_module( |
| kwargs.pop("llopa_v2_decode_replay_module", runtime_llopa_replay_module) |
| ) |
| llopa_v2_decode_replay_per_layers = _normalize_replay_per_layers( |
| kwargs.pop("llopa_v2_decode_replay_per_layers", runtime_llopa_replay_per_layers) |
| ) |
| prefill_lower_solo_attention_v2 = bool( |
| prefill_lower_solo_attention_v2 or prefill_lower_solo_attention_v2_with_bos |
| ) |
| if ( |
| llopa_v2_decode |
| and labels is None |
| and inputs_embeds is None |
| and past_key_values is not None |
| and isinstance(input_ids, torch.Tensor) |
| and input_ids.dim() == 2 |
| and input_ids.size(1) > 0 |
| and llopa_v2_decode_layers > 0 |
| ): |
| tri_outputs = self.model.tri_forward_assistant( |
| assistant_ids=input_ids, |
| lower_k=llopa_v2_decode_layers, |
| pkv=past_key_values, |
| S=0, |
| U=0, |
| write_cache=bool(use_cache if use_cache is not None else True), |
| prefill_mode="lower", |
| no_upper_attn=llopa_v2_decode_no_upper_attn, |
| align_cache_position_to_layer_past=False, |
| replay_module=llopa_v2_decode_replay_module, |
| replay_per_layers=llopa_v2_decode_replay_per_layers, |
| ) |
| hidden_states = tri_outputs.last_hidden_state |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
| return CausalLMOutputWithPast( |
| loss=None, |
| logits=logits, |
| past_key_values=tri_outputs.past_key_values, |
| hidden_states=tri_outputs.hidden_states, |
| attentions=tri_outputs.attentions, |
| ) |
| if prefill_lower_split_start is None: |
| prefill_lower_split_start = assistant_header_start |
| if runtime_llopa_prefill and inputs_embeds is None and (bool(use_cache) or past_key_values is not None): |
| tri_outputs = self.tri_runtime_llopa_prompt_prefill_forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| use_cache=use_cache, |
| logits_to_keep=logits_to_keep, |
| past_key_values=past_key_values, |
| lower_k=runtime_llopa_layers, |
| prefill_attn=runtime_llopa_attn, |
| no_upper_attn=runtime_llopa_no_upper_attn, |
| replay_module=runtime_llopa_replay_module, |
| replay_per_layers=runtime_llopa_replay_per_layers, |
| ) |
| if tri_outputs is not None: |
| return tri_outputs |
| if prefill_lower_layers > 0 and inputs_embeds is None: |
| tri_outputs = None |
| if labels is not None and past_key_values is None and not bool(use_cache): |
| if prefill_lower_solo_attention_v2: |
| tri_outputs = self.tri_vanilla_solo_attention_v2_train_forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=labels, |
| logits_to_keep=logits_to_keep, |
| prefill_lower_layers=prefill_lower_layers, |
| prefill_attn=prefill_lower_attn, |
| split_start_hint=prefill_lower_split_start, |
| system_len_hint=prefill_lower_system_len, |
| system_prefill=prefill_lower_system_prefill, |
| with_bos=prefill_lower_solo_attention_v2_with_bos, |
| ) |
| elif prefill_lower_solo_attention: |
| tri_outputs = self.tri_vanilla_solo_prefix_train_forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=labels, |
| logits_to_keep=logits_to_keep, |
| prefill_lower_layers=prefill_lower_layers, |
| prefill_attn=prefill_lower_attn, |
| split_start_hint=prefill_lower_split_start, |
| system_len_hint=prefill_lower_system_len, |
| system_prefill=prefill_lower_system_prefill, |
| ) |
| else: |
| tri_outputs = self.tri_vanilla_train_forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=labels, |
| logits_to_keep=logits_to_keep, |
| prefill_lower_layers=prefill_lower_layers, |
| prefill_attn=prefill_lower_attn, |
| split_start_hint=prefill_lower_split_start, |
| system_len_hint=prefill_lower_system_len, |
| system_prefill=prefill_lower_system_prefill, |
| assistant_header_starts_hint=assistant_header_starts, |
| assistant_turn_ends_hint=assistant_turn_ends, |
| assistant_header_start_mask_hint=assistant_header_start_mask, |
| see_past_assistant=prefill_lower_see_past_assistant, |
| no_upper_attn=prefill_lower_no_upper_attn, |
| upper_attention_mode=prefill_lower_upper_attention_mode, |
| replay_module=prefill_lower_replay_module, |
| replay_per_layers=prefill_lower_replay_per_layers, |
| ) |
| elif prefill_lower_solo_attention_v2 and past_key_values is None and not bool(use_cache): |
| tri_outputs = self.tri_vanilla_solo_attention_v2_infer_forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| logits_to_keep=logits_to_keep, |
| prefill_lower_layers=prefill_lower_layers, |
| prefill_attn=prefill_lower_attn, |
| split_start_hint=prefill_lower_split_start, |
| system_len_hint=prefill_lower_system_len, |
| system_prefill=prefill_lower_system_prefill, |
| with_bos=prefill_lower_solo_attention_v2_with_bos, |
| ) |
| elif prefill_lower_solo_attention and past_key_values is None and not bool(use_cache): |
| tri_outputs = self.tri_vanilla_solo_prefix_infer_forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| logits_to_keep=logits_to_keep, |
| prefill_lower_layers=prefill_lower_layers, |
| prefill_attn=prefill_lower_attn, |
| split_start_hint=prefill_lower_split_start, |
| system_len_hint=prefill_lower_system_len, |
| system_prefill=prefill_lower_system_prefill, |
| ) |
| elif prefill_lower_freeze_runtime and past_key_values is None and not bool(use_cache): |
| tri_outputs = self.tri_vanilla_frozen_prefix_infer_forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| logits_to_keep=logits_to_keep, |
| prefill_lower_layers=prefill_lower_layers, |
| prefill_attn=prefill_lower_attn, |
| split_start_hint=prefill_lower_split_start, |
| system_len_hint=prefill_lower_system_len, |
| system_prefill=prefill_lower_system_prefill, |
| ) |
| elif prefill_lower_solo_attention and past_key_values is None and not bool(use_cache): |
| tri_outputs = self.tri_vanilla_solo_prefix_infer_forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| logits_to_keep=logits_to_keep, |
| prefill_lower_layers=prefill_lower_layers, |
| prefill_attn=prefill_lower_attn, |
| split_start_hint=prefill_lower_split_start, |
| system_len_hint=prefill_lower_system_len, |
| system_prefill=prefill_lower_system_prefill, |
| ) |
| elif bool(use_cache) or past_key_values is not None: |
| tri_outputs = self.tri_vanilla_prefill_decode_forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=labels, |
| use_cache=use_cache, |
| logits_to_keep=logits_to_keep, |
| past_key_values=past_key_values, |
| prefill_lower_layers=prefill_lower_layers, |
| prefill_attn=prefill_lower_attn, |
| split_start_hint=prefill_lower_split_start, |
| system_len_hint=prefill_lower_system_len, |
| system_prefill=prefill_lower_system_prefill, |
| assistant_header_starts_hint=assistant_header_starts, |
| assistant_turn_ends_hint=assistant_turn_ends, |
| assistant_header_start_mask_hint=assistant_header_start_mask, |
| replay_user_prefix_keep_len_hint=prefill_lower_replay_user_prefix_keep_len, |
| replay_user_start_hint=prefill_lower_replay_user_start, |
| replay_user_len_hint=prefill_lower_replay_user_len, |
| see_past_assistant=prefill_lower_see_past_assistant, |
| no_upper_attn=prefill_lower_no_upper_attn, |
| replay_module=prefill_lower_replay_module, |
| replay_per_layers=prefill_lower_replay_per_layers, |
| ) |
| if tri_outputs is not None: |
| return tri_outputs |
| outputs: BaseModelOutputWithPast = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| skip_upper_attention_layers=skip_upper_attention_layers, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
| |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| class LlamaForSequenceClassification(GenericForSequenceClassification, LlamaPreTrainedModel): ... |
|
|
|
|
| class LlamaForQuestionAnswering(GenericForQuestionAnswering, LlamaPreTrainedModel): |
| base_model_prefix = "transformer" |
|
|
|
|
| class LlamaForTokenClassification(GenericForTokenClassification, LlamaPreTrainedModel): ... |
|
|
|
|
| |
| |
| |
| def _tri_arange(start: int, length: int, device) -> torch.LongTensor: |
| return torch.arange(int(start), int(start + length), device=device, dtype=torch.long) |
|
|
| def _prefill_layer_indices(n_layers: int, mode: str, k: int) -> list[int]: |
| if (mode or "lower") == "periodic": |
| step = int(k) |
| if step <= 0: |
| return [] |
| return list(range(0, n_layers, step)) |
| K_eff = max(0, min(int(k), n_layers)) |
| return list(range(K_eff)) |
|
|
|
|
| def _safe_len_seqlen(x) -> int: |
| if x is None: |
| return 0 |
| if hasattr(x, "shape") and x.dim() >= 3: |
| return int(x.shape[2]) |
| return 0 |
|
|
|
|
| def _layer_past_len(pkv: Cache, li: int) -> int: |
| if pkv is None: |
| return 0 |
| if hasattr(pkv, "key_cache") and isinstance(pkv.key_cache, (list, tuple)) and li < len(pkv.key_cache): |
| return _safe_len_seqlen(pkv.key_cache[li]) |
| if hasattr(pkv, "layers") and isinstance(pkv.layers, (list, tuple)) and li < len(pkv.layers): |
| layer = pkv.layers[li] |
| keys = None if layer is None else getattr(layer, "keys", None) |
| return _safe_len_seqlen(keys) |
| try: |
| kv = pkv[li] |
| if isinstance(kv, (list, tuple)) and len(kv) >= 1: |
| return _safe_len_seqlen(kv[0]) |
| except Exception: |
| pass |
| return 0 |
|
|
|
|
| def _build_tri_mask_local(B: int, Tq: int, past_len: int, device, dtype): |
| Tk = past_len + Tq |
| i = torch.arange(Tq, device=device).unsqueeze(1) |
| j = torch.arange(Tk, device=device).unsqueeze(0) |
| allow = j <= (past_len + i) |
| neg_inf = torch.tensor(torch.finfo(torch.float32).min, device=device) |
| mask = torch.where(allow, torch.tensor(0.0, device=device), neg_inf).to(dtype) |
| return mask.view(1, 1, Tq, Tk).expand(B, 1, Tq, Tk) |
|
|
| def _build_tri_mask_full(B: int, Tq: int, past_len: int, device, dtype): |
| Tk = past_len + Tq |
| return torch.zeros((B, 1, Tq, Tk), device=device, dtype=dtype) |
|
|
| def _build_self_only_mask(B: int, Tq: int, past_len: int, device, dtype): |
| Tk = past_len + Tq |
| min_value = torch.finfo(dtype).min |
| mask = torch.full((B, 1, Tq, Tk), min_value, device=device, dtype=dtype) |
| idx = torch.arange(Tq, device=device) |
| mask[:, :, idx, past_len + idx] = 0 |
| return mask |
|
|
|
|
| def _tri_expand_valid_mask( |
| valid_mask: Optional[torch.Tensor], |
| *, |
| batch_size: int, |
| total_len: int, |
| device, |
| ) -> torch.BoolTensor: |
| if valid_mask is None: |
| return torch.ones((int(batch_size), int(total_len)), device=device, dtype=torch.bool) |
| valid_mask = valid_mask.to(device=device, dtype=torch.bool) |
| if valid_mask.dim() == 1: |
| valid_mask = valid_mask.view(1, -1) |
| if int(valid_mask.size(0)) == 1 and int(batch_size) > 1: |
| valid_mask = valid_mask.expand(int(batch_size), -1) |
| if int(valid_mask.size(1)) < int(total_len): |
| pad = torch.ones( |
| (int(valid_mask.size(0)), int(total_len) - int(valid_mask.size(1))), |
| device=device, |
| dtype=torch.bool, |
| ) |
| valid_mask = torch.cat([pad, valid_mask], dim=1) |
| elif int(valid_mask.size(1)) > int(total_len): |
| valid_mask = valid_mask[:, -int(total_len) :] |
| return valid_mask |
|
|
|
|
| def _tri_build_prefix_query_allow_mask( |
| *, |
| batch_size: int, |
| query_len: int, |
| prefix_valid_mask: Optional[torch.Tensor], |
| query_prefix_mask: Optional[torch.Tensor], |
| prefix_len: int, |
| device, |
| ) -> torch.BoolTensor: |
| if int(prefix_len) <= 0: |
| return torch.zeros((int(batch_size), 1, int(query_len), 0), device=device, dtype=torch.bool) |
| prefix_valid = _tri_expand_valid_mask( |
| prefix_valid_mask, |
| batch_size=batch_size, |
| total_len=prefix_len, |
| device=device, |
| ) |
| if query_prefix_mask is None: |
| query_prefix_mask = torch.ones((int(batch_size), int(query_len)), device=device, dtype=torch.bool) |
| else: |
| query_prefix_mask = query_prefix_mask.to(device=device, dtype=torch.bool) |
| if query_prefix_mask.dim() == 1: |
| query_prefix_mask = query_prefix_mask.view(1, -1) |
| if int(query_prefix_mask.size(0)) == 1 and int(batch_size) > 1: |
| query_prefix_mask = query_prefix_mask.expand(int(batch_size), -1) |
| return prefix_valid.view(int(batch_size), 1, 1, int(prefix_len)).expand( |
| int(batch_size), 1, int(query_len), int(prefix_len) |
| ) & query_prefix_mask.view(int(batch_size), 1, int(query_len), 1) |
|
|
|
|
| def _tri_build_self_prefix_attention_mask( |
| *, |
| local_attention_mask: Optional[torch.Tensor], |
| batch_size: int, |
| query_len: int, |
| local_key_len: int, |
| prefix_valid_mask: Optional[torch.Tensor], |
| query_prefix_mask: Optional[torch.Tensor], |
| local_valid_mask: Optional[torch.Tensor], |
| device, |
| dtype, |
| ) -> torch.Tensor: |
| prefix_len = 0 if prefix_valid_mask is None else int(prefix_valid_mask.size(-1)) |
| prefix_allow = _tri_build_prefix_query_allow_mask( |
| batch_size=batch_size, |
| query_len=query_len, |
| prefix_valid_mask=prefix_valid_mask, |
| query_prefix_mask=query_prefix_mask, |
| prefix_len=prefix_len, |
| device=device, |
| ) |
|
|
| if isinstance(local_attention_mask, torch.Tensor) and local_attention_mask.dim() == 4: |
| local_mask = local_attention_mask.to(device=device) |
| local_mask = local_mask[:, :, :, : int(local_key_len)] |
| local_allow = local_mask == 0 |
| else: |
| local_valid = _tri_expand_valid_mask( |
| local_valid_mask, |
| batch_size=batch_size, |
| total_len=local_key_len, |
| device=device, |
| ) |
| past_len = max(int(local_key_len) - int(query_len), 0) |
| q_pos = torch.arange(query_len, device=device, dtype=torch.long).view(1, 1, query_len, 1) |
| k_pos = torch.arange(local_key_len, device=device, dtype=torch.long).view(1, 1, 1, local_key_len) |
| local_allow = (k_pos <= (past_len + q_pos)) & local_valid.view(int(batch_size), 1, 1, int(local_key_len)) |
|
|
| allow = torch.cat([prefix_allow, local_allow], dim=-1) |
| neg_inf = torch.tensor(torch.finfo(torch.float32).min, device=device) |
| zeros = torch.tensor(0.0, device=device) |
| return torch.where(allow, zeros, neg_inf).to(dtype) |
|
|
|
|
| def _tri_build_last_layer_memory_index_batch( |
| *, |
| split_starts: torch.LongTensor, |
| prefix_keep_lens: torch.LongTensor, |
| system_lens: Optional[torch.LongTensor] = None, |
| latest_user_starts: Optional[torch.LongTensor] = None, |
| latest_user_lens: Optional[torch.LongTensor] = None, |
| device, |
| owner=None, |
| bucket_multiple: int = 0, |
| workspace_key: str = "single", |
| ) -> tuple[torch.LongTensor, torch.BoolTensor, torch.LongTensor]: |
| split_starts = split_starts.to(device=device, dtype=torch.long) |
| prefix_keep_lens = prefix_keep_lens.to(device=device, dtype=torch.long) |
| batch_size = int(split_starts.size(0)) |
| if isinstance(system_lens, torch.Tensor): |
| system_lens = system_lens.to(device=device, dtype=torch.long) |
| else: |
| system_lens = None |
| if isinstance(latest_user_starts, torch.Tensor): |
| latest_user_starts = latest_user_starts.to(device=device, dtype=torch.long) |
| else: |
| latest_user_starts = None |
| if isinstance(latest_user_lens, torch.Tensor): |
| latest_user_lens = latest_user_lens.to(device=device, dtype=torch.long) |
| else: |
| latest_user_lens = None |
|
|
| memory_indices = [] |
| memory_lens = torch.zeros((batch_size,), device=device, dtype=torch.long) |
| max_memory_len = 0 |
| for row in range(batch_size): |
| prefix_keep = max(int(prefix_keep_lens[row].item()), 0) |
| split_start = max(int(split_starts[row].item()), 0) |
| system_end = prefix_keep |
| if isinstance(system_lens, torch.Tensor): |
| system_end = min(max(int(system_lens[row].item()), prefix_keep), split_start) |
|
|
| idx_parts = [] |
| if system_end > prefix_keep: |
| idx_parts.append(torch.arange(prefix_keep, system_end, device=device, dtype=torch.long)) |
|
|
| if isinstance(latest_user_starts, torch.Tensor): |
| latest_start = min(max(int(latest_user_starts[row].item()), system_end), split_start) |
| if isinstance(latest_user_lens, torch.Tensor): |
| latest_len = max(int(latest_user_lens[row].item()), 0) |
| latest_end = min(latest_start + latest_len, split_start) |
| else: |
| latest_end = split_start |
| if latest_end > latest_start: |
| idx_parts.append(torch.arange(latest_start, latest_end, device=device, dtype=torch.long)) |
|
|
| if idx_parts: |
| idx = idx_parts[0] if len(idx_parts) == 1 else torch.cat(idx_parts, dim=0) |
| elif split_start > prefix_keep: |
| idx = torch.arange(prefix_keep, split_start, device=device, dtype=torch.long) |
| else: |
| idx = torch.empty((0,), device=device, dtype=torch.long) |
| memory_indices.append(idx) |
| memory_lens[row] = int(idx.numel()) |
| if idx.numel() > max_memory_len: |
| max_memory_len = int(idx.numel()) |
|
|
| bucket_multiple = int(bucket_multiple or 0) |
| target_memory_len = int(max_memory_len) |
| if bucket_multiple > 0 and target_memory_len > 0: |
| target_memory_len = ((target_memory_len + bucket_multiple - 1) // bucket_multiple) * bucket_multiple |
|
|
| gather_idx = None |
| valid_mask = None |
| if owner is not None and bucket_multiple > 0 and not bool(getattr(owner, "training", False)): |
| cache = getattr(owner, "_optimized_llopa_last_layer_memory_workspace_cache", None) |
| if not isinstance(cache, dict): |
| cache = {} |
| key = (str(workspace_key), str(device), int(batch_size), int(target_memory_len)) |
| pair = cache.get(key) |
| if ( |
| not isinstance(pair, tuple) |
| or len(pair) != 2 |
| or not isinstance(pair[0], torch.Tensor) |
| or not isinstance(pair[1], torch.Tensor) |
| or pair[0].shape != (int(batch_size), int(target_memory_len)) |
| or pair[1].shape != (int(batch_size), int(target_memory_len)) |
| or pair[0].device != device |
| or pair[1].device != device |
| ): |
| pair = ( |
| torch.zeros((int(batch_size), int(target_memory_len)), device=device, dtype=torch.long), |
| torch.zeros((int(batch_size), int(target_memory_len)), device=device, dtype=torch.bool), |
| ) |
| cache[key] = pair |
| with contextlib.suppress(Exception): |
| setattr(owner, "_optimized_llopa_last_layer_memory_workspace_cache", cache) |
| gather_idx, valid_mask = pair |
| gather_idx.zero_() |
| valid_mask.zero_() |
|
|
| if gather_idx is None or valid_mask is None: |
| gather_idx = torch.zeros((batch_size, target_memory_len), device=device, dtype=torch.long) |
| valid_mask = torch.zeros((batch_size, target_memory_len), device=device, dtype=torch.bool) |
|
|
| for row, idx in enumerate(memory_indices): |
| width = int(idx.numel()) |
| if width <= 0: |
| continue |
| gather_idx[row, :width] = idx |
| valid_mask[row, :width] = True |
| return gather_idx, valid_mask, memory_lens |
|
|
|
|
| def _tri_resolve_replay_user_window( |
| *, |
| user_len: int, |
| prefix_keep_len: int, |
| latest_user_start: Optional[int], |
| latest_user_len: Optional[int], |
| ) -> tuple[int, int, int]: |
| user_len = max(int(user_len), 0) |
| prefix_keep_len = min(max(int(prefix_keep_len), 0), user_len) |
| if latest_user_start is None: |
| return prefix_keep_len, prefix_keep_len, user_len |
|
|
| start = min(max(int(latest_user_start), prefix_keep_len), user_len) |
| if latest_user_len is None: |
| end = user_len |
| else: |
| end = min(max(start + int(latest_user_len), start), user_len) |
| if end <= start and user_len > prefix_keep_len: |
| start = prefix_keep_len |
| end = user_len |
| return prefix_keep_len, start, end |
|
|
|
|
| def _tri_build_last_layer_query_mask( |
| *, |
| prefix_keep_lens: torch.LongTensor, |
| valid_lens: torch.LongTensor, |
| total_len: int, |
| device, |
| ) -> torch.BoolTensor: |
| if total_len <= 0: |
| return torch.zeros((int(valid_lens.size(0)), 0), device=device, dtype=torch.bool) |
| prefix_keep_lens = prefix_keep_lens.to(device=device, dtype=torch.long) |
| valid_lens = valid_lens.to(device=device, dtype=torch.long) |
| positions = torch.arange(total_len, device=device, dtype=torch.long).unsqueeze(0) |
| return (positions >= prefix_keep_lens.unsqueeze(1)) & (positions < valid_lens.unsqueeze(1)) |
|
|
|
|
| def _tri_store_last_layer_memory( |
| pkv: Optional[Cache], |
| *, |
| hidden_states: Optional[torch.Tensor], |
| position_ids: Optional[torch.LongTensor], |
| valid_mask: Optional[torch.Tensor], |
| module_type: str, |
| replay_per_layers: int = -1, |
| ) -> None: |
| if pkv is None: |
| return |
| module_type = _normalize_replay_module(module_type) |
| replay_per_layers = _normalize_replay_per_layers(replay_per_layers) |
| if module_type == "none": |
| return |
| if not isinstance(hidden_states, torch.Tensor) or hidden_states.numel() == 0 or int(hidden_states.size(1)) == 0: |
| return |
| try: |
| setattr(pkv, "_tri_last_layer_memory_hidden", hidden_states) |
| setattr(pkv, "_tri_last_layer_memory_position_ids", position_ids) |
| setattr(pkv, "_tri_last_layer_memory_valid_mask", valid_mask) |
| setattr(pkv, "_tri_last_layer_module", module_type) |
| setattr(pkv, "_tri_replay_module", module_type) |
| setattr(pkv, "_tri_replay_per_layers", int(replay_per_layers)) |
| except Exception: |
| pass |
|
|
|
|
| def _tri_segment_position_ids( |
| *, |
| batch_size: int, |
| start: int, |
| length: int, |
| device, |
| ) -> torch.LongTensor: |
| if int(length) <= 0: |
| return torch.empty((int(batch_size), 0), device=device, dtype=torch.long) |
| position_ids = _tri_arange(int(start), int(length), device).unsqueeze(0) |
| return position_ids.expand(int(batch_size), -1) |
|
|
|
|
| def _tri_valid_lengths( |
| valid_mask: Optional[torch.Tensor], |
| *, |
| batch_size: int, |
| default_len: int, |
| device, |
| ) -> torch.LongTensor: |
| if valid_mask is None: |
| return torch.full((int(batch_size),), int(default_len), device=device, dtype=torch.long) |
| valid_mask = valid_mask.to(device=device, dtype=torch.bool) |
| if valid_mask.dim() == 1: |
| valid_mask = valid_mask.view(1, -1) |
| return valid_mask.sum(dim=1, dtype=torch.long) |
|
|
|
|
| def _tri_cu_seq_lens_from_lengths(lengths: torch.LongTensor) -> tuple[torch.IntTensor, int]: |
| lengths = lengths.to(dtype=torch.int32) |
| if lengths.numel() == 0: |
| return torch.zeros((1,), device=lengths.device, dtype=torch.int32), 0 |
| cu_seq_lens = torch.cat( |
| [ |
| torch.zeros((1,), device=lengths.device, dtype=torch.int32), |
| torch.cumsum(lengths, dim=0, dtype=torch.int32), |
| ], |
| dim=0, |
| ) |
| return cu_seq_lens, int(lengths.max().item()) |
|
|
|
|
| def _tri_build_replay_attention_mask( |
| *, |
| batch_size: int, |
| query_len: int, |
| local_valid_mask: Optional[torch.Tensor], |
| memory_valid_mask: Optional[torch.Tensor], |
| module_type: str, |
| device, |
| dtype, |
| ) -> Optional[torch.Tensor]: |
| module_type = _normalize_last_layer_module(module_type) |
| memory_len = 0 if memory_valid_mask is None else int(memory_valid_mask.size(1)) |
| if module_type == "none" or query_len <= 0 or (module_type == "cross" and memory_len <= 0): |
| return None |
|
|
| pieces = [] |
| if memory_len > 0: |
| mem_valid = memory_valid_mask.to(device=device, dtype=torch.bool).view(batch_size, 1, 1, memory_len) |
| pieces.append(mem_valid.expand(batch_size, 1, query_len, memory_len)) |
|
|
| if module_type == "self": |
| if local_valid_mask is None: |
| local_valid_mask = torch.ones((batch_size, query_len), device=device, dtype=torch.bool) |
| else: |
| local_valid_mask = local_valid_mask.to(device=device, dtype=torch.bool) |
| local_len = int(local_valid_mask.size(1)) |
| q_pos = torch.arange(query_len, device=device, dtype=torch.long).view(1, 1, query_len, 1) |
| k_pos = torch.arange(local_len, device=device, dtype=torch.long).view(1, 1, 1, local_len) |
| local_allow = (k_pos <= q_pos) & local_valid_mask.view(batch_size, 1, 1, local_len) |
| pieces.append(local_allow) |
|
|
| if not pieces: |
| return None |
|
|
| allow = torch.cat(pieces, dim=-1) |
| neg_inf = torch.tensor(torch.finfo(torch.float32).min, device=device) |
| zeros = torch.tensor(0.0, device=device) |
| return torch.where(allow, zeros, neg_inf).to(dtype) |
|
|
|
|
| def _tri_project_memory_kv( |
| *, |
| attn_module: "LlamaAttention", |
| memory_hidden_states: torch.Tensor, |
| memory_position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| target_device, |
| target_dtype, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| memory_hidden_states = memory_hidden_states.to(device=target_device, dtype=target_dtype) |
| memory_shape = (*memory_hidden_states.shape[:-1], -1, attn_module.head_dim) |
| memory_key_states = attn_module.k_proj(memory_hidden_states).view(memory_shape).transpose(1, 2) |
| memory_value_states = attn_module.v_proj(memory_hidden_states).view(memory_shape).transpose(1, 2) |
| memory_cos, memory_sin = memory_position_embeddings |
| memory_key_states = ( |
| memory_key_states * memory_cos.unsqueeze(1) |
| ) + (rotate_half(memory_key_states) * memory_sin.unsqueeze(1)) |
| return memory_key_states, memory_value_states |
|
|
|
|
| def _tri_replay_attention_forward( |
| *, |
| attn_module: "LlamaAttention", |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| local_valid_mask: Optional[torch.Tensor], |
| query_replay_mask: Optional[torch.Tensor], |
| memory_hidden_states: Optional[torch.Tensor], |
| memory_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]], |
| memory_valid_mask: Optional[torch.Tensor], |
| module_type: str, |
| ) -> torch.Tensor: |
| module_type = _normalize_last_layer_module(module_type) |
| if module_type == "none" or hidden_states.size(1) == 0: |
| return hidden_states.new_zeros(hidden_states.shape) |
| if not isinstance(memory_hidden_states, torch.Tensor) or memory_hidden_states.numel() == 0 or int(memory_hidden_states.size(1)) == 0: |
| return hidden_states.new_zeros(hidden_states.shape) |
|
|
| batch_size, query_len = int(hidden_states.size(0)), int(hidden_states.size(1)) |
| input_shape = hidden_states.shape[:-1] |
| hidden_shape = (*input_shape, -1, attn_module.head_dim) |
|
|
| query_states = attn_module.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| cos, sin = position_embeddings |
| cos_u = cos.unsqueeze(1) |
| sin_u = sin.unsqueeze(1) |
| query_states = (query_states * cos_u) + (rotate_half(query_states) * sin_u) |
|
|
| memory_key_states, memory_value_states = _tri_project_memory_kv( |
| attn_module=attn_module, |
| memory_hidden_states=memory_hidden_states, |
| memory_position_embeddings=memory_position_embeddings, |
| target_device=hidden_states.device, |
| target_dtype=hidden_states.dtype, |
| ) |
|
|
| if module_type == "self": |
| local_key_states = attn_module.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| local_value_states = attn_module.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| local_key_states = (local_key_states * cos_u) + (rotate_half(local_key_states) * sin_u) |
| key_states = torch.cat([memory_key_states, local_key_states], dim=2) |
| value_states = torch.cat([memory_value_states, local_value_states], dim=2) |
| else: |
| key_states = memory_key_states |
| value_states = memory_value_states |
|
|
| attn_impl = _resolve_attn_impl(attn_module.config) |
| attention_interface: Callable = eager_attention_forward |
| if attn_impl != "eager": |
| attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl] |
|
|
| attn_kwargs = {} |
| replay_mask = None |
| restore_is_causal = None |
| if attn_impl == "flash_attention_2": |
| if module_type == "self": |
| memory_valid = memory_valid_mask |
| if memory_valid is None: |
| memory_valid = torch.ones( |
| (batch_size, memory_key_states.size(2)), |
| device=hidden_states.device, |
| dtype=torch.bool, |
| ) |
| local_valid = local_valid_mask |
| if local_valid is None: |
| local_valid = torch.ones((batch_size, query_len), device=hidden_states.device, dtype=torch.bool) |
| replay_mask = torch.cat( |
| [ |
| memory_valid.to(device=hidden_states.device, dtype=torch.long), |
| local_valid.to(device=hidden_states.device, dtype=torch.long), |
| ], |
| dim=1, |
| ) |
| else: |
| query_lengths = _tri_valid_lengths( |
| local_valid_mask, |
| batch_size=batch_size, |
| default_len=query_len, |
| device=hidden_states.device, |
| ) |
| key_lengths = _tri_valid_lengths( |
| memory_valid_mask, |
| batch_size=batch_size, |
| default_len=memory_key_states.size(2), |
| device=hidden_states.device, |
| ) |
| if bool(torch.all(query_lengths > 0).item()) and bool(torch.all(key_lengths > 0).item()): |
| cu_q, max_q = _tri_cu_seq_lens_from_lengths(query_lengths) |
| cu_k, max_k = _tri_cu_seq_lens_from_lengths(key_lengths) |
| attn_kwargs.update( |
| { |
| "cu_seq_lens_q": cu_q, |
| "cu_seq_lens_k": cu_k, |
| "max_length_q": max_q, |
| "max_length_k": max_k, |
| } |
| ) |
| restore_is_causal = bool(attn_module.is_causal) |
| attn_module.is_causal = False |
| else: |
| attn_impl = "eager" |
| attention_interface = eager_attention_forward |
|
|
| if attn_impl != "flash_attention_2": |
| replay_mask = _tri_build_replay_attention_mask( |
| batch_size=batch_size, |
| query_len=query_len, |
| local_valid_mask=local_valid_mask, |
| memory_valid_mask=memory_valid_mask, |
| module_type=module_type, |
| device=hidden_states.device, |
| dtype=hidden_states.dtype, |
| ) |
|
|
| try: |
| attn_output, _ = attention_interface( |
| attn_module, |
| query_states, |
| key_states, |
| value_states, |
| replay_mask, |
| dropout=0.0 if not attn_module.training else attn_module.attention_dropout, |
| scaling=attn_module.scaling, |
| **attn_kwargs, |
| ) |
| finally: |
| if restore_is_causal is not None: |
| attn_module.is_causal = restore_is_causal |
| attn_output = _apply_sdpa_output_gate(attn_module, hidden_states, attn_output) |
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| attn_output = attn_module.o_proj(attn_output) |
|
|
| if query_replay_mask is not None: |
| query_replay_mask = query_replay_mask.to(device=hidden_states.device, dtype=torch.bool) |
| attn_output = attn_output.masked_fill(~query_replay_mask.unsqueeze(-1), 0) |
| return attn_output |
|
|
|
|
| def _tri_run_last_layer_with_replay( |
| *, |
| layer: "LlamaDecoderLayer", |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| position_ids: Optional[torch.LongTensor], |
| past_key_values: Optional[Cache], |
| use_cache: Optional[bool], |
| cache_position: Optional[torch.LongTensor], |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| memory_hidden_states: Optional[torch.Tensor], |
| memory_position_ids: Optional[torch.LongTensor], |
| memory_valid_mask: Optional[torch.Tensor], |
| local_valid_mask: Optional[torch.Tensor], |
| query_replay_mask: Optional[torch.Tensor], |
| module_type: str, |
| rotary_emb: "LlamaRotaryEmbedding", |
| use_checkpoint: bool, |
| ) -> torch.Tensor: |
| residual = hidden_states |
| hidden_states = layer.input_layernorm(hidden_states) |
| module_type = _normalize_last_layer_module(module_type) |
| has_memory = ( |
| module_type != "none" |
| and isinstance(memory_hidden_states, torch.Tensor) |
| and memory_hidden_states.numel() > 0 |
| and int(memory_hidden_states.size(1)) > 0 |
| and isinstance(memory_position_ids, torch.Tensor) |
| ) |
| attn_kwargs = {} |
| if has_memory and module_type == "self": |
| replay_position_embeddings = rotary_emb(memory_hidden_states, memory_position_ids) |
| memory_key_states, memory_value_states = _tri_project_memory_kv( |
| attn_module=layer.self_attn, |
| memory_hidden_states=memory_hidden_states, |
| memory_position_embeddings=replay_position_embeddings, |
| target_device=hidden_states.device, |
| target_dtype=hidden_states.dtype, |
| ) |
| attn_kwargs.update( |
| { |
| "extra_prefix_kv": (memory_key_states, memory_value_states), |
| "extra_prefix_valid_mask": memory_valid_mask, |
| "extra_prefix_query_mask": query_replay_mask, |
| "extra_prefix_local_valid_mask": local_valid_mask, |
| } |
| ) |
| hidden_states, _ = layer.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| **attn_kwargs, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| if ( |
| has_memory |
| and module_type == "cross" |
| ): |
| replay_attn = layer.ensure_replay_cross_attn() |
| replay_residual = hidden_states |
| replay_hidden = layer.post_attention_layernorm(hidden_states) |
| replay_position_embeddings = rotary_emb(memory_hidden_states, memory_position_ids) |
| replay_out = _tri_replay_attention_forward( |
| attn_module=replay_attn, |
| hidden_states=replay_hidden, |
| position_embeddings=position_embeddings, |
| local_valid_mask=local_valid_mask, |
| query_replay_mask=query_replay_mask, |
| memory_hidden_states=memory_hidden_states, |
| memory_position_embeddings=replay_position_embeddings, |
| memory_valid_mask=memory_valid_mask, |
| module_type=module_type, |
| ) |
| hidden_states = replay_residual + replay_out |
|
|
| residual = hidden_states |
| hidden_states = layer.post_attention_layernorm(hidden_states) |
| if use_checkpoint: |
| hidden_states = _checkpoint_mlp(layer.mlp, hidden_states) |
| else: |
| hidden_states = layer.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
| return hidden_states |
|
|
|
|
| def _tri_upper_solo_v2_attention_forward( |
| attn_module: "LlamaAttention", |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| *, |
| with_bos: bool = False, |
| ) -> torch.Tensor: |
| if hidden_states.size(1) == 0: |
| return hidden_states |
|
|
| input_shape = hidden_states.shape[:-1] |
| hidden_shape = (*input_shape, -1, attn_module.head_dim) |
|
|
| value_states = attn_module.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| value_states = repeat_kv(value_states, attn_module.num_key_value_groups) |
|
|
| if not bool(with_bos): |
| attn_output = value_states.transpose(1, 2).contiguous() |
| attn_output = _apply_sdpa_output_gate(attn_module, hidden_states, attn_output) |
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| return attn_module.o_proj(attn_output) |
|
|
| query_states = attn_module.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| key_states = attn_module.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
| key_states = repeat_kv(key_states, attn_module.num_key_value_groups) |
|
|
| bos_hidden_states = hidden_states[:, :1, :] |
| bos_shape = (bos_hidden_states.size(0), bos_hidden_states.size(1), -1, attn_module.head_dim) |
| bos_key_states = attn_module.k_proj(bos_hidden_states).view(bos_shape).transpose(1, 2) |
| bos_value_states = attn_module.v_proj(bos_hidden_states).view(bos_shape).transpose(1, 2) |
| bos_cos = cos[:, :1, :] |
| bos_sin = sin[:, :1, :] |
| bos_key_states = (bos_key_states * bos_cos.unsqueeze(1)) + (rotate_half(bos_key_states) * bos_sin.unsqueeze(1)) |
| bos_key_states = repeat_kv(bos_key_states, attn_module.num_key_value_groups) |
| bos_value_states = repeat_kv(bos_value_states, attn_module.num_key_value_groups) |
|
|
| bos_scores = ( |
| query_states * bos_key_states.expand(-1, -1, query_states.size(2), -1) |
| ).sum(dim=-1) * attn_module.scaling |
| self_scores = (query_states * key_states).sum(dim=-1) * attn_module.scaling |
|
|
| attn_weights = torch.stack([bos_scores, self_scores], dim=-1) |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
| attn_weights = nn.functional.dropout( |
| attn_weights, |
| p=0.0 if not attn_module.training else attn_module.attention_dropout, |
| training=attn_module.training, |
| ) |
| bos_weight = attn_weights[..., :1] |
| self_weight = attn_weights[..., 1:2] |
| attn_output = ( |
| bos_weight * bos_value_states.expand(-1, -1, query_states.size(2), -1) |
| + self_weight * value_states |
| ) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
| attn_output = _apply_sdpa_output_gate(attn_module, hidden_states, attn_output) |
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| return attn_module.o_proj(attn_output) |
|
|
|
|
| def _tri_run_decoder_layer_upper_solo_v2( |
| *, |
| layer: "LlamaDecoderLayer", |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| with_bos: bool, |
| use_checkpoint: bool, |
| ) -> torch.Tensor: |
| residual = hidden_states |
| attn_input = layer.input_layernorm(hidden_states) |
| attn_output = _tri_upper_solo_v2_attention_forward( |
| layer.self_attn, |
| attn_input, |
| position_embeddings, |
| with_bos=bool(with_bos), |
| ) |
| hidden_states = residual + attn_output |
|
|
| residual = hidden_states |
| hidden_states = layer.post_attention_layernorm(hidden_states) |
| if use_checkpoint: |
| hidden_states = _checkpoint_mlp(layer.mlp, hidden_states) |
| else: |
| hidden_states = layer.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
| return hidden_states |
|
|
|
|
| def _tri_compact_solo_v2_upper_training_inputs( |
| *, |
| hidden_states: torch.Tensor, |
| position_ids: torch.LongTensor, |
| labels: torch.LongTensor, |
| valid_lens: torch.LongTensor, |
| logits_to_keep: int, |
| with_bos: bool, |
| ): |
| valid_lens_device = valid_lens.to(device=hidden_states.device, dtype=torch.long) |
| keep_lens = torch.minimum( |
| valid_lens_device, |
| torch.full_like(valid_lens_device, int(logits_to_keep)), |
| ) |
| split_starts = (valid_lens_device - keep_lens).clamp_min(0) |
| if bool(with_bos) and bool(torch.any(split_starts <= 0).item()): |
| return None |
|
|
| upper_hidden, upper_valid_mask, upper_lens, _ = _tri_pack_suffix_tensor( |
| hidden_states, |
| split_starts=split_starts, |
| valid_lens=valid_lens, |
| pad_value=0.0, |
| ) |
| upper_position_ids, _, _, _ = _tri_pack_suffix_tensor( |
| position_ids.to(device=hidden_states.device), |
| split_starts=split_starts, |
| valid_lens=valid_lens, |
| pad_value=0, |
| ) |
| upper_labels, _, _, _ = _tri_pack_suffix_tensor( |
| labels.to(device=hidden_states.device), |
| split_starts=split_starts, |
| valid_lens=valid_lens, |
| pad_value=-100, |
| ) |
|
|
| if bool(with_bos): |
| bos_valid = torch.ones((hidden_states.size(0), 1), device=hidden_states.device, dtype=torch.bool) |
| upper_hidden = torch.cat([hidden_states[:, :1, :], upper_hidden], dim=1) |
| upper_position_ids = torch.cat( |
| [position_ids.to(device=hidden_states.device)[:, :1], upper_position_ids], |
| dim=1, |
| ) |
| bos_labels = torch.full( |
| (labels.size(0), 1), |
| -100, |
| device=hidden_states.device, |
| dtype=labels.dtype, |
| ) |
| upper_labels = torch.cat([bos_labels, upper_labels], dim=1) |
| upper_valid_mask = torch.cat([bos_valid, upper_valid_mask], dim=1) |
| upper_lens = upper_lens + 1 |
|
|
| return upper_hidden, upper_position_ids, upper_labels, upper_valid_mask, upper_lens |
|
|
|
|
| def _build_tri_mask_prefix_full(B: int, prefix_len: int, total_len: int, device, dtype): |
| |
| |
| if total_len <= 0: |
| return torch.zeros((B, 1, 0, 0), device=device, dtype=dtype) |
| P = max(0, min(int(prefix_len), int(total_len))) |
| i = torch.arange(total_len, device=device).unsqueeze(1) |
| j = torch.arange(total_len, device=device).unsqueeze(0) |
| allow = torch.where(i < P, j < P, j <= i) |
| neg_inf = torch.tensor(torch.finfo(torch.float32).min, device=device) |
| mask = torch.where(allow, torch.tensor(0.0, device=device), neg_inf).to(dtype) |
| return mask.view(1, 1, total_len, total_len).expand(B, 1, total_len, total_len) |
|
|
|
|
| def _build_tri_mask_prefix_full_varlen( |
| prefix_lens: torch.LongTensor, |
| attention_mask: torch.Tensor, |
| dtype, |
| ): |
| |
| if attention_mask.dim() != 2: |
| raise ValueError("attention_mask must be rank-2 for variable-length prefix-full masking.") |
| B, total_len = attention_mask.shape |
| device = attention_mask.device |
| if total_len <= 0: |
| return torch.zeros((B, 1, 0, 0), device=device, dtype=dtype) |
|
|
| prefix_lens = prefix_lens.to(device=device, dtype=torch.long).clamp_(min=0, max=total_len) |
| i = torch.arange(total_len, device=device, dtype=torch.long).view(1, total_len, 1) |
| j = torch.arange(total_len, device=device, dtype=torch.long).view(1, 1, total_len) |
| prefix = prefix_lens.view(B, 1, 1) |
| allow = torch.where(i < prefix, j < prefix, j <= i) |
| key_valid = attention_mask.to(device=device, dtype=torch.bool).view(B, 1, total_len) |
| allow = allow & key_valid |
|
|
| neg_inf = torch.tensor(torch.finfo(torch.float32).min, device=device) |
| mask = torch.where(allow, torch.tensor(0.0, device=device), neg_inf).to(dtype) |
| return mask.view(B, 1, total_len, total_len) |
|
|
|
|
| def _can_use_implicit_causal_mask(config) -> bool: |
| attn_impl = _resolve_attn_impl(config) |
| if attn_impl not in {"sdpa", "flash_attention_2", "flex_attention"}: |
| return False |
| |
| return getattr(config, "sliding_window", None) is None |
|
|
|
|
| def _llopa_upper_indices(S: int, U: int, A: int, device) -> torch.LongTensor: |
| if U <= 0: |
| return torch.arange(0, S + A, device=device, dtype=torch.long) |
| if A <= 0: |
| return torch.arange(0, S, device=device, dtype=torch.long) |
| left = torch.arange(0, S, device=device, dtype=torch.long) |
| right = torch.arange(S + U, S + U + A, device=device, dtype=torch.long) |
| return torch.cat([left, right], dim=0) |
|
|
|
|
| def _llopa_default_attention_mask(token_ids: torch.Tensor, attention_mask: Optional[torch.Tensor]) -> torch.LongTensor: |
| if attention_mask is None: |
| return torch.ones_like(token_ids, dtype=torch.long) |
| return attention_mask.to(device=token_ids.device, dtype=torch.long) |
|
|
|
|
| def _llopa_mask_is_all_ones(attention_mask: Optional[torch.Tensor]) -> bool: |
| if attention_mask is None: |
| return True |
| return bool(torch.all(attention_mask > 0).item()) |
|
|
|
|
| def _llopa_position_ids_from_mask(attention_mask: torch.LongTensor) -> torch.LongTensor: |
| if attention_mask.numel() == 0: |
| return attention_mask.to(dtype=torch.long) |
| position_ids = attention_mask.cumsum(dim=1, dtype=torch.long) - 1 |
| return position_ids.clamp_min_(0) |
|
|
|
|
| def _llopa_profile_enabled(model) -> bool: |
| return bool(getattr(model, "_llopa_profile_memory_enabled", False)) |
|
|
|
|
| def _llopa_profile_step(model) -> int: |
| return int(getattr(model, "_llopa_profile_memory_step", 0) or 0) |
|
|
|
|
| def _llopa_profile_rank() -> int: |
| if torch.distributed.is_available() and torch.distributed.is_initialized(): |
| return int(torch.distributed.get_rank()) |
| return 0 |
|
|
|
|
| def _llopa_segment_length(token_ids: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor]) -> int: |
| ref = attention_mask if attention_mask is not None else token_ids |
| if ref is None: |
| return 0 |
| if attention_mask is not None and attention_mask.numel() > 0: |
| return int(attention_mask.sum(dim=1, dtype=torch.long).max().item()) |
| if ref.dim() >= 2: |
| return int(ref.size(1)) |
| return int(ref.numel()) |
|
|
|
|
| def _llopa_memory_device( |
| *tensors: Optional[torch.Tensor], |
| ) -> Optional[torch.device]: |
| for tensor in tensors: |
| if isinstance(tensor, torch.Tensor): |
| return tensor.device |
| return None |
|
|
|
|
| def _llopa_reset_peak_memory(*tensors: Optional[torch.Tensor]) -> None: |
| device = _llopa_memory_device(*tensors) |
| if device is None or device.type != "cuda": |
| return |
| torch.cuda.reset_peak_memory_stats(device=device) |
|
|
|
|
| def _llopa_log_memory( |
| model, |
| stage: str, |
| *, |
| system_ids: Optional[torch.Tensor] = None, |
| user_ids: Optional[torch.Tensor] = None, |
| assistant_ids: Optional[torch.Tensor] = None, |
| system_attention_mask: Optional[torch.Tensor] = None, |
| user_attention_mask: Optional[torch.Tensor] = None, |
| assistant_attention_mask: Optional[torch.Tensor] = None, |
| sequence_len: Optional[int] = None, |
| ) -> None: |
| if not _llopa_profile_enabled(model): |
| return |
| device = None |
| device = _llopa_memory_device( |
| system_ids, |
| user_ids, |
| assistant_ids, |
| system_attention_mask, |
| user_attention_mask, |
| assistant_attention_mask, |
| ) |
| if device is None or device.type != "cuda": |
| return |
|
|
| logger.info( |
| "[LLOPA_MEM] step=%s rank=%s stage=%s peak_scope=since_reset system_len=%s user_len=%s assistant_len=%s sequence_len=%s " |
| "allocated_GiB=%.3f reserved_GiB=%.3f max_allocated_GiB=%.3f max_reserved_GiB=%.3f", |
| _llopa_profile_step(model), |
| _llopa_profile_rank(), |
| stage, |
| _llopa_segment_length(system_ids, system_attention_mask), |
| _llopa_segment_length(user_ids, user_attention_mask), |
| _llopa_segment_length(assistant_ids, assistant_attention_mask), |
| -1 if sequence_len is None else int(sequence_len), |
| torch.cuda.memory_allocated(device=device) / 2**30, |
| torch.cuda.memory_reserved(device=device) / 2**30, |
| torch.cuda.max_memory_allocated(device=device) / 2**30, |
| torch.cuda.max_memory_reserved(device=device) / 2**30, |
| ) |
|
|
|
|
| def _llopa_pack_upper_hidden(hidden_states: torch.Tensor, S: int, U: int) -> torch.Tensor: |
| if U <= 0: |
| return hidden_states |
| B, _, H = hidden_states.shape |
| A = hidden_states.size(1) - S - U |
| packed = hidden_states.new_empty((B, S + A, H)) |
| if S > 0: |
| packed[:, :S, :] = hidden_states[:, :S, :] |
| if A > 0: |
| packed[:, S:, :] = hidden_states[:, S + U :, :] |
| return packed |
|
|
|
|
| def _llopa_pack_upper_tokens(token_tensor: torch.Tensor, S: int, U: int) -> torch.Tensor: |
| if U <= 0: |
| return token_tensor |
| B = token_tensor.size(0) |
| A = token_tensor.size(1) - S - U |
| packed = token_tensor.new_empty((B, S + A)) |
| if S > 0: |
| packed[:, :S] = token_tensor[:, :S] |
| if A > 0: |
| packed[:, S:] = token_tensor[:, S + U :] |
| return packed |
|
|
|
|
| def _llopa_insert_specials(upper_hidden: torch.Tensor, specials: torch.Tensor, S: int) -> torch.Tensor: |
| B, compact_len, H = upper_hidden.shape |
| num_specials = specials.size(1) |
| layer_hidden = upper_hidden.new_empty((B, compact_len + num_specials, H)) |
| if S > 0: |
| layer_hidden[:, :S, :] = upper_hidden[:, :S, :] |
| layer_hidden[:, S : S + num_specials, :] = specials |
| if compact_len > S: |
| layer_hidden[:, S + num_specials :, :] = upper_hidden[:, S:, :] |
| return layer_hidden |
|
|
|
|
| def _llopa_unpack_upper_hidden(layer_hidden: torch.Tensor, S: int, num_specials: int) -> torch.Tensor: |
| if num_specials <= 0: |
| return layer_hidden |
| B, expanded_len, H = layer_hidden.shape |
| compact = layer_hidden.new_empty((B, expanded_len - num_specials, H)) |
| if S > 0: |
| compact[:, :S, :] = layer_hidden[:, :S, :] |
| if expanded_len > S + num_specials: |
| compact[:, S:, :] = layer_hidden[:, S + num_specials :, :] |
| return compact |
|
|
|
|
| def _llopa_remove_specials(layer_hidden: torch.Tensor, S: int, num_specials: int) -> torch.Tensor: |
| return _llopa_unpack_upper_hidden(layer_hidden, S, num_specials) |
|
|
|
|
| def _llopa_use_optimized_training_path( |
| model, |
| *, |
| attn_impl: str, |
| prefix_full: bool, |
| no_upper_attn: bool, |
| ) -> bool: |
| return ( |
| bool(getattr(model, "training", False)) |
| and attn_impl == "flash_attention_2" |
| and not prefix_full |
| and not bool(no_upper_attn) |
| ) |
|
|
|
|
| def _llopa_warn_once(model, flag: str, msg: str) -> None: |
| if getattr(model, flag, False): |
| return |
| try: |
| setattr(model, flag, True) |
| except Exception: |
| pass |
| logger.warning(msg) |
|
|
|
|
| def _llopa_split_system(system_ids: torch.Tensor, system_prefill: str): |
| """Return (system_upper, system_lower_extra) based on system_prefill.""" |
| sys_prefill = (system_prefill or "full").strip().lower() |
| if sys_prefill not in {"full", "no_system", "no_bos_system"}: |
| sys_prefill = "full" |
| if sys_prefill == "full": |
| return system_ids, system_ids[:, :0] |
| if sys_prefill == "no_system": |
| if system_ids.size(1) < 1: |
| return system_ids[:, :0], system_ids[:, :0] |
| return system_ids[:, :1], system_ids[:, 1:] |
| return system_ids[:, :0], system_ids |
|
|
|
|
| def _llopa_merge_user(system_ids: torch.Tensor, user_ids: torch.Tensor, system_prefill: str): |
| sys_upper, sys_lower_extra = _llopa_split_system(system_ids, system_prefill) |
| if sys_lower_extra.numel() == 0: |
| return sys_upper, user_ids |
| if user_ids.numel() == 0: |
| return sys_upper, sys_lower_extra |
| return sys_upper, torch.cat([sys_lower_extra, user_ids], dim=1) |
|
|
|
|
| def _llopa_effective_specials(model, gap_u: int) -> int: |
| num = int(getattr(model, "llopa_num_specials", 0) or 0) |
| if num <= 0: |
| return 0 |
| if getattr(model, "llopa_specials", None) is None: |
| _llopa_warn_once(model, "_llopa_warned_no_specials", |
| "[LLOPA] llopa_num_specials set but llopa_specials missing; specials disabled.") |
| return 0 |
| if gap_u <= 0: |
| _llopa_warn_once(model, "_llopa_warned_no_gap", |
| f"[LLOPA] num_specials={num} but gap=0 (U=0); specials disabled.") |
| return 0 |
| if num > gap_u: |
| _llopa_warn_once(model, "_llopa_warned_short_gap", |
| f"[LLOPA] num_specials={num} exceeds gap U={gap_u}; using {gap_u} specials.") |
| return gap_u |
| return num |
|
|
|
|
| def _llopa_can_cache_special_kv(model) -> bool: |
| return (not bool(getattr(model, "training", False))) and (not torch.is_grad_enabled()) |
|
|
|
|
| def _llopa_special_kv_for_layer( |
| model: "LlamaModel", |
| layer: "LlamaDecoderLayer", |
| layer_idx: int, |
| num_specials_eff: int, |
| start_pos: int, |
| batch_size: int, |
| dtype: torch.dtype, |
| device, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| if num_specials_eff <= 0: |
| empty = torch.empty((batch_size, layer.self_attn.num_key_value_heads, 0, layer.self_attn.head_dim), |
| device=device, dtype=dtype) |
| return empty, empty |
|
|
| specials_all = getattr(model, "llopa_specials", None) |
| if specials_all is None: |
| empty = torch.empty((batch_size, layer.self_attn.num_key_value_heads, 0, layer.self_attn.head_dim), |
| device=device, dtype=dtype) |
| return empty, empty |
|
|
| can_cache = _llopa_can_cache_special_kv(model) |
| cache = None |
| cache_key = None |
| param_ptr = None |
| param_ver = None |
| if can_cache: |
| cache = getattr(model, "_llopa_special_kv_cache", None) |
| if not isinstance(cache, dict): |
| cache = {} |
| setattr(model, "_llopa_special_kv_cache", cache) |
| dev_index = -1 if getattr(device, "index", None) is None else int(device.index) |
| cache_key = (int(layer_idx), int(num_specials_eff), int(start_pos), str(dtype), str(device.type), dev_index) |
| try: |
| param_ptr = int(specials_all[layer_idx].data_ptr()) |
| except Exception: |
| param_ptr = None |
| try: |
| param_ver = int(getattr(specials_all[layer_idx], "_version", 0)) |
| except Exception: |
| param_ver = None |
| entry = cache.get(cache_key) |
| if isinstance(entry, dict) and entry.get("param_ptr") == param_ptr and entry.get("param_ver") == param_ver: |
| k_base = entry.get("k") |
| v_base = entry.get("v") |
| if isinstance(k_base, torch.Tensor) and isinstance(v_base, torch.Tensor): |
| if k_base.device == device and v_base.device == device: |
| if k_base.dtype == dtype and v_base.dtype == dtype: |
| if k_base.size(0) == 1 and batch_size > 1: |
| return k_base.expand(batch_size, -1, -1, -1), v_base.expand(batch_size, -1, -1, -1) |
| return k_base, v_base |
|
|
| specials = specials_all[layer_idx][:num_specials_eff].to(device=device, dtype=dtype) |
| if can_cache: |
| specials = specials.unsqueeze(0) |
| sp_pos = torch.arange(start_pos, start_pos + num_specials_eff, device=device, dtype=torch.long).unsqueeze(0) |
| else: |
| specials = specials.unsqueeze(0).expand(batch_size, -1, -1) |
| sp_pos = torch.arange(start_pos, start_pos + num_specials_eff, device=device, dtype=torch.long).unsqueeze(0) |
| sp_pos = sp_pos.expand(batch_size, -1) |
|
|
| specials = layer.input_layernorm(specials) |
| hidden_shape = (*specials.shape[:-1], -1, layer.self_attn.head_dim) |
| key_states = layer.self_attn.k_proj(specials).view(hidden_shape).transpose(1, 2) |
| value_states = layer.self_attn.v_proj(specials).view(hidden_shape).transpose(1, 2) |
| cos_sp, sin_sp = model.rotary_emb(specials, sp_pos) |
| key_states = (key_states * cos_sp.unsqueeze(1)) + (rotate_half(key_states) * sin_sp.unsqueeze(1)) |
|
|
| if can_cache: |
| key_states = key_states.detach() |
| value_states = value_states.detach() |
| if cache is not None and cache_key is not None: |
| cache[cache_key] = {"k": key_states, "v": value_states, "param_ptr": param_ptr, "param_ver": param_ver} |
| try: |
| max_entries = int(os.environ.get("LLOPA_SPECIAL_KV_CACHE_MAX", "128") or 128) |
| except Exception: |
| max_entries = 128 |
| if max_entries > 0: |
| while len(cache) > max_entries: |
| cache.pop(next(iter(cache))) |
| if batch_size > 1: |
| return key_states.expand(batch_size, -1, -1, -1), value_states.expand(batch_size, -1, -1, -1) |
| return key_states, value_states |
|
|
|
|
| def _tri_prefix_keep_lengths(system_lens: torch.LongTensor, system_prefill: str) -> torch.LongTensor: |
| system_lens = system_lens.to(dtype=torch.long) |
| mode = normalize_system_prefill(system_prefill) |
| if mode == "full": |
| return system_lens |
| if mode == "no_system": |
| return torch.minimum(system_lens, torch.ones_like(system_lens)) |
| return torch.zeros_like(system_lens) |
|
|
|
|
| def _tri_effective_suffix_special_token_ids(model) -> list[int]: |
| config = getattr(model, "config", None) |
| num_suffix_specials = int(getattr(config, "capsule_num_suffix_specials", 0) or 0) |
| if num_suffix_specials <= 0: |
| return [] |
|
|
| raw_ids = getattr(config, "capsule_suffix_special_token_ids", None) |
| if not isinstance(raw_ids, (list, tuple)) or len(raw_ids) == 0: |
| _tri_vanilla_warn_once( |
| model, |
| "_tri_warned_missing_suffix_special_ids", |
| "[TRI] capsule_num_suffix_specials is set but capsule_suffix_special_token_ids is missing; fusion specials disabled.", |
| ) |
| return [] |
|
|
| token_ids: list[int] = [] |
| for raw_id in list(raw_ids)[:num_suffix_specials]: |
| try: |
| token_id = int(raw_id) |
| except Exception: |
| continue |
| if token_id >= 0: |
| token_ids.append(token_id) |
|
|
| if len(token_ids) < num_suffix_specials: |
| _tri_vanilla_warn_once( |
| model, |
| "_tri_warned_short_suffix_special_ids", |
| "[TRI] Some configured fusion special token ids are invalid; using the available subset.", |
| ) |
| return token_ids |
|
|
|
|
| def _tri_normalize_fusion_mode(mode) -> str: |
| normalized = str(mode or "upper_only").strip().lower() |
| return normalized or "upper_only" |
|
|
|
|
| def _tri_effective_fusion_mode(model) -> str: |
| config = getattr(model, "config", None) |
| return _tri_normalize_fusion_mode(getattr(config, "capsule_fusion_mode", "upper_only")) |
|
|
|
|
| def _tri_build_suffix_special_position_ids( |
| split_starts: torch.LongTensor, |
| *, |
| num_suffix_specials: int, |
| device, |
| ) -> torch.LongTensor: |
| split_starts = split_starts.to(device=device, dtype=torch.long) |
| if num_suffix_specials <= 0: |
| return torch.empty((split_starts.size(0), 0), device=device, dtype=torch.long) |
| offsets = torch.arange(num_suffix_specials, device=device, dtype=torch.long).unsqueeze(0) |
| return (split_starts.unsqueeze(1) - int(num_suffix_specials) + offsets).clamp_min_(0) |
|
|
|
|
| def _tri_repack_upper_with_suffix_specials( |
| model: "LlamaForCausalLM", |
| *, |
| upper_hidden: torch.Tensor, |
| upper_position_ids: torch.LongTensor, |
| upper_attention_mask: torch.Tensor, |
| decode_labels: torch.LongTensor, |
| prefix_keep_lens: torch.LongTensor, |
| split_starts: torch.LongTensor, |
| ) -> tuple[torch.Tensor, torch.LongTensor, torch.Tensor, torch.LongTensor, torch.LongTensor]: |
| if _tri_effective_fusion_mode(model) != "upper_only": |
| valid_lens = upper_attention_mask.sum(dim=1, dtype=torch.long) |
| return upper_hidden, upper_position_ids, upper_attention_mask, decode_labels, valid_lens |
| token_ids = _tri_effective_suffix_special_token_ids(model) |
| if not token_ids: |
| valid_lens = upper_attention_mask.sum(dim=1, dtype=torch.long) |
| return upper_hidden, upper_position_ids, upper_attention_mask, decode_labels, valid_lens |
|
|
| B, _, H = upper_hidden.shape |
| device = upper_hidden.device |
| dtype = upper_hidden.dtype |
| prefix_keep_lens = prefix_keep_lens.to(device=device, dtype=torch.long) |
| split_starts = split_starts.to(device=device, dtype=torch.long) |
| valid_lens = upper_attention_mask.sum(dim=1, dtype=torch.long).to(device=device) |
| suffix_lens = (valid_lens - prefix_keep_lens).clamp_min_(0) |
| num_suffix_specials = len(token_ids) |
|
|
| token_id_tensor = torch.tensor(token_ids, device=device, dtype=torch.long).unsqueeze(0).expand(B, -1) |
| special_hidden = model.model.embed_tokens(token_id_tensor).to(dtype=dtype) |
| special_position_ids = _tri_build_suffix_special_position_ids( |
| split_starts, |
| num_suffix_specials=num_suffix_specials, |
| device=device, |
| ) |
|
|
| repacked_valid_lens = prefix_keep_lens + num_suffix_specials + suffix_lens |
| total_len = int(repacked_valid_lens.max().item()) if B > 0 else 0 |
| repacked_hidden = upper_hidden.new_zeros((B, total_len, H)) |
| repacked_position_ids = upper_position_ids.new_zeros((B, total_len)) |
| repacked_attention_mask = upper_attention_mask.new_zeros((B, total_len)) |
| repacked_labels = decode_labels.new_full((B, total_len), -100) |
|
|
| for row in range(B): |
| prefix_keep = int(prefix_keep_lens[row].item()) |
| suffix_len = int(suffix_lens[row].item()) |
| special_start = prefix_keep |
| if prefix_keep > 0: |
| repacked_hidden[row, :prefix_keep, :] = upper_hidden[row, :prefix_keep, :] |
| repacked_position_ids[row, :prefix_keep] = upper_position_ids[row, :prefix_keep] |
| repacked_attention_mask[row, :prefix_keep] = upper_attention_mask[row, :prefix_keep] |
| repacked_labels[row, :prefix_keep] = decode_labels[row, :prefix_keep] |
|
|
| repacked_hidden[row, special_start : special_start + num_suffix_specials, :] = special_hidden[row] |
| repacked_position_ids[row, special_start : special_start + num_suffix_specials] = special_position_ids[row] |
| repacked_attention_mask[row, special_start : special_start + num_suffix_specials] = 1 |
|
|
| if suffix_len > 0: |
| src_start = prefix_keep |
| dst_start = special_start + num_suffix_specials |
| repacked_hidden[row, dst_start : dst_start + suffix_len, :] = upper_hidden[row, src_start : src_start + suffix_len, :] |
| repacked_position_ids[row, dst_start : dst_start + suffix_len] = upper_position_ids[row, src_start : src_start + suffix_len] |
| repacked_attention_mask[row, dst_start : dst_start + suffix_len] = upper_attention_mask[row, src_start : src_start + suffix_len] |
| repacked_labels[row, dst_start : dst_start + suffix_len] = decode_labels[row, src_start : src_start + suffix_len] |
|
|
| return repacked_hidden, repacked_position_ids, repacked_attention_mask, repacked_labels, repacked_valid_lens |
|
|
|
|
| def _tri_prefill_suffix_specials_upper_cache( |
| model: "LlamaModel", |
| pkv: Cache, |
| *, |
| lower_k: int, |
| split_start: Optional[int] = None, |
| split_starts=None, |
| no_upper_attn: bool, |
| ) -> int: |
| if _tri_effective_fusion_mode(model) != "upper_only": |
| return 0 |
| token_ids = _tri_effective_suffix_special_token_ids(model) |
| if not token_ids: |
| return 0 |
| if bool(no_upper_attn): |
| _tri_vanilla_warn_once( |
| model, |
| "_tri_warned_suffix_specials_no_upper_attn", |
| "[TRI] Fusion specials are enabled, but no_upper_attn=True disables upper-layer attention; skipping fusion specials.", |
| ) |
| return 0 |
| if bool(getattr(pkv, "_capsule_suffix_specials_written", False)): |
| return int(getattr(pkv, "_capsule_suffix_specials_count", 0) or 0) |
|
|
| n_layers = len(model.layers) |
| K = max(0, min(int(lower_k), n_layers)) |
| if K >= n_layers: |
| return 0 |
|
|
| split_positions: list[int] = [] |
| if isinstance(split_starts, torch.Tensor): |
| split_positions = [int(v) for v in split_starts.view(-1).tolist()] |
| elif isinstance(split_starts, (list, tuple)): |
| split_positions = [int(v) for v in split_starts] |
| elif split_start is not None: |
| split_positions = [int(split_start)] |
| split_positions = sorted({int(v) for v in split_positions if int(v) >= 0}) |
| if not split_positions: |
| return 0 |
|
|
| device = model.embed_tokens.weight.device |
| token_id_blocks: list[int] = [] |
| for _ in split_positions: |
| token_id_blocks.extend(token_ids) |
| token_id_tensor = torch.tensor(token_id_blocks, device=device, dtype=torch.long).unsqueeze(0) |
| hidden_states = model.embed_tokens(token_id_tensor) |
| num_suffix_specials = int(token_id_tensor.size(1)) |
| split_start_tensor = torch.tensor(split_positions, device=device, dtype=torch.long) |
| cache_position = _tri_build_suffix_special_position_ids( |
| split_start_tensor, |
| num_suffix_specials=len(token_ids), |
| device=device, |
| ).reshape(-1) |
| position_ids = cache_position.unsqueeze(0) |
| position_embeddings = model.rotary_emb(hidden_states, position_ids) |
| attn_impl = _resolve_attn_impl(model.config) |
| implicit_causal = attn_impl == "flash_attention_2" or _can_use_implicit_causal_mask(model.config) |
|
|
| past_lens = [_layer_past_len(pkv, li) for li in range(n_layers)] |
| for li in range(K, n_layers): |
| layer = model.layers[li] |
| if implicit_causal: |
| attn_mask = None |
| else: |
| attn_mask = _build_tri_mask_local( |
| 1, |
| num_suffix_specials, |
| int(past_lens[li]), |
| device, |
| hidden_states.dtype, |
| ) |
| hidden_states = layer( |
| hidden_states, |
| attention_mask=attn_mask, |
| position_ids=position_ids, |
| past_key_values=pkv, |
| use_cache=True, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
|
|
| try: |
| setattr(pkv, "_tri_past_len_cache", None) |
| setattr(pkv, "_capsule_suffix_specials_written", True) |
| setattr(pkv, "_capsule_suffix_specials_count", num_suffix_specials) |
| except Exception: |
| pass |
| return num_suffix_specials |
|
|
|
|
| def _tri_prefill_upper_history_cache( |
| model: "LlamaModel", |
| pkv: Cache, |
| *, |
| hidden_states: torch.Tensor, |
| position_ids: torch.LongTensor, |
| lower_k: int, |
| no_upper_attn: bool, |
| ) -> int: |
| if bool(no_upper_attn): |
| return 0 |
| if not isinstance(hidden_states, torch.Tensor) or hidden_states.numel() == 0 or int(hidden_states.size(1)) <= 0: |
| return 0 |
| if not isinstance(position_ids, torch.Tensor) or position_ids.shape[:2] != hidden_states.shape[:2]: |
| return 0 |
|
|
| n_layers = len(model.layers) |
| K = max(0, min(int(lower_k), n_layers)) |
| if K >= n_layers: |
| return 0 |
|
|
| device = hidden_states.device |
| token_count = int(hidden_states.size(1)) |
| cache_position = position_ids.reshape(-1).to(device=device, dtype=torch.long) |
| layer_position_ids = position_ids.to(device=device, dtype=torch.long) |
| position_embeddings = model.rotary_emb(hidden_states, layer_position_ids) |
| attn_impl = _resolve_attn_impl(model.config) |
| implicit_causal = attn_impl == "flash_attention_2" or _can_use_implicit_causal_mask(model.config) |
|
|
| past_lens = [_layer_past_len(pkv, li) for li in range(n_layers)] |
| for li in range(K, n_layers): |
| layer = model.layers[li] |
| if implicit_causal: |
| attn_mask = None |
| else: |
| attn_mask = _build_tri_mask_local( |
| int(hidden_states.size(0)), |
| token_count, |
| int(past_lens[li]), |
| device, |
| hidden_states.dtype, |
| ) |
| hidden_states = layer( |
| hidden_states, |
| attention_mask=attn_mask, |
| position_ids=layer_position_ids, |
| past_key_values=pkv, |
| use_cache=True, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
|
|
| try: |
| setattr(pkv, "_tri_past_len_cache", None) |
| except Exception: |
| pass |
| return token_count |
|
|
|
|
| def tri_prefill_system_all( |
| self: "LlamaModel", |
| system_ids: torch.LongTensor, |
| past_key_values: Optional[Cache] = None, |
| prefill_attn: str = "causal", |
| ): |
| device = system_ids.device |
| if past_key_values is None: |
| past_key_values = _safe_dynamic_cache(self.config) |
|
|
| inputs_embeds = self.embed_tokens(system_ids) |
| start = past_key_values.get_seq_length() |
| cache_position = _tri_arange(start, inputs_embeds.shape[1], device) |
| position_ids = cache_position.unsqueeze(0) |
| attention_mask = torch.ones_like(system_ids, dtype=torch.long, device=device) |
|
|
| attn_impl = _resolve_attn_impl(self.config) |
| if (prefill_attn or "causal") == "full": |
| if attn_impl == "flash_attention_2": |
| raise ValueError("prefill_attn=full is not supported with flash_attention_2.") |
| causal_mask = _build_tri_mask_full(system_ids.size(0), inputs_embeds.shape[1], start, device, inputs_embeds.dtype) |
| elif attn_impl == "flash_attention_2" or _can_use_implicit_causal_mask(self.config): |
| causal_mask = None |
| else: |
| causal_mask = create_causal_mask( |
| config=self.config, |
| input_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| cache_position=cache_position, |
| past_key_values=past_key_values, |
| position_ids=position_ids, |
| ) |
|
|
| hidden_states = inputs_embeds |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
| for layer in self.layers: |
| hidden_states = layer( |
| hidden_states, |
| attention_mask=causal_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=True, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| hidden_states = self.norm(hidden_states) |
| return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=past_key_values) |
|
|
|
|
| def tri_prefill_user_lower( |
| self: "LlamaModel", |
| user_ids: torch.LongTensor, |
| lower_k: int, |
| past_key_values: Cache, |
| prefill_mode: str = "lower", |
| prefill_attn: str = "causal", |
| return_pre_norm_hidden: bool = False, |
| ): |
| device = user_ids.device |
| inputs_embeds = self.embed_tokens(user_ids) |
| start = past_key_values.get_seq_length() |
| cache_position = _tri_arange(start, inputs_embeds.shape[1], device) |
| position_ids = cache_position.unsqueeze(0) |
| attention_mask = torch.ones_like(user_ids, dtype=torch.long, device=device) |
|
|
| attn_impl = _resolve_attn_impl(self.config) |
| if (prefill_attn or "causal") == "full": |
| if attn_impl == "flash_attention_2": |
| raise ValueError("prefill_attn=full is not supported with flash_attention_2.") |
| causal_mask = _build_tri_mask_full(user_ids.size(0), inputs_embeds.shape[1], start, device, inputs_embeds.dtype) |
| elif attn_impl == "flash_attention_2" or _can_use_implicit_causal_mask(self.config): |
| causal_mask = None |
| else: |
| causal_mask = create_causal_mask( |
| config=self.config, |
| input_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| cache_position=cache_position, |
| past_key_values=past_key_values, |
| position_ids=position_ids, |
| ) |
|
|
| hidden_states = inputs_embeds |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
| select = _prefill_layer_indices(self.config.num_hidden_layers, prefill_mode, lower_k) |
| for li in select: |
| layer = self.layers[li] |
| hidden_states = layer( |
| hidden_states, |
| attention_mask=causal_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=True, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| pre_norm_hidden = hidden_states |
| hidden_states = self.norm(hidden_states) |
| out = BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=past_key_values) |
| if bool(return_pre_norm_hidden): |
| setattr(out, "pre_norm_hidden", pre_norm_hidden) |
| return out |
|
|
|
|
| def tri_build_caches( |
| self: "LlamaModel", |
| system_ids: torch.LongTensor, |
| user_ids: torch.LongTensor, |
| lower_k: int, |
| prefill_mode: str = "lower", |
| prefill_attn: str = "causal", |
| ) -> Tuple[Cache, int, int]: |
| out = self.tri_prefill_system_all(system_ids, past_key_values=None, prefill_attn=prefill_attn) |
| pkv = out.past_key_values |
| S = system_ids.size(1) |
| _ = self.tri_prefill_user_lower( |
| user_ids, |
| lower_k=lower_k, |
| past_key_values=pkv, |
| prefill_mode=prefill_mode, |
| prefill_attn=prefill_attn, |
| ) |
| U = user_ids.size(1) |
| return pkv, S, U |
|
|
|
|
| def _tri_build_valid_cache_decode_mask( |
| valid_mask: torch.Tensor, |
| *, |
| query_len: int, |
| device, |
| dtype: torch.dtype, |
| ) -> torch.Tensor: |
| valid_mask = valid_mask.to(device=device, dtype=torch.bool) |
| B, past_len = valid_mask.shape |
| query_len = int(query_len) |
| past_visible = valid_mask.unsqueeze(1).expand(B, query_len, past_len) |
| query_visible = torch.ones((B, query_len, query_len), device=device, dtype=torch.bool) |
| if query_len > 1: |
| query_visible = torch.tril(query_visible) |
| visible = torch.cat([past_visible, query_visible], dim=-1) |
| mask = torch.zeros((B, 1, query_len, past_len + query_len), device=device, dtype=dtype) |
| mask = mask.masked_fill(~visible.unsqueeze(1), torch.finfo(dtype).min) |
| return mask |
|
|
|
|
| def _tri_batch_packed_varlen_decode_enabled(hidden_states: torch.Tensor, *, batch_size: int, query_len: int) -> bool: |
| raw = str(os.environ.get("LLOPA_BATCH_PACKED_VARLEN", "1")).strip().lower() |
| if raw in {"0", "false", "no", "off"}: |
| return False |
| if _flash_attn_varlen_func is None: |
| return False |
| if int(batch_size) <= 1 or int(query_len) != 1: |
| return False |
| if not isinstance(hidden_states, torch.Tensor) or not hidden_states.is_cuda: |
| return False |
| if hidden_states.dtype not in (torch.float16, torch.bfloat16): |
| return False |
| return True |
|
|
|
|
| def _tri_packed_varlen_attention_decode( |
| attn_module: "LlamaAttention", |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| *, |
| past_key_values: Cache, |
| cache_position: torch.LongTensor, |
| layer_valid_mask: torch.Tensor, |
| ) -> Optional[torch.Tensor]: |
| if _flash_attn_varlen_func is None: |
| return None |
| if hidden_states.dim() != 3 or int(hidden_states.size(1)) != 1: |
| return None |
| if not hidden_states.is_cuda or hidden_states.dtype not in (torch.float16, torch.bfloat16): |
| return None |
| if not isinstance(layer_valid_mask, torch.Tensor) or layer_valid_mask.dim() != 2: |
| return None |
|
|
| batch_size = int(hidden_states.size(0)) |
| past_width = int(layer_valid_mask.size(1)) |
| if batch_size <= 1 or int(layer_valid_mask.size(0)) != batch_size: |
| return None |
| if _layer_past_len(past_key_values, int(attn_module.layer_idx)) != past_width: |
| return None |
|
|
| input_shape = hidden_states.shape[:-1] |
| hidden_shape = (*input_shape, -1, attn_module.head_dim) |
| query_states = attn_module.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| key_states = attn_module.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| value_states = attn_module.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
| cos, sin = position_embeddings |
| cos_u = cos.unsqueeze(1) |
| sin_u = sin.unsqueeze(1) |
| query_states = (query_states * cos_u) + (rotate_half(query_states) * sin_u) |
| key_states = (key_states * cos_u) + (rotate_half(key_states) * sin_u) |
|
|
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
| key_states, value_states = past_key_values.update( |
| key_states, |
| value_states, |
| attn_module.layer_idx, |
| cache_kwargs, |
| ) |
| key_width = int(key_states.size(2)) |
| if key_width != past_width + 1: |
| raise RuntimeError( |
| f"packed varlen LLoPA cache length mismatch at layer {attn_module.layer_idx}: " |
| f"got {key_width}, expected {past_width + 1}" |
| ) |
|
|
| old_valid = layer_valid_mask.to(device=hidden_states.device, dtype=torch.bool) |
| current_valid = torch.ones((batch_size, 1), device=hidden_states.device, dtype=torch.bool) |
| full_valid = torch.cat([old_valid, current_valid], dim=1) |
| key_lengths = full_valid.sum(dim=1, dtype=torch.int32) |
| if not bool(torch.all(key_lengths > 0).item()): |
| return None |
|
|
| key_seq = key_states.transpose(1, 2).contiguous() |
| value_seq = value_states.transpose(1, 2).contiguous() |
| packed_key = key_seq[full_valid].contiguous() |
| packed_value = value_seq[full_valid].contiguous() |
| packed_query = query_states.squeeze(2).contiguous() |
|
|
| cu_q = torch.arange(batch_size + 1, device=hidden_states.device, dtype=torch.int32) |
| cu_k = torch.cat( |
| [ |
| torch.zeros((1,), device=hidden_states.device, dtype=torch.int32), |
| torch.cumsum(key_lengths, dim=0).to(dtype=torch.int32), |
| ], |
| dim=0, |
| ) |
| max_k = int(key_lengths.max().item()) |
| attn_output = _flash_attn_varlen_func( |
| packed_query, |
| packed_key, |
| packed_value, |
| cu_seqlens_q=cu_q, |
| cu_seqlens_k=cu_k, |
| max_seqlen_q=1, |
| max_seqlen_k=max_k, |
| dropout_p=0.0 if not attn_module.training else attn_module.attention_dropout, |
| softmax_scale=attn_module.scaling, |
| causal=False, |
| ) |
| if isinstance(attn_output, tuple): |
| attn_output = attn_output[0] |
| attn_output = attn_output.view(batch_size, 1, -1, attn_module.head_dim) |
| attn_output = _apply_sdpa_output_gate(attn_module, hidden_states, attn_output) |
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| return attn_module.o_proj(attn_output) |
|
|
|
|
| def _tri_run_decoder_layer_packed_varlen_decode( |
| *, |
| layer: "LlamaDecoderLayer", |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| past_key_values: Cache, |
| cache_position: torch.LongTensor, |
| layer_valid_mask: torch.Tensor, |
| use_checkpoint: bool, |
| ) -> Optional[torch.Tensor]: |
| if use_checkpoint: |
| return None |
| residual = hidden_states |
| attn_input = layer.input_layernorm(hidden_states) |
| attn_output = _tri_packed_varlen_attention_decode( |
| layer.self_attn, |
| attn_input, |
| position_embeddings, |
| past_key_values=past_key_values, |
| cache_position=cache_position, |
| layer_valid_mask=layer_valid_mask, |
| ) |
| if attn_output is None: |
| return None |
| hidden_states = residual + attn_output |
|
|
| residual = hidden_states |
| hidden_states = layer.post_attention_layernorm(hidden_states) |
| hidden_states = layer.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
| return hidden_states |
|
|
|
|
| def _tri_batch_packed_varlen_prefill_enabled(hidden_states: torch.Tensor, valid_mask: torch.Tensor) -> bool: |
| raw = str(os.environ.get("LLOPA_BATCH_PACKED_VARLEN", "1")).strip().lower() |
| if raw in {"0", "false", "no", "off"}: |
| return False |
| if _flash_attn_varlen_func is None: |
| return False |
| if not isinstance(hidden_states, torch.Tensor) or not hidden_states.is_cuda: |
| return False |
| if hidden_states.dtype not in (torch.float16, torch.bfloat16): |
| return False |
| if not isinstance(valid_mask, torch.Tensor) or valid_mask.dim() != 2 or int(valid_mask.size(0)) <= 1: |
| return False |
| lengths = valid_mask.to(device=hidden_states.device, dtype=torch.bool).sum(dim=1) |
| if not bool(torch.all(lengths > 0).item()): |
| return False |
| return True |
|
|
|
|
| def _tri_scatter_packed_hidden( |
| packed_hidden: torch.Tensor, |
| valid_mask: torch.Tensor, |
| *, |
| batch_size: int, |
| width: int, |
| ) -> torch.Tensor: |
| out = torch.zeros( |
| (int(batch_size), int(width), int(packed_hidden.size(-1))), |
| device=packed_hidden.device, |
| dtype=packed_hidden.dtype, |
| ) |
| out[valid_mask.to(device=packed_hidden.device, dtype=torch.bool)] = packed_hidden |
| return out |
|
|
|
|
| def _tri_scatter_packed_kv( |
| packed_key: torch.Tensor, |
| packed_value: torch.Tensor, |
| valid_mask: torch.Tensor, |
| *, |
| batch_size: int, |
| width: int, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| key_dense = torch.zeros( |
| ( |
| int(batch_size), |
| int(width), |
| int(packed_key.size(1)), |
| int(packed_key.size(2)), |
| ), |
| device=packed_key.device, |
| dtype=packed_key.dtype, |
| ) |
| value_dense = torch.zeros( |
| ( |
| int(batch_size), |
| int(width), |
| int(packed_value.size(1)), |
| int(packed_value.size(2)), |
| ), |
| device=packed_value.device, |
| dtype=packed_value.dtype, |
| ) |
| valid_mask = valid_mask.to(device=packed_key.device, dtype=torch.bool) |
| key_dense[valid_mask] = packed_key |
| value_dense[valid_mask] = packed_value |
| return key_dense.transpose(1, 2).contiguous(), value_dense.transpose(1, 2).contiguous() |
|
|
|
|
| def _tri_run_packed_varlen_prefill_layers( |
| *, |
| model: "LlamaModel", |
| hidden_states: torch.Tensor, |
| position_ids: torch.LongTensor, |
| valid_mask: torch.Tensor, |
| past_key_values: Cache, |
| start_layer: int, |
| end_layer: int, |
| use_checkpoint: bool, |
| ) -> Optional[torch.Tensor]: |
| if use_checkpoint or _flash_attn_varlen_func is None: |
| return None |
| if hidden_states.dim() != 2 or position_ids.dim() != 1: |
| return None |
| if int(hidden_states.size(0)) != int(position_ids.numel()): |
| return None |
| if not _tri_batch_packed_varlen_prefill_enabled(hidden_states, valid_mask): |
| return None |
|
|
| valid_mask = valid_mask.to(device=hidden_states.device, dtype=torch.bool) |
| batch_size, width = int(valid_mask.size(0)), int(valid_mask.size(1)) |
| lengths = valid_mask.sum(dim=1, dtype=torch.int32) |
| cu = torch.cat( |
| [ |
| torch.zeros((1,), device=hidden_states.device, dtype=torch.int32), |
| torch.cumsum(lengths, dim=0).to(dtype=torch.int32), |
| ], |
| dim=0, |
| ) |
| max_len = int(lengths.max().item()) |
| cache_position = torch.arange(width, device=hidden_states.device, dtype=torch.long) |
|
|
| for li in range(int(start_layer), int(end_layer)): |
| layer = model.layers[li] |
| residual = hidden_states |
| attn_input = layer.input_layernorm(hidden_states) |
| hidden_shape = (int(attn_input.size(0)), -1, layer.self_attn.head_dim) |
|
|
| query_states = layer.self_attn.q_proj(attn_input).view(hidden_shape) |
| key_states = layer.self_attn.k_proj(attn_input).view(hidden_shape) |
| value_states = layer.self_attn.v_proj(attn_input).view(hidden_shape) |
|
|
| cos, sin = model.rotary_emb(attn_input.view(1, int(attn_input.size(0)), -1), position_ids.view(1, -1)) |
| cos = cos[0].unsqueeze(1) |
| sin = sin[0].unsqueeze(1) |
| query_states = (query_states * cos) + (rotate_half(query_states) * sin) |
| key_states = (key_states * cos) + (rotate_half(key_states) * sin) |
|
|
| key_dense, value_dense = _tri_scatter_packed_kv( |
| key_states, |
| value_states, |
| valid_mask, |
| batch_size=batch_size, |
| width=width, |
| ) |
| cache_kwargs = {"sin": None, "cos": None, "cache_position": cache_position} |
| past_key_values.update(key_dense, value_dense, li, cache_kwargs) |
|
|
| attn_output = _flash_attn_varlen_func( |
| query_states.contiguous(), |
| key_states.contiguous(), |
| value_states.contiguous(), |
| cu_seqlens_q=cu, |
| cu_seqlens_k=cu, |
| max_seqlen_q=max_len, |
| max_seqlen_k=max_len, |
| dropout_p=0.0 if not layer.self_attn.training else layer.self_attn.attention_dropout, |
| softmax_scale=layer.self_attn.scaling, |
| causal=True, |
| ) |
| if isinstance(attn_output, tuple): |
| attn_output = attn_output[0] |
| attn_output = attn_output.view(int(hidden_states.size(0)), 1, -1, layer.self_attn.head_dim) |
| attn_output = _apply_sdpa_output_gate(layer.self_attn, attn_input.view(int(hidden_states.size(0)), 1, -1), attn_output) |
| attn_output = attn_output.reshape(int(hidden_states.size(0)), -1).contiguous() |
| hidden_states = residual + layer.self_attn.o_proj(attn_output) |
|
|
| residual = hidden_states |
| hidden_states = layer.post_attention_layernorm(hidden_states) |
| hidden_states = layer.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| return hidden_states |
|
|
|
|
| def _tri_append_batch_cache_valid_masks(pkv: Cache, input_valid: torch.Tensor) -> None: |
| masks = getattr(pkv, "_llopa_batch_layer_valid_masks", None) |
| if not isinstance(masks, list): |
| return |
| input_valid = input_valid.to(dtype=torch.bool) |
| new_masks = [] |
| for layer_idx, mask in enumerate(masks): |
| if not isinstance(mask, torch.Tensor) or mask.dim() != 2: |
| new_masks.append(mask) |
| continue |
| layer_len = _layer_past_len(pkv, int(layer_idx)) |
| cur_len = int(mask.size(1)) |
| if layer_len > cur_len: |
| add_width = int(layer_len - cur_len) |
| add = input_valid.to(device=mask.device).view(mask.size(0), -1) |
| if int(add.size(1)) != add_width: |
| if int(add.size(1)) == 1: |
| add = add.expand(mask.size(0), add_width) |
| else: |
| add = add[:, :add_width] |
| if int(add.size(1)) < add_width: |
| pad = torch.zeros( |
| (mask.size(0), add_width - int(add.size(1))), |
| device=mask.device, |
| dtype=torch.bool, |
| ) |
| add = torch.cat([add, pad], dim=1) |
| mask = torch.cat([mask, add], dim=1) |
| elif layer_len < cur_len: |
| mask = mask[:, :layer_len] |
| new_masks.append(mask) |
| with contextlib.suppress(Exception): |
| setattr(pkv, "_llopa_batch_layer_valid_masks", new_masks) |
| setattr(pkv, "_tri_past_len_cache", None) |
|
|
|
|
| def tri_forward_assistant( |
| self: "LlamaModel", |
| assistant_ids: torch.LongTensor, |
| lower_k: int, |
| pkv: Cache, |
| S: int, |
| U: int, |
| write_cache: bool = True, |
| prefill_mode: str = "lower", |
| no_upper_attn: bool = False, |
| align_cache_position_to_layer_past: bool = False, |
| replay_module: str = "none", |
| replay_per_layers: int = -1, |
| ): |
| device = assistant_ids.device |
| B, T = assistant_ids.shape |
| inputs_embeds = self.embed_tokens(assistant_ids) |
| prefix_global = pkv.get_seq_length() |
| cache_position = _tri_arange(prefix_global, T, device) |
| position_ids = cache_position.unsqueeze(0).expand(B, -1) |
| batch_valid_masks = getattr(pkv, "_llopa_batch_layer_valid_masks", None) |
| batch_decode_position_ids = getattr(pkv, "_llopa_batch_decode_position_ids", None) |
| if isinstance(batch_decode_position_ids, torch.Tensor): |
| batch_decode_position_ids = batch_decode_position_ids.to(device=device, dtype=torch.long) |
| if batch_decode_position_ids.shape != (B, T): |
| batch_decode_position_ids = None |
| else: |
| batch_decode_position_ids = None |
| if not ( |
| isinstance(batch_valid_masks, (list, tuple)) |
| and len(batch_valid_masks) >= len(self.layers) |
| and all(isinstance(mask, torch.Tensor) and mask.dim() == 2 for mask in batch_valid_masks[: len(self.layers)]) |
| ): |
| batch_valid_masks = None |
|
|
| attn_impl = _resolve_attn_impl(self.config) |
| implicit_causal = attn_impl == "flash_attention_2" or _can_use_implicit_causal_mask(self.config) |
| select = set(_prefill_layer_indices(self.config.num_hidden_layers, prefill_mode, lower_k)) |
| n_layers = len(self.layers) |
| has_lower = bool(select) |
| has_upper = len(select) < n_layers |
| replay_module = _normalize_replay_module( |
| replay_module or getattr(pkv, "_tri_replay_module", getattr(pkv, "_tri_last_layer_module", "none")) |
| ) |
| replay_per_layers = _normalize_replay_per_layers( |
| replay_per_layers if replay_per_layers is not None else getattr(pkv, "_tri_replay_per_layers", -1) |
| ) |
| replay_hidden = getattr(pkv, "_tri_last_layer_memory_hidden", None) |
| replay_position_ids = getattr(pkv, "_tri_last_layer_memory_position_ids", None) |
| replay_valid_mask = getattr(pkv, "_tri_last_layer_memory_valid_mask", None) |
| upper_layer_indices = [li for li in range(n_layers) if li not in select] |
| replay_layer_set = _tri_replay_layer_index_set( |
| upper_layer_indices=upper_layer_indices, |
| replay_per_layers=replay_per_layers, |
| ) |
| replay_enabled = ( |
| replay_module != "none" |
| and has_upper |
| and not bool(no_upper_attn) |
| and bool(replay_layer_set) |
| and isinstance(replay_hidden, torch.Tensor) |
| and replay_hidden.numel() > 0 |
| and int(replay_hidden.size(1)) > 0 |
| ) |
|
|
| past_lens = None |
| layer_position_offsets = None |
| if batch_valid_masks is None and bool(align_cache_position_to_layer_past): |
| cached_offsets = getattr(pkv, "_tri_layer_position_offsets", None) |
| if isinstance(cached_offsets, (list, tuple)) and len(cached_offsets) == n_layers: |
| layer_position_offsets = [int(offset) for offset in cached_offsets] |
|
|
| need_past_lens = batch_valid_masks is None and ( |
| (not implicit_causal and has_upper) |
| or (bool(align_cache_position_to_layer_past) and layer_position_offsets is None) |
| ) |
| if need_past_lens: |
| cache = getattr(pkv, "_tri_past_len_cache", None) |
| if isinstance(cache, dict): |
| last = cache.get("last_seq_len") |
| cached = cache.get("past_lens") |
| if isinstance(cached, list) and len(cached) == n_layers: |
| if isinstance(last, int) and last == prefix_global: |
| past_lens = cached |
| elif isinstance(last, int) and last == prefix_global - T: |
| past_lens = [pl + T for pl in cached] |
| if past_lens is None: |
| past_lens = [_layer_past_len(pkv, li) for li in range(n_layers)] |
| try: |
| pkv._tri_past_len_cache = {"last_seq_len": prefix_global, "past_lens": past_lens} |
| except Exception: |
| pass |
| if bool(align_cache_position_to_layer_past) and layer_position_offsets is None: |
| layer_position_offsets = [int(prefix_global - int(pl)) for pl in past_lens] |
| with contextlib.suppress(Exception): |
| setattr(pkv, "_tri_layer_position_offsets", layer_position_offsets) |
|
|
| mask_lower = None |
| mask_upper = None |
| if batch_valid_masks is None and not implicit_causal: |
| if has_lower: |
| attn_mask_tokens = torch.ones((B, prefix_global + T), device=device, dtype=torch.long) |
| mask_kwargs = { |
| "config": self.config, |
| "input_embeds": inputs_embeds, |
| "attention_mask": attn_mask_tokens, |
| "cache_position": cache_position, |
| "past_key_values": pkv, |
| "position_ids": position_ids, |
| } |
| mask_lower = create_causal_mask(**mask_kwargs) |
| if has_upper: |
| upper_past_len = None |
| upper_uniform = True |
| for li in range(n_layers): |
| if li in select: |
| continue |
| past_len = past_lens[li] if past_lens is not None else _layer_past_len(pkv, li) |
| if upper_past_len is None: |
| upper_past_len = past_len |
| elif past_len != upper_past_len: |
| upper_uniform = False |
| break |
| if upper_uniform and upper_past_len is not None: |
| mask_upper = _build_tri_mask_local(B, T, upper_past_len, device, inputs_embeds.dtype) |
|
|
| hidden_states = inputs_embeds |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
| position_embeddings_by_offset = {0: position_embeddings} if batch_valid_masks is None and batch_decode_position_ids is None else None |
| use_checkpoint = should_use_checkpoint(self) |
| packed_varlen_decode = ( |
| bool(write_cache) |
| and batch_valid_masks is not None |
| and not replay_enabled |
| and _tri_batch_packed_varlen_decode_enabled(inputs_embeds, batch_size=B, query_len=T) |
| ) |
| for li, layer in enumerate(self.layers): |
| layer_cache_position = cache_position |
| layer_position_ids = position_ids |
| layer_position_offset = 0 |
| if batch_decode_position_ids is not None: |
| layer_position_ids = batch_decode_position_ids |
| elif batch_valid_masks is not None: |
| valid_mask_li = batch_valid_masks[li].to(device=device, dtype=torch.bool) |
| layer_past_lens = valid_mask_li.sum(dim=1).to(device=device, dtype=torch.long) |
| layer_position_ids = layer_past_lens.unsqueeze(1) + torch.arange(T, device=device, dtype=torch.long).unsqueeze(0) |
| elif bool(align_cache_position_to_layer_past) and layer_position_offsets is not None: |
| offset = int(layer_position_offsets[li]) |
| if offset: |
| layer_cache_position = cache_position - offset |
| layer_position_ids = position_ids - offset |
| layer_position_offset = offset |
| if bool(no_upper_attn) and (li not in select): |
| residual = hidden_states |
| ffn_input = layer.post_attention_layernorm(hidden_states) |
| if use_checkpoint: |
| ffn_out = _checkpoint_mlp(layer.mlp, ffn_input) |
| else: |
| ffn_out = layer.mlp(ffn_input) |
| hidden_states = residual + ffn_out |
| continue |
| if position_embeddings_by_offset is not None: |
| layer_position_embeddings = position_embeddings_by_offset.get(layer_position_offset) |
| if layer_position_embeddings is None: |
| layer_position_embeddings = self.rotary_emb(hidden_states, layer_position_ids) |
| position_embeddings_by_offset[layer_position_offset] = layer_position_embeddings |
| else: |
| layer_position_embeddings = ( |
| self.rotary_emb(hidden_states, layer_position_ids) |
| if batch_valid_masks is not None or layer_position_ids is not position_ids |
| else position_embeddings |
| ) |
| attn_mask = None |
| if batch_valid_masks is not None: |
| if bool(packed_varlen_decode): |
| packed_hidden_states = _tri_run_decoder_layer_packed_varlen_decode( |
| layer=layer, |
| hidden_states=hidden_states, |
| position_embeddings=layer_position_embeddings, |
| past_key_values=pkv, |
| cache_position=layer_cache_position, |
| layer_valid_mask=batch_valid_masks[li], |
| use_checkpoint=use_checkpoint, |
| ) |
| if packed_hidden_states is not None: |
| hidden_states = packed_hidden_states |
| continue |
| attn_mask = _tri_build_valid_cache_decode_mask( |
| batch_valid_masks[li], |
| query_len=T, |
| device=device, |
| dtype=inputs_embeds.dtype, |
| ) |
| elif implicit_causal: |
| attn_mask = None |
| elif li in select: |
| attn_mask = mask_lower |
| else: |
| attn_mask = mask_upper |
| if attn_mask is None: |
| past_len_li = past_lens[li] if past_lens is not None else _layer_past_len(pkv, li) |
| attn_mask = _build_tri_mask_local(B, T, past_len_li, device, inputs_embeds.dtype) |
| use_replay_layer = replay_enabled and (li in replay_layer_set) |
| if use_replay_layer: |
| hidden_states = _tri_run_last_layer_with_replay( |
| layer=layer, |
| hidden_states=hidden_states, |
| attention_mask=attn_mask, |
| position_ids=layer_position_ids, |
| past_key_values=pkv, |
| use_cache=write_cache, |
| cache_position=layer_cache_position, |
| position_embeddings=layer_position_embeddings, |
| memory_hidden_states=replay_hidden, |
| memory_position_ids=replay_position_ids, |
| memory_valid_mask=replay_valid_mask, |
| local_valid_mask=torch.ones((B, T), device=device, dtype=torch.bool), |
| query_replay_mask=torch.ones((B, T), device=device, dtype=torch.bool), |
| module_type=replay_module, |
| rotary_emb=self.rotary_emb, |
| use_checkpoint=use_checkpoint, |
| ) |
| else: |
| hidden_states = layer( |
| hidden_states, |
| attention_mask=attn_mask, |
| position_ids=layer_position_ids, |
| past_key_values=pkv, |
| use_cache=write_cache, |
| cache_position=layer_cache_position, |
| position_embeddings=layer_position_embeddings, |
| ) |
| hidden_states = self.norm(hidden_states) |
| if batch_decode_position_ids is not None: |
| with contextlib.suppress(Exception): |
| setattr(pkv, "_llopa_batch_decode_position_ids", None) |
| return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=pkv) |
|
|
|
|
| LlamaModel.tri_prefill_system_all = tri_prefill_system_all |
| LlamaModel.tri_prefill_user_lower = tri_prefill_user_lower |
| LlamaModel.tri_build_caches = tri_build_caches |
| LlamaModel.tri_forward_assistant = tri_forward_assistant |
|
|
|
|
| def llopa_prefill_cache( |
| self: "LlamaModel", |
| system_ids: Optional[torch.LongTensor], |
| user_ids: Optional[torch.LongTensor], |
| assistant_ids: Optional[torch.LongTensor], |
| lower_k: int, |
| prefill_mode: str = "lower", |
| prefill_attn: str = "causal", |
| return_last_assistant_hidden: bool = False, |
| replay_module: str = "none", |
| replay_per_layers: int = -1, |
| replay_user_prefix_keep_len: int = 0, |
| replay_user_start: Optional[int] = None, |
| replay_user_len: Optional[int] = None, |
| ) -> Cache | tuple[Cache, Optional[torch.Tensor]]: |
| |
| if (prefill_mode or "lower") != "lower": |
| raise ValueError("llopa_prefill_cache requires prefill_mode='lower' (periodic not supported).") |
| attn_mode = (prefill_attn or "causal").strip().lower() |
| if attn_mode == "prefix_full": |
| attn_mode = "full" |
| if attn_mode not in ("causal", "full"): |
| raise ValueError("llopa_prefill_cache requires prefill_attn in {'causal','full'}.") |
| prefix_full = (attn_mode == "full") |
| replay_module = _normalize_replay_module(replay_module) |
| replay_per_layers = _normalize_replay_per_layers(replay_per_layers) |
|
|
| ref = system_ids if system_ids is not None else (user_ids if user_ids is not None else assistant_ids) |
| if ref is None: |
| raise ValueError("llopa_prefill_cache requires at least one non-empty input segment.") |
|
|
| if system_ids is None: |
| system_ids = ref[:, :0] |
| if user_ids is None: |
| user_ids = ref[:, :0] |
| if assistant_ids is None: |
| assistant_ids = ref[:, :0] |
|
|
| if system_ids.size(0) != ref.size(0) or user_ids.size(0) != ref.size(0) or assistant_ids.size(0) != ref.size(0): |
| raise ValueError("llopa_prefill_cache expects all segments to share the same batch size.") |
|
|
| n_layers = len(self.layers) |
| K = max(0, min(int(lower_k), n_layers)) |
| |
| user_ids_eff = user_ids if K > 0 else user_ids[:, :0] |
| S, U, A = system_ids.size(1), user_ids_eff.size(1), assistant_ids.size(1) |
|
|
| full_ids = torch.cat([system_ids, user_ids_eff, assistant_ids], dim=1) |
| B, L = full_ids.shape |
| pkv = _safe_dynamic_cache(self.config) |
| def _finish(last_hidden_source: Optional[torch.Tensor]): |
| if not bool(return_last_assistant_hidden): |
| return pkv |
| if A <= 0 or not isinstance(last_hidden_source, torch.Tensor) or last_hidden_source.numel() == 0: |
| return pkv, None |
| return pkv, self.norm(last_hidden_source[:, -1:, :]) |
| if L == 0: |
| return _finish(None) |
|
|
| device = full_ids.device |
| inputs_embeds = self.embed_tokens(full_ids) |
| cache_position = _tri_arange(0, L, device) |
| position_ids = cache_position.unsqueeze(0).expand(B, -1) |
|
|
| attn_impl = _resolve_attn_impl(self.config) |
| if prefix_full and attn_impl == "flash_attention_2": |
| raise ValueError("llopa_prefill_cache prefill_attn=full is not supported with flash_attention_2.") |
| if attn_impl == "flash_attention_2" or (not prefix_full and _can_use_implicit_causal_mask(self.config)): |
| lower_mask = None |
| elif prefix_full: |
| lower_mask = _build_tri_mask_prefix_full(B, S + U, L, device, inputs_embeds.dtype) |
| else: |
| attention_mask = torch.ones_like(full_ids, dtype=torch.long, device=device) |
| lower_mask = create_causal_mask( |
| config=self.config, |
| input_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| cache_position=cache_position, |
| past_key_values=None, |
| position_ids=position_ids, |
| ) |
|
|
| hidden_states = inputs_embeds |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
| use_checkpoint = should_use_checkpoint(self) |
|
|
| for li in range(K): |
| layer = self.layers[li] |
| if use_checkpoint: |
| hidden_states = checkpoint_with_cache( |
| layer.__call__, |
| hidden_states, |
| attention_mask=lower_mask, |
| position_ids=position_ids, |
| past_key_values=pkv, |
| use_cache=True, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| else: |
| hidden_states = layer( |
| hidden_states, |
| attention_mask=lower_mask, |
| position_ids=position_ids, |
| past_key_values=pkv, |
| use_cache=True, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
|
|
| |
| if K >= n_layers: |
| return _finish(hidden_states) |
|
|
| if U == 0: |
| upper_hidden = hidden_states |
| |
| |
| upper_position_ids = position_ids.to(device=upper_hidden.device) |
| upper_cache_position = cache_position.to(device=upper_hidden.device) |
| else: |
| split_device = hidden_states.device |
| upper_idx = _llopa_upper_indices(S, U, A, split_device) |
| upper_hidden = hidden_states.index_select(dim=1, index=upper_idx) |
| upper_position_ids = upper_idx.unsqueeze(0).expand(B, -1) |
| upper_cache_position = upper_idx |
|
|
| if upper_hidden.size(1) == 0: |
| return _finish(None) |
|
|
| replay_hidden = None |
| replay_position_ids = None |
| replay_valid_mask = None |
| replay_query_mask = None |
| upper_layer_indices = list(range(K, n_layers)) |
| replay_layer_set = _tri_replay_layer_index_set( |
| upper_layer_indices=upper_layer_indices, |
| replay_per_layers=replay_per_layers, |
| ) |
| replay_enabled = replay_module != "none" and U > 0 and K < n_layers and bool(replay_layer_set) |
| if replay_enabled: |
| replay_prefix_keep_len, replay_start, replay_end = _tri_resolve_replay_user_window( |
| user_len=U, |
| prefix_keep_len=int(replay_user_prefix_keep_len or 0), |
| latest_user_start=replay_user_start, |
| latest_user_len=replay_user_len, |
| ) |
| replay_hidden_parts = [] |
| replay_position_parts = [] |
| if replay_prefix_keep_len > 0: |
| replay_hidden_parts.append(hidden_states[:, S : S + replay_prefix_keep_len, :]) |
| replay_position_parts.append(position_ids[:, S : S + replay_prefix_keep_len]) |
| if replay_end > replay_start: |
| replay_hidden_parts.append(hidden_states[:, S + replay_start : S + replay_end, :]) |
| replay_position_parts.append(position_ids[:, S + replay_start : S + replay_end]) |
| if replay_hidden_parts: |
| replay_hidden = ( |
| replay_hidden_parts[0] |
| if len(replay_hidden_parts) == 1 |
| else torch.cat(replay_hidden_parts, dim=1) |
| ).to(device=upper_hidden.device) |
| replay_position_ids = ( |
| replay_position_parts[0] |
| if len(replay_position_parts) == 1 |
| else torch.cat(replay_position_parts, dim=1) |
| ).to(device=upper_hidden.device, dtype=torch.long) |
| replay_valid_mask = torch.ones( |
| (B, int(replay_hidden.size(1))), |
| device=upper_hidden.device, |
| dtype=torch.bool, |
| ) |
| replay_query_mask = _tri_build_last_layer_query_mask( |
| prefix_keep_lens=torch.full((B,), int(S), device=upper_hidden.device, dtype=torch.long), |
| valid_lens=torch.full((B,), int(upper_hidden.size(1)), device=upper_hidden.device, dtype=torch.long), |
| total_len=int(upper_hidden.size(1)), |
| device=upper_hidden.device, |
| ) |
| if isinstance(replay_hidden, torch.Tensor) and replay_hidden.numel() > 0 and int(replay_hidden.size(1)) > 0: |
| _tri_store_last_layer_memory( |
| pkv, |
| hidden_states=replay_hidden, |
| position_ids=replay_position_ids, |
| valid_mask=replay_valid_mask, |
| module_type=replay_module, |
| replay_per_layers=replay_per_layers, |
| ) |
|
|
| num_specials_eff = _llopa_effective_specials(self, U) |
|
|
| if attn_impl == "flash_attention_2" or (not prefix_full and _can_use_implicit_causal_mask(self.config)): |
| upper_mask = None |
| elif prefix_full: |
| upper_mask = _build_tri_mask_prefix_full(B, S, upper_hidden.size(1), device, upper_hidden.dtype) |
| else: |
| upper_mask = _build_tri_mask_local(B, upper_hidden.size(1), 0, device, upper_hidden.dtype) |
|
|
| upper_pos_emb = self.rotary_emb(upper_hidden, upper_position_ids) |
|
|
| if num_specials_eff > 0: |
| special_pos = torch.arange( |
| S, |
| S + num_specials_eff, |
| device=upper_cache_position.device, |
| dtype=upper_cache_position.dtype, |
| ) |
| special_pos_b = special_pos.unsqueeze(0).expand(B, -1) |
| upper_position_ids_sp = torch.cat( |
| [upper_position_ids[:, :S], special_pos_b, upper_position_ids[:, S:]], |
| dim=1, |
| ) |
| upper_cache_position_sp = torch.cat( |
| [upper_cache_position[:S], special_pos, upper_cache_position[S:]], |
| dim=0, |
| ) |
| if attn_impl == "flash_attention_2" or (not prefix_full and _can_use_implicit_causal_mask(self.config)): |
| upper_mask_sp = None |
| elif prefix_full: |
| upper_mask_sp = _build_tri_mask_prefix_full( |
| B, S, upper_position_ids_sp.size(1), device, upper_hidden.dtype |
| ) |
| else: |
| upper_mask_sp = _build_tri_mask_local( |
| B, upper_position_ids_sp.size(1), 0, device, upper_hidden.dtype |
| ) |
| else: |
| upper_position_ids_sp = None |
| upper_cache_position_sp = None |
| upper_mask_sp = None |
|
|
| for li in range(K, n_layers): |
| layer = self.layers[li] |
| llopa_special_kv = None |
| llopa_special_span = None |
| if num_specials_eff > 0: |
| specials = self.llopa_specials[li][:num_specials_eff] |
| specials = specials.to(dtype=upper_hidden.dtype, device=upper_hidden.device) |
| specials = specials.unsqueeze(0).expand(B, -1, -1) |
| layer_hidden = torch.cat( |
| [upper_hidden[:, :S, :], specials, upper_hidden[:, S:, :]], |
| dim=1, |
| ) |
| layer_position_ids = upper_position_ids_sp |
| layer_cache_position = upper_cache_position_sp |
| attn_mask = upper_mask_sp |
| pos_emb = self.rotary_emb(layer_hidden, layer_position_ids) |
| if _llopa_can_cache_special_kv(self): |
| llopa_special_kv = _llopa_special_kv_for_layer( |
| self, |
| layer, |
| layer_idx=li, |
| num_specials_eff=num_specials_eff, |
| start_pos=S, |
| batch_size=B, |
| dtype=upper_hidden.dtype, |
| device=upper_hidden.device, |
| ) |
| llopa_special_span = (S, S + num_specials_eff) |
| else: |
| layer_hidden = upper_hidden |
| layer_position_ids = upper_position_ids |
| layer_cache_position = upper_cache_position |
| attn_mask = upper_mask |
| pos_emb = upper_pos_emb |
| use_replay_layer = ( |
| replay_enabled |
| and li in replay_layer_set |
| and num_specials_eff == 0 |
| and isinstance(replay_hidden, torch.Tensor) |
| and replay_hidden.numel() > 0 |
| ) |
| if use_replay_layer: |
| layer_out = _tri_run_last_layer_with_replay( |
| layer=layer, |
| hidden_states=layer_hidden, |
| attention_mask=attn_mask, |
| position_ids=layer_position_ids, |
| past_key_values=pkv, |
| use_cache=True, |
| cache_position=layer_cache_position, |
| position_embeddings=pos_emb, |
| memory_hidden_states=replay_hidden, |
| memory_position_ids=replay_position_ids, |
| memory_valid_mask=replay_valid_mask, |
| local_valid_mask=torch.ones((B, layer_hidden.size(1)), device=layer_hidden.device, dtype=torch.bool), |
| query_replay_mask=replay_query_mask, |
| module_type=replay_module, |
| rotary_emb=self.rotary_emb, |
| use_checkpoint=use_checkpoint, |
| ) |
| elif use_checkpoint: |
| layer_out = checkpoint_with_cache( |
| layer.__call__, |
| layer_hidden, |
| attention_mask=attn_mask, |
| position_ids=layer_position_ids, |
| past_key_values=pkv, |
| use_cache=True, |
| cache_position=layer_cache_position, |
| position_embeddings=pos_emb, |
| llopa_special_kv=llopa_special_kv, |
| llopa_special_span=llopa_special_span, |
| ) |
| else: |
| layer_out = layer( |
| layer_hidden, |
| attention_mask=attn_mask, |
| position_ids=layer_position_ids, |
| past_key_values=pkv, |
| use_cache=True, |
| cache_position=layer_cache_position, |
| position_embeddings=pos_emb, |
| llopa_special_kv=llopa_special_kv, |
| llopa_special_span=llopa_special_span, |
| ) |
| if num_specials_eff > 0: |
| upper_hidden = torch.cat( |
| [layer_out[:, :S, :], layer_out[:, S + num_specials_eff :, :]], |
| dim=1, |
| ) |
| else: |
| upper_hidden = layer_out |
|
|
| return _finish(upper_hidden) |
|
|
|
|
| def llopa_forward( |
| self: "LlamaModel", |
| system_ids: Optional[torch.LongTensor], |
| user_ids: Optional[torch.LongTensor], |
| assistant_ids: Optional[torch.LongTensor], |
| lower_k: int, |
| system_attention_mask: Optional[torch.LongTensor] = None, |
| user_attention_mask: Optional[torch.LongTensor] = None, |
| assistant_attention_mask: Optional[torch.LongTensor] = None, |
| prefill_mode: str = "lower", |
| prefill_attn: str = "causal", |
| no_upper_attn: bool = False, |
| ): |
| |
| if (prefill_mode or "lower") != "lower": |
| raise ValueError("llopa requires prefill_mode='lower' (periodic not supported).") |
| attn_mode = (prefill_attn or "causal").strip().lower() |
| if attn_mode == "prefix_full": |
| attn_mode = "full" |
| if attn_mode not in ("causal", "full"): |
| raise ValueError("llopa requires prefill_attn in {'causal','full'}.") |
| prefix_full = (attn_mode == "full") |
|
|
| ref = system_ids if system_ids is not None else (user_ids if user_ids is not None else assistant_ids) |
| if ref is None: |
| raise ValueError("llopa_forward requires at least one non-empty input segment.") |
|
|
| if system_ids is None: |
| system_ids = ref[:, :0] |
| if user_ids is None: |
| user_ids = ref[:, :0] |
| if assistant_ids is None: |
| assistant_ids = ref[:, :0] |
| system_attention_mask = _llopa_default_attention_mask(system_ids, system_attention_mask) |
| user_attention_mask = _llopa_default_attention_mask(user_ids, user_attention_mask) |
| assistant_attention_mask = _llopa_default_attention_mask(assistant_ids, assistant_attention_mask) |
|
|
| if system_ids.size(0) != ref.size(0) or user_ids.size(0) != ref.size(0) or assistant_ids.size(0) != ref.size(0): |
| raise ValueError("llopa_forward expects all segments to share the same batch size.") |
|
|
| n_layers = len(self.layers) |
| K = max(0, min(int(lower_k), n_layers)) |
|
|
| |
| user_ids_eff = user_ids if K > 0 else user_ids[:, :0] |
| user_attention_mask_eff = user_attention_mask if K > 0 else user_attention_mask[:, :0] |
| S, U, A = system_ids.size(1), user_ids_eff.size(1), assistant_ids.size(1) |
|
|
| full_ids = torch.cat([system_ids, user_ids_eff, assistant_ids], dim=1) |
| full_attention_mask = torch.cat([system_attention_mask, user_attention_mask_eff, assistant_attention_mask], dim=1) |
| B, L = full_ids.shape |
| if L == 0: |
| empty = full_ids.new_zeros((B, 0, self.config.hidden_size)) |
| return BaseModelOutputWithPast(last_hidden_state=empty, past_key_values=None) |
|
|
| device = full_ids.device |
| _llopa_reset_peak_memory(full_ids) |
| inputs_embeds = self.embed_tokens(full_ids) |
| cache_position = _tri_arange(0, L, device) |
| position_ids = _llopa_position_ids_from_mask(full_attention_mask) |
| _llopa_log_memory( |
| self, |
| "after_embed", |
| system_ids=system_ids, |
| user_ids=user_ids_eff, |
| assistant_ids=assistant_ids, |
| system_attention_mask=system_attention_mask, |
| user_attention_mask=user_attention_mask_eff, |
| assistant_attention_mask=assistant_attention_mask, |
| sequence_len=L, |
| ) |
|
|
| attn_impl = _resolve_attn_impl(self.config) |
| if prefix_full and attn_impl == "flash_attention_2": |
| raise ValueError("llopa prefill_attn=full is not supported with flash_attention_2.") |
| if prefix_full and not _llopa_mask_is_all_ones(full_attention_mask): |
| raise ValueError("llopa padded batches require prefill_attn='causal'.") |
| if attn_impl == "flash_attention_2": |
| lower_mask = None if _llopa_mask_is_all_ones(full_attention_mask) else full_attention_mask |
| elif not prefix_full and _can_use_implicit_causal_mask(self.config) and _llopa_mask_is_all_ones(full_attention_mask): |
| lower_mask = None |
| elif prefix_full: |
| lower_mask = _build_tri_mask_prefix_full(B, S + U, L, device, inputs_embeds.dtype) |
| else: |
| lower_mask = create_causal_mask( |
| config=self.config, |
| input_embeds=inputs_embeds, |
| attention_mask=full_attention_mask, |
| cache_position=cache_position, |
| past_key_values=None, |
| position_ids=position_ids, |
| ) |
|
|
| hidden_states = inputs_embeds |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
| use_checkpoint = should_use_checkpoint(self) |
| use_optimized_training_path = _llopa_use_optimized_training_path( |
| self, |
| attn_impl=attn_impl, |
| prefix_full=prefix_full, |
| no_upper_attn=bool(no_upper_attn), |
| ) |
| use_custom_checkpoint = use_checkpoint and not use_optimized_training_path |
|
|
| _llopa_reset_peak_memory(hidden_states) |
| for li in range(K): |
| layer = self.layers[li] |
| if use_custom_checkpoint: |
| hidden_states = checkpoint_with_cache( |
| layer.__call__, |
| hidden_states, |
| attention_mask=lower_mask, |
| position_ids=position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| else: |
| hidden_states = layer( |
| hidden_states, |
| attention_mask=lower_mask, |
| position_ids=position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| _llopa_log_memory( |
| self, |
| "after_lower", |
| system_ids=system_ids, |
| user_ids=user_ids_eff, |
| assistant_ids=assistant_ids, |
| system_attention_mask=system_attention_mask, |
| user_attention_mask=user_attention_mask_eff, |
| assistant_attention_mask=assistant_attention_mask, |
| sequence_len=L, |
| ) |
|
|
| |
| if K >= n_layers: |
| hidden_states = self.norm(hidden_states) |
| assistant_hidden = hidden_states[:, S + U :, :] |
| return BaseModelOutputWithPast(last_hidden_state=assistant_hidden, past_key_values=None) |
|
|
| |
| if U == 0: |
| upper_hidden = hidden_states |
| upper_attention_mask = full_attention_mask.to(device=upper_hidden.device) |
| upper_position_ids = position_ids.to(device=upper_hidden.device) |
| else: |
| _llopa_reset_peak_memory(hidden_states) |
| upper_hidden = _llopa_pack_upper_hidden(hidden_states, S, U) |
| upper_attention_mask = _llopa_pack_upper_tokens(full_attention_mask.to(device=hidden_states.device), S, U) |
| upper_position_ids = _llopa_pack_upper_tokens(position_ids.to(device=hidden_states.device), S, U) |
| upper_cache_position = _tri_arange(0, upper_hidden.size(1), upper_hidden.device) |
| _llopa_log_memory( |
| self, |
| "after_upper_compaction", |
| system_ids=system_ids, |
| user_ids=user_ids_eff, |
| assistant_ids=assistant_ids, |
| system_attention_mask=system_attention_mask, |
| user_attention_mask=user_attention_mask_eff, |
| assistant_attention_mask=assistant_attention_mask, |
| sequence_len=upper_hidden.size(1), |
| ) |
|
|
| if bool(no_upper_attn): |
| for li in range(K, n_layers): |
| layer = self.layers[li] |
| residual = upper_hidden |
| ffn_input = layer.post_attention_layernorm(upper_hidden) |
| if use_checkpoint: |
| ffn_out = _checkpoint_mlp(layer.mlp, ffn_input) |
| else: |
| ffn_out = layer.mlp(ffn_input) |
| upper_hidden = residual + ffn_out |
| upper_hidden = self.norm(upper_hidden) |
| assistant_hidden = upper_hidden[:, S:, :] |
| return BaseModelOutputWithPast(last_hidden_state=assistant_hidden, past_key_values=None) |
|
|
| num_specials_eff = _llopa_effective_specials(self, U) |
|
|
| if attn_impl == "flash_attention_2": |
| upper_mask = None if _llopa_mask_is_all_ones(upper_attention_mask) else upper_attention_mask |
| elif not prefix_full and _can_use_implicit_causal_mask(self.config) and _llopa_mask_is_all_ones(upper_attention_mask): |
| upper_mask = None |
| elif prefix_full: |
| upper_mask = _build_tri_mask_prefix_full(B, S, upper_hidden.size(1), device, upper_hidden.dtype) |
| else: |
| upper_mask = create_causal_mask( |
| config=self.config, |
| input_embeds=upper_hidden, |
| attention_mask=upper_attention_mask, |
| cache_position=upper_cache_position, |
| past_key_values=None, |
| position_ids=upper_position_ids, |
| ) |
|
|
| upper_pos_emb = self.rotary_emb(upper_hidden, upper_position_ids) |
|
|
| if num_specials_eff > 0: |
| special_steps = torch.arange(num_specials_eff, device=upper_hidden.device, dtype=torch.long).unsqueeze(0) |
| system_lengths = system_attention_mask.sum(dim=1, dtype=torch.long).to(device=upper_hidden.device).unsqueeze(1) |
| special_pos_b = system_lengths + special_steps |
| upper_position_ids_sp = torch.cat( |
| [upper_position_ids[:, :S], special_pos_b, upper_position_ids[:, S:]], |
| dim=1, |
| ) |
| upper_cache_position_sp = _tri_arange(0, upper_position_ids_sp.size(1), upper_hidden.device) |
| special_mask = torch.ones((B, num_specials_eff), device=upper_hidden.device, dtype=upper_attention_mask.dtype) |
| upper_attention_mask_sp = torch.cat( |
| [upper_attention_mask[:, :S], special_mask, upper_attention_mask[:, S:]], |
| dim=1, |
| ) |
| if attn_impl == "flash_attention_2": |
| upper_mask_sp = None if _llopa_mask_is_all_ones(upper_attention_mask_sp) else upper_attention_mask_sp |
| elif not prefix_full and _can_use_implicit_causal_mask(self.config) and _llopa_mask_is_all_ones(upper_attention_mask_sp): |
| upper_mask_sp = None |
| elif prefix_full: |
| upper_mask_sp = _build_tri_mask_prefix_full( |
| B, S, upper_position_ids_sp.size(1), device, upper_hidden.dtype |
| ) |
| else: |
| pad_hidden = upper_hidden.new_empty((B, upper_position_ids_sp.size(1), upper_hidden.size(-1))) |
| upper_mask_sp = create_causal_mask( |
| config=self.config, |
| input_embeds=pad_hidden, |
| attention_mask=upper_attention_mask_sp, |
| cache_position=upper_cache_position_sp, |
| past_key_values=None, |
| position_ids=upper_position_ids_sp, |
| ) |
| else: |
| upper_position_ids_sp = None |
| upper_cache_position_sp = None |
| upper_mask_sp = None |
|
|
| for li in range(K, n_layers): |
| layer = self.layers[li] |
| _llopa_reset_peak_memory(upper_hidden) |
| if num_specials_eff > 0: |
| specials = self.llopa_specials[li][:num_specials_eff] |
| specials = specials.to(dtype=upper_hidden.dtype, device=upper_hidden.device) |
| specials = specials.unsqueeze(0).expand(B, -1, -1) |
| layer_hidden = _llopa_insert_specials(upper_hidden, specials, S) |
| layer_position_ids = upper_position_ids_sp |
| layer_cache_position = upper_cache_position_sp |
| attn_mask = upper_mask_sp |
| pos_emb = self.rotary_emb(layer_hidden, layer_position_ids) |
| else: |
| layer_hidden = upper_hidden |
| layer_position_ids = upper_position_ids |
| layer_cache_position = upper_cache_position |
| attn_mask = upper_mask |
| pos_emb = upper_pos_emb |
| if use_custom_checkpoint: |
| layer_out = checkpoint_with_cache( |
| layer.__call__, |
| layer_hidden, |
| attention_mask=attn_mask, |
| position_ids=layer_position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=layer_cache_position, |
| position_embeddings=pos_emb, |
| ) |
| else: |
| layer_out = layer( |
| layer_hidden, |
| attention_mask=attn_mask, |
| position_ids=layer_position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=layer_cache_position, |
| position_embeddings=pos_emb, |
| ) |
| if num_specials_eff > 0: |
| upper_hidden = _llopa_remove_specials(layer_out, S, num_specials_eff) |
| else: |
| upper_hidden = layer_out |
| _llopa_log_memory( |
| self, |
| f"after_upper_layer_{li}", |
| system_ids=system_ids, |
| user_ids=user_ids_eff, |
| assistant_ids=assistant_ids, |
| system_attention_mask=system_attention_mask, |
| user_attention_mask=user_attention_mask_eff, |
| assistant_attention_mask=assistant_attention_mask, |
| sequence_len=upper_hidden.size(1), |
| ) |
|
|
| upper_hidden = self.norm(upper_hidden) |
| assistant_hidden = upper_hidden[:, S:, :] |
| return BaseModelOutputWithPast(last_hidden_state=assistant_hidden, past_key_values=None) |
|
|
|
|
| LlamaModel.llopa_forward = llopa_forward |
| LlamaModel.llopa_prefill_cache = llopa_prefill_cache |
|
|
|
|
| def llopa_step_logits( |
| self: "LlamaForCausalLM", |
| system_ids: Optional[torch.LongTensor], |
| user_ids: Optional[torch.LongTensor], |
| assistant_ids: Optional[torch.LongTensor], |
| lower_k: int, |
| system_attention_mask: Optional[torch.LongTensor] = None, |
| user_attention_mask: Optional[torch.LongTensor] = None, |
| assistant_attention_mask: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| labels: Optional[torch.LongTensor] = None, |
| prefill_mode: str = "lower", |
| prefill_attn: str = "causal", |
| no_upper_attn: bool = False, |
| ): |
| out = self.model.llopa_forward( |
| system_ids=system_ids, |
| system_attention_mask=system_attention_mask, |
| user_ids=user_ids, |
| user_attention_mask=user_attention_mask, |
| assistant_ids=assistant_ids, |
| assistant_attention_mask=assistant_attention_mask, |
| lower_k=lower_k, |
| prefill_mode=prefill_mode, |
| prefill_attn=prefill_attn, |
| no_upper_attn=no_upper_attn, |
| ) |
| hidden_states = out.last_hidden_state |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| _llopa_reset_peak_memory(hidden_states) |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
| loss = None |
| if labels is not None: |
| loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) |
| _llopa_log_memory( |
| self.model, |
| "after_loss", |
| system_ids=system_ids, |
| user_ids=user_ids, |
| assistant_ids=assistant_ids, |
| system_attention_mask=system_attention_mask, |
| user_attention_mask=user_attention_mask, |
| assistant_attention_mask=assistant_attention_mask, |
| sequence_len=hidden_states.size(1), |
| ) |
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=None, |
| hidden_states=hidden_states, |
| ) |
|
|
|
|
| LlamaForCausalLM.llopa_step_logits = llopa_step_logits |
|
|
|
|
| def _tri_pack_prefill_training_segments( |
| *, |
| model, |
| system_ids: Optional[torch.LongTensor], |
| user_ids: Optional[torch.LongTensor], |
| assistant_ids: Optional[torch.LongTensor], |
| system_attention_mask: Optional[torch.Tensor] = None, |
| user_attention_mask: Optional[torch.Tensor] = None, |
| assistant_attention_mask: Optional[torch.Tensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| ): |
| ref = system_ids if system_ids is not None else (user_ids if user_ids is not None else assistant_ids) |
| if ref is None: |
| raise ValueError("segmented prefill step requires at least one non-empty input segment.") |
|
|
| device = ref.device |
| batch_size = int(ref.size(0)) |
| if system_ids is None: |
| system_ids = ref[:, :0] |
| if user_ids is None: |
| user_ids = ref[:, :0] |
| if assistant_ids is None: |
| assistant_ids = ref[:, :0] |
| system_attention_mask = _llopa_default_attention_mask(system_ids, system_attention_mask) |
| user_attention_mask = _llopa_default_attention_mask(user_ids, user_attention_mask) |
| assistant_attention_mask = _llopa_default_attention_mask(assistant_ids, assistant_attention_mask) |
| if labels is None: |
| labels = assistant_ids.new_full(assistant_ids.shape, -100) |
|
|
| if ( |
| system_ids.size(0) != batch_size |
| or user_ids.size(0) != batch_size |
| or assistant_ids.size(0) != batch_size |
| or labels.size(0) != batch_size |
| ): |
| raise ValueError("segmented prefill step expects all segments to share the same batch size.") |
|
|
| pad_token_id = getattr(model.config, "pad_token_id", None) |
| if pad_token_id is None: |
| pad_token_id = 0 |
|
|
| packed_input_rows = [] |
| packed_label_rows = [] |
| split_starts = [] |
| system_lens = [] |
| for row in range(batch_size): |
| sys_len = int(system_attention_mask[row].sum().item()) |
| user_len = int(user_attention_mask[row].sum().item()) |
| assistant_len = int(assistant_attention_mask[row].sum().item()) |
| if assistant_len <= 0: |
| continue |
| sys_row = system_ids[row : row + 1, :sys_len] |
| user_row = user_ids[row : row + 1, :user_len] |
| assistant_row = assistant_ids[row : row + 1, :assistant_len] |
| label_row = labels[row : row + 1, :assistant_len] |
| packed_input_rows.append(torch.cat([sys_row, user_row, assistant_row], dim=1)) |
| packed_label_rows.append( |
| torch.cat( |
| [ |
| torch.full((1, sys_len + user_len), -100, dtype=label_row.dtype, device=device), |
| label_row, |
| ], |
| dim=1, |
| ) |
| ) |
| split_starts.append(sys_len + user_len) |
| system_lens.append(sys_len) |
|
|
| if not packed_input_rows: |
| return None |
|
|
| max_len = max(int(row.size(1)) for row in packed_input_rows) |
| input_ids = torch.full((len(packed_input_rows), max_len), int(pad_token_id), dtype=system_ids.dtype, device=device) |
| attention_mask = torch.zeros((len(packed_input_rows), max_len), dtype=torch.long, device=device) |
| full_labels = torch.full((len(packed_label_rows), max_len), -100, dtype=labels.dtype, device=device) |
| for row_idx, (input_row, label_row) in enumerate(zip(packed_input_rows, packed_label_rows)): |
| width = int(input_row.size(1)) |
| input_ids[row_idx, :width] = input_row.squeeze(0) |
| attention_mask[row_idx, :width] = 1 |
| full_labels[row_idx, :width] = label_row.squeeze(0) |
|
|
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "labels": full_labels, |
| "split_starts": torch.tensor(split_starts, dtype=torch.long, device=device), |
| "system_lens": torch.tensor(system_lens, dtype=torch.long, device=device), |
| } |
|
|
|
|
| def segmented_prefill_lower_freeze_step_logits( |
| self: "LlamaForCausalLM", |
| *, |
| system_ids: Optional[torch.LongTensor], |
| user_ids: Optional[torch.LongTensor], |
| assistant_ids: Optional[torch.LongTensor], |
| lower_k: int, |
| system_attention_mask: Optional[torch.LongTensor] = None, |
| user_attention_mask: Optional[torch.LongTensor] = None, |
| assistant_attention_mask: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| labels: Optional[torch.LongTensor] = None, |
| prefill_attn: str = "causal", |
| system_prefill: str = "no_bos_system", |
| ): |
| packed = _tri_pack_prefill_training_segments( |
| model=self, |
| system_ids=system_ids, |
| user_ids=user_ids, |
| assistant_ids=assistant_ids, |
| system_attention_mask=system_attention_mask, |
| user_attention_mask=user_attention_mask, |
| assistant_attention_mask=assistant_attention_mask, |
| labels=labels, |
| ) |
| if packed is None: |
| return None |
| return self.tri_vanilla_frozen_prefix_train_forward( |
| input_ids=packed["input_ids"], |
| attention_mask=packed["attention_mask"], |
| labels=packed["labels"], |
| use_cache=False, |
| logits_to_keep=logits_to_keep, |
| past_key_values=None, |
| prefill_lower_layers=int(lower_k), |
| prefill_attn=str(prefill_attn), |
| split_start_hint=packed["split_starts"], |
| system_len_hint=packed["system_lens"], |
| system_prefill=str(system_prefill), |
| ) |
|
|
|
|
| def segmented_prefill_lower_solo_step_logits( |
| self: "LlamaForCausalLM", |
| *, |
| system_ids: Optional[torch.LongTensor], |
| user_ids: Optional[torch.LongTensor], |
| assistant_ids: Optional[torch.LongTensor], |
| lower_k: int, |
| system_attention_mask: Optional[torch.LongTensor] = None, |
| user_attention_mask: Optional[torch.LongTensor] = None, |
| assistant_attention_mask: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| labels: Optional[torch.LongTensor] = None, |
| prefill_attn: str = "causal", |
| system_prefill: str = "no_bos_system", |
| ): |
| packed = _tri_pack_prefill_training_segments( |
| model=self, |
| system_ids=system_ids, |
| user_ids=user_ids, |
| assistant_ids=assistant_ids, |
| system_attention_mask=system_attention_mask, |
| user_attention_mask=user_attention_mask, |
| assistant_attention_mask=assistant_attention_mask, |
| labels=labels, |
| ) |
| if packed is None: |
| return None |
| return self.tri_vanilla_solo_prefix_train_forward( |
| input_ids=packed["input_ids"], |
| attention_mask=packed["attention_mask"], |
| labels=packed["labels"], |
| logits_to_keep=logits_to_keep, |
| prefill_lower_layers=int(lower_k), |
| prefill_attn=str(prefill_attn), |
| split_start_hint=packed["split_starts"], |
| system_len_hint=packed["system_lens"], |
| system_prefill=str(system_prefill), |
| ) |
|
|
|
|
| def segmented_prefill_lower_solo_v2_step_logits( |
| self: "LlamaForCausalLM", |
| *, |
| system_ids: Optional[torch.LongTensor], |
| user_ids: Optional[torch.LongTensor], |
| assistant_ids: Optional[torch.LongTensor], |
| lower_k: int, |
| system_attention_mask: Optional[torch.LongTensor] = None, |
| user_attention_mask: Optional[torch.LongTensor] = None, |
| assistant_attention_mask: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| labels: Optional[torch.LongTensor] = None, |
| prefill_attn: str = "causal", |
| system_prefill: str = "no_bos_system", |
| with_bos: bool = False, |
| ): |
| packed = _tri_pack_prefill_training_segments( |
| model=self, |
| system_ids=system_ids, |
| user_ids=user_ids, |
| assistant_ids=assistant_ids, |
| system_attention_mask=system_attention_mask, |
| user_attention_mask=user_attention_mask, |
| assistant_attention_mask=assistant_attention_mask, |
| labels=labels, |
| ) |
| if packed is None: |
| return None |
| return self.tri_vanilla_solo_attention_v2_train_forward( |
| input_ids=packed["input_ids"], |
| attention_mask=packed["attention_mask"], |
| labels=packed["labels"], |
| logits_to_keep=logits_to_keep, |
| prefill_lower_layers=int(lower_k), |
| prefill_attn=str(prefill_attn), |
| split_start_hint=packed["split_starts"], |
| system_len_hint=packed["system_lens"], |
| system_prefill=str(system_prefill), |
| with_bos=bool(with_bos), |
| ) |
|
|
|
|
| LlamaForCausalLM.segmented_prefill_lower_freeze_step_logits = segmented_prefill_lower_freeze_step_logits |
| LlamaForCausalLM.segmented_prefill_lower_solo_step_logits = segmented_prefill_lower_solo_step_logits |
| LlamaForCausalLM.segmented_prefill_lower_solo_v2_step_logits = segmented_prefill_lower_solo_v2_step_logits |
|
|
|
|
| def tri_step_logits( |
| self: "LlamaForCausalLM", |
| assistant_ids: torch.LongTensor, |
| lower_k: int, |
| pkv: Cache, |
| S: int, |
| U: int, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| labels: Optional[torch.LongTensor] = None, |
| write_cache: bool = True, |
| prefill_mode: str = "lower", |
| no_upper_attn: bool = False, |
| align_cache_position_to_layer_past: bool = False, |
| replay_module: str = "none", |
| replay_per_layers: int = -1, |
| ): |
| out = self.model.tri_forward_assistant( |
| assistant_ids=assistant_ids, |
| lower_k=lower_k, |
| pkv=pkv, |
| S=S, |
| U=U, |
| write_cache=write_cache, |
| prefill_mode=prefill_mode, |
| no_upper_attn=no_upper_attn, |
| align_cache_position_to_layer_past=align_cache_position_to_layer_past, |
| replay_module=replay_module, |
| replay_per_layers=replay_per_layers, |
| ) |
| hidden_states = out.last_hidden_state |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
| loss = None |
| if labels is not None: |
| loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) |
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=pkv if write_cache else None, |
| hidden_states=hidden_states, |
| ) |
|
|
|
|
| LlamaForCausalLM.tri_step_logits = tri_step_logits |
| LlamaForCausalLM.llopa_decode_step_logits = tri_step_logits |
|
|
|
|
| def tri_prefill_lower_cache( |
| self: "LlamaModel", |
| prefix_ids: torch.LongTensor, |
| lower_k: int, |
| prefill_mode: str = "lower", |
| prefill_attn: str = "causal", |
| ) -> Cache: |
| pkv = _safe_dynamic_cache(self.config) |
| if prefix_ids is None or prefix_ids.size(1) == 0: |
| return pkv |
|
|
| device = prefix_ids.device |
| inputs_embeds = self.embed_tokens(prefix_ids) |
| cache_position = _tri_arange(0, inputs_embeds.shape[1], device) |
| position_ids = cache_position.unsqueeze(0).expand(prefix_ids.size(0), -1) |
| attention_mask = torch.ones_like(prefix_ids, dtype=torch.long, device=device) |
|
|
| attn_impl = _resolve_attn_impl(self.config) |
| attn_mode = (prefill_attn or "causal").strip().lower() |
| if attn_mode == "prefix_full": |
| attn_mode = "full" |
| if attn_mode == "full": |
| if attn_impl == "flash_attention_2": |
| raise ValueError("prefill_lower_attn=full is not supported with flash_attention_2.") |
| causal_mask = _build_tri_mask_full(prefix_ids.size(0), inputs_embeds.shape[1], 0, device, inputs_embeds.dtype) |
| elif attn_impl == "flash_attention_2" or _can_use_implicit_causal_mask(self.config): |
| causal_mask = None |
| else: |
| causal_mask = create_causal_mask( |
| config=self.config, |
| input_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| cache_position=cache_position, |
| past_key_values=None, |
| position_ids=position_ids, |
| ) |
|
|
| hidden_states = inputs_embeds |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
| use_checkpoint = should_use_checkpoint(self) |
| for li in _prefill_layer_indices(self.config.num_hidden_layers, prefill_mode, lower_k): |
| layer = self.layers[li] |
| if use_checkpoint: |
| hidden_states = checkpoint_with_cache( |
| layer.__call__, |
| hidden_states, |
| attention_mask=causal_mask, |
| position_ids=position_ids, |
| past_key_values=pkv, |
| use_cache=True, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| else: |
| hidden_states = layer( |
| hidden_states, |
| attention_mask=causal_mask, |
| position_ids=position_ids, |
| past_key_values=pkv, |
| use_cache=True, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| return pkv |
|
|
|
|
| LlamaModel.tri_prefill_lower_cache = tri_prefill_lower_cache |
|
|
|
|
| def _tri_vanilla_warn_once(model, flag: str, msg: str) -> None: |
| if getattr(model, flag, False): |
| return |
| try: |
| setattr(model, flag, True) |
| except Exception: |
| pass |
| logger.warning(msg) |
|
|
|
|
| def _tri_parse_split_start_hint(split_start_hint, batch_size: int) -> Optional[int]: |
| if split_start_hint is None: |
| return None |
| if isinstance(split_start_hint, torch.Tensor): |
| if split_start_hint.numel() == 0: |
| return None |
| if split_start_hint.dim() == 0: |
| return int(split_start_hint.item()) |
| if split_start_hint.dim() == 1 and split_start_hint.size(0) == batch_size: |
| return int(split_start_hint[0].item()) |
| return None |
| if isinstance(split_start_hint, (list, tuple)): |
| if len(split_start_hint) == 0: |
| return None |
| return int(split_start_hint[0]) |
| return int(split_start_hint) |
|
|
|
|
| def _tri_parse_split_start_hints( |
| split_start_hint, |
| batch_size: int, |
| *, |
| device, |
| ) -> Optional[torch.LongTensor]: |
| if split_start_hint is None: |
| return None |
| if isinstance(split_start_hint, torch.Tensor): |
| if split_start_hint.numel() == 0: |
| return None |
| if split_start_hint.dim() == 0: |
| value = int(split_start_hint.item()) |
| return torch.full((batch_size,), value, device=device, dtype=torch.long) |
| flat = split_start_hint.flatten() |
| if flat.numel() == 1: |
| value = int(flat[0].item()) |
| return torch.full((batch_size,), value, device=device, dtype=torch.long) |
| if flat.dim() == 1 and flat.size(0) == batch_size: |
| return flat.to(device=device, dtype=torch.long) |
| return None |
| if isinstance(split_start_hint, (list, tuple)): |
| if len(split_start_hint) == 0: |
| return None |
| if len(split_start_hint) == 1: |
| value = int(split_start_hint[0]) |
| return torch.full((batch_size,), value, device=device, dtype=torch.long) |
| if len(split_start_hint) == batch_size: |
| return torch.tensor(split_start_hint, device=device, dtype=torch.long) |
| return None |
| value = int(split_start_hint) |
| return torch.full((batch_size,), value, device=device, dtype=torch.long) |
|
|
|
|
| def _tri_parse_system_len_hint(system_len_hint, batch_size: int) -> Optional[int]: |
| if system_len_hint is None: |
| return None |
| if isinstance(system_len_hint, torch.Tensor): |
| if system_len_hint.numel() == 0: |
| return None |
| if system_len_hint.dim() == 0: |
| return int(system_len_hint.item()) |
| if system_len_hint.dim() == 1 and system_len_hint.size(0) == batch_size: |
| return int(system_len_hint[0].item()) |
| return None |
| if isinstance(system_len_hint, (list, tuple)): |
| if len(system_len_hint) == 0: |
| return None |
| return int(system_len_hint[0]) |
| return int(system_len_hint) |
|
|
|
|
| def _tri_parse_system_len_hints( |
| system_len_hint, |
| batch_size: int, |
| *, |
| device, |
| ) -> Optional[torch.LongTensor]: |
| if system_len_hint is None: |
| return None |
| if isinstance(system_len_hint, torch.Tensor): |
| if system_len_hint.numel() == 0: |
| return None |
| if system_len_hint.dim() == 0: |
| value = int(system_len_hint.item()) |
| return torch.full((batch_size,), value, device=device, dtype=torch.long) |
| flat = system_len_hint.flatten() |
| if flat.numel() == 1: |
| value = int(flat[0].item()) |
| return torch.full((batch_size,), value, device=device, dtype=torch.long) |
| if flat.dim() == 1 and flat.size(0) == batch_size: |
| return flat.to(device=device, dtype=torch.long) |
| return None |
| if isinstance(system_len_hint, (list, tuple)): |
| if len(system_len_hint) == 0: |
| return None |
| if len(system_len_hint) == 1: |
| value = int(system_len_hint[0]) |
| return torch.full((batch_size,), value, device=device, dtype=torch.long) |
| if len(system_len_hint) == batch_size: |
| return torch.tensor(system_len_hint, device=device, dtype=torch.long) |
| return None |
| value = int(system_len_hint) |
| return torch.full((batch_size,), value, device=device, dtype=torch.long) |
|
|
|
|
| def _tri_parse_assistant_header_start_hints( |
| assistant_header_starts_hint, |
| assistant_header_start_mask_hint, |
| batch_size: int, |
| *, |
| device, |
| ) -> Optional[tuple[torch.LongTensor, torch.BoolTensor]]: |
| if assistant_header_starts_hint is None: |
| return None |
|
|
| starts = assistant_header_starts_hint |
| if isinstance(starts, torch.Tensor): |
| if starts.numel() == 0: |
| return None |
| if starts.dim() == 1: |
| starts = starts.view(1, -1) |
| elif starts.dim() != 2: |
| return None |
| if starts.size(0) == 1 and batch_size > 1: |
| starts = starts.expand(batch_size, -1) |
| if starts.size(0) != batch_size: |
| return None |
| starts = starts.to(device=device, dtype=torch.long) |
| elif isinstance(starts, (list, tuple)): |
| if len(starts) == 0: |
| return None |
| if isinstance(starts[0], (list, tuple)): |
| if len(starts) != batch_size: |
| return None |
| max_turns = max((len(row) for row in starts), default=0) |
| starts_tensor = torch.full((batch_size, max_turns), -1, device=device, dtype=torch.long) |
| for row_idx, row in enumerate(starts): |
| if len(row) > 0: |
| starts_tensor[row_idx, : len(row)] = torch.tensor(list(row), device=device, dtype=torch.long) |
| starts = starts_tensor |
| else: |
| starts_tensor = torch.tensor(list(starts), device=device, dtype=torch.long) |
| starts = starts_tensor.view(1, -1).expand(batch_size, -1) |
| else: |
| return None |
|
|
| if assistant_header_start_mask_hint is None: |
| start_mask = starts >= 0 |
| else: |
| start_mask = assistant_header_start_mask_hint |
| if isinstance(start_mask, torch.Tensor): |
| if start_mask.numel() == 0: |
| return None |
| if start_mask.dim() == 1: |
| start_mask = start_mask.view(1, -1) |
| elif start_mask.dim() != 2: |
| return None |
| if start_mask.size(0) == 1 and batch_size > 1: |
| start_mask = start_mask.expand(batch_size, -1) |
| if start_mask.size(0) != batch_size or start_mask.size(1) != starts.size(1): |
| return None |
| start_mask = start_mask.to(device=device, dtype=torch.bool) |
| elif isinstance(start_mask, (list, tuple)): |
| if len(start_mask) != batch_size: |
| return None |
| mask_tensor = torch.zeros_like(starts, dtype=torch.bool) |
| for row_idx, row in enumerate(start_mask): |
| if isinstance(row, (list, tuple)): |
| width = min(len(row), starts.size(1)) |
| if width > 0: |
| mask_tensor[row_idx, :width] = torch.tensor(list(row)[:width], device=device, dtype=torch.bool) |
| elif bool(row): |
| mask_tensor[row_idx, 0] = True |
| start_mask = mask_tensor |
| else: |
| return None |
|
|
| start_mask = start_mask & (starts >= 0) |
| return starts, start_mask |
|
|
|
|
| def _tri_find_valid_label_segments(labels_row: torch.Tensor, valid_len: int) -> list[tuple[int, int]]: |
| if labels_row.dim() != 1: |
| raise ValueError("labels_row must be rank-1.") |
| limit = min(max(int(valid_len), 0), int(labels_row.numel())) |
| if limit <= 0: |
| return [] |
|
|
| valid_mask = labels_row[:limit] != -100 |
| segments: list[tuple[int, int]] = [] |
| seg_start: Optional[int] = None |
| for idx in range(limit): |
| if bool(valid_mask[idx].item()): |
| if seg_start is None: |
| seg_start = idx |
| continue |
| if seg_start is not None: |
| segments.append((seg_start, idx)) |
| seg_start = None |
| if seg_start is not None: |
| segments.append((seg_start, limit)) |
| return segments |
|
|
|
|
| def _tri_resolve_assistant_turn_spans( |
| *, |
| labels_row: torch.Tensor, |
| valid_len: int, |
| assistant_header_starts: torch.Tensor, |
| assistant_header_start_mask: torch.Tensor, |
| ) -> list[tuple[int, int]]: |
| valid_len = min(max(int(valid_len), 0), int(labels_row.numel())) |
| if valid_len <= 0: |
| return [] |
|
|
| starts: list[int] = [] |
| for col in range(int(assistant_header_starts.numel())): |
| if not bool(assistant_header_start_mask[col].item()): |
| continue |
| start = int(assistant_header_starts[col].item()) |
| if 0 <= start < valid_len: |
| starts.append(start) |
| if not starts: |
| return [] |
| starts = sorted(set(starts)) |
|
|
| segments = _tri_find_valid_label_segments(labels_row, valid_len=valid_len) |
| if not segments: |
| return [] |
|
|
| spans: list[tuple[int, int]] = [] |
| seg_cursor = 0 |
| for idx, start in enumerate(starts): |
| next_start = starts[idx + 1] if idx + 1 < len(starts) else valid_len |
| while seg_cursor < len(segments) and segments[seg_cursor][1] <= start: |
| seg_cursor += 1 |
| probe = seg_cursor |
| while probe < len(segments): |
| seg_start, seg_end = segments[probe] |
| if seg_start >= next_start: |
| break |
| if seg_start >= start: |
| spans.append((start, seg_end)) |
| seg_cursor = probe + 1 |
| break |
| probe += 1 |
| return spans |
|
|
|
|
| def _tri_build_prefill_lower_multiturn_index_batch( |
| *, |
| turn_starts: torch.LongTensor, |
| turn_ends: torch.LongTensor, |
| system_lens: torch.LongTensor, |
| system_prefill: str, |
| device, |
| owner=None, |
| bucket_multiple: int = 0, |
| ) -> tuple[torch.LongTensor, torch.BoolTensor, torch.LongTensor]: |
| turn_starts = turn_starts.to(device=device, dtype=torch.long) |
| turn_ends = turn_ends.to(device=device, dtype=torch.long) |
| system_lens = system_lens.to(device=device, dtype=torch.long) |
| row_count = int(turn_starts.size(0)) |
|
|
| upper_indices = [] |
| upper_lens = torch.zeros((row_count,), device=device, dtype=torch.long) |
| max_upper_len = 0 |
| for row in range(row_count): |
| idx = build_prefill_lower_upper_indices( |
| sequence_len=int(turn_ends[row].item()), |
| split_start=int(turn_starts[row].item()), |
| system_len=int(system_lens[row].item()), |
| system_prefill=system_prefill, |
| device=device, |
| ) |
| upper_indices.append(idx) |
| upper_lens[row] = int(idx.numel()) |
| if idx.numel() > max_upper_len: |
| max_upper_len = int(idx.numel()) |
|
|
| bucket_multiple = int(bucket_multiple or 0) |
| target_upper_len = int(max_upper_len) |
| if bucket_multiple > 0 and target_upper_len > 0: |
| target_upper_len = ((target_upper_len + bucket_multiple - 1) // bucket_multiple) * bucket_multiple |
|
|
| gather_idx = None |
| valid_mask = None |
| if owner is not None and bucket_multiple > 0 and not bool(getattr(owner, "training", False)): |
| cache = getattr(owner, "_optimized_llopa_upper_index_workspace_cache", None) |
| if not isinstance(cache, dict): |
| cache = {} |
| key = ("multiturn", str(device), int(row_count), int(target_upper_len)) |
| pair = cache.get(key) |
| if ( |
| not isinstance(pair, tuple) |
| or len(pair) != 2 |
| or not isinstance(pair[0], torch.Tensor) |
| or not isinstance(pair[1], torch.Tensor) |
| or pair[0].shape != (int(row_count), int(target_upper_len)) |
| or pair[1].shape != (int(row_count), int(target_upper_len)) |
| or pair[0].device != device |
| or pair[1].device != device |
| ): |
| pair = ( |
| torch.zeros((int(row_count), int(target_upper_len)), device=device, dtype=torch.long), |
| torch.zeros((int(row_count), int(target_upper_len)), device=device, dtype=torch.bool), |
| ) |
| cache[key] = pair |
| with contextlib.suppress(Exception): |
| setattr(owner, "_optimized_llopa_upper_index_workspace_cache", cache) |
| gather_idx, valid_mask = pair |
| gather_idx.zero_() |
| valid_mask.zero_() |
|
|
| if gather_idx is None or valid_mask is None: |
| gather_idx = torch.zeros((row_count, target_upper_len), device=device, dtype=torch.long) |
| valid_mask = torch.zeros((row_count, target_upper_len), device=device, dtype=torch.bool) |
|
|
| for row, idx in enumerate(upper_indices): |
| width = int(idx.numel()) |
| if width <= 0: |
| continue |
| gather_idx[row, :width] = idx |
| valid_mask[row, :width] = True |
| return gather_idx, valid_mask, upper_lens |
|
|
|
|
| def _tri_build_prefill_lower_past_assistant_index_batch( |
| *, |
| split_starts: torch.LongTensor, |
| valid_lens: torch.LongTensor, |
| system_lens: torch.LongTensor, |
| system_prefill: str, |
| assistant_header_starts: torch.LongTensor, |
| assistant_turn_ends: torch.LongTensor, |
| assistant_header_start_mask: torch.BoolTensor, |
| device, |
| owner=None, |
| bucket_multiple: int = 0, |
| ) -> tuple[torch.LongTensor, torch.BoolTensor, torch.LongTensor]: |
| split_starts = split_starts.to(device=device, dtype=torch.long) |
| valid_lens = valid_lens.to(device=device, dtype=torch.long) |
| system_lens = system_lens.to(device=device, dtype=torch.long) |
| assistant_header_starts = assistant_header_starts.to(device=device, dtype=torch.long) |
| assistant_turn_ends = assistant_turn_ends.to(device=device, dtype=torch.long) |
| assistant_header_start_mask = assistant_header_start_mask.to(device=device, dtype=torch.bool) |
| batch_size = int(split_starts.size(0)) |
| prefix_keep_lens = _tri_prefix_keep_lengths(system_lens, system_prefill) |
|
|
| upper_indices = [] |
| upper_lens = torch.zeros((batch_size,), device=device, dtype=torch.long) |
| max_upper_len = 0 |
| for row in range(batch_size): |
| valid_len = int(valid_lens[row].item()) |
| split_start = max(0, min(int(split_starts[row].item()), valid_len)) |
| prefix_keep = min(max(int(prefix_keep_lens[row].item()), 0), split_start) |
| row_indices: list[torch.Tensor] = [] |
|
|
| if prefix_keep > 0: |
| row_indices.append(torch.arange(0, prefix_keep, device=device, dtype=torch.long)) |
|
|
| num_turns = min(int(assistant_header_starts.size(1)), int(assistant_turn_ends.size(1))) |
| for col in range(num_turns): |
| if not bool(assistant_header_start_mask[row, col].item()): |
| continue |
| turn_start = max(int(assistant_header_starts[row, col].item()), prefix_keep) |
| turn_end = int(assistant_turn_ends[row, col].item()) |
| if turn_end <= turn_start: |
| continue |
| turn_end = min(turn_end, split_start) |
| if turn_end <= turn_start: |
| continue |
| row_indices.append(torch.arange(turn_start, turn_end, device=device, dtype=torch.long)) |
|
|
| if split_start < valid_len: |
| row_indices.append(torch.arange(split_start, valid_len, device=device, dtype=torch.long)) |
|
|
| idx = torch.cat(row_indices, dim=0) if row_indices else torch.empty((0,), device=device, dtype=torch.long) |
| upper_indices.append(idx) |
| upper_lens[row] = int(idx.numel()) |
| if idx.numel() > max_upper_len: |
| max_upper_len = int(idx.numel()) |
|
|
| bucket_multiple = int(bucket_multiple or 0) |
| target_upper_len = int(max_upper_len) |
| if bucket_multiple > 0 and target_upper_len > 0: |
| target_upper_len = ((target_upper_len + bucket_multiple - 1) // bucket_multiple) * bucket_multiple |
|
|
| gather_idx = None |
| valid_mask = None |
| if owner is not None and bucket_multiple > 0 and not bool(getattr(owner, "training", False)): |
| cache = getattr(owner, "_optimized_llopa_upper_index_workspace_cache", None) |
| if not isinstance(cache, dict): |
| cache = {} |
| key = ("single_past_assistant", str(device), int(batch_size), int(target_upper_len)) |
| pair = cache.get(key) |
| if ( |
| not isinstance(pair, tuple) |
| or len(pair) != 2 |
| or not isinstance(pair[0], torch.Tensor) |
| or not isinstance(pair[1], torch.Tensor) |
| or pair[0].shape != (int(batch_size), int(target_upper_len)) |
| or pair[1].shape != (int(batch_size), int(target_upper_len)) |
| or pair[0].device != device |
| or pair[1].device != device |
| ): |
| pair = ( |
| torch.zeros((int(batch_size), int(target_upper_len)), device=device, dtype=torch.long), |
| torch.zeros((int(batch_size), int(target_upper_len)), device=device, dtype=torch.bool), |
| ) |
| cache[key] = pair |
| with contextlib.suppress(Exception): |
| setattr(owner, "_optimized_llopa_upper_index_workspace_cache", cache) |
| gather_idx, valid_mask = pair |
| gather_idx.zero_() |
| valid_mask.zero_() |
|
|
| if gather_idx is None or valid_mask is None: |
| gather_idx = torch.zeros((batch_size, target_upper_len), device=device, dtype=torch.long) |
| valid_mask = torch.zeros((batch_size, target_upper_len), device=device, dtype=torch.bool) |
|
|
| for row, idx in enumerate(upper_indices): |
| width = int(idx.numel()) |
| if width <= 0: |
| continue |
| gather_idx[row, :width] = idx |
| valid_mask[row, :width] = True |
| return gather_idx, valid_mask, upper_lens |
|
|
|
|
| def _tri_has_effective_past_assistant_history( |
| *, |
| split_start: int, |
| assistant_header_hints, |
| assistant_turn_end_hints, |
| ) -> bool: |
| if assistant_header_hints is None or assistant_turn_end_hints is None: |
| return False |
|
|
| header_starts, header_start_mask = assistant_header_hints |
| turn_ends, _ = assistant_turn_end_hints |
| if ( |
| not isinstance(header_starts, torch.Tensor) |
| or not isinstance(header_start_mask, torch.Tensor) |
| or not isinstance(turn_ends, torch.Tensor) |
| or header_starts.ndim != 2 |
| or header_start_mask.ndim != 2 |
| or turn_ends.ndim != 2 |
| or header_starts.size(0) <= 0 |
| or header_start_mask.size(0) <= 0 |
| or turn_ends.size(0) <= 0 |
| ): |
| return False |
|
|
| split_start = int(split_start) |
| num_turns = min(int(header_starts.size(1)), int(turn_ends.size(1)), int(header_start_mask.size(1))) |
| for col in range(num_turns): |
| if not bool(header_start_mask[0, col].item()): |
| continue |
| turn_start = int(header_starts[0, col].item()) |
| turn_end = int(turn_ends[0, col].item()) |
| if turn_end <= turn_start: |
| continue |
| if turn_start < split_start: |
| return True |
| return False |
|
|
|
|
| def _tri_build_prefill_lower_seed_upper_index_batch( |
| *, |
| split_starts: torch.LongTensor, |
| system_lens: torch.LongTensor, |
| system_prefill: str, |
| assistant_header_starts: Optional[torch.LongTensor] = None, |
| assistant_turn_ends: Optional[torch.LongTensor] = None, |
| assistant_header_start_mask: Optional[torch.BoolTensor] = None, |
| see_past_assistant: bool = False, |
| device, |
| ) -> tuple[torch.LongTensor, torch.BoolTensor, torch.LongTensor]: |
| split_starts = split_starts.to(device=device, dtype=torch.long) |
| system_lens = system_lens.to(device=device, dtype=torch.long) |
| batch_size = int(split_starts.size(0)) |
| prefix_keep_lens = _tri_prefix_keep_lengths(system_lens, system_prefill) |
| if isinstance(assistant_header_starts, torch.Tensor): |
| assistant_header_starts = assistant_header_starts.to(device=device, dtype=torch.long) |
| if isinstance(assistant_turn_ends, torch.Tensor): |
| assistant_turn_ends = assistant_turn_ends.to(device=device, dtype=torch.long) |
| if isinstance(assistant_header_start_mask, torch.Tensor): |
| assistant_header_start_mask = assistant_header_start_mask.to(device=device, dtype=torch.bool) |
|
|
| upper_indices: list[torch.Tensor] = [] |
| upper_lens = torch.zeros((batch_size,), device=device, dtype=torch.long) |
| max_upper_len = 0 |
| has_history_metadata = ( |
| bool(see_past_assistant) |
| and isinstance(assistant_header_starts, torch.Tensor) |
| and isinstance(assistant_turn_ends, torch.Tensor) |
| and isinstance(assistant_header_start_mask, torch.Tensor) |
| and assistant_header_starts.dim() == 2 |
| and assistant_turn_ends.dim() == 2 |
| and assistant_header_start_mask.dim() == 2 |
| ) |
| for row in range(batch_size): |
| split_start = max(int(split_starts[row].item()), 0) |
| prefix_keep = min(max(int(prefix_keep_lens[row].item()), 0), split_start) |
| row_parts: list[torch.Tensor] = [] |
| if prefix_keep > 0: |
| row_parts.append(torch.arange(0, prefix_keep, device=device, dtype=torch.long)) |
| if has_history_metadata: |
| num_turns = min(int(assistant_header_starts.size(1)), int(assistant_turn_ends.size(1))) |
| for col in range(num_turns): |
| if not bool(assistant_header_start_mask[row, col].item()): |
| continue |
| turn_start = max(int(assistant_header_starts[row, col].item()), prefix_keep) |
| turn_end = min(int(assistant_turn_ends[row, col].item()), split_start) |
| if turn_end > turn_start: |
| row_parts.append(torch.arange(turn_start, turn_end, device=device, dtype=torch.long)) |
| idx = torch.cat(row_parts, dim=0) if row_parts else torch.empty((0,), device=device, dtype=torch.long) |
| upper_indices.append(idx) |
| upper_lens[row] = int(idx.numel()) |
| max_upper_len = max(max_upper_len, int(idx.numel())) |
|
|
| gather_idx = torch.zeros((batch_size, max_upper_len), device=device, dtype=torch.long) |
| valid_mask = torch.zeros((batch_size, max_upper_len), device=device, dtype=torch.bool) |
| for row, idx in enumerate(upper_indices): |
| width = int(idx.numel()) |
| if width > 0: |
| gather_idx[row, :width] = idx |
| valid_mask[row, :width] = True |
| return gather_idx, valid_mask, upper_lens |
|
|
|
|
| def _tri_prefill_lower_prompt_cache( |
| model: "LlamaModel", |
| prefix_ids: torch.LongTensor, |
| *, |
| lower_k: int, |
| prefill_attn: str, |
| system_len: int, |
| system_prefill: str, |
| return_replay_memory: bool = False, |
| return_prefix_hidden: bool = False, |
| replay_user_prefix_keep_len: int = 0, |
| replay_user_start: Optional[int] = None, |
| replay_user_len: Optional[int] = None, |
| ) -> Cache | tuple[Cache, Optional[torch.Tensor], Optional[torch.LongTensor], Optional[torch.Tensor], Optional[torch.LongTensor]]: |
| system_len = min(max(int(system_len), 0), int(prefix_ids.size(1))) |
| if prefix_ids.size(1) == 0: |
| empty_cache = _safe_dynamic_cache(model.config) |
| if not bool(return_replay_memory) and not bool(return_prefix_hidden): |
| return empty_cache |
| empty_pos = torch.empty((prefix_ids.size(0), 0), device=prefix_ids.device, dtype=torch.long) |
| replay_hidden = model.embed_tokens(prefix_ids) if bool(return_replay_memory) else None |
| replay_pos = empty_pos if bool(return_replay_memory) else None |
| prefix_hidden = model.embed_tokens(prefix_ids) if bool(return_prefix_hidden) else None |
| prefix_pos = empty_pos if bool(return_prefix_hidden) else None |
| return empty_cache, replay_hidden, replay_pos, prefix_hidden, prefix_pos |
|
|
| system_ids = prefix_ids[:, :system_len] |
| user_ids = prefix_ids[:, system_len:] |
| mode = normalize_system_prefill(system_prefill) |
| batch_size = int(prefix_ids.size(0)) |
| replay_hidden_parts: list[torch.Tensor] = [] |
| replay_pos_parts: list[torch.LongTensor] = [] |
| prefix_hidden_parts: list[torch.Tensor] = [] |
| prefix_pos_parts: list[torch.LongTensor] = [] |
|
|
| def _capture_replay_segment(hidden_states: Optional[torch.Tensor], start_pos: int) -> None: |
| if not bool(return_replay_memory): |
| return |
| if not isinstance(hidden_states, torch.Tensor) or hidden_states.numel() == 0 or int(hidden_states.size(1)) == 0: |
| return |
| replay_hidden_parts.append(hidden_states) |
| replay_pos_parts.append( |
| _tri_segment_position_ids( |
| batch_size=batch_size, |
| start=int(start_pos), |
| length=int(hidden_states.size(1)), |
| device=hidden_states.device, |
| ) |
| ) |
|
|
| def _capture_prefix_segment(hidden_states: Optional[torch.Tensor], start_pos: int) -> None: |
| if not bool(return_prefix_hidden): |
| return |
| if not isinstance(hidden_states, torch.Tensor) or hidden_states.numel() == 0 or int(hidden_states.size(1)) == 0: |
| return |
| prefix_hidden_parts.append(hidden_states) |
| prefix_pos_parts.append( |
| _tri_segment_position_ids( |
| batch_size=batch_size, |
| start=int(start_pos), |
| length=int(hidden_states.size(1)), |
| device=hidden_states.device, |
| ) |
| ) |
|
|
| def _capture_user_replay(hidden_states: Optional[torch.Tensor], start_pos: int) -> None: |
| if not bool(return_replay_memory): |
| return |
| if not isinstance(hidden_states, torch.Tensor) or hidden_states.numel() == 0 or int(hidden_states.size(1)) == 0: |
| return |
| replay_prefix_keep, replay_start, replay_end = _tri_resolve_replay_user_window( |
| user_len=int(hidden_states.size(1)), |
| prefix_keep_len=int(replay_user_prefix_keep_len or 0), |
| latest_user_start=replay_user_start, |
| latest_user_len=replay_user_len, |
| ) |
| if replay_prefix_keep > 0: |
| _capture_replay_segment(hidden_states[:, :replay_prefix_keep, :], start_pos=start_pos) |
| if replay_end > replay_start: |
| _capture_replay_segment( |
| hidden_states[:, replay_start:replay_end, :], |
| start_pos=int(start_pos) + int(replay_start), |
| ) |
|
|
| def _finalize( |
| pkv: Cache, |
| ) -> Cache | tuple[Cache, Optional[torch.Tensor], Optional[torch.LongTensor], Optional[torch.Tensor], Optional[torch.LongTensor]]: |
| if not bool(return_replay_memory) and not bool(return_prefix_hidden): |
| return pkv |
| replay_hidden = None |
| replay_position_ids = None |
| prefix_hidden = None |
| prefix_position_ids = None |
| if replay_hidden_parts: |
| replay_hidden = replay_hidden_parts[0] if len(replay_hidden_parts) == 1 else torch.cat(replay_hidden_parts, dim=1) |
| replay_position_ids = replay_pos_parts[0] if len(replay_pos_parts) == 1 else torch.cat(replay_pos_parts, dim=1) |
| if prefix_hidden_parts: |
| prefix_hidden = prefix_hidden_parts[0] if len(prefix_hidden_parts) == 1 else torch.cat(prefix_hidden_parts, dim=1) |
| prefix_position_ids = prefix_pos_parts[0] if len(prefix_pos_parts) == 1 else torch.cat(prefix_pos_parts, dim=1) |
| return pkv, replay_hidden, replay_position_ids, prefix_hidden, prefix_position_ids |
|
|
| if mode == "full": |
| system_out = model.tri_prefill_system_all( |
| system_ids, |
| past_key_values=None, |
| prefill_attn=prefill_attn, |
| ) |
| pkv = system_out.past_key_values |
| if user_ids.size(1) > 0: |
| user_out = model.tri_prefill_user_lower( |
| user_ids, |
| lower_k=int(lower_k), |
| past_key_values=pkv, |
| prefill_mode="lower", |
| prefill_attn=prefill_attn, |
| return_pre_norm_hidden=bool(return_replay_memory or return_prefix_hidden), |
| ) |
| _capture_user_replay(getattr(user_out, "pre_norm_hidden", None), start_pos=system_len) |
| _capture_prefix_segment(getattr(user_out, "pre_norm_hidden", None), start_pos=system_len) |
| return _finalize(pkv) |
|
|
| pkv = _safe_dynamic_cache(model.config) |
| if mode == "no_system": |
| if system_ids.size(1) > 0: |
| bos_ids = system_ids[:, :1] |
| rest_ids = system_ids[:, 1:] |
| out = model.tri_prefill_system_all( |
| bos_ids, |
| past_key_values=None, |
| prefill_attn=prefill_attn, |
| ) |
| pkv = out.past_key_values |
| if rest_ids.size(1) > 0: |
| rest_out = model.tri_prefill_user_lower( |
| rest_ids, |
| lower_k=int(lower_k), |
| past_key_values=pkv, |
| prefill_mode="lower", |
| prefill_attn=prefill_attn, |
| return_pre_norm_hidden=bool(return_replay_memory or return_prefix_hidden), |
| ) |
| _capture_replay_segment(getattr(rest_out, "pre_norm_hidden", None), start_pos=1) |
| _capture_prefix_segment(getattr(rest_out, "pre_norm_hidden", None), start_pos=1) |
| if user_ids.size(1) > 0: |
| user_out = model.tri_prefill_user_lower( |
| user_ids, |
| lower_k=int(lower_k), |
| past_key_values=pkv, |
| prefill_mode="lower", |
| prefill_attn=prefill_attn, |
| return_pre_norm_hidden=bool(return_replay_memory or return_prefix_hidden), |
| ) |
| _capture_user_replay(getattr(user_out, "pre_norm_hidden", None), start_pos=system_len) |
| _capture_prefix_segment(getattr(user_out, "pre_norm_hidden", None), start_pos=system_len) |
| return _finalize(pkv) |
|
|
| if system_ids.size(1) > 0: |
| system_out = model.tri_prefill_user_lower( |
| system_ids, |
| lower_k=int(lower_k), |
| past_key_values=pkv, |
| prefill_mode="lower", |
| prefill_attn=prefill_attn, |
| return_pre_norm_hidden=bool(return_replay_memory or return_prefix_hidden), |
| ) |
| _capture_replay_segment(getattr(system_out, "pre_norm_hidden", None), start_pos=0) |
| _capture_prefix_segment(getattr(system_out, "pre_norm_hidden", None), start_pos=0) |
| if user_ids.size(1) > 0: |
| user_out = model.tri_prefill_user_lower( |
| user_ids, |
| lower_k=int(lower_k), |
| past_key_values=pkv, |
| prefill_mode="lower", |
| prefill_attn=prefill_attn, |
| return_pre_norm_hidden=bool(return_replay_memory or return_prefix_hidden), |
| ) |
| _capture_user_replay(getattr(user_out, "pre_norm_hidden", None), start_pos=system_len) |
| _capture_prefix_segment(getattr(user_out, "pre_norm_hidden", None), start_pos=system_len) |
| return _finalize(pkv) |
|
|
|
|
| def _tri_prefill_lower_prompt_hidden( |
| model: "LlamaModel", |
| prefix_ids: torch.LongTensor, |
| *, |
| lower_k: int, |
| prefill_attn: str, |
| ) -> tuple[torch.Tensor, torch.LongTensor]: |
| if prefix_ids.size(1) == 0: |
| empty_hidden = model.embed_tokens(prefix_ids) |
| empty_pos = torch.empty((prefix_ids.size(0), 0), device=prefix_ids.device, dtype=torch.long) |
| return empty_hidden, empty_pos |
|
|
| n_layers = len(model.layers) |
| lower_k = max(0, min(int(lower_k), n_layers)) |
| if lower_k <= 0: |
| raise ValueError("_tri_prefill_lower_prompt_hidden requires lower_k > 0.") |
|
|
| attn_mode = (prefill_attn or "causal").strip().lower() |
| if attn_mode == "prefix_full": |
| attn_mode = "full" |
| if attn_mode not in {"causal", "full"}: |
| raise ValueError("_tri_prefill_lower_prompt_hidden requires prefill_attn in {'causal', 'full'}.") |
|
|
| device = prefix_ids.device |
| inputs_embeds = model.embed_tokens(prefix_ids) |
| attention_mask = torch.ones_like(prefix_ids, dtype=torch.long, device=device) |
| cache_position = _tri_arange(0, inputs_embeds.shape[1], device) |
| position_ids = _llopa_position_ids_from_mask(attention_mask) |
|
|
| attn_impl = _resolve_attn_impl(model.config) |
| if attn_mode == "full": |
| if attn_impl == "flash_attention_2": |
| raise ValueError("_tri_prefill_lower_prompt_hidden prefill_attn=full is not supported with flash_attention_2.") |
| lower_mask = _build_tri_mask_full(prefix_ids.size(0), inputs_embeds.shape[1], 0, device, inputs_embeds.dtype) |
| elif attn_impl == "flash_attention_2" or _can_use_implicit_causal_mask(model.config): |
| lower_mask = None |
| else: |
| lower_mask = create_causal_mask( |
| config=model.config, |
| input_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| cache_position=cache_position, |
| past_key_values=None, |
| position_ids=position_ids, |
| ) |
|
|
| hidden_states = inputs_embeds |
| position_embeddings = model.rotary_emb(hidden_states, position_ids) |
| use_checkpoint = should_use_checkpoint(model) |
| for li in range(lower_k): |
| layer = model.layers[li] |
| if use_checkpoint: |
| hidden_states = checkpoint_with_cache( |
| layer.__call__, |
| hidden_states, |
| attention_mask=lower_mask, |
| position_ids=position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| else: |
| hidden_states = layer( |
| hidden_states, |
| attention_mask=lower_mask, |
| position_ids=position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| return hidden_states, position_ids |
|
|
|
|
| def _tri_build_prefill_lower_upper_index_batch( |
| *, |
| split_starts: torch.LongTensor, |
| valid_lens: torch.LongTensor, |
| system_lens: torch.LongTensor, |
| system_prefill: str, |
| device, |
| owner=None, |
| bucket_multiple: int = 0, |
| ) -> tuple[torch.LongTensor, torch.BoolTensor, torch.LongTensor]: |
| split_starts = split_starts.to(device=device, dtype=torch.long) |
| valid_lens = valid_lens.to(device=device, dtype=torch.long) |
| system_lens = system_lens.to(device=device, dtype=torch.long) |
| batch_size = int(split_starts.size(0)) |
|
|
| upper_indices = [] |
| upper_lens = torch.zeros((batch_size,), device=device, dtype=torch.long) |
| max_upper_len = 0 |
| for row in range(batch_size): |
| idx = build_prefill_lower_upper_indices( |
| sequence_len=int(valid_lens[row].item()), |
| split_start=int(split_starts[row].item()), |
| system_len=int(system_lens[row].item()), |
| system_prefill=system_prefill, |
| device=device, |
| ) |
| upper_indices.append(idx) |
| upper_lens[row] = int(idx.numel()) |
| if idx.numel() > max_upper_len: |
| max_upper_len = int(idx.numel()) |
|
|
| bucket_multiple = int(bucket_multiple or 0) |
| target_upper_len = int(max_upper_len) |
| if bucket_multiple > 0 and target_upper_len > 0: |
| target_upper_len = ((target_upper_len + bucket_multiple - 1) // bucket_multiple) * bucket_multiple |
|
|
| gather_idx = None |
| valid_mask = None |
| if owner is not None and bucket_multiple > 0 and not bool(getattr(owner, "training", False)): |
| cache = getattr(owner, "_optimized_llopa_upper_index_workspace_cache", None) |
| if not isinstance(cache, dict): |
| cache = {} |
| key = ("single", str(device), int(batch_size), int(target_upper_len)) |
| pair = cache.get(key) |
| if ( |
| not isinstance(pair, tuple) |
| or len(pair) != 2 |
| or not isinstance(pair[0], torch.Tensor) |
| or not isinstance(pair[1], torch.Tensor) |
| or pair[0].shape != (int(batch_size), int(target_upper_len)) |
| or pair[1].shape != (int(batch_size), int(target_upper_len)) |
| or pair[0].device != device |
| or pair[1].device != device |
| ): |
| pair = ( |
| torch.zeros((int(batch_size), int(target_upper_len)), device=device, dtype=torch.long), |
| torch.zeros((int(batch_size), int(target_upper_len)), device=device, dtype=torch.bool), |
| ) |
| cache[key] = pair |
| with contextlib.suppress(Exception): |
| setattr(owner, "_optimized_llopa_upper_index_workspace_cache", cache) |
| gather_idx, valid_mask = pair |
| gather_idx.zero_() |
| valid_mask.zero_() |
|
|
| if gather_idx is None or valid_mask is None: |
| gather_idx = torch.zeros((batch_size, target_upper_len), device=device, dtype=torch.long) |
| valid_mask = torch.zeros((batch_size, target_upper_len), device=device, dtype=torch.bool) |
|
|
| for row, idx in enumerate(upper_indices): |
| width = int(idx.numel()) |
| if width <= 0: |
| continue |
| gather_idx[row, :width] = idx |
| valid_mask[row, :width] = True |
| return gather_idx, valid_mask, upper_lens |
|
|
|
|
| def _tri_pack_indexed_tensor( |
| tensor: torch.Tensor, |
| *, |
| gather_idx: torch.LongTensor, |
| valid_mask: torch.Tensor, |
| pad_value, |
| owner=None, |
| workspace_name: Optional[str] = None, |
| ): |
| if tensor.dim() < 2: |
| raise ValueError("Expected rank-2 or higher tensor for indexed packing.") |
| B, total_len = tensor.shape[:2] |
| device = tensor.device |
| gather_idx = gather_idx.to(device=device, dtype=torch.long) |
| valid_mask = valid_mask.to(device=device, dtype=torch.bool) |
| clamped_idx = gather_idx.clamp(min=0, max=max(total_len - 1, 0)) |
| packed_lens = valid_mask.sum(dim=1, dtype=torch.long) |
|
|
| use_workspace = ( |
| owner is not None |
| and workspace_name is not None |
| and not bool(getattr(owner, "training", False)) |
| and int(clamped_idx.size(1)) > 0 |
| ) |
| packed = None |
| if use_workspace: |
| cache = getattr(owner, "_optimized_llopa_upper_pack_workspace_cache", None) |
| if not isinstance(cache, dict): |
| cache = {} |
| target_shape = (B, int(clamped_idx.size(1)), *tensor.shape[2:]) |
| key = (str(workspace_name), str(device), str(tensor.dtype), tuple(int(x) for x in target_shape)) |
| packed = cache.get(key) |
| if ( |
| not isinstance(packed, torch.Tensor) |
| or packed.device != device |
| or packed.dtype != tensor.dtype |
| or tuple(int(x) for x in packed.shape) != tuple(int(x) for x in target_shape) |
| ): |
| packed = torch.empty(target_shape, device=device, dtype=tensor.dtype) |
| cache[key] = packed |
| with contextlib.suppress(Exception): |
| setattr(owner, "_optimized_llopa_upper_pack_workspace_cache", cache) |
|
|
| if tensor.dim() == 2: |
| if packed is not None: |
| torch.gather(tensor, 1, clamped_idx, out=packed) |
| packed.masked_fill_(~valid_mask, pad_value) |
| else: |
| packed = tensor.gather(1, clamped_idx) |
| packed = packed.masked_fill(~valid_mask, pad_value) |
| else: |
| expand_shape = (*clamped_idx.shape, *tensor.shape[2:]) |
| gather_index = clamped_idx.view(B, clamped_idx.size(1), *([1] * (tensor.dim() - 2))).expand(*expand_shape) |
| mask_view = ~valid_mask.view(B, clamped_idx.size(1), *([1] * (tensor.dim() - 2))) |
| if packed is not None: |
| torch.gather(tensor, 1, gather_index, out=packed) |
| packed.masked_fill_(mask_view, pad_value) |
| else: |
| packed = tensor.gather(1, gather_index) |
| packed = packed.masked_fill(mask_view, pad_value) |
| return packed, packed_lens |
|
|
|
|
| def _tri_find_last_subsequence_start(input_ids: torch.Tensor, pattern: torch.Tensor) -> Optional[int]: |
| if not isinstance(input_ids, torch.Tensor) or not isinstance(pattern, torch.Tensor): |
| return None |
| if input_ids.dim() == 2: |
| if input_ids.size(0) != 1: |
| return None |
| seq = input_ids[0] |
| elif input_ids.dim() == 1: |
| seq = input_ids |
| else: |
| return None |
| if pattern.dim() == 2: |
| if pattern.size(0) != 1: |
| return None |
| pat = pattern[0] |
| elif pattern.dim() == 1: |
| pat = pattern |
| else: |
| return None |
|
|
| pat_len = int(pat.numel()) |
| seq_len = int(seq.numel()) |
| if pat_len <= 0 or seq_len < pat_len: |
| return None |
|
|
| windows = seq.unfold(0, pat_len, 1) |
| matches = (windows == pat).all(dim=-1).nonzero(as_tuple=False) |
| if matches.numel() == 0: |
| return None |
| return int(matches[-1].item()) |
|
|
|
|
| def _tri_pack_suffix_tensor( |
| tensor: torch.Tensor, |
| *, |
| split_starts: torch.LongTensor, |
| valid_lens: torch.LongTensor, |
| pad_value, |
| ): |
| if tensor.dim() < 2: |
| raise ValueError("Expected rank-2 or higher tensor for suffix packing.") |
| B, total_len = tensor.shape[:2] |
| device = tensor.device |
| split_starts = split_starts.to(device=device, dtype=torch.long) |
| valid_lens = valid_lens.to(device=device, dtype=torch.long) |
| suffix_lens = valid_lens - split_starts |
| if torch.any(suffix_lens <= 0): |
| raise ValueError("All packed suffixes must have positive length.") |
|
|
| max_suffix_len = int(suffix_lens.max().item()) |
| offsets = torch.arange(max_suffix_len, device=device, dtype=torch.long).unsqueeze(0) |
| gather_pos = split_starts.unsqueeze(1) + offsets |
| valid = gather_pos < valid_lens.unsqueeze(1) |
| gather_idx = gather_pos.clamp(max=max(total_len - 1, 0)) |
|
|
| if tensor.dim() == 2: |
| packed = tensor.gather(1, gather_idx) |
| packed = packed.masked_fill(~valid, pad_value) |
| else: |
| expand_shape = (*gather_idx.shape, *tensor.shape[2:]) |
| packed = tensor.gather(1, gather_idx.view(B, max_suffix_len, *([1] * (tensor.dim() - 2))).expand(*expand_shape)) |
| packed = packed.masked_fill(~valid.view(B, max_suffix_len, *([1] * (tensor.dim() - 2))), pad_value) |
| return packed, valid, suffix_lens, gather_idx |
|
|
|
|
| def _tri_insert_suffix_specials_inband( |
| *, |
| token_ids: list[int], |
| input_ids: torch.LongTensor, |
| attention_mask: torch.Tensor, |
| labels: Optional[torch.LongTensor], |
| split_starts: Optional[torch.LongTensor] = None, |
| assistant_header_starts: Optional[torch.LongTensor] = None, |
| assistant_header_start_mask: Optional[torch.BoolTensor] = None, |
| ) -> tuple[ |
| torch.LongTensor, |
| torch.Tensor, |
| Optional[torch.LongTensor], |
| Optional[torch.LongTensor], |
| Optional[torch.LongTensor], |
| Optional[torch.BoolTensor], |
| ]: |
| if not token_ids: |
| return input_ids, attention_mask, labels, split_starts, assistant_header_starts, assistant_header_start_mask |
|
|
| batch_size = int(input_ids.size(0)) |
| device = input_ids.device |
| token_count = len(token_ids) |
| valid_lens = attention_mask.to(device=device, dtype=torch.long).sum(dim=1, dtype=torch.long) |
|
|
| starts_tensor = assistant_header_starts |
| starts_mask = assistant_header_start_mask |
| generated_starts = False |
| if starts_tensor is None and split_starts is not None: |
| starts_tensor = split_starts.view(batch_size, 1) |
| starts_mask = torch.ones((batch_size, 1), device=device, dtype=torch.bool) |
| generated_starts = True |
| elif starts_tensor is not None and starts_mask is None: |
| starts_mask = starts_tensor >= 0 |
|
|
| if starts_tensor is not None: |
| starts_tensor = starts_tensor.to(device=device, dtype=torch.long) |
| if starts_tensor.dim() == 1: |
| starts_tensor = starts_tensor.view(batch_size, -1) |
| if starts_mask is not None: |
| starts_mask = starts_mask.to(device=device, dtype=torch.bool) |
| if starts_mask.dim() == 1: |
| starts_mask = starts_mask.view(batch_size, -1) |
|
|
| row_ids: list[torch.Tensor] = [] |
| row_masks: list[torch.Tensor] = [] |
| row_labels: list[torch.Tensor] = [] |
| remapped_split_starts: list[int] = [] |
| remapped_turn_starts: list[list[int]] = [] |
| remapped_turn_masks: list[list[bool]] = [] |
| max_len = 0 |
|
|
| for row in range(batch_size): |
| valid_len = int(valid_lens[row].item()) |
| ids_list = [int(v) for v in input_ids[row, :valid_len].tolist()] |
| mask_list = [int(v) for v in attention_mask[row, :valid_len].tolist()] |
| labels_list = None |
| if labels is not None: |
| labels_list = [int(v) for v in labels[row, :valid_len].tolist()] |
|
|
| positions_to_cols: dict[int, list[int]] = {} |
| num_turns = 0 |
| if starts_tensor is not None and starts_mask is not None: |
| num_turns = int(starts_tensor.size(1)) |
| for col in range(num_turns): |
| if not bool(starts_mask[row, col].item()): |
| continue |
| start = int(starts_tensor[row, col].item()) |
| if 0 <= start < valid_len: |
| positions_to_cols.setdefault(start, []).append(col) |
|
|
| remapped_row_turns = [-1] * num_turns |
| remapped_row_mask = [False] * num_turns |
| offset = 0 |
| for start in sorted(positions_to_cols.keys()): |
| insert_at = min(max(int(start) + offset, 0), len(ids_list)) |
| ids_list[insert_at:insert_at] = list(token_ids) |
| mask_list[insert_at:insert_at] = [1] * token_count |
| if labels_list is not None: |
| labels_list[insert_at:insert_at] = [-100] * token_count |
| for col in positions_to_cols[start]: |
| remapped_row_turns[col] = insert_at |
| remapped_row_mask[col] = True |
| offset += token_count |
|
|
| if split_starts is not None: |
| base_split = int(split_starts[row].item()) |
| remapped_split = base_split |
| if 0 <= base_split <= valid_len: |
| remapped_split = base_split + token_count * sum(1 for start in positions_to_cols.keys() if start < base_split) |
| if base_split in positions_to_cols: |
| remapped_split = base_split + token_count * sum(1 for start in positions_to_cols.keys() if start < base_split) |
| remapped_split_starts.append(remapped_split) |
|
|
| row_valid_len = len(ids_list) |
| max_len = max(max_len, row_valid_len) |
| row_pad_id = 0 |
| if valid_len < int(input_ids.size(1)): |
| row_pad_id = int(input_ids[row, valid_len].item()) |
| row_id_tensor = input_ids.new_full((row_valid_len,), row_pad_id) |
| row_id_tensor[:] = torch.tensor(ids_list, device=device, dtype=input_ids.dtype) |
| row_mask_tensor = attention_mask.new_zeros((row_valid_len,)) |
| row_mask_tensor[:] = torch.tensor(mask_list, device=device, dtype=attention_mask.dtype) |
| row_ids.append(row_id_tensor) |
| row_masks.append(row_mask_tensor) |
|
|
| if labels is not None and labels_list is not None: |
| row_label_tensor = labels.new_full((row_valid_len,), -100) |
| row_label_tensor[:] = torch.tensor(labels_list, device=device, dtype=labels.dtype) |
| row_labels.append(row_label_tensor) |
|
|
| if starts_tensor is not None: |
| remapped_turn_starts.append(remapped_row_turns) |
| remapped_turn_masks.append(remapped_row_mask) |
|
|
| new_input_ids = input_ids.new_zeros((batch_size, max_len)) |
| new_attention_mask = attention_mask.new_zeros((batch_size, max_len)) |
| new_labels = labels.new_full((batch_size, max_len), -100) if labels is not None else None |
| for row in range(batch_size): |
| width = int(row_ids[row].size(0)) |
| new_input_ids[row, :width] = row_ids[row] |
| new_attention_mask[row, :width] = row_masks[row] |
| if new_labels is not None: |
| new_labels[row, :width] = row_labels[row] |
|
|
| new_split_starts = None |
| if split_starts is not None: |
| new_split_starts = torch.tensor(remapped_split_starts, device=device, dtype=torch.long) |
|
|
| if starts_tensor is None: |
| return new_input_ids, new_attention_mask, new_labels, new_split_starts, None, None |
|
|
| max_turns = int(starts_tensor.size(1)) |
| new_turn_starts = torch.full((batch_size, max_turns), -1, device=device, dtype=torch.long) |
| new_turn_mask = torch.zeros((batch_size, max_turns), device=device, dtype=torch.bool) |
| for row in range(batch_size): |
| for col in range(max_turns): |
| if col >= len(remapped_turn_starts[row]): |
| continue |
| new_turn_starts[row, col] = int(remapped_turn_starts[row][col]) |
| new_turn_mask[row, col] = bool(remapped_turn_masks[row][col]) |
|
|
| if generated_starts: |
| return new_input_ids, new_attention_mask, new_labels, new_split_starts, None, None |
| return new_input_ids, new_attention_mask, new_labels, new_split_starts, new_turn_starts, new_turn_mask |
|
|
|
|
| def tri_vanilla_train_forward( |
| self: "LlamaForCausalLM", |
| *, |
| input_ids: Optional[torch.LongTensor], |
| attention_mask: Optional[torch.Tensor], |
| labels: Optional[torch.LongTensor], |
| logits_to_keep: Union[int, torch.Tensor], |
| prefill_lower_layers: int, |
| prefill_attn: str, |
| split_start_hint, |
| system_len_hint, |
| system_prefill: str, |
| assistant_header_starts_hint=None, |
| assistant_turn_ends_hint=None, |
| assistant_header_start_mask_hint=None, |
| see_past_assistant: bool = False, |
| no_upper_attn: bool = False, |
| upper_attention_mode: str = "causal", |
| replay_module: str = "none", |
| replay_per_layers: int = -1, |
| ) -> Optional[CausalLMOutputWithPast]: |
| if input_ids is None or labels is None or input_ids.dim() != 2 or labels.dim() != 2: |
| return None |
|
|
| batch_size = int(input_ids.size(0)) |
| if attention_mask is not None: |
| if attention_mask.dim() != 2 or attention_mask.size(0) != batch_size: |
| return None |
| batch_attention_mask = attention_mask.to(device=input_ids.device, dtype=torch.long) |
| valid_lens = batch_attention_mask.sum(dim=1, dtype=torch.long) |
| else: |
| batch_attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) |
| valid_lens = torch.full( |
| (batch_size,), |
| int(input_ids.size(1)), |
| device=input_ids.device, |
| dtype=torch.long, |
| ) |
| if torch.any(valid_lens <= 0): |
| return None |
|
|
| max_valid_len = int(valid_lens.max().item()) |
| trimmed_input_ids = input_ids[:, :max_valid_len] |
| trimmed_labels = labels[:, :max_valid_len] |
| full_attention_mask = batch_attention_mask[:, :max_valid_len] |
|
|
| parsed_assistant_header_hints = _tri_parse_assistant_header_start_hints( |
| assistant_header_starts_hint, |
| assistant_header_start_mask_hint, |
| batch_size=batch_size, |
| device=trimmed_labels.device, |
| ) |
| parsed_assistant_turn_end_hints = _tri_parse_assistant_header_start_hints( |
| assistant_turn_ends_hint, |
| assistant_header_start_mask_hint, |
| batch_size=batch_size, |
| device=trimmed_labels.device, |
| ) |
| parsed_split_starts = _tri_parse_split_start_hints( |
| split_start_hint, |
| batch_size=batch_size, |
| device=trimmed_labels.device, |
| ) |
| if _tri_effective_fusion_mode(self) == "inband": |
| header_starts = None |
| header_start_mask = None |
| if parsed_assistant_header_hints is not None: |
| header_starts, header_start_mask = parsed_assistant_header_hints |
| ( |
| trimmed_input_ids, |
| full_attention_mask, |
| trimmed_labels, |
| parsed_split_starts, |
| header_starts, |
| header_start_mask, |
| ) = _tri_insert_suffix_specials_inband( |
| token_ids=_tri_effective_suffix_special_token_ids(self), |
| input_ids=trimmed_input_ids, |
| attention_mask=full_attention_mask, |
| labels=trimmed_labels, |
| split_starts=parsed_split_starts, |
| assistant_header_starts=header_starts, |
| assistant_header_start_mask=header_start_mask, |
| ) |
| valid_lens = full_attention_mask.sum(dim=1, dtype=torch.long) |
| max_valid_len = int(valid_lens.max().item()) |
| parsed_assistant_header_hints = ( |
| (header_starts, header_start_mask) if header_starts is not None and header_start_mask is not None else None |
| ) |
|
|
| valid_label_mask = trimmed_labels != -100 |
| valid_counts = valid_label_mask.sum(dim=1, dtype=torch.long) |
| if torch.any(valid_counts <= 0): |
| _tri_vanilla_warn_once( |
| self, |
| "_tri_warned_prefill_lower_train_no_labels", |
| "[TRI] prefill_lower_layers training path found no supervised suffix; falling back to vanilla.", |
| ) |
| return None |
|
|
| token_positions = torch.arange(max_valid_len, device=trimmed_labels.device, dtype=torch.long).unsqueeze(0) |
| first_valid = torch.where( |
| valid_label_mask, |
| token_positions, |
| torch.full_like(token_positions, max_valid_len), |
| ).min(dim=1).values |
| last_valid = torch.where( |
| valid_label_mask, |
| token_positions, |
| torch.full_like(token_positions, -1), |
| ).max(dim=1).values |
| contiguous_suffix = valid_counts == (last_valid - first_valid + 1) |
| use_multiturn = not bool(torch.all(contiguous_suffix).item()) |
| turn_sample_rows: list[int] = [] |
| turn_starts_list: list[int] = [] |
| turn_ends_list: list[int] = [] |
| if use_multiturn: |
| if parsed_assistant_header_hints is None: |
| _tri_vanilla_warn_once( |
| self, |
| "_tri_warned_prefill_lower_train_noncontiguous", |
| "[TRI] prefill_lower_layers training path found multi-turn labels without assistant turn metadata; falling back to vanilla.", |
| ) |
| return None |
| header_starts, header_start_mask = parsed_assistant_header_hints |
| for row in range(batch_size): |
| spans = _tri_resolve_assistant_turn_spans( |
| labels_row=trimmed_labels[row], |
| valid_len=int(valid_lens[row].item()), |
| assistant_header_starts=header_starts[row], |
| assistant_header_start_mask=header_start_mask[row], |
| ) |
| if not spans: |
| _tri_vanilla_warn_once( |
| self, |
| "_tri_warned_prefill_lower_train_noncontiguous", |
| "[TRI] prefill_lower_layers training path could not resolve multi-turn assistant spans; falling back to vanilla.", |
| ) |
| return None |
| for turn_start, turn_end in spans: |
| turn_sample_rows.append(row) |
| turn_starts_list.append(int(turn_start)) |
| turn_ends_list.append(int(turn_end)) |
|
|
| split_starts = (first_valid - 1).clamp_min_(0) |
| if parsed_split_starts is not None: |
| valid_split_mask = ( |
| (parsed_split_starts >= 0) |
| & (parsed_split_starts < valid_lens) |
| & (parsed_split_starts <= first_valid) |
| ) |
| if not bool(torch.all(valid_split_mask).item()): |
| _tri_vanilla_warn_once( |
| self, |
| "_tri_warned_prefill_lower_train_invalid_split_start", |
| "[TRI] Ignoring invalid prefill_lower split start for training; falling back to the label-derived boundary.", |
| ) |
| split_starts = torch.where(valid_split_mask, parsed_split_starts, split_starts) |
|
|
| parsed_system_lens = _tri_parse_system_len_hints( |
| system_len_hint, |
| batch_size=batch_size, |
| device=trimmed_labels.device, |
| ) |
| if parsed_system_lens is None: |
| system_lens = torch.zeros_like(split_starts) |
| else: |
| system_lens = torch.minimum(parsed_system_lens.clamp_min(0), split_starts) |
| system_prefill = normalize_system_prefill(system_prefill) |
| upper_attention_mode = _normalize_upper_attention_mode(upper_attention_mode) |
| replay_module = _normalize_replay_module(replay_module) |
| replay_per_layers = _normalize_replay_per_layers(replay_per_layers) |
| if bool(no_upper_attn) and upper_attention_mode != "causal": |
| raise ValueError("upper_attention_mode=solo_v2 cannot be combined with no_upper_attn.") |
| if replay_module != "none" and upper_attention_mode != "causal": |
| raise ValueError("replay_module cannot be combined with upper_attention_mode=solo_v2.") |
|
|
| model = self.model |
| n_layers = len(model.layers) |
| K = max(0, min(int(prefill_lower_layers), n_layers)) |
| if K <= 0: |
| return None |
|
|
| full_ids = trimmed_input_ids |
| B, L = full_ids.shape |
| device = full_ids.device |
|
|
| attn_mode = (prefill_attn or "causal").strip().lower() |
| if attn_mode == "prefix_full": |
| attn_mode = "full" |
| if attn_mode not in ("causal", "full"): |
| raise ValueError("prefill_lower_attn must be one of {'causal', 'full'}.") |
| prefix_full = attn_mode == "full" |
|
|
| inputs_embeds = model.embed_tokens(full_ids) |
| cache_position = _tri_arange(0, L, device) |
| position_ids = _llopa_position_ids_from_mask(full_attention_mask) |
|
|
| attn_impl = _resolve_attn_impl(model.config) |
| if prefix_full and attn_impl == "flash_attention_2": |
| raise ValueError("prefill_lower_attn=full is not supported with flash_attention_2.") |
| if attn_impl == "flash_attention_2": |
| lower_mask = None if _llopa_mask_is_all_ones(full_attention_mask) else full_attention_mask |
| elif not prefix_full and _can_use_implicit_causal_mask(model.config) and _llopa_mask_is_all_ones(full_attention_mask): |
| lower_mask = None |
| elif prefix_full: |
| lower_mask = _build_tri_mask_prefix_full_varlen(split_starts, full_attention_mask, inputs_embeds.dtype) |
| else: |
| lower_mask = create_causal_mask( |
| config=model.config, |
| input_embeds=inputs_embeds, |
| attention_mask=full_attention_mask, |
| cache_position=cache_position, |
| past_key_values=None, |
| position_ids=position_ids, |
| ) |
|
|
| hidden_states = inputs_embeds |
| position_embeddings = model.rotary_emb(hidden_states, position_ids) |
| use_checkpoint = should_use_checkpoint(model) |
| use_optimized_training_path = _llopa_use_optimized_training_path( |
| model, |
| attn_impl=attn_impl, |
| prefix_full=prefix_full, |
| no_upper_attn=bool(no_upper_attn), |
| ) |
| use_custom_checkpoint = use_checkpoint and not use_optimized_training_path |
|
|
| for li in range(K): |
| layer = model.layers[li] |
| if use_custom_checkpoint: |
| hidden_states = checkpoint_with_cache( |
| layer.__call__, |
| hidden_states, |
| attention_mask=lower_mask, |
| position_ids=position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| else: |
| hidden_states = layer( |
| hidden_states, |
| attention_mask=lower_mask, |
| position_ids=position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
|
|
| split_starts_upper = split_starts.to(device=hidden_states.device, dtype=torch.long) |
| valid_lens_upper = valid_lens.to(device=hidden_states.device, dtype=torch.long) |
| system_lens_upper = system_lens.to(device=hidden_states.device, dtype=torch.long) |
| upper_prepare_mode = str(getattr(self, "_optimized_llopa_upper_prepare_mode", "exact") or "exact").strip().lower() |
| upper_bucket_multiple = int(getattr(self, "_optimized_llopa_upper_bucket_multiple", 0) or 0) |
| use_upper_workspace = ( |
| not bool(getattr(self, "training", False)) |
| and upper_prepare_mode == "bucketed_workspace" |
| and upper_bucket_multiple > 0 |
| ) |
| upper_prepare_owner = self if use_upper_workspace else None |
| upper_position_ids_src = position_ids.to(device=hidden_states.device, dtype=torch.long) |
| decode_labels_src = trimmed_labels.to(device=hidden_states.device) |
| if use_multiturn: |
| turn_sample_ids = torch.tensor(turn_sample_rows, device=hidden_states.device, dtype=torch.long) |
| turn_starts = torch.tensor(turn_starts_list, device=hidden_states.device, dtype=torch.long) |
| turn_ends = torch.tensor(turn_ends_list, device=hidden_states.device, dtype=torch.long) |
| turn_system_lens = system_lens_upper.index_select(0, turn_sample_ids) |
| prefix_keep_lens = _tri_prefix_keep_lengths(turn_system_lens, system_prefill) |
| split_starts_for_suffix_specials = turn_starts |
| upper_gather_idx, upper_valid_mask, upper_lens = _tri_build_prefill_lower_multiturn_index_batch( |
| turn_starts=turn_starts, |
| turn_ends=turn_ends, |
| system_lens=turn_system_lens, |
| system_prefill=system_prefill, |
| device=hidden_states.device, |
| owner=upper_prepare_owner, |
| bucket_multiple=upper_bucket_multiple if use_upper_workspace else 0, |
| ) |
| hidden_states_upper_src = hidden_states.index_select(0, turn_sample_ids) |
| upper_hidden, _ = _tri_pack_indexed_tensor( |
| hidden_states_upper_src, |
| gather_idx=upper_gather_idx, |
| valid_mask=upper_valid_mask, |
| pad_value=0.0, |
| owner=upper_prepare_owner, |
| workspace_name="upper_hidden", |
| ) |
| upper_position_ids_src = upper_position_ids_src.index_select(0, turn_sample_ids) |
| upper_position_ids, _ = _tri_pack_indexed_tensor( |
| upper_position_ids_src, |
| gather_idx=upper_gather_idx, |
| valid_mask=upper_valid_mask, |
| pad_value=0, |
| owner=upper_prepare_owner, |
| workspace_name="upper_position_ids", |
| ) |
| decode_labels_src = decode_labels_src.index_select(0, turn_sample_ids) |
| decode_labels, _ = _tri_pack_indexed_tensor( |
| decode_labels_src, |
| gather_idx=upper_gather_idx, |
| valid_mask=upper_valid_mask, |
| pad_value=-100, |
| owner=upper_prepare_owner, |
| workspace_name="upper_decode_labels", |
| ) |
| else: |
| prefix_keep_lens = _tri_prefix_keep_lengths(system_lens_upper, system_prefill) |
| split_starts_for_suffix_specials = split_starts_upper |
| use_past_assistant = ( |
| bool(see_past_assistant) |
| and parsed_assistant_header_hints is not None |
| and parsed_assistant_turn_end_hints is not None |
| ) |
| if use_past_assistant: |
| header_starts, header_start_mask = parsed_assistant_header_hints |
| turn_ends, _ = parsed_assistant_turn_end_hints |
| upper_gather_idx, upper_valid_mask, upper_lens = _tri_build_prefill_lower_past_assistant_index_batch( |
| split_starts=split_starts_upper, |
| valid_lens=valid_lens_upper, |
| system_lens=system_lens_upper, |
| system_prefill=system_prefill, |
| assistant_header_starts=header_starts, |
| assistant_turn_ends=turn_ends, |
| assistant_header_start_mask=header_start_mask, |
| device=hidden_states.device, |
| owner=upper_prepare_owner, |
| bucket_multiple=upper_bucket_multiple if use_upper_workspace else 0, |
| ) |
| else: |
| upper_gather_idx, upper_valid_mask, upper_lens = _tri_build_prefill_lower_upper_index_batch( |
| split_starts=split_starts_upper, |
| valid_lens=valid_lens_upper, |
| system_lens=system_lens_upper, |
| system_prefill=system_prefill, |
| device=hidden_states.device, |
| owner=upper_prepare_owner, |
| bucket_multiple=upper_bucket_multiple if use_upper_workspace else 0, |
| ) |
| upper_hidden, _ = _tri_pack_indexed_tensor( |
| hidden_states, |
| gather_idx=upper_gather_idx, |
| valid_mask=upper_valid_mask, |
| pad_value=0.0, |
| owner=upper_prepare_owner, |
| workspace_name="upper_hidden", |
| ) |
| upper_position_ids, _ = _tri_pack_indexed_tensor( |
| upper_position_ids_src, |
| gather_idx=upper_gather_idx, |
| valid_mask=upper_valid_mask, |
| pad_value=0, |
| owner=upper_prepare_owner, |
| workspace_name="upper_position_ids", |
| ) |
| decode_labels, _ = _tri_pack_indexed_tensor( |
| decode_labels_src, |
| gather_idx=upper_gather_idx, |
| valid_mask=upper_valid_mask, |
| pad_value=-100, |
| owner=upper_prepare_owner, |
| workspace_name="upper_decode_labels", |
| ) |
| upper_attention_mask = upper_valid_mask.to(dtype=full_attention_mask.dtype, device=hidden_states.device) |
| replay_hidden = None |
| replay_position_ids = None |
| replay_valid_mask = None |
| upper_layer_indices = list(range(K, n_layers)) |
| replay_layer_set = _tri_replay_layer_index_set( |
| upper_layer_indices=upper_layer_indices, |
| replay_per_layers=replay_per_layers, |
| ) |
| if not bool(no_upper_attn) and _tri_effective_fusion_mode(self) == "upper_only": |
| upper_hidden, upper_position_ids, upper_attention_mask, decode_labels, upper_lens = _tri_repack_upper_with_suffix_specials( |
| self, |
| upper_hidden=upper_hidden, |
| upper_position_ids=upper_position_ids, |
| upper_attention_mask=upper_attention_mask, |
| decode_labels=decode_labels, |
| prefix_keep_lens=prefix_keep_lens, |
| split_starts=split_starts_for_suffix_specials, |
| ) |
| replay_query_mask = None |
| replay_enabled = ( |
| replay_module != "none" |
| and not bool(no_upper_attn) |
| and K < n_layers |
| and bool(replay_layer_set) |
| ) |
| if replay_enabled: |
| if use_multiturn: |
| replay_gather_idx, replay_valid_mask, _ = _tri_build_last_layer_memory_index_batch( |
| split_starts=turn_starts, |
| prefix_keep_lens=prefix_keep_lens, |
| device=hidden_states.device, |
| owner=upper_prepare_owner, |
| bucket_multiple=upper_bucket_multiple if use_upper_workspace else 0, |
| workspace_key="multiturn", |
| ) |
| replay_hidden, _ = _tri_pack_indexed_tensor( |
| hidden_states_upper_src, |
| gather_idx=replay_gather_idx, |
| valid_mask=replay_valid_mask, |
| pad_value=0.0, |
| owner=upper_prepare_owner, |
| workspace_name="last_layer_memory_hidden", |
| ) |
| replay_position_ids, _ = _tri_pack_indexed_tensor( |
| upper_position_ids_src, |
| gather_idx=replay_gather_idx, |
| valid_mask=replay_valid_mask, |
| pad_value=0, |
| owner=upper_prepare_owner, |
| workspace_name="last_layer_memory_position_ids", |
| ) |
| else: |
| replay_gather_idx, replay_valid_mask, _ = _tri_build_last_layer_memory_index_batch( |
| split_starts=split_starts_upper, |
| prefix_keep_lens=prefix_keep_lens, |
| device=hidden_states.device, |
| owner=upper_prepare_owner, |
| bucket_multiple=upper_bucket_multiple if use_upper_workspace else 0, |
| workspace_key="single", |
| ) |
| replay_hidden, _ = _tri_pack_indexed_tensor( |
| hidden_states, |
| gather_idx=replay_gather_idx, |
| valid_mask=replay_valid_mask, |
| pad_value=0.0, |
| owner=upper_prepare_owner, |
| workspace_name="last_layer_memory_hidden", |
| ) |
| replay_position_ids, _ = _tri_pack_indexed_tensor( |
| upper_position_ids_src, |
| gather_idx=replay_gather_idx, |
| valid_mask=replay_valid_mask, |
| pad_value=0, |
| owner=upper_prepare_owner, |
| workspace_name="last_layer_memory_position_ids", |
| ) |
| if isinstance(replay_hidden, torch.Tensor) and replay_hidden.numel() > 0 and int(replay_hidden.size(1)) > 0: |
| replay_query_mask = _tri_build_last_layer_query_mask( |
| prefix_keep_lens=prefix_keep_lens, |
| valid_lens=upper_lens, |
| total_len=int(upper_hidden.size(1)), |
| device=upper_hidden.device, |
| ) |
| else: |
| replay_hidden = None |
| replay_position_ids = None |
| replay_valid_mask = None |
| replay_enabled = False |
| upper_cache_position = _tri_arange(0, upper_hidden.size(1), upper_hidden.device) |
|
|
| if K < n_layers: |
| if bool(no_upper_attn): |
| for li in range(K, n_layers): |
| layer = model.layers[li] |
| residual = upper_hidden |
| ffn_input = layer.post_attention_layernorm(upper_hidden) |
| if use_checkpoint: |
| ffn_out = _checkpoint_mlp(layer.mlp, ffn_input) |
| else: |
| ffn_out = layer.mlp(ffn_input) |
| upper_hidden = residual + ffn_out |
| elif upper_attention_mode == "solo_v2": |
| upper_pos_emb = model.rotary_emb(upper_hidden, upper_position_ids) |
| for li in range(K, n_layers): |
| upper_hidden = _tri_run_decoder_layer_upper_solo_v2( |
| layer=model.layers[li], |
| hidden_states=upper_hidden, |
| position_embeddings=upper_pos_emb, |
| with_bos=False, |
| use_checkpoint=use_checkpoint, |
| ) |
| else: |
| if attn_impl == "flash_attention_2": |
| upper_mask = None if _llopa_mask_is_all_ones(upper_attention_mask) else upper_attention_mask |
| elif _can_use_implicit_causal_mask(model.config) and _llopa_mask_is_all_ones(upper_attention_mask): |
| upper_mask = None |
| else: |
| upper_mask = create_causal_mask( |
| config=model.config, |
| input_embeds=upper_hidden, |
| attention_mask=upper_attention_mask, |
| cache_position=upper_cache_position, |
| past_key_values=None, |
| position_ids=upper_position_ids, |
| ) |
|
|
| upper_pos_emb = model.rotary_emb(upper_hidden, upper_position_ids) |
| for li in range(K, n_layers): |
| layer = model.layers[li] |
| use_replay_layer = ( |
| replay_enabled |
| and li in replay_layer_set |
| and isinstance(replay_hidden, torch.Tensor) |
| and replay_hidden.numel() > 0 |
| ) |
| if use_replay_layer: |
| upper_hidden = _tri_run_last_layer_with_replay( |
| layer=layer, |
| hidden_states=upper_hidden, |
| attention_mask=upper_mask, |
| position_ids=upper_position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=upper_cache_position, |
| position_embeddings=upper_pos_emb, |
| memory_hidden_states=replay_hidden, |
| memory_position_ids=replay_position_ids, |
| memory_valid_mask=replay_valid_mask, |
| local_valid_mask=upper_attention_mask.to(device=upper_hidden.device, dtype=torch.bool), |
| query_replay_mask=replay_query_mask, |
| module_type=replay_module, |
| rotary_emb=model.rotary_emb, |
| use_checkpoint=use_checkpoint, |
| ) |
| elif use_custom_checkpoint: |
| upper_hidden = checkpoint_with_cache( |
| layer.__call__, |
| upper_hidden, |
| attention_mask=upper_mask, |
| position_ids=upper_position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=upper_cache_position, |
| position_embeddings=upper_pos_emb, |
| ) |
| else: |
| upper_hidden = layer( |
| upper_hidden, |
| attention_mask=upper_mask, |
| position_ids=upper_position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=upper_cache_position, |
| position_embeddings=upper_pos_emb, |
| ) |
|
|
| upper_hidden = model.norm(upper_hidden) |
| kept_labels = decode_labels |
| if isinstance(logits_to_keep, int) and logits_to_keep > 0: |
| keep_lens = torch.minimum( |
| upper_lens, |
| torch.full_like(upper_lens, int(logits_to_keep)), |
| ) |
| keep_hidden, _, _, _ = _tri_pack_suffix_tensor( |
| upper_hidden, |
| split_starts=(upper_lens - keep_lens).clamp_min_(0), |
| valid_lens=upper_lens, |
| pad_value=0.0, |
| ) |
| kept_labels, _, _, _ = _tri_pack_suffix_tensor( |
| decode_labels, |
| split_starts=(upper_lens - keep_lens).clamp_min_(0), |
| valid_lens=upper_lens, |
| pad_value=-100, |
| ) |
| logits = self.lm_head(keep_hidden) |
| else: |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| if isinstance(slice_indices, torch.Tensor): |
| slice_indices = slice_indices.to(device=upper_hidden.device) |
| logits = self.lm_head(upper_hidden[:, slice_indices, :]) |
| if not isinstance(logits_to_keep, int): |
| kept_labels = decode_labels[:, slice_indices] |
| loss = self.loss_function(logits=logits, labels=kept_labels, vocab_size=self.config.vocab_size) |
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=None, |
| hidden_states=upper_hidden, |
| ) |
|
|
|
|
| def tri_vanilla_frozen_prefix_train_forward( |
| self: "LlamaForCausalLM", |
| *, |
| input_ids: Optional[torch.LongTensor], |
| attention_mask: Optional[torch.Tensor], |
| labels: Optional[torch.LongTensor], |
| use_cache: Optional[bool] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| past_key_values=None, |
| prefill_lower_layers: int, |
| prefill_attn: str, |
| split_start_hint, |
| system_len_hint, |
| system_prefill: str, |
| ) -> Optional[CausalLMOutputWithPast]: |
| if input_ids is None or labels is None or input_ids.dim() != 2 or labels.dim() != 2: |
| return None |
|
|
| batch_size = int(input_ids.size(0)) |
| if attention_mask is not None: |
| if attention_mask.dim() != 2 or attention_mask.size(0) != batch_size: |
| return None |
| batch_attention_mask = attention_mask.to(device=input_ids.device, dtype=torch.long) |
| valid_lens = batch_attention_mask.sum(dim=1, dtype=torch.long) |
| else: |
| batch_attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) |
| valid_lens = torch.full( |
| (batch_size,), |
| int(input_ids.size(1)), |
| device=input_ids.device, |
| dtype=torch.long, |
| ) |
| if torch.any(valid_lens <= 0): |
| return None |
|
|
| max_valid_len = int(valid_lens.max().item()) |
| trimmed_input_ids = input_ids[:, :max_valid_len] |
| trimmed_labels = labels[:, :max_valid_len] |
| full_attention_mask = batch_attention_mask[:, :max_valid_len] |
|
|
| valid_label_mask = trimmed_labels != -100 |
| valid_counts = valid_label_mask.sum(dim=1, dtype=torch.long) |
| if torch.any(valid_counts <= 0): |
| _tri_vanilla_warn_once( |
| self, |
| "_tri_warned_prefill_lower_freeze_no_labels", |
| "[TRI] prefill_lower_freeze path found no supervised suffix; falling back to vanilla.", |
| ) |
| return None |
|
|
| token_positions = torch.arange(max_valid_len, device=trimmed_labels.device, dtype=torch.long).unsqueeze(0) |
| first_valid = torch.where( |
| valid_label_mask, |
| token_positions, |
| torch.full_like(token_positions, max_valid_len), |
| ).min(dim=1).values |
| last_valid = torch.where( |
| valid_label_mask, |
| token_positions, |
| torch.full_like(token_positions, -1), |
| ).max(dim=1).values |
| contiguous_suffix = valid_counts == (last_valid - first_valid + 1) |
| if not bool(torch.all(contiguous_suffix).item()): |
| _tri_vanilla_warn_once( |
| self, |
| "_tri_warned_prefill_lower_freeze_noncontiguous", |
| "[TRI] prefill_lower_freeze expects a contiguous supervised suffix; falling back to vanilla.", |
| ) |
| return None |
|
|
| split_starts = (first_valid - 1).clamp_min_(0) |
| parsed_split_starts = _tri_parse_split_start_hints( |
| split_start_hint, |
| batch_size=batch_size, |
| device=trimmed_labels.device, |
| ) |
| if parsed_split_starts is not None: |
| valid_split_mask = ( |
| (parsed_split_starts >= 0) |
| & (parsed_split_starts < valid_lens) |
| & (parsed_split_starts <= first_valid) |
| ) |
| if not bool(torch.all(valid_split_mask).item()): |
| _tri_vanilla_warn_once( |
| self, |
| "_tri_warned_prefill_lower_freeze_invalid_split_start", |
| "[TRI] Ignoring invalid prefill_lower freeze split start; falling back to the label-derived boundary.", |
| ) |
| split_starts = torch.where(valid_split_mask, parsed_split_starts, split_starts) |
|
|
| parsed_system_lens = _tri_parse_system_len_hints( |
| system_len_hint, |
| batch_size=batch_size, |
| device=trimmed_labels.device, |
| ) |
| if parsed_system_lens is None: |
| system_lens = torch.zeros_like(split_starts) |
| else: |
| system_lens = torch.minimum(parsed_system_lens.clamp_min(0), split_starts) |
| system_prefill = normalize_system_prefill(system_prefill) |
|
|
| model = self.model |
| n_layers = len(model.layers) |
| K = max(0, min(int(prefill_lower_layers), n_layers)) |
| if K <= 0: |
| return None |
|
|
| full_ids = trimmed_input_ids |
| B, L = full_ids.shape |
| device = full_ids.device |
|
|
| attn_mode = (prefill_attn or "causal").strip().lower() |
| if attn_mode == "prefix_full": |
| attn_mode = "full" |
| if attn_mode not in ("causal", "full"): |
| raise ValueError("prefill_lower_attn must be one of {'causal', 'full'}.") |
| prefix_full = attn_mode == "full" |
|
|
| inputs_embeds = model.embed_tokens(full_ids) |
| cache_position = _tri_arange(0, L, device) |
| position_ids = _llopa_position_ids_from_mask(full_attention_mask) |
|
|
| attn_impl = _resolve_attn_impl(model.config) |
| if prefix_full and attn_impl == "flash_attention_2": |
| raise ValueError("prefill_lower_attn=full is not supported with flash_attention_2.") |
| if attn_impl == "flash_attention_2": |
| full_mask = None if _llopa_mask_is_all_ones(full_attention_mask) else full_attention_mask |
| elif not prefix_full and _can_use_implicit_causal_mask(model.config) and _llopa_mask_is_all_ones(full_attention_mask): |
| full_mask = None |
| elif prefix_full: |
| full_mask = _build_tri_mask_prefix_full_varlen(split_starts, full_attention_mask, inputs_embeds.dtype) |
| else: |
| full_mask = create_causal_mask( |
| config=model.config, |
| input_embeds=inputs_embeds, |
| attention_mask=full_attention_mask, |
| cache_position=cache_position, |
| past_key_values=None, |
| position_ids=position_ids, |
| ) |
|
|
| hidden_states = inputs_embeds |
| position_embeddings = model.rotary_emb(hidden_states, position_ids) |
| use_checkpoint = should_use_checkpoint(model) |
|
|
| for li in range(K): |
| layer = model.layers[li] |
| if use_checkpoint: |
| hidden_states = checkpoint_with_cache( |
| layer.__call__, |
| hidden_states, |
| attention_mask=full_mask, |
| position_ids=position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| else: |
| hidden_states = layer( |
| hidden_states, |
| attention_mask=full_mask, |
| position_ids=position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
|
|
| if K < n_layers: |
| split_starts_upper = split_starts.to(device=hidden_states.device, dtype=torch.long) |
| valid_lens_upper = valid_lens.to(device=hidden_states.device, dtype=torch.long) |
| system_lens_upper = system_lens.to(device=hidden_states.device, dtype=torch.long) |
| upper_gather_idx, upper_valid_mask, _ = _tri_build_prefill_lower_upper_index_batch( |
| split_starts=split_starts_upper, |
| valid_lens=valid_lens_upper, |
| system_lens=system_lens_upper, |
| system_prefill=system_prefill, |
| device=hidden_states.device, |
| ) |
| upper_visible_mask = torch.zeros((B, L), device=hidden_states.device, dtype=torch.bool) |
| if upper_gather_idx.numel() > 0 and upper_valid_mask.any(): |
| upper_visible_mask.scatter_(1, upper_gather_idx, upper_valid_mask) |
| valid_token_mask = token_positions[:, :L] < valid_lens.unsqueeze(1) |
| prefix_freeze_mask = valid_token_mask & (~upper_visible_mask) |
| prefix_freeze_mask = prefix_freeze_mask.unsqueeze(-1) |
| frozen_hidden_states = hidden_states |
| for li in range(K, n_layers): |
| layer = model.layers[li] |
| layer_input = torch.where(prefix_freeze_mask, frozen_hidden_states, hidden_states) |
| if use_checkpoint: |
| hidden_states = checkpoint_with_cache( |
| layer.__call__, |
| layer_input, |
| attention_mask=full_mask, |
| position_ids=position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| else: |
| hidden_states = layer( |
| layer_input, |
| attention_mask=full_mask, |
| position_ids=position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| hidden_states = torch.where(prefix_freeze_mask, frozen_hidden_states, hidden_states) |
|
|
| hidden_states = model.norm(hidden_states) |
| if isinstance(logits_to_keep, int) and logits_to_keep > 0: |
| keep_hidden, kept_labels, _, _ = _tri_pack_suffix_tensor( |
| hidden_states, |
| split_starts=(valid_lens - torch.minimum(valid_lens, torch.full_like(valid_lens, int(logits_to_keep)))).clamp_min_(0), |
| valid_lens=valid_lens, |
| pad_value=0.0, |
| ) |
| kept_labels, _, _, _ = _tri_pack_suffix_tensor( |
| trimmed_labels, |
| split_starts=(valid_lens - torch.minimum(valid_lens, torch.full_like(valid_lens, int(logits_to_keep)))).clamp_min_(0), |
| valid_lens=valid_lens, |
| pad_value=-100, |
| ) |
| logits = self.lm_head(keep_hidden) |
| loss = self.loss_function(logits=logits, labels=kept_labels, vocab_size=self.config.vocab_size) |
| else: |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| if isinstance(slice_indices, torch.Tensor): |
| slice_indices = slice_indices.to(device=hidden_states.device) |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
| selected_labels = trimmed_labels[:, slice_indices] if isinstance(slice_indices, slice) else trimmed_labels.index_select(1, slice_indices) |
| loss = self.loss_function(logits=logits, labels=selected_labels, vocab_size=self.config.vocab_size) |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=None, |
| hidden_states=hidden_states, |
| ) |
|
|
|
|
| LlamaForCausalLM.tri_vanilla_train_forward = tri_vanilla_train_forward |
| LlamaForCausalLM.tri_vanilla_frozen_prefix_train_forward = tri_vanilla_frozen_prefix_train_forward |
|
|
|
|
| def tri_vanilla_frozen_prefix_infer_forward( |
| self: "LlamaForCausalLM", |
| *, |
| input_ids: Optional[torch.LongTensor], |
| attention_mask: Optional[torch.Tensor], |
| logits_to_keep: Union[int, torch.Tensor], |
| prefill_lower_layers: int, |
| prefill_attn: str, |
| split_start_hint, |
| system_len_hint, |
| system_prefill: str, |
| ) -> Optional[CausalLMOutputWithPast]: |
| if input_ids is None or input_ids.dim() != 2: |
| return None |
|
|
| batch_size = int(input_ids.size(0)) |
| if attention_mask is not None: |
| if attention_mask.dim() != 2 or attention_mask.size(0) != batch_size: |
| return None |
| batch_attention_mask = attention_mask.to(device=input_ids.device, dtype=torch.long) |
| valid_lens = batch_attention_mask.sum(dim=1, dtype=torch.long) |
| else: |
| batch_attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) |
| valid_lens = torch.full( |
| (batch_size,), |
| int(input_ids.size(1)), |
| device=input_ids.device, |
| dtype=torch.long, |
| ) |
| if torch.any(valid_lens <= 0): |
| return None |
|
|
| max_valid_len = int(valid_lens.max().item()) |
| trimmed_input_ids = input_ids[:, :max_valid_len] |
| full_attention_mask = batch_attention_mask[:, :max_valid_len] |
|
|
| parsed_split_starts = _tri_parse_split_start_hints( |
| split_start_hint, |
| batch_size=batch_size, |
| device=trimmed_input_ids.device, |
| ) |
| if parsed_split_starts is None: |
| split_starts = (valid_lens - 1).clamp_min_(0) |
| else: |
| valid_split_mask = (parsed_split_starts >= 0) & (parsed_split_starts < valid_lens) |
| if not bool(torch.all(valid_split_mask).item()): |
| _tri_vanilla_warn_once( |
| self, |
| "_tri_warned_prefill_lower_freeze_infer_invalid_split_start", |
| "[TRI] Ignoring invalid prefill_lower freeze split start for inference; using the last prompt token.", |
| ) |
| split_starts = torch.where(valid_split_mask, parsed_split_starts, (valid_lens - 1).clamp_min_(0)) |
|
|
| parsed_system_lens = _tri_parse_system_len_hints( |
| system_len_hint, |
| batch_size=batch_size, |
| device=trimmed_input_ids.device, |
| ) |
| if parsed_system_lens is None: |
| system_lens = torch.zeros_like(split_starts) |
| else: |
| system_lens = torch.minimum(parsed_system_lens.clamp_min(0), split_starts) |
| system_prefill = normalize_system_prefill(system_prefill) |
|
|
| model = self.model |
| n_layers = len(model.layers) |
| K = max(0, min(int(prefill_lower_layers), n_layers)) |
| if K <= 0: |
| return None |
|
|
| full_ids = trimmed_input_ids |
| B, L = full_ids.shape |
| device = full_ids.device |
|
|
| attn_mode = (prefill_attn or "causal").strip().lower() |
| if attn_mode == "prefix_full": |
| attn_mode = "full" |
| if attn_mode not in ("causal", "full"): |
| raise ValueError("prefill_lower_attn must be one of {'causal', 'full'}.") |
| prefix_full = attn_mode == "full" |
|
|
| inputs_embeds = model.embed_tokens(full_ids) |
| cache_position = _tri_arange(0, L, device) |
| position_ids = _llopa_position_ids_from_mask(full_attention_mask) |
|
|
| attn_impl = _resolve_attn_impl(model.config) |
| if prefix_full and attn_impl == "flash_attention_2": |
| raise ValueError("prefill_lower_attn=full is not supported with flash_attention_2.") |
| if attn_impl == "flash_attention_2": |
| full_mask = None if _llopa_mask_is_all_ones(full_attention_mask) else full_attention_mask |
| elif not prefix_full and _can_use_implicit_causal_mask(model.config) and _llopa_mask_is_all_ones(full_attention_mask): |
| full_mask = None |
| elif prefix_full: |
| full_mask = _build_tri_mask_prefix_full_varlen(split_starts, full_attention_mask, inputs_embeds.dtype) |
| else: |
| full_mask = create_causal_mask( |
| config=model.config, |
| input_embeds=inputs_embeds, |
| attention_mask=full_attention_mask, |
| cache_position=cache_position, |
| past_key_values=None, |
| position_ids=position_ids, |
| ) |
|
|
| hidden_states = inputs_embeds |
| position_embeddings = model.rotary_emb(hidden_states, position_ids) |
| use_checkpoint = should_use_checkpoint(model) |
|
|
| for li in range(K): |
| layer = model.layers[li] |
| if use_checkpoint: |
| hidden_states = checkpoint_with_cache( |
| layer.__call__, |
| hidden_states, |
| attention_mask=full_mask, |
| position_ids=position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| else: |
| hidden_states = layer( |
| hidden_states, |
| attention_mask=full_mask, |
| position_ids=position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
|
|
| if K < n_layers: |
| split_starts_upper = split_starts.to(device=hidden_states.device, dtype=torch.long) |
| valid_lens_upper = valid_lens.to(device=hidden_states.device, dtype=torch.long) |
| system_lens_upper = system_lens.to(device=hidden_states.device, dtype=torch.long) |
| upper_gather_idx, upper_valid_mask, _ = _tri_build_prefill_lower_upper_index_batch( |
| split_starts=split_starts_upper, |
| valid_lens=valid_lens_upper, |
| system_lens=system_lens_upper, |
| system_prefill=system_prefill, |
| device=hidden_states.device, |
| ) |
| upper_visible_mask = torch.zeros((B, L), device=hidden_states.device, dtype=torch.bool) |
| if upper_gather_idx.numel() > 0 and upper_valid_mask.any(): |
| upper_visible_mask.scatter_(1, upper_gather_idx, upper_valid_mask) |
| token_positions = torch.arange(L, device=hidden_states.device, dtype=torch.long).unsqueeze(0) |
| valid_token_mask = token_positions[:, :L] < valid_lens.unsqueeze(1) |
| prefix_freeze_mask = valid_token_mask & (~upper_visible_mask) |
| prefix_freeze_mask = prefix_freeze_mask.unsqueeze(-1) |
| frozen_hidden_states = hidden_states |
| for li in range(K, n_layers): |
| layer = model.layers[li] |
| layer_input = torch.where(prefix_freeze_mask, frozen_hidden_states, hidden_states) |
| if use_checkpoint: |
| hidden_states = checkpoint_with_cache( |
| layer.__call__, |
| layer_input, |
| attention_mask=full_mask, |
| position_ids=position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| else: |
| hidden_states = layer( |
| layer_input, |
| attention_mask=full_mask, |
| position_ids=position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| hidden_states = torch.where(prefix_freeze_mask, frozen_hidden_states, hidden_states) |
|
|
| hidden_states = model.norm(hidden_states) |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| if isinstance(slice_indices, torch.Tensor): |
| slice_indices = slice_indices.to(device=hidden_states.device) |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
| return CausalLMOutputWithPast( |
| loss=None, |
| logits=logits, |
| past_key_values=None, |
| hidden_states=hidden_states, |
| ) |
|
|
|
|
| LlamaForCausalLM.tri_vanilla_frozen_prefix_infer_forward = tri_vanilla_frozen_prefix_infer_forward |
|
|
|
|
| def tri_vanilla_solo_prefix_infer_forward( |
| self: "LlamaForCausalLM", |
| *, |
| input_ids: Optional[torch.LongTensor], |
| attention_mask: Optional[torch.Tensor], |
| logits_to_keep: Union[int, torch.Tensor], |
| prefill_lower_layers: int, |
| prefill_attn: str, |
| split_start_hint, |
| system_len_hint, |
| system_prefill: str, |
| ) -> Optional[CausalLMOutputWithPast]: |
| if input_ids is None or input_ids.dim() != 2: |
| return None |
|
|
| batch_size = int(input_ids.size(0)) |
| if attention_mask is not None: |
| if attention_mask.dim() != 2 or attention_mask.size(0) != batch_size: |
| return None |
| batch_attention_mask = attention_mask.to(device=input_ids.device, dtype=torch.long) |
| valid_lens = batch_attention_mask.sum(dim=1, dtype=torch.long) |
| else: |
| batch_attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) |
| valid_lens = torch.full((batch_size,), int(input_ids.size(1)), device=input_ids.device, dtype=torch.long) |
| if torch.any(valid_lens <= 0): |
| return None |
|
|
| max_valid_len = int(valid_lens.max().item()) |
| trimmed_input_ids = input_ids[:, :max_valid_len] |
| full_attention_mask = batch_attention_mask[:, :max_valid_len] |
|
|
| parsed_split_starts = _tri_parse_split_start_hints( |
| split_start_hint, |
| batch_size=batch_size, |
| device=trimmed_input_ids.device, |
| ) |
| if parsed_split_starts is None: |
| split_starts = (valid_lens - 1).clamp_min_(0) |
| else: |
| valid_split_mask = (parsed_split_starts >= 0) & (parsed_split_starts < valid_lens) |
| if not bool(torch.all(valid_split_mask).item()): |
| _tri_vanilla_warn_once( |
| self, |
| "_tri_warned_prefill_lower_solo_infer_invalid_split_start", |
| "[TRI] Ignoring invalid prefill_lower solo split start for inference; using the last prompt token.", |
| ) |
| split_starts = torch.where(valid_split_mask, parsed_split_starts, (valid_lens - 1).clamp_min_(0)) |
|
|
| parsed_system_lens = _tri_parse_system_len_hints( |
| system_len_hint, |
| batch_size=batch_size, |
| device=trimmed_input_ids.device, |
| ) |
| if parsed_system_lens is None: |
| system_lens = torch.zeros_like(split_starts) |
| else: |
| system_lens = torch.minimum(parsed_system_lens.clamp_min(0), split_starts) |
| system_prefill = normalize_system_prefill(system_prefill) |
|
|
| model = self.model |
| n_layers = len(model.layers) |
| K = max(0, min(int(prefill_lower_layers), n_layers)) |
| if K <= 0: |
| return None |
|
|
| B, L = trimmed_input_ids.shape |
| device = trimmed_input_ids.device |
| attn_mode = (prefill_attn or "causal").strip().lower() |
| if attn_mode == "prefix_full": |
| attn_mode = "full" |
| if attn_mode not in ("causal", "full"): |
| raise ValueError("prefill_lower_attn must be one of {'causal', 'full'}.") |
| prefix_full = attn_mode == "full" |
|
|
| inputs_embeds = model.embed_tokens(trimmed_input_ids) |
| cache_position = _tri_arange(0, L, device) |
| position_ids = _llopa_position_ids_from_mask(full_attention_mask) |
|
|
| attn_impl = _resolve_attn_impl(model.config) |
| if prefix_full and attn_impl == "flash_attention_2": |
| raise ValueError("prefill_lower_attn=full is not supported with flash_attention_2.") |
| if attn_impl == "flash_attention_2": |
| full_mask = None if _llopa_mask_is_all_ones(full_attention_mask) else full_attention_mask |
| elif not prefix_full and _can_use_implicit_causal_mask(model.config) and _llopa_mask_is_all_ones(full_attention_mask): |
| full_mask = None |
| elif prefix_full: |
| full_mask = _build_tri_mask_prefix_full_varlen(split_starts, full_attention_mask, inputs_embeds.dtype) |
| else: |
| full_mask = create_causal_mask( |
| config=model.config, |
| input_embeds=inputs_embeds, |
| attention_mask=full_attention_mask, |
| cache_position=cache_position, |
| past_key_values=None, |
| position_ids=position_ids, |
| ) |
|
|
| hidden_states = inputs_embeds |
| position_embeddings = model.rotary_emb(hidden_states, position_ids) |
| use_checkpoint = should_use_checkpoint(model) |
|
|
| for li in range(K): |
| layer = model.layers[li] |
| if use_checkpoint: |
| hidden_states = checkpoint_with_cache( |
| layer.__call__, |
| hidden_states, |
| attention_mask=full_mask, |
| position_ids=position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| else: |
| hidden_states = layer( |
| hidden_states, |
| attention_mask=full_mask, |
| position_ids=position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
|
|
| if K < n_layers: |
| split_starts_upper = split_starts.to(device=hidden_states.device, dtype=torch.long) |
| valid_lens_upper = valid_lens.to(device=hidden_states.device, dtype=torch.long) |
| system_lens_upper = system_lens.to(device=hidden_states.device, dtype=torch.long) |
| upper_gather_idx, upper_valid_mask, _ = _tri_build_prefill_lower_upper_index_batch( |
| split_starts=split_starts_upper, |
| valid_lens=valid_lens_upper, |
| system_lens=system_lens_upper, |
| system_prefill=system_prefill, |
| device=hidden_states.device, |
| ) |
| upper_visible_mask = torch.zeros((B, L), device=hidden_states.device, dtype=torch.bool) |
| if upper_gather_idx.numel() > 0 and upper_valid_mask.any(): |
| upper_visible_mask.scatter_(1, upper_gather_idx, upper_valid_mask) |
| token_positions = torch.arange(L, device=hidden_states.device, dtype=torch.long).unsqueeze(0) |
| valid_token_mask = token_positions[:, :L] < valid_lens.unsqueeze(1) |
| prefix_self_mask = valid_token_mask & (~upper_visible_mask) |
| has_prefix_self = bool(prefix_self_mask.any().item()) |
| if has_prefix_self: |
| prefix_row_idx, prefix_col_idx = prefix_self_mask.nonzero(as_tuple=True) |
| singleton_cache_position = _tri_arange(0, 1, hidden_states.device) |
| singleton_attention_mask = None |
| for li in range(K, n_layers): |
| layer = model.layers[li] |
| layer_input = hidden_states |
| if use_checkpoint: |
| upper_hidden = checkpoint_with_cache( |
| layer.__call__, |
| layer_input, |
| attention_mask=full_mask, |
| position_ids=position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| else: |
| upper_hidden = layer( |
| layer_input, |
| attention_mask=full_mask, |
| position_ids=position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| if has_prefix_self: |
| singleton_hidden = layer_input[prefix_row_idx, prefix_col_idx].unsqueeze(1) |
| singleton_position_ids = position_ids[prefix_row_idx, prefix_col_idx].unsqueeze(1) |
| singleton_pos_emb = model.rotary_emb(singleton_hidden, singleton_position_ids) |
| if use_checkpoint: |
| singleton_outputs = checkpoint_with_cache( |
| layer.__call__, |
| singleton_hidden, |
| attention_mask=singleton_attention_mask, |
| position_ids=singleton_position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=singleton_cache_position, |
| position_embeddings=singleton_pos_emb, |
| ) |
| else: |
| singleton_outputs = layer( |
| singleton_hidden, |
| attention_mask=singleton_attention_mask, |
| position_ids=singleton_position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=singleton_cache_position, |
| position_embeddings=singleton_pos_emb, |
| ) |
| merged_hidden = upper_hidden.clone() |
| merged_hidden[prefix_row_idx, prefix_col_idx] = singleton_outputs.squeeze(1) |
| hidden_states = merged_hidden |
| else: |
| hidden_states = upper_hidden |
|
|
| hidden_states = model.norm(hidden_states) |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| if isinstance(slice_indices, torch.Tensor): |
| slice_indices = slice_indices.to(device=hidden_states.device) |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
| return CausalLMOutputWithPast( |
| loss=None, |
| logits=logits, |
| past_key_values=None, |
| hidden_states=hidden_states, |
| ) |
|
|
|
|
| LlamaForCausalLM.tri_vanilla_solo_prefix_infer_forward = tri_vanilla_solo_prefix_infer_forward |
|
|
|
|
| def _tri_vanilla_solo_attention_v2_forward( |
| self: "LlamaForCausalLM", |
| *, |
| input_ids: Optional[torch.LongTensor], |
| attention_mask: Optional[torch.Tensor], |
| labels: Optional[torch.LongTensor], |
| logits_to_keep: Union[int, torch.Tensor], |
| prefill_lower_layers: int, |
| prefill_attn: str, |
| split_start_hint, |
| system_len_hint, |
| system_prefill: str, |
| with_bos: bool = False, |
| ) -> Optional[CausalLMOutputWithPast]: |
| if input_ids is None or input_ids.dim() != 2: |
| return None |
| if labels is not None and (labels.dim() != 2 or labels.shape != input_ids.shape): |
| return None |
|
|
| batch_size = int(input_ids.size(0)) |
| if attention_mask is not None: |
| if attention_mask.dim() != 2 or attention_mask.size(0) != batch_size: |
| return None |
| batch_attention_mask = attention_mask.to(device=input_ids.device, dtype=torch.long) |
| valid_lens = batch_attention_mask.sum(dim=1, dtype=torch.long) |
| else: |
| batch_attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) |
| valid_lens = torch.full((batch_size,), int(input_ids.size(1)), device=input_ids.device, dtype=torch.long) |
| if torch.any(valid_lens <= 0): |
| return None |
|
|
| max_valid_len = int(valid_lens.max().item()) |
| trimmed_input_ids = input_ids[:, :max_valid_len] |
| trimmed_labels = labels[:, :max_valid_len] if labels is not None else None |
| full_attention_mask = batch_attention_mask[:, :max_valid_len] |
|
|
| if trimmed_labels is not None: |
| valid_label_mask = trimmed_labels != -100 |
| valid_counts = valid_label_mask.sum(dim=1, dtype=torch.long) |
| if torch.any(valid_counts <= 0): |
| _tri_vanilla_warn_once( |
| self, |
| "_tri_warned_prefill_lower_solo_v2_no_labels", |
| "[TRI] prefill_lower_solo_attention_v2 path found no supervised suffix; falling back to vanilla.", |
| ) |
| return None |
|
|
| token_positions = torch.arange(max_valid_len, device=trimmed_labels.device, dtype=torch.long).unsqueeze(0) |
| first_valid = torch.where( |
| valid_label_mask, |
| token_positions, |
| torch.full_like(token_positions, max_valid_len), |
| ).min(dim=1).values |
| last_valid = torch.where( |
| valid_label_mask, |
| token_positions, |
| torch.full_like(token_positions, -1), |
| ).max(dim=1).values |
| contiguous_suffix = valid_counts == (last_valid - first_valid + 1) |
| if not bool(torch.all(contiguous_suffix).item()): |
| _tri_vanilla_warn_once( |
| self, |
| "_tri_warned_prefill_lower_solo_v2_noncontiguous", |
| "[TRI] prefill_lower_solo_attention_v2 expects a contiguous supervised suffix; falling back to vanilla.", |
| ) |
| return None |
|
|
| split_starts = (first_valid - 1).clamp_min_(0) |
| parsed_split_starts = _tri_parse_split_start_hints( |
| split_start_hint, |
| batch_size=batch_size, |
| device=trimmed_labels.device, |
| ) |
| if parsed_split_starts is not None: |
| valid_split_mask = ( |
| (parsed_split_starts >= 0) |
| & (parsed_split_starts < valid_lens) |
| & (parsed_split_starts <= first_valid) |
| ) |
| split_starts = torch.where(valid_split_mask, parsed_split_starts, split_starts) |
| else: |
| parsed_split_starts = _tri_parse_split_start_hints( |
| split_start_hint, |
| batch_size=batch_size, |
| device=trimmed_input_ids.device, |
| ) |
| if parsed_split_starts is None: |
| split_starts = (valid_lens - 1).clamp_min_(0) |
| else: |
| valid_split_mask = (parsed_split_starts >= 0) & (parsed_split_starts < valid_lens) |
| if not bool(torch.all(valid_split_mask).item()): |
| _tri_vanilla_warn_once( |
| self, |
| "_tri_warned_prefill_lower_solo_v2_infer_invalid_split_start", |
| "[TRI] Ignoring invalid prefill_lower solo_v2 split start for inference; using the last prompt token.", |
| ) |
| split_starts = torch.where(valid_split_mask, parsed_split_starts, (valid_lens - 1).clamp_min_(0)) |
|
|
| _ = _tri_parse_system_len_hints( |
| system_len_hint, |
| batch_size=batch_size, |
| device=trimmed_input_ids.device, |
| ) |
| _ = normalize_system_prefill(system_prefill) |
|
|
| model = self.model |
| n_layers = len(model.layers) |
| K = max(0, min(int(prefill_lower_layers), n_layers)) |
| if K <= 0: |
| return None |
|
|
| _, L = trimmed_input_ids.shape |
| device = trimmed_input_ids.device |
| attn_mode = (prefill_attn or "causal").strip().lower() |
| if attn_mode == "prefix_full": |
| attn_mode = "full" |
| if attn_mode not in ("causal", "full"): |
| raise ValueError("prefill_lower_attn must be one of {'causal', 'full'}.") |
| prefix_full = attn_mode == "full" |
|
|
| inputs_embeds = model.embed_tokens(trimmed_input_ids) |
| cache_position = _tri_arange(0, L, device) |
| position_ids = _llopa_position_ids_from_mask(full_attention_mask) |
|
|
| attn_impl = _resolve_attn_impl(model.config) |
| if prefix_full and attn_impl == "flash_attention_2": |
| raise ValueError("prefill_lower_attn=full is not supported with flash_attention_2.") |
| if attn_impl == "flash_attention_2": |
| full_mask = None if _llopa_mask_is_all_ones(full_attention_mask) else full_attention_mask |
| elif not prefix_full and _can_use_implicit_causal_mask(model.config) and _llopa_mask_is_all_ones(full_attention_mask): |
| full_mask = None |
| elif prefix_full: |
| full_mask = _build_tri_mask_prefix_full_varlen(split_starts, full_attention_mask, inputs_embeds.dtype) |
| else: |
| full_mask = create_causal_mask( |
| config=model.config, |
| input_embeds=inputs_embeds, |
| attention_mask=full_attention_mask, |
| cache_position=cache_position, |
| past_key_values=None, |
| position_ids=position_ids, |
| ) |
|
|
| hidden_states = inputs_embeds |
| position_embeddings = model.rotary_emb(hidden_states, position_ids) |
| use_checkpoint = should_use_checkpoint(model) |
| use_optimized_training_path = _llopa_use_optimized_training_path( |
| model, |
| attn_impl=attn_impl, |
| prefix_full=prefix_full, |
| no_upper_attn=False, |
| ) |
| use_custom_checkpoint = use_checkpoint and not use_optimized_training_path |
|
|
| for li in range(K): |
| layer = model.layers[li] |
| if use_custom_checkpoint: |
| hidden_states = checkpoint_with_cache( |
| layer.__call__, |
| hidden_states, |
| attention_mask=full_mask, |
| position_ids=position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| else: |
| hidden_states = layer( |
| hidden_states, |
| attention_mask=full_mask, |
| position_ids=position_ids, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
|
|
| compacted_solo_v2_upper = False |
| compacted_labels = None |
| if trimmed_labels is not None and isinstance(logits_to_keep, int) and int(logits_to_keep) > 0: |
| compacted = _tri_compact_solo_v2_upper_training_inputs( |
| hidden_states=hidden_states, |
| position_ids=position_ids, |
| labels=trimmed_labels, |
| valid_lens=valid_lens, |
| logits_to_keep=int(logits_to_keep), |
| with_bos=bool(with_bos), |
| ) |
| if compacted is not None: |
| hidden_states, position_ids, compacted_labels, _, _ = compacted |
| position_embeddings = model.rotary_emb(hidden_states, position_ids) |
| compacted_solo_v2_upper = True |
|
|
| if K < n_layers: |
| for li in range(K, n_layers): |
| hidden_states = _tri_run_decoder_layer_upper_solo_v2( |
| layer=model.layers[li], |
| hidden_states=hidden_states, |
| position_embeddings=position_embeddings, |
| with_bos=bool(with_bos), |
| use_checkpoint=use_custom_checkpoint, |
| ) |
|
|
| hidden_states = model.norm(hidden_states) |
| if compacted_solo_v2_upper and compacted_labels is not None: |
| if bool(with_bos): |
| logits_hidden = hidden_states[:, 1:, :] |
| loss_labels = compacted_labels[:, 1:] |
| else: |
| logits_hidden = hidden_states |
| loss_labels = compacted_labels |
| logits = self.lm_head(logits_hidden) |
| loss = self.loss_function(logits=logits, labels=loss_labels, vocab_size=self.config.vocab_size) |
| return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=None, hidden_states=hidden_states) |
|
|
| if trimmed_labels is None: |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| if isinstance(slice_indices, torch.Tensor): |
| slice_indices = slice_indices.to(device=hidden_states.device) |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
| return CausalLMOutputWithPast( |
| loss=None, |
| logits=logits, |
| past_key_values=None, |
| hidden_states=hidden_states, |
| ) |
|
|
| if isinstance(logits_to_keep, int) and logits_to_keep > 0: |
| keep_hidden, kept_labels, _, _ = _tri_pack_suffix_tensor( |
| hidden_states, |
| split_starts=(valid_lens - torch.minimum(valid_lens, torch.full_like(valid_lens, int(logits_to_keep)))).clamp_min_(0), |
| valid_lens=valid_lens, |
| pad_value=0.0, |
| ) |
| kept_labels, _, _, _ = _tri_pack_suffix_tensor( |
| trimmed_labels, |
| split_starts=(valid_lens - torch.minimum(valid_lens, torch.full_like(valid_lens, int(logits_to_keep)))).clamp_min_(0), |
| valid_lens=valid_lens, |
| pad_value=-100, |
| ) |
| logits = self.lm_head(keep_hidden) |
| loss = self.loss_function(logits=logits, labels=kept_labels, vocab_size=self.config.vocab_size) |
| else: |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| if isinstance(slice_indices, torch.Tensor): |
| slice_indices = slice_indices.to(device=hidden_states.device) |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
| selected_labels = ( |
| trimmed_labels[:, slice_indices] |
| if isinstance(slice_indices, slice) |
| else trimmed_labels.index_select(1, slice_indices) |
| ) |
| loss = self.loss_function(logits=logits, labels=selected_labels, vocab_size=self.config.vocab_size) |
|
|
| return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=None, hidden_states=hidden_states) |
|
|
|
|
| def tri_vanilla_solo_attention_v2_infer_forward( |
| self: "LlamaForCausalLM", |
| *, |
| input_ids: Optional[torch.LongTensor], |
| attention_mask: Optional[torch.Tensor], |
| logits_to_keep: Union[int, torch.Tensor], |
| prefill_lower_layers: int, |
| prefill_attn: str, |
| split_start_hint, |
| system_len_hint, |
| system_prefill: str, |
| with_bos: bool = False, |
| ) -> Optional[CausalLMOutputWithPast]: |
| return _tri_vanilla_solo_attention_v2_forward( |
| self, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=None, |
| logits_to_keep=logits_to_keep, |
| prefill_lower_layers=prefill_lower_layers, |
| prefill_attn=prefill_attn, |
| split_start_hint=split_start_hint, |
| system_len_hint=system_len_hint, |
| system_prefill=system_prefill, |
| with_bos=bool(with_bos), |
| ) |
|
|
|
|
| LlamaForCausalLM.tri_vanilla_solo_attention_v2_infer_forward = tri_vanilla_solo_attention_v2_infer_forward |
|
|
|
|
| def tri_vanilla_solo_attention_v2_train_forward( |
| self: "LlamaForCausalLM", |
| *, |
| input_ids: Optional[torch.LongTensor], |
| attention_mask: Optional[torch.Tensor], |
| labels: Optional[torch.LongTensor], |
| logits_to_keep: Union[int, torch.Tensor], |
| prefill_lower_layers: int, |
| prefill_attn: str, |
| split_start_hint, |
| system_len_hint, |
| system_prefill: str, |
| with_bos: bool = False, |
| ) -> Optional[CausalLMOutputWithPast]: |
| return _tri_vanilla_solo_attention_v2_forward( |
| self, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=labels, |
| logits_to_keep=logits_to_keep, |
| prefill_lower_layers=prefill_lower_layers, |
| prefill_attn=prefill_attn, |
| split_start_hint=split_start_hint, |
| system_len_hint=system_len_hint, |
| system_prefill=system_prefill, |
| with_bos=bool(with_bos), |
| ) |
|
|
|
|
| LlamaForCausalLM.tri_vanilla_solo_attention_v2_train_forward = tri_vanilla_solo_attention_v2_train_forward |
|
|
|
|
| def tri_vanilla_solo_prefix_train_forward( |
| self: "LlamaForCausalLM", |
| *, |
| input_ids: Optional[torch.LongTensor], |
| attention_mask: Optional[torch.Tensor], |
| labels: Optional[torch.LongTensor], |
| logits_to_keep: Union[int, torch.Tensor], |
| prefill_lower_layers: int, |
| prefill_attn: str, |
| split_start_hint, |
| system_len_hint, |
| system_prefill: str, |
| ) -> Optional[CausalLMOutputWithPast]: |
| if input_ids is None or labels is None or input_ids.dim() != 2 or labels.dim() != 2: |
| return None |
|
|
| batch_size = int(input_ids.size(0)) |
| if attention_mask is not None: |
| if attention_mask.dim() != 2 or attention_mask.size(0) != batch_size: |
| return None |
| batch_attention_mask = attention_mask.to(device=input_ids.device, dtype=torch.long) |
| valid_lens = batch_attention_mask.sum(dim=1, dtype=torch.long) |
| else: |
| batch_attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) |
| valid_lens = torch.full((batch_size,), int(input_ids.size(1)), device=input_ids.device, dtype=torch.long) |
| if torch.any(valid_lens <= 0): |
| return None |
|
|
| max_valid_len = int(valid_lens.max().item()) |
| trimmed_input_ids = input_ids[:, :max_valid_len] |
| trimmed_labels = labels[:, :max_valid_len] |
| full_attention_mask = batch_attention_mask[:, :max_valid_len] |
|
|
| valid_label_mask = trimmed_labels != -100 |
| valid_counts = valid_label_mask.sum(dim=1, dtype=torch.long) |
| if torch.any(valid_counts <= 0): |
| _tri_vanilla_warn_once( |
| self, |
| "_tri_warned_prefill_lower_solo_no_labels", |
| "[TRI] prefill_lower_solo_attention path found no supervised suffix; falling back to vanilla.", |
| ) |
| return None |
|
|
| token_positions = torch.arange(max_valid_len, device=trimmed_labels.device, dtype=torch.long).unsqueeze(0) |
| first_valid = torch.where( |
| valid_label_mask, |
| token_positions, |
| torch.full_like(token_positions, max_valid_len), |
| ).min(dim=1).values |
| last_valid = torch.where( |
| valid_label_mask, |
| token_positions, |
| torch.full_like(token_positions, -1), |
| ).max(dim=1).values |
| contiguous_suffix = valid_counts == (last_valid - first_valid + 1) |
| if not bool(torch.all(contiguous_suffix).item()): |
| _tri_vanilla_warn_once( |
| self, |
| "_tri_warned_prefill_lower_solo_noncontiguous", |
| "[TRI] prefill_lower_solo_attention expects a contiguous supervised suffix; falling back to vanilla.", |
| ) |
| return None |
|
|
| split_starts = (first_valid - 1).clamp_min_(0) |
| parsed_split_starts = _tri_parse_split_start_hints( |
| split_start_hint, |
| batch_size=batch_size, |
| device=trimmed_labels.device, |
| ) |
| if parsed_split_starts is not None: |
| valid_split_mask = ( |
| (parsed_split_starts >= 0) |
| & (parsed_split_starts < valid_lens) |
| & (parsed_split_starts <= first_valid) |
| ) |
| split_starts = torch.where(valid_split_mask, parsed_split_starts, split_starts) |
|
|
| parsed_system_lens = _tri_parse_system_len_hints( |
| system_len_hint, |
| batch_size=batch_size, |
| device=trimmed_labels.device, |
| ) |
| if parsed_system_lens is None: |
| system_lens = torch.zeros_like(split_starts) |
| else: |
| system_lens = torch.minimum(parsed_system_lens.clamp_min(0), split_starts) |
| system_prefill = normalize_system_prefill(system_prefill) |
|
|
| model = self.model |
| n_layers = len(model.layers) |
| K = max(0, min(int(prefill_lower_layers), n_layers)) |
| if K <= 0: |
| return None |
|
|
| B, L = trimmed_input_ids.shape |
| device = trimmed_input_ids.device |
| attn_mode = (prefill_attn or "causal").strip().lower() |
| if attn_mode == "prefix_full": |
| attn_mode = "full" |
| if attn_mode not in ("causal", "full"): |
| raise ValueError("prefill_lower_attn must be one of {'causal', 'full'}.") |
| prefix_full = attn_mode == "full" |
|
|
| inputs_embeds = model.embed_tokens(trimmed_input_ids) |
| cache_position = _tri_arange(0, L, device) |
| position_ids = _llopa_position_ids_from_mask(full_attention_mask) |
|
|
| attn_impl = _resolve_attn_impl(model.config) |
| if prefix_full and attn_impl == "flash_attention_2": |
| raise ValueError("prefill_lower_attn=full is not supported with flash_attention_2.") |
| if attn_impl == "flash_attention_2": |
| full_mask = None if _llopa_mask_is_all_ones(full_attention_mask) else full_attention_mask |
| elif not prefix_full and _can_use_implicit_causal_mask(model.config) and _llopa_mask_is_all_ones(full_attention_mask): |
| full_mask = None |
| elif prefix_full: |
| full_mask = _build_tri_mask_prefix_full_varlen(split_starts, full_attention_mask, inputs_embeds.dtype) |
| else: |
| full_mask = create_causal_mask( |
| config=model.config, |
| input_embeds=inputs_embeds, |
| attention_mask=full_attention_mask, |
| cache_position=cache_position, |
| past_key_values=None, |
| position_ids=position_ids, |
| ) |
|
|
| hidden_states = inputs_embeds |
| position_embeddings = model.rotary_emb(hidden_states, position_ids) |
| use_checkpoint = should_use_checkpoint(model) |
|
|
| for li in range(K): |
| layer = model.layers[li] |
| if use_checkpoint: |
| hidden_states = checkpoint_with_cache( |
| layer.__call__, hidden_states, attention_mask=full_mask, position_ids=position_ids, |
| past_key_values=None, use_cache=False, cache_position=cache_position, position_embeddings=position_embeddings, |
| ) |
| else: |
| hidden_states = layer( |
| hidden_states, attention_mask=full_mask, position_ids=position_ids, |
| past_key_values=None, use_cache=False, cache_position=cache_position, position_embeddings=position_embeddings, |
| ) |
|
|
| if K < n_layers: |
| split_starts_upper = split_starts.to(device=hidden_states.device, dtype=torch.long) |
| valid_lens_upper = valid_lens.to(device=hidden_states.device, dtype=torch.long) |
| system_lens_upper = system_lens.to(device=hidden_states.device, dtype=torch.long) |
| upper_gather_idx, upper_valid_mask, _ = _tri_build_prefill_lower_upper_index_batch( |
| split_starts=split_starts_upper, |
| valid_lens=valid_lens_upper, |
| system_lens=system_lens_upper, |
| system_prefill=system_prefill, |
| device=hidden_states.device, |
| ) |
| upper_visible_mask = torch.zeros((B, L), device=hidden_states.device, dtype=torch.bool) |
| if upper_gather_idx.numel() > 0 and upper_valid_mask.any(): |
| upper_visible_mask.scatter_(1, upper_gather_idx, upper_valid_mask) |
| valid_token_mask = token_positions[:, :L] < valid_lens.unsqueeze(1) |
| prefix_self_mask = valid_token_mask & (~upper_visible_mask) |
| has_prefix_self = bool(prefix_self_mask.any().item()) |
| if has_prefix_self: |
| prefix_row_idx, prefix_col_idx = prefix_self_mask.nonzero(as_tuple=True) |
| singleton_cache_position = _tri_arange(0, 1, device) |
| singleton_attention_mask = None |
| for li in range(K, n_layers): |
| layer = model.layers[li] |
| layer_input = hidden_states |
| if use_checkpoint: |
| upper_hidden = checkpoint_with_cache( |
| layer.__call__, layer_input, attention_mask=full_mask, position_ids=position_ids, |
| past_key_values=None, use_cache=False, cache_position=cache_position, position_embeddings=position_embeddings, |
| ) |
| else: |
| upper_hidden = layer( |
| layer_input, attention_mask=full_mask, position_ids=position_ids, |
| past_key_values=None, use_cache=False, cache_position=cache_position, position_embeddings=position_embeddings, |
| ) |
| if has_prefix_self: |
| singleton_hidden = layer_input[prefix_row_idx, prefix_col_idx].unsqueeze(1) |
| singleton_position_ids = position_ids[prefix_row_idx, prefix_col_idx].unsqueeze(1) |
| singleton_pos_emb = model.rotary_emb(singleton_hidden, singleton_position_ids) |
| if use_checkpoint: |
| singleton_outputs = checkpoint_with_cache( |
| layer.__call__, singleton_hidden, attention_mask=singleton_attention_mask, position_ids=singleton_position_ids, |
| past_key_values=None, use_cache=False, cache_position=singleton_cache_position, position_embeddings=singleton_pos_emb, |
| ) |
| else: |
| singleton_outputs = layer( |
| singleton_hidden, attention_mask=singleton_attention_mask, position_ids=singleton_position_ids, |
| past_key_values=None, use_cache=False, cache_position=singleton_cache_position, position_embeddings=singleton_pos_emb, |
| ) |
| merged_hidden = upper_hidden.clone() |
| merged_hidden[prefix_row_idx, prefix_col_idx] = singleton_outputs.squeeze(1) |
| hidden_states = merged_hidden |
| else: |
| hidden_states = upper_hidden |
|
|
| hidden_states = model.norm(hidden_states) |
| if isinstance(logits_to_keep, int) and logits_to_keep > 0: |
| keep_hidden, kept_labels, _, _ = _tri_pack_suffix_tensor( |
| hidden_states, |
| split_starts=(valid_lens - torch.minimum(valid_lens, torch.full_like(valid_lens, int(logits_to_keep)))).clamp_min_(0), |
| valid_lens=valid_lens, |
| pad_value=0.0, |
| ) |
| kept_labels, _, _, _ = _tri_pack_suffix_tensor( |
| trimmed_labels, |
| split_starts=(valid_lens - torch.minimum(valid_lens, torch.full_like(valid_lens, int(logits_to_keep)))).clamp_min_(0), |
| valid_lens=valid_lens, |
| pad_value=-100, |
| ) |
| logits = self.lm_head(keep_hidden) |
| loss = self.loss_function(logits=logits, labels=kept_labels, vocab_size=self.config.vocab_size) |
| else: |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| if isinstance(slice_indices, torch.Tensor): |
| slice_indices = slice_indices.to(device=hidden_states.device) |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
| selected_labels = trimmed_labels[:, slice_indices] if isinstance(slice_indices, slice) else trimmed_labels.index_select(1, slice_indices) |
| loss = self.loss_function(logits=logits, labels=selected_labels, vocab_size=self.config.vocab_size) |
|
|
| return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=None, hidden_states=hidden_states) |
|
|
|
|
|
|
| LlamaForCausalLM.tri_vanilla_solo_prefix_train_forward = tri_vanilla_solo_prefix_train_forward |
|
|
|
|
| def _tri_vanilla_prefill_decode_forward_batch_seed( |
| self: "LlamaForCausalLM", |
| *, |
| input_ids: torch.LongTensor, |
| attention_mask: Optional[torch.Tensor], |
| labels: Optional[torch.LongTensor], |
| use_cache: Optional[bool], |
| logits_to_keep: Union[int, torch.Tensor], |
| past_key_values: Optional[Cache], |
| prefill_lower_layers: int, |
| prefill_attn: str, |
| split_start_hint, |
| system_len_hint, |
| system_prefill: str, |
| assistant_header_starts_hint=None, |
| assistant_turn_ends_hint=None, |
| assistant_header_start_mask_hint=None, |
| see_past_assistant: bool = False, |
| no_upper_attn: bool = False, |
| replay_module: str = "none", |
| replay_per_layers: int = -1, |
| ) -> Optional[CausalLMOutputWithPast]: |
| if input_ids is None or input_ids.dim() != 2 or input_ids.size(0) <= 1: |
| return None |
| if not _tri_cache_is_empty(past_key_values): |
| return None |
| replay_module = _normalize_replay_module(replay_module) |
| replay_per_layers = _normalize_replay_per_layers(replay_per_layers) |
| if replay_module != "none": |
| return None |
|
|
| model = self.model |
| n_layers = len(model.layers) |
| K = max(0, min(int(prefill_lower_layers), n_layers)) |
| if K <= 0: |
| return None |
|
|
| attn_mode = (prefill_attn or "causal").strip().lower() |
| if attn_mode == "prefix_full": |
| attn_mode = "full" |
| if attn_mode != "causal": |
| return None |
| if _tri_effective_suffix_special_token_ids(self): |
| return None |
|
|
| batch_size = int(input_ids.size(0)) |
| device = input_ids.device |
| if attention_mask is not None: |
| if attention_mask.dim() != 2 or attention_mask.size(0) != batch_size: |
| return None |
| batch_attention_mask = attention_mask.to(device=device, dtype=torch.long) |
| else: |
| batch_attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device) |
| valid_lens = batch_attention_mask.sum(dim=1, dtype=torch.long) |
| if torch.any(valid_lens <= 0): |
| return None |
| max_valid_len = int(valid_lens.max().item()) |
| trimmed_input_ids = input_ids[:, :max_valid_len] |
| full_attention_mask = batch_attention_mask[:, :max_valid_len] |
| trimmed_labels = labels[:, :max_valid_len] if isinstance(labels, torch.Tensor) and labels.dim() == 2 else None |
|
|
| parsed_assistant_header_hints = _tri_parse_assistant_header_start_hints( |
| assistant_header_starts_hint, |
| assistant_header_start_mask_hint, |
| batch_size=batch_size, |
| device=device, |
| ) |
| parsed_assistant_turn_end_hints = _tri_parse_assistant_header_start_hints( |
| assistant_turn_ends_hint, |
| assistant_header_start_mask_hint, |
| batch_size=batch_size, |
| device=device, |
| ) |
| parsed_split_starts = _tri_parse_split_start_hints( |
| split_start_hint, |
| batch_size=batch_size, |
| device=device, |
| ) |
|
|
| if trimmed_labels is not None: |
| valid_label_mask = trimmed_labels != -100 |
| valid_counts = valid_label_mask.sum(dim=1, dtype=torch.long) |
| if torch.any(valid_counts <= 0): |
| return None |
| token_positions = torch.arange(max_valid_len, device=device, dtype=torch.long).unsqueeze(0) |
| first_valid = torch.where( |
| valid_label_mask, |
| token_positions, |
| torch.full_like(token_positions, max_valid_len), |
| ).min(dim=1).values |
| last_valid = torch.where( |
| valid_label_mask, |
| token_positions, |
| torch.full_like(token_positions, -1), |
| ).max(dim=1).values |
| contiguous_suffix = valid_counts == (last_valid - first_valid + 1) |
| if not bool(torch.all(contiguous_suffix).item()): |
| return None |
| split_starts = (first_valid - 1).clamp_min_(0) |
| if parsed_split_starts is not None: |
| valid_split_mask = ( |
| (parsed_split_starts >= 0) |
| & (parsed_split_starts < valid_lens) |
| & (parsed_split_starts <= first_valid) |
| ) |
| split_starts = torch.where(valid_split_mask, parsed_split_starts, split_starts) |
| else: |
| split_starts = (valid_lens - 1).clamp_min_(0) |
| if parsed_split_starts is not None: |
| valid_split_mask = (parsed_split_starts >= 0) & (parsed_split_starts < valid_lens) |
| split_starts = torch.where(valid_split_mask, parsed_split_starts, split_starts) |
|
|
| parsed_system_lens = _tri_parse_system_len_hints( |
| system_len_hint, |
| batch_size=batch_size, |
| device=device, |
| ) |
| if parsed_system_lens is None: |
| system_lens = torch.zeros_like(split_starts) |
| else: |
| system_lens = torch.minimum(parsed_system_lens.clamp_min(0), split_starts) |
| system_prefill = normalize_system_prefill(system_prefill) |
|
|
| decode_lens = valid_lens - split_starts |
| if torch.any(decode_lens <= 0): |
| return None |
| if not bool(torch.all(decode_lens == decode_lens[0]).item()): |
| return None |
| decode_width = int(decode_lens[0].item()) |
| if decode_width <= 0: |
| return None |
|
|
| prefix_width = int(split_starts.max().item()) |
| pkv = _safe_dynamic_cache(model.config) |
| if prefix_width > 0: |
| prefix_offsets = torch.arange(prefix_width, device=device, dtype=torch.long).unsqueeze(0) |
| prefix_valid = prefix_offsets < split_starts.unsqueeze(1) |
| prefix_gather = prefix_offsets.expand(batch_size, -1).clamp(max=max(max_valid_len - 1, 0)) |
| prefix_ids = trimmed_input_ids.gather(1, prefix_gather).masked_fill(~prefix_valid, 0) |
| prefix_attention_mask = prefix_valid.to(dtype=torch.long) |
| inputs_embeds = model.embed_tokens(prefix_ids) |
| cache_position = _tri_arange(0, prefix_width, device) |
| position_ids = _llopa_position_ids_from_mask(prefix_attention_mask) |
| use_checkpoint = should_use_checkpoint(model) |
| packed_lower_hidden = None |
| if _tri_batch_packed_varlen_prefill_enabled(inputs_embeds, prefix_valid): |
| packed_lower_hidden = _tri_run_packed_varlen_prefill_layers( |
| model=model, |
| hidden_states=inputs_embeds[prefix_valid], |
| position_ids=position_ids[prefix_valid], |
| valid_mask=prefix_valid, |
| past_key_values=pkv, |
| start_layer=0, |
| end_layer=K, |
| use_checkpoint=use_checkpoint, |
| ) |
| if packed_lower_hidden is not None: |
| hidden_states = _tri_scatter_packed_hidden( |
| packed_lower_hidden, |
| prefix_valid, |
| batch_size=batch_size, |
| width=prefix_width, |
| ) |
| else: |
| attn_impl = _resolve_attn_impl(model.config) |
| if attn_impl == "flash_attention_2": |
| lower_mask = None if _llopa_mask_is_all_ones(prefix_attention_mask) else prefix_attention_mask |
| elif _can_use_implicit_causal_mask(model.config) and _llopa_mask_is_all_ones(prefix_attention_mask): |
| lower_mask = None |
| else: |
| lower_mask = create_causal_mask( |
| config=model.config, |
| input_embeds=inputs_embeds, |
| attention_mask=prefix_attention_mask, |
| cache_position=cache_position, |
| past_key_values=None, |
| position_ids=position_ids, |
| ) |
| hidden_states = inputs_embeds |
| position_embeddings = model.rotary_emb(hidden_states, position_ids) |
| for li in range(K): |
| layer = model.layers[li] |
| if use_checkpoint: |
| hidden_states = checkpoint_with_cache( |
| layer.__call__, |
| hidden_states, |
| attention_mask=lower_mask, |
| position_ids=position_ids, |
| past_key_values=pkv, |
| use_cache=True, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| else: |
| hidden_states = layer( |
| hidden_states, |
| attention_mask=lower_mask, |
| position_ids=position_ids, |
| past_key_values=pkv, |
| use_cache=True, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
| else: |
| prefix_valid = torch.zeros((batch_size, 0), device=device, dtype=torch.bool) |
| hidden_states = model.embed_tokens(input_ids[:, :0]) |
| position_ids = torch.empty((batch_size, 0), device=device, dtype=torch.long) |
| use_checkpoint = should_use_checkpoint(model) |
|
|
| if parsed_assistant_header_hints is not None: |
| header_starts, header_start_mask = parsed_assistant_header_hints |
| else: |
| header_starts = None |
| header_start_mask = None |
| if parsed_assistant_turn_end_hints is not None: |
| turn_ends, _ = parsed_assistant_turn_end_hints |
| else: |
| turn_ends = None |
|
|
| upper_gather_idx, upper_valid_mask, upper_lens = _tri_build_prefill_lower_seed_upper_index_batch( |
| split_starts=split_starts, |
| system_lens=system_lens, |
| system_prefill=system_prefill, |
| assistant_header_starts=header_starts, |
| assistant_turn_ends=turn_ends, |
| assistant_header_start_mask=header_start_mask, |
| see_past_assistant=bool(see_past_assistant), |
| device=device, |
| ) |
| upper_width = int(upper_gather_idx.size(1)) |
| if K < n_layers and upper_width > 0 and not bool(no_upper_attn): |
| upper_hidden, _ = _tri_pack_indexed_tensor( |
| hidden_states, |
| gather_idx=upper_gather_idx, |
| valid_mask=upper_valid_mask, |
| pad_value=0.0, |
| ) |
| upper_position_ids, _ = _tri_pack_indexed_tensor( |
| position_ids, |
| gather_idx=upper_gather_idx, |
| valid_mask=upper_valid_mask, |
| pad_value=0, |
| ) |
| packed_upper_hidden = None |
| if _tri_batch_packed_varlen_prefill_enabled(upper_hidden, upper_valid_mask): |
| packed_upper_hidden = _tri_run_packed_varlen_prefill_layers( |
| model=model, |
| hidden_states=upper_hidden[upper_valid_mask], |
| position_ids=upper_position_ids[upper_valid_mask], |
| valid_mask=upper_valid_mask, |
| past_key_values=pkv, |
| start_layer=K, |
| end_layer=n_layers, |
| use_checkpoint=use_checkpoint, |
| ) |
| if packed_upper_hidden is None: |
| upper_attention_mask = upper_valid_mask.to(dtype=torch.long) |
| upper_cache_position = _tri_arange(0, upper_width, device) |
| attn_impl = _resolve_attn_impl(model.config) |
| if attn_impl == "flash_attention_2": |
| upper_mask = None if _llopa_mask_is_all_ones(upper_attention_mask) else upper_attention_mask |
| elif _can_use_implicit_causal_mask(model.config) and _llopa_mask_is_all_ones(upper_attention_mask): |
| upper_mask = None |
| else: |
| upper_mask = create_causal_mask( |
| config=model.config, |
| input_embeds=upper_hidden, |
| attention_mask=upper_attention_mask, |
| cache_position=upper_cache_position, |
| past_key_values=None, |
| position_ids=upper_position_ids, |
| ) |
| upper_position_embeddings = model.rotary_emb(upper_hidden, upper_position_ids) |
| for li in range(K, n_layers): |
| layer = model.layers[li] |
| if use_checkpoint: |
| upper_hidden = checkpoint_with_cache( |
| layer.__call__, |
| upper_hidden, |
| attention_mask=upper_mask, |
| position_ids=upper_position_ids, |
| past_key_values=pkv, |
| use_cache=True, |
| cache_position=upper_cache_position, |
| position_embeddings=upper_position_embeddings, |
| ) |
| else: |
| upper_hidden = layer( |
| upper_hidden, |
| attention_mask=upper_mask, |
| position_ids=upper_position_ids, |
| past_key_values=pkv, |
| use_cache=True, |
| cache_position=upper_cache_position, |
| position_embeddings=upper_position_embeddings, |
| ) |
|
|
| layer_valid_masks: list[torch.Tensor] = [] |
| for li in range(n_layers): |
| if li < K: |
| layer_valid_masks.append(prefix_valid.to(device=device, dtype=torch.bool).clone()) |
| else: |
| layer_valid_masks.append(upper_valid_mask.to(device=device, dtype=torch.bool).clone()) |
| with contextlib.suppress(Exception): |
| setattr(pkv, "_llopa_batch_layer_valid_masks", layer_valid_masks) |
| setattr(pkv, "_tri_past_len_cache", None) |
|
|
| decode_offsets = torch.arange(decode_width, device=device, dtype=torch.long).unsqueeze(0) |
| decode_gather = split_starts.unsqueeze(1) + decode_offsets |
| decode_ids = trimmed_input_ids.gather(1, decode_gather) |
| decode_labels = trimmed_labels.gather(1, decode_gather) if trimmed_labels is not None else None |
| with contextlib.suppress(Exception): |
| setattr(pkv, "_llopa_batch_decode_position_ids", decode_gather.to(device=device, dtype=torch.long)) |
|
|
| write_cache = bool(self.config.use_cache if use_cache is None else use_cache) |
| out = self.tri_step_logits( |
| assistant_ids=decode_ids, |
| lower_k=int(prefill_lower_layers), |
| pkv=pkv, |
| S=0, |
| U=0, |
| logits_to_keep=logits_to_keep, |
| labels=decode_labels, |
| write_cache=write_cache, |
| prefill_mode="lower", |
| no_upper_attn=bool(no_upper_attn), |
| replay_module=replay_module, |
| replay_per_layers=replay_per_layers, |
| ) |
| if write_cache: |
| _tri_append_batch_cache_valid_masks( |
| pkv, |
| torch.ones((batch_size, decode_width), device=device, dtype=torch.bool), |
| ) |
| with contextlib.suppress(Exception): |
| setattr( |
| pkv, |
| "_tri_prefill_seed_meta", |
| { |
| "S": 0, |
| "U": 0, |
| "split_start": split_starts.detach(), |
| }, |
| ) |
| return out |
|
|
|
|
| def tri_vanilla_prefill_decode_forward( |
| self: "LlamaForCausalLM", |
| *, |
| input_ids: Optional[torch.LongTensor], |
| attention_mask: Optional[torch.Tensor], |
| labels: Optional[torch.LongTensor], |
| use_cache: Optional[bool], |
| logits_to_keep: Union[int, torch.Tensor], |
| past_key_values: Optional[Cache], |
| prefill_lower_layers: int, |
| prefill_attn: str, |
| split_start_hint, |
| system_len_hint, |
| system_prefill: str, |
| assistant_header_starts_hint=None, |
| assistant_turn_ends_hint=None, |
| assistant_header_start_mask_hint=None, |
| replay_user_prefix_keep_len_hint=None, |
| replay_user_start_hint=None, |
| replay_user_len_hint=None, |
| see_past_assistant: bool = False, |
| no_upper_attn: bool = False, |
| replay_module: str = "none", |
| replay_per_layers: int = -1, |
| ) -> Optional[CausalLMOutputWithPast]: |
| if input_ids is None or input_ids.dim() != 2: |
| return None |
| if input_ids.size(0) != 1: |
| batch_out = _tri_vanilla_prefill_decode_forward_batch_seed( |
| self, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=labels, |
| use_cache=use_cache, |
| logits_to_keep=logits_to_keep, |
| past_key_values=past_key_values, |
| prefill_lower_layers=prefill_lower_layers, |
| prefill_attn=prefill_attn, |
| split_start_hint=split_start_hint, |
| system_len_hint=system_len_hint, |
| system_prefill=system_prefill, |
| assistant_header_starts_hint=assistant_header_starts_hint, |
| assistant_turn_ends_hint=assistant_turn_ends_hint, |
| assistant_header_start_mask_hint=assistant_header_start_mask_hint, |
| see_past_assistant=bool(see_past_assistant), |
| no_upper_attn=bool(no_upper_attn), |
| replay_module=replay_module, |
| replay_per_layers=replay_per_layers, |
| ) |
| if batch_out is not None: |
| return batch_out |
| return None |
|
|
| valid_len = int(input_ids.size(1)) |
| if attention_mask is not None: |
| if attention_mask.dim() != 2 or attention_mask.size(0) != input_ids.size(0): |
| return None |
| valid_len = int(attention_mask[0].sum().item()) |
| if valid_len <= 0: |
| return None |
|
|
| trimmed_input_ids = input_ids[:, :valid_len] |
| trimmed_labels = labels[:, :valid_len] if labels is not None else None |
|
|
| if _tri_cache_is_empty(past_key_values): |
| parsed_split_start = _tri_parse_split_start_hint(split_start_hint, batch_size=input_ids.size(0)) |
| parsed_assistant_header_hints = _tri_parse_assistant_header_start_hints( |
| assistant_header_starts_hint, |
| assistant_header_start_mask_hint, |
| batch_size=input_ids.size(0), |
| device=trimmed_input_ids.device, |
| ) |
| parsed_assistant_turn_end_hints = _tri_parse_assistant_header_start_hints( |
| assistant_turn_ends_hint, |
| assistant_header_start_mask_hint, |
| batch_size=input_ids.size(0), |
| device=trimmed_input_ids.device, |
| ) |
| parsed_system_len = _tri_parse_system_len_hint(system_len_hint, batch_size=input_ids.size(0)) |
| parsed_replay_user_prefix_keep_len = _tri_parse_system_len_hint( |
| replay_user_prefix_keep_len_hint, |
| batch_size=input_ids.size(0), |
| ) |
| parsed_replay_user_start = _tri_parse_system_len_hint( |
| replay_user_start_hint, |
| batch_size=input_ids.size(0), |
| ) |
| parsed_replay_user_len = _tri_parse_system_len_hint( |
| replay_user_len_hint, |
| batch_size=input_ids.size(0), |
| ) |
| if trimmed_labels is not None: |
| valid_mask = trimmed_labels[0] != -100 |
| valid_positions = valid_mask.nonzero(as_tuple=False).flatten() |
| if valid_positions.numel() == 0: |
| _tri_vanilla_warn_once( |
| self, |
| "_tri_warned_prefill_lower_no_labels", |
| "[TRI] prefill_lower_layers found no supervised suffix; falling back to vanilla.", |
| ) |
| return None |
| first_valid = int(valid_positions[0].item()) |
| last_valid = int(valid_positions[-1].item()) |
| if not bool(valid_mask[first_valid : last_valid + 1].all()): |
| _tri_vanilla_warn_once( |
| self, |
| "_tri_warned_prefill_lower_noncontiguous", |
| "[TRI] prefill_lower_layers expects a contiguous supervised suffix; falling back to vanilla.", |
| ) |
| return None |
| split_start = max(first_valid - 1, 0) |
| if parsed_split_start is not None: |
| if 0 <= parsed_split_start < valid_len and parsed_split_start <= first_valid: |
| split_start = int(parsed_split_start) |
| else: |
| _tri_vanilla_warn_once( |
| self, |
| "_tri_warned_prefill_lower_invalid_split_start", |
| "[TRI] Ignoring invalid prefill_lower split start; falling back to the label-derived boundary.", |
| ) |
| else: |
| split_start = max(valid_len - 1, 0) |
| if parsed_split_start is not None: |
| if 0 <= parsed_split_start < valid_len: |
| split_start = int(parsed_split_start) |
| else: |
| _tri_vanilla_warn_once( |
| self, |
| "_tri_warned_prefill_lower_infer_invalid_split_start", |
| "[TRI] Ignoring invalid prefill_lower split start for inference; using the last prompt token.", |
| ) |
| if parsed_system_len is None: |
| system_len = 0 |
| else: |
| system_len = min(max(int(parsed_system_len), 0), split_start) |
| system_prefill = normalize_system_prefill(system_prefill) |
| if _tri_effective_fusion_mode(self) == "inband": |
| inband_attention_mask = torch.ones_like(trimmed_input_ids, dtype=torch.long, device=trimmed_input_ids.device) |
| header_starts = None |
| turn_ends = None |
| header_start_mask = None |
| if parsed_assistant_header_hints is not None: |
| header_starts, header_start_mask = parsed_assistant_header_hints |
| if parsed_assistant_turn_end_hints is not None: |
| turn_ends, _ = parsed_assistant_turn_end_hints |
| ( |
| trimmed_input_ids, |
| inband_attention_mask, |
| trimmed_labels, |
| remapped_split_starts, |
| remapped_header_starts, |
| remapped_header_start_mask, |
| ) = _tri_insert_suffix_specials_inband( |
| token_ids=_tri_effective_suffix_special_token_ids(self), |
| input_ids=trimmed_input_ids, |
| attention_mask=inband_attention_mask, |
| labels=trimmed_labels, |
| split_starts=torch.tensor([split_start], device=trimmed_input_ids.device, dtype=torch.long), |
| assistant_header_starts=header_starts, |
| assistant_header_start_mask=header_start_mask, |
| ) |
| valid_len = int(inband_attention_mask[0].sum().item()) |
| split_start = int(remapped_split_starts[0].item()) |
| if remapped_header_starts is not None and remapped_header_start_mask is not None: |
| parsed_assistant_header_hints = (remapped_header_starts, remapped_header_start_mask) |
| if ( |
| turn_ends is not None |
| and parsed_assistant_header_hints is not None |
| ): |
| remapped_header_starts, remapped_header_start_mask = parsed_assistant_header_hints |
| remapped_turn_ends = turn_ends.clone() |
| token_count = len(_tri_effective_suffix_special_token_ids(self)) |
| if token_count > 0: |
| valid_turn_count = int(remapped_turn_ends.size(1)) |
| for col in range(valid_turn_count): |
| if not bool(remapped_header_start_mask[0, col].item()): |
| continue |
| remapped_turn_ends[0, col] += token_count * (col + 1) |
| parsed_assistant_turn_end_hints = (remapped_turn_ends, remapped_header_start_mask) |
| prefix_ids = trimmed_input_ids[:, :split_start] |
| decode_ids = trimmed_input_ids[:, split_start:valid_len] |
| decode_labels = trimmed_labels[:, split_start:valid_len] if trimmed_labels is not None else None |
| replay_module = _normalize_replay_module(replay_module) |
| replay_per_layers = _normalize_replay_per_layers(replay_per_layers) |
| if system_prefill == "full": |
| visible_prefix_len = system_len |
| elif system_prefill == "no_system": |
| visible_prefix_len = min(system_len, 1) |
| else: |
| visible_prefix_len = 0 |
| need_replay_memory = replay_module != "none" and not bool(no_upper_attn) and int(split_start) > int(visible_prefix_len) |
| effective_see_past_assistant = bool(see_past_assistant) and _tri_has_effective_past_assistant_history( |
| split_start=int(split_start), |
| assistant_header_hints=parsed_assistant_header_hints, |
| assistant_turn_end_hints=parsed_assistant_turn_end_hints, |
| ) |
| prefill_out = _tri_prefill_lower_prompt_cache( |
| self.model, |
| prefix_ids, |
| lower_k=int(prefill_lower_layers), |
| prefill_attn=prefill_attn, |
| system_len=system_len, |
| system_prefill=system_prefill, |
| return_replay_memory=need_replay_memory, |
| return_prefix_hidden=bool(effective_see_past_assistant), |
| replay_user_prefix_keep_len=0 if parsed_replay_user_prefix_keep_len is None else int(parsed_replay_user_prefix_keep_len), |
| replay_user_start=parsed_replay_user_start, |
| replay_user_len=parsed_replay_user_len, |
| ) |
| if need_replay_memory or bool(effective_see_past_assistant): |
| pkv, replay_hidden, replay_position_ids, prefix_hidden, prefix_position_ids = prefill_out |
| else: |
| pkv = prefill_out |
| replay_hidden = None |
| replay_position_ids = None |
| prefix_hidden = None |
| prefix_position_ids = None |
| if _tri_effective_fusion_mode(self) == "upper_only": |
| split_positions = [int(split_start)] |
| if parsed_assistant_header_hints is not None: |
| header_starts, header_start_mask = parsed_assistant_header_hints |
| valid_turns = header_starts[0][header_start_mask[0]] |
| if valid_turns.numel() > 0: |
| split_positions = [int(v) for v in valid_turns.tolist()] |
| if int(split_start) not in split_positions: |
| split_positions.append(int(split_start)) |
| _tri_prefill_suffix_specials_upper_cache( |
| self.model, |
| pkv, |
| lower_k=int(prefill_lower_layers), |
| split_starts=split_positions, |
| no_upper_attn=bool(no_upper_attn), |
| ) |
| if bool(effective_see_past_assistant) and parsed_assistant_header_hints is not None and parsed_assistant_turn_end_hints is not None: |
| header_starts, header_start_mask = parsed_assistant_header_hints |
| turn_ends, _ = parsed_assistant_turn_end_hints |
| history_indices: list[torch.Tensor] = [] |
| num_turns = min(int(header_starts.size(1)), int(turn_ends.size(1))) |
| for col in range(num_turns): |
| if not bool(header_start_mask[0, col].item()): |
| continue |
| turn_start = int(header_starts[0, col].item()) |
| turn_end = int(turn_ends[0, col].item()) |
| if turn_end <= turn_start or turn_start >= int(split_start): |
| continue |
| turn_end = min(turn_end, int(split_start)) |
| if turn_end <= turn_start: |
| continue |
| history_indices.append( |
| torch.arange(turn_start, turn_end, device=trimmed_input_ids.device, dtype=torch.long) |
| ) |
| if history_indices: |
| history_idx = torch.cat(history_indices, dim=0) |
| if isinstance(prefix_hidden, torch.Tensor) and isinstance(prefix_position_ids, torch.Tensor): |
| prefix_pos_row = prefix_position_ids[0] |
| history_mask = torch.isin(prefix_pos_row, history_idx) |
| history_local_idx = history_mask.nonzero(as_tuple=False).flatten() |
| if history_local_idx.numel() > 0: |
| history_hidden = prefix_hidden.index_select(1, history_local_idx) |
| history_position_ids = prefix_position_ids.index_select(1, history_local_idx) |
| _tri_prefill_upper_history_cache( |
| self.model, |
| pkv, |
| hidden_states=history_hidden, |
| position_ids=history_position_ids, |
| lower_k=int(prefill_lower_layers), |
| no_upper_attn=bool(no_upper_attn), |
| ) |
| if need_replay_memory and isinstance(replay_hidden, torch.Tensor) and replay_hidden.numel() > 0 and int(replay_hidden.size(1)) > 0: |
| _tri_store_last_layer_memory( |
| pkv, |
| hidden_states=replay_hidden, |
| position_ids=replay_position_ids, |
| valid_mask=torch.ones( |
| (replay_hidden.size(0), replay_hidden.size(1)), |
| device=replay_hidden.device, |
| dtype=torch.bool, |
| ), |
| module_type=replay_module, |
| replay_per_layers=replay_per_layers, |
| ) |
| try: |
| setattr( |
| pkv, |
| "_tri_prefill_seed_meta", |
| { |
| "S": int(visible_prefix_len), |
| "U": int(max(split_start - visible_prefix_len, 0)), |
| "split_start": int(split_start), |
| }, |
| ) |
| except Exception: |
| pass |
| else: |
| split_start = 0 |
| decode_ids = trimmed_input_ids |
| decode_labels = trimmed_labels |
| pkv = past_key_values |
| visible_prefix_len = 0 |
|
|
| if decode_ids.size(1) == 0: |
| return None |
|
|
| write_cache = bool(self.config.use_cache if use_cache is None else use_cache) |
| return self.tri_step_logits( |
| assistant_ids=decode_ids, |
| lower_k=int(prefill_lower_layers), |
| pkv=pkv, |
| S=visible_prefix_len, |
| U=max(split_start - visible_prefix_len, 0), |
| logits_to_keep=logits_to_keep, |
| labels=decode_labels, |
| write_cache=write_cache, |
| prefill_mode="lower", |
| no_upper_attn=bool(no_upper_attn), |
| replay_module=replay_module, |
| replay_per_layers=replay_per_layers, |
| ) |
|
|
|
|
| LlamaForCausalLM.tri_vanilla_prefill_decode_forward = tri_vanilla_prefill_decode_forward |
|
|
|
|
| def llopa_reference_prefill_seed( |
| self: "LlamaForCausalLM", |
| *, |
| system_ids: Optional[torch.LongTensor], |
| user_ids: Optional[torch.LongTensor], |
| assistant_ids: Optional[torch.LongTensor], |
| lower_k: int, |
| prefill_attn: str, |
| system_prefill: str, |
| no_upper_attn: bool, |
| replay_module: str = "none", |
| replay_per_layers: int = -1, |
| replay_user_prefix_keep_len: int = 0, |
| replay_user_start: Optional[int] = None, |
| replay_user_len: Optional[int] = None, |
| ): |
| if bool(no_upper_attn): |
| return None |
| if _tri_effective_fusion_mode(self) == "inband" and _tri_effective_suffix_special_token_ids(self): |
| return None |
| llopa_prefill_fn = getattr(self.model, "llopa_prefill_cache", None) |
| if not callable(llopa_prefill_fn): |
| return None |
| sys_upper, sys_lower_extra = _llopa_split_system(system_ids, system_prefill) |
| if sys_lower_extra.numel() == 0: |
| user_llopa = user_ids |
| elif user_ids.numel() == 0: |
| user_llopa = sys_lower_extra |
| else: |
| user_llopa = torch.cat([sys_lower_extra, user_ids], dim=1) |
| merged_replay_prefix_keep_len = int(sys_lower_extra.size(1)) + int(replay_user_prefix_keep_len or 0) |
| merged_replay_user_start = None if replay_user_start is None else int(sys_lower_extra.size(1)) + int(replay_user_start) |
| prefill_out = llopa_prefill_fn( |
| system_ids=sys_upper, |
| user_ids=user_llopa, |
| assistant_ids=assistant_ids, |
| lower_k=int(lower_k), |
| prefill_mode="lower", |
| prefill_attn=prefill_attn, |
| return_last_assistant_hidden=True, |
| replay_module=replay_module, |
| replay_per_layers=replay_per_layers, |
| replay_user_prefix_keep_len=merged_replay_prefix_keep_len, |
| replay_user_start=merged_replay_user_start, |
| replay_user_len=replay_user_len, |
| ) |
| if not isinstance(prefill_out, tuple) or len(prefill_out) != 2: |
| return None |
| pkv, last_hidden = prefill_out |
| if pkv is None or not isinstance(last_hidden, torch.Tensor) or last_hidden.numel() == 0: |
| return None |
| logits = self.lm_head(last_hidden[:, -1:, :]).to(torch.float32) |
| system_len = int(sys_upper.size(1)) if isinstance(sys_upper, torch.Tensor) else 0 |
| user_len = int(user_llopa.size(1)) if isinstance(user_llopa, torch.Tensor) else 0 |
| return pkv, system_len, user_len, logits[:, -1, :] |
|
|
|
|
| LlamaForCausalLM.llopa_reference_prefill_seed = llopa_reference_prefill_seed |
|
|
|
|
| def tri_reference_prefill_seed( |
| self: "LlamaForCausalLM", |
| *, |
| input_ids: Optional[torch.LongTensor], |
| attention_mask: Optional[torch.Tensor], |
| use_cache: Optional[bool], |
| logits_to_keep: Union[int, torch.Tensor], |
| lower_k: int, |
| prefill_attn: str, |
| system_prefill: str, |
| no_upper_attn: bool, |
| prefill_lower_split_start, |
| prefill_lower_system_len, |
| prefill_lower_replay_user_prefix_keep_len=None, |
| prefill_lower_replay_user_start=None, |
| prefill_lower_replay_user_len=None, |
| assistant_header_starts=None, |
| assistant_turn_ends=None, |
| assistant_header_start_mask=None, |
| prefill_lower_see_past_assistant: bool = False, |
| replay_module: str = "none", |
| replay_per_layers: int = -1, |
| ): |
| outputs = self( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| use_cache=use_cache, |
| logits_to_keep=logits_to_keep, |
| prefill_lower_layers=int(lower_k), |
| prefill_lower_attn=prefill_attn, |
| prefill_lower_system_prefill=system_prefill, |
| prefill_lower_no_upper_attn=bool(no_upper_attn), |
| prefill_lower_split_start=prefill_lower_split_start, |
| prefill_lower_system_len=prefill_lower_system_len, |
| prefill_lower_replay_user_prefix_keep_len=prefill_lower_replay_user_prefix_keep_len, |
| prefill_lower_replay_user_start=prefill_lower_replay_user_start, |
| prefill_lower_replay_user_len=prefill_lower_replay_user_len, |
| assistant_header_starts=assistant_header_starts, |
| assistant_turn_ends=assistant_turn_ends, |
| assistant_header_start_mask=assistant_header_start_mask, |
| prefill_lower_see_past_assistant=bool(prefill_lower_see_past_assistant), |
| prefill_lower_replay_module=replay_module, |
| prefill_lower_replay_per_layers=replay_per_layers, |
| ) |
| if outputs is None or outputs.past_key_values is None: |
| return None |
| logits = getattr(outputs, "logits", None) |
| if not isinstance(logits, torch.Tensor) or logits.numel() == 0: |
| return None |
| meta = getattr(outputs.past_key_values, "_tri_prefill_seed_meta", None) |
| if not isinstance(meta, dict): |
| return None |
| try: |
| S = int(meta.get("S", 0)) |
| U = int(meta.get("U", 0)) |
| except Exception: |
| return None |
| return outputs.past_key_values, S, U, logits[:, -1, :].to(torch.float32) |
|
|
|
|
| LlamaForCausalLM.tri_reference_prefill_seed = tri_reference_prefill_seed |
| LlamaForCausalLM.llopa_full_prompt_prefill_seed = tri_reference_prefill_seed |
|
|
|
|
| def tri_runtime_llopa_prompt_prefill_forward( |
| self: "LlamaForCausalLM", |
| *, |
| input_ids: Optional[torch.LongTensor], |
| attention_mask: Optional[torch.Tensor], |
| use_cache: Optional[bool], |
| logits_to_keep: Union[int, torch.Tensor], |
| past_key_values: Optional[Cache], |
| lower_k: int, |
| prefill_attn: str, |
| no_upper_attn: bool, |
| replay_module: str = "none", |
| replay_per_layers: int = -1, |
| ) -> Optional[CausalLMOutputWithPast]: |
| if input_ids is None or input_ids.dim() != 2: |
| return None |
| if input_ids.size(0) != 1: |
| return None |
|
|
| header_ids = getattr(self, "_runtime_llopa_header_ids", None) |
| if not isinstance(header_ids, torch.Tensor) or header_ids.numel() == 0: |
| return None |
|
|
| valid_len = int(input_ids.size(1)) |
| if attention_mask is not None: |
| if attention_mask.dim() != 2 or attention_mask.size(0) != input_ids.size(0): |
| return None |
| valid_len = int(attention_mask[0].sum().item()) |
| if valid_len <= 0: |
| return None |
|
|
| trimmed_input_ids = input_ids[:, :valid_len] |
| write_cache = bool(self.config.use_cache if use_cache is None else use_cache) |
|
|
| if _tri_cache_is_empty(past_key_values): |
| hdr = header_ids.to(device=trimmed_input_ids.device, dtype=trimmed_input_ids.dtype) |
| assistant_start = _tri_find_last_subsequence_start(trimmed_input_ids, hdr) |
| if assistant_start is None: |
| return None |
| prefix_ids = trimmed_input_ids[:, :assistant_start] |
| assistant_ids = trimmed_input_ids[:, assistant_start:] |
| if assistant_ids.numel() == 0: |
| return None |
| prefill_out = self.model.llopa_prefill_cache( |
| system_ids=trimmed_input_ids[:, :0], |
| user_ids=prefix_ids, |
| assistant_ids=assistant_ids, |
| lower_k=int(lower_k), |
| prefill_mode="lower", |
| prefill_attn=prefill_attn, |
| return_last_assistant_hidden=not bool(no_upper_attn), |
| replay_module=replay_module, |
| replay_per_layers=replay_per_layers, |
| ) |
| last_hidden = None |
| if isinstance(prefill_out, tuple): |
| pkv, last_hidden = prefill_out |
| else: |
| pkv = prefill_out |
| try: |
| setattr(pkv, "_runtime_llopa_user_len", int(prefix_ids.size(1))) |
| except Exception: |
| pass |
| if isinstance(last_hidden, torch.Tensor) and last_hidden.numel() > 0: |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| logits = self.lm_head(last_hidden[:, slice_indices, :]) |
| return CausalLMOutputWithPast( |
| loss=None, |
| logits=logits, |
| past_key_values=pkv if write_cache else None, |
| hidden_states=last_hidden, |
| ) |
| last = assistant_ids[:, -1:] |
| return self.tri_step_logits( |
| assistant_ids=last, |
| lower_k=int(lower_k), |
| pkv=pkv, |
| S=0, |
| U=int(prefix_ids.size(1)), |
| logits_to_keep=logits_to_keep, |
| labels=None, |
| write_cache=write_cache, |
| prefill_mode="lower", |
| no_upper_attn=bool(no_upper_attn), |
| replay_module=replay_module, |
| replay_per_layers=replay_per_layers, |
| ) |
|
|
| user_len = int(getattr(past_key_values, "_runtime_llopa_user_len", 0) or 0) |
| return self.tri_step_logits( |
| assistant_ids=trimmed_input_ids, |
| lower_k=int(lower_k), |
| pkv=past_key_values, |
| S=0, |
| U=user_len, |
| logits_to_keep=logits_to_keep, |
| labels=None, |
| write_cache=write_cache, |
| prefill_mode="lower", |
| no_upper_attn=bool(no_upper_attn), |
| replay_module=replay_module, |
| replay_per_layers=replay_per_layers, |
| ) |
|
|
|
|
| LlamaForCausalLM.tri_runtime_llopa_prompt_prefill_forward = tri_runtime_llopa_prompt_prefill_forward |
|
|
|
|
| __all__ = [ |
| "LlamaForCausalLM", |
| "LlamaModel", |
| "LlamaPreTrainedModel", |
| "LlamaForSequenceClassification", |
| "LlamaForQuestionAnswering", |
| "LlamaForTokenClassification", |
| ] |
|
|