llama3.1_8b_sft_SPEED-20-BoS_OpenCode / tri_llama3_modeling.py
jeongseokoh's picture
Add files using upload-large-folder tool
9751720 verified
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from typing import Any, Callable, Optional, Union, Tuple, TypedDict
import os
import torch
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
try:
from ...masking_utils import create_causal_mask
except Exception:
def _compat_past_key_values_length(past_key_values) -> int:
if past_key_values is None:
return 0
get_seq_length = getattr(past_key_values, "get_seq_length", None)
if callable(get_seq_length):
try:
return max(0, int(get_seq_length()))
except Exception:
pass
if isinstance(past_key_values, tuple) and len(past_key_values) > 0:
try:
first = past_key_values[0][0]
if isinstance(first, torch.Tensor) and first.dim() >= 3:
return max(0, int(first.shape[2]))
except Exception:
pass
return 0
def create_causal_mask(
*,
config,
input_embeds,
attention_mask,
cache_position=None,
past_key_values=None,
position_ids=None,
):
del position_ids
batch_size, query_len = input_embeds.shape[:2]
device = input_embeds.device
dtype = input_embeds.dtype
past_key_values_length = _compat_past_key_values_length(past_key_values)
if past_key_values_length == 0 and isinstance(cache_position, torch.Tensor) and cache_position.numel() > 0:
try:
past_key_values_length = max(0, int(cache_position.reshape(-1)[0].item()))
except Exception:
past_key_values_length = 0
key_len = past_key_values_length + query_len
if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 2:
key_len = max(key_len, int(attention_mask.size(1)))
if isinstance(cache_position, torch.Tensor) and cache_position.numel() == query_len:
query_positions = cache_position.to(device=device, dtype=torch.long).view(1, query_len, 1)
if query_positions.dim() == 3 and query_positions.size(0) != batch_size:
query_positions = query_positions.expand(batch_size, -1, -1)
else:
query_positions = (
torch.arange(query_len, device=device, dtype=torch.long).view(1, query_len, 1) + past_key_values_length
).expand(batch_size, -1, -1)
key_positions = torch.arange(key_len, device=device, dtype=torch.long).view(1, 1, key_len)
allow = key_positions <= query_positions
if isinstance(attention_mask, torch.Tensor):
key_valid = attention_mask.to(device=device, dtype=torch.bool)
if key_valid.dim() == 1:
key_valid = key_valid.unsqueeze(0)
if key_valid.dim() == 2:
if int(key_valid.size(1)) < key_len:
pad = torch.ones(
(int(key_valid.size(0)), key_len - int(key_valid.size(1))),
device=device,
dtype=torch.bool,
)
key_valid = torch.cat([pad, key_valid], dim=1)
elif int(key_valid.size(1)) > key_len:
key_valid = key_valid[:, -key_len:]
allow = allow & key_valid.unsqueeze(1)
query_valid = key_valid[:, -query_len:]
allow = allow & query_valid.unsqueeze(-1)
neg_inf = torch.tensor(torch.finfo(torch.float32).min, device=device)
zeros = torch.tensor(0.0, device=device)
mask = torch.where(allow, zeros, neg_inf).to(dtype)
return mask.view(batch_size, 1, query_len, key_len)
try:
from ...modeling_layers import (
GenericForQuestionAnswering,
GenericForSequenceClassification,
GenericForTokenClassification,
GradientCheckpointingLayer,
)
except Exception:
class GradientCheckpointingLayer(nn.Module):
pass
class GenericForSequenceClassification:
pass
class GenericForQuestionAnswering:
pass
class GenericForTokenClassification:
pass
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
try:
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
except Exception:
from ...utils import can_return_tuple, logging
class TransformersKwargs(TypedDict, total=False):
pass
def auto_docstring(obj=None, *args, **kwargs):
del args, kwargs
if callable(obj):
return obj
def decorator(inner):
return inner
return decorator
from ...utils.deprecation import deprecate_kwarg
try:
from ...utils.generic import check_model_inputs
except Exception:
def check_model_inputs(fn):
return fn
from .configuration_llama import LlamaConfig
logger = logging.get_logger(__name__)
PREFILL_LOWER_REPLAY_USER_PREFIX_KEEP_LEN_KEY = "prefill_lower_replay_user_prefix_keep_len"
PREFILL_LOWER_REPLAY_USER_START_KEY = "prefill_lower_replay_user_start"
PREFILL_LOWER_REPLAY_USER_LEN_KEY = "prefill_lower_replay_user_len"
try:
from open_instruct.dataset_transformation import (
ASSISTANT_HEADER_START_MASK_KEY,
ASSISTANT_HEADER_STARTS_KEY,
)
except Exception:
ASSISTANT_HEADER_STARTS_KEY = "assistant_header_starts"
ASSISTANT_HEADER_START_MASK_KEY = "assistant_header_start_mask"
ASSISTANT_TURN_ENDS_KEY = "assistant_turn_ends"
try:
from open_instruct.llopa_adapter import (
PREFILL_LOWER_SYSTEM_LEN_KEY,
build_prefill_lower_upper_indices,
normalize_system_prefill,
)
except Exception:
PREFILL_LOWER_SYSTEM_LEN_KEY = "prefill_lower_system_len"
def normalize_system_prefill(system_prefill: str) -> str:
mode = (system_prefill or "full").strip().lower()
if mode not in {"full", "no_system", "no_bos_system"}:
return "full"
return mode
def build_prefill_lower_upper_indices(
*,
sequence_len: int,
split_start: int,
system_len: int,
system_prefill: str,
device,
) -> torch.LongTensor:
total_len = max(int(sequence_len), 0)
split_start = min(max(int(split_start), 0), total_len)
system_len = min(max(int(system_len), 0), split_start)
mode = normalize_system_prefill(system_prefill)
if mode == "full":
prefix_keep = system_len
elif mode == "no_system":
prefix_keep = min(system_len, 1)
else:
prefix_keep = 0
pieces = []
if prefix_keep > 0:
pieces.append(torch.arange(0, prefix_keep, device=device, dtype=torch.long))
if split_start < total_len:
pieces.append(torch.arange(split_start, total_len, device=device, dtype=torch.long))
if not pieces:
return torch.empty((0,), device=device, dtype=torch.long)
if len(pieces) == 1:
return pieces[0]
return torch.cat(pieces, dim=0)
try:
from ...utils import is_flash_attn_2_available
except Exception:
def is_flash_attn_2_available():
return False
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as _flash_attn_varlen_func
except Exception:
_flash_attn_varlen_func = None
_WARNED_NO_FLASH_ATTN = False
def _resolve_attn_impl(config):
attn_impl = getattr(config, "_attn_implementation", "eager")
if attn_impl == "flash_attention_2" and not is_flash_attn_2_available():
require_flash = str(os.environ.get("FLASH_ATTN_REQUIRED", "0")).lower() in ("1", "true", "yes")
if require_flash:
raise ValueError(
"flash_attention_2 requested but flash_attn is unavailable. "
"Install flash-attn or unset FLASH_ATTN_REQUIRED to allow sdpa fallback."
)
global _WARNED_NO_FLASH_ATTN
if not _WARNED_NO_FLASH_ATTN:
logger.warning("flash_attention_2 requested but flash_attn is unavailable; falling back to sdpa.")
_WARNED_NO_FLASH_ATTN = True
attn_impl = "sdpa"
try:
setattr(config, "_attn_implementation", attn_impl)
except Exception:
pass
return attn_impl
os.environ.setdefault("DS_BUILD_AIO", "0")
_USE_DS_CHECKPOINT = os.environ.get("LLOPA_ENABLE_DS_CHECKPOINT", "").strip().lower() in {
"1",
"true",
"yes",
"on",
}
if _USE_DS_CHECKPOINT:
try:
from deepspeed.checkpointing import checkpoint as ds_checkpoint
except Exception:
ds_checkpoint = None
else:
ds_checkpoint = None
try:
from checkpoint_utils import checkpoint_with_cache, should_use_checkpoint
except Exception:
def checkpoint_with_cache(fn, *args, **kwargs):
return fn(*args, **kwargs)
def should_use_checkpoint(model) -> bool:
return False
def _safe_dynamic_cache(config):
try:
return DynamicCache(config=config)
except TypeError as exc:
if "max_cache_len" in str(exc) or "config" in str(exc):
return DynamicCache()
raise
def _tri_cache_is_empty(past_key_values) -> bool:
if past_key_values is None:
return True
if isinstance(past_key_values, tuple):
if len(past_key_values) == 0:
return True
try:
first = past_key_values[0][0]
return not isinstance(first, torch.Tensor) or int(first.shape[2]) == 0
except Exception:
return False
get_seq_length = getattr(past_key_values, "get_seq_length", None)
if callable(get_seq_length):
try:
return int(get_seq_length()) == 0
except Exception:
return False
return False
def _normalize_last_layer_module(last_layer_module: str) -> str:
mode = str(last_layer_module or "none").strip().lower()
aliases = {
"": "none",
"off": "none",
"disabled": "none",
"disable": "none",
"self-attention": "self",
"self_attention": "self",
"selfattn": "self",
"self_attn": "self",
"cross-attention": "cross",
"cross_attention": "cross",
"crossattn": "cross",
"cross_attn": "cross",
}
mode = aliases.get(mode, mode)
if mode not in {"none", "self", "cross"}:
mode = "none"
return mode
def _normalize_replay_module(replay_module: str) -> str:
return _normalize_last_layer_module(replay_module)
def _replay_module_enabled(config) -> bool:
replay_module = _normalize_replay_module(
getattr(config, "capsule_replay_module", getattr(config, "capsule_last_layer_module", "none"))
)
return replay_module != "none"
def _normalize_replay_per_layers(replay_per_layers) -> int:
try:
value = int(replay_per_layers)
except Exception:
value = -1
if value == -1:
return -1
if value <= 0:
return -1
return value
def _normalize_upper_attention_mode(mode: Any) -> str:
normalized = str(mode or "causal").strip().lower().replace("-", "_")
aliases = {
"": "causal",
"normal": "causal",
"default": "causal",
"full": "causal",
"causal_attention": "causal",
"self": "solo_v2",
"self_only": "solo_v2",
"solo": "solo_v2",
"solo2": "solo_v2",
"solo_attention_v2": "solo_v2",
}
normalized = aliases.get(normalized, normalized)
if normalized not in {"causal", "solo_v2"}:
raise ValueError("upper_attention_mode must be one of {'causal', 'solo_v2'}.")
return normalized
def _normalize_attention_gate_mode(mode: Any) -> str:
normalized = str(mode or "off").strip().lower()
aliases = {
"": "off",
"none": "off",
"disabled": "off",
"disable": "off",
"false": "off",
"0": "off",
"paper": "sdpa_sigmoid",
"sdpa_gate": "sdpa_sigmoid",
"sdpa-gate": "sdpa_sigmoid",
"sigmoid_after_sdpa": "sdpa_sigmoid",
"sdpa_elementwise_sigmoid": "sdpa_sigmoid",
}
normalized = aliases.get(normalized, normalized)
if normalized not in {"off", "sdpa_sigmoid"}:
normalized = "off"
return normalized
def _attention_gate_enabled(config) -> bool:
return _normalize_attention_gate_mode(getattr(config, "capsule_attention_gate_mode", "off")) == "sdpa_sigmoid"
def _init_sdpa_gate_linear(linear: nn.Linear) -> None:
with torch.no_grad():
linear.weight.zero_()
if linear.bias is not None:
linear.bias.fill_(4.0)
setattr(linear, "_is_hf_initialized", True)
def _apply_sdpa_output_gate(attn_module: nn.Module, hidden_states: torch.Tensor, attn_output: torch.Tensor) -> torch.Tensor:
gate_proj = getattr(attn_module, "sdpa_gate_proj", None)
if gate_proj is None or not isinstance(hidden_states, torch.Tensor):
return attn_output
gate = torch.sigmoid(gate_proj(hidden_states))
gate = gate.view(
hidden_states.size(0),
hidden_states.size(1),
int(getattr(attn_module.config, "num_attention_heads")),
int(getattr(attn_module, "head_dim")),
)
return attn_output * gate.to(device=attn_output.device, dtype=attn_output.dtype)
def _copy_attention_module_state(target_attn: nn.Module, source_attn: nn.Module) -> None:
target_attn.load_state_dict(source_attn.state_dict(), strict=False)
def _seed_replay_cross_attn_from_self_pre_hook(
module: nn.Module,
state_dict: dict[str, torch.Tensor],
prefix: str,
local_metadata,
strict: bool,
missing_keys,
unexpected_keys,
error_msgs,
) -> None:
replay_attn = getattr(module, "replay_cross_attn", None)
self_attn = getattr(module, "self_attn", None)
if replay_attn is None or self_attn is None:
return
replay_state = replay_attn.state_dict()
for key in replay_state.keys():
replay_key = f"{prefix}replay_cross_attn.{key}"
if replay_key in state_dict:
continue
self_key = f"{prefix}self_attn.{key}"
if self_key not in state_dict:
continue
value = state_dict[self_key]
if isinstance(value, torch.Tensor):
state_dict[replay_key] = value.detach().clone()
else:
state_dict[replay_key] = value
def _tri_replay_layer_index_set(
*,
upper_layer_indices,
replay_per_layers,
) -> set[int]:
upper = [int(li) for li in upper_layer_indices]
if not upper:
return set()
replay_per_layers = _normalize_replay_per_layers(replay_per_layers)
if replay_per_layers == -1:
return {int(upper[-1])}
return {int(upper[idx]) for idx in range(int(replay_per_layers) - 1, len(upper), int(replay_per_layers))}
@use_kernel_forward_from_hub("RMSNorm")
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class LlamaRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, config: LlamaConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class LlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def _checkpoint_mlp(mlp, hidden_states):
if ds_checkpoint is not None:
return ds_checkpoint(mlp, hidden_states)
try:
return torch.utils.checkpoint.checkpoint(mlp, hidden_states, use_reentrant=False)
except TypeError:
return torch.utils.checkpoint.checkpoint(mlp, hidden_states)
def _run_decoder_mlp_only(layer, hidden_states: torch.Tensor, use_checkpoint: bool) -> torch.Tensor:
residual = hidden_states
hidden_states = layer.post_attention_layernorm(hidden_states)
if use_checkpoint:
hidden_states = _checkpoint_mlp(layer.mlp, hidden_states)
else:
hidden_states = layer.mlp(hidden_states)
return residual + hidden_states
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.sdpa_gate_proj = None
if _attention_gate_enabled(config):
self.sdpa_gate_proj = nn.Linear(
config.hidden_size,
config.num_attention_heads * self.head_dim,
bias=True,
)
_init_sdpa_gate_linear(self.sdpa_gate_proj)
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
if hidden_states.size(1) == 0:
# Nothing to attend over; happens when TRI splits yield empty segments.
return hidden_states, None
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
llopa_special_kv = kwargs.pop("llopa_special_kv", None)
llopa_special_span = kwargs.pop("llopa_special_span", None)
extra_prefix_kv = kwargs.pop("extra_prefix_kv", None)
extra_prefix_valid_mask = kwargs.pop("extra_prefix_valid_mask", None)
extra_prefix_query_mask = kwargs.pop("extra_prefix_query_mask", None)
extra_prefix_local_valid_mask = kwargs.pop("extra_prefix_local_valid_mask", None)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
cos_u = cos.unsqueeze(1)
sin_u = sin.unsqueeze(1)
query_states = (query_states * cos_u) + (rotate_half(query_states) * sin_u)
key_states = None
value_states = None
if llopa_special_kv is not None and llopa_special_span is not None:
try:
sp_start, sp_end = int(llopa_special_span[0]), int(llopa_special_span[1])
except Exception:
sp_start, sp_end = -1, -1
seq_len = int(input_shape[1])
if (
0 <= sp_start < sp_end <= seq_len
and isinstance(llopa_special_kv, (tuple, list))
and len(llopa_special_kv) == 2
):
key_sp, value_sp = llopa_special_kv
key_sp = key_sp.to(device=hidden_states.device, dtype=query_states.dtype)
value_sp = value_sp.to(device=hidden_states.device, dtype=query_states.dtype)
left_len = sp_start
right_len = seq_len - sp_end
if left_len == 0 and right_len == 0:
key_states = key_sp
value_states = value_sp
else:
parts = []
if left_len > 0:
parts.append(hidden_states[:, :left_len, :])
if right_len > 0:
parts.append(hidden_states[:, sp_end:, :])
non_special = torch.cat(parts, dim=1) if parts else hidden_states[:, :0, :]
if non_special.size(1) > 0:
non_hidden_shape = (*non_special.shape[:-1], -1, self.head_dim)
key_non = self.k_proj(non_special).view(non_hidden_shape).transpose(1, 2)
value_non = self.v_proj(non_special).view(non_hidden_shape).transpose(1, 2)
if left_len > 0 and right_len > 0:
cos_non = torch.cat([cos[:, :left_len, :], cos[:, sp_end:, :]], dim=1)
sin_non = torch.cat([sin[:, :left_len, :], sin[:, sp_end:, :]], dim=1)
elif left_len > 0:
cos_non = cos[:, :left_len, :]
sin_non = sin[:, :left_len, :]
else:
cos_non = cos[:, sp_end:, :]
sin_non = sin[:, sp_end:, :]
key_non = (key_non * cos_non.unsqueeze(1)) + (rotate_half(key_non) * sin_non.unsqueeze(1))
key_parts = []
value_parts = []
if left_len > 0:
key_parts.append(key_non[:, :, :left_len, :])
value_parts.append(value_non[:, :, :left_len, :])
key_parts.append(key_sp)
value_parts.append(value_sp)
if right_len > 0:
key_parts.append(key_non[:, :, left_len:, :])
value_parts.append(value_non[:, :, left_len:, :])
key_states = torch.cat(key_parts, dim=2)
value_states = torch.cat(value_parts, dim=2)
else:
key_states = key_sp
value_states = value_sp
if key_states is None or value_states is None:
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = (key_states * cos_u) + (rotate_half(key_states) * sin_u)
if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
attn_impl = _resolve_attn_impl(self.config)
if isinstance(extra_prefix_kv, (tuple, list)) and len(extra_prefix_kv) == 2:
prefix_key_states, prefix_value_states = extra_prefix_kv
if isinstance(prefix_key_states, torch.Tensor) and isinstance(prefix_value_states, torch.Tensor):
prefix_key_states = prefix_key_states.to(device=hidden_states.device, dtype=query_states.dtype)
prefix_value_states = prefix_value_states.to(device=hidden_states.device, dtype=query_states.dtype)
if prefix_key_states.numel() > 0 and int(prefix_key_states.size(2)) > 0:
local_key_len = int(key_states.size(2))
batch_size = int(hidden_states.size(0))
query_len = int(hidden_states.size(1))
prefix_len = int(prefix_key_states.size(2))
key_states = torch.cat([prefix_key_states, key_states], dim=2)
value_states = torch.cat([prefix_value_states, value_states], dim=2)
if attn_impl == "flash_attention_2":
local_valid_source = extra_prefix_local_valid_mask
if local_valid_source is None and isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 2:
local_valid_source = attention_mask
query_prefix_enabled = True
if isinstance(extra_prefix_query_mask, torch.Tensor):
query_prefix_enabled = bool(extra_prefix_query_mask.to(device=hidden_states.device, dtype=torch.bool).all().item())
if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 4:
attention_mask = _tri_build_self_prefix_attention_mask(
local_attention_mask=attention_mask,
batch_size=batch_size,
query_len=query_len,
local_key_len=local_key_len,
prefix_valid_mask=extra_prefix_valid_mask,
query_prefix_mask=extra_prefix_query_mask,
local_valid_mask=local_valid_source,
device=hidden_states.device,
dtype=hidden_states.dtype,
)
attn_impl = "eager"
elif query_prefix_enabled:
prefix_valid = _tri_expand_valid_mask(
extra_prefix_valid_mask,
batch_size=batch_size,
total_len=prefix_len,
device=hidden_states.device,
)
prefix_all_valid = bool(prefix_valid.all().item())
if attention_mask is None and prefix_all_valid:
pass
else:
local_valid = _tri_expand_valid_mask(
local_valid_source,
batch_size=batch_size,
total_len=local_key_len,
device=hidden_states.device,
)
attention_mask = torch.cat(
[prefix_valid.to(dtype=torch.long), local_valid.to(dtype=torch.long)],
dim=1,
)
else:
attention_mask = _tri_build_self_prefix_attention_mask(
local_attention_mask=attention_mask if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 4 else None,
batch_size=batch_size,
query_len=query_len,
local_key_len=local_key_len,
prefix_valid_mask=extra_prefix_valid_mask,
query_prefix_mask=extra_prefix_query_mask,
local_valid_mask=local_valid_source,
device=hidden_states.device,
dtype=hidden_states.dtype,
)
attn_impl = "eager"
else:
local_valid_source = extra_prefix_local_valid_mask
if local_valid_source is None and isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 2:
local_valid_source = attention_mask
attention_mask = _tri_build_self_prefix_attention_mask(
local_attention_mask=attention_mask if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 4 else None,
batch_size=batch_size,
query_len=query_len,
local_key_len=local_key_len,
prefix_valid_mask=extra_prefix_valid_mask,
query_prefix_mask=extra_prefix_query_mask,
local_valid_mask=local_valid_source,
device=hidden_states.device,
dtype=hidden_states.dtype,
)
attention_interface: Callable = eager_attention_forward
if attn_impl != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl]
if attn_impl == "flash_attention_2" and attention_mask is None and "position_ids" in kwargs:
# HF flash-attn varlen path expects flattened `position_ids` to contain segment starts at 0.
# In TRI/LLoPA we can have gap-preserving position ids (e.g., [0..S-1, S+U..]) with no padding mask,
# which can trigger an empty cu_seqlens.diff() in `_prepare_from_posids`.
# For fully valid (non-padded) sequences, dropping position_ids uses dense flash path safely.
kwargs = dict(kwargs)
kwargs.pop("position_ids", None)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = _apply_sdpa_output_gate(self, hidden_states, attn_output)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class LlamaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.partial_gradient_checkpointing = False
self.layer_idx = layer_idx
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
self.replay_cross_attn = None
if _replay_module_enabled(config):
self.replay_cross_attn = LlamaAttention(config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.register_load_state_dict_pre_hook(_seed_replay_cross_attn_from_self_pre_hook)
def ensure_replay_cross_attn(self) -> Optional[LlamaAttention]:
replay_attn = getattr(self, "replay_cross_attn", None)
if replay_attn is not None:
return replay_attn
if self.training:
raise RuntimeError(
"Replay attention was requested at runtime, but replay parameters were not initialized. "
"Build the model with replay_module=cross before training."
)
replay_attn = LlamaAttention(config=self.self_attn.config, layer_idx=self.layer_idx)
replay_attn.to(
device=self.self_attn.q_proj.weight.device,
dtype=self.self_attn.q_proj.weight.dtype,
)
self.replay_cross_attn = replay_attn
self.sync_replay_cross_attn_from_self()
return replay_attn
def sync_replay_cross_attn_from_self(self) -> None:
replay_attn = getattr(self, "replay_cross_attn", None)
if replay_attn is None:
return
_copy_attention_module_state(replay_attn, self.self_attn)
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
if self.training and self.partial_gradient_checkpointing:
hidden_states = _checkpoint_mlp(self.mlp, hidden_states)
else:
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
@auto_docstring
class LlamaPreTrainedModel(PreTrainedModel):
config: LlamaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": LlamaDecoderLayer,
"attentions": LlamaAttention,
}
@auto_docstring
class LlamaModel(LlamaPreTrainedModel):
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
self.llopa_num_specials = int(getattr(config, "llopa_num_specials", 0) or 0)
self.llopa_specials = None
if self.llopa_num_specials > 0:
self.llopa_specials = nn.ParameterList(
[nn.Parameter(torch.empty(self.llopa_num_specials, config.hidden_size))
for _ in range(config.num_hidden_layers)]
)
init_range = float(getattr(config, "initializer_range", 0.02))
for p in self.llopa_specials:
nn.init.normal_(p, mean=0.0, std=init_range)
# Initialize weights and apply final processing
self.post_init()
for layer in self.layers:
if hasattr(layer, "sync_replay_cross_attn_from_self"):
layer.sync_replay_cross_attn_from_self()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@check_model_inputs
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
skip_upper_attention_layers = int(kwargs.pop("skip_upper_attention_layers", 0) or 0)
solo_attention_layers = int(kwargs.pop("solo_attention_layers", 0) or 0)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = _safe_dynamic_cache(self.config)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position: torch.Tensor = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
skip_from_layer = max(0, min(int(skip_upper_attention_layers), self.config.num_hidden_layers))
solo_from_layer = max(0, min(int(solo_attention_layers), self.config.num_hidden_layers))
solo_mask = None
if solo_from_layer > 0:
past_len = past_key_values.get_seq_length() if past_key_values is not None else 0
solo_mask = _build_self_only_mask(inputs_embeds.size(0), inputs_embeds.size(1), past_len, inputs_embeds.device, inputs_embeds.dtype)
use_checkpoint = should_use_checkpoint(self)
_llopa_log_memory(
self,
"after_vanilla_embed",
sequence_len=hidden_states.size(1),
)
for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
if skip_from_layer > 0 and layer_idx >= skip_from_layer:
hidden_states = _run_decoder_mlp_only(decoder_layer, hidden_states, use_checkpoint)
_llopa_log_memory(
self,
f"after_vanilla_upper_skip_layer_{layer_idx}",
sequence_len=hidden_states.size(1),
)
else:
layer_mask = causal_mask
if solo_from_layer > 0 and layer_idx >= solo_from_layer:
layer_mask = solo_mask
hidden_states = decoder_layer(
hidden_states,
attention_mask=layer_mask,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
if skip_from_layer > 0 and layer_idx + 1 == skip_from_layer:
_llopa_log_memory(
self,
"after_vanilla_lower",
sequence_len=hidden_states.size(1),
)
hidden_states = self.norm(hidden_states)
_llopa_log_memory(
self,
"after_vanilla_final",
sequence_len=hidden_states.size(1),
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)
@auto_docstring
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config, **kwargs):
# Keep compatibility with newer transformers `from_pretrained(..., dtype=...)` kwargs.
init_dtype = kwargs.pop("dtype", None)
if init_dtype is None:
init_dtype = kwargs.pop("torch_dtype", None)
super().__init__(config, **kwargs)
self.model = LlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
if isinstance(init_dtype, str):
alias = {
"bf16": torch.bfloat16,
"bfloat16": torch.bfloat16,
"fp16": torch.float16,
"float16": torch.float16,
"fp32": torch.float32,
"float32": torch.float32,
}
init_dtype = alias.get(init_dtype.lower(), None)
if isinstance(init_dtype, torch.dtype):
self.to(dtype=init_dtype)
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
skip_upper_attention_layers = int(kwargs.pop("skip_upper_attention_layers", 0) or 0)
prefill_lower_layers = int(kwargs.pop("prefill_lower_layers", 0) or 0)
prefill_lower_attn = str(kwargs.pop("prefill_lower_attn", "causal") or "causal")
prefill_lower_no_upper_attn = bool(kwargs.pop("prefill_lower_no_upper_attn", False))
prefill_lower_upper_attention_mode = _normalize_upper_attention_mode(
kwargs.pop(
"prefill_lower_upper_attention_mode",
getattr(self.config, "capsule_upper_attention_mode", "causal"),
)
)
prefill_lower_solo_attention = bool(kwargs.pop("prefill_lower_solo_attention", False))
prefill_lower_solo_attention_v2 = bool(kwargs.pop("prefill_lower_solo_attention_v2", False))
prefill_lower_solo_attention_v2_with_bos = bool(kwargs.pop("prefill_lower_solo_attention_v2_with_bos", False))
prefill_lower_freeze_runtime = bool(kwargs.pop("prefill_lower_freeze_runtime", False))
prefill_lower_split_start = kwargs.pop("prefill_lower_split_start", None)
prefill_lower_system_len = kwargs.pop(PREFILL_LOWER_SYSTEM_LEN_KEY, None)
prefill_lower_replay_user_prefix_keep_len = kwargs.pop(PREFILL_LOWER_REPLAY_USER_PREFIX_KEEP_LEN_KEY, None)
prefill_lower_replay_user_start = kwargs.pop(PREFILL_LOWER_REPLAY_USER_START_KEY, None)
prefill_lower_replay_user_len = kwargs.pop(PREFILL_LOWER_REPLAY_USER_LEN_KEY, None)
prefill_lower_system_prefill = str(kwargs.pop("prefill_lower_system_prefill", "no_bos_system") or "no_bos_system")
prefill_lower_replay_module = _normalize_replay_module(
kwargs.pop(
"prefill_lower_replay_module",
kwargs.pop(
"prefill_lower_last_layer_module",
getattr(self.config, "capsule_replay_module", getattr(self.config, "capsule_last_layer_module", "none")),
),
)
)
prefill_lower_replay_per_layers = _normalize_replay_per_layers(
kwargs.pop(
"prefill_lower_replay_per_layers",
getattr(self.config, "capsule_replay_per_layers", -1),
)
)
prefill_lower_see_past_assistant = bool(kwargs.pop("prefill_lower_see_past_assistant", False))
assistant_header_start = kwargs.pop("assistant_header_start", None)
assistant_header_starts = kwargs.pop(ASSISTANT_HEADER_STARTS_KEY, None)
assistant_turn_ends = kwargs.pop(ASSISTANT_TURN_ENDS_KEY, None)
assistant_header_start_mask = kwargs.pop(ASSISTANT_HEADER_START_MASK_KEY, None)
runtime_llopa_prefill = bool(kwargs.pop("runtime_llopa_prefill", False))
runtime_llopa_layers = int(kwargs.pop("runtime_llopa_layers", 0) or 0)
runtime_llopa_attn = str(kwargs.pop("runtime_llopa_attn", "causal") or "causal")
runtime_llopa_no_upper_attn = bool(kwargs.pop("runtime_llopa_no_upper_attn", False))
runtime_llopa_replay_module = _normalize_replay_module(
kwargs.pop(
"runtime_llopa_replay_module",
kwargs.pop("runtime_llopa_last_layer_module", prefill_lower_replay_module),
)
)
runtime_llopa_replay_per_layers = _normalize_replay_per_layers(
kwargs.pop(
"runtime_llopa_replay_per_layers",
prefill_lower_replay_per_layers,
)
)
llopa_v2_decode = bool(kwargs.pop("llopa_v2_decode", False))
llopa_v2_decode_layers = int(
kwargs.pop(
"llopa_v2_decode_layers",
kwargs.pop("llopa_v2_lower_k", runtime_llopa_layers),
)
or 0
)
llopa_v2_decode_no_upper_attn = bool(
kwargs.pop("llopa_v2_decode_no_upper_attn", runtime_llopa_no_upper_attn)
)
llopa_v2_decode_replay_module = _normalize_replay_module(
kwargs.pop("llopa_v2_decode_replay_module", runtime_llopa_replay_module)
)
llopa_v2_decode_replay_per_layers = _normalize_replay_per_layers(
kwargs.pop("llopa_v2_decode_replay_per_layers", runtime_llopa_replay_per_layers)
)
prefill_lower_solo_attention_v2 = bool(
prefill_lower_solo_attention_v2 or prefill_lower_solo_attention_v2_with_bos
)
if (
llopa_v2_decode
and labels is None
and inputs_embeds is None
and past_key_values is not None
and isinstance(input_ids, torch.Tensor)
and input_ids.dim() == 2
and input_ids.size(1) > 0
and llopa_v2_decode_layers > 0
):
tri_outputs = self.model.tri_forward_assistant(
assistant_ids=input_ids,
lower_k=llopa_v2_decode_layers,
pkv=past_key_values,
S=0,
U=0,
write_cache=bool(use_cache if use_cache is not None else True),
prefill_mode="lower",
no_upper_attn=llopa_v2_decode_no_upper_attn,
align_cache_position_to_layer_past=False,
replay_module=llopa_v2_decode_replay_module,
replay_per_layers=llopa_v2_decode_replay_per_layers,
)
hidden_states = tri_outputs.last_hidden_state
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
return CausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=tri_outputs.past_key_values,
hidden_states=tri_outputs.hidden_states,
attentions=tri_outputs.attentions,
)
if prefill_lower_split_start is None:
prefill_lower_split_start = assistant_header_start
if runtime_llopa_prefill and inputs_embeds is None and (bool(use_cache) or past_key_values is not None):
tri_outputs = self.tri_runtime_llopa_prompt_prefill_forward(
input_ids=input_ids,
attention_mask=attention_mask,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
past_key_values=past_key_values,
lower_k=runtime_llopa_layers,
prefill_attn=runtime_llopa_attn,
no_upper_attn=runtime_llopa_no_upper_attn,
replay_module=runtime_llopa_replay_module,
replay_per_layers=runtime_llopa_replay_per_layers,
)
if tri_outputs is not None:
return tri_outputs
if prefill_lower_layers > 0 and inputs_embeds is None:
tri_outputs = None
if labels is not None and past_key_values is None and not bool(use_cache):
if prefill_lower_solo_attention_v2:
tri_outputs = self.tri_vanilla_solo_attention_v2_train_forward(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
logits_to_keep=logits_to_keep,
prefill_lower_layers=prefill_lower_layers,
prefill_attn=prefill_lower_attn,
split_start_hint=prefill_lower_split_start,
system_len_hint=prefill_lower_system_len,
system_prefill=prefill_lower_system_prefill,
with_bos=prefill_lower_solo_attention_v2_with_bos,
)
elif prefill_lower_solo_attention:
tri_outputs = self.tri_vanilla_solo_prefix_train_forward(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
logits_to_keep=logits_to_keep,
prefill_lower_layers=prefill_lower_layers,
prefill_attn=prefill_lower_attn,
split_start_hint=prefill_lower_split_start,
system_len_hint=prefill_lower_system_len,
system_prefill=prefill_lower_system_prefill,
)
else:
tri_outputs = self.tri_vanilla_train_forward(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
logits_to_keep=logits_to_keep,
prefill_lower_layers=prefill_lower_layers,
prefill_attn=prefill_lower_attn,
split_start_hint=prefill_lower_split_start,
system_len_hint=prefill_lower_system_len,
system_prefill=prefill_lower_system_prefill,
assistant_header_starts_hint=assistant_header_starts,
assistant_turn_ends_hint=assistant_turn_ends,
assistant_header_start_mask_hint=assistant_header_start_mask,
see_past_assistant=prefill_lower_see_past_assistant,
no_upper_attn=prefill_lower_no_upper_attn,
upper_attention_mode=prefill_lower_upper_attention_mode,
replay_module=prefill_lower_replay_module,
replay_per_layers=prefill_lower_replay_per_layers,
)
elif prefill_lower_solo_attention_v2 and past_key_values is None and not bool(use_cache):
tri_outputs = self.tri_vanilla_solo_attention_v2_infer_forward(
input_ids=input_ids,
attention_mask=attention_mask,
logits_to_keep=logits_to_keep,
prefill_lower_layers=prefill_lower_layers,
prefill_attn=prefill_lower_attn,
split_start_hint=prefill_lower_split_start,
system_len_hint=prefill_lower_system_len,
system_prefill=prefill_lower_system_prefill,
with_bos=prefill_lower_solo_attention_v2_with_bos,
)
elif prefill_lower_solo_attention and past_key_values is None and not bool(use_cache):
tri_outputs = self.tri_vanilla_solo_prefix_infer_forward(
input_ids=input_ids,
attention_mask=attention_mask,
logits_to_keep=logits_to_keep,
prefill_lower_layers=prefill_lower_layers,
prefill_attn=prefill_lower_attn,
split_start_hint=prefill_lower_split_start,
system_len_hint=prefill_lower_system_len,
system_prefill=prefill_lower_system_prefill,
)
elif prefill_lower_freeze_runtime and past_key_values is None and not bool(use_cache):
tri_outputs = self.tri_vanilla_frozen_prefix_infer_forward(
input_ids=input_ids,
attention_mask=attention_mask,
logits_to_keep=logits_to_keep,
prefill_lower_layers=prefill_lower_layers,
prefill_attn=prefill_lower_attn,
split_start_hint=prefill_lower_split_start,
system_len_hint=prefill_lower_system_len,
system_prefill=prefill_lower_system_prefill,
)
elif prefill_lower_solo_attention and past_key_values is None and not bool(use_cache):
tri_outputs = self.tri_vanilla_solo_prefix_infer_forward(
input_ids=input_ids,
attention_mask=attention_mask,
logits_to_keep=logits_to_keep,
prefill_lower_layers=prefill_lower_layers,
prefill_attn=prefill_lower_attn,
split_start_hint=prefill_lower_split_start,
system_len_hint=prefill_lower_system_len,
system_prefill=prefill_lower_system_prefill,
)
elif bool(use_cache) or past_key_values is not None:
tri_outputs = self.tri_vanilla_prefill_decode_forward(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
past_key_values=past_key_values,
prefill_lower_layers=prefill_lower_layers,
prefill_attn=prefill_lower_attn,
split_start_hint=prefill_lower_split_start,
system_len_hint=prefill_lower_system_len,
system_prefill=prefill_lower_system_prefill,
assistant_header_starts_hint=assistant_header_starts,
assistant_turn_ends_hint=assistant_turn_ends,
assistant_header_start_mask_hint=assistant_header_start_mask,
replay_user_prefix_keep_len_hint=prefill_lower_replay_user_prefix_keep_len,
replay_user_start_hint=prefill_lower_replay_user_start,
replay_user_len_hint=prefill_lower_replay_user_len,
see_past_assistant=prefill_lower_see_past_assistant,
no_upper_attn=prefill_lower_no_upper_attn,
replay_module=prefill_lower_replay_module,
replay_per_layers=prefill_lower_replay_per_layers,
)
if tri_outputs is not None:
return tri_outputs
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
skip_upper_attention_layers=skip_upper_attention_layers,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class LlamaForSequenceClassification(GenericForSequenceClassification, LlamaPreTrainedModel): ...
class LlamaForQuestionAnswering(GenericForQuestionAnswering, LlamaPreTrainedModel):
base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
class LlamaForTokenClassification(GenericForTokenClassification, LlamaPreTrainedModel): ...
# =============================================================================
# TRI utilities (System/User/Assistant split)
# =============================================================================
def _tri_arange(start: int, length: int, device) -> torch.LongTensor:
return torch.arange(int(start), int(start + length), device=device, dtype=torch.long)
def _prefill_layer_indices(n_layers: int, mode: str, k: int) -> list[int]:
if (mode or "lower") == "periodic":
step = int(k)
if step <= 0:
return []
return list(range(0, n_layers, step))
K_eff = max(0, min(int(k), n_layers))
return list(range(K_eff))
def _safe_len_seqlen(x) -> int:
if x is None:
return 0
if hasattr(x, "shape") and x.dim() >= 3:
return int(x.shape[2])
return 0
def _layer_past_len(pkv: Cache, li: int) -> int:
if pkv is None:
return 0
if hasattr(pkv, "key_cache") and isinstance(pkv.key_cache, (list, tuple)) and li < len(pkv.key_cache):
return _safe_len_seqlen(pkv.key_cache[li])
if hasattr(pkv, "layers") and isinstance(pkv.layers, (list, tuple)) and li < len(pkv.layers):
layer = pkv.layers[li]
keys = None if layer is None else getattr(layer, "keys", None)
return _safe_len_seqlen(keys)
try:
kv = pkv[li]
if isinstance(kv, (list, tuple)) and len(kv) >= 1:
return _safe_len_seqlen(kv[0])
except Exception:
pass
return 0
def _build_tri_mask_local(B: int, Tq: int, past_len: int, device, dtype):
Tk = past_len + Tq
i = torch.arange(Tq, device=device).unsqueeze(1)
j = torch.arange(Tk, device=device).unsqueeze(0)
allow = j <= (past_len + i)
neg_inf = torch.tensor(torch.finfo(torch.float32).min, device=device)
mask = torch.where(allow, torch.tensor(0.0, device=device), neg_inf).to(dtype)
return mask.view(1, 1, Tq, Tk).expand(B, 1, Tq, Tk)
def _build_tri_mask_full(B: int, Tq: int, past_len: int, device, dtype):
Tk = past_len + Tq
return torch.zeros((B, 1, Tq, Tk), device=device, dtype=dtype)
def _build_self_only_mask(B: int, Tq: int, past_len: int, device, dtype):
Tk = past_len + Tq
min_value = torch.finfo(dtype).min
mask = torch.full((B, 1, Tq, Tk), min_value, device=device, dtype=dtype)
idx = torch.arange(Tq, device=device)
mask[:, :, idx, past_len + idx] = 0
return mask
def _tri_expand_valid_mask(
valid_mask: Optional[torch.Tensor],
*,
batch_size: int,
total_len: int,
device,
) -> torch.BoolTensor:
if valid_mask is None:
return torch.ones((int(batch_size), int(total_len)), device=device, dtype=torch.bool)
valid_mask = valid_mask.to(device=device, dtype=torch.bool)
if valid_mask.dim() == 1:
valid_mask = valid_mask.view(1, -1)
if int(valid_mask.size(0)) == 1 and int(batch_size) > 1:
valid_mask = valid_mask.expand(int(batch_size), -1)
if int(valid_mask.size(1)) < int(total_len):
pad = torch.ones(
(int(valid_mask.size(0)), int(total_len) - int(valid_mask.size(1))),
device=device,
dtype=torch.bool,
)
valid_mask = torch.cat([pad, valid_mask], dim=1)
elif int(valid_mask.size(1)) > int(total_len):
valid_mask = valid_mask[:, -int(total_len) :]
return valid_mask
def _tri_build_prefix_query_allow_mask(
*,
batch_size: int,
query_len: int,
prefix_valid_mask: Optional[torch.Tensor],
query_prefix_mask: Optional[torch.Tensor],
prefix_len: int,
device,
) -> torch.BoolTensor:
if int(prefix_len) <= 0:
return torch.zeros((int(batch_size), 1, int(query_len), 0), device=device, dtype=torch.bool)
prefix_valid = _tri_expand_valid_mask(
prefix_valid_mask,
batch_size=batch_size,
total_len=prefix_len,
device=device,
)
if query_prefix_mask is None:
query_prefix_mask = torch.ones((int(batch_size), int(query_len)), device=device, dtype=torch.bool)
else:
query_prefix_mask = query_prefix_mask.to(device=device, dtype=torch.bool)
if query_prefix_mask.dim() == 1:
query_prefix_mask = query_prefix_mask.view(1, -1)
if int(query_prefix_mask.size(0)) == 1 and int(batch_size) > 1:
query_prefix_mask = query_prefix_mask.expand(int(batch_size), -1)
return prefix_valid.view(int(batch_size), 1, 1, int(prefix_len)).expand(
int(batch_size), 1, int(query_len), int(prefix_len)
) & query_prefix_mask.view(int(batch_size), 1, int(query_len), 1)
def _tri_build_self_prefix_attention_mask(
*,
local_attention_mask: Optional[torch.Tensor],
batch_size: int,
query_len: int,
local_key_len: int,
prefix_valid_mask: Optional[torch.Tensor],
query_prefix_mask: Optional[torch.Tensor],
local_valid_mask: Optional[torch.Tensor],
device,
dtype,
) -> torch.Tensor:
prefix_len = 0 if prefix_valid_mask is None else int(prefix_valid_mask.size(-1))
prefix_allow = _tri_build_prefix_query_allow_mask(
batch_size=batch_size,
query_len=query_len,
prefix_valid_mask=prefix_valid_mask,
query_prefix_mask=query_prefix_mask,
prefix_len=prefix_len,
device=device,
)
if isinstance(local_attention_mask, torch.Tensor) and local_attention_mask.dim() == 4:
local_mask = local_attention_mask.to(device=device)
local_mask = local_mask[:, :, :, : int(local_key_len)]
local_allow = local_mask == 0
else:
local_valid = _tri_expand_valid_mask(
local_valid_mask,
batch_size=batch_size,
total_len=local_key_len,
device=device,
)
past_len = max(int(local_key_len) - int(query_len), 0)
q_pos = torch.arange(query_len, device=device, dtype=torch.long).view(1, 1, query_len, 1)
k_pos = torch.arange(local_key_len, device=device, dtype=torch.long).view(1, 1, 1, local_key_len)
local_allow = (k_pos <= (past_len + q_pos)) & local_valid.view(int(batch_size), 1, 1, int(local_key_len))
allow = torch.cat([prefix_allow, local_allow], dim=-1)
neg_inf = torch.tensor(torch.finfo(torch.float32).min, device=device)
zeros = torch.tensor(0.0, device=device)
return torch.where(allow, zeros, neg_inf).to(dtype)
def _tri_build_last_layer_memory_index_batch(
*,
split_starts: torch.LongTensor,
prefix_keep_lens: torch.LongTensor,
system_lens: Optional[torch.LongTensor] = None,
latest_user_starts: Optional[torch.LongTensor] = None,
latest_user_lens: Optional[torch.LongTensor] = None,
device,
owner=None,
bucket_multiple: int = 0,
workspace_key: str = "single",
) -> tuple[torch.LongTensor, torch.BoolTensor, torch.LongTensor]:
split_starts = split_starts.to(device=device, dtype=torch.long)
prefix_keep_lens = prefix_keep_lens.to(device=device, dtype=torch.long)
batch_size = int(split_starts.size(0))
if isinstance(system_lens, torch.Tensor):
system_lens = system_lens.to(device=device, dtype=torch.long)
else:
system_lens = None
if isinstance(latest_user_starts, torch.Tensor):
latest_user_starts = latest_user_starts.to(device=device, dtype=torch.long)
else:
latest_user_starts = None
if isinstance(latest_user_lens, torch.Tensor):
latest_user_lens = latest_user_lens.to(device=device, dtype=torch.long)
else:
latest_user_lens = None
memory_indices = []
memory_lens = torch.zeros((batch_size,), device=device, dtype=torch.long)
max_memory_len = 0
for row in range(batch_size):
prefix_keep = max(int(prefix_keep_lens[row].item()), 0)
split_start = max(int(split_starts[row].item()), 0)
system_end = prefix_keep
if isinstance(system_lens, torch.Tensor):
system_end = min(max(int(system_lens[row].item()), prefix_keep), split_start)
idx_parts = []
if system_end > prefix_keep:
idx_parts.append(torch.arange(prefix_keep, system_end, device=device, dtype=torch.long))
if isinstance(latest_user_starts, torch.Tensor):
latest_start = min(max(int(latest_user_starts[row].item()), system_end), split_start)
if isinstance(latest_user_lens, torch.Tensor):
latest_len = max(int(latest_user_lens[row].item()), 0)
latest_end = min(latest_start + latest_len, split_start)
else:
latest_end = split_start
if latest_end > latest_start:
idx_parts.append(torch.arange(latest_start, latest_end, device=device, dtype=torch.long))
if idx_parts:
idx = idx_parts[0] if len(idx_parts) == 1 else torch.cat(idx_parts, dim=0)
elif split_start > prefix_keep:
idx = torch.arange(prefix_keep, split_start, device=device, dtype=torch.long)
else:
idx = torch.empty((0,), device=device, dtype=torch.long)
memory_indices.append(idx)
memory_lens[row] = int(idx.numel())
if idx.numel() > max_memory_len:
max_memory_len = int(idx.numel())
bucket_multiple = int(bucket_multiple or 0)
target_memory_len = int(max_memory_len)
if bucket_multiple > 0 and target_memory_len > 0:
target_memory_len = ((target_memory_len + bucket_multiple - 1) // bucket_multiple) * bucket_multiple
gather_idx = None
valid_mask = None
if owner is not None and bucket_multiple > 0 and not bool(getattr(owner, "training", False)):
cache = getattr(owner, "_optimized_llopa_last_layer_memory_workspace_cache", None)
if not isinstance(cache, dict):
cache = {}
key = (str(workspace_key), str(device), int(batch_size), int(target_memory_len))
pair = cache.get(key)
if (
not isinstance(pair, tuple)
or len(pair) != 2
or not isinstance(pair[0], torch.Tensor)
or not isinstance(pair[1], torch.Tensor)
or pair[0].shape != (int(batch_size), int(target_memory_len))
or pair[1].shape != (int(batch_size), int(target_memory_len))
or pair[0].device != device
or pair[1].device != device
):
pair = (
torch.zeros((int(batch_size), int(target_memory_len)), device=device, dtype=torch.long),
torch.zeros((int(batch_size), int(target_memory_len)), device=device, dtype=torch.bool),
)
cache[key] = pair
with contextlib.suppress(Exception):
setattr(owner, "_optimized_llopa_last_layer_memory_workspace_cache", cache)
gather_idx, valid_mask = pair
gather_idx.zero_()
valid_mask.zero_()
if gather_idx is None or valid_mask is None:
gather_idx = torch.zeros((batch_size, target_memory_len), device=device, dtype=torch.long)
valid_mask = torch.zeros((batch_size, target_memory_len), device=device, dtype=torch.bool)
for row, idx in enumerate(memory_indices):
width = int(idx.numel())
if width <= 0:
continue
gather_idx[row, :width] = idx
valid_mask[row, :width] = True
return gather_idx, valid_mask, memory_lens
def _tri_resolve_replay_user_window(
*,
user_len: int,
prefix_keep_len: int,
latest_user_start: Optional[int],
latest_user_len: Optional[int],
) -> tuple[int, int, int]:
user_len = max(int(user_len), 0)
prefix_keep_len = min(max(int(prefix_keep_len), 0), user_len)
if latest_user_start is None:
return prefix_keep_len, prefix_keep_len, user_len
start = min(max(int(latest_user_start), prefix_keep_len), user_len)
if latest_user_len is None:
end = user_len
else:
end = min(max(start + int(latest_user_len), start), user_len)
if end <= start and user_len > prefix_keep_len:
start = prefix_keep_len
end = user_len
return prefix_keep_len, start, end
def _tri_build_last_layer_query_mask(
*,
prefix_keep_lens: torch.LongTensor,
valid_lens: torch.LongTensor,
total_len: int,
device,
) -> torch.BoolTensor:
if total_len <= 0:
return torch.zeros((int(valid_lens.size(0)), 0), device=device, dtype=torch.bool)
prefix_keep_lens = prefix_keep_lens.to(device=device, dtype=torch.long)
valid_lens = valid_lens.to(device=device, dtype=torch.long)
positions = torch.arange(total_len, device=device, dtype=torch.long).unsqueeze(0)
return (positions >= prefix_keep_lens.unsqueeze(1)) & (positions < valid_lens.unsqueeze(1))
def _tri_store_last_layer_memory(
pkv: Optional[Cache],
*,
hidden_states: Optional[torch.Tensor],
position_ids: Optional[torch.LongTensor],
valid_mask: Optional[torch.Tensor],
module_type: str,
replay_per_layers: int = -1,
) -> None:
if pkv is None:
return
module_type = _normalize_replay_module(module_type)
replay_per_layers = _normalize_replay_per_layers(replay_per_layers)
if module_type == "none":
return
if not isinstance(hidden_states, torch.Tensor) or hidden_states.numel() == 0 or int(hidden_states.size(1)) == 0:
return
try:
setattr(pkv, "_tri_last_layer_memory_hidden", hidden_states)
setattr(pkv, "_tri_last_layer_memory_position_ids", position_ids)
setattr(pkv, "_tri_last_layer_memory_valid_mask", valid_mask)
setattr(pkv, "_tri_last_layer_module", module_type)
setattr(pkv, "_tri_replay_module", module_type)
setattr(pkv, "_tri_replay_per_layers", int(replay_per_layers))
except Exception:
pass
def _tri_segment_position_ids(
*,
batch_size: int,
start: int,
length: int,
device,
) -> torch.LongTensor:
if int(length) <= 0:
return torch.empty((int(batch_size), 0), device=device, dtype=torch.long)
position_ids = _tri_arange(int(start), int(length), device).unsqueeze(0)
return position_ids.expand(int(batch_size), -1)
def _tri_valid_lengths(
valid_mask: Optional[torch.Tensor],
*,
batch_size: int,
default_len: int,
device,
) -> torch.LongTensor:
if valid_mask is None:
return torch.full((int(batch_size),), int(default_len), device=device, dtype=torch.long)
valid_mask = valid_mask.to(device=device, dtype=torch.bool)
if valid_mask.dim() == 1:
valid_mask = valid_mask.view(1, -1)
return valid_mask.sum(dim=1, dtype=torch.long)
def _tri_cu_seq_lens_from_lengths(lengths: torch.LongTensor) -> tuple[torch.IntTensor, int]:
lengths = lengths.to(dtype=torch.int32)
if lengths.numel() == 0:
return torch.zeros((1,), device=lengths.device, dtype=torch.int32), 0
cu_seq_lens = torch.cat(
[
torch.zeros((1,), device=lengths.device, dtype=torch.int32),
torch.cumsum(lengths, dim=0, dtype=torch.int32),
],
dim=0,
)
return cu_seq_lens, int(lengths.max().item())
def _tri_build_replay_attention_mask(
*,
batch_size: int,
query_len: int,
local_valid_mask: Optional[torch.Tensor],
memory_valid_mask: Optional[torch.Tensor],
module_type: str,
device,
dtype,
) -> Optional[torch.Tensor]:
module_type = _normalize_last_layer_module(module_type)
memory_len = 0 if memory_valid_mask is None else int(memory_valid_mask.size(1))
if module_type == "none" or query_len <= 0 or (module_type == "cross" and memory_len <= 0):
return None
pieces = []
if memory_len > 0:
mem_valid = memory_valid_mask.to(device=device, dtype=torch.bool).view(batch_size, 1, 1, memory_len)
pieces.append(mem_valid.expand(batch_size, 1, query_len, memory_len))
if module_type == "self":
if local_valid_mask is None:
local_valid_mask = torch.ones((batch_size, query_len), device=device, dtype=torch.bool)
else:
local_valid_mask = local_valid_mask.to(device=device, dtype=torch.bool)
local_len = int(local_valid_mask.size(1))
q_pos = torch.arange(query_len, device=device, dtype=torch.long).view(1, 1, query_len, 1)
k_pos = torch.arange(local_len, device=device, dtype=torch.long).view(1, 1, 1, local_len)
local_allow = (k_pos <= q_pos) & local_valid_mask.view(batch_size, 1, 1, local_len)
pieces.append(local_allow)
if not pieces:
return None
allow = torch.cat(pieces, dim=-1)
neg_inf = torch.tensor(torch.finfo(torch.float32).min, device=device)
zeros = torch.tensor(0.0, device=device)
return torch.where(allow, zeros, neg_inf).to(dtype)
def _tri_project_memory_kv(
*,
attn_module: "LlamaAttention",
memory_hidden_states: torch.Tensor,
memory_position_embeddings: tuple[torch.Tensor, torch.Tensor],
target_device,
target_dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
memory_hidden_states = memory_hidden_states.to(device=target_device, dtype=target_dtype)
memory_shape = (*memory_hidden_states.shape[:-1], -1, attn_module.head_dim)
memory_key_states = attn_module.k_proj(memory_hidden_states).view(memory_shape).transpose(1, 2)
memory_value_states = attn_module.v_proj(memory_hidden_states).view(memory_shape).transpose(1, 2)
memory_cos, memory_sin = memory_position_embeddings
memory_key_states = (
memory_key_states * memory_cos.unsqueeze(1)
) + (rotate_half(memory_key_states) * memory_sin.unsqueeze(1))
return memory_key_states, memory_value_states
def _tri_replay_attention_forward(
*,
attn_module: "LlamaAttention",
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
local_valid_mask: Optional[torch.Tensor],
query_replay_mask: Optional[torch.Tensor],
memory_hidden_states: Optional[torch.Tensor],
memory_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]],
memory_valid_mask: Optional[torch.Tensor],
module_type: str,
) -> torch.Tensor:
module_type = _normalize_last_layer_module(module_type)
if module_type == "none" or hidden_states.size(1) == 0:
return hidden_states.new_zeros(hidden_states.shape)
if not isinstance(memory_hidden_states, torch.Tensor) or memory_hidden_states.numel() == 0 or int(memory_hidden_states.size(1)) == 0:
return hidden_states.new_zeros(hidden_states.shape)
batch_size, query_len = int(hidden_states.size(0)), int(hidden_states.size(1))
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, attn_module.head_dim)
query_states = attn_module.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
cos_u = cos.unsqueeze(1)
sin_u = sin.unsqueeze(1)
query_states = (query_states * cos_u) + (rotate_half(query_states) * sin_u)
memory_key_states, memory_value_states = _tri_project_memory_kv(
attn_module=attn_module,
memory_hidden_states=memory_hidden_states,
memory_position_embeddings=memory_position_embeddings,
target_device=hidden_states.device,
target_dtype=hidden_states.dtype,
)
if module_type == "self":
local_key_states = attn_module.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
local_value_states = attn_module.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
local_key_states = (local_key_states * cos_u) + (rotate_half(local_key_states) * sin_u)
key_states = torch.cat([memory_key_states, local_key_states], dim=2)
value_states = torch.cat([memory_value_states, local_value_states], dim=2)
else:
key_states = memory_key_states
value_states = memory_value_states
attn_impl = _resolve_attn_impl(attn_module.config)
attention_interface: Callable = eager_attention_forward
if attn_impl != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl]
attn_kwargs = {}
replay_mask = None
restore_is_causal = None
if attn_impl == "flash_attention_2":
if module_type == "self":
memory_valid = memory_valid_mask
if memory_valid is None:
memory_valid = torch.ones(
(batch_size, memory_key_states.size(2)),
device=hidden_states.device,
dtype=torch.bool,
)
local_valid = local_valid_mask
if local_valid is None:
local_valid = torch.ones((batch_size, query_len), device=hidden_states.device, dtype=torch.bool)
replay_mask = torch.cat(
[
memory_valid.to(device=hidden_states.device, dtype=torch.long),
local_valid.to(device=hidden_states.device, dtype=torch.long),
],
dim=1,
)
else:
query_lengths = _tri_valid_lengths(
local_valid_mask,
batch_size=batch_size,
default_len=query_len,
device=hidden_states.device,
)
key_lengths = _tri_valid_lengths(
memory_valid_mask,
batch_size=batch_size,
default_len=memory_key_states.size(2),
device=hidden_states.device,
)
if bool(torch.all(query_lengths > 0).item()) and bool(torch.all(key_lengths > 0).item()):
cu_q, max_q = _tri_cu_seq_lens_from_lengths(query_lengths)
cu_k, max_k = _tri_cu_seq_lens_from_lengths(key_lengths)
attn_kwargs.update(
{
"cu_seq_lens_q": cu_q,
"cu_seq_lens_k": cu_k,
"max_length_q": max_q,
"max_length_k": max_k,
}
)
restore_is_causal = bool(attn_module.is_causal)
attn_module.is_causal = False
else:
attn_impl = "eager"
attention_interface = eager_attention_forward
if attn_impl != "flash_attention_2":
replay_mask = _tri_build_replay_attention_mask(
batch_size=batch_size,
query_len=query_len,
local_valid_mask=local_valid_mask,
memory_valid_mask=memory_valid_mask,
module_type=module_type,
device=hidden_states.device,
dtype=hidden_states.dtype,
)
try:
attn_output, _ = attention_interface(
attn_module,
query_states,
key_states,
value_states,
replay_mask,
dropout=0.0 if not attn_module.training else attn_module.attention_dropout,
scaling=attn_module.scaling,
**attn_kwargs,
)
finally:
if restore_is_causal is not None:
attn_module.is_causal = restore_is_causal
attn_output = _apply_sdpa_output_gate(attn_module, hidden_states, attn_output)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = attn_module.o_proj(attn_output)
if query_replay_mask is not None:
query_replay_mask = query_replay_mask.to(device=hidden_states.device, dtype=torch.bool)
attn_output = attn_output.masked_fill(~query_replay_mask.unsqueeze(-1), 0)
return attn_output
def _tri_run_last_layer_with_replay(
*,
layer: "LlamaDecoderLayer",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.LongTensor],
past_key_values: Optional[Cache],
use_cache: Optional[bool],
cache_position: Optional[torch.LongTensor],
position_embeddings: tuple[torch.Tensor, torch.Tensor],
memory_hidden_states: Optional[torch.Tensor],
memory_position_ids: Optional[torch.LongTensor],
memory_valid_mask: Optional[torch.Tensor],
local_valid_mask: Optional[torch.Tensor],
query_replay_mask: Optional[torch.Tensor],
module_type: str,
rotary_emb: "LlamaRotaryEmbedding",
use_checkpoint: bool,
) -> torch.Tensor:
residual = hidden_states
hidden_states = layer.input_layernorm(hidden_states)
module_type = _normalize_last_layer_module(module_type)
has_memory = (
module_type != "none"
and isinstance(memory_hidden_states, torch.Tensor)
and memory_hidden_states.numel() > 0
and int(memory_hidden_states.size(1)) > 0
and isinstance(memory_position_ids, torch.Tensor)
)
attn_kwargs = {}
if has_memory and module_type == "self":
replay_position_embeddings = rotary_emb(memory_hidden_states, memory_position_ids)
memory_key_states, memory_value_states = _tri_project_memory_kv(
attn_module=layer.self_attn,
memory_hidden_states=memory_hidden_states,
memory_position_embeddings=replay_position_embeddings,
target_device=hidden_states.device,
target_dtype=hidden_states.dtype,
)
attn_kwargs.update(
{
"extra_prefix_kv": (memory_key_states, memory_value_states),
"extra_prefix_valid_mask": memory_valid_mask,
"extra_prefix_query_mask": query_replay_mask,
"extra_prefix_local_valid_mask": local_valid_mask,
}
)
hidden_states, _ = layer.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**attn_kwargs,
)
hidden_states = residual + hidden_states
if (
has_memory
and module_type == "cross"
):
replay_attn = layer.ensure_replay_cross_attn()
replay_residual = hidden_states
replay_hidden = layer.post_attention_layernorm(hidden_states)
replay_position_embeddings = rotary_emb(memory_hidden_states, memory_position_ids)
replay_out = _tri_replay_attention_forward(
attn_module=replay_attn,
hidden_states=replay_hidden,
position_embeddings=position_embeddings,
local_valid_mask=local_valid_mask,
query_replay_mask=query_replay_mask,
memory_hidden_states=memory_hidden_states,
memory_position_embeddings=replay_position_embeddings,
memory_valid_mask=memory_valid_mask,
module_type=module_type,
)
hidden_states = replay_residual + replay_out
residual = hidden_states
hidden_states = layer.post_attention_layernorm(hidden_states)
if use_checkpoint:
hidden_states = _checkpoint_mlp(layer.mlp, hidden_states)
else:
hidden_states = layer.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
def _tri_upper_solo_v2_attention_forward(
attn_module: "LlamaAttention",
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
*,
with_bos: bool = False,
) -> torch.Tensor:
if hidden_states.size(1) == 0:
return hidden_states
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, attn_module.head_dim)
value_states = attn_module.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = repeat_kv(value_states, attn_module.num_key_value_groups)
if not bool(with_bos):
attn_output = value_states.transpose(1, 2).contiguous()
attn_output = _apply_sdpa_output_gate(attn_module, hidden_states, attn_output)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
return attn_module.o_proj(attn_output)
query_states = attn_module.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = attn_module.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
key_states = repeat_kv(key_states, attn_module.num_key_value_groups)
bos_hidden_states = hidden_states[:, :1, :]
bos_shape = (bos_hidden_states.size(0), bos_hidden_states.size(1), -1, attn_module.head_dim)
bos_key_states = attn_module.k_proj(bos_hidden_states).view(bos_shape).transpose(1, 2)
bos_value_states = attn_module.v_proj(bos_hidden_states).view(bos_shape).transpose(1, 2)
bos_cos = cos[:, :1, :]
bos_sin = sin[:, :1, :]
bos_key_states = (bos_key_states * bos_cos.unsqueeze(1)) + (rotate_half(bos_key_states) * bos_sin.unsqueeze(1))
bos_key_states = repeat_kv(bos_key_states, attn_module.num_key_value_groups)
bos_value_states = repeat_kv(bos_value_states, attn_module.num_key_value_groups)
bos_scores = (
query_states * bos_key_states.expand(-1, -1, query_states.size(2), -1)
).sum(dim=-1) * attn_module.scaling
self_scores = (query_states * key_states).sum(dim=-1) * attn_module.scaling
attn_weights = torch.stack([bos_scores, self_scores], dim=-1)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights,
p=0.0 if not attn_module.training else attn_module.attention_dropout,
training=attn_module.training,
)
bos_weight = attn_weights[..., :1]
self_weight = attn_weights[..., 1:2]
attn_output = (
bos_weight * bos_value_states.expand(-1, -1, query_states.size(2), -1)
+ self_weight * value_states
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = _apply_sdpa_output_gate(attn_module, hidden_states, attn_output)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
return attn_module.o_proj(attn_output)
def _tri_run_decoder_layer_upper_solo_v2(
*,
layer: "LlamaDecoderLayer",
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
with_bos: bool,
use_checkpoint: bool,
) -> torch.Tensor:
residual = hidden_states
attn_input = layer.input_layernorm(hidden_states)
attn_output = _tri_upper_solo_v2_attention_forward(
layer.self_attn,
attn_input,
position_embeddings,
with_bos=bool(with_bos),
)
hidden_states = residual + attn_output
residual = hidden_states
hidden_states = layer.post_attention_layernorm(hidden_states)
if use_checkpoint:
hidden_states = _checkpoint_mlp(layer.mlp, hidden_states)
else:
hidden_states = layer.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
def _tri_compact_solo_v2_upper_training_inputs(
*,
hidden_states: torch.Tensor,
position_ids: torch.LongTensor,
labels: torch.LongTensor,
valid_lens: torch.LongTensor,
logits_to_keep: int,
with_bos: bool,
):
valid_lens_device = valid_lens.to(device=hidden_states.device, dtype=torch.long)
keep_lens = torch.minimum(
valid_lens_device,
torch.full_like(valid_lens_device, int(logits_to_keep)),
)
split_starts = (valid_lens_device - keep_lens).clamp_min(0)
if bool(with_bos) and bool(torch.any(split_starts <= 0).item()):
return None
upper_hidden, upper_valid_mask, upper_lens, _ = _tri_pack_suffix_tensor(
hidden_states,
split_starts=split_starts,
valid_lens=valid_lens,
pad_value=0.0,
)
upper_position_ids, _, _, _ = _tri_pack_suffix_tensor(
position_ids.to(device=hidden_states.device),
split_starts=split_starts,
valid_lens=valid_lens,
pad_value=0,
)
upper_labels, _, _, _ = _tri_pack_suffix_tensor(
labels.to(device=hidden_states.device),
split_starts=split_starts,
valid_lens=valid_lens,
pad_value=-100,
)
if bool(with_bos):
bos_valid = torch.ones((hidden_states.size(0), 1), device=hidden_states.device, dtype=torch.bool)
upper_hidden = torch.cat([hidden_states[:, :1, :], upper_hidden], dim=1)
upper_position_ids = torch.cat(
[position_ids.to(device=hidden_states.device)[:, :1], upper_position_ids],
dim=1,
)
bos_labels = torch.full(
(labels.size(0), 1),
-100,
device=hidden_states.device,
dtype=labels.dtype,
)
upper_labels = torch.cat([bos_labels, upper_labels], dim=1)
upper_valid_mask = torch.cat([bos_valid, upper_valid_mask], dim=1)
upper_lens = upper_lens + 1
return upper_hidden, upper_position_ids, upper_labels, upper_valid_mask, upper_lens
def _build_tri_mask_prefix_full(B: int, prefix_len: int, total_len: int, device, dtype):
# Prefix-full: prefix tokens attend bidirectionally within prefix only,
# assistant tokens attend causally over prefix + assistant.
if total_len <= 0:
return torch.zeros((B, 1, 0, 0), device=device, dtype=dtype)
P = max(0, min(int(prefix_len), int(total_len)))
i = torch.arange(total_len, device=device).unsqueeze(1)
j = torch.arange(total_len, device=device).unsqueeze(0)
allow = torch.where(i < P, j < P, j <= i)
neg_inf = torch.tensor(torch.finfo(torch.float32).min, device=device)
mask = torch.where(allow, torch.tensor(0.0, device=device), neg_inf).to(dtype)
return mask.view(1, 1, total_len, total_len).expand(B, 1, total_len, total_len)
def _build_tri_mask_prefix_full_varlen(
prefix_lens: torch.LongTensor,
attention_mask: torch.Tensor,
dtype,
):
# Batched prefix-full mask with per-sample prefix lengths and right padding.
if attention_mask.dim() != 2:
raise ValueError("attention_mask must be rank-2 for variable-length prefix-full masking.")
B, total_len = attention_mask.shape
device = attention_mask.device
if total_len <= 0:
return torch.zeros((B, 1, 0, 0), device=device, dtype=dtype)
prefix_lens = prefix_lens.to(device=device, dtype=torch.long).clamp_(min=0, max=total_len)
i = torch.arange(total_len, device=device, dtype=torch.long).view(1, total_len, 1)
j = torch.arange(total_len, device=device, dtype=torch.long).view(1, 1, total_len)
prefix = prefix_lens.view(B, 1, 1)
allow = torch.where(i < prefix, j < prefix, j <= i)
key_valid = attention_mask.to(device=device, dtype=torch.bool).view(B, 1, total_len)
allow = allow & key_valid
neg_inf = torch.tensor(torch.finfo(torch.float32).min, device=device)
mask = torch.where(allow, torch.tensor(0.0, device=device), neg_inf).to(dtype)
return mask.view(B, 1, total_len, total_len)
def _can_use_implicit_causal_mask(config) -> bool:
attn_impl = _resolve_attn_impl(config)
if attn_impl not in {"sdpa", "flash_attention_2", "flex_attention"}:
return False
# Llama path assumes no sliding attention; keep explicit masks otherwise.
return getattr(config, "sliding_window", None) is None
def _llopa_upper_indices(S: int, U: int, A: int, device) -> torch.LongTensor:
if U <= 0:
return torch.arange(0, S + A, device=device, dtype=torch.long)
if A <= 0:
return torch.arange(0, S, device=device, dtype=torch.long)
left = torch.arange(0, S, device=device, dtype=torch.long)
right = torch.arange(S + U, S + U + A, device=device, dtype=torch.long)
return torch.cat([left, right], dim=0)
def _llopa_default_attention_mask(token_ids: torch.Tensor, attention_mask: Optional[torch.Tensor]) -> torch.LongTensor:
if attention_mask is None:
return torch.ones_like(token_ids, dtype=torch.long)
return attention_mask.to(device=token_ids.device, dtype=torch.long)
def _llopa_mask_is_all_ones(attention_mask: Optional[torch.Tensor]) -> bool:
if attention_mask is None:
return True
return bool(torch.all(attention_mask > 0).item())
def _llopa_position_ids_from_mask(attention_mask: torch.LongTensor) -> torch.LongTensor:
if attention_mask.numel() == 0:
return attention_mask.to(dtype=torch.long)
position_ids = attention_mask.cumsum(dim=1, dtype=torch.long) - 1
return position_ids.clamp_min_(0)
def _llopa_profile_enabled(model) -> bool:
return bool(getattr(model, "_llopa_profile_memory_enabled", False))
def _llopa_profile_step(model) -> int:
return int(getattr(model, "_llopa_profile_memory_step", 0) or 0)
def _llopa_profile_rank() -> int:
if torch.distributed.is_available() and torch.distributed.is_initialized():
return int(torch.distributed.get_rank())
return 0
def _llopa_segment_length(token_ids: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor]) -> int:
ref = attention_mask if attention_mask is not None else token_ids
if ref is None:
return 0
if attention_mask is not None and attention_mask.numel() > 0:
return int(attention_mask.sum(dim=1, dtype=torch.long).max().item())
if ref.dim() >= 2:
return int(ref.size(1))
return int(ref.numel())
def _llopa_memory_device(
*tensors: Optional[torch.Tensor],
) -> Optional[torch.device]:
for tensor in tensors:
if isinstance(tensor, torch.Tensor):
return tensor.device
return None
def _llopa_reset_peak_memory(*tensors: Optional[torch.Tensor]) -> None:
device = _llopa_memory_device(*tensors)
if device is None or device.type != "cuda":
return
torch.cuda.reset_peak_memory_stats(device=device)
def _llopa_log_memory(
model,
stage: str,
*,
system_ids: Optional[torch.Tensor] = None,
user_ids: Optional[torch.Tensor] = None,
assistant_ids: Optional[torch.Tensor] = None,
system_attention_mask: Optional[torch.Tensor] = None,
user_attention_mask: Optional[torch.Tensor] = None,
assistant_attention_mask: Optional[torch.Tensor] = None,
sequence_len: Optional[int] = None,
) -> None:
if not _llopa_profile_enabled(model):
return
device = None
device = _llopa_memory_device(
system_ids,
user_ids,
assistant_ids,
system_attention_mask,
user_attention_mask,
assistant_attention_mask,
)
if device is None or device.type != "cuda":
return
logger.info(
"[LLOPA_MEM] step=%s rank=%s stage=%s peak_scope=since_reset system_len=%s user_len=%s assistant_len=%s sequence_len=%s "
"allocated_GiB=%.3f reserved_GiB=%.3f max_allocated_GiB=%.3f max_reserved_GiB=%.3f",
_llopa_profile_step(model),
_llopa_profile_rank(),
stage,
_llopa_segment_length(system_ids, system_attention_mask),
_llopa_segment_length(user_ids, user_attention_mask),
_llopa_segment_length(assistant_ids, assistant_attention_mask),
-1 if sequence_len is None else int(sequence_len),
torch.cuda.memory_allocated(device=device) / 2**30,
torch.cuda.memory_reserved(device=device) / 2**30,
torch.cuda.max_memory_allocated(device=device) / 2**30,
torch.cuda.max_memory_reserved(device=device) / 2**30,
)
def _llopa_pack_upper_hidden(hidden_states: torch.Tensor, S: int, U: int) -> torch.Tensor:
if U <= 0:
return hidden_states
B, _, H = hidden_states.shape
A = hidden_states.size(1) - S - U
packed = hidden_states.new_empty((B, S + A, H))
if S > 0:
packed[:, :S, :] = hidden_states[:, :S, :]
if A > 0:
packed[:, S:, :] = hidden_states[:, S + U :, :]
return packed
def _llopa_pack_upper_tokens(token_tensor: torch.Tensor, S: int, U: int) -> torch.Tensor:
if U <= 0:
return token_tensor
B = token_tensor.size(0)
A = token_tensor.size(1) - S - U
packed = token_tensor.new_empty((B, S + A))
if S > 0:
packed[:, :S] = token_tensor[:, :S]
if A > 0:
packed[:, S:] = token_tensor[:, S + U :]
return packed
def _llopa_insert_specials(upper_hidden: torch.Tensor, specials: torch.Tensor, S: int) -> torch.Tensor:
B, compact_len, H = upper_hidden.shape
num_specials = specials.size(1)
layer_hidden = upper_hidden.new_empty((B, compact_len + num_specials, H))
if S > 0:
layer_hidden[:, :S, :] = upper_hidden[:, :S, :]
layer_hidden[:, S : S + num_specials, :] = specials
if compact_len > S:
layer_hidden[:, S + num_specials :, :] = upper_hidden[:, S:, :]
return layer_hidden
def _llopa_unpack_upper_hidden(layer_hidden: torch.Tensor, S: int, num_specials: int) -> torch.Tensor:
if num_specials <= 0:
return layer_hidden
B, expanded_len, H = layer_hidden.shape
compact = layer_hidden.new_empty((B, expanded_len - num_specials, H))
if S > 0:
compact[:, :S, :] = layer_hidden[:, :S, :]
if expanded_len > S + num_specials:
compact[:, S:, :] = layer_hidden[:, S + num_specials :, :]
return compact
def _llopa_remove_specials(layer_hidden: torch.Tensor, S: int, num_specials: int) -> torch.Tensor:
return _llopa_unpack_upper_hidden(layer_hidden, S, num_specials)
def _llopa_use_optimized_training_path(
model,
*,
attn_impl: str,
prefix_full: bool,
no_upper_attn: bool,
) -> bool:
return (
bool(getattr(model, "training", False))
and attn_impl == "flash_attention_2"
and not prefix_full
and not bool(no_upper_attn)
)
def _llopa_warn_once(model, flag: str, msg: str) -> None:
if getattr(model, flag, False):
return
try:
setattr(model, flag, True)
except Exception:
pass
logger.warning(msg)
def _llopa_split_system(system_ids: torch.Tensor, system_prefill: str):
"""Return (system_upper, system_lower_extra) based on system_prefill."""
sys_prefill = (system_prefill or "full").strip().lower()
if sys_prefill not in {"full", "no_system", "no_bos_system"}:
sys_prefill = "full"
if sys_prefill == "full":
return system_ids, system_ids[:, :0]
if sys_prefill == "no_system":
if system_ids.size(1) < 1:
return system_ids[:, :0], system_ids[:, :0]
return system_ids[:, :1], system_ids[:, 1:]
return system_ids[:, :0], system_ids
def _llopa_merge_user(system_ids: torch.Tensor, user_ids: torch.Tensor, system_prefill: str):
sys_upper, sys_lower_extra = _llopa_split_system(system_ids, system_prefill)
if sys_lower_extra.numel() == 0:
return sys_upper, user_ids
if user_ids.numel() == 0:
return sys_upper, sys_lower_extra
return sys_upper, torch.cat([sys_lower_extra, user_ids], dim=1)
def _llopa_effective_specials(model, gap_u: int) -> int:
num = int(getattr(model, "llopa_num_specials", 0) or 0)
if num <= 0:
return 0
if getattr(model, "llopa_specials", None) is None:
_llopa_warn_once(model, "_llopa_warned_no_specials",
"[LLOPA] llopa_num_specials set but llopa_specials missing; specials disabled.")
return 0
if gap_u <= 0:
_llopa_warn_once(model, "_llopa_warned_no_gap",
f"[LLOPA] num_specials={num} but gap=0 (U=0); specials disabled.")
return 0
if num > gap_u:
_llopa_warn_once(model, "_llopa_warned_short_gap",
f"[LLOPA] num_specials={num} exceeds gap U={gap_u}; using {gap_u} specials.")
return gap_u
return num
def _llopa_can_cache_special_kv(model) -> bool:
return (not bool(getattr(model, "training", False))) and (not torch.is_grad_enabled())
def _llopa_special_kv_for_layer(
model: "LlamaModel",
layer: "LlamaDecoderLayer",
layer_idx: int,
num_specials_eff: int,
start_pos: int,
batch_size: int,
dtype: torch.dtype,
device,
) -> Tuple[torch.Tensor, torch.Tensor]:
if num_specials_eff <= 0:
empty = torch.empty((batch_size, layer.self_attn.num_key_value_heads, 0, layer.self_attn.head_dim),
device=device, dtype=dtype)
return empty, empty
specials_all = getattr(model, "llopa_specials", None)
if specials_all is None:
empty = torch.empty((batch_size, layer.self_attn.num_key_value_heads, 0, layer.self_attn.head_dim),
device=device, dtype=dtype)
return empty, empty
can_cache = _llopa_can_cache_special_kv(model)
cache = None
cache_key = None
param_ptr = None
param_ver = None
if can_cache:
cache = getattr(model, "_llopa_special_kv_cache", None)
if not isinstance(cache, dict):
cache = {}
setattr(model, "_llopa_special_kv_cache", cache)
dev_index = -1 if getattr(device, "index", None) is None else int(device.index)
cache_key = (int(layer_idx), int(num_specials_eff), int(start_pos), str(dtype), str(device.type), dev_index)
try:
param_ptr = int(specials_all[layer_idx].data_ptr())
except Exception:
param_ptr = None
try:
param_ver = int(getattr(specials_all[layer_idx], "_version", 0))
except Exception:
param_ver = None
entry = cache.get(cache_key)
if isinstance(entry, dict) and entry.get("param_ptr") == param_ptr and entry.get("param_ver") == param_ver:
k_base = entry.get("k")
v_base = entry.get("v")
if isinstance(k_base, torch.Tensor) and isinstance(v_base, torch.Tensor):
if k_base.device == device and v_base.device == device:
if k_base.dtype == dtype and v_base.dtype == dtype:
if k_base.size(0) == 1 and batch_size > 1:
return k_base.expand(batch_size, -1, -1, -1), v_base.expand(batch_size, -1, -1, -1)
return k_base, v_base
specials = specials_all[layer_idx][:num_specials_eff].to(device=device, dtype=dtype)
if can_cache:
specials = specials.unsqueeze(0)
sp_pos = torch.arange(start_pos, start_pos + num_specials_eff, device=device, dtype=torch.long).unsqueeze(0)
else:
specials = specials.unsqueeze(0).expand(batch_size, -1, -1)
sp_pos = torch.arange(start_pos, start_pos + num_specials_eff, device=device, dtype=torch.long).unsqueeze(0)
sp_pos = sp_pos.expand(batch_size, -1)
specials = layer.input_layernorm(specials)
hidden_shape = (*specials.shape[:-1], -1, layer.self_attn.head_dim)
key_states = layer.self_attn.k_proj(specials).view(hidden_shape).transpose(1, 2)
value_states = layer.self_attn.v_proj(specials).view(hidden_shape).transpose(1, 2)
cos_sp, sin_sp = model.rotary_emb(specials, sp_pos)
key_states = (key_states * cos_sp.unsqueeze(1)) + (rotate_half(key_states) * sin_sp.unsqueeze(1))
if can_cache:
key_states = key_states.detach()
value_states = value_states.detach()
if cache is not None and cache_key is not None:
cache[cache_key] = {"k": key_states, "v": value_states, "param_ptr": param_ptr, "param_ver": param_ver}
try:
max_entries = int(os.environ.get("LLOPA_SPECIAL_KV_CACHE_MAX", "128") or 128)
except Exception:
max_entries = 128
if max_entries > 0:
while len(cache) > max_entries:
cache.pop(next(iter(cache)))
if batch_size > 1:
return key_states.expand(batch_size, -1, -1, -1), value_states.expand(batch_size, -1, -1, -1)
return key_states, value_states
def _tri_prefix_keep_lengths(system_lens: torch.LongTensor, system_prefill: str) -> torch.LongTensor:
system_lens = system_lens.to(dtype=torch.long)
mode = normalize_system_prefill(system_prefill)
if mode == "full":
return system_lens
if mode == "no_system":
return torch.minimum(system_lens, torch.ones_like(system_lens))
return torch.zeros_like(system_lens)
def _tri_effective_suffix_special_token_ids(model) -> list[int]:
config = getattr(model, "config", None)
num_suffix_specials = int(getattr(config, "capsule_num_suffix_specials", 0) or 0)
if num_suffix_specials <= 0:
return []
raw_ids = getattr(config, "capsule_suffix_special_token_ids", None)
if not isinstance(raw_ids, (list, tuple)) or len(raw_ids) == 0:
_tri_vanilla_warn_once(
model,
"_tri_warned_missing_suffix_special_ids",
"[TRI] capsule_num_suffix_specials is set but capsule_suffix_special_token_ids is missing; fusion specials disabled.",
)
return []
token_ids: list[int] = []
for raw_id in list(raw_ids)[:num_suffix_specials]:
try:
token_id = int(raw_id)
except Exception:
continue
if token_id >= 0:
token_ids.append(token_id)
if len(token_ids) < num_suffix_specials:
_tri_vanilla_warn_once(
model,
"_tri_warned_short_suffix_special_ids",
"[TRI] Some configured fusion special token ids are invalid; using the available subset.",
)
return token_ids
def _tri_normalize_fusion_mode(mode) -> str:
normalized = str(mode or "upper_only").strip().lower()
return normalized or "upper_only"
def _tri_effective_fusion_mode(model) -> str:
config = getattr(model, "config", None)
return _tri_normalize_fusion_mode(getattr(config, "capsule_fusion_mode", "upper_only"))
def _tri_build_suffix_special_position_ids(
split_starts: torch.LongTensor,
*,
num_suffix_specials: int,
device,
) -> torch.LongTensor:
split_starts = split_starts.to(device=device, dtype=torch.long)
if num_suffix_specials <= 0:
return torch.empty((split_starts.size(0), 0), device=device, dtype=torch.long)
offsets = torch.arange(num_suffix_specials, device=device, dtype=torch.long).unsqueeze(0)
return (split_starts.unsqueeze(1) - int(num_suffix_specials) + offsets).clamp_min_(0)
def _tri_repack_upper_with_suffix_specials(
model: "LlamaForCausalLM",
*,
upper_hidden: torch.Tensor,
upper_position_ids: torch.LongTensor,
upper_attention_mask: torch.Tensor,
decode_labels: torch.LongTensor,
prefix_keep_lens: torch.LongTensor,
split_starts: torch.LongTensor,
) -> tuple[torch.Tensor, torch.LongTensor, torch.Tensor, torch.LongTensor, torch.LongTensor]:
if _tri_effective_fusion_mode(model) != "upper_only":
valid_lens = upper_attention_mask.sum(dim=1, dtype=torch.long)
return upper_hidden, upper_position_ids, upper_attention_mask, decode_labels, valid_lens
token_ids = _tri_effective_suffix_special_token_ids(model)
if not token_ids:
valid_lens = upper_attention_mask.sum(dim=1, dtype=torch.long)
return upper_hidden, upper_position_ids, upper_attention_mask, decode_labels, valid_lens
B, _, H = upper_hidden.shape
device = upper_hidden.device
dtype = upper_hidden.dtype
prefix_keep_lens = prefix_keep_lens.to(device=device, dtype=torch.long)
split_starts = split_starts.to(device=device, dtype=torch.long)
valid_lens = upper_attention_mask.sum(dim=1, dtype=torch.long).to(device=device)
suffix_lens = (valid_lens - prefix_keep_lens).clamp_min_(0)
num_suffix_specials = len(token_ids)
token_id_tensor = torch.tensor(token_ids, device=device, dtype=torch.long).unsqueeze(0).expand(B, -1)
special_hidden = model.model.embed_tokens(token_id_tensor).to(dtype=dtype)
special_position_ids = _tri_build_suffix_special_position_ids(
split_starts,
num_suffix_specials=num_suffix_specials,
device=device,
)
repacked_valid_lens = prefix_keep_lens + num_suffix_specials + suffix_lens
total_len = int(repacked_valid_lens.max().item()) if B > 0 else 0
repacked_hidden = upper_hidden.new_zeros((B, total_len, H))
repacked_position_ids = upper_position_ids.new_zeros((B, total_len))
repacked_attention_mask = upper_attention_mask.new_zeros((B, total_len))
repacked_labels = decode_labels.new_full((B, total_len), -100)
for row in range(B):
prefix_keep = int(prefix_keep_lens[row].item())
suffix_len = int(suffix_lens[row].item())
special_start = prefix_keep
if prefix_keep > 0:
repacked_hidden[row, :prefix_keep, :] = upper_hidden[row, :prefix_keep, :]
repacked_position_ids[row, :prefix_keep] = upper_position_ids[row, :prefix_keep]
repacked_attention_mask[row, :prefix_keep] = upper_attention_mask[row, :prefix_keep]
repacked_labels[row, :prefix_keep] = decode_labels[row, :prefix_keep]
repacked_hidden[row, special_start : special_start + num_suffix_specials, :] = special_hidden[row]
repacked_position_ids[row, special_start : special_start + num_suffix_specials] = special_position_ids[row]
repacked_attention_mask[row, special_start : special_start + num_suffix_specials] = 1
if suffix_len > 0:
src_start = prefix_keep
dst_start = special_start + num_suffix_specials
repacked_hidden[row, dst_start : dst_start + suffix_len, :] = upper_hidden[row, src_start : src_start + suffix_len, :]
repacked_position_ids[row, dst_start : dst_start + suffix_len] = upper_position_ids[row, src_start : src_start + suffix_len]
repacked_attention_mask[row, dst_start : dst_start + suffix_len] = upper_attention_mask[row, src_start : src_start + suffix_len]
repacked_labels[row, dst_start : dst_start + suffix_len] = decode_labels[row, src_start : src_start + suffix_len]
return repacked_hidden, repacked_position_ids, repacked_attention_mask, repacked_labels, repacked_valid_lens
def _tri_prefill_suffix_specials_upper_cache(
model: "LlamaModel",
pkv: Cache,
*,
lower_k: int,
split_start: Optional[int] = None,
split_starts=None,
no_upper_attn: bool,
) -> int:
if _tri_effective_fusion_mode(model) != "upper_only":
return 0
token_ids = _tri_effective_suffix_special_token_ids(model)
if not token_ids:
return 0
if bool(no_upper_attn):
_tri_vanilla_warn_once(
model,
"_tri_warned_suffix_specials_no_upper_attn",
"[TRI] Fusion specials are enabled, but no_upper_attn=True disables upper-layer attention; skipping fusion specials.",
)
return 0
if bool(getattr(pkv, "_capsule_suffix_specials_written", False)):
return int(getattr(pkv, "_capsule_suffix_specials_count", 0) or 0)
n_layers = len(model.layers)
K = max(0, min(int(lower_k), n_layers))
if K >= n_layers:
return 0
split_positions: list[int] = []
if isinstance(split_starts, torch.Tensor):
split_positions = [int(v) for v in split_starts.view(-1).tolist()]
elif isinstance(split_starts, (list, tuple)):
split_positions = [int(v) for v in split_starts]
elif split_start is not None:
split_positions = [int(split_start)]
split_positions = sorted({int(v) for v in split_positions if int(v) >= 0})
if not split_positions:
return 0
device = model.embed_tokens.weight.device
token_id_blocks: list[int] = []
for _ in split_positions:
token_id_blocks.extend(token_ids)
token_id_tensor = torch.tensor(token_id_blocks, device=device, dtype=torch.long).unsqueeze(0)
hidden_states = model.embed_tokens(token_id_tensor)
num_suffix_specials = int(token_id_tensor.size(1))
split_start_tensor = torch.tensor(split_positions, device=device, dtype=torch.long)
cache_position = _tri_build_suffix_special_position_ids(
split_start_tensor,
num_suffix_specials=len(token_ids),
device=device,
).reshape(-1)
position_ids = cache_position.unsqueeze(0)
position_embeddings = model.rotary_emb(hidden_states, position_ids)
attn_impl = _resolve_attn_impl(model.config)
implicit_causal = attn_impl == "flash_attention_2" or _can_use_implicit_causal_mask(model.config)
past_lens = [_layer_past_len(pkv, li) for li in range(n_layers)]
for li in range(K, n_layers):
layer = model.layers[li]
if implicit_causal:
attn_mask = None
else:
attn_mask = _build_tri_mask_local(
1,
num_suffix_specials,
int(past_lens[li]),
device,
hidden_states.dtype,
)
hidden_states = layer(
hidden_states,
attention_mask=attn_mask,
position_ids=position_ids,
past_key_values=pkv,
use_cache=True,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
try:
setattr(pkv, "_tri_past_len_cache", None)
setattr(pkv, "_capsule_suffix_specials_written", True)
setattr(pkv, "_capsule_suffix_specials_count", num_suffix_specials)
except Exception:
pass
return num_suffix_specials
def _tri_prefill_upper_history_cache(
model: "LlamaModel",
pkv: Cache,
*,
hidden_states: torch.Tensor,
position_ids: torch.LongTensor,
lower_k: int,
no_upper_attn: bool,
) -> int:
if bool(no_upper_attn):
return 0
if not isinstance(hidden_states, torch.Tensor) or hidden_states.numel() == 0 or int(hidden_states.size(1)) <= 0:
return 0
if not isinstance(position_ids, torch.Tensor) or position_ids.shape[:2] != hidden_states.shape[:2]:
return 0
n_layers = len(model.layers)
K = max(0, min(int(lower_k), n_layers))
if K >= n_layers:
return 0
device = hidden_states.device
token_count = int(hidden_states.size(1))
cache_position = position_ids.reshape(-1).to(device=device, dtype=torch.long)
layer_position_ids = position_ids.to(device=device, dtype=torch.long)
position_embeddings = model.rotary_emb(hidden_states, layer_position_ids)
attn_impl = _resolve_attn_impl(model.config)
implicit_causal = attn_impl == "flash_attention_2" or _can_use_implicit_causal_mask(model.config)
past_lens = [_layer_past_len(pkv, li) for li in range(n_layers)]
for li in range(K, n_layers):
layer = model.layers[li]
if implicit_causal:
attn_mask = None
else:
attn_mask = _build_tri_mask_local(
int(hidden_states.size(0)),
token_count,
int(past_lens[li]),
device,
hidden_states.dtype,
)
hidden_states = layer(
hidden_states,
attention_mask=attn_mask,
position_ids=layer_position_ids,
past_key_values=pkv,
use_cache=True,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
try:
setattr(pkv, "_tri_past_len_cache", None)
except Exception:
pass
return token_count
def tri_prefill_system_all(
self: "LlamaModel",
system_ids: torch.LongTensor,
past_key_values: Optional[Cache] = None,
prefill_attn: str = "causal",
):
device = system_ids.device
if past_key_values is None:
past_key_values = _safe_dynamic_cache(self.config)
inputs_embeds = self.embed_tokens(system_ids)
start = past_key_values.get_seq_length()
cache_position = _tri_arange(start, inputs_embeds.shape[1], device)
position_ids = cache_position.unsqueeze(0)
attention_mask = torch.ones_like(system_ids, dtype=torch.long, device=device)
attn_impl = _resolve_attn_impl(self.config)
if (prefill_attn or "causal") == "full":
if attn_impl == "flash_attention_2":
raise ValueError("prefill_attn=full is not supported with flash_attention_2.")
causal_mask = _build_tri_mask_full(system_ids.size(0), inputs_embeds.shape[1], start, device, inputs_embeds.dtype)
elif attn_impl == "flash_attention_2" or _can_use_implicit_causal_mask(self.config):
causal_mask = None
else:
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for layer in self.layers:
hidden_states = layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=past_key_values)
def tri_prefill_user_lower(
self: "LlamaModel",
user_ids: torch.LongTensor,
lower_k: int,
past_key_values: Cache,
prefill_mode: str = "lower",
prefill_attn: str = "causal",
return_pre_norm_hidden: bool = False,
):
device = user_ids.device
inputs_embeds = self.embed_tokens(user_ids)
start = past_key_values.get_seq_length()
cache_position = _tri_arange(start, inputs_embeds.shape[1], device)
position_ids = cache_position.unsqueeze(0)
attention_mask = torch.ones_like(user_ids, dtype=torch.long, device=device)
attn_impl = _resolve_attn_impl(self.config)
if (prefill_attn or "causal") == "full":
if attn_impl == "flash_attention_2":
raise ValueError("prefill_attn=full is not supported with flash_attention_2.")
causal_mask = _build_tri_mask_full(user_ids.size(0), inputs_embeds.shape[1], start, device, inputs_embeds.dtype)
elif attn_impl == "flash_attention_2" or _can_use_implicit_causal_mask(self.config):
causal_mask = None
else:
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
select = _prefill_layer_indices(self.config.num_hidden_layers, prefill_mode, lower_k)
for li in select:
layer = self.layers[li]
hidden_states = layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
pre_norm_hidden = hidden_states
hidden_states = self.norm(hidden_states)
out = BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=past_key_values)
if bool(return_pre_norm_hidden):
setattr(out, "pre_norm_hidden", pre_norm_hidden)
return out
def tri_build_caches(
self: "LlamaModel",
system_ids: torch.LongTensor,
user_ids: torch.LongTensor,
lower_k: int,
prefill_mode: str = "lower",
prefill_attn: str = "causal",
) -> Tuple[Cache, int, int]:
out = self.tri_prefill_system_all(system_ids, past_key_values=None, prefill_attn=prefill_attn)
pkv = out.past_key_values
S = system_ids.size(1)
_ = self.tri_prefill_user_lower(
user_ids,
lower_k=lower_k,
past_key_values=pkv,
prefill_mode=prefill_mode,
prefill_attn=prefill_attn,
)
U = user_ids.size(1)
return pkv, S, U
def _tri_build_valid_cache_decode_mask(
valid_mask: torch.Tensor,
*,
query_len: int,
device,
dtype: torch.dtype,
) -> torch.Tensor:
valid_mask = valid_mask.to(device=device, dtype=torch.bool)
B, past_len = valid_mask.shape
query_len = int(query_len)
past_visible = valid_mask.unsqueeze(1).expand(B, query_len, past_len)
query_visible = torch.ones((B, query_len, query_len), device=device, dtype=torch.bool)
if query_len > 1:
query_visible = torch.tril(query_visible)
visible = torch.cat([past_visible, query_visible], dim=-1)
mask = torch.zeros((B, 1, query_len, past_len + query_len), device=device, dtype=dtype)
mask = mask.masked_fill(~visible.unsqueeze(1), torch.finfo(dtype).min)
return mask
def _tri_batch_packed_varlen_decode_enabled(hidden_states: torch.Tensor, *, batch_size: int, query_len: int) -> bool:
raw = str(os.environ.get("LLOPA_BATCH_PACKED_VARLEN", "1")).strip().lower()
if raw in {"0", "false", "no", "off"}:
return False
if _flash_attn_varlen_func is None:
return False
if int(batch_size) <= 1 or int(query_len) != 1:
return False
if not isinstance(hidden_states, torch.Tensor) or not hidden_states.is_cuda:
return False
if hidden_states.dtype not in (torch.float16, torch.bfloat16):
return False
return True
def _tri_packed_varlen_attention_decode(
attn_module: "LlamaAttention",
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
*,
past_key_values: Cache,
cache_position: torch.LongTensor,
layer_valid_mask: torch.Tensor,
) -> Optional[torch.Tensor]:
if _flash_attn_varlen_func is None:
return None
if hidden_states.dim() != 3 or int(hidden_states.size(1)) != 1:
return None
if not hidden_states.is_cuda or hidden_states.dtype not in (torch.float16, torch.bfloat16):
return None
if not isinstance(layer_valid_mask, torch.Tensor) or layer_valid_mask.dim() != 2:
return None
batch_size = int(hidden_states.size(0))
past_width = int(layer_valid_mask.size(1))
if batch_size <= 1 or int(layer_valid_mask.size(0)) != batch_size:
return None
if _layer_past_len(past_key_values, int(attn_module.layer_idx)) != past_width:
return None
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, attn_module.head_dim)
query_states = attn_module.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = attn_module.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = attn_module.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
cos_u = cos.unsqueeze(1)
sin_u = sin.unsqueeze(1)
query_states = (query_states * cos_u) + (rotate_half(query_states) * sin_u)
key_states = (key_states * cos_u) + (rotate_half(key_states) * sin_u)
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(
key_states,
value_states,
attn_module.layer_idx,
cache_kwargs,
)
key_width = int(key_states.size(2))
if key_width != past_width + 1:
raise RuntimeError(
f"packed varlen LLoPA cache length mismatch at layer {attn_module.layer_idx}: "
f"got {key_width}, expected {past_width + 1}"
)
old_valid = layer_valid_mask.to(device=hidden_states.device, dtype=torch.bool)
current_valid = torch.ones((batch_size, 1), device=hidden_states.device, dtype=torch.bool)
full_valid = torch.cat([old_valid, current_valid], dim=1)
key_lengths = full_valid.sum(dim=1, dtype=torch.int32)
if not bool(torch.all(key_lengths > 0).item()):
return None
key_seq = key_states.transpose(1, 2).contiguous()
value_seq = value_states.transpose(1, 2).contiguous()
packed_key = key_seq[full_valid].contiguous()
packed_value = value_seq[full_valid].contiguous()
packed_query = query_states.squeeze(2).contiguous()
cu_q = torch.arange(batch_size + 1, device=hidden_states.device, dtype=torch.int32)
cu_k = torch.cat(
[
torch.zeros((1,), device=hidden_states.device, dtype=torch.int32),
torch.cumsum(key_lengths, dim=0).to(dtype=torch.int32),
],
dim=0,
)
max_k = int(key_lengths.max().item())
attn_output = _flash_attn_varlen_func(
packed_query,
packed_key,
packed_value,
cu_seqlens_q=cu_q,
cu_seqlens_k=cu_k,
max_seqlen_q=1,
max_seqlen_k=max_k,
dropout_p=0.0 if not attn_module.training else attn_module.attention_dropout,
softmax_scale=attn_module.scaling,
causal=False,
)
if isinstance(attn_output, tuple):
attn_output = attn_output[0]
attn_output = attn_output.view(batch_size, 1, -1, attn_module.head_dim)
attn_output = _apply_sdpa_output_gate(attn_module, hidden_states, attn_output)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
return attn_module.o_proj(attn_output)
def _tri_run_decoder_layer_packed_varlen_decode(
*,
layer: "LlamaDecoderLayer",
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
past_key_values: Cache,
cache_position: torch.LongTensor,
layer_valid_mask: torch.Tensor,
use_checkpoint: bool,
) -> Optional[torch.Tensor]:
if use_checkpoint:
return None
residual = hidden_states
attn_input = layer.input_layernorm(hidden_states)
attn_output = _tri_packed_varlen_attention_decode(
layer.self_attn,
attn_input,
position_embeddings,
past_key_values=past_key_values,
cache_position=cache_position,
layer_valid_mask=layer_valid_mask,
)
if attn_output is None:
return None
hidden_states = residual + attn_output
residual = hidden_states
hidden_states = layer.post_attention_layernorm(hidden_states)
hidden_states = layer.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
def _tri_batch_packed_varlen_prefill_enabled(hidden_states: torch.Tensor, valid_mask: torch.Tensor) -> bool:
raw = str(os.environ.get("LLOPA_BATCH_PACKED_VARLEN", "1")).strip().lower()
if raw in {"0", "false", "no", "off"}:
return False
if _flash_attn_varlen_func is None:
return False
if not isinstance(hidden_states, torch.Tensor) or not hidden_states.is_cuda:
return False
if hidden_states.dtype not in (torch.float16, torch.bfloat16):
return False
if not isinstance(valid_mask, torch.Tensor) or valid_mask.dim() != 2 or int(valid_mask.size(0)) <= 1:
return False
lengths = valid_mask.to(device=hidden_states.device, dtype=torch.bool).sum(dim=1)
if not bool(torch.all(lengths > 0).item()):
return False
return True
def _tri_scatter_packed_hidden(
packed_hidden: torch.Tensor,
valid_mask: torch.Tensor,
*,
batch_size: int,
width: int,
) -> torch.Tensor:
out = torch.zeros(
(int(batch_size), int(width), int(packed_hidden.size(-1))),
device=packed_hidden.device,
dtype=packed_hidden.dtype,
)
out[valid_mask.to(device=packed_hidden.device, dtype=torch.bool)] = packed_hidden
return out
def _tri_scatter_packed_kv(
packed_key: torch.Tensor,
packed_value: torch.Tensor,
valid_mask: torch.Tensor,
*,
batch_size: int,
width: int,
) -> tuple[torch.Tensor, torch.Tensor]:
key_dense = torch.zeros(
(
int(batch_size),
int(width),
int(packed_key.size(1)),
int(packed_key.size(2)),
),
device=packed_key.device,
dtype=packed_key.dtype,
)
value_dense = torch.zeros(
(
int(batch_size),
int(width),
int(packed_value.size(1)),
int(packed_value.size(2)),
),
device=packed_value.device,
dtype=packed_value.dtype,
)
valid_mask = valid_mask.to(device=packed_key.device, dtype=torch.bool)
key_dense[valid_mask] = packed_key
value_dense[valid_mask] = packed_value
return key_dense.transpose(1, 2).contiguous(), value_dense.transpose(1, 2).contiguous()
def _tri_run_packed_varlen_prefill_layers(
*,
model: "LlamaModel",
hidden_states: torch.Tensor,
position_ids: torch.LongTensor,
valid_mask: torch.Tensor,
past_key_values: Cache,
start_layer: int,
end_layer: int,
use_checkpoint: bool,
) -> Optional[torch.Tensor]:
if use_checkpoint or _flash_attn_varlen_func is None:
return None
if hidden_states.dim() != 2 or position_ids.dim() != 1:
return None
if int(hidden_states.size(0)) != int(position_ids.numel()):
return None
if not _tri_batch_packed_varlen_prefill_enabled(hidden_states, valid_mask):
return None
valid_mask = valid_mask.to(device=hidden_states.device, dtype=torch.bool)
batch_size, width = int(valid_mask.size(0)), int(valid_mask.size(1))
lengths = valid_mask.sum(dim=1, dtype=torch.int32)
cu = torch.cat(
[
torch.zeros((1,), device=hidden_states.device, dtype=torch.int32),
torch.cumsum(lengths, dim=0).to(dtype=torch.int32),
],
dim=0,
)
max_len = int(lengths.max().item())
cache_position = torch.arange(width, device=hidden_states.device, dtype=torch.long)
for li in range(int(start_layer), int(end_layer)):
layer = model.layers[li]
residual = hidden_states
attn_input = layer.input_layernorm(hidden_states)
hidden_shape = (int(attn_input.size(0)), -1, layer.self_attn.head_dim)
query_states = layer.self_attn.q_proj(attn_input).view(hidden_shape)
key_states = layer.self_attn.k_proj(attn_input).view(hidden_shape)
value_states = layer.self_attn.v_proj(attn_input).view(hidden_shape)
cos, sin = model.rotary_emb(attn_input.view(1, int(attn_input.size(0)), -1), position_ids.view(1, -1))
cos = cos[0].unsqueeze(1)
sin = sin[0].unsqueeze(1)
query_states = (query_states * cos) + (rotate_half(query_states) * sin)
key_states = (key_states * cos) + (rotate_half(key_states) * sin)
key_dense, value_dense = _tri_scatter_packed_kv(
key_states,
value_states,
valid_mask,
batch_size=batch_size,
width=width,
)
cache_kwargs = {"sin": None, "cos": None, "cache_position": cache_position}
past_key_values.update(key_dense, value_dense, li, cache_kwargs)
attn_output = _flash_attn_varlen_func(
query_states.contiguous(),
key_states.contiguous(),
value_states.contiguous(),
cu_seqlens_q=cu,
cu_seqlens_k=cu,
max_seqlen_q=max_len,
max_seqlen_k=max_len,
dropout_p=0.0 if not layer.self_attn.training else layer.self_attn.attention_dropout,
softmax_scale=layer.self_attn.scaling,
causal=True,
)
if isinstance(attn_output, tuple):
attn_output = attn_output[0]
attn_output = attn_output.view(int(hidden_states.size(0)), 1, -1, layer.self_attn.head_dim)
attn_output = _apply_sdpa_output_gate(layer.self_attn, attn_input.view(int(hidden_states.size(0)), 1, -1), attn_output)
attn_output = attn_output.reshape(int(hidden_states.size(0)), -1).contiguous()
hidden_states = residual + layer.self_attn.o_proj(attn_output)
residual = hidden_states
hidden_states = layer.post_attention_layernorm(hidden_states)
hidden_states = layer.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
def _tri_append_batch_cache_valid_masks(pkv: Cache, input_valid: torch.Tensor) -> None:
masks = getattr(pkv, "_llopa_batch_layer_valid_masks", None)
if not isinstance(masks, list):
return
input_valid = input_valid.to(dtype=torch.bool)
new_masks = []
for layer_idx, mask in enumerate(masks):
if not isinstance(mask, torch.Tensor) or mask.dim() != 2:
new_masks.append(mask)
continue
layer_len = _layer_past_len(pkv, int(layer_idx))
cur_len = int(mask.size(1))
if layer_len > cur_len:
add_width = int(layer_len - cur_len)
add = input_valid.to(device=mask.device).view(mask.size(0), -1)
if int(add.size(1)) != add_width:
if int(add.size(1)) == 1:
add = add.expand(mask.size(0), add_width)
else:
add = add[:, :add_width]
if int(add.size(1)) < add_width:
pad = torch.zeros(
(mask.size(0), add_width - int(add.size(1))),
device=mask.device,
dtype=torch.bool,
)
add = torch.cat([add, pad], dim=1)
mask = torch.cat([mask, add], dim=1)
elif layer_len < cur_len:
mask = mask[:, :layer_len]
new_masks.append(mask)
with contextlib.suppress(Exception):
setattr(pkv, "_llopa_batch_layer_valid_masks", new_masks)
setattr(pkv, "_tri_past_len_cache", None)
def tri_forward_assistant(
self: "LlamaModel",
assistant_ids: torch.LongTensor,
lower_k: int,
pkv: Cache,
S: int,
U: int,
write_cache: bool = True,
prefill_mode: str = "lower",
no_upper_attn: bool = False,
align_cache_position_to_layer_past: bool = False,
replay_module: str = "none",
replay_per_layers: int = -1,
):
device = assistant_ids.device
B, T = assistant_ids.shape
inputs_embeds = self.embed_tokens(assistant_ids)
prefix_global = pkv.get_seq_length()
cache_position = _tri_arange(prefix_global, T, device)
position_ids = cache_position.unsqueeze(0).expand(B, -1)
batch_valid_masks = getattr(pkv, "_llopa_batch_layer_valid_masks", None)
batch_decode_position_ids = getattr(pkv, "_llopa_batch_decode_position_ids", None)
if isinstance(batch_decode_position_ids, torch.Tensor):
batch_decode_position_ids = batch_decode_position_ids.to(device=device, dtype=torch.long)
if batch_decode_position_ids.shape != (B, T):
batch_decode_position_ids = None
else:
batch_decode_position_ids = None
if not (
isinstance(batch_valid_masks, (list, tuple))
and len(batch_valid_masks) >= len(self.layers)
and all(isinstance(mask, torch.Tensor) and mask.dim() == 2 for mask in batch_valid_masks[: len(self.layers)])
):
batch_valid_masks = None
attn_impl = _resolve_attn_impl(self.config)
implicit_causal = attn_impl == "flash_attention_2" or _can_use_implicit_causal_mask(self.config)
select = set(_prefill_layer_indices(self.config.num_hidden_layers, prefill_mode, lower_k))
n_layers = len(self.layers)
has_lower = bool(select)
has_upper = len(select) < n_layers
replay_module = _normalize_replay_module(
replay_module or getattr(pkv, "_tri_replay_module", getattr(pkv, "_tri_last_layer_module", "none"))
)
replay_per_layers = _normalize_replay_per_layers(
replay_per_layers if replay_per_layers is not None else getattr(pkv, "_tri_replay_per_layers", -1)
)
replay_hidden = getattr(pkv, "_tri_last_layer_memory_hidden", None)
replay_position_ids = getattr(pkv, "_tri_last_layer_memory_position_ids", None)
replay_valid_mask = getattr(pkv, "_tri_last_layer_memory_valid_mask", None)
upper_layer_indices = [li for li in range(n_layers) if li not in select]
replay_layer_set = _tri_replay_layer_index_set(
upper_layer_indices=upper_layer_indices,
replay_per_layers=replay_per_layers,
)
replay_enabled = (
replay_module != "none"
and has_upper
and not bool(no_upper_attn)
and bool(replay_layer_set)
and isinstance(replay_hidden, torch.Tensor)
and replay_hidden.numel() > 0
and int(replay_hidden.size(1)) > 0
)
past_lens = None
layer_position_offsets = None
if batch_valid_masks is None and bool(align_cache_position_to_layer_past):
cached_offsets = getattr(pkv, "_tri_layer_position_offsets", None)
if isinstance(cached_offsets, (list, tuple)) and len(cached_offsets) == n_layers:
layer_position_offsets = [int(offset) for offset in cached_offsets]
need_past_lens = batch_valid_masks is None and (
(not implicit_causal and has_upper)
or (bool(align_cache_position_to_layer_past) and layer_position_offsets is None)
)
if need_past_lens:
cache = getattr(pkv, "_tri_past_len_cache", None)
if isinstance(cache, dict):
last = cache.get("last_seq_len")
cached = cache.get("past_lens")
if isinstance(cached, list) and len(cached) == n_layers:
if isinstance(last, int) and last == prefix_global:
past_lens = cached
elif isinstance(last, int) and last == prefix_global - T:
past_lens = [pl + T for pl in cached]
if past_lens is None:
past_lens = [_layer_past_len(pkv, li) for li in range(n_layers)]
try:
pkv._tri_past_len_cache = {"last_seq_len": prefix_global, "past_lens": past_lens}
except Exception:
pass
if bool(align_cache_position_to_layer_past) and layer_position_offsets is None:
layer_position_offsets = [int(prefix_global - int(pl)) for pl in past_lens]
with contextlib.suppress(Exception):
setattr(pkv, "_tri_layer_position_offsets", layer_position_offsets)
mask_lower = None
mask_upper = None
if batch_valid_masks is None and not implicit_causal:
if has_lower:
attn_mask_tokens = torch.ones((B, prefix_global + T), device=device, dtype=torch.long)
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attn_mask_tokens,
"cache_position": cache_position,
"past_key_values": pkv,
"position_ids": position_ids,
}
mask_lower = create_causal_mask(**mask_kwargs)
if has_upper:
upper_past_len = None
upper_uniform = True
for li in range(n_layers):
if li in select:
continue
past_len = past_lens[li] if past_lens is not None else _layer_past_len(pkv, li)
if upper_past_len is None:
upper_past_len = past_len
elif past_len != upper_past_len:
upper_uniform = False
break
if upper_uniform and upper_past_len is not None:
mask_upper = _build_tri_mask_local(B, T, upper_past_len, device, inputs_embeds.dtype)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
position_embeddings_by_offset = {0: position_embeddings} if batch_valid_masks is None and batch_decode_position_ids is None else None
use_checkpoint = should_use_checkpoint(self)
packed_varlen_decode = (
bool(write_cache)
and batch_valid_masks is not None
and not replay_enabled
and _tri_batch_packed_varlen_decode_enabled(inputs_embeds, batch_size=B, query_len=T)
)
for li, layer in enumerate(self.layers):
layer_cache_position = cache_position
layer_position_ids = position_ids
layer_position_offset = 0
if batch_decode_position_ids is not None:
layer_position_ids = batch_decode_position_ids
elif batch_valid_masks is not None:
valid_mask_li = batch_valid_masks[li].to(device=device, dtype=torch.bool)
layer_past_lens = valid_mask_li.sum(dim=1).to(device=device, dtype=torch.long)
layer_position_ids = layer_past_lens.unsqueeze(1) + torch.arange(T, device=device, dtype=torch.long).unsqueeze(0)
elif bool(align_cache_position_to_layer_past) and layer_position_offsets is not None:
offset = int(layer_position_offsets[li])
if offset:
layer_cache_position = cache_position - offset
layer_position_ids = position_ids - offset
layer_position_offset = offset
if bool(no_upper_attn) and (li not in select):
residual = hidden_states
ffn_input = layer.post_attention_layernorm(hidden_states)
if use_checkpoint:
ffn_out = _checkpoint_mlp(layer.mlp, ffn_input)
else:
ffn_out = layer.mlp(ffn_input)
hidden_states = residual + ffn_out
continue
if position_embeddings_by_offset is not None:
layer_position_embeddings = position_embeddings_by_offset.get(layer_position_offset)
if layer_position_embeddings is None:
layer_position_embeddings = self.rotary_emb(hidden_states, layer_position_ids)
position_embeddings_by_offset[layer_position_offset] = layer_position_embeddings
else:
layer_position_embeddings = (
self.rotary_emb(hidden_states, layer_position_ids)
if batch_valid_masks is not None or layer_position_ids is not position_ids
else position_embeddings
)
attn_mask = None
if batch_valid_masks is not None:
if bool(packed_varlen_decode):
packed_hidden_states = _tri_run_decoder_layer_packed_varlen_decode(
layer=layer,
hidden_states=hidden_states,
position_embeddings=layer_position_embeddings,
past_key_values=pkv,
cache_position=layer_cache_position,
layer_valid_mask=batch_valid_masks[li],
use_checkpoint=use_checkpoint,
)
if packed_hidden_states is not None:
hidden_states = packed_hidden_states
continue
attn_mask = _tri_build_valid_cache_decode_mask(
batch_valid_masks[li],
query_len=T,
device=device,
dtype=inputs_embeds.dtype,
)
elif implicit_causal:
attn_mask = None
elif li in select:
attn_mask = mask_lower
else:
attn_mask = mask_upper
if attn_mask is None:
past_len_li = past_lens[li] if past_lens is not None else _layer_past_len(pkv, li)
attn_mask = _build_tri_mask_local(B, T, past_len_li, device, inputs_embeds.dtype)
use_replay_layer = replay_enabled and (li in replay_layer_set)
if use_replay_layer:
hidden_states = _tri_run_last_layer_with_replay(
layer=layer,
hidden_states=hidden_states,
attention_mask=attn_mask,
position_ids=layer_position_ids,
past_key_values=pkv,
use_cache=write_cache,
cache_position=layer_cache_position,
position_embeddings=layer_position_embeddings,
memory_hidden_states=replay_hidden,
memory_position_ids=replay_position_ids,
memory_valid_mask=replay_valid_mask,
local_valid_mask=torch.ones((B, T), device=device, dtype=torch.bool),
query_replay_mask=torch.ones((B, T), device=device, dtype=torch.bool),
module_type=replay_module,
rotary_emb=self.rotary_emb,
use_checkpoint=use_checkpoint,
)
else:
hidden_states = layer(
hidden_states,
attention_mask=attn_mask,
position_ids=layer_position_ids,
past_key_values=pkv,
use_cache=write_cache,
cache_position=layer_cache_position,
position_embeddings=layer_position_embeddings,
)
hidden_states = self.norm(hidden_states)
if batch_decode_position_ids is not None:
with contextlib.suppress(Exception):
setattr(pkv, "_llopa_batch_decode_position_ids", None)
return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=pkv)
LlamaModel.tri_prefill_system_all = tri_prefill_system_all
LlamaModel.tri_prefill_user_lower = tri_prefill_user_lower
LlamaModel.tri_build_caches = tri_build_caches
LlamaModel.tri_forward_assistant = tri_forward_assistant
def llopa_prefill_cache(
self: "LlamaModel",
system_ids: Optional[torch.LongTensor],
user_ids: Optional[torch.LongTensor],
assistant_ids: Optional[torch.LongTensor],
lower_k: int,
prefill_mode: str = "lower",
prefill_attn: str = "causal",
return_last_assistant_hidden: bool = False,
replay_module: str = "none",
replay_per_layers: int = -1,
replay_user_prefix_keep_len: int = 0,
replay_user_start: Optional[int] = None,
replay_user_len: Optional[int] = None,
) -> Cache | tuple[Cache, Optional[torch.Tensor]]:
# LLOPA prefill currently supports contiguous lower layers only (causal or prefix-full).
if (prefill_mode or "lower") != "lower":
raise ValueError("llopa_prefill_cache requires prefill_mode='lower' (periodic not supported).")
attn_mode = (prefill_attn or "causal").strip().lower()
if attn_mode == "prefix_full":
attn_mode = "full"
if attn_mode not in ("causal", "full"):
raise ValueError("llopa_prefill_cache requires prefill_attn in {'causal','full'}.")
prefix_full = (attn_mode == "full")
replay_module = _normalize_replay_module(replay_module)
replay_per_layers = _normalize_replay_per_layers(replay_per_layers)
ref = system_ids if system_ids is not None else (user_ids if user_ids is not None else assistant_ids)
if ref is None:
raise ValueError("llopa_prefill_cache requires at least one non-empty input segment.")
if system_ids is None:
system_ids = ref[:, :0]
if user_ids is None:
user_ids = ref[:, :0]
if assistant_ids is None:
assistant_ids = ref[:, :0]
if system_ids.size(0) != ref.size(0) or user_ids.size(0) != ref.size(0) or assistant_ids.size(0) != ref.size(0):
raise ValueError("llopa_prefill_cache expects all segments to share the same batch size.")
n_layers = len(self.layers)
K = max(0, min(int(lower_k), n_layers))
# If no lower layers, user tokens are not represented at all (match TRI behavior).
user_ids_eff = user_ids if K > 0 else user_ids[:, :0]
S, U, A = system_ids.size(1), user_ids_eff.size(1), assistant_ids.size(1)
full_ids = torch.cat([system_ids, user_ids_eff, assistant_ids], dim=1)
B, L = full_ids.shape
pkv = _safe_dynamic_cache(self.config)
def _finish(last_hidden_source: Optional[torch.Tensor]):
if not bool(return_last_assistant_hidden):
return pkv
if A <= 0 or not isinstance(last_hidden_source, torch.Tensor) or last_hidden_source.numel() == 0:
return pkv, None
return pkv, self.norm(last_hidden_source[:, -1:, :])
if L == 0:
return _finish(None)
device = full_ids.device
inputs_embeds = self.embed_tokens(full_ids)
cache_position = _tri_arange(0, L, device)
position_ids = cache_position.unsqueeze(0).expand(B, -1)
attn_impl = _resolve_attn_impl(self.config)
if prefix_full and attn_impl == "flash_attention_2":
raise ValueError("llopa_prefill_cache prefill_attn=full is not supported with flash_attention_2.")
if attn_impl == "flash_attention_2" or (not prefix_full and _can_use_implicit_causal_mask(self.config)):
lower_mask = None
elif prefix_full:
lower_mask = _build_tri_mask_prefix_full(B, S + U, L, device, inputs_embeds.dtype)
else:
attention_mask = torch.ones_like(full_ids, dtype=torch.long, device=device)
lower_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
use_checkpoint = should_use_checkpoint(self)
for li in range(K):
layer = self.layers[li]
if use_checkpoint:
hidden_states = checkpoint_with_cache(
layer.__call__,
hidden_states,
attention_mask=lower_mask,
position_ids=position_ids,
past_key_values=pkv,
use_cache=True,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
else:
hidden_states = layer(
hidden_states,
attention_mask=lower_mask,
position_ids=position_ids,
past_key_values=pkv,
use_cache=True,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
# If all layers are lower, we're done.
if K >= n_layers:
return _finish(hidden_states)
if U == 0:
upper_hidden = hidden_states
# With model sharding, hidden_states can migrate across devices after lower layers.
# Keep index/position tensors on the same device as current activations.
upper_position_ids = position_ids.to(device=upper_hidden.device)
upper_cache_position = cache_position.to(device=upper_hidden.device)
else:
split_device = hidden_states.device
upper_idx = _llopa_upper_indices(S, U, A, split_device)
upper_hidden = hidden_states.index_select(dim=1, index=upper_idx)
upper_position_ids = upper_idx.unsqueeze(0).expand(B, -1)
upper_cache_position = upper_idx
if upper_hidden.size(1) == 0:
return _finish(None)
replay_hidden = None
replay_position_ids = None
replay_valid_mask = None
replay_query_mask = None
upper_layer_indices = list(range(K, n_layers))
replay_layer_set = _tri_replay_layer_index_set(
upper_layer_indices=upper_layer_indices,
replay_per_layers=replay_per_layers,
)
replay_enabled = replay_module != "none" and U > 0 and K < n_layers and bool(replay_layer_set)
if replay_enabled:
replay_prefix_keep_len, replay_start, replay_end = _tri_resolve_replay_user_window(
user_len=U,
prefix_keep_len=int(replay_user_prefix_keep_len or 0),
latest_user_start=replay_user_start,
latest_user_len=replay_user_len,
)
replay_hidden_parts = []
replay_position_parts = []
if replay_prefix_keep_len > 0:
replay_hidden_parts.append(hidden_states[:, S : S + replay_prefix_keep_len, :])
replay_position_parts.append(position_ids[:, S : S + replay_prefix_keep_len])
if replay_end > replay_start:
replay_hidden_parts.append(hidden_states[:, S + replay_start : S + replay_end, :])
replay_position_parts.append(position_ids[:, S + replay_start : S + replay_end])
if replay_hidden_parts:
replay_hidden = (
replay_hidden_parts[0]
if len(replay_hidden_parts) == 1
else torch.cat(replay_hidden_parts, dim=1)
).to(device=upper_hidden.device)
replay_position_ids = (
replay_position_parts[0]
if len(replay_position_parts) == 1
else torch.cat(replay_position_parts, dim=1)
).to(device=upper_hidden.device, dtype=torch.long)
replay_valid_mask = torch.ones(
(B, int(replay_hidden.size(1))),
device=upper_hidden.device,
dtype=torch.bool,
)
replay_query_mask = _tri_build_last_layer_query_mask(
prefix_keep_lens=torch.full((B,), int(S), device=upper_hidden.device, dtype=torch.long),
valid_lens=torch.full((B,), int(upper_hidden.size(1)), device=upper_hidden.device, dtype=torch.long),
total_len=int(upper_hidden.size(1)),
device=upper_hidden.device,
)
if isinstance(replay_hidden, torch.Tensor) and replay_hidden.numel() > 0 and int(replay_hidden.size(1)) > 0:
_tri_store_last_layer_memory(
pkv,
hidden_states=replay_hidden,
position_ids=replay_position_ids,
valid_mask=replay_valid_mask,
module_type=replay_module,
replay_per_layers=replay_per_layers,
)
num_specials_eff = _llopa_effective_specials(self, U)
if attn_impl == "flash_attention_2" or (not prefix_full and _can_use_implicit_causal_mask(self.config)):
upper_mask = None
elif prefix_full:
upper_mask = _build_tri_mask_prefix_full(B, S, upper_hidden.size(1), device, upper_hidden.dtype)
else:
upper_mask = _build_tri_mask_local(B, upper_hidden.size(1), 0, device, upper_hidden.dtype)
upper_pos_emb = self.rotary_emb(upper_hidden, upper_position_ids)
if num_specials_eff > 0:
special_pos = torch.arange(
S,
S + num_specials_eff,
device=upper_cache_position.device,
dtype=upper_cache_position.dtype,
)
special_pos_b = special_pos.unsqueeze(0).expand(B, -1)
upper_position_ids_sp = torch.cat(
[upper_position_ids[:, :S], special_pos_b, upper_position_ids[:, S:]],
dim=1,
)
upper_cache_position_sp = torch.cat(
[upper_cache_position[:S], special_pos, upper_cache_position[S:]],
dim=0,
)
if attn_impl == "flash_attention_2" or (not prefix_full and _can_use_implicit_causal_mask(self.config)):
upper_mask_sp = None
elif prefix_full:
upper_mask_sp = _build_tri_mask_prefix_full(
B, S, upper_position_ids_sp.size(1), device, upper_hidden.dtype
)
else:
upper_mask_sp = _build_tri_mask_local(
B, upper_position_ids_sp.size(1), 0, device, upper_hidden.dtype
)
else:
upper_position_ids_sp = None
upper_cache_position_sp = None
upper_mask_sp = None
for li in range(K, n_layers):
layer = self.layers[li]
llopa_special_kv = None
llopa_special_span = None
if num_specials_eff > 0:
specials = self.llopa_specials[li][:num_specials_eff]
specials = specials.to(dtype=upper_hidden.dtype, device=upper_hidden.device)
specials = specials.unsqueeze(0).expand(B, -1, -1)
layer_hidden = torch.cat(
[upper_hidden[:, :S, :], specials, upper_hidden[:, S:, :]],
dim=1,
)
layer_position_ids = upper_position_ids_sp
layer_cache_position = upper_cache_position_sp
attn_mask = upper_mask_sp
pos_emb = self.rotary_emb(layer_hidden, layer_position_ids)
if _llopa_can_cache_special_kv(self):
llopa_special_kv = _llopa_special_kv_for_layer(
self,
layer,
layer_idx=li,
num_specials_eff=num_specials_eff,
start_pos=S,
batch_size=B,
dtype=upper_hidden.dtype,
device=upper_hidden.device,
)
llopa_special_span = (S, S + num_specials_eff)
else:
layer_hidden = upper_hidden
layer_position_ids = upper_position_ids
layer_cache_position = upper_cache_position
attn_mask = upper_mask
pos_emb = upper_pos_emb
use_replay_layer = (
replay_enabled
and li in replay_layer_set
and num_specials_eff == 0
and isinstance(replay_hidden, torch.Tensor)
and replay_hidden.numel() > 0
)
if use_replay_layer:
layer_out = _tri_run_last_layer_with_replay(
layer=layer,
hidden_states=layer_hidden,
attention_mask=attn_mask,
position_ids=layer_position_ids,
past_key_values=pkv,
use_cache=True,
cache_position=layer_cache_position,
position_embeddings=pos_emb,
memory_hidden_states=replay_hidden,
memory_position_ids=replay_position_ids,
memory_valid_mask=replay_valid_mask,
local_valid_mask=torch.ones((B, layer_hidden.size(1)), device=layer_hidden.device, dtype=torch.bool),
query_replay_mask=replay_query_mask,
module_type=replay_module,
rotary_emb=self.rotary_emb,
use_checkpoint=use_checkpoint,
)
elif use_checkpoint:
layer_out = checkpoint_with_cache(
layer.__call__,
layer_hidden,
attention_mask=attn_mask,
position_ids=layer_position_ids,
past_key_values=pkv,
use_cache=True,
cache_position=layer_cache_position,
position_embeddings=pos_emb,
llopa_special_kv=llopa_special_kv,
llopa_special_span=llopa_special_span,
)
else:
layer_out = layer(
layer_hidden,
attention_mask=attn_mask,
position_ids=layer_position_ids,
past_key_values=pkv,
use_cache=True,
cache_position=layer_cache_position,
position_embeddings=pos_emb,
llopa_special_kv=llopa_special_kv,
llopa_special_span=llopa_special_span,
)
if num_specials_eff > 0:
upper_hidden = torch.cat(
[layer_out[:, :S, :], layer_out[:, S + num_specials_eff :, :]],
dim=1,
)
else:
upper_hidden = layer_out
return _finish(upper_hidden)
def llopa_forward(
self: "LlamaModel",
system_ids: Optional[torch.LongTensor],
user_ids: Optional[torch.LongTensor],
assistant_ids: Optional[torch.LongTensor],
lower_k: int,
system_attention_mask: Optional[torch.LongTensor] = None,
user_attention_mask: Optional[torch.LongTensor] = None,
assistant_attention_mask: Optional[torch.LongTensor] = None,
prefill_mode: str = "lower",
prefill_attn: str = "causal",
no_upper_attn: bool = False,
):
# llopa currently supports contiguous lower layers only
if (prefill_mode or "lower") != "lower":
raise ValueError("llopa requires prefill_mode='lower' (periodic not supported).")
attn_mode = (prefill_attn or "causal").strip().lower()
if attn_mode == "prefix_full":
attn_mode = "full"
if attn_mode not in ("causal", "full"):
raise ValueError("llopa requires prefill_attn in {'causal','full'}.")
prefix_full = (attn_mode == "full")
ref = system_ids if system_ids is not None else (user_ids if user_ids is not None else assistant_ids)
if ref is None:
raise ValueError("llopa_forward requires at least one non-empty input segment.")
if system_ids is None:
system_ids = ref[:, :0]
if user_ids is None:
user_ids = ref[:, :0]
if assistant_ids is None:
assistant_ids = ref[:, :0]
system_attention_mask = _llopa_default_attention_mask(system_ids, system_attention_mask)
user_attention_mask = _llopa_default_attention_mask(user_ids, user_attention_mask)
assistant_attention_mask = _llopa_default_attention_mask(assistant_ids, assistant_attention_mask)
if system_ids.size(0) != ref.size(0) or user_ids.size(0) != ref.size(0) or assistant_ids.size(0) != ref.size(0):
raise ValueError("llopa_forward expects all segments to share the same batch size.")
n_layers = len(self.layers)
K = max(0, min(int(lower_k), n_layers))
# If no lower layers, user tokens are not represented at all (match TRI behavior).
user_ids_eff = user_ids if K > 0 else user_ids[:, :0]
user_attention_mask_eff = user_attention_mask if K > 0 else user_attention_mask[:, :0]
S, U, A = system_ids.size(1), user_ids_eff.size(1), assistant_ids.size(1)
full_ids = torch.cat([system_ids, user_ids_eff, assistant_ids], dim=1)
full_attention_mask = torch.cat([system_attention_mask, user_attention_mask_eff, assistant_attention_mask], dim=1)
B, L = full_ids.shape
if L == 0:
empty = full_ids.new_zeros((B, 0, self.config.hidden_size))
return BaseModelOutputWithPast(last_hidden_state=empty, past_key_values=None)
device = full_ids.device
_llopa_reset_peak_memory(full_ids)
inputs_embeds = self.embed_tokens(full_ids)
cache_position = _tri_arange(0, L, device)
position_ids = _llopa_position_ids_from_mask(full_attention_mask)
_llopa_log_memory(
self,
"after_embed",
system_ids=system_ids,
user_ids=user_ids_eff,
assistant_ids=assistant_ids,
system_attention_mask=system_attention_mask,
user_attention_mask=user_attention_mask_eff,
assistant_attention_mask=assistant_attention_mask,
sequence_len=L,
)
attn_impl = _resolve_attn_impl(self.config)
if prefix_full and attn_impl == "flash_attention_2":
raise ValueError("llopa prefill_attn=full is not supported with flash_attention_2.")
if prefix_full and not _llopa_mask_is_all_ones(full_attention_mask):
raise ValueError("llopa padded batches require prefill_attn='causal'.")
if attn_impl == "flash_attention_2":
lower_mask = None if _llopa_mask_is_all_ones(full_attention_mask) else full_attention_mask
elif not prefix_full and _can_use_implicit_causal_mask(self.config) and _llopa_mask_is_all_ones(full_attention_mask):
lower_mask = None
elif prefix_full:
lower_mask = _build_tri_mask_prefix_full(B, S + U, L, device, inputs_embeds.dtype)
else:
lower_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=full_attention_mask,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
use_checkpoint = should_use_checkpoint(self)
use_optimized_training_path = _llopa_use_optimized_training_path(
self,
attn_impl=attn_impl,
prefix_full=prefix_full,
no_upper_attn=bool(no_upper_attn),
)
use_custom_checkpoint = use_checkpoint and not use_optimized_training_path
_llopa_reset_peak_memory(hidden_states)
for li in range(K):
layer = self.layers[li]
if use_custom_checkpoint:
hidden_states = checkpoint_with_cache(
layer.__call__,
hidden_states,
attention_mask=lower_mask,
position_ids=position_ids,
past_key_values=None,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
else:
hidden_states = layer(
hidden_states,
attention_mask=lower_mask,
position_ids=position_ids,
past_key_values=None,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
_llopa_log_memory(
self,
"after_lower",
system_ids=system_ids,
user_ids=user_ids_eff,
assistant_ids=assistant_ids,
system_attention_mask=system_attention_mask,
user_attention_mask=user_attention_mask_eff,
assistant_attention_mask=assistant_attention_mask,
sequence_len=L,
)
# If all layers are lower, return assistant slice directly.
if K >= n_layers:
hidden_states = self.norm(hidden_states)
assistant_hidden = hidden_states[:, S + U :, :]
return BaseModelOutputWithPast(last_hidden_state=assistant_hidden, past_key_values=None)
# Upper stage: system + assistant only (gap preserved in position ids)
if U == 0:
upper_hidden = hidden_states
upper_attention_mask = full_attention_mask.to(device=upper_hidden.device)
upper_position_ids = position_ids.to(device=upper_hidden.device)
else:
_llopa_reset_peak_memory(hidden_states)
upper_hidden = _llopa_pack_upper_hidden(hidden_states, S, U)
upper_attention_mask = _llopa_pack_upper_tokens(full_attention_mask.to(device=hidden_states.device), S, U)
upper_position_ids = _llopa_pack_upper_tokens(position_ids.to(device=hidden_states.device), S, U)
upper_cache_position = _tri_arange(0, upper_hidden.size(1), upper_hidden.device)
_llopa_log_memory(
self,
"after_upper_compaction",
system_ids=system_ids,
user_ids=user_ids_eff,
assistant_ids=assistant_ids,
system_attention_mask=system_attention_mask,
user_attention_mask=user_attention_mask_eff,
assistant_attention_mask=assistant_attention_mask,
sequence_len=upper_hidden.size(1),
)
if bool(no_upper_attn):
for li in range(K, n_layers):
layer = self.layers[li]
residual = upper_hidden
ffn_input = layer.post_attention_layernorm(upper_hidden)
if use_checkpoint:
ffn_out = _checkpoint_mlp(layer.mlp, ffn_input)
else:
ffn_out = layer.mlp(ffn_input)
upper_hidden = residual + ffn_out
upper_hidden = self.norm(upper_hidden)
assistant_hidden = upper_hidden[:, S:, :]
return BaseModelOutputWithPast(last_hidden_state=assistant_hidden, past_key_values=None)
num_specials_eff = _llopa_effective_specials(self, U)
if attn_impl == "flash_attention_2":
upper_mask = None if _llopa_mask_is_all_ones(upper_attention_mask) else upper_attention_mask
elif not prefix_full and _can_use_implicit_causal_mask(self.config) and _llopa_mask_is_all_ones(upper_attention_mask):
upper_mask = None
elif prefix_full:
upper_mask = _build_tri_mask_prefix_full(B, S, upper_hidden.size(1), device, upper_hidden.dtype)
else:
upper_mask = create_causal_mask(
config=self.config,
input_embeds=upper_hidden,
attention_mask=upper_attention_mask,
cache_position=upper_cache_position,
past_key_values=None,
position_ids=upper_position_ids,
)
upper_pos_emb = self.rotary_emb(upper_hidden, upper_position_ids)
if num_specials_eff > 0:
special_steps = torch.arange(num_specials_eff, device=upper_hidden.device, dtype=torch.long).unsqueeze(0)
system_lengths = system_attention_mask.sum(dim=1, dtype=torch.long).to(device=upper_hidden.device).unsqueeze(1)
special_pos_b = system_lengths + special_steps
upper_position_ids_sp = torch.cat(
[upper_position_ids[:, :S], special_pos_b, upper_position_ids[:, S:]],
dim=1,
)
upper_cache_position_sp = _tri_arange(0, upper_position_ids_sp.size(1), upper_hidden.device)
special_mask = torch.ones((B, num_specials_eff), device=upper_hidden.device, dtype=upper_attention_mask.dtype)
upper_attention_mask_sp = torch.cat(
[upper_attention_mask[:, :S], special_mask, upper_attention_mask[:, S:]],
dim=1,
)
if attn_impl == "flash_attention_2":
upper_mask_sp = None if _llopa_mask_is_all_ones(upper_attention_mask_sp) else upper_attention_mask_sp
elif not prefix_full and _can_use_implicit_causal_mask(self.config) and _llopa_mask_is_all_ones(upper_attention_mask_sp):
upper_mask_sp = None
elif prefix_full:
upper_mask_sp = _build_tri_mask_prefix_full(
B, S, upper_position_ids_sp.size(1), device, upper_hidden.dtype
)
else:
pad_hidden = upper_hidden.new_empty((B, upper_position_ids_sp.size(1), upper_hidden.size(-1)))
upper_mask_sp = create_causal_mask(
config=self.config,
input_embeds=pad_hidden,
attention_mask=upper_attention_mask_sp,
cache_position=upper_cache_position_sp,
past_key_values=None,
position_ids=upper_position_ids_sp,
)
else:
upper_position_ids_sp = None
upper_cache_position_sp = None
upper_mask_sp = None
for li in range(K, n_layers):
layer = self.layers[li]
_llopa_reset_peak_memory(upper_hidden)
if num_specials_eff > 0:
specials = self.llopa_specials[li][:num_specials_eff]
specials = specials.to(dtype=upper_hidden.dtype, device=upper_hidden.device)
specials = specials.unsqueeze(0).expand(B, -1, -1)
layer_hidden = _llopa_insert_specials(upper_hidden, specials, S)
layer_position_ids = upper_position_ids_sp
layer_cache_position = upper_cache_position_sp
attn_mask = upper_mask_sp
pos_emb = self.rotary_emb(layer_hidden, layer_position_ids)
else:
layer_hidden = upper_hidden
layer_position_ids = upper_position_ids
layer_cache_position = upper_cache_position
attn_mask = upper_mask
pos_emb = upper_pos_emb
if use_custom_checkpoint:
layer_out = checkpoint_with_cache(
layer.__call__,
layer_hidden,
attention_mask=attn_mask,
position_ids=layer_position_ids,
past_key_values=None,
use_cache=False,
cache_position=layer_cache_position,
position_embeddings=pos_emb,
)
else:
layer_out = layer(
layer_hidden,
attention_mask=attn_mask,
position_ids=layer_position_ids,
past_key_values=None,
use_cache=False,
cache_position=layer_cache_position,
position_embeddings=pos_emb,
)
if num_specials_eff > 0:
upper_hidden = _llopa_remove_specials(layer_out, S, num_specials_eff)
else:
upper_hidden = layer_out
_llopa_log_memory(
self,
f"after_upper_layer_{li}",
system_ids=system_ids,
user_ids=user_ids_eff,
assistant_ids=assistant_ids,
system_attention_mask=system_attention_mask,
user_attention_mask=user_attention_mask_eff,
assistant_attention_mask=assistant_attention_mask,
sequence_len=upper_hidden.size(1),
)
upper_hidden = self.norm(upper_hidden)
assistant_hidden = upper_hidden[:, S:, :]
return BaseModelOutputWithPast(last_hidden_state=assistant_hidden, past_key_values=None)
LlamaModel.llopa_forward = llopa_forward
LlamaModel.llopa_prefill_cache = llopa_prefill_cache
def llopa_step_logits(
self: "LlamaForCausalLM",
system_ids: Optional[torch.LongTensor],
user_ids: Optional[torch.LongTensor],
assistant_ids: Optional[torch.LongTensor],
lower_k: int,
system_attention_mask: Optional[torch.LongTensor] = None,
user_attention_mask: Optional[torch.LongTensor] = None,
assistant_attention_mask: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
labels: Optional[torch.LongTensor] = None,
prefill_mode: str = "lower",
prefill_attn: str = "causal",
no_upper_attn: bool = False,
):
out = self.model.llopa_forward(
system_ids=system_ids,
system_attention_mask=system_attention_mask,
user_ids=user_ids,
user_attention_mask=user_attention_mask,
assistant_ids=assistant_ids,
assistant_attention_mask=assistant_attention_mask,
lower_k=lower_k,
prefill_mode=prefill_mode,
prefill_attn=prefill_attn,
no_upper_attn=no_upper_attn,
)
hidden_states = out.last_hidden_state
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
_llopa_reset_peak_memory(hidden_states)
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
_llopa_log_memory(
self.model,
"after_loss",
system_ids=system_ids,
user_ids=user_ids,
assistant_ids=assistant_ids,
system_attention_mask=system_attention_mask,
user_attention_mask=user_attention_mask,
assistant_attention_mask=assistant_attention_mask,
sequence_len=hidden_states.size(1),
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None,
hidden_states=hidden_states,
)
LlamaForCausalLM.llopa_step_logits = llopa_step_logits
def _tri_pack_prefill_training_segments(
*,
model,
system_ids: Optional[torch.LongTensor],
user_ids: Optional[torch.LongTensor],
assistant_ids: Optional[torch.LongTensor],
system_attention_mask: Optional[torch.Tensor] = None,
user_attention_mask: Optional[torch.Tensor] = None,
assistant_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
):
ref = system_ids if system_ids is not None else (user_ids if user_ids is not None else assistant_ids)
if ref is None:
raise ValueError("segmented prefill step requires at least one non-empty input segment.")
device = ref.device
batch_size = int(ref.size(0))
if system_ids is None:
system_ids = ref[:, :0]
if user_ids is None:
user_ids = ref[:, :0]
if assistant_ids is None:
assistant_ids = ref[:, :0]
system_attention_mask = _llopa_default_attention_mask(system_ids, system_attention_mask)
user_attention_mask = _llopa_default_attention_mask(user_ids, user_attention_mask)
assistant_attention_mask = _llopa_default_attention_mask(assistant_ids, assistant_attention_mask)
if labels is None:
labels = assistant_ids.new_full(assistant_ids.shape, -100)
if (
system_ids.size(0) != batch_size
or user_ids.size(0) != batch_size
or assistant_ids.size(0) != batch_size
or labels.size(0) != batch_size
):
raise ValueError("segmented prefill step expects all segments to share the same batch size.")
pad_token_id = getattr(model.config, "pad_token_id", None)
if pad_token_id is None:
pad_token_id = 0
packed_input_rows = []
packed_label_rows = []
split_starts = []
system_lens = []
for row in range(batch_size):
sys_len = int(system_attention_mask[row].sum().item())
user_len = int(user_attention_mask[row].sum().item())
assistant_len = int(assistant_attention_mask[row].sum().item())
if assistant_len <= 0:
continue
sys_row = system_ids[row : row + 1, :sys_len]
user_row = user_ids[row : row + 1, :user_len]
assistant_row = assistant_ids[row : row + 1, :assistant_len]
label_row = labels[row : row + 1, :assistant_len]
packed_input_rows.append(torch.cat([sys_row, user_row, assistant_row], dim=1))
packed_label_rows.append(
torch.cat(
[
torch.full((1, sys_len + user_len), -100, dtype=label_row.dtype, device=device),
label_row,
],
dim=1,
)
)
split_starts.append(sys_len + user_len)
system_lens.append(sys_len)
if not packed_input_rows:
return None
max_len = max(int(row.size(1)) for row in packed_input_rows)
input_ids = torch.full((len(packed_input_rows), max_len), int(pad_token_id), dtype=system_ids.dtype, device=device)
attention_mask = torch.zeros((len(packed_input_rows), max_len), dtype=torch.long, device=device)
full_labels = torch.full((len(packed_label_rows), max_len), -100, dtype=labels.dtype, device=device)
for row_idx, (input_row, label_row) in enumerate(zip(packed_input_rows, packed_label_rows)):
width = int(input_row.size(1))
input_ids[row_idx, :width] = input_row.squeeze(0)
attention_mask[row_idx, :width] = 1
full_labels[row_idx, :width] = label_row.squeeze(0)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": full_labels,
"split_starts": torch.tensor(split_starts, dtype=torch.long, device=device),
"system_lens": torch.tensor(system_lens, dtype=torch.long, device=device),
}
def segmented_prefill_lower_freeze_step_logits(
self: "LlamaForCausalLM",
*,
system_ids: Optional[torch.LongTensor],
user_ids: Optional[torch.LongTensor],
assistant_ids: Optional[torch.LongTensor],
lower_k: int,
system_attention_mask: Optional[torch.LongTensor] = None,
user_attention_mask: Optional[torch.LongTensor] = None,
assistant_attention_mask: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
labels: Optional[torch.LongTensor] = None,
prefill_attn: str = "causal",
system_prefill: str = "no_bos_system",
):
packed = _tri_pack_prefill_training_segments(
model=self,
system_ids=system_ids,
user_ids=user_ids,
assistant_ids=assistant_ids,
system_attention_mask=system_attention_mask,
user_attention_mask=user_attention_mask,
assistant_attention_mask=assistant_attention_mask,
labels=labels,
)
if packed is None:
return None
return self.tri_vanilla_frozen_prefix_train_forward(
input_ids=packed["input_ids"],
attention_mask=packed["attention_mask"],
labels=packed["labels"],
use_cache=False,
logits_to_keep=logits_to_keep,
past_key_values=None,
prefill_lower_layers=int(lower_k),
prefill_attn=str(prefill_attn),
split_start_hint=packed["split_starts"],
system_len_hint=packed["system_lens"],
system_prefill=str(system_prefill),
)
def segmented_prefill_lower_solo_step_logits(
self: "LlamaForCausalLM",
*,
system_ids: Optional[torch.LongTensor],
user_ids: Optional[torch.LongTensor],
assistant_ids: Optional[torch.LongTensor],
lower_k: int,
system_attention_mask: Optional[torch.LongTensor] = None,
user_attention_mask: Optional[torch.LongTensor] = None,
assistant_attention_mask: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
labels: Optional[torch.LongTensor] = None,
prefill_attn: str = "causal",
system_prefill: str = "no_bos_system",
):
packed = _tri_pack_prefill_training_segments(
model=self,
system_ids=system_ids,
user_ids=user_ids,
assistant_ids=assistant_ids,
system_attention_mask=system_attention_mask,
user_attention_mask=user_attention_mask,
assistant_attention_mask=assistant_attention_mask,
labels=labels,
)
if packed is None:
return None
return self.tri_vanilla_solo_prefix_train_forward(
input_ids=packed["input_ids"],
attention_mask=packed["attention_mask"],
labels=packed["labels"],
logits_to_keep=logits_to_keep,
prefill_lower_layers=int(lower_k),
prefill_attn=str(prefill_attn),
split_start_hint=packed["split_starts"],
system_len_hint=packed["system_lens"],
system_prefill=str(system_prefill),
)
def segmented_prefill_lower_solo_v2_step_logits(
self: "LlamaForCausalLM",
*,
system_ids: Optional[torch.LongTensor],
user_ids: Optional[torch.LongTensor],
assistant_ids: Optional[torch.LongTensor],
lower_k: int,
system_attention_mask: Optional[torch.LongTensor] = None,
user_attention_mask: Optional[torch.LongTensor] = None,
assistant_attention_mask: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
labels: Optional[torch.LongTensor] = None,
prefill_attn: str = "causal",
system_prefill: str = "no_bos_system",
with_bos: bool = False,
):
packed = _tri_pack_prefill_training_segments(
model=self,
system_ids=system_ids,
user_ids=user_ids,
assistant_ids=assistant_ids,
system_attention_mask=system_attention_mask,
user_attention_mask=user_attention_mask,
assistant_attention_mask=assistant_attention_mask,
labels=labels,
)
if packed is None:
return None
return self.tri_vanilla_solo_attention_v2_train_forward(
input_ids=packed["input_ids"],
attention_mask=packed["attention_mask"],
labels=packed["labels"],
logits_to_keep=logits_to_keep,
prefill_lower_layers=int(lower_k),
prefill_attn=str(prefill_attn),
split_start_hint=packed["split_starts"],
system_len_hint=packed["system_lens"],
system_prefill=str(system_prefill),
with_bos=bool(with_bos),
)
LlamaForCausalLM.segmented_prefill_lower_freeze_step_logits = segmented_prefill_lower_freeze_step_logits
LlamaForCausalLM.segmented_prefill_lower_solo_step_logits = segmented_prefill_lower_solo_step_logits
LlamaForCausalLM.segmented_prefill_lower_solo_v2_step_logits = segmented_prefill_lower_solo_v2_step_logits
def tri_step_logits(
self: "LlamaForCausalLM",
assistant_ids: torch.LongTensor,
lower_k: int,
pkv: Cache,
S: int,
U: int,
logits_to_keep: Union[int, torch.Tensor] = 0,
labels: Optional[torch.LongTensor] = None,
write_cache: bool = True,
prefill_mode: str = "lower",
no_upper_attn: bool = False,
align_cache_position_to_layer_past: bool = False,
replay_module: str = "none",
replay_per_layers: int = -1,
):
out = self.model.tri_forward_assistant(
assistant_ids=assistant_ids,
lower_k=lower_k,
pkv=pkv,
S=S,
U=U,
write_cache=write_cache,
prefill_mode=prefill_mode,
no_upper_attn=no_upper_attn,
align_cache_position_to_layer_past=align_cache_position_to_layer_past,
replay_module=replay_module,
replay_per_layers=replay_per_layers,
)
hidden_states = out.last_hidden_state
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=pkv if write_cache else None,
hidden_states=hidden_states,
)
LlamaForCausalLM.tri_step_logits = tri_step_logits
LlamaForCausalLM.llopa_decode_step_logits = tri_step_logits
def tri_prefill_lower_cache(
self: "LlamaModel",
prefix_ids: torch.LongTensor,
lower_k: int,
prefill_mode: str = "lower",
prefill_attn: str = "causal",
) -> Cache:
pkv = _safe_dynamic_cache(self.config)
if prefix_ids is None or prefix_ids.size(1) == 0:
return pkv
device = prefix_ids.device
inputs_embeds = self.embed_tokens(prefix_ids)
cache_position = _tri_arange(0, inputs_embeds.shape[1], device)
position_ids = cache_position.unsqueeze(0).expand(prefix_ids.size(0), -1)
attention_mask = torch.ones_like(prefix_ids, dtype=torch.long, device=device)
attn_impl = _resolve_attn_impl(self.config)
attn_mode = (prefill_attn or "causal").strip().lower()
if attn_mode == "prefix_full":
attn_mode = "full"
if attn_mode == "full":
if attn_impl == "flash_attention_2":
raise ValueError("prefill_lower_attn=full is not supported with flash_attention_2.")
causal_mask = _build_tri_mask_full(prefix_ids.size(0), inputs_embeds.shape[1], 0, device, inputs_embeds.dtype)
elif attn_impl == "flash_attention_2" or _can_use_implicit_causal_mask(self.config):
causal_mask = None
else:
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
use_checkpoint = should_use_checkpoint(self)
for li in _prefill_layer_indices(self.config.num_hidden_layers, prefill_mode, lower_k):
layer = self.layers[li]
if use_checkpoint:
hidden_states = checkpoint_with_cache(
layer.__call__,
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=pkv,
use_cache=True,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
else:
hidden_states = layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=pkv,
use_cache=True,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
return pkv
LlamaModel.tri_prefill_lower_cache = tri_prefill_lower_cache
def _tri_vanilla_warn_once(model, flag: str, msg: str) -> None:
if getattr(model, flag, False):
return
try:
setattr(model, flag, True)
except Exception:
pass
logger.warning(msg)
def _tri_parse_split_start_hint(split_start_hint, batch_size: int) -> Optional[int]:
if split_start_hint is None:
return None
if isinstance(split_start_hint, torch.Tensor):
if split_start_hint.numel() == 0:
return None
if split_start_hint.dim() == 0:
return int(split_start_hint.item())
if split_start_hint.dim() == 1 and split_start_hint.size(0) == batch_size:
return int(split_start_hint[0].item())
return None
if isinstance(split_start_hint, (list, tuple)):
if len(split_start_hint) == 0:
return None
return int(split_start_hint[0])
return int(split_start_hint)
def _tri_parse_split_start_hints(
split_start_hint,
batch_size: int,
*,
device,
) -> Optional[torch.LongTensor]:
if split_start_hint is None:
return None
if isinstance(split_start_hint, torch.Tensor):
if split_start_hint.numel() == 0:
return None
if split_start_hint.dim() == 0:
value = int(split_start_hint.item())
return torch.full((batch_size,), value, device=device, dtype=torch.long)
flat = split_start_hint.flatten()
if flat.numel() == 1:
value = int(flat[0].item())
return torch.full((batch_size,), value, device=device, dtype=torch.long)
if flat.dim() == 1 and flat.size(0) == batch_size:
return flat.to(device=device, dtype=torch.long)
return None
if isinstance(split_start_hint, (list, tuple)):
if len(split_start_hint) == 0:
return None
if len(split_start_hint) == 1:
value = int(split_start_hint[0])
return torch.full((batch_size,), value, device=device, dtype=torch.long)
if len(split_start_hint) == batch_size:
return torch.tensor(split_start_hint, device=device, dtype=torch.long)
return None
value = int(split_start_hint)
return torch.full((batch_size,), value, device=device, dtype=torch.long)
def _tri_parse_system_len_hint(system_len_hint, batch_size: int) -> Optional[int]:
if system_len_hint is None:
return None
if isinstance(system_len_hint, torch.Tensor):
if system_len_hint.numel() == 0:
return None
if system_len_hint.dim() == 0:
return int(system_len_hint.item())
if system_len_hint.dim() == 1 and system_len_hint.size(0) == batch_size:
return int(system_len_hint[0].item())
return None
if isinstance(system_len_hint, (list, tuple)):
if len(system_len_hint) == 0:
return None
return int(system_len_hint[0])
return int(system_len_hint)
def _tri_parse_system_len_hints(
system_len_hint,
batch_size: int,
*,
device,
) -> Optional[torch.LongTensor]:
if system_len_hint is None:
return None
if isinstance(system_len_hint, torch.Tensor):
if system_len_hint.numel() == 0:
return None
if system_len_hint.dim() == 0:
value = int(system_len_hint.item())
return torch.full((batch_size,), value, device=device, dtype=torch.long)
flat = system_len_hint.flatten()
if flat.numel() == 1:
value = int(flat[0].item())
return torch.full((batch_size,), value, device=device, dtype=torch.long)
if flat.dim() == 1 and flat.size(0) == batch_size:
return flat.to(device=device, dtype=torch.long)
return None
if isinstance(system_len_hint, (list, tuple)):
if len(system_len_hint) == 0:
return None
if len(system_len_hint) == 1:
value = int(system_len_hint[0])
return torch.full((batch_size,), value, device=device, dtype=torch.long)
if len(system_len_hint) == batch_size:
return torch.tensor(system_len_hint, device=device, dtype=torch.long)
return None
value = int(system_len_hint)
return torch.full((batch_size,), value, device=device, dtype=torch.long)
def _tri_parse_assistant_header_start_hints(
assistant_header_starts_hint,
assistant_header_start_mask_hint,
batch_size: int,
*,
device,
) -> Optional[tuple[torch.LongTensor, torch.BoolTensor]]:
if assistant_header_starts_hint is None:
return None
starts = assistant_header_starts_hint
if isinstance(starts, torch.Tensor):
if starts.numel() == 0:
return None
if starts.dim() == 1:
starts = starts.view(1, -1)
elif starts.dim() != 2:
return None
if starts.size(0) == 1 and batch_size > 1:
starts = starts.expand(batch_size, -1)
if starts.size(0) != batch_size:
return None
starts = starts.to(device=device, dtype=torch.long)
elif isinstance(starts, (list, tuple)):
if len(starts) == 0:
return None
if isinstance(starts[0], (list, tuple)):
if len(starts) != batch_size:
return None
max_turns = max((len(row) for row in starts), default=0)
starts_tensor = torch.full((batch_size, max_turns), -1, device=device, dtype=torch.long)
for row_idx, row in enumerate(starts):
if len(row) > 0:
starts_tensor[row_idx, : len(row)] = torch.tensor(list(row), device=device, dtype=torch.long)
starts = starts_tensor
else:
starts_tensor = torch.tensor(list(starts), device=device, dtype=torch.long)
starts = starts_tensor.view(1, -1).expand(batch_size, -1)
else:
return None
if assistant_header_start_mask_hint is None:
start_mask = starts >= 0
else:
start_mask = assistant_header_start_mask_hint
if isinstance(start_mask, torch.Tensor):
if start_mask.numel() == 0:
return None
if start_mask.dim() == 1:
start_mask = start_mask.view(1, -1)
elif start_mask.dim() != 2:
return None
if start_mask.size(0) == 1 and batch_size > 1:
start_mask = start_mask.expand(batch_size, -1)
if start_mask.size(0) != batch_size or start_mask.size(1) != starts.size(1):
return None
start_mask = start_mask.to(device=device, dtype=torch.bool)
elif isinstance(start_mask, (list, tuple)):
if len(start_mask) != batch_size:
return None
mask_tensor = torch.zeros_like(starts, dtype=torch.bool)
for row_idx, row in enumerate(start_mask):
if isinstance(row, (list, tuple)):
width = min(len(row), starts.size(1))
if width > 0:
mask_tensor[row_idx, :width] = torch.tensor(list(row)[:width], device=device, dtype=torch.bool)
elif bool(row):
mask_tensor[row_idx, 0] = True
start_mask = mask_tensor
else:
return None
start_mask = start_mask & (starts >= 0)
return starts, start_mask
def _tri_find_valid_label_segments(labels_row: torch.Tensor, valid_len: int) -> list[tuple[int, int]]:
if labels_row.dim() != 1:
raise ValueError("labels_row must be rank-1.")
limit = min(max(int(valid_len), 0), int(labels_row.numel()))
if limit <= 0:
return []
valid_mask = labels_row[:limit] != -100
segments: list[tuple[int, int]] = []
seg_start: Optional[int] = None
for idx in range(limit):
if bool(valid_mask[idx].item()):
if seg_start is None:
seg_start = idx
continue
if seg_start is not None:
segments.append((seg_start, idx))
seg_start = None
if seg_start is not None:
segments.append((seg_start, limit))
return segments
def _tri_resolve_assistant_turn_spans(
*,
labels_row: torch.Tensor,
valid_len: int,
assistant_header_starts: torch.Tensor,
assistant_header_start_mask: torch.Tensor,
) -> list[tuple[int, int]]:
valid_len = min(max(int(valid_len), 0), int(labels_row.numel()))
if valid_len <= 0:
return []
starts: list[int] = []
for col in range(int(assistant_header_starts.numel())):
if not bool(assistant_header_start_mask[col].item()):
continue
start = int(assistant_header_starts[col].item())
if 0 <= start < valid_len:
starts.append(start)
if not starts:
return []
starts = sorted(set(starts))
segments = _tri_find_valid_label_segments(labels_row, valid_len=valid_len)
if not segments:
return []
spans: list[tuple[int, int]] = []
seg_cursor = 0
for idx, start in enumerate(starts):
next_start = starts[idx + 1] if idx + 1 < len(starts) else valid_len
while seg_cursor < len(segments) and segments[seg_cursor][1] <= start:
seg_cursor += 1
probe = seg_cursor
while probe < len(segments):
seg_start, seg_end = segments[probe]
if seg_start >= next_start:
break
if seg_start >= start:
spans.append((start, seg_end))
seg_cursor = probe + 1
break
probe += 1
return spans
def _tri_build_prefill_lower_multiturn_index_batch(
*,
turn_starts: torch.LongTensor,
turn_ends: torch.LongTensor,
system_lens: torch.LongTensor,
system_prefill: str,
device,
owner=None,
bucket_multiple: int = 0,
) -> tuple[torch.LongTensor, torch.BoolTensor, torch.LongTensor]:
turn_starts = turn_starts.to(device=device, dtype=torch.long)
turn_ends = turn_ends.to(device=device, dtype=torch.long)
system_lens = system_lens.to(device=device, dtype=torch.long)
row_count = int(turn_starts.size(0))
upper_indices = []
upper_lens = torch.zeros((row_count,), device=device, dtype=torch.long)
max_upper_len = 0
for row in range(row_count):
idx = build_prefill_lower_upper_indices(
sequence_len=int(turn_ends[row].item()),
split_start=int(turn_starts[row].item()),
system_len=int(system_lens[row].item()),
system_prefill=system_prefill,
device=device,
)
upper_indices.append(idx)
upper_lens[row] = int(idx.numel())
if idx.numel() > max_upper_len:
max_upper_len = int(idx.numel())
bucket_multiple = int(bucket_multiple or 0)
target_upper_len = int(max_upper_len)
if bucket_multiple > 0 and target_upper_len > 0:
target_upper_len = ((target_upper_len + bucket_multiple - 1) // bucket_multiple) * bucket_multiple
gather_idx = None
valid_mask = None
if owner is not None and bucket_multiple > 0 and not bool(getattr(owner, "training", False)):
cache = getattr(owner, "_optimized_llopa_upper_index_workspace_cache", None)
if not isinstance(cache, dict):
cache = {}
key = ("multiturn", str(device), int(row_count), int(target_upper_len))
pair = cache.get(key)
if (
not isinstance(pair, tuple)
or len(pair) != 2
or not isinstance(pair[0], torch.Tensor)
or not isinstance(pair[1], torch.Tensor)
or pair[0].shape != (int(row_count), int(target_upper_len))
or pair[1].shape != (int(row_count), int(target_upper_len))
or pair[0].device != device
or pair[1].device != device
):
pair = (
torch.zeros((int(row_count), int(target_upper_len)), device=device, dtype=torch.long),
torch.zeros((int(row_count), int(target_upper_len)), device=device, dtype=torch.bool),
)
cache[key] = pair
with contextlib.suppress(Exception):
setattr(owner, "_optimized_llopa_upper_index_workspace_cache", cache)
gather_idx, valid_mask = pair
gather_idx.zero_()
valid_mask.zero_()
if gather_idx is None or valid_mask is None:
gather_idx = torch.zeros((row_count, target_upper_len), device=device, dtype=torch.long)
valid_mask = torch.zeros((row_count, target_upper_len), device=device, dtype=torch.bool)
for row, idx in enumerate(upper_indices):
width = int(idx.numel())
if width <= 0:
continue
gather_idx[row, :width] = idx
valid_mask[row, :width] = True
return gather_idx, valid_mask, upper_lens
def _tri_build_prefill_lower_past_assistant_index_batch(
*,
split_starts: torch.LongTensor,
valid_lens: torch.LongTensor,
system_lens: torch.LongTensor,
system_prefill: str,
assistant_header_starts: torch.LongTensor,
assistant_turn_ends: torch.LongTensor,
assistant_header_start_mask: torch.BoolTensor,
device,
owner=None,
bucket_multiple: int = 0,
) -> tuple[torch.LongTensor, torch.BoolTensor, torch.LongTensor]:
split_starts = split_starts.to(device=device, dtype=torch.long)
valid_lens = valid_lens.to(device=device, dtype=torch.long)
system_lens = system_lens.to(device=device, dtype=torch.long)
assistant_header_starts = assistant_header_starts.to(device=device, dtype=torch.long)
assistant_turn_ends = assistant_turn_ends.to(device=device, dtype=torch.long)
assistant_header_start_mask = assistant_header_start_mask.to(device=device, dtype=torch.bool)
batch_size = int(split_starts.size(0))
prefix_keep_lens = _tri_prefix_keep_lengths(system_lens, system_prefill)
upper_indices = []
upper_lens = torch.zeros((batch_size,), device=device, dtype=torch.long)
max_upper_len = 0
for row in range(batch_size):
valid_len = int(valid_lens[row].item())
split_start = max(0, min(int(split_starts[row].item()), valid_len))
prefix_keep = min(max(int(prefix_keep_lens[row].item()), 0), split_start)
row_indices: list[torch.Tensor] = []
if prefix_keep > 0:
row_indices.append(torch.arange(0, prefix_keep, device=device, dtype=torch.long))
num_turns = min(int(assistant_header_starts.size(1)), int(assistant_turn_ends.size(1)))
for col in range(num_turns):
if not bool(assistant_header_start_mask[row, col].item()):
continue
turn_start = max(int(assistant_header_starts[row, col].item()), prefix_keep)
turn_end = int(assistant_turn_ends[row, col].item())
if turn_end <= turn_start:
continue
turn_end = min(turn_end, split_start)
if turn_end <= turn_start:
continue
row_indices.append(torch.arange(turn_start, turn_end, device=device, dtype=torch.long))
if split_start < valid_len:
row_indices.append(torch.arange(split_start, valid_len, device=device, dtype=torch.long))
idx = torch.cat(row_indices, dim=0) if row_indices else torch.empty((0,), device=device, dtype=torch.long)
upper_indices.append(idx)
upper_lens[row] = int(idx.numel())
if idx.numel() > max_upper_len:
max_upper_len = int(idx.numel())
bucket_multiple = int(bucket_multiple or 0)
target_upper_len = int(max_upper_len)
if bucket_multiple > 0 and target_upper_len > 0:
target_upper_len = ((target_upper_len + bucket_multiple - 1) // bucket_multiple) * bucket_multiple
gather_idx = None
valid_mask = None
if owner is not None and bucket_multiple > 0 and not bool(getattr(owner, "training", False)):
cache = getattr(owner, "_optimized_llopa_upper_index_workspace_cache", None)
if not isinstance(cache, dict):
cache = {}
key = ("single_past_assistant", str(device), int(batch_size), int(target_upper_len))
pair = cache.get(key)
if (
not isinstance(pair, tuple)
or len(pair) != 2
or not isinstance(pair[0], torch.Tensor)
or not isinstance(pair[1], torch.Tensor)
or pair[0].shape != (int(batch_size), int(target_upper_len))
or pair[1].shape != (int(batch_size), int(target_upper_len))
or pair[0].device != device
or pair[1].device != device
):
pair = (
torch.zeros((int(batch_size), int(target_upper_len)), device=device, dtype=torch.long),
torch.zeros((int(batch_size), int(target_upper_len)), device=device, dtype=torch.bool),
)
cache[key] = pair
with contextlib.suppress(Exception):
setattr(owner, "_optimized_llopa_upper_index_workspace_cache", cache)
gather_idx, valid_mask = pair
gather_idx.zero_()
valid_mask.zero_()
if gather_idx is None or valid_mask is None:
gather_idx = torch.zeros((batch_size, target_upper_len), device=device, dtype=torch.long)
valid_mask = torch.zeros((batch_size, target_upper_len), device=device, dtype=torch.bool)
for row, idx in enumerate(upper_indices):
width = int(idx.numel())
if width <= 0:
continue
gather_idx[row, :width] = idx
valid_mask[row, :width] = True
return gather_idx, valid_mask, upper_lens
def _tri_has_effective_past_assistant_history(
*,
split_start: int,
assistant_header_hints,
assistant_turn_end_hints,
) -> bool:
if assistant_header_hints is None or assistant_turn_end_hints is None:
return False
header_starts, header_start_mask = assistant_header_hints
turn_ends, _ = assistant_turn_end_hints
if (
not isinstance(header_starts, torch.Tensor)
or not isinstance(header_start_mask, torch.Tensor)
or not isinstance(turn_ends, torch.Tensor)
or header_starts.ndim != 2
or header_start_mask.ndim != 2
or turn_ends.ndim != 2
or header_starts.size(0) <= 0
or header_start_mask.size(0) <= 0
or turn_ends.size(0) <= 0
):
return False
split_start = int(split_start)
num_turns = min(int(header_starts.size(1)), int(turn_ends.size(1)), int(header_start_mask.size(1)))
for col in range(num_turns):
if not bool(header_start_mask[0, col].item()):
continue
turn_start = int(header_starts[0, col].item())
turn_end = int(turn_ends[0, col].item())
if turn_end <= turn_start:
continue
if turn_start < split_start:
return True
return False
def _tri_build_prefill_lower_seed_upper_index_batch(
*,
split_starts: torch.LongTensor,
system_lens: torch.LongTensor,
system_prefill: str,
assistant_header_starts: Optional[torch.LongTensor] = None,
assistant_turn_ends: Optional[torch.LongTensor] = None,
assistant_header_start_mask: Optional[torch.BoolTensor] = None,
see_past_assistant: bool = False,
device,
) -> tuple[torch.LongTensor, torch.BoolTensor, torch.LongTensor]:
split_starts = split_starts.to(device=device, dtype=torch.long)
system_lens = system_lens.to(device=device, dtype=torch.long)
batch_size = int(split_starts.size(0))
prefix_keep_lens = _tri_prefix_keep_lengths(system_lens, system_prefill)
if isinstance(assistant_header_starts, torch.Tensor):
assistant_header_starts = assistant_header_starts.to(device=device, dtype=torch.long)
if isinstance(assistant_turn_ends, torch.Tensor):
assistant_turn_ends = assistant_turn_ends.to(device=device, dtype=torch.long)
if isinstance(assistant_header_start_mask, torch.Tensor):
assistant_header_start_mask = assistant_header_start_mask.to(device=device, dtype=torch.bool)
upper_indices: list[torch.Tensor] = []
upper_lens = torch.zeros((batch_size,), device=device, dtype=torch.long)
max_upper_len = 0
has_history_metadata = (
bool(see_past_assistant)
and isinstance(assistant_header_starts, torch.Tensor)
and isinstance(assistant_turn_ends, torch.Tensor)
and isinstance(assistant_header_start_mask, torch.Tensor)
and assistant_header_starts.dim() == 2
and assistant_turn_ends.dim() == 2
and assistant_header_start_mask.dim() == 2
)
for row in range(batch_size):
split_start = max(int(split_starts[row].item()), 0)
prefix_keep = min(max(int(prefix_keep_lens[row].item()), 0), split_start)
row_parts: list[torch.Tensor] = []
if prefix_keep > 0:
row_parts.append(torch.arange(0, prefix_keep, device=device, dtype=torch.long))
if has_history_metadata:
num_turns = min(int(assistant_header_starts.size(1)), int(assistant_turn_ends.size(1)))
for col in range(num_turns):
if not bool(assistant_header_start_mask[row, col].item()):
continue
turn_start = max(int(assistant_header_starts[row, col].item()), prefix_keep)
turn_end = min(int(assistant_turn_ends[row, col].item()), split_start)
if turn_end > turn_start:
row_parts.append(torch.arange(turn_start, turn_end, device=device, dtype=torch.long))
idx = torch.cat(row_parts, dim=0) if row_parts else torch.empty((0,), device=device, dtype=torch.long)
upper_indices.append(idx)
upper_lens[row] = int(idx.numel())
max_upper_len = max(max_upper_len, int(idx.numel()))
gather_idx = torch.zeros((batch_size, max_upper_len), device=device, dtype=torch.long)
valid_mask = torch.zeros((batch_size, max_upper_len), device=device, dtype=torch.bool)
for row, idx in enumerate(upper_indices):
width = int(idx.numel())
if width > 0:
gather_idx[row, :width] = idx
valid_mask[row, :width] = True
return gather_idx, valid_mask, upper_lens
def _tri_prefill_lower_prompt_cache(
model: "LlamaModel",
prefix_ids: torch.LongTensor,
*,
lower_k: int,
prefill_attn: str,
system_len: int,
system_prefill: str,
return_replay_memory: bool = False,
return_prefix_hidden: bool = False,
replay_user_prefix_keep_len: int = 0,
replay_user_start: Optional[int] = None,
replay_user_len: Optional[int] = None,
) -> Cache | tuple[Cache, Optional[torch.Tensor], Optional[torch.LongTensor], Optional[torch.Tensor], Optional[torch.LongTensor]]:
system_len = min(max(int(system_len), 0), int(prefix_ids.size(1)))
if prefix_ids.size(1) == 0:
empty_cache = _safe_dynamic_cache(model.config)
if not bool(return_replay_memory) and not bool(return_prefix_hidden):
return empty_cache
empty_pos = torch.empty((prefix_ids.size(0), 0), device=prefix_ids.device, dtype=torch.long)
replay_hidden = model.embed_tokens(prefix_ids) if bool(return_replay_memory) else None
replay_pos = empty_pos if bool(return_replay_memory) else None
prefix_hidden = model.embed_tokens(prefix_ids) if bool(return_prefix_hidden) else None
prefix_pos = empty_pos if bool(return_prefix_hidden) else None
return empty_cache, replay_hidden, replay_pos, prefix_hidden, prefix_pos
system_ids = prefix_ids[:, :system_len]
user_ids = prefix_ids[:, system_len:]
mode = normalize_system_prefill(system_prefill)
batch_size = int(prefix_ids.size(0))
replay_hidden_parts: list[torch.Tensor] = []
replay_pos_parts: list[torch.LongTensor] = []
prefix_hidden_parts: list[torch.Tensor] = []
prefix_pos_parts: list[torch.LongTensor] = []
def _capture_replay_segment(hidden_states: Optional[torch.Tensor], start_pos: int) -> None:
if not bool(return_replay_memory):
return
if not isinstance(hidden_states, torch.Tensor) or hidden_states.numel() == 0 or int(hidden_states.size(1)) == 0:
return
replay_hidden_parts.append(hidden_states)
replay_pos_parts.append(
_tri_segment_position_ids(
batch_size=batch_size,
start=int(start_pos),
length=int(hidden_states.size(1)),
device=hidden_states.device,
)
)
def _capture_prefix_segment(hidden_states: Optional[torch.Tensor], start_pos: int) -> None:
if not bool(return_prefix_hidden):
return
if not isinstance(hidden_states, torch.Tensor) or hidden_states.numel() == 0 or int(hidden_states.size(1)) == 0:
return
prefix_hidden_parts.append(hidden_states)
prefix_pos_parts.append(
_tri_segment_position_ids(
batch_size=batch_size,
start=int(start_pos),
length=int(hidden_states.size(1)),
device=hidden_states.device,
)
)
def _capture_user_replay(hidden_states: Optional[torch.Tensor], start_pos: int) -> None:
if not bool(return_replay_memory):
return
if not isinstance(hidden_states, torch.Tensor) or hidden_states.numel() == 0 or int(hidden_states.size(1)) == 0:
return
replay_prefix_keep, replay_start, replay_end = _tri_resolve_replay_user_window(
user_len=int(hidden_states.size(1)),
prefix_keep_len=int(replay_user_prefix_keep_len or 0),
latest_user_start=replay_user_start,
latest_user_len=replay_user_len,
)
if replay_prefix_keep > 0:
_capture_replay_segment(hidden_states[:, :replay_prefix_keep, :], start_pos=start_pos)
if replay_end > replay_start:
_capture_replay_segment(
hidden_states[:, replay_start:replay_end, :],
start_pos=int(start_pos) + int(replay_start),
)
def _finalize(
pkv: Cache,
) -> Cache | tuple[Cache, Optional[torch.Tensor], Optional[torch.LongTensor], Optional[torch.Tensor], Optional[torch.LongTensor]]:
if not bool(return_replay_memory) and not bool(return_prefix_hidden):
return pkv
replay_hidden = None
replay_position_ids = None
prefix_hidden = None
prefix_position_ids = None
if replay_hidden_parts:
replay_hidden = replay_hidden_parts[0] if len(replay_hidden_parts) == 1 else torch.cat(replay_hidden_parts, dim=1)
replay_position_ids = replay_pos_parts[0] if len(replay_pos_parts) == 1 else torch.cat(replay_pos_parts, dim=1)
if prefix_hidden_parts:
prefix_hidden = prefix_hidden_parts[0] if len(prefix_hidden_parts) == 1 else torch.cat(prefix_hidden_parts, dim=1)
prefix_position_ids = prefix_pos_parts[0] if len(prefix_pos_parts) == 1 else torch.cat(prefix_pos_parts, dim=1)
return pkv, replay_hidden, replay_position_ids, prefix_hidden, prefix_position_ids
if mode == "full":
system_out = model.tri_prefill_system_all(
system_ids,
past_key_values=None,
prefill_attn=prefill_attn,
)
pkv = system_out.past_key_values
if user_ids.size(1) > 0:
user_out = model.tri_prefill_user_lower(
user_ids,
lower_k=int(lower_k),
past_key_values=pkv,
prefill_mode="lower",
prefill_attn=prefill_attn,
return_pre_norm_hidden=bool(return_replay_memory or return_prefix_hidden),
)
_capture_user_replay(getattr(user_out, "pre_norm_hidden", None), start_pos=system_len)
_capture_prefix_segment(getattr(user_out, "pre_norm_hidden", None), start_pos=system_len)
return _finalize(pkv)
pkv = _safe_dynamic_cache(model.config)
if mode == "no_system":
if system_ids.size(1) > 0:
bos_ids = system_ids[:, :1]
rest_ids = system_ids[:, 1:]
out = model.tri_prefill_system_all(
bos_ids,
past_key_values=None,
prefill_attn=prefill_attn,
)
pkv = out.past_key_values
if rest_ids.size(1) > 0:
rest_out = model.tri_prefill_user_lower(
rest_ids,
lower_k=int(lower_k),
past_key_values=pkv,
prefill_mode="lower",
prefill_attn=prefill_attn,
return_pre_norm_hidden=bool(return_replay_memory or return_prefix_hidden),
)
_capture_replay_segment(getattr(rest_out, "pre_norm_hidden", None), start_pos=1)
_capture_prefix_segment(getattr(rest_out, "pre_norm_hidden", None), start_pos=1)
if user_ids.size(1) > 0:
user_out = model.tri_prefill_user_lower(
user_ids,
lower_k=int(lower_k),
past_key_values=pkv,
prefill_mode="lower",
prefill_attn=prefill_attn,
return_pre_norm_hidden=bool(return_replay_memory or return_prefix_hidden),
)
_capture_user_replay(getattr(user_out, "pre_norm_hidden", None), start_pos=system_len)
_capture_prefix_segment(getattr(user_out, "pre_norm_hidden", None), start_pos=system_len)
return _finalize(pkv)
if system_ids.size(1) > 0:
system_out = model.tri_prefill_user_lower(
system_ids,
lower_k=int(lower_k),
past_key_values=pkv,
prefill_mode="lower",
prefill_attn=prefill_attn,
return_pre_norm_hidden=bool(return_replay_memory or return_prefix_hidden),
)
_capture_replay_segment(getattr(system_out, "pre_norm_hidden", None), start_pos=0)
_capture_prefix_segment(getattr(system_out, "pre_norm_hidden", None), start_pos=0)
if user_ids.size(1) > 0:
user_out = model.tri_prefill_user_lower(
user_ids,
lower_k=int(lower_k),
past_key_values=pkv,
prefill_mode="lower",
prefill_attn=prefill_attn,
return_pre_norm_hidden=bool(return_replay_memory or return_prefix_hidden),
)
_capture_user_replay(getattr(user_out, "pre_norm_hidden", None), start_pos=system_len)
_capture_prefix_segment(getattr(user_out, "pre_norm_hidden", None), start_pos=system_len)
return _finalize(pkv)
def _tri_prefill_lower_prompt_hidden(
model: "LlamaModel",
prefix_ids: torch.LongTensor,
*,
lower_k: int,
prefill_attn: str,
) -> tuple[torch.Tensor, torch.LongTensor]:
if prefix_ids.size(1) == 0:
empty_hidden = model.embed_tokens(prefix_ids)
empty_pos = torch.empty((prefix_ids.size(0), 0), device=prefix_ids.device, dtype=torch.long)
return empty_hidden, empty_pos
n_layers = len(model.layers)
lower_k = max(0, min(int(lower_k), n_layers))
if lower_k <= 0:
raise ValueError("_tri_prefill_lower_prompt_hidden requires lower_k > 0.")
attn_mode = (prefill_attn or "causal").strip().lower()
if attn_mode == "prefix_full":
attn_mode = "full"
if attn_mode not in {"causal", "full"}:
raise ValueError("_tri_prefill_lower_prompt_hidden requires prefill_attn in {'causal', 'full'}.")
device = prefix_ids.device
inputs_embeds = model.embed_tokens(prefix_ids)
attention_mask = torch.ones_like(prefix_ids, dtype=torch.long, device=device)
cache_position = _tri_arange(0, inputs_embeds.shape[1], device)
position_ids = _llopa_position_ids_from_mask(attention_mask)
attn_impl = _resolve_attn_impl(model.config)
if attn_mode == "full":
if attn_impl == "flash_attention_2":
raise ValueError("_tri_prefill_lower_prompt_hidden prefill_attn=full is not supported with flash_attention_2.")
lower_mask = _build_tri_mask_full(prefix_ids.size(0), inputs_embeds.shape[1], 0, device, inputs_embeds.dtype)
elif attn_impl == "flash_attention_2" or _can_use_implicit_causal_mask(model.config):
lower_mask = None
else:
lower_mask = create_causal_mask(
config=model.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = model.rotary_emb(hidden_states, position_ids)
use_checkpoint = should_use_checkpoint(model)
for li in range(lower_k):
layer = model.layers[li]
if use_checkpoint:
hidden_states = checkpoint_with_cache(
layer.__call__,
hidden_states,
attention_mask=lower_mask,
position_ids=position_ids,
past_key_values=None,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
else:
hidden_states = layer(
hidden_states,
attention_mask=lower_mask,
position_ids=position_ids,
past_key_values=None,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
return hidden_states, position_ids
def _tri_build_prefill_lower_upper_index_batch(
*,
split_starts: torch.LongTensor,
valid_lens: torch.LongTensor,
system_lens: torch.LongTensor,
system_prefill: str,
device,
owner=None,
bucket_multiple: int = 0,
) -> tuple[torch.LongTensor, torch.BoolTensor, torch.LongTensor]:
split_starts = split_starts.to(device=device, dtype=torch.long)
valid_lens = valid_lens.to(device=device, dtype=torch.long)
system_lens = system_lens.to(device=device, dtype=torch.long)
batch_size = int(split_starts.size(0))
upper_indices = []
upper_lens = torch.zeros((batch_size,), device=device, dtype=torch.long)
max_upper_len = 0
for row in range(batch_size):
idx = build_prefill_lower_upper_indices(
sequence_len=int(valid_lens[row].item()),
split_start=int(split_starts[row].item()),
system_len=int(system_lens[row].item()),
system_prefill=system_prefill,
device=device,
)
upper_indices.append(idx)
upper_lens[row] = int(idx.numel())
if idx.numel() > max_upper_len:
max_upper_len = int(idx.numel())
bucket_multiple = int(bucket_multiple or 0)
target_upper_len = int(max_upper_len)
if bucket_multiple > 0 and target_upper_len > 0:
target_upper_len = ((target_upper_len + bucket_multiple - 1) // bucket_multiple) * bucket_multiple
gather_idx = None
valid_mask = None
if owner is not None and bucket_multiple > 0 and not bool(getattr(owner, "training", False)):
cache = getattr(owner, "_optimized_llopa_upper_index_workspace_cache", None)
if not isinstance(cache, dict):
cache = {}
key = ("single", str(device), int(batch_size), int(target_upper_len))
pair = cache.get(key)
if (
not isinstance(pair, tuple)
or len(pair) != 2
or not isinstance(pair[0], torch.Tensor)
or not isinstance(pair[1], torch.Tensor)
or pair[0].shape != (int(batch_size), int(target_upper_len))
or pair[1].shape != (int(batch_size), int(target_upper_len))
or pair[0].device != device
or pair[1].device != device
):
pair = (
torch.zeros((int(batch_size), int(target_upper_len)), device=device, dtype=torch.long),
torch.zeros((int(batch_size), int(target_upper_len)), device=device, dtype=torch.bool),
)
cache[key] = pair
with contextlib.suppress(Exception):
setattr(owner, "_optimized_llopa_upper_index_workspace_cache", cache)
gather_idx, valid_mask = pair
gather_idx.zero_()
valid_mask.zero_()
if gather_idx is None or valid_mask is None:
gather_idx = torch.zeros((batch_size, target_upper_len), device=device, dtype=torch.long)
valid_mask = torch.zeros((batch_size, target_upper_len), device=device, dtype=torch.bool)
for row, idx in enumerate(upper_indices):
width = int(idx.numel())
if width <= 0:
continue
gather_idx[row, :width] = idx
valid_mask[row, :width] = True
return gather_idx, valid_mask, upper_lens
def _tri_pack_indexed_tensor(
tensor: torch.Tensor,
*,
gather_idx: torch.LongTensor,
valid_mask: torch.Tensor,
pad_value,
owner=None,
workspace_name: Optional[str] = None,
):
if tensor.dim() < 2:
raise ValueError("Expected rank-2 or higher tensor for indexed packing.")
B, total_len = tensor.shape[:2]
device = tensor.device
gather_idx = gather_idx.to(device=device, dtype=torch.long)
valid_mask = valid_mask.to(device=device, dtype=torch.bool)
clamped_idx = gather_idx.clamp(min=0, max=max(total_len - 1, 0))
packed_lens = valid_mask.sum(dim=1, dtype=torch.long)
use_workspace = (
owner is not None
and workspace_name is not None
and not bool(getattr(owner, "training", False))
and int(clamped_idx.size(1)) > 0
)
packed = None
if use_workspace:
cache = getattr(owner, "_optimized_llopa_upper_pack_workspace_cache", None)
if not isinstance(cache, dict):
cache = {}
target_shape = (B, int(clamped_idx.size(1)), *tensor.shape[2:])
key = (str(workspace_name), str(device), str(tensor.dtype), tuple(int(x) for x in target_shape))
packed = cache.get(key)
if (
not isinstance(packed, torch.Tensor)
or packed.device != device
or packed.dtype != tensor.dtype
or tuple(int(x) for x in packed.shape) != tuple(int(x) for x in target_shape)
):
packed = torch.empty(target_shape, device=device, dtype=tensor.dtype)
cache[key] = packed
with contextlib.suppress(Exception):
setattr(owner, "_optimized_llopa_upper_pack_workspace_cache", cache)
if tensor.dim() == 2:
if packed is not None:
torch.gather(tensor, 1, clamped_idx, out=packed)
packed.masked_fill_(~valid_mask, pad_value)
else:
packed = tensor.gather(1, clamped_idx)
packed = packed.masked_fill(~valid_mask, pad_value)
else:
expand_shape = (*clamped_idx.shape, *tensor.shape[2:])
gather_index = clamped_idx.view(B, clamped_idx.size(1), *([1] * (tensor.dim() - 2))).expand(*expand_shape)
mask_view = ~valid_mask.view(B, clamped_idx.size(1), *([1] * (tensor.dim() - 2)))
if packed is not None:
torch.gather(tensor, 1, gather_index, out=packed)
packed.masked_fill_(mask_view, pad_value)
else:
packed = tensor.gather(1, gather_index)
packed = packed.masked_fill(mask_view, pad_value)
return packed, packed_lens
def _tri_find_last_subsequence_start(input_ids: torch.Tensor, pattern: torch.Tensor) -> Optional[int]:
if not isinstance(input_ids, torch.Tensor) or not isinstance(pattern, torch.Tensor):
return None
if input_ids.dim() == 2:
if input_ids.size(0) != 1:
return None
seq = input_ids[0]
elif input_ids.dim() == 1:
seq = input_ids
else:
return None
if pattern.dim() == 2:
if pattern.size(0) != 1:
return None
pat = pattern[0]
elif pattern.dim() == 1:
pat = pattern
else:
return None
pat_len = int(pat.numel())
seq_len = int(seq.numel())
if pat_len <= 0 or seq_len < pat_len:
return None
windows = seq.unfold(0, pat_len, 1)
matches = (windows == pat).all(dim=-1).nonzero(as_tuple=False)
if matches.numel() == 0:
return None
return int(matches[-1].item())
def _tri_pack_suffix_tensor(
tensor: torch.Tensor,
*,
split_starts: torch.LongTensor,
valid_lens: torch.LongTensor,
pad_value,
):
if tensor.dim() < 2:
raise ValueError("Expected rank-2 or higher tensor for suffix packing.")
B, total_len = tensor.shape[:2]
device = tensor.device
split_starts = split_starts.to(device=device, dtype=torch.long)
valid_lens = valid_lens.to(device=device, dtype=torch.long)
suffix_lens = valid_lens - split_starts
if torch.any(suffix_lens <= 0):
raise ValueError("All packed suffixes must have positive length.")
max_suffix_len = int(suffix_lens.max().item())
offsets = torch.arange(max_suffix_len, device=device, dtype=torch.long).unsqueeze(0)
gather_pos = split_starts.unsqueeze(1) + offsets
valid = gather_pos < valid_lens.unsqueeze(1)
gather_idx = gather_pos.clamp(max=max(total_len - 1, 0))
if tensor.dim() == 2:
packed = tensor.gather(1, gather_idx)
packed = packed.masked_fill(~valid, pad_value)
else:
expand_shape = (*gather_idx.shape, *tensor.shape[2:])
packed = tensor.gather(1, gather_idx.view(B, max_suffix_len, *([1] * (tensor.dim() - 2))).expand(*expand_shape))
packed = packed.masked_fill(~valid.view(B, max_suffix_len, *([1] * (tensor.dim() - 2))), pad_value)
return packed, valid, suffix_lens, gather_idx
def _tri_insert_suffix_specials_inband(
*,
token_ids: list[int],
input_ids: torch.LongTensor,
attention_mask: torch.Tensor,
labels: Optional[torch.LongTensor],
split_starts: Optional[torch.LongTensor] = None,
assistant_header_starts: Optional[torch.LongTensor] = None,
assistant_header_start_mask: Optional[torch.BoolTensor] = None,
) -> tuple[
torch.LongTensor,
torch.Tensor,
Optional[torch.LongTensor],
Optional[torch.LongTensor],
Optional[torch.LongTensor],
Optional[torch.BoolTensor],
]:
if not token_ids:
return input_ids, attention_mask, labels, split_starts, assistant_header_starts, assistant_header_start_mask
batch_size = int(input_ids.size(0))
device = input_ids.device
token_count = len(token_ids)
valid_lens = attention_mask.to(device=device, dtype=torch.long).sum(dim=1, dtype=torch.long)
starts_tensor = assistant_header_starts
starts_mask = assistant_header_start_mask
generated_starts = False
if starts_tensor is None and split_starts is not None:
starts_tensor = split_starts.view(batch_size, 1)
starts_mask = torch.ones((batch_size, 1), device=device, dtype=torch.bool)
generated_starts = True
elif starts_tensor is not None and starts_mask is None:
starts_mask = starts_tensor >= 0
if starts_tensor is not None:
starts_tensor = starts_tensor.to(device=device, dtype=torch.long)
if starts_tensor.dim() == 1:
starts_tensor = starts_tensor.view(batch_size, -1)
if starts_mask is not None:
starts_mask = starts_mask.to(device=device, dtype=torch.bool)
if starts_mask.dim() == 1:
starts_mask = starts_mask.view(batch_size, -1)
row_ids: list[torch.Tensor] = []
row_masks: list[torch.Tensor] = []
row_labels: list[torch.Tensor] = []
remapped_split_starts: list[int] = []
remapped_turn_starts: list[list[int]] = []
remapped_turn_masks: list[list[bool]] = []
max_len = 0
for row in range(batch_size):
valid_len = int(valid_lens[row].item())
ids_list = [int(v) for v in input_ids[row, :valid_len].tolist()]
mask_list = [int(v) for v in attention_mask[row, :valid_len].tolist()]
labels_list = None
if labels is not None:
labels_list = [int(v) for v in labels[row, :valid_len].tolist()]
positions_to_cols: dict[int, list[int]] = {}
num_turns = 0
if starts_tensor is not None and starts_mask is not None:
num_turns = int(starts_tensor.size(1))
for col in range(num_turns):
if not bool(starts_mask[row, col].item()):
continue
start = int(starts_tensor[row, col].item())
if 0 <= start < valid_len:
positions_to_cols.setdefault(start, []).append(col)
remapped_row_turns = [-1] * num_turns
remapped_row_mask = [False] * num_turns
offset = 0
for start in sorted(positions_to_cols.keys()):
insert_at = min(max(int(start) + offset, 0), len(ids_list))
ids_list[insert_at:insert_at] = list(token_ids)
mask_list[insert_at:insert_at] = [1] * token_count
if labels_list is not None:
labels_list[insert_at:insert_at] = [-100] * token_count
for col in positions_to_cols[start]:
remapped_row_turns[col] = insert_at
remapped_row_mask[col] = True
offset += token_count
if split_starts is not None:
base_split = int(split_starts[row].item())
remapped_split = base_split
if 0 <= base_split <= valid_len:
remapped_split = base_split + token_count * sum(1 for start in positions_to_cols.keys() if start < base_split)
if base_split in positions_to_cols:
remapped_split = base_split + token_count * sum(1 for start in positions_to_cols.keys() if start < base_split)
remapped_split_starts.append(remapped_split)
row_valid_len = len(ids_list)
max_len = max(max_len, row_valid_len)
row_pad_id = 0
if valid_len < int(input_ids.size(1)):
row_pad_id = int(input_ids[row, valid_len].item())
row_id_tensor = input_ids.new_full((row_valid_len,), row_pad_id)
row_id_tensor[:] = torch.tensor(ids_list, device=device, dtype=input_ids.dtype)
row_mask_tensor = attention_mask.new_zeros((row_valid_len,))
row_mask_tensor[:] = torch.tensor(mask_list, device=device, dtype=attention_mask.dtype)
row_ids.append(row_id_tensor)
row_masks.append(row_mask_tensor)
if labels is not None and labels_list is not None:
row_label_tensor = labels.new_full((row_valid_len,), -100)
row_label_tensor[:] = torch.tensor(labels_list, device=device, dtype=labels.dtype)
row_labels.append(row_label_tensor)
if starts_tensor is not None:
remapped_turn_starts.append(remapped_row_turns)
remapped_turn_masks.append(remapped_row_mask)
new_input_ids = input_ids.new_zeros((batch_size, max_len))
new_attention_mask = attention_mask.new_zeros((batch_size, max_len))
new_labels = labels.new_full((batch_size, max_len), -100) if labels is not None else None
for row in range(batch_size):
width = int(row_ids[row].size(0))
new_input_ids[row, :width] = row_ids[row]
new_attention_mask[row, :width] = row_masks[row]
if new_labels is not None:
new_labels[row, :width] = row_labels[row]
new_split_starts = None
if split_starts is not None:
new_split_starts = torch.tensor(remapped_split_starts, device=device, dtype=torch.long)
if starts_tensor is None:
return new_input_ids, new_attention_mask, new_labels, new_split_starts, None, None
max_turns = int(starts_tensor.size(1))
new_turn_starts = torch.full((batch_size, max_turns), -1, device=device, dtype=torch.long)
new_turn_mask = torch.zeros((batch_size, max_turns), device=device, dtype=torch.bool)
for row in range(batch_size):
for col in range(max_turns):
if col >= len(remapped_turn_starts[row]):
continue
new_turn_starts[row, col] = int(remapped_turn_starts[row][col])
new_turn_mask[row, col] = bool(remapped_turn_masks[row][col])
if generated_starts:
return new_input_ids, new_attention_mask, new_labels, new_split_starts, None, None
return new_input_ids, new_attention_mask, new_labels, new_split_starts, new_turn_starts, new_turn_mask
def tri_vanilla_train_forward(
self: "LlamaForCausalLM",
*,
input_ids: Optional[torch.LongTensor],
attention_mask: Optional[torch.Tensor],
labels: Optional[torch.LongTensor],
logits_to_keep: Union[int, torch.Tensor],
prefill_lower_layers: int,
prefill_attn: str,
split_start_hint,
system_len_hint,
system_prefill: str,
assistant_header_starts_hint=None,
assistant_turn_ends_hint=None,
assistant_header_start_mask_hint=None,
see_past_assistant: bool = False,
no_upper_attn: bool = False,
upper_attention_mode: str = "causal",
replay_module: str = "none",
replay_per_layers: int = -1,
) -> Optional[CausalLMOutputWithPast]:
if input_ids is None or labels is None or input_ids.dim() != 2 or labels.dim() != 2:
return None
batch_size = int(input_ids.size(0))
if attention_mask is not None:
if attention_mask.dim() != 2 or attention_mask.size(0) != batch_size:
return None
batch_attention_mask = attention_mask.to(device=input_ids.device, dtype=torch.long)
valid_lens = batch_attention_mask.sum(dim=1, dtype=torch.long)
else:
batch_attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
valid_lens = torch.full(
(batch_size,),
int(input_ids.size(1)),
device=input_ids.device,
dtype=torch.long,
)
if torch.any(valid_lens <= 0):
return None
max_valid_len = int(valid_lens.max().item())
trimmed_input_ids = input_ids[:, :max_valid_len]
trimmed_labels = labels[:, :max_valid_len]
full_attention_mask = batch_attention_mask[:, :max_valid_len]
parsed_assistant_header_hints = _tri_parse_assistant_header_start_hints(
assistant_header_starts_hint,
assistant_header_start_mask_hint,
batch_size=batch_size,
device=trimmed_labels.device,
)
parsed_assistant_turn_end_hints = _tri_parse_assistant_header_start_hints(
assistant_turn_ends_hint,
assistant_header_start_mask_hint,
batch_size=batch_size,
device=trimmed_labels.device,
)
parsed_split_starts = _tri_parse_split_start_hints(
split_start_hint,
batch_size=batch_size,
device=trimmed_labels.device,
)
if _tri_effective_fusion_mode(self) == "inband":
header_starts = None
header_start_mask = None
if parsed_assistant_header_hints is not None:
header_starts, header_start_mask = parsed_assistant_header_hints
(
trimmed_input_ids,
full_attention_mask,
trimmed_labels,
parsed_split_starts,
header_starts,
header_start_mask,
) = _tri_insert_suffix_specials_inband(
token_ids=_tri_effective_suffix_special_token_ids(self),
input_ids=trimmed_input_ids,
attention_mask=full_attention_mask,
labels=trimmed_labels,
split_starts=parsed_split_starts,
assistant_header_starts=header_starts,
assistant_header_start_mask=header_start_mask,
)
valid_lens = full_attention_mask.sum(dim=1, dtype=torch.long)
max_valid_len = int(valid_lens.max().item())
parsed_assistant_header_hints = (
(header_starts, header_start_mask) if header_starts is not None and header_start_mask is not None else None
)
valid_label_mask = trimmed_labels != -100
valid_counts = valid_label_mask.sum(dim=1, dtype=torch.long)
if torch.any(valid_counts <= 0):
_tri_vanilla_warn_once(
self,
"_tri_warned_prefill_lower_train_no_labels",
"[TRI] prefill_lower_layers training path found no supervised suffix; falling back to vanilla.",
)
return None
token_positions = torch.arange(max_valid_len, device=trimmed_labels.device, dtype=torch.long).unsqueeze(0)
first_valid = torch.where(
valid_label_mask,
token_positions,
torch.full_like(token_positions, max_valid_len),
).min(dim=1).values
last_valid = torch.where(
valid_label_mask,
token_positions,
torch.full_like(token_positions, -1),
).max(dim=1).values
contiguous_suffix = valid_counts == (last_valid - first_valid + 1)
use_multiturn = not bool(torch.all(contiguous_suffix).item())
turn_sample_rows: list[int] = []
turn_starts_list: list[int] = []
turn_ends_list: list[int] = []
if use_multiturn:
if parsed_assistant_header_hints is None:
_tri_vanilla_warn_once(
self,
"_tri_warned_prefill_lower_train_noncontiguous",
"[TRI] prefill_lower_layers training path found multi-turn labels without assistant turn metadata; falling back to vanilla.",
)
return None
header_starts, header_start_mask = parsed_assistant_header_hints
for row in range(batch_size):
spans = _tri_resolve_assistant_turn_spans(
labels_row=trimmed_labels[row],
valid_len=int(valid_lens[row].item()),
assistant_header_starts=header_starts[row],
assistant_header_start_mask=header_start_mask[row],
)
if not spans:
_tri_vanilla_warn_once(
self,
"_tri_warned_prefill_lower_train_noncontiguous",
"[TRI] prefill_lower_layers training path could not resolve multi-turn assistant spans; falling back to vanilla.",
)
return None
for turn_start, turn_end in spans:
turn_sample_rows.append(row)
turn_starts_list.append(int(turn_start))
turn_ends_list.append(int(turn_end))
split_starts = (first_valid - 1).clamp_min_(0)
if parsed_split_starts is not None:
valid_split_mask = (
(parsed_split_starts >= 0)
& (parsed_split_starts < valid_lens)
& (parsed_split_starts <= first_valid)
)
if not bool(torch.all(valid_split_mask).item()):
_tri_vanilla_warn_once(
self,
"_tri_warned_prefill_lower_train_invalid_split_start",
"[TRI] Ignoring invalid prefill_lower split start for training; falling back to the label-derived boundary.",
)
split_starts = torch.where(valid_split_mask, parsed_split_starts, split_starts)
parsed_system_lens = _tri_parse_system_len_hints(
system_len_hint,
batch_size=batch_size,
device=trimmed_labels.device,
)
if parsed_system_lens is None:
system_lens = torch.zeros_like(split_starts)
else:
system_lens = torch.minimum(parsed_system_lens.clamp_min(0), split_starts)
system_prefill = normalize_system_prefill(system_prefill)
upper_attention_mode = _normalize_upper_attention_mode(upper_attention_mode)
replay_module = _normalize_replay_module(replay_module)
replay_per_layers = _normalize_replay_per_layers(replay_per_layers)
if bool(no_upper_attn) and upper_attention_mode != "causal":
raise ValueError("upper_attention_mode=solo_v2 cannot be combined with no_upper_attn.")
if replay_module != "none" and upper_attention_mode != "causal":
raise ValueError("replay_module cannot be combined with upper_attention_mode=solo_v2.")
model = self.model
n_layers = len(model.layers)
K = max(0, min(int(prefill_lower_layers), n_layers))
if K <= 0:
return None
full_ids = trimmed_input_ids
B, L = full_ids.shape
device = full_ids.device
attn_mode = (prefill_attn or "causal").strip().lower()
if attn_mode == "prefix_full":
attn_mode = "full"
if attn_mode not in ("causal", "full"):
raise ValueError("prefill_lower_attn must be one of {'causal', 'full'}.")
prefix_full = attn_mode == "full"
inputs_embeds = model.embed_tokens(full_ids)
cache_position = _tri_arange(0, L, device)
position_ids = _llopa_position_ids_from_mask(full_attention_mask)
attn_impl = _resolve_attn_impl(model.config)
if prefix_full and attn_impl == "flash_attention_2":
raise ValueError("prefill_lower_attn=full is not supported with flash_attention_2.")
if attn_impl == "flash_attention_2":
lower_mask = None if _llopa_mask_is_all_ones(full_attention_mask) else full_attention_mask
elif not prefix_full and _can_use_implicit_causal_mask(model.config) and _llopa_mask_is_all_ones(full_attention_mask):
lower_mask = None
elif prefix_full:
lower_mask = _build_tri_mask_prefix_full_varlen(split_starts, full_attention_mask, inputs_embeds.dtype)
else:
lower_mask = create_causal_mask(
config=model.config,
input_embeds=inputs_embeds,
attention_mask=full_attention_mask,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = model.rotary_emb(hidden_states, position_ids)
use_checkpoint = should_use_checkpoint(model)
use_optimized_training_path = _llopa_use_optimized_training_path(
model,
attn_impl=attn_impl,
prefix_full=prefix_full,
no_upper_attn=bool(no_upper_attn),
)
use_custom_checkpoint = use_checkpoint and not use_optimized_training_path
for li in range(K):
layer = model.layers[li]
if use_custom_checkpoint:
hidden_states = checkpoint_with_cache(
layer.__call__,
hidden_states,
attention_mask=lower_mask,
position_ids=position_ids,
past_key_values=None,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
else:
hidden_states = layer(
hidden_states,
attention_mask=lower_mask,
position_ids=position_ids,
past_key_values=None,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
split_starts_upper = split_starts.to(device=hidden_states.device, dtype=torch.long)
valid_lens_upper = valid_lens.to(device=hidden_states.device, dtype=torch.long)
system_lens_upper = system_lens.to(device=hidden_states.device, dtype=torch.long)
upper_prepare_mode = str(getattr(self, "_optimized_llopa_upper_prepare_mode", "exact") or "exact").strip().lower()
upper_bucket_multiple = int(getattr(self, "_optimized_llopa_upper_bucket_multiple", 0) or 0)
use_upper_workspace = (
not bool(getattr(self, "training", False))
and upper_prepare_mode == "bucketed_workspace"
and upper_bucket_multiple > 0
)
upper_prepare_owner = self if use_upper_workspace else None
upper_position_ids_src = position_ids.to(device=hidden_states.device, dtype=torch.long)
decode_labels_src = trimmed_labels.to(device=hidden_states.device)
if use_multiturn:
turn_sample_ids = torch.tensor(turn_sample_rows, device=hidden_states.device, dtype=torch.long)
turn_starts = torch.tensor(turn_starts_list, device=hidden_states.device, dtype=torch.long)
turn_ends = torch.tensor(turn_ends_list, device=hidden_states.device, dtype=torch.long)
turn_system_lens = system_lens_upper.index_select(0, turn_sample_ids)
prefix_keep_lens = _tri_prefix_keep_lengths(turn_system_lens, system_prefill)
split_starts_for_suffix_specials = turn_starts
upper_gather_idx, upper_valid_mask, upper_lens = _tri_build_prefill_lower_multiturn_index_batch(
turn_starts=turn_starts,
turn_ends=turn_ends,
system_lens=turn_system_lens,
system_prefill=system_prefill,
device=hidden_states.device,
owner=upper_prepare_owner,
bucket_multiple=upper_bucket_multiple if use_upper_workspace else 0,
)
hidden_states_upper_src = hidden_states.index_select(0, turn_sample_ids)
upper_hidden, _ = _tri_pack_indexed_tensor(
hidden_states_upper_src,
gather_idx=upper_gather_idx,
valid_mask=upper_valid_mask,
pad_value=0.0,
owner=upper_prepare_owner,
workspace_name="upper_hidden",
)
upper_position_ids_src = upper_position_ids_src.index_select(0, turn_sample_ids)
upper_position_ids, _ = _tri_pack_indexed_tensor(
upper_position_ids_src,
gather_idx=upper_gather_idx,
valid_mask=upper_valid_mask,
pad_value=0,
owner=upper_prepare_owner,
workspace_name="upper_position_ids",
)
decode_labels_src = decode_labels_src.index_select(0, turn_sample_ids)
decode_labels, _ = _tri_pack_indexed_tensor(
decode_labels_src,
gather_idx=upper_gather_idx,
valid_mask=upper_valid_mask,
pad_value=-100,
owner=upper_prepare_owner,
workspace_name="upper_decode_labels",
)
else:
prefix_keep_lens = _tri_prefix_keep_lengths(system_lens_upper, system_prefill)
split_starts_for_suffix_specials = split_starts_upper
use_past_assistant = (
bool(see_past_assistant)
and parsed_assistant_header_hints is not None
and parsed_assistant_turn_end_hints is not None
)
if use_past_assistant:
header_starts, header_start_mask = parsed_assistant_header_hints
turn_ends, _ = parsed_assistant_turn_end_hints
upper_gather_idx, upper_valid_mask, upper_lens = _tri_build_prefill_lower_past_assistant_index_batch(
split_starts=split_starts_upper,
valid_lens=valid_lens_upper,
system_lens=system_lens_upper,
system_prefill=system_prefill,
assistant_header_starts=header_starts,
assistant_turn_ends=turn_ends,
assistant_header_start_mask=header_start_mask,
device=hidden_states.device,
owner=upper_prepare_owner,
bucket_multiple=upper_bucket_multiple if use_upper_workspace else 0,
)
else:
upper_gather_idx, upper_valid_mask, upper_lens = _tri_build_prefill_lower_upper_index_batch(
split_starts=split_starts_upper,
valid_lens=valid_lens_upper,
system_lens=system_lens_upper,
system_prefill=system_prefill,
device=hidden_states.device,
owner=upper_prepare_owner,
bucket_multiple=upper_bucket_multiple if use_upper_workspace else 0,
)
upper_hidden, _ = _tri_pack_indexed_tensor(
hidden_states,
gather_idx=upper_gather_idx,
valid_mask=upper_valid_mask,
pad_value=0.0,
owner=upper_prepare_owner,
workspace_name="upper_hidden",
)
upper_position_ids, _ = _tri_pack_indexed_tensor(
upper_position_ids_src,
gather_idx=upper_gather_idx,
valid_mask=upper_valid_mask,
pad_value=0,
owner=upper_prepare_owner,
workspace_name="upper_position_ids",
)
decode_labels, _ = _tri_pack_indexed_tensor(
decode_labels_src,
gather_idx=upper_gather_idx,
valid_mask=upper_valid_mask,
pad_value=-100,
owner=upper_prepare_owner,
workspace_name="upper_decode_labels",
)
upper_attention_mask = upper_valid_mask.to(dtype=full_attention_mask.dtype, device=hidden_states.device)
replay_hidden = None
replay_position_ids = None
replay_valid_mask = None
upper_layer_indices = list(range(K, n_layers))
replay_layer_set = _tri_replay_layer_index_set(
upper_layer_indices=upper_layer_indices,
replay_per_layers=replay_per_layers,
)
if not bool(no_upper_attn) and _tri_effective_fusion_mode(self) == "upper_only":
upper_hidden, upper_position_ids, upper_attention_mask, decode_labels, upper_lens = _tri_repack_upper_with_suffix_specials(
self,
upper_hidden=upper_hidden,
upper_position_ids=upper_position_ids,
upper_attention_mask=upper_attention_mask,
decode_labels=decode_labels,
prefix_keep_lens=prefix_keep_lens,
split_starts=split_starts_for_suffix_specials,
)
replay_query_mask = None
replay_enabled = (
replay_module != "none"
and not bool(no_upper_attn)
and K < n_layers
and bool(replay_layer_set)
)
if replay_enabled:
if use_multiturn:
replay_gather_idx, replay_valid_mask, _ = _tri_build_last_layer_memory_index_batch(
split_starts=turn_starts,
prefix_keep_lens=prefix_keep_lens,
device=hidden_states.device,
owner=upper_prepare_owner,
bucket_multiple=upper_bucket_multiple if use_upper_workspace else 0,
workspace_key="multiturn",
)
replay_hidden, _ = _tri_pack_indexed_tensor(
hidden_states_upper_src,
gather_idx=replay_gather_idx,
valid_mask=replay_valid_mask,
pad_value=0.0,
owner=upper_prepare_owner,
workspace_name="last_layer_memory_hidden",
)
replay_position_ids, _ = _tri_pack_indexed_tensor(
upper_position_ids_src,
gather_idx=replay_gather_idx,
valid_mask=replay_valid_mask,
pad_value=0,
owner=upper_prepare_owner,
workspace_name="last_layer_memory_position_ids",
)
else:
replay_gather_idx, replay_valid_mask, _ = _tri_build_last_layer_memory_index_batch(
split_starts=split_starts_upper,
prefix_keep_lens=prefix_keep_lens,
device=hidden_states.device,
owner=upper_prepare_owner,
bucket_multiple=upper_bucket_multiple if use_upper_workspace else 0,
workspace_key="single",
)
replay_hidden, _ = _tri_pack_indexed_tensor(
hidden_states,
gather_idx=replay_gather_idx,
valid_mask=replay_valid_mask,
pad_value=0.0,
owner=upper_prepare_owner,
workspace_name="last_layer_memory_hidden",
)
replay_position_ids, _ = _tri_pack_indexed_tensor(
upper_position_ids_src,
gather_idx=replay_gather_idx,
valid_mask=replay_valid_mask,
pad_value=0,
owner=upper_prepare_owner,
workspace_name="last_layer_memory_position_ids",
)
if isinstance(replay_hidden, torch.Tensor) and replay_hidden.numel() > 0 and int(replay_hidden.size(1)) > 0:
replay_query_mask = _tri_build_last_layer_query_mask(
prefix_keep_lens=prefix_keep_lens,
valid_lens=upper_lens,
total_len=int(upper_hidden.size(1)),
device=upper_hidden.device,
)
else:
replay_hidden = None
replay_position_ids = None
replay_valid_mask = None
replay_enabled = False
upper_cache_position = _tri_arange(0, upper_hidden.size(1), upper_hidden.device)
if K < n_layers:
if bool(no_upper_attn):
for li in range(K, n_layers):
layer = model.layers[li]
residual = upper_hidden
ffn_input = layer.post_attention_layernorm(upper_hidden)
if use_checkpoint:
ffn_out = _checkpoint_mlp(layer.mlp, ffn_input)
else:
ffn_out = layer.mlp(ffn_input)
upper_hidden = residual + ffn_out
elif upper_attention_mode == "solo_v2":
upper_pos_emb = model.rotary_emb(upper_hidden, upper_position_ids)
for li in range(K, n_layers):
upper_hidden = _tri_run_decoder_layer_upper_solo_v2(
layer=model.layers[li],
hidden_states=upper_hidden,
position_embeddings=upper_pos_emb,
with_bos=False,
use_checkpoint=use_checkpoint,
)
else:
if attn_impl == "flash_attention_2":
upper_mask = None if _llopa_mask_is_all_ones(upper_attention_mask) else upper_attention_mask
elif _can_use_implicit_causal_mask(model.config) and _llopa_mask_is_all_ones(upper_attention_mask):
upper_mask = None
else:
upper_mask = create_causal_mask(
config=model.config,
input_embeds=upper_hidden,
attention_mask=upper_attention_mask,
cache_position=upper_cache_position,
past_key_values=None,
position_ids=upper_position_ids,
)
upper_pos_emb = model.rotary_emb(upper_hidden, upper_position_ids)
for li in range(K, n_layers):
layer = model.layers[li]
use_replay_layer = (
replay_enabled
and li in replay_layer_set
and isinstance(replay_hidden, torch.Tensor)
and replay_hidden.numel() > 0
)
if use_replay_layer:
upper_hidden = _tri_run_last_layer_with_replay(
layer=layer,
hidden_states=upper_hidden,
attention_mask=upper_mask,
position_ids=upper_position_ids,
past_key_values=None,
use_cache=False,
cache_position=upper_cache_position,
position_embeddings=upper_pos_emb,
memory_hidden_states=replay_hidden,
memory_position_ids=replay_position_ids,
memory_valid_mask=replay_valid_mask,
local_valid_mask=upper_attention_mask.to(device=upper_hidden.device, dtype=torch.bool),
query_replay_mask=replay_query_mask,
module_type=replay_module,
rotary_emb=model.rotary_emb,
use_checkpoint=use_checkpoint,
)
elif use_custom_checkpoint:
upper_hidden = checkpoint_with_cache(
layer.__call__,
upper_hidden,
attention_mask=upper_mask,
position_ids=upper_position_ids,
past_key_values=None,
use_cache=False,
cache_position=upper_cache_position,
position_embeddings=upper_pos_emb,
)
else:
upper_hidden = layer(
upper_hidden,
attention_mask=upper_mask,
position_ids=upper_position_ids,
past_key_values=None,
use_cache=False,
cache_position=upper_cache_position,
position_embeddings=upper_pos_emb,
)
upper_hidden = model.norm(upper_hidden)
kept_labels = decode_labels
if isinstance(logits_to_keep, int) and logits_to_keep > 0:
keep_lens = torch.minimum(
upper_lens,
torch.full_like(upper_lens, int(logits_to_keep)),
)
keep_hidden, _, _, _ = _tri_pack_suffix_tensor(
upper_hidden,
split_starts=(upper_lens - keep_lens).clamp_min_(0),
valid_lens=upper_lens,
pad_value=0.0,
)
kept_labels, _, _, _ = _tri_pack_suffix_tensor(
decode_labels,
split_starts=(upper_lens - keep_lens).clamp_min_(0),
valid_lens=upper_lens,
pad_value=-100,
)
logits = self.lm_head(keep_hidden)
else:
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
if isinstance(slice_indices, torch.Tensor):
slice_indices = slice_indices.to(device=upper_hidden.device)
logits = self.lm_head(upper_hidden[:, slice_indices, :])
if not isinstance(logits_to_keep, int):
kept_labels = decode_labels[:, slice_indices]
loss = self.loss_function(logits=logits, labels=kept_labels, vocab_size=self.config.vocab_size)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None,
hidden_states=upper_hidden,
)
def tri_vanilla_frozen_prefix_train_forward(
self: "LlamaForCausalLM",
*,
input_ids: Optional[torch.LongTensor],
attention_mask: Optional[torch.Tensor],
labels: Optional[torch.LongTensor],
use_cache: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
past_key_values=None,
prefill_lower_layers: int,
prefill_attn: str,
split_start_hint,
system_len_hint,
system_prefill: str,
) -> Optional[CausalLMOutputWithPast]:
if input_ids is None or labels is None or input_ids.dim() != 2 or labels.dim() != 2:
return None
batch_size = int(input_ids.size(0))
if attention_mask is not None:
if attention_mask.dim() != 2 or attention_mask.size(0) != batch_size:
return None
batch_attention_mask = attention_mask.to(device=input_ids.device, dtype=torch.long)
valid_lens = batch_attention_mask.sum(dim=1, dtype=torch.long)
else:
batch_attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
valid_lens = torch.full(
(batch_size,),
int(input_ids.size(1)),
device=input_ids.device,
dtype=torch.long,
)
if torch.any(valid_lens <= 0):
return None
max_valid_len = int(valid_lens.max().item())
trimmed_input_ids = input_ids[:, :max_valid_len]
trimmed_labels = labels[:, :max_valid_len]
full_attention_mask = batch_attention_mask[:, :max_valid_len]
valid_label_mask = trimmed_labels != -100
valid_counts = valid_label_mask.sum(dim=1, dtype=torch.long)
if torch.any(valid_counts <= 0):
_tri_vanilla_warn_once(
self,
"_tri_warned_prefill_lower_freeze_no_labels",
"[TRI] prefill_lower_freeze path found no supervised suffix; falling back to vanilla.",
)
return None
token_positions = torch.arange(max_valid_len, device=trimmed_labels.device, dtype=torch.long).unsqueeze(0)
first_valid = torch.where(
valid_label_mask,
token_positions,
torch.full_like(token_positions, max_valid_len),
).min(dim=1).values
last_valid = torch.where(
valid_label_mask,
token_positions,
torch.full_like(token_positions, -1),
).max(dim=1).values
contiguous_suffix = valid_counts == (last_valid - first_valid + 1)
if not bool(torch.all(contiguous_suffix).item()):
_tri_vanilla_warn_once(
self,
"_tri_warned_prefill_lower_freeze_noncontiguous",
"[TRI] prefill_lower_freeze expects a contiguous supervised suffix; falling back to vanilla.",
)
return None
split_starts = (first_valid - 1).clamp_min_(0)
parsed_split_starts = _tri_parse_split_start_hints(
split_start_hint,
batch_size=batch_size,
device=trimmed_labels.device,
)
if parsed_split_starts is not None:
valid_split_mask = (
(parsed_split_starts >= 0)
& (parsed_split_starts < valid_lens)
& (parsed_split_starts <= first_valid)
)
if not bool(torch.all(valid_split_mask).item()):
_tri_vanilla_warn_once(
self,
"_tri_warned_prefill_lower_freeze_invalid_split_start",
"[TRI] Ignoring invalid prefill_lower freeze split start; falling back to the label-derived boundary.",
)
split_starts = torch.where(valid_split_mask, parsed_split_starts, split_starts)
parsed_system_lens = _tri_parse_system_len_hints(
system_len_hint,
batch_size=batch_size,
device=trimmed_labels.device,
)
if parsed_system_lens is None:
system_lens = torch.zeros_like(split_starts)
else:
system_lens = torch.minimum(parsed_system_lens.clamp_min(0), split_starts)
system_prefill = normalize_system_prefill(system_prefill)
model = self.model
n_layers = len(model.layers)
K = max(0, min(int(prefill_lower_layers), n_layers))
if K <= 0:
return None
full_ids = trimmed_input_ids
B, L = full_ids.shape
device = full_ids.device
attn_mode = (prefill_attn or "causal").strip().lower()
if attn_mode == "prefix_full":
attn_mode = "full"
if attn_mode not in ("causal", "full"):
raise ValueError("prefill_lower_attn must be one of {'causal', 'full'}.")
prefix_full = attn_mode == "full"
inputs_embeds = model.embed_tokens(full_ids)
cache_position = _tri_arange(0, L, device)
position_ids = _llopa_position_ids_from_mask(full_attention_mask)
attn_impl = _resolve_attn_impl(model.config)
if prefix_full and attn_impl == "flash_attention_2":
raise ValueError("prefill_lower_attn=full is not supported with flash_attention_2.")
if attn_impl == "flash_attention_2":
full_mask = None if _llopa_mask_is_all_ones(full_attention_mask) else full_attention_mask
elif not prefix_full and _can_use_implicit_causal_mask(model.config) and _llopa_mask_is_all_ones(full_attention_mask):
full_mask = None
elif prefix_full:
full_mask = _build_tri_mask_prefix_full_varlen(split_starts, full_attention_mask, inputs_embeds.dtype)
else:
full_mask = create_causal_mask(
config=model.config,
input_embeds=inputs_embeds,
attention_mask=full_attention_mask,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = model.rotary_emb(hidden_states, position_ids)
use_checkpoint = should_use_checkpoint(model)
for li in range(K):
layer = model.layers[li]
if use_checkpoint:
hidden_states = checkpoint_with_cache(
layer.__call__,
hidden_states,
attention_mask=full_mask,
position_ids=position_ids,
past_key_values=None,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
else:
hidden_states = layer(
hidden_states,
attention_mask=full_mask,
position_ids=position_ids,
past_key_values=None,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
if K < n_layers:
split_starts_upper = split_starts.to(device=hidden_states.device, dtype=torch.long)
valid_lens_upper = valid_lens.to(device=hidden_states.device, dtype=torch.long)
system_lens_upper = system_lens.to(device=hidden_states.device, dtype=torch.long)
upper_gather_idx, upper_valid_mask, _ = _tri_build_prefill_lower_upper_index_batch(
split_starts=split_starts_upper,
valid_lens=valid_lens_upper,
system_lens=system_lens_upper,
system_prefill=system_prefill,
device=hidden_states.device,
)
upper_visible_mask = torch.zeros((B, L), device=hidden_states.device, dtype=torch.bool)
if upper_gather_idx.numel() > 0 and upper_valid_mask.any():
upper_visible_mask.scatter_(1, upper_gather_idx, upper_valid_mask)
valid_token_mask = token_positions[:, :L] < valid_lens.unsqueeze(1)
prefix_freeze_mask = valid_token_mask & (~upper_visible_mask)
prefix_freeze_mask = prefix_freeze_mask.unsqueeze(-1)
frozen_hidden_states = hidden_states
for li in range(K, n_layers):
layer = model.layers[li]
layer_input = torch.where(prefix_freeze_mask, frozen_hidden_states, hidden_states)
if use_checkpoint:
hidden_states = checkpoint_with_cache(
layer.__call__,
layer_input,
attention_mask=full_mask,
position_ids=position_ids,
past_key_values=None,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
else:
hidden_states = layer(
layer_input,
attention_mask=full_mask,
position_ids=position_ids,
past_key_values=None,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = torch.where(prefix_freeze_mask, frozen_hidden_states, hidden_states)
hidden_states = model.norm(hidden_states)
if isinstance(logits_to_keep, int) and logits_to_keep > 0:
keep_hidden, kept_labels, _, _ = _tri_pack_suffix_tensor(
hidden_states,
split_starts=(valid_lens - torch.minimum(valid_lens, torch.full_like(valid_lens, int(logits_to_keep)))).clamp_min_(0),
valid_lens=valid_lens,
pad_value=0.0,
)
kept_labels, _, _, _ = _tri_pack_suffix_tensor(
trimmed_labels,
split_starts=(valid_lens - torch.minimum(valid_lens, torch.full_like(valid_lens, int(logits_to_keep)))).clamp_min_(0),
valid_lens=valid_lens,
pad_value=-100,
)
logits = self.lm_head(keep_hidden)
loss = self.loss_function(logits=logits, labels=kept_labels, vocab_size=self.config.vocab_size)
else:
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
if isinstance(slice_indices, torch.Tensor):
slice_indices = slice_indices.to(device=hidden_states.device)
logits = self.lm_head(hidden_states[:, slice_indices, :])
selected_labels = trimmed_labels[:, slice_indices] if isinstance(slice_indices, slice) else trimmed_labels.index_select(1, slice_indices)
loss = self.loss_function(logits=logits, labels=selected_labels, vocab_size=self.config.vocab_size)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None,
hidden_states=hidden_states,
)
LlamaForCausalLM.tri_vanilla_train_forward = tri_vanilla_train_forward
LlamaForCausalLM.tri_vanilla_frozen_prefix_train_forward = tri_vanilla_frozen_prefix_train_forward
def tri_vanilla_frozen_prefix_infer_forward(
self: "LlamaForCausalLM",
*,
input_ids: Optional[torch.LongTensor],
attention_mask: Optional[torch.Tensor],
logits_to_keep: Union[int, torch.Tensor],
prefill_lower_layers: int,
prefill_attn: str,
split_start_hint,
system_len_hint,
system_prefill: str,
) -> Optional[CausalLMOutputWithPast]:
if input_ids is None or input_ids.dim() != 2:
return None
batch_size = int(input_ids.size(0))
if attention_mask is not None:
if attention_mask.dim() != 2 or attention_mask.size(0) != batch_size:
return None
batch_attention_mask = attention_mask.to(device=input_ids.device, dtype=torch.long)
valid_lens = batch_attention_mask.sum(dim=1, dtype=torch.long)
else:
batch_attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
valid_lens = torch.full(
(batch_size,),
int(input_ids.size(1)),
device=input_ids.device,
dtype=torch.long,
)
if torch.any(valid_lens <= 0):
return None
max_valid_len = int(valid_lens.max().item())
trimmed_input_ids = input_ids[:, :max_valid_len]
full_attention_mask = batch_attention_mask[:, :max_valid_len]
parsed_split_starts = _tri_parse_split_start_hints(
split_start_hint,
batch_size=batch_size,
device=trimmed_input_ids.device,
)
if parsed_split_starts is None:
split_starts = (valid_lens - 1).clamp_min_(0)
else:
valid_split_mask = (parsed_split_starts >= 0) & (parsed_split_starts < valid_lens)
if not bool(torch.all(valid_split_mask).item()):
_tri_vanilla_warn_once(
self,
"_tri_warned_prefill_lower_freeze_infer_invalid_split_start",
"[TRI] Ignoring invalid prefill_lower freeze split start for inference; using the last prompt token.",
)
split_starts = torch.where(valid_split_mask, parsed_split_starts, (valid_lens - 1).clamp_min_(0))
parsed_system_lens = _tri_parse_system_len_hints(
system_len_hint,
batch_size=batch_size,
device=trimmed_input_ids.device,
)
if parsed_system_lens is None:
system_lens = torch.zeros_like(split_starts)
else:
system_lens = torch.minimum(parsed_system_lens.clamp_min(0), split_starts)
system_prefill = normalize_system_prefill(system_prefill)
model = self.model
n_layers = len(model.layers)
K = max(0, min(int(prefill_lower_layers), n_layers))
if K <= 0:
return None
full_ids = trimmed_input_ids
B, L = full_ids.shape
device = full_ids.device
attn_mode = (prefill_attn or "causal").strip().lower()
if attn_mode == "prefix_full":
attn_mode = "full"
if attn_mode not in ("causal", "full"):
raise ValueError("prefill_lower_attn must be one of {'causal', 'full'}.")
prefix_full = attn_mode == "full"
inputs_embeds = model.embed_tokens(full_ids)
cache_position = _tri_arange(0, L, device)
position_ids = _llopa_position_ids_from_mask(full_attention_mask)
attn_impl = _resolve_attn_impl(model.config)
if prefix_full and attn_impl == "flash_attention_2":
raise ValueError("prefill_lower_attn=full is not supported with flash_attention_2.")
if attn_impl == "flash_attention_2":
full_mask = None if _llopa_mask_is_all_ones(full_attention_mask) else full_attention_mask
elif not prefix_full and _can_use_implicit_causal_mask(model.config) and _llopa_mask_is_all_ones(full_attention_mask):
full_mask = None
elif prefix_full:
full_mask = _build_tri_mask_prefix_full_varlen(split_starts, full_attention_mask, inputs_embeds.dtype)
else:
full_mask = create_causal_mask(
config=model.config,
input_embeds=inputs_embeds,
attention_mask=full_attention_mask,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = model.rotary_emb(hidden_states, position_ids)
use_checkpoint = should_use_checkpoint(model)
for li in range(K):
layer = model.layers[li]
if use_checkpoint:
hidden_states = checkpoint_with_cache(
layer.__call__,
hidden_states,
attention_mask=full_mask,
position_ids=position_ids,
past_key_values=None,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
else:
hidden_states = layer(
hidden_states,
attention_mask=full_mask,
position_ids=position_ids,
past_key_values=None,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
if K < n_layers:
split_starts_upper = split_starts.to(device=hidden_states.device, dtype=torch.long)
valid_lens_upper = valid_lens.to(device=hidden_states.device, dtype=torch.long)
system_lens_upper = system_lens.to(device=hidden_states.device, dtype=torch.long)
upper_gather_idx, upper_valid_mask, _ = _tri_build_prefill_lower_upper_index_batch(
split_starts=split_starts_upper,
valid_lens=valid_lens_upper,
system_lens=system_lens_upper,
system_prefill=system_prefill,
device=hidden_states.device,
)
upper_visible_mask = torch.zeros((B, L), device=hidden_states.device, dtype=torch.bool)
if upper_gather_idx.numel() > 0 and upper_valid_mask.any():
upper_visible_mask.scatter_(1, upper_gather_idx, upper_valid_mask)
token_positions = torch.arange(L, device=hidden_states.device, dtype=torch.long).unsqueeze(0)
valid_token_mask = token_positions[:, :L] < valid_lens.unsqueeze(1)
prefix_freeze_mask = valid_token_mask & (~upper_visible_mask)
prefix_freeze_mask = prefix_freeze_mask.unsqueeze(-1)
frozen_hidden_states = hidden_states
for li in range(K, n_layers):
layer = model.layers[li]
layer_input = torch.where(prefix_freeze_mask, frozen_hidden_states, hidden_states)
if use_checkpoint:
hidden_states = checkpoint_with_cache(
layer.__call__,
layer_input,
attention_mask=full_mask,
position_ids=position_ids,
past_key_values=None,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
else:
hidden_states = layer(
layer_input,
attention_mask=full_mask,
position_ids=position_ids,
past_key_values=None,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = torch.where(prefix_freeze_mask, frozen_hidden_states, hidden_states)
hidden_states = model.norm(hidden_states)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
if isinstance(slice_indices, torch.Tensor):
slice_indices = slice_indices.to(device=hidden_states.device)
logits = self.lm_head(hidden_states[:, slice_indices, :])
return CausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=None,
hidden_states=hidden_states,
)
LlamaForCausalLM.tri_vanilla_frozen_prefix_infer_forward = tri_vanilla_frozen_prefix_infer_forward
def tri_vanilla_solo_prefix_infer_forward(
self: "LlamaForCausalLM",
*,
input_ids: Optional[torch.LongTensor],
attention_mask: Optional[torch.Tensor],
logits_to_keep: Union[int, torch.Tensor],
prefill_lower_layers: int,
prefill_attn: str,
split_start_hint,
system_len_hint,
system_prefill: str,
) -> Optional[CausalLMOutputWithPast]:
if input_ids is None or input_ids.dim() != 2:
return None
batch_size = int(input_ids.size(0))
if attention_mask is not None:
if attention_mask.dim() != 2 or attention_mask.size(0) != batch_size:
return None
batch_attention_mask = attention_mask.to(device=input_ids.device, dtype=torch.long)
valid_lens = batch_attention_mask.sum(dim=1, dtype=torch.long)
else:
batch_attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
valid_lens = torch.full((batch_size,), int(input_ids.size(1)), device=input_ids.device, dtype=torch.long)
if torch.any(valid_lens <= 0):
return None
max_valid_len = int(valid_lens.max().item())
trimmed_input_ids = input_ids[:, :max_valid_len]
full_attention_mask = batch_attention_mask[:, :max_valid_len]
parsed_split_starts = _tri_parse_split_start_hints(
split_start_hint,
batch_size=batch_size,
device=trimmed_input_ids.device,
)
if parsed_split_starts is None:
split_starts = (valid_lens - 1).clamp_min_(0)
else:
valid_split_mask = (parsed_split_starts >= 0) & (parsed_split_starts < valid_lens)
if not bool(torch.all(valid_split_mask).item()):
_tri_vanilla_warn_once(
self,
"_tri_warned_prefill_lower_solo_infer_invalid_split_start",
"[TRI] Ignoring invalid prefill_lower solo split start for inference; using the last prompt token.",
)
split_starts = torch.where(valid_split_mask, parsed_split_starts, (valid_lens - 1).clamp_min_(0))
parsed_system_lens = _tri_parse_system_len_hints(
system_len_hint,
batch_size=batch_size,
device=trimmed_input_ids.device,
)
if parsed_system_lens is None:
system_lens = torch.zeros_like(split_starts)
else:
system_lens = torch.minimum(parsed_system_lens.clamp_min(0), split_starts)
system_prefill = normalize_system_prefill(system_prefill)
model = self.model
n_layers = len(model.layers)
K = max(0, min(int(prefill_lower_layers), n_layers))
if K <= 0:
return None
B, L = trimmed_input_ids.shape
device = trimmed_input_ids.device
attn_mode = (prefill_attn or "causal").strip().lower()
if attn_mode == "prefix_full":
attn_mode = "full"
if attn_mode not in ("causal", "full"):
raise ValueError("prefill_lower_attn must be one of {'causal', 'full'}.")
prefix_full = attn_mode == "full"
inputs_embeds = model.embed_tokens(trimmed_input_ids)
cache_position = _tri_arange(0, L, device)
position_ids = _llopa_position_ids_from_mask(full_attention_mask)
attn_impl = _resolve_attn_impl(model.config)
if prefix_full and attn_impl == "flash_attention_2":
raise ValueError("prefill_lower_attn=full is not supported with flash_attention_2.")
if attn_impl == "flash_attention_2":
full_mask = None if _llopa_mask_is_all_ones(full_attention_mask) else full_attention_mask
elif not prefix_full and _can_use_implicit_causal_mask(model.config) and _llopa_mask_is_all_ones(full_attention_mask):
full_mask = None
elif prefix_full:
full_mask = _build_tri_mask_prefix_full_varlen(split_starts, full_attention_mask, inputs_embeds.dtype)
else:
full_mask = create_causal_mask(
config=model.config,
input_embeds=inputs_embeds,
attention_mask=full_attention_mask,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = model.rotary_emb(hidden_states, position_ids)
use_checkpoint = should_use_checkpoint(model)
for li in range(K):
layer = model.layers[li]
if use_checkpoint:
hidden_states = checkpoint_with_cache(
layer.__call__,
hidden_states,
attention_mask=full_mask,
position_ids=position_ids,
past_key_values=None,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
else:
hidden_states = layer(
hidden_states,
attention_mask=full_mask,
position_ids=position_ids,
past_key_values=None,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
if K < n_layers:
split_starts_upper = split_starts.to(device=hidden_states.device, dtype=torch.long)
valid_lens_upper = valid_lens.to(device=hidden_states.device, dtype=torch.long)
system_lens_upper = system_lens.to(device=hidden_states.device, dtype=torch.long)
upper_gather_idx, upper_valid_mask, _ = _tri_build_prefill_lower_upper_index_batch(
split_starts=split_starts_upper,
valid_lens=valid_lens_upper,
system_lens=system_lens_upper,
system_prefill=system_prefill,
device=hidden_states.device,
)
upper_visible_mask = torch.zeros((B, L), device=hidden_states.device, dtype=torch.bool)
if upper_gather_idx.numel() > 0 and upper_valid_mask.any():
upper_visible_mask.scatter_(1, upper_gather_idx, upper_valid_mask)
token_positions = torch.arange(L, device=hidden_states.device, dtype=torch.long).unsqueeze(0)
valid_token_mask = token_positions[:, :L] < valid_lens.unsqueeze(1)
prefix_self_mask = valid_token_mask & (~upper_visible_mask)
has_prefix_self = bool(prefix_self_mask.any().item())
if has_prefix_self:
prefix_row_idx, prefix_col_idx = prefix_self_mask.nonzero(as_tuple=True)
singleton_cache_position = _tri_arange(0, 1, hidden_states.device)
singleton_attention_mask = None
for li in range(K, n_layers):
layer = model.layers[li]
layer_input = hidden_states
if use_checkpoint:
upper_hidden = checkpoint_with_cache(
layer.__call__,
layer_input,
attention_mask=full_mask,
position_ids=position_ids,
past_key_values=None,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
else:
upper_hidden = layer(
layer_input,
attention_mask=full_mask,
position_ids=position_ids,
past_key_values=None,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
if has_prefix_self:
singleton_hidden = layer_input[prefix_row_idx, prefix_col_idx].unsqueeze(1)
singleton_position_ids = position_ids[prefix_row_idx, prefix_col_idx].unsqueeze(1)
singleton_pos_emb = model.rotary_emb(singleton_hidden, singleton_position_ids)
if use_checkpoint:
singleton_outputs = checkpoint_with_cache(
layer.__call__,
singleton_hidden,
attention_mask=singleton_attention_mask,
position_ids=singleton_position_ids,
past_key_values=None,
use_cache=False,
cache_position=singleton_cache_position,
position_embeddings=singleton_pos_emb,
)
else:
singleton_outputs = layer(
singleton_hidden,
attention_mask=singleton_attention_mask,
position_ids=singleton_position_ids,
past_key_values=None,
use_cache=False,
cache_position=singleton_cache_position,
position_embeddings=singleton_pos_emb,
)
merged_hidden = upper_hidden.clone()
merged_hidden[prefix_row_idx, prefix_col_idx] = singleton_outputs.squeeze(1)
hidden_states = merged_hidden
else:
hidden_states = upper_hidden
hidden_states = model.norm(hidden_states)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
if isinstance(slice_indices, torch.Tensor):
slice_indices = slice_indices.to(device=hidden_states.device)
logits = self.lm_head(hidden_states[:, slice_indices, :])
return CausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=None,
hidden_states=hidden_states,
)
LlamaForCausalLM.tri_vanilla_solo_prefix_infer_forward = tri_vanilla_solo_prefix_infer_forward
def _tri_vanilla_solo_attention_v2_forward(
self: "LlamaForCausalLM",
*,
input_ids: Optional[torch.LongTensor],
attention_mask: Optional[torch.Tensor],
labels: Optional[torch.LongTensor],
logits_to_keep: Union[int, torch.Tensor],
prefill_lower_layers: int,
prefill_attn: str,
split_start_hint,
system_len_hint,
system_prefill: str,
with_bos: bool = False,
) -> Optional[CausalLMOutputWithPast]:
if input_ids is None or input_ids.dim() != 2:
return None
if labels is not None and (labels.dim() != 2 or labels.shape != input_ids.shape):
return None
batch_size = int(input_ids.size(0))
if attention_mask is not None:
if attention_mask.dim() != 2 or attention_mask.size(0) != batch_size:
return None
batch_attention_mask = attention_mask.to(device=input_ids.device, dtype=torch.long)
valid_lens = batch_attention_mask.sum(dim=1, dtype=torch.long)
else:
batch_attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
valid_lens = torch.full((batch_size,), int(input_ids.size(1)), device=input_ids.device, dtype=torch.long)
if torch.any(valid_lens <= 0):
return None
max_valid_len = int(valid_lens.max().item())
trimmed_input_ids = input_ids[:, :max_valid_len]
trimmed_labels = labels[:, :max_valid_len] if labels is not None else None
full_attention_mask = batch_attention_mask[:, :max_valid_len]
if trimmed_labels is not None:
valid_label_mask = trimmed_labels != -100
valid_counts = valid_label_mask.sum(dim=1, dtype=torch.long)
if torch.any(valid_counts <= 0):
_tri_vanilla_warn_once(
self,
"_tri_warned_prefill_lower_solo_v2_no_labels",
"[TRI] prefill_lower_solo_attention_v2 path found no supervised suffix; falling back to vanilla.",
)
return None
token_positions = torch.arange(max_valid_len, device=trimmed_labels.device, dtype=torch.long).unsqueeze(0)
first_valid = torch.where(
valid_label_mask,
token_positions,
torch.full_like(token_positions, max_valid_len),
).min(dim=1).values
last_valid = torch.where(
valid_label_mask,
token_positions,
torch.full_like(token_positions, -1),
).max(dim=1).values
contiguous_suffix = valid_counts == (last_valid - first_valid + 1)
if not bool(torch.all(contiguous_suffix).item()):
_tri_vanilla_warn_once(
self,
"_tri_warned_prefill_lower_solo_v2_noncontiguous",
"[TRI] prefill_lower_solo_attention_v2 expects a contiguous supervised suffix; falling back to vanilla.",
)
return None
split_starts = (first_valid - 1).clamp_min_(0)
parsed_split_starts = _tri_parse_split_start_hints(
split_start_hint,
batch_size=batch_size,
device=trimmed_labels.device,
)
if parsed_split_starts is not None:
valid_split_mask = (
(parsed_split_starts >= 0)
& (parsed_split_starts < valid_lens)
& (parsed_split_starts <= first_valid)
)
split_starts = torch.where(valid_split_mask, parsed_split_starts, split_starts)
else:
parsed_split_starts = _tri_parse_split_start_hints(
split_start_hint,
batch_size=batch_size,
device=trimmed_input_ids.device,
)
if parsed_split_starts is None:
split_starts = (valid_lens - 1).clamp_min_(0)
else:
valid_split_mask = (parsed_split_starts >= 0) & (parsed_split_starts < valid_lens)
if not bool(torch.all(valid_split_mask).item()):
_tri_vanilla_warn_once(
self,
"_tri_warned_prefill_lower_solo_v2_infer_invalid_split_start",
"[TRI] Ignoring invalid prefill_lower solo_v2 split start for inference; using the last prompt token.",
)
split_starts = torch.where(valid_split_mask, parsed_split_starts, (valid_lens - 1).clamp_min_(0))
_ = _tri_parse_system_len_hints(
system_len_hint,
batch_size=batch_size,
device=trimmed_input_ids.device,
)
_ = normalize_system_prefill(system_prefill)
model = self.model
n_layers = len(model.layers)
K = max(0, min(int(prefill_lower_layers), n_layers))
if K <= 0:
return None
_, L = trimmed_input_ids.shape
device = trimmed_input_ids.device
attn_mode = (prefill_attn or "causal").strip().lower()
if attn_mode == "prefix_full":
attn_mode = "full"
if attn_mode not in ("causal", "full"):
raise ValueError("prefill_lower_attn must be one of {'causal', 'full'}.")
prefix_full = attn_mode == "full"
inputs_embeds = model.embed_tokens(trimmed_input_ids)
cache_position = _tri_arange(0, L, device)
position_ids = _llopa_position_ids_from_mask(full_attention_mask)
attn_impl = _resolve_attn_impl(model.config)
if prefix_full and attn_impl == "flash_attention_2":
raise ValueError("prefill_lower_attn=full is not supported with flash_attention_2.")
if attn_impl == "flash_attention_2":
full_mask = None if _llopa_mask_is_all_ones(full_attention_mask) else full_attention_mask
elif not prefix_full and _can_use_implicit_causal_mask(model.config) and _llopa_mask_is_all_ones(full_attention_mask):
full_mask = None
elif prefix_full:
full_mask = _build_tri_mask_prefix_full_varlen(split_starts, full_attention_mask, inputs_embeds.dtype)
else:
full_mask = create_causal_mask(
config=model.config,
input_embeds=inputs_embeds,
attention_mask=full_attention_mask,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = model.rotary_emb(hidden_states, position_ids)
use_checkpoint = should_use_checkpoint(model)
use_optimized_training_path = _llopa_use_optimized_training_path(
model,
attn_impl=attn_impl,
prefix_full=prefix_full,
no_upper_attn=False,
)
use_custom_checkpoint = use_checkpoint and not use_optimized_training_path
for li in range(K):
layer = model.layers[li]
if use_custom_checkpoint:
hidden_states = checkpoint_with_cache(
layer.__call__,
hidden_states,
attention_mask=full_mask,
position_ids=position_ids,
past_key_values=None,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
else:
hidden_states = layer(
hidden_states,
attention_mask=full_mask,
position_ids=position_ids,
past_key_values=None,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
compacted_solo_v2_upper = False
compacted_labels = None
if trimmed_labels is not None and isinstance(logits_to_keep, int) and int(logits_to_keep) > 0:
compacted = _tri_compact_solo_v2_upper_training_inputs(
hidden_states=hidden_states,
position_ids=position_ids,
labels=trimmed_labels,
valid_lens=valid_lens,
logits_to_keep=int(logits_to_keep),
with_bos=bool(with_bos),
)
if compacted is not None:
hidden_states, position_ids, compacted_labels, _, _ = compacted
position_embeddings = model.rotary_emb(hidden_states, position_ids)
compacted_solo_v2_upper = True
if K < n_layers:
for li in range(K, n_layers):
hidden_states = _tri_run_decoder_layer_upper_solo_v2(
layer=model.layers[li],
hidden_states=hidden_states,
position_embeddings=position_embeddings,
with_bos=bool(with_bos),
use_checkpoint=use_custom_checkpoint,
)
hidden_states = model.norm(hidden_states)
if compacted_solo_v2_upper and compacted_labels is not None:
if bool(with_bos):
logits_hidden = hidden_states[:, 1:, :]
loss_labels = compacted_labels[:, 1:]
else:
logits_hidden = hidden_states
loss_labels = compacted_labels
logits = self.lm_head(logits_hidden)
loss = self.loss_function(logits=logits, labels=loss_labels, vocab_size=self.config.vocab_size)
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=None, hidden_states=hidden_states)
if trimmed_labels is None:
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
if isinstance(slice_indices, torch.Tensor):
slice_indices = slice_indices.to(device=hidden_states.device)
logits = self.lm_head(hidden_states[:, slice_indices, :])
return CausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=None,
hidden_states=hidden_states,
)
if isinstance(logits_to_keep, int) and logits_to_keep > 0:
keep_hidden, kept_labels, _, _ = _tri_pack_suffix_tensor(
hidden_states,
split_starts=(valid_lens - torch.minimum(valid_lens, torch.full_like(valid_lens, int(logits_to_keep)))).clamp_min_(0),
valid_lens=valid_lens,
pad_value=0.0,
)
kept_labels, _, _, _ = _tri_pack_suffix_tensor(
trimmed_labels,
split_starts=(valid_lens - torch.minimum(valid_lens, torch.full_like(valid_lens, int(logits_to_keep)))).clamp_min_(0),
valid_lens=valid_lens,
pad_value=-100,
)
logits = self.lm_head(keep_hidden)
loss = self.loss_function(logits=logits, labels=kept_labels, vocab_size=self.config.vocab_size)
else:
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
if isinstance(slice_indices, torch.Tensor):
slice_indices = slice_indices.to(device=hidden_states.device)
logits = self.lm_head(hidden_states[:, slice_indices, :])
selected_labels = (
trimmed_labels[:, slice_indices]
if isinstance(slice_indices, slice)
else trimmed_labels.index_select(1, slice_indices)
)
loss = self.loss_function(logits=logits, labels=selected_labels, vocab_size=self.config.vocab_size)
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=None, hidden_states=hidden_states)
def tri_vanilla_solo_attention_v2_infer_forward(
self: "LlamaForCausalLM",
*,
input_ids: Optional[torch.LongTensor],
attention_mask: Optional[torch.Tensor],
logits_to_keep: Union[int, torch.Tensor],
prefill_lower_layers: int,
prefill_attn: str,
split_start_hint,
system_len_hint,
system_prefill: str,
with_bos: bool = False,
) -> Optional[CausalLMOutputWithPast]:
return _tri_vanilla_solo_attention_v2_forward(
self,
input_ids=input_ids,
attention_mask=attention_mask,
labels=None,
logits_to_keep=logits_to_keep,
prefill_lower_layers=prefill_lower_layers,
prefill_attn=prefill_attn,
split_start_hint=split_start_hint,
system_len_hint=system_len_hint,
system_prefill=system_prefill,
with_bos=bool(with_bos),
)
LlamaForCausalLM.tri_vanilla_solo_attention_v2_infer_forward = tri_vanilla_solo_attention_v2_infer_forward
def tri_vanilla_solo_attention_v2_train_forward(
self: "LlamaForCausalLM",
*,
input_ids: Optional[torch.LongTensor],
attention_mask: Optional[torch.Tensor],
labels: Optional[torch.LongTensor],
logits_to_keep: Union[int, torch.Tensor],
prefill_lower_layers: int,
prefill_attn: str,
split_start_hint,
system_len_hint,
system_prefill: str,
with_bos: bool = False,
) -> Optional[CausalLMOutputWithPast]:
return _tri_vanilla_solo_attention_v2_forward(
self,
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
logits_to_keep=logits_to_keep,
prefill_lower_layers=prefill_lower_layers,
prefill_attn=prefill_attn,
split_start_hint=split_start_hint,
system_len_hint=system_len_hint,
system_prefill=system_prefill,
with_bos=bool(with_bos),
)
LlamaForCausalLM.tri_vanilla_solo_attention_v2_train_forward = tri_vanilla_solo_attention_v2_train_forward
def tri_vanilla_solo_prefix_train_forward(
self: "LlamaForCausalLM",
*,
input_ids: Optional[torch.LongTensor],
attention_mask: Optional[torch.Tensor],
labels: Optional[torch.LongTensor],
logits_to_keep: Union[int, torch.Tensor],
prefill_lower_layers: int,
prefill_attn: str,
split_start_hint,
system_len_hint,
system_prefill: str,
) -> Optional[CausalLMOutputWithPast]:
if input_ids is None or labels is None or input_ids.dim() != 2 or labels.dim() != 2:
return None
batch_size = int(input_ids.size(0))
if attention_mask is not None:
if attention_mask.dim() != 2 or attention_mask.size(0) != batch_size:
return None
batch_attention_mask = attention_mask.to(device=input_ids.device, dtype=torch.long)
valid_lens = batch_attention_mask.sum(dim=1, dtype=torch.long)
else:
batch_attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
valid_lens = torch.full((batch_size,), int(input_ids.size(1)), device=input_ids.device, dtype=torch.long)
if torch.any(valid_lens <= 0):
return None
max_valid_len = int(valid_lens.max().item())
trimmed_input_ids = input_ids[:, :max_valid_len]
trimmed_labels = labels[:, :max_valid_len]
full_attention_mask = batch_attention_mask[:, :max_valid_len]
valid_label_mask = trimmed_labels != -100
valid_counts = valid_label_mask.sum(dim=1, dtype=torch.long)
if torch.any(valid_counts <= 0):
_tri_vanilla_warn_once(
self,
"_tri_warned_prefill_lower_solo_no_labels",
"[TRI] prefill_lower_solo_attention path found no supervised suffix; falling back to vanilla.",
)
return None
token_positions = torch.arange(max_valid_len, device=trimmed_labels.device, dtype=torch.long).unsqueeze(0)
first_valid = torch.where(
valid_label_mask,
token_positions,
torch.full_like(token_positions, max_valid_len),
).min(dim=1).values
last_valid = torch.where(
valid_label_mask,
token_positions,
torch.full_like(token_positions, -1),
).max(dim=1).values
contiguous_suffix = valid_counts == (last_valid - first_valid + 1)
if not bool(torch.all(contiguous_suffix).item()):
_tri_vanilla_warn_once(
self,
"_tri_warned_prefill_lower_solo_noncontiguous",
"[TRI] prefill_lower_solo_attention expects a contiguous supervised suffix; falling back to vanilla.",
)
return None
split_starts = (first_valid - 1).clamp_min_(0)
parsed_split_starts = _tri_parse_split_start_hints(
split_start_hint,
batch_size=batch_size,
device=trimmed_labels.device,
)
if parsed_split_starts is not None:
valid_split_mask = (
(parsed_split_starts >= 0)
& (parsed_split_starts < valid_lens)
& (parsed_split_starts <= first_valid)
)
split_starts = torch.where(valid_split_mask, parsed_split_starts, split_starts)
parsed_system_lens = _tri_parse_system_len_hints(
system_len_hint,
batch_size=batch_size,
device=trimmed_labels.device,
)
if parsed_system_lens is None:
system_lens = torch.zeros_like(split_starts)
else:
system_lens = torch.minimum(parsed_system_lens.clamp_min(0), split_starts)
system_prefill = normalize_system_prefill(system_prefill)
model = self.model
n_layers = len(model.layers)
K = max(0, min(int(prefill_lower_layers), n_layers))
if K <= 0:
return None
B, L = trimmed_input_ids.shape
device = trimmed_input_ids.device
attn_mode = (prefill_attn or "causal").strip().lower()
if attn_mode == "prefix_full":
attn_mode = "full"
if attn_mode not in ("causal", "full"):
raise ValueError("prefill_lower_attn must be one of {'causal', 'full'}.")
prefix_full = attn_mode == "full"
inputs_embeds = model.embed_tokens(trimmed_input_ids)
cache_position = _tri_arange(0, L, device)
position_ids = _llopa_position_ids_from_mask(full_attention_mask)
attn_impl = _resolve_attn_impl(model.config)
if prefix_full and attn_impl == "flash_attention_2":
raise ValueError("prefill_lower_attn=full is not supported with flash_attention_2.")
if attn_impl == "flash_attention_2":
full_mask = None if _llopa_mask_is_all_ones(full_attention_mask) else full_attention_mask
elif not prefix_full and _can_use_implicit_causal_mask(model.config) and _llopa_mask_is_all_ones(full_attention_mask):
full_mask = None
elif prefix_full:
full_mask = _build_tri_mask_prefix_full_varlen(split_starts, full_attention_mask, inputs_embeds.dtype)
else:
full_mask = create_causal_mask(
config=model.config,
input_embeds=inputs_embeds,
attention_mask=full_attention_mask,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = model.rotary_emb(hidden_states, position_ids)
use_checkpoint = should_use_checkpoint(model)
for li in range(K):
layer = model.layers[li]
if use_checkpoint:
hidden_states = checkpoint_with_cache(
layer.__call__, hidden_states, attention_mask=full_mask, position_ids=position_ids,
past_key_values=None, use_cache=False, cache_position=cache_position, position_embeddings=position_embeddings,
)
else:
hidden_states = layer(
hidden_states, attention_mask=full_mask, position_ids=position_ids,
past_key_values=None, use_cache=False, cache_position=cache_position, position_embeddings=position_embeddings,
)
if K < n_layers:
split_starts_upper = split_starts.to(device=hidden_states.device, dtype=torch.long)
valid_lens_upper = valid_lens.to(device=hidden_states.device, dtype=torch.long)
system_lens_upper = system_lens.to(device=hidden_states.device, dtype=torch.long)
upper_gather_idx, upper_valid_mask, _ = _tri_build_prefill_lower_upper_index_batch(
split_starts=split_starts_upper,
valid_lens=valid_lens_upper,
system_lens=system_lens_upper,
system_prefill=system_prefill,
device=hidden_states.device,
)
upper_visible_mask = torch.zeros((B, L), device=hidden_states.device, dtype=torch.bool)
if upper_gather_idx.numel() > 0 and upper_valid_mask.any():
upper_visible_mask.scatter_(1, upper_gather_idx, upper_valid_mask)
valid_token_mask = token_positions[:, :L] < valid_lens.unsqueeze(1)
prefix_self_mask = valid_token_mask & (~upper_visible_mask)
has_prefix_self = bool(prefix_self_mask.any().item())
if has_prefix_self:
prefix_row_idx, prefix_col_idx = prefix_self_mask.nonzero(as_tuple=True)
singleton_cache_position = _tri_arange(0, 1, device)
singleton_attention_mask = None
for li in range(K, n_layers):
layer = model.layers[li]
layer_input = hidden_states
if use_checkpoint:
upper_hidden = checkpoint_with_cache(
layer.__call__, layer_input, attention_mask=full_mask, position_ids=position_ids,
past_key_values=None, use_cache=False, cache_position=cache_position, position_embeddings=position_embeddings,
)
else:
upper_hidden = layer(
layer_input, attention_mask=full_mask, position_ids=position_ids,
past_key_values=None, use_cache=False, cache_position=cache_position, position_embeddings=position_embeddings,
)
if has_prefix_self:
singleton_hidden = layer_input[prefix_row_idx, prefix_col_idx].unsqueeze(1)
singleton_position_ids = position_ids[prefix_row_idx, prefix_col_idx].unsqueeze(1)
singleton_pos_emb = model.rotary_emb(singleton_hidden, singleton_position_ids)
if use_checkpoint:
singleton_outputs = checkpoint_with_cache(
layer.__call__, singleton_hidden, attention_mask=singleton_attention_mask, position_ids=singleton_position_ids,
past_key_values=None, use_cache=False, cache_position=singleton_cache_position, position_embeddings=singleton_pos_emb,
)
else:
singleton_outputs = layer(
singleton_hidden, attention_mask=singleton_attention_mask, position_ids=singleton_position_ids,
past_key_values=None, use_cache=False, cache_position=singleton_cache_position, position_embeddings=singleton_pos_emb,
)
merged_hidden = upper_hidden.clone()
merged_hidden[prefix_row_idx, prefix_col_idx] = singleton_outputs.squeeze(1)
hidden_states = merged_hidden
else:
hidden_states = upper_hidden
hidden_states = model.norm(hidden_states)
if isinstance(logits_to_keep, int) and logits_to_keep > 0:
keep_hidden, kept_labels, _, _ = _tri_pack_suffix_tensor(
hidden_states,
split_starts=(valid_lens - torch.minimum(valid_lens, torch.full_like(valid_lens, int(logits_to_keep)))).clamp_min_(0),
valid_lens=valid_lens,
pad_value=0.0,
)
kept_labels, _, _, _ = _tri_pack_suffix_tensor(
trimmed_labels,
split_starts=(valid_lens - torch.minimum(valid_lens, torch.full_like(valid_lens, int(logits_to_keep)))).clamp_min_(0),
valid_lens=valid_lens,
pad_value=-100,
)
logits = self.lm_head(keep_hidden)
loss = self.loss_function(logits=logits, labels=kept_labels, vocab_size=self.config.vocab_size)
else:
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
if isinstance(slice_indices, torch.Tensor):
slice_indices = slice_indices.to(device=hidden_states.device)
logits = self.lm_head(hidden_states[:, slice_indices, :])
selected_labels = trimmed_labels[:, slice_indices] if isinstance(slice_indices, slice) else trimmed_labels.index_select(1, slice_indices)
loss = self.loss_function(logits=logits, labels=selected_labels, vocab_size=self.config.vocab_size)
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=None, hidden_states=hidden_states)
LlamaForCausalLM.tri_vanilla_solo_prefix_train_forward = tri_vanilla_solo_prefix_train_forward
def _tri_vanilla_prefill_decode_forward_batch_seed(
self: "LlamaForCausalLM",
*,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor],
labels: Optional[torch.LongTensor],
use_cache: Optional[bool],
logits_to_keep: Union[int, torch.Tensor],
past_key_values: Optional[Cache],
prefill_lower_layers: int,
prefill_attn: str,
split_start_hint,
system_len_hint,
system_prefill: str,
assistant_header_starts_hint=None,
assistant_turn_ends_hint=None,
assistant_header_start_mask_hint=None,
see_past_assistant: bool = False,
no_upper_attn: bool = False,
replay_module: str = "none",
replay_per_layers: int = -1,
) -> Optional[CausalLMOutputWithPast]:
if input_ids is None or input_ids.dim() != 2 or input_ids.size(0) <= 1:
return None
if not _tri_cache_is_empty(past_key_values):
return None
replay_module = _normalize_replay_module(replay_module)
replay_per_layers = _normalize_replay_per_layers(replay_per_layers)
if replay_module != "none":
return None
model = self.model
n_layers = len(model.layers)
K = max(0, min(int(prefill_lower_layers), n_layers))
if K <= 0:
return None
attn_mode = (prefill_attn or "causal").strip().lower()
if attn_mode == "prefix_full":
attn_mode = "full"
if attn_mode != "causal":
return None
if _tri_effective_suffix_special_token_ids(self):
return None
batch_size = int(input_ids.size(0))
device = input_ids.device
if attention_mask is not None:
if attention_mask.dim() != 2 or attention_mask.size(0) != batch_size:
return None
batch_attention_mask = attention_mask.to(device=device, dtype=torch.long)
else:
batch_attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
valid_lens = batch_attention_mask.sum(dim=1, dtype=torch.long)
if torch.any(valid_lens <= 0):
return None
max_valid_len = int(valid_lens.max().item())
trimmed_input_ids = input_ids[:, :max_valid_len]
full_attention_mask = batch_attention_mask[:, :max_valid_len]
trimmed_labels = labels[:, :max_valid_len] if isinstance(labels, torch.Tensor) and labels.dim() == 2 else None
parsed_assistant_header_hints = _tri_parse_assistant_header_start_hints(
assistant_header_starts_hint,
assistant_header_start_mask_hint,
batch_size=batch_size,
device=device,
)
parsed_assistant_turn_end_hints = _tri_parse_assistant_header_start_hints(
assistant_turn_ends_hint,
assistant_header_start_mask_hint,
batch_size=batch_size,
device=device,
)
parsed_split_starts = _tri_parse_split_start_hints(
split_start_hint,
batch_size=batch_size,
device=device,
)
if trimmed_labels is not None:
valid_label_mask = trimmed_labels != -100
valid_counts = valid_label_mask.sum(dim=1, dtype=torch.long)
if torch.any(valid_counts <= 0):
return None
token_positions = torch.arange(max_valid_len, device=device, dtype=torch.long).unsqueeze(0)
first_valid = torch.where(
valid_label_mask,
token_positions,
torch.full_like(token_positions, max_valid_len),
).min(dim=1).values
last_valid = torch.where(
valid_label_mask,
token_positions,
torch.full_like(token_positions, -1),
).max(dim=1).values
contiguous_suffix = valid_counts == (last_valid - first_valid + 1)
if not bool(torch.all(contiguous_suffix).item()):
return None
split_starts = (first_valid - 1).clamp_min_(0)
if parsed_split_starts is not None:
valid_split_mask = (
(parsed_split_starts >= 0)
& (parsed_split_starts < valid_lens)
& (parsed_split_starts <= first_valid)
)
split_starts = torch.where(valid_split_mask, parsed_split_starts, split_starts)
else:
split_starts = (valid_lens - 1).clamp_min_(0)
if parsed_split_starts is not None:
valid_split_mask = (parsed_split_starts >= 0) & (parsed_split_starts < valid_lens)
split_starts = torch.where(valid_split_mask, parsed_split_starts, split_starts)
parsed_system_lens = _tri_parse_system_len_hints(
system_len_hint,
batch_size=batch_size,
device=device,
)
if parsed_system_lens is None:
system_lens = torch.zeros_like(split_starts)
else:
system_lens = torch.minimum(parsed_system_lens.clamp_min(0), split_starts)
system_prefill = normalize_system_prefill(system_prefill)
decode_lens = valid_lens - split_starts
if torch.any(decode_lens <= 0):
return None
if not bool(torch.all(decode_lens == decode_lens[0]).item()):
return None
decode_width = int(decode_lens[0].item())
if decode_width <= 0:
return None
prefix_width = int(split_starts.max().item())
pkv = _safe_dynamic_cache(model.config)
if prefix_width > 0:
prefix_offsets = torch.arange(prefix_width, device=device, dtype=torch.long).unsqueeze(0)
prefix_valid = prefix_offsets < split_starts.unsqueeze(1)
prefix_gather = prefix_offsets.expand(batch_size, -1).clamp(max=max(max_valid_len - 1, 0))
prefix_ids = trimmed_input_ids.gather(1, prefix_gather).masked_fill(~prefix_valid, 0)
prefix_attention_mask = prefix_valid.to(dtype=torch.long)
inputs_embeds = model.embed_tokens(prefix_ids)
cache_position = _tri_arange(0, prefix_width, device)
position_ids = _llopa_position_ids_from_mask(prefix_attention_mask)
use_checkpoint = should_use_checkpoint(model)
packed_lower_hidden = None
if _tri_batch_packed_varlen_prefill_enabled(inputs_embeds, prefix_valid):
packed_lower_hidden = _tri_run_packed_varlen_prefill_layers(
model=model,
hidden_states=inputs_embeds[prefix_valid],
position_ids=position_ids[prefix_valid],
valid_mask=prefix_valid,
past_key_values=pkv,
start_layer=0,
end_layer=K,
use_checkpoint=use_checkpoint,
)
if packed_lower_hidden is not None:
hidden_states = _tri_scatter_packed_hidden(
packed_lower_hidden,
prefix_valid,
batch_size=batch_size,
width=prefix_width,
)
else:
attn_impl = _resolve_attn_impl(model.config)
if attn_impl == "flash_attention_2":
lower_mask = None if _llopa_mask_is_all_ones(prefix_attention_mask) else prefix_attention_mask
elif _can_use_implicit_causal_mask(model.config) and _llopa_mask_is_all_ones(prefix_attention_mask):
lower_mask = None
else:
lower_mask = create_causal_mask(
config=model.config,
input_embeds=inputs_embeds,
attention_mask=prefix_attention_mask,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = model.rotary_emb(hidden_states, position_ids)
for li in range(K):
layer = model.layers[li]
if use_checkpoint:
hidden_states = checkpoint_with_cache(
layer.__call__,
hidden_states,
attention_mask=lower_mask,
position_ids=position_ids,
past_key_values=pkv,
use_cache=True,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
else:
hidden_states = layer(
hidden_states,
attention_mask=lower_mask,
position_ids=position_ids,
past_key_values=pkv,
use_cache=True,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
else:
prefix_valid = torch.zeros((batch_size, 0), device=device, dtype=torch.bool)
hidden_states = model.embed_tokens(input_ids[:, :0])
position_ids = torch.empty((batch_size, 0), device=device, dtype=torch.long)
use_checkpoint = should_use_checkpoint(model)
if parsed_assistant_header_hints is not None:
header_starts, header_start_mask = parsed_assistant_header_hints
else:
header_starts = None
header_start_mask = None
if parsed_assistant_turn_end_hints is not None:
turn_ends, _ = parsed_assistant_turn_end_hints
else:
turn_ends = None
upper_gather_idx, upper_valid_mask, upper_lens = _tri_build_prefill_lower_seed_upper_index_batch(
split_starts=split_starts,
system_lens=system_lens,
system_prefill=system_prefill,
assistant_header_starts=header_starts,
assistant_turn_ends=turn_ends,
assistant_header_start_mask=header_start_mask,
see_past_assistant=bool(see_past_assistant),
device=device,
)
upper_width = int(upper_gather_idx.size(1))
if K < n_layers and upper_width > 0 and not bool(no_upper_attn):
upper_hidden, _ = _tri_pack_indexed_tensor(
hidden_states,
gather_idx=upper_gather_idx,
valid_mask=upper_valid_mask,
pad_value=0.0,
)
upper_position_ids, _ = _tri_pack_indexed_tensor(
position_ids,
gather_idx=upper_gather_idx,
valid_mask=upper_valid_mask,
pad_value=0,
)
packed_upper_hidden = None
if _tri_batch_packed_varlen_prefill_enabled(upper_hidden, upper_valid_mask):
packed_upper_hidden = _tri_run_packed_varlen_prefill_layers(
model=model,
hidden_states=upper_hidden[upper_valid_mask],
position_ids=upper_position_ids[upper_valid_mask],
valid_mask=upper_valid_mask,
past_key_values=pkv,
start_layer=K,
end_layer=n_layers,
use_checkpoint=use_checkpoint,
)
if packed_upper_hidden is None:
upper_attention_mask = upper_valid_mask.to(dtype=torch.long)
upper_cache_position = _tri_arange(0, upper_width, device)
attn_impl = _resolve_attn_impl(model.config)
if attn_impl == "flash_attention_2":
upper_mask = None if _llopa_mask_is_all_ones(upper_attention_mask) else upper_attention_mask
elif _can_use_implicit_causal_mask(model.config) and _llopa_mask_is_all_ones(upper_attention_mask):
upper_mask = None
else:
upper_mask = create_causal_mask(
config=model.config,
input_embeds=upper_hidden,
attention_mask=upper_attention_mask,
cache_position=upper_cache_position,
past_key_values=None,
position_ids=upper_position_ids,
)
upper_position_embeddings = model.rotary_emb(upper_hidden, upper_position_ids)
for li in range(K, n_layers):
layer = model.layers[li]
if use_checkpoint:
upper_hidden = checkpoint_with_cache(
layer.__call__,
upper_hidden,
attention_mask=upper_mask,
position_ids=upper_position_ids,
past_key_values=pkv,
use_cache=True,
cache_position=upper_cache_position,
position_embeddings=upper_position_embeddings,
)
else:
upper_hidden = layer(
upper_hidden,
attention_mask=upper_mask,
position_ids=upper_position_ids,
past_key_values=pkv,
use_cache=True,
cache_position=upper_cache_position,
position_embeddings=upper_position_embeddings,
)
layer_valid_masks: list[torch.Tensor] = []
for li in range(n_layers):
if li < K:
layer_valid_masks.append(prefix_valid.to(device=device, dtype=torch.bool).clone())
else:
layer_valid_masks.append(upper_valid_mask.to(device=device, dtype=torch.bool).clone())
with contextlib.suppress(Exception):
setattr(pkv, "_llopa_batch_layer_valid_masks", layer_valid_masks)
setattr(pkv, "_tri_past_len_cache", None)
decode_offsets = torch.arange(decode_width, device=device, dtype=torch.long).unsqueeze(0)
decode_gather = split_starts.unsqueeze(1) + decode_offsets
decode_ids = trimmed_input_ids.gather(1, decode_gather)
decode_labels = trimmed_labels.gather(1, decode_gather) if trimmed_labels is not None else None
with contextlib.suppress(Exception):
setattr(pkv, "_llopa_batch_decode_position_ids", decode_gather.to(device=device, dtype=torch.long))
write_cache = bool(self.config.use_cache if use_cache is None else use_cache)
out = self.tri_step_logits(
assistant_ids=decode_ids,
lower_k=int(prefill_lower_layers),
pkv=pkv,
S=0,
U=0,
logits_to_keep=logits_to_keep,
labels=decode_labels,
write_cache=write_cache,
prefill_mode="lower",
no_upper_attn=bool(no_upper_attn),
replay_module=replay_module,
replay_per_layers=replay_per_layers,
)
if write_cache:
_tri_append_batch_cache_valid_masks(
pkv,
torch.ones((batch_size, decode_width), device=device, dtype=torch.bool),
)
with contextlib.suppress(Exception):
setattr(
pkv,
"_tri_prefill_seed_meta",
{
"S": 0,
"U": 0,
"split_start": split_starts.detach(),
},
)
return out
def tri_vanilla_prefill_decode_forward(
self: "LlamaForCausalLM",
*,
input_ids: Optional[torch.LongTensor],
attention_mask: Optional[torch.Tensor],
labels: Optional[torch.LongTensor],
use_cache: Optional[bool],
logits_to_keep: Union[int, torch.Tensor],
past_key_values: Optional[Cache],
prefill_lower_layers: int,
prefill_attn: str,
split_start_hint,
system_len_hint,
system_prefill: str,
assistant_header_starts_hint=None,
assistant_turn_ends_hint=None,
assistant_header_start_mask_hint=None,
replay_user_prefix_keep_len_hint=None,
replay_user_start_hint=None,
replay_user_len_hint=None,
see_past_assistant: bool = False,
no_upper_attn: bool = False,
replay_module: str = "none",
replay_per_layers: int = -1,
) -> Optional[CausalLMOutputWithPast]:
if input_ids is None or input_ids.dim() != 2:
return None
if input_ids.size(0) != 1:
batch_out = _tri_vanilla_prefill_decode_forward_batch_seed(
self,
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
past_key_values=past_key_values,
prefill_lower_layers=prefill_lower_layers,
prefill_attn=prefill_attn,
split_start_hint=split_start_hint,
system_len_hint=system_len_hint,
system_prefill=system_prefill,
assistant_header_starts_hint=assistant_header_starts_hint,
assistant_turn_ends_hint=assistant_turn_ends_hint,
assistant_header_start_mask_hint=assistant_header_start_mask_hint,
see_past_assistant=bool(see_past_assistant),
no_upper_attn=bool(no_upper_attn),
replay_module=replay_module,
replay_per_layers=replay_per_layers,
)
if batch_out is not None:
return batch_out
return None
valid_len = int(input_ids.size(1))
if attention_mask is not None:
if attention_mask.dim() != 2 or attention_mask.size(0) != input_ids.size(0):
return None
valid_len = int(attention_mask[0].sum().item())
if valid_len <= 0:
return None
trimmed_input_ids = input_ids[:, :valid_len]
trimmed_labels = labels[:, :valid_len] if labels is not None else None
if _tri_cache_is_empty(past_key_values):
parsed_split_start = _tri_parse_split_start_hint(split_start_hint, batch_size=input_ids.size(0))
parsed_assistant_header_hints = _tri_parse_assistant_header_start_hints(
assistant_header_starts_hint,
assistant_header_start_mask_hint,
batch_size=input_ids.size(0),
device=trimmed_input_ids.device,
)
parsed_assistant_turn_end_hints = _tri_parse_assistant_header_start_hints(
assistant_turn_ends_hint,
assistant_header_start_mask_hint,
batch_size=input_ids.size(0),
device=trimmed_input_ids.device,
)
parsed_system_len = _tri_parse_system_len_hint(system_len_hint, batch_size=input_ids.size(0))
parsed_replay_user_prefix_keep_len = _tri_parse_system_len_hint(
replay_user_prefix_keep_len_hint,
batch_size=input_ids.size(0),
)
parsed_replay_user_start = _tri_parse_system_len_hint(
replay_user_start_hint,
batch_size=input_ids.size(0),
)
parsed_replay_user_len = _tri_parse_system_len_hint(
replay_user_len_hint,
batch_size=input_ids.size(0),
)
if trimmed_labels is not None:
valid_mask = trimmed_labels[0] != -100
valid_positions = valid_mask.nonzero(as_tuple=False).flatten()
if valid_positions.numel() == 0:
_tri_vanilla_warn_once(
self,
"_tri_warned_prefill_lower_no_labels",
"[TRI] prefill_lower_layers found no supervised suffix; falling back to vanilla.",
)
return None
first_valid = int(valid_positions[0].item())
last_valid = int(valid_positions[-1].item())
if not bool(valid_mask[first_valid : last_valid + 1].all()):
_tri_vanilla_warn_once(
self,
"_tri_warned_prefill_lower_noncontiguous",
"[TRI] prefill_lower_layers expects a contiguous supervised suffix; falling back to vanilla.",
)
return None
split_start = max(first_valid - 1, 0)
if parsed_split_start is not None:
if 0 <= parsed_split_start < valid_len and parsed_split_start <= first_valid:
split_start = int(parsed_split_start)
else:
_tri_vanilla_warn_once(
self,
"_tri_warned_prefill_lower_invalid_split_start",
"[TRI] Ignoring invalid prefill_lower split start; falling back to the label-derived boundary.",
)
else:
split_start = max(valid_len - 1, 0)
if parsed_split_start is not None:
if 0 <= parsed_split_start < valid_len:
split_start = int(parsed_split_start)
else:
_tri_vanilla_warn_once(
self,
"_tri_warned_prefill_lower_infer_invalid_split_start",
"[TRI] Ignoring invalid prefill_lower split start for inference; using the last prompt token.",
)
if parsed_system_len is None:
system_len = 0
else:
system_len = min(max(int(parsed_system_len), 0), split_start)
system_prefill = normalize_system_prefill(system_prefill)
if _tri_effective_fusion_mode(self) == "inband":
inband_attention_mask = torch.ones_like(trimmed_input_ids, dtype=torch.long, device=trimmed_input_ids.device)
header_starts = None
turn_ends = None
header_start_mask = None
if parsed_assistant_header_hints is not None:
header_starts, header_start_mask = parsed_assistant_header_hints
if parsed_assistant_turn_end_hints is not None:
turn_ends, _ = parsed_assistant_turn_end_hints
(
trimmed_input_ids,
inband_attention_mask,
trimmed_labels,
remapped_split_starts,
remapped_header_starts,
remapped_header_start_mask,
) = _tri_insert_suffix_specials_inband(
token_ids=_tri_effective_suffix_special_token_ids(self),
input_ids=trimmed_input_ids,
attention_mask=inband_attention_mask,
labels=trimmed_labels,
split_starts=torch.tensor([split_start], device=trimmed_input_ids.device, dtype=torch.long),
assistant_header_starts=header_starts,
assistant_header_start_mask=header_start_mask,
)
valid_len = int(inband_attention_mask[0].sum().item())
split_start = int(remapped_split_starts[0].item())
if remapped_header_starts is not None and remapped_header_start_mask is not None:
parsed_assistant_header_hints = (remapped_header_starts, remapped_header_start_mask)
if (
turn_ends is not None
and parsed_assistant_header_hints is not None
):
remapped_header_starts, remapped_header_start_mask = parsed_assistant_header_hints
remapped_turn_ends = turn_ends.clone()
token_count = len(_tri_effective_suffix_special_token_ids(self))
if token_count > 0:
valid_turn_count = int(remapped_turn_ends.size(1))
for col in range(valid_turn_count):
if not bool(remapped_header_start_mask[0, col].item()):
continue
remapped_turn_ends[0, col] += token_count * (col + 1)
parsed_assistant_turn_end_hints = (remapped_turn_ends, remapped_header_start_mask)
prefix_ids = trimmed_input_ids[:, :split_start]
decode_ids = trimmed_input_ids[:, split_start:valid_len]
decode_labels = trimmed_labels[:, split_start:valid_len] if trimmed_labels is not None else None
replay_module = _normalize_replay_module(replay_module)
replay_per_layers = _normalize_replay_per_layers(replay_per_layers)
if system_prefill == "full":
visible_prefix_len = system_len
elif system_prefill == "no_system":
visible_prefix_len = min(system_len, 1)
else:
visible_prefix_len = 0
need_replay_memory = replay_module != "none" and not bool(no_upper_attn) and int(split_start) > int(visible_prefix_len)
effective_see_past_assistant = bool(see_past_assistant) and _tri_has_effective_past_assistant_history(
split_start=int(split_start),
assistant_header_hints=parsed_assistant_header_hints,
assistant_turn_end_hints=parsed_assistant_turn_end_hints,
)
prefill_out = _tri_prefill_lower_prompt_cache(
self.model,
prefix_ids,
lower_k=int(prefill_lower_layers),
prefill_attn=prefill_attn,
system_len=system_len,
system_prefill=system_prefill,
return_replay_memory=need_replay_memory,
return_prefix_hidden=bool(effective_see_past_assistant),
replay_user_prefix_keep_len=0 if parsed_replay_user_prefix_keep_len is None else int(parsed_replay_user_prefix_keep_len),
replay_user_start=parsed_replay_user_start,
replay_user_len=parsed_replay_user_len,
)
if need_replay_memory or bool(effective_see_past_assistant):
pkv, replay_hidden, replay_position_ids, prefix_hidden, prefix_position_ids = prefill_out
else:
pkv = prefill_out
replay_hidden = None
replay_position_ids = None
prefix_hidden = None
prefix_position_ids = None
if _tri_effective_fusion_mode(self) == "upper_only":
split_positions = [int(split_start)]
if parsed_assistant_header_hints is not None:
header_starts, header_start_mask = parsed_assistant_header_hints
valid_turns = header_starts[0][header_start_mask[0]]
if valid_turns.numel() > 0:
split_positions = [int(v) for v in valid_turns.tolist()]
if int(split_start) not in split_positions:
split_positions.append(int(split_start))
_tri_prefill_suffix_specials_upper_cache(
self.model,
pkv,
lower_k=int(prefill_lower_layers),
split_starts=split_positions,
no_upper_attn=bool(no_upper_attn),
)
if bool(effective_see_past_assistant) and parsed_assistant_header_hints is not None and parsed_assistant_turn_end_hints is not None:
header_starts, header_start_mask = parsed_assistant_header_hints
turn_ends, _ = parsed_assistant_turn_end_hints
history_indices: list[torch.Tensor] = []
num_turns = min(int(header_starts.size(1)), int(turn_ends.size(1)))
for col in range(num_turns):
if not bool(header_start_mask[0, col].item()):
continue
turn_start = int(header_starts[0, col].item())
turn_end = int(turn_ends[0, col].item())
if turn_end <= turn_start or turn_start >= int(split_start):
continue
turn_end = min(turn_end, int(split_start))
if turn_end <= turn_start:
continue
history_indices.append(
torch.arange(turn_start, turn_end, device=trimmed_input_ids.device, dtype=torch.long)
)
if history_indices:
history_idx = torch.cat(history_indices, dim=0)
if isinstance(prefix_hidden, torch.Tensor) and isinstance(prefix_position_ids, torch.Tensor):
prefix_pos_row = prefix_position_ids[0]
history_mask = torch.isin(prefix_pos_row, history_idx)
history_local_idx = history_mask.nonzero(as_tuple=False).flatten()
if history_local_idx.numel() > 0:
history_hidden = prefix_hidden.index_select(1, history_local_idx)
history_position_ids = prefix_position_ids.index_select(1, history_local_idx)
_tri_prefill_upper_history_cache(
self.model,
pkv,
hidden_states=history_hidden,
position_ids=history_position_ids,
lower_k=int(prefill_lower_layers),
no_upper_attn=bool(no_upper_attn),
)
if need_replay_memory and isinstance(replay_hidden, torch.Tensor) and replay_hidden.numel() > 0 and int(replay_hidden.size(1)) > 0:
_tri_store_last_layer_memory(
pkv,
hidden_states=replay_hidden,
position_ids=replay_position_ids,
valid_mask=torch.ones(
(replay_hidden.size(0), replay_hidden.size(1)),
device=replay_hidden.device,
dtype=torch.bool,
),
module_type=replay_module,
replay_per_layers=replay_per_layers,
)
try:
setattr(
pkv,
"_tri_prefill_seed_meta",
{
"S": int(visible_prefix_len),
"U": int(max(split_start - visible_prefix_len, 0)),
"split_start": int(split_start),
},
)
except Exception:
pass
else:
split_start = 0
decode_ids = trimmed_input_ids
decode_labels = trimmed_labels
pkv = past_key_values
visible_prefix_len = 0
if decode_ids.size(1) == 0:
return None
write_cache = bool(self.config.use_cache if use_cache is None else use_cache)
return self.tri_step_logits(
assistant_ids=decode_ids,
lower_k=int(prefill_lower_layers),
pkv=pkv,
S=visible_prefix_len,
U=max(split_start - visible_prefix_len, 0),
logits_to_keep=logits_to_keep,
labels=decode_labels,
write_cache=write_cache,
prefill_mode="lower",
no_upper_attn=bool(no_upper_attn),
replay_module=replay_module,
replay_per_layers=replay_per_layers,
)
LlamaForCausalLM.tri_vanilla_prefill_decode_forward = tri_vanilla_prefill_decode_forward
def llopa_reference_prefill_seed(
self: "LlamaForCausalLM",
*,
system_ids: Optional[torch.LongTensor],
user_ids: Optional[torch.LongTensor],
assistant_ids: Optional[torch.LongTensor],
lower_k: int,
prefill_attn: str,
system_prefill: str,
no_upper_attn: bool,
replay_module: str = "none",
replay_per_layers: int = -1,
replay_user_prefix_keep_len: int = 0,
replay_user_start: Optional[int] = None,
replay_user_len: Optional[int] = None,
):
if bool(no_upper_attn):
return None
if _tri_effective_fusion_mode(self) == "inband" and _tri_effective_suffix_special_token_ids(self):
return None
llopa_prefill_fn = getattr(self.model, "llopa_prefill_cache", None)
if not callable(llopa_prefill_fn):
return None
sys_upper, sys_lower_extra = _llopa_split_system(system_ids, system_prefill)
if sys_lower_extra.numel() == 0:
user_llopa = user_ids
elif user_ids.numel() == 0:
user_llopa = sys_lower_extra
else:
user_llopa = torch.cat([sys_lower_extra, user_ids], dim=1)
merged_replay_prefix_keep_len = int(sys_lower_extra.size(1)) + int(replay_user_prefix_keep_len or 0)
merged_replay_user_start = None if replay_user_start is None else int(sys_lower_extra.size(1)) + int(replay_user_start)
prefill_out = llopa_prefill_fn(
system_ids=sys_upper,
user_ids=user_llopa,
assistant_ids=assistant_ids,
lower_k=int(lower_k),
prefill_mode="lower",
prefill_attn=prefill_attn,
return_last_assistant_hidden=True,
replay_module=replay_module,
replay_per_layers=replay_per_layers,
replay_user_prefix_keep_len=merged_replay_prefix_keep_len,
replay_user_start=merged_replay_user_start,
replay_user_len=replay_user_len,
)
if not isinstance(prefill_out, tuple) or len(prefill_out) != 2:
return None
pkv, last_hidden = prefill_out
if pkv is None or not isinstance(last_hidden, torch.Tensor) or last_hidden.numel() == 0:
return None
logits = self.lm_head(last_hidden[:, -1:, :]).to(torch.float32)
system_len = int(sys_upper.size(1)) if isinstance(sys_upper, torch.Tensor) else 0
user_len = int(user_llopa.size(1)) if isinstance(user_llopa, torch.Tensor) else 0
return pkv, system_len, user_len, logits[:, -1, :]
LlamaForCausalLM.llopa_reference_prefill_seed = llopa_reference_prefill_seed
def tri_reference_prefill_seed(
self: "LlamaForCausalLM",
*,
input_ids: Optional[torch.LongTensor],
attention_mask: Optional[torch.Tensor],
use_cache: Optional[bool],
logits_to_keep: Union[int, torch.Tensor],
lower_k: int,
prefill_attn: str,
system_prefill: str,
no_upper_attn: bool,
prefill_lower_split_start,
prefill_lower_system_len,
prefill_lower_replay_user_prefix_keep_len=None,
prefill_lower_replay_user_start=None,
prefill_lower_replay_user_len=None,
assistant_header_starts=None,
assistant_turn_ends=None,
assistant_header_start_mask=None,
prefill_lower_see_past_assistant: bool = False,
replay_module: str = "none",
replay_per_layers: int = -1,
):
outputs = self(
input_ids=input_ids,
attention_mask=attention_mask,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
prefill_lower_layers=int(lower_k),
prefill_lower_attn=prefill_attn,
prefill_lower_system_prefill=system_prefill,
prefill_lower_no_upper_attn=bool(no_upper_attn),
prefill_lower_split_start=prefill_lower_split_start,
prefill_lower_system_len=prefill_lower_system_len,
prefill_lower_replay_user_prefix_keep_len=prefill_lower_replay_user_prefix_keep_len,
prefill_lower_replay_user_start=prefill_lower_replay_user_start,
prefill_lower_replay_user_len=prefill_lower_replay_user_len,
assistant_header_starts=assistant_header_starts,
assistant_turn_ends=assistant_turn_ends,
assistant_header_start_mask=assistant_header_start_mask,
prefill_lower_see_past_assistant=bool(prefill_lower_see_past_assistant),
prefill_lower_replay_module=replay_module,
prefill_lower_replay_per_layers=replay_per_layers,
)
if outputs is None or outputs.past_key_values is None:
return None
logits = getattr(outputs, "logits", None)
if not isinstance(logits, torch.Tensor) or logits.numel() == 0:
return None
meta = getattr(outputs.past_key_values, "_tri_prefill_seed_meta", None)
if not isinstance(meta, dict):
return None
try:
S = int(meta.get("S", 0))
U = int(meta.get("U", 0))
except Exception:
return None
return outputs.past_key_values, S, U, logits[:, -1, :].to(torch.float32)
LlamaForCausalLM.tri_reference_prefill_seed = tri_reference_prefill_seed
LlamaForCausalLM.llopa_full_prompt_prefill_seed = tri_reference_prefill_seed
def tri_runtime_llopa_prompt_prefill_forward(
self: "LlamaForCausalLM",
*,
input_ids: Optional[torch.LongTensor],
attention_mask: Optional[torch.Tensor],
use_cache: Optional[bool],
logits_to_keep: Union[int, torch.Tensor],
past_key_values: Optional[Cache],
lower_k: int,
prefill_attn: str,
no_upper_attn: bool,
replay_module: str = "none",
replay_per_layers: int = -1,
) -> Optional[CausalLMOutputWithPast]:
if input_ids is None or input_ids.dim() != 2:
return None
if input_ids.size(0) != 1:
return None
header_ids = getattr(self, "_runtime_llopa_header_ids", None)
if not isinstance(header_ids, torch.Tensor) or header_ids.numel() == 0:
return None
valid_len = int(input_ids.size(1))
if attention_mask is not None:
if attention_mask.dim() != 2 or attention_mask.size(0) != input_ids.size(0):
return None
valid_len = int(attention_mask[0].sum().item())
if valid_len <= 0:
return None
trimmed_input_ids = input_ids[:, :valid_len]
write_cache = bool(self.config.use_cache if use_cache is None else use_cache)
if _tri_cache_is_empty(past_key_values):
hdr = header_ids.to(device=trimmed_input_ids.device, dtype=trimmed_input_ids.dtype)
assistant_start = _tri_find_last_subsequence_start(trimmed_input_ids, hdr)
if assistant_start is None:
return None
prefix_ids = trimmed_input_ids[:, :assistant_start]
assistant_ids = trimmed_input_ids[:, assistant_start:]
if assistant_ids.numel() == 0:
return None
prefill_out = self.model.llopa_prefill_cache(
system_ids=trimmed_input_ids[:, :0],
user_ids=prefix_ids,
assistant_ids=assistant_ids,
lower_k=int(lower_k),
prefill_mode="lower",
prefill_attn=prefill_attn,
return_last_assistant_hidden=not bool(no_upper_attn),
replay_module=replay_module,
replay_per_layers=replay_per_layers,
)
last_hidden = None
if isinstance(prefill_out, tuple):
pkv, last_hidden = prefill_out
else:
pkv = prefill_out
try:
setattr(pkv, "_runtime_llopa_user_len", int(prefix_ids.size(1)))
except Exception:
pass
if isinstance(last_hidden, torch.Tensor) and last_hidden.numel() > 0:
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(last_hidden[:, slice_indices, :])
return CausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=pkv if write_cache else None,
hidden_states=last_hidden,
)
last = assistant_ids[:, -1:]
return self.tri_step_logits(
assistant_ids=last,
lower_k=int(lower_k),
pkv=pkv,
S=0,
U=int(prefix_ids.size(1)),
logits_to_keep=logits_to_keep,
labels=None,
write_cache=write_cache,
prefill_mode="lower",
no_upper_attn=bool(no_upper_attn),
replay_module=replay_module,
replay_per_layers=replay_per_layers,
)
user_len = int(getattr(past_key_values, "_runtime_llopa_user_len", 0) or 0)
return self.tri_step_logits(
assistant_ids=trimmed_input_ids,
lower_k=int(lower_k),
pkv=past_key_values,
S=0,
U=user_len,
logits_to_keep=logits_to_keep,
labels=None,
write_cache=write_cache,
prefill_mode="lower",
no_upper_attn=bool(no_upper_attn),
replay_module=replay_module,
replay_per_layers=replay_per_layers,
)
LlamaForCausalLM.tri_runtime_llopa_prompt_prefill_forward = tri_runtime_llopa_prompt_prefill_forward
__all__ = [
"LlamaForCausalLM",
"LlamaModel",
"LlamaPreTrainedModel",
"LlamaForSequenceClassification",
"LlamaForQuestionAnswering",
"LlamaForTokenClassification",
]