LiteRT-LM / runtime /executor /llm_litert_compiled_model_executor.cc
SeaWolf-AI's picture
Upload full LiteRT-LM codebase
5f923cd verified
// Copyright 2025 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "runtime/executor/llm_litert_compiled_model_executor.h"
#include <algorithm>
#include <atomic>
#include <cstdint>
#include <cstring>
#include <memory>
#include <optional>
#include <random>
#include <string>
#include <utility>
#include <variant>
#include <vector>
#include "absl/container/flat_hash_map.h" // from @com_google_absl
#include "absl/log/absl_log.h" // from @com_google_absl
#include "absl/memory/memory.h" // from @com_google_absl
#include "absl/status/status.h" // from @com_google_absl
#include "absl/status/statusor.h" // from @com_google_absl
#include "absl/strings/match.h" // from @com_google_absl
#include "absl/strings/str_cat.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "absl/types/span.h" // from @com_google_absl
#include "litert/cc/litert_common.h" // from @litert
#include "litert/cc/litert_compiled_model.h" // from @litert
#include "litert/cc/litert_element_type.h" // from @litert
#include "litert/cc/litert_environment.h" // from @litert
#include "litert/cc/litert_expected.h" // from @litert
#include "litert/cc/litert_layout.h" // from @litert
#include "litert/cc/litert_macros.h" // from @litert
#include "litert/cc/litert_model.h" // from @litert
#include "litert/cc/litert_options.h" // from @litert
#include "litert/cc/litert_ranked_tensor_type.h" // from @litert
#include "litert/cc/litert_tensor_buffer.h" // from @litert
#include "litert/cc/litert_tensor_buffer_types.h" // from @litert
#include "litert/cc/options/litert_cpu_options.h" // from @litert
#include "litert/cc/options/litert_runtime_options.h" // from @litert
#include "runtime/components/embedding_lookup/embedding_lookup_manager.h"
#include "runtime/components/model_resources.h"
#include "runtime/components/sampler_factory.h"
#include "runtime/executor/common_utils.h"
#include "runtime/executor/executor_settings_base.h"
#include "runtime/executor/litert_compiled_model_executor_utils.h"
#include "runtime/executor/llm_executor_io_types.h"
#include "runtime/executor/llm_executor_processed_tokens.h"
#include "runtime/executor/llm_executor_settings.h"
#include "runtime/executor/llm_executor_settings_utils.h"
#include "runtime/executor/llm_litert_compiled_model_cache_utils.h"
#include "runtime/executor/llm_litert_mtp_drafter.h"
#include "runtime/util/convert_tensor_buffer.h"
#include "runtime/util/log_tensor_buffer.h"
#include "runtime/util/lora_util.h"
#include "runtime/util/scoped_file.h"
#include "runtime/util/status_macros.h" // IWYU pragma: keep
#include "tflite/delegates/xnnpack/xnnpack_delegate.h" // from @litert
#include "tflite/types/half.h" // from @litert
namespace litert::lm {
namespace {
using ::absl::Span;
// Names of the signature runners, used to get the signature runners from the
// interpreter.
constexpr absl::string_view kPrefillSignatureRunner = "prefill";
constexpr absl::string_view kDecodeSignatureRunner = "decode";
constexpr int kDynamicDimValue = -1;
absl::Status InitializeEmbeddingLookups(
litert::Environment& env,
ModelResources& resources,
std::unique_ptr<EmbeddingLookupManager>& embedding_lookup,
std::unique_ptr<EmbeddingLookupManager>& per_layer_embedding_lookup) {
absl::flat_hash_map<int, const Model*> end_of_multi_modal_embedding_models;
{
auto end_of_audio_model =
resources.GetTFLiteModel(ModelType::kTfLiteEndOfAudio);
if (end_of_audio_model.ok()) {
end_of_multi_modal_embedding_models.insert(
{ExecutorAudioData::kEndToken, end_of_audio_model.value()});
}
}
{
auto end_of_vision_model =
resources.GetTFLiteModel(ModelType::kTfLiteEndOfVision);
if (end_of_vision_model.ok()) {
end_of_multi_modal_embedding_models.insert(
{ExecutorVisionData::kEndToken, end_of_vision_model.value()});
}
}
auto text_embedder_model =
resources.GetTFLiteModel(ModelType::kTfLiteEmbedder);
if (text_embedder_model.ok()) {
ASSIGN_OR_RETURN(
embedding_lookup,
EmbeddingLookupManager::Create(*text_embedder_model,
end_of_multi_modal_embedding_models,
/*fully_supports_multi_modal=*/true,
/*signature_key=*/std::nullopt, &env));
}
// Create per layer embedding lookups from the resources.
auto per_layer_embedder_model =
resources.GetTFLiteModel(ModelType::kTfLitePerLayerEmbedder);
if (per_layer_embedder_model.ok()) {
ASSIGN_OR_RETURN(
per_layer_embedding_lookup,
EmbeddingLookupManager::Create(*per_layer_embedder_model,
/*fully_supports_multi_modal=*/false,
/*signature_key=*/std::nullopt, &env));
}
return absl::OkStatus();
}
absl::Status CopyKvCacheBuffers(
size_t decode_batch_size, int src_index_to_copy_on_prefill,
const absl::flat_hash_map<absl::string_view, TensorBuffer>&
src_kv_cache_buffers,
const absl::flat_hash_map<absl::string_view, TensorBuffer>&
dst_kv_cache_buffers) {
for (const auto& [name, src_buffer] : src_kv_cache_buffers) {
if (!dst_kv_cache_buffers.contains(name)) {
return absl::FailedPreconditionError(
absl::StrCat("KV cache buffer ", name, " not found."));
}
const auto& dst_buffer = dst_kv_cache_buffers.at(name);
LITERT_ASSIGN_OR_RETURN(auto src_buffer_lock_and_addr,
TensorBufferScopedLock::Create(
src_buffer, TensorBuffer::LockMode::kRead));
LITERT_ASSIGN_OR_RETURN(size_t src_buffer_size, src_buffer.PackedSize());
const char* src_buffer_ptr =
static_cast<const char*>(src_buffer_lock_and_addr.second);
LITERT_ASSIGN_OR_RETURN(auto dst_buffer_lock_and_addr,
TensorBufferScopedLock::Create(
dst_buffer, TensorBuffer::LockMode::kWrite));
LITERT_ASSIGN_OR_RETURN(size_t dst_buffer_size, dst_buffer.PackedSize());
char* dst_buffer_ptr =
static_cast<char*>(const_cast<void*>(dst_buffer_lock_and_addr.second));
// This copy is based on the assumption that the KV cache buffers are in the
// layout of [batch * X, ...] or [1, batch * X, ...] where X could be 1 or
// more and X doesn't make values interleaved across batches which is true
// for the current LLM models of all backends.
if (src_index_to_copy_on_prefill >= 0) {
// This is the case of the first prefill after decode. It reduces the KV
// cache size to one by copying only the cache content of the given index.
RET_CHECK_EQ(src_buffer_size, dst_buffer_size * decode_batch_size);
RET_CHECK_LT(src_index_to_copy_on_prefill, decode_batch_size);
src_buffer_ptr += src_index_to_copy_on_prefill * dst_buffer_size;
memcpy(dst_buffer_ptr, src_buffer_ptr, dst_buffer_size);
} else {
// This is the case of the first decode after prefill. It broadcasts the
// KV cache contents to all the batches.
RET_CHECK_EQ(src_buffer_size * decode_batch_size, dst_buffer_size);
for (int i = 0; i < decode_batch_size; ++i) {
memcpy(dst_buffer_ptr, src_buffer_ptr, src_buffer_size);
dst_buffer_ptr += src_buffer_size;
}
}
}
return absl::OkStatus();
}
absl::StatusOr<int> GetDynamicDimIndex(const Model& model,
absl::string_view signature,
absl::string_view tensor_name) {
LITERT_ASSIGN_OR_RETURN(const SimpleSignature& sig,
model.FindSignature(signature));
LITERT_ASSIGN_OR_RETURN(const SimpleTensor& tensor,
sig.InputTensor(tensor_name));
LITERT_ASSIGN_OR_RETURN(const RankedTensorType ranked_tensor_type,
tensor.RankedTensorType());
auto dimensions = ranked_tensor_type.Layout().Dimensions();
for (int i = 0; i < dimensions.size(); ++i) {
if (dimensions[i] == kDynamicDimValue) {
return i;
}
}
return absl::InvalidArgumentError("No dynamic dimension found.");
}
absl::StatusOr<bool> HasDynamicDim(const Model& model,
absl::string_view signature,
absl::string_view tensor_name) {
LITERT_ASSIGN_OR_RETURN(const SimpleSignature& sig,
model.FindSignature(signature));
LITERT_ASSIGN_OR_RETURN(const SimpleTensor& tensor,
sig.InputTensor(tensor_name));
LITERT_ASSIGN_OR_RETURN(const RankedTensorType ranked_tensor_type,
tensor.RankedTensorType());
auto dimensions = ranked_tensor_type.Layout().Dimensions();
for (int i = 0; i < dimensions.size(); ++i) {
if (dimensions[i] == kDynamicDimValue) {
return true;
}
}
return false;
}
absl::Status ResolveDynamicShape(const Model& model,
CompiledModel& compiled_model,
absl::string_view signature,
absl::string_view tensor_name, int new_value) {
LITERT_ASSIGN_OR_RETURN(const SimpleSignature& sig,
model.FindSignature(signature));
LITERT_ASSIGN_OR_RETURN(const SimpleTensor& tensor,
sig.InputTensor(tensor_name));
LITERT_ASSIGN_OR_RETURN(const RankedTensorType ranked_tensor_type,
tensor.RankedTensorType());
auto dimensions = ranked_tensor_type.Layout().Dimensions();
bool has_dynamic_dim = false;
std::vector<int> new_shape;
new_shape.reserve(dimensions.size());
for (int i = 0; i < dimensions.size(); ++i) {
if (dimensions[i] == kDynamicDimValue) {
has_dynamic_dim = true;
new_shape.push_back(new_value);
} else {
new_shape.push_back(dimensions[i]);
}
}
if (has_dynamic_dim) {
LITERT_RETURN_IF_ERROR(
compiled_model.ResizeInputTensor(signature, tensor_name, new_shape));
}
return absl::OkStatus();
}
absl::StatusOr<TensorBuffer> ResizeKVCacheTensorBuffer(
Environment& env, TensorBuffer& tensor_buffer, int dynamic_dim_index,
int num_entries_to_insert) {
LITERT_ASSIGN_OR_RETURN(const RankedTensorType& tensor_type,
tensor_buffer.TensorType());
RET_CHECK(!tensor_type.Layout().HasStrides());
auto dimensions = tensor_type.Layout().Dimensions();
std::vector<int> new_dimensions;
new_dimensions.reserve(dimensions.size());
for (int i = 0; i < dimensions.size(); ++i) {
if (i == dynamic_dim_index) {
new_dimensions.push_back(dimensions[i] + num_entries_to_insert);
} else {
new_dimensions.push_back(dimensions[i]);
}
}
LITERT_ASSIGN_OR_RETURN(TensorBufferType buffer_type,
tensor_buffer.BufferType());
Layout new_layout(Dimensions(new_dimensions.begin(), new_dimensions.end()));
auto new_out_type =
RankedTensorType(tensor_type.ElementType(), std::move(new_layout));
LITERT_ASSIGN_OR_RETURN(size_t new_size, new_out_type.Bytes());
LITERT_ASSIGN_OR_RETURN(
TensorBuffer new_tensor_buffer,
TensorBuffer::CreateManaged(env, buffer_type, new_out_type, new_size));
LITERT_ASSIGN_OR_RETURN(auto tensor_buffer_lock_and_addr,
TensorBufferScopedLock::Create(
tensor_buffer, TensorBuffer::LockMode::kRead));
auto* tensor_buffer_ptr =
static_cast<uint8_t*>(tensor_buffer_lock_and_addr.second);
LITERT_ASSIGN_OR_RETURN(
auto new_tensor_buffer_lock_and_addr,
TensorBufferScopedLock::Create(new_tensor_buffer,
TensorBuffer::LockMode::kWrite));
auto* new_tensor_buffer_ptr =
static_cast<uint8_t*>(new_tensor_buffer_lock_and_addr.second);
std::optional<size_t> element_size = GetByteWidth(tensor_type.ElementType());
RET_CHECK(element_size.has_value());
RETURN_IF_ERROR(ExpandBuffer(tensor_buffer_ptr, dimensions,
new_tensor_buffer_ptr, new_dimensions,
element_size.value()));
return new_tensor_buffer;
}
// Builds the output tensor type for the embedding lookup. The output tensor
// type is the same as the input tensor type, except the first dimension is the
// number of tokens.
absl::StatusOr<RankedTensorType> GetEmbeddingLookupOutputTensorType(
int num_tokens, const RankedTensorType& output_element_type) {
if (num_tokens == 1) {
return output_element_type;
} else if (num_tokens == 0) {
return absl::InvalidArgumentError(
"Number of tokens must be greater than 0.");
}
const auto& dims = output_element_type.Layout().Dimensions();
if (dims.size() < 3) {
return absl::InvalidArgumentError("Tensor type must have rank 3 or more.");
}
if (dims[0] != 1 || dims[1] != 1) {
return absl::InvalidArgumentError(
"Element type must have first two dimensions as 1.");
}
Dimensions embedding_dims(dims.begin(), dims.end());
embedding_dims[1] = num_tokens;
return RankedTensorType(output_element_type.ElementType(),
Layout(std::move(embedding_dims)));
}
struct MaybeWrappedTensorBuffer {
TensorBuffer buffer;
bool wrapped;
};
template <typename T>
absl::StatusOr<MaybeWrappedTensorBuffer> WrapOrCreateTensorBufferFromHostMemory(
RankedTensorType tensor_type, absl::Span<T> data) {
size_t size = data.size() * sizeof(T);
// First try to wrap the memory with a TensorBuffer.
auto wrapped_buffer =
TensorBuffer::CreateFromHostMemory(tensor_type, data.data(), size);
if (wrapped_buffer.HasValue()) {
return MaybeWrappedTensorBuffer{.buffer = std::move(*wrapped_buffer),
.wrapped = true};
}
LITERT_ASSIGN_OR_RETURN(
auto new_buffer,
TensorBuffer::CreateManagedHostMemory(tensor_type, size));
return MaybeWrappedTensorBuffer{.buffer = std::move(new_buffer),
.wrapped = false};
}
// Returns a subspan of the given span for a chunk at the given index.
template <typename T>
absl::Span<const T> GetSpanForChunk(absl::Span<T> span, int num_chunks,
int chunk_index) {
size_t total_size = span.size();
size_t chunk_size = total_size / num_chunks;
return span.subspan(chunk_size * chunk_index, chunk_size);
}
absl::StatusOr<TensorBuffer> CreateFP16OutputBuffer(
Environment& env, CompiledModel& compiled_model, size_t signature_index,
absl::string_view output_name, size_t output_index) {
LITERT_ASSIGN_OR_RETURN(
std::vector<Layout> runtime_layouts,
compiled_model.GetOutputTensorLayouts(signature_index,
/*update_allocation=*/true));
// Use runtime layout.
Layout runtime_layout = runtime_layouts[output_index];
LITERT_ASSIGN_OR_RETURN(
auto requirements,
compiled_model.GetOutputBufferRequirements(signature_index, output_name));
LITERT_ASSIGN_OR_RETURN(auto strides, requirements.Strides());
if (!strides.empty()) {
auto dims = runtime_layout.Dimensions();
runtime_layout = Layout(litert::Dimensions(dims.begin(), dims.end()),
litert::Strides(strides.begin(), strides.end()));
}
RankedTensorType new_tensor_type(litert::ElementType::Float16,
std::move(runtime_layout));
LITERT_ASSIGN_OR_RETURN(size_t size, requirements.BufferSize());
LITERT_ASSIGN_OR_RETURN(auto buffer_types, requirements.SupportedTypes());
if (buffer_types.empty()) {
return absl::InternalError("No supported buffer types found.");
}
auto buffer_type = buffer_types[0];
LITERT_ASSIGN_OR_RETURN(
auto buffer, TensorBuffer::CreateManaged(
env, buffer_type, std::move(new_tensor_type), size));
return buffer;
}
} // namespace
absl::Status LlmLiteRtCompiledModelExecutorBase::CreatePrefillInputBuffers(
absl::string_view prefill_signature, int sequence_length,
int context_length,
absl::flat_hash_map<absl::string_view, TensorBuffer>&
prefill_input_buffers) {
auto dyn_shape_resolver = [&](absl::string_view tensor_name) -> absl::Status {
return ResolveDynamicShape(model_, compiled_model_, prefill_signature,
tensor_name, sequence_length);
};
// Create input_token, positions and attn_mask buffers after determining
// the prefill length.
if (!signatures_.input_tokens.empty()) {
RETURN_IF_ERROR(dyn_shape_resolver(signatures_.input_tokens));
auto tokens_buffer = compiled_model_.CreateInputBuffer(
prefill_signature, signatures_.input_tokens);
prefill_input_buffers[signatures_.input_tokens] = std::move(*tokens_buffer);
} else {
// If input_tokens is empty, we must have input_embeddings.
if (!signatures_.input_embeddings.has_value()) {
return absl::FailedPreconditionError(
"Input tokens or embeddings must be provided.");
}
if (embedding_lookup_ == nullptr) {
return absl::FailedPreconditionError(
"Input embeddings required by signature but embedding lookup "
"model is not initialized.");
}
RETURN_IF_ERROR(dyn_shape_resolver(signatures_.input_embeddings.value()));
auto embeddings_buffer = compiled_model_.CreateInputBuffer(
prefill_signature, signatures_.input_embeddings.value());
prefill_input_buffers[signatures_.input_embeddings.value()] =
std::move(*embeddings_buffer);
// We may have per layer embedding as well.
if (signatures_.input_per_layer_embeddings.has_value()) {
if (embedding_lookup_ == nullptr) {
return absl::FailedPreconditionError(
"Input per layer embeddings required by signature but "
"embedding lookup model is not initialized.");
}
RETURN_IF_ERROR(
dyn_shape_resolver(signatures_.input_per_layer_embeddings.value()));
auto per_layer_embeddings_buffer = compiled_model_.CreateInputBuffer(
prefill_signature, signatures_.input_per_layer_embeddings.value());
prefill_input_buffers[signatures_.input_per_layer_embeddings.value()] =
std::move(*per_layer_embeddings_buffer);
}
}
RETURN_IF_ERROR(dyn_shape_resolver(signatures_.input_positions));
auto positions_buffer = compiled_model_.CreateInputBuffer(
prefill_signature, signatures_.input_positions);
prefill_input_buffers[signatures_.input_positions] =
std::move(*positions_buffer);
if (signatures_.input_attn_mask.has_value()) {
ASSIGN_OR_RETURN(bool is_attn_dyn,
HasDynamicDim(model_, prefill_signature,
signatures_.input_attn_mask.value()));
if (is_attn_dyn) {
std::vector<int> new_shape = {1, 1, sequence_length, context_length};
LITERT_RETURN_IF_ERROR(compiled_model_.ResizeInputTensor(
prefill_signature, signatures_.input_attn_mask.value(), new_shape));
}
auto attn_mask_buffer = compiled_model_.CreateInputBuffer(
prefill_signature, signatures_.input_attn_mask.value());
prefill_input_buffers[signatures_.input_attn_mask.value()] =
std::move(*attn_mask_buffer);
}
if (signatures_.input_int32_param.has_value()) {
gpu_optimized_single_buffer_cache_ = true;
auto param_tensor_buffer = compiled_model_.CreateInputBuffer(
prefill_signature, signatures_.input_int32_param.value());
prefill_input_buffers[signatures_.input_int32_param.value()] =
std::move(*param_tensor_buffer);
}
return absl::OkStatus();
}
absl::Status LlmLiteRtCompiledModelExecutorBase::FillInputBufferWithToken(
const std::vector<std::shared_ptr<TokenData>>& unprocessed_token,
TensorBuffer& input_buffer, bool is_per_layer_embedding) {
if (unprocessed_token.empty()) {
return absl::InvalidArgumentError("Unprocessed token is null.");
}
LITERT_ASSIGN_OR_RETURN(auto input_buffer_lock_and_addr,
TensorBufferScopedLock::Create(
input_buffer, TensorBuffer::LockMode::kWrite));
LITERT_ASSIGN_OR_RETURN(size_t packed_size, input_buffer.PackedSize());
size_t stride = packed_size / unprocessed_token.size();
char* input_buffer_ptr =
static_cast<char*>(input_buffer_lock_and_addr.second);
for (const auto& token : unprocessed_token) {
size_t size_to_fill = 0;
if (token->embedding().empty()) {
size_to_fill = sizeof(int32_t);
RET_CHECK_GE(stride, size_to_fill);
// If the token has no embedding, the input_buffer should takes token id.
*reinterpret_cast<int32_t*>(input_buffer_ptr) = token->id();
} else if (is_per_layer_embedding) {
size_to_fill = token->per_layer_embedding().size() * sizeof(float);
RET_CHECK_GE(stride, size_to_fill);
memcpy(input_buffer_ptr, token->per_layer_embedding().data(),
size_to_fill);
} else {
size_to_fill = token->embedding().size() * sizeof(float);
RET_CHECK_GE(stride, size_to_fill);
memcpy(input_buffer_ptr, token->embedding().data(), size_to_fill);
}
if (stride > size_to_fill) {
memset(input_buffer_ptr + size_to_fill, 0, stride - size_to_fill);
}
input_buffer_ptr += stride;
}
return absl::OkStatus();
}
absl::Status LlmLiteRtCompiledModelExecutorBase::RollBackProcessedTokens() {
int current_step = llm_context_->runtime_state().current_step;
ProcessedTokens& processed_tokens =
llm_context_->processed_context().processed_tokens();
if (current_step == processed_tokens.TokenCount()) {
return absl::OkStatus();
}
if (current_step == 0) {
RETURN_IF_ERROR(processed_tokens.RollBackToStep(0));
} else {
auto token_at_step = processed_tokens.GetTokenAtStep(current_step - 1);
RETURN_IF_ERROR(processed_tokens.RollBackToStep(current_step - 1));
if (!token_at_step.empty()) {
RET_CHECK_EQ(token_at_step.size(), 1);
// Multimodal input cannot become a pending input token.
if (token_at_step.at(0) > 0) {
RETURN_IF_ERROR(processed_tokens.AddPendingInputToken(
{std::make_shared<TokenData>(token_at_step.at(0))}));
} else {
processed_tokens.AddProcessedTokens({token_at_step.at(0)});
}
}
}
// Reset sampler input handling as the step is rolled back.
if (sampler_ != nullptr && sampler_->HandlesInput()) {
RETURN_IF_ERROR(SetSamplerInputHandling(/*reset=*/true));
}
return absl::OkStatus();
}
absl::Status LlmLiteRtCompiledModelExecutorBase::PrepareFirstPrefillAfterDecode(
int token_index_to_reduce) {
if (!llm_context_->runtime_state().ran_decode && !force_prepare_needed_) {
return absl::OkStatus();
}
force_prepare_needed_ = false;
llm_context_->runtime_state().ran_decode = false;
int output_heads = 1;
if (llm_context_->runtime_config().output_heads.has_value()) {
output_heads = llm_context_->runtime_config().output_heads.value();
}
if (output_heads > 1) {
LITERT_RETURN_IF_ERROR(llm_context_->processed_context()
.processed_tokens()
.ReduceTokenCandidates(token_index_to_reduce));
LITERT_RETURN_IF_ERROR(
CopyKvCacheBuffers(output_heads, token_index_to_reduce,
*input_kv_cache_buffers_, kv_cache_buffers_1_));
input_kv_cache_buffers_ = &kv_cache_buffers_1_;
output_kv_cache_buffers_ = &kv_cache_buffers_2_;
}
// Reset sampler input handling if it handles input for next decode.
if (sampler_ != nullptr && sampler_->HandlesInput()) {
RETURN_IF_ERROR(SetSamplerInputHandling(/*reset=*/true));
}
return absl::OkStatus();
}
absl::Status LlmLiteRtCompiledModelExecutorBase::PrefillInternal(
absl::string_view prefill_signature,
absl::flat_hash_map<absl::string_view, TensorBuffer>& prefill_input_buffers,
Span<const int> ids, bool async) {
RETURN_IF_ERROR(RollBackProcessedTokens());
{
// Fill the input buffers with scoped locks.
auto& prefill_input_pos =
prefill_input_buffers[signatures_.input_positions];
LITERT_ASSIGN_OR_RETURN(auto prefill_input_pos_size,
prefill_input_pos.PackedSize());
LITERT_ASSIGN_OR_RETURN(
auto prefill_input_pos_lock_and_addr,
TensorBufferScopedLock::Create(prefill_input_pos,
TensorBuffer::LockMode::kWrite));
auto* prefill_input_pos_ptr =
static_cast<int32_t*>(prefill_input_pos_lock_and_addr.second);
memset(prefill_input_pos_ptr, 0, prefill_input_pos_size);
if (signatures_.input_attn_mask.has_value()) {
RETURN_IF_ERROR(InitializeAttentionMask(
prefill_input_buffers[signatures_.input_attn_mask.value()],
use_fp16_precision_));
}
// TODO(b/425396146): Add the unit tests for checking the prefill length.
// We always hold one pending token in the input ids for the next
// prefill or decode step.
int prefill_length = ids.size() - 1;
// Check if have a pending input token. Note that 'internal_start_step' is
// always equal to the number of processed tokens plus 1.
auto [internal_start_step, pending_input_token] =
llm_context_->processed_context()
.processed_tokens()
.GetNextUnprocessedToken();
RET_CHECK_LE(pending_input_token.size(), 1);
const int start_step = internal_start_step;
const bool has_pending_input_token = !pending_input_token.empty();
const bool use_token_as_lookup = !signatures_.input_tokens.empty();
const bool use_per_layer_embedding =
signatures_.input_per_layer_embeddings.has_value();
// If there is no pending input token and no input token to prefill, we can
// skip the prefill by storing the token as a pending input token.
bool skip_prefill = !has_pending_input_token && prefill_length == 0;
if (!skip_prefill) {
int input_idx = 0;
if (has_pending_input_token) {
if (use_token_as_lookup) {
RETURN_IF_ERROR(FillInputBufferWithToken(
pending_input_token,
prefill_input_buffers[signatures_.input_tokens]));
} else {
RETURN_IF_ERROR(FillInputBufferWithToken(
pending_input_token,
prefill_input_buffers[signatures_.input_embeddings.value()]));
if (use_per_layer_embedding) {
RETURN_IF_ERROR(FillInputBufferWithToken(
pending_input_token,
prefill_input_buffers[signatures_.input_per_layer_embeddings
.value()],
/*is_per_layer_embedding=*/true));
}
}
prefill_input_pos_ptr[input_idx] = internal_start_step;
RETURN_IF_ERROR(llm_context_->processed_context()
.processed_tokens()
.MarkPendingInputTokenAsProcessed());
++prefill_input_pos_ptr;
++input_idx;
}
std::transform(prefill_input_pos_ptr,
prefill_input_pos_ptr + prefill_length,
prefill_input_pos_ptr, [&](int token) mutable {
return llm_context_->runtime_state().current_step++;
});
std::vector<int> processed_input_tokens(ids.begin(),
ids.begin() + prefill_length);
llm_context_->processed_context().processed_tokens().AddProcessedTokens(
processed_input_tokens);
if (use_token_as_lookup) {
auto& prefill_input_buffer =
prefill_input_buffers[signatures_.input_tokens];
LITERT_ASSIGN_OR_RETURN(
auto prefill_input_lock_and_addr,
TensorBufferScopedLock::Create(prefill_input_buffer,
TensorBuffer::LockMode::kWrite));
int32_t* prefill_input_ptr =
static_cast<int32_t*>(prefill_input_lock_and_addr.second);
if (!has_pending_input_token) {
LITERT_ASSIGN_OR_RETURN(auto prefill_input_size,
prefill_input_buffer.PackedSize());
// If there is a pending input token, the zeros and the pending input
// token id are already filled in the above
// FillInputBufferWithToken() function, so we cannot zero out the
// whole prefill input buffer here.
//
// If there is no pending input token, we need to zero out the whole
// prefill input buffer.
memset(prefill_input_ptr, 0, prefill_input_size);
}
memcpy(prefill_input_ptr + input_idx, processed_input_tokens.data(),
processed_input_tokens.size() * sizeof(int32_t));
} else {
// If not using token as lookup, we must have input_embeddings. There is
// no need to create input_embeddings_ptr because TensorBuffer locking
// and filling is handled by the embedding lookup.
TensorBuffer* prefill_input_embeddings_buffer =
&(prefill_input_buffers[signatures_.input_embeddings.value()]);
RETURN_IF_ERROR(embedding_lookup_->LookupPrefill(
processed_input_tokens, prefill_input_embeddings_buffer,
/*offset=*/input_idx));
// We may have per layer embedding as well.
if (signatures_.input_per_layer_embeddings) {
TensorBuffer* prefill_input_per_layer_embeddings_buffer =
&(prefill_input_buffers[signatures_.input_per_layer_embeddings
.value()]);
RETURN_IF_ERROR(per_layer_embedding_lookup_->LookupPrefill(
processed_input_tokens, prefill_input_per_layer_embeddings_buffer,
/*offset=*/input_idx));
}
}
if (signatures_.input_attn_mask.has_value()) {
RETURN_IF_ERROR(FillAttentionMask(
prefill_input_buffers[signatures_.input_attn_mask.value()],
start_step,
/*steps=*/prefill_length + input_idx));
}
if (gpu_optimized_single_buffer_cache_) {
LITERT_RETURN_IF_ERROR(signatures_.input_int32_param.has_value());
RETURN_IF_ERROR(FillSingleBufferCacheParamTensor(
prefill_input_buffers[signatures_.input_int32_param.value()],
start_step, ids.size()));
}
}
// Add the last token of the current input as a pending input token, to be
// used in the next prefill or decode.
auto last_input_token = std::make_shared<TokenData>(ids.back());
if (!use_token_as_lookup) {
// Look up the embeddings for the last token so they can be used in the
// next prefill or decode. This has to be done now in the case of
// multi-modal prefill so the embeddings are used in the correct order.
RETURN_IF_ERROR(embedding_lookup_->LookupPrefill(
last_input_token->id(), last_input_token->mutable_embedding()));
if (use_per_layer_embedding) {
RETURN_IF_ERROR(per_layer_embedding_lookup_->LookupPrefill(
last_input_token->id(),
last_input_token->mutable_per_layer_embedding()));
}
}
// Add the last input token to the pending input token list.
RETURN_IF_ERROR(llm_context_->processed_context()
.processed_tokens()
.AddPendingInputToken({std::move(last_input_token)}));
++llm_context_->runtime_state().current_step;
if (skip_prefill) {
return absl::OkStatus();
}
}
return BindTensorsAndRunPrefill(prefill_signature, prefill_input_buffers,
async);
}
absl::Status LlmLiteRtCompiledModelExecutorBase::BindTensorsAndRunPrefill(
absl::string_view prefill_signature,
absl::flat_hash_map<absl::string_view, TensorBuffer>& prefill_input_buffers,
bool async) {
absl::flat_hash_map<absl::string_view, TensorBuffer> input_buffers;
for (const auto& [input_name, input_buffer] : prefill_input_buffers) {
LITERT_ASSIGN_OR_RETURN(auto input_buffer_dup, input_buffer.Duplicate());
input_buffers[input_name] = std::move(input_buffer_dup);
}
for (const auto& [input_name, input_buffer] : *input_kv_cache_buffers_) {
LITERT_ASSIGN_OR_RETURN(auto input_buffer_dup, input_buffer.Duplicate());
input_buffers[input_name] = std::move(input_buffer_dup);
}
absl::flat_hash_map<absl::string_view, TensorBuffer> output_buffers;
for (const auto& [output_name, output_buffer] : *output_kv_cache_buffers_) {
LITERT_ASSIGN_OR_RETURN(auto output_buffer_dup, output_buffer.Duplicate());
output_buffer_dup.ClearEvent();
output_buffers[output_name] = std::move(output_buffer_dup);
}
if (async) {
LITERT_RETURN_IF_ERROR(compiled_model_.RunAsync(
prefill_signature, input_buffers, output_buffers, async));
} else {
LITERT_RETURN_IF_ERROR(
compiled_model_.Run(prefill_signature, input_buffers, output_buffers));
}
if (!gpu_optimized_single_buffer_cache_) {
std::swap(input_kv_cache_buffers_, output_kv_cache_buffers_);
}
return absl::OkStatus();
}
absl::StatusOr<ProcessedTokens::StepAndToken>
LlmLiteRtCompiledModelExecutorBase::GetTokenToDecode(
const ExecutorInputs& inputs) {
RETURN_IF_ERROR(RollBackProcessedTokens());
if (inputs.GetTextDataPtr().ok()) {
LITERT_ASSIGN_OR_RETURN(auto token_ids_buffer, inputs.GetTextTokenIdsPtr());
auto input_tensor_size = token_ids_buffer->PackedSize();
if (input_tensor_size && *input_tensor_size != 0) {
int output_heads = 1;
if (llm_context_->runtime_config().output_heads.has_value()) {
output_heads = llm_context_->runtime_config().output_heads.value();
}
// Input token ids provided, so use it regardless of whether next input
// token id is set.
RET_CHECK_EQ(*input_tensor_size, output_heads * sizeof(int32_t));
LITERT_ASSIGN_OR_RETURN(
auto ids, ReferTensorBufferAsSpan<int32_t>(*token_ids_buffer));
if (ids[0] >= 0) {
// If the input token id is >= 0, it means the input token is provided
// by the user. In this case, we should invalidate the pending input
// token and add the input token as a pending input token.
llm_context_->processed_context()
.processed_tokens()
.InvalidatePendingInputToken();
std::vector<std::shared_ptr<TokenData>> token;
token.reserve(output_heads);
for (int i = 0; i < output_heads; ++i) {
token.push_back(std::make_shared<TokenData>(ids[i]));
}
RETURN_IF_ERROR(llm_context_->processed_context()
.processed_tokens()
.AddPendingInputToken(token));
}
}
}
// Here we must have a pending input token to decode that's either coming from
// the previous prefill or decode, or we just added one from the inputs.
for (const auto& token : llm_context_->processed_context()
.processed_tokens()
.GetNextUnprocessedToken()
.token) {
// If the token has no embedding, we will look up the embedding for the
// token here. This reduces the complexity for internal or external
// sampling.
if (signatures_.input_embeddings.has_value() &&
token->mutable_embedding().empty()) {
RETURN_IF_ERROR(embedding_lookup_->LookupDecode(
token->id(), token->mutable_embedding()));
if (signatures_.input_per_layer_embeddings.has_value()) {
RETURN_IF_ERROR(per_layer_embedding_lookup_->LookupDecode(
token->id(), token->mutable_per_layer_embedding()));
}
}
}
return llm_context_->processed_context()
.processed_tokens()
.GetNextUnprocessedToken();
}
absl::Status
LlmLiteRtCompiledModelExecutorBase::ConsumePendingOrAddProcessedToken(
const std::vector<std::shared_ptr<TokenData>>& token) {
auto status = llm_context_->processed_context()
.processed_tokens()
.MarkPendingInputTokenAsProcessed();
if (status.ok() || status.code() != absl::StatusCode::kNotFound) {
return status;
}
// If the pending input token was not used, we should add the token to the
// processed tokens.
std::vector<int> processed_tokens;
int output_heads = 1;
if (llm_context_->runtime_config().output_heads.has_value()) {
output_heads = llm_context_->runtime_config().output_heads.value();
}
processed_tokens.reserve(output_heads);
for (const auto& t : token) {
processed_tokens.push_back(t->id());
}
llm_context_->processed_context().processed_tokens().AddProcessedTokens(
processed_tokens);
++llm_context_->runtime_state().current_step;
return absl::OkStatus();
}
absl::Status LlmLiteRtCompiledModelExecutorBase::DecodeInternal(
const std::vector<std::shared_ptr<TokenData>>& token,
TensorBuffer& output_logits) {
int step = llm_context_->runtime_state().current_step - 1;
if (sampler_ && sampler_->HandlesInput()) {
// The sampler has already been running decode for this step. Check if
// output_logits is the one used last time, i.e. by
// BindTensorsAndRunDecodeStatic().
LITERT_RETURN_IF_ERROR(
output_logits.Get() ==
decode_output_buffers_[signatures_.output_logits].Get());
return absl::OkStatus();
}
const bool use_token_as_lookup = !signatures_.input_tokens.empty();
const bool use_per_layer_embedding =
signatures_.input_per_layer_embeddings.has_value();
// Fill the input buffers with scoped locks.
if (use_token_as_lookup) {
RETURN_IF_ERROR(FillInputBufferWithToken(
token, decode_input_buffers_[signatures_.input_tokens]));
} else {
if (!signatures_.input_embeddings.has_value()) {
return absl::InvalidArgumentError(
"Input tokens or embeddings must be provided.");
}
RETURN_IF_ERROR(FillInputBufferWithToken(
token, decode_input_buffers_[signatures_.input_embeddings.value()]));
if (use_per_layer_embedding) {
RETURN_IF_ERROR(FillInputBufferWithToken(
token,
decode_input_buffers_[signatures_.input_per_layer_embeddings.value()],
/*is_per_layer_embedding=*/true));
}
}
{
LITERT_ASSIGN_OR_RETURN(
auto input_pos_type,
decode_input_buffers_[signatures_.input_positions].TensorType());
LITERT_ASSIGN_OR_RETURN(
auto input_pos_lock_and_addr,
TensorBufferScopedLock::Create(
decode_input_buffers_[signatures_.input_positions],
TensorBuffer::LockMode::kWrite));
auto* input_pos_ptr = static_cast<int32_t*>(input_pos_lock_and_addr.second);
if (input_pos_type.Layout().Dimensions()[0] == 1) {
*input_pos_ptr = step;
} else {
int output_heads = 1;
if (llm_context_->runtime_config().output_heads.has_value()) {
output_heads = llm_context_->runtime_config().output_heads.value();
}
RET_CHECK_EQ(input_pos_type.Layout().Dimensions()[0], output_heads);
LITERT_ASSIGN_OR_RETURN(
auto input_pos_size,
decode_input_buffers_[signatures_.input_positions].PackedSize());
size_t offset = input_pos_size / output_heads / sizeof(int32_t);
for (int i = 0; i < output_heads; ++i) {
input_pos_ptr[i * offset] = step;
}
}
}
if (signatures_.input_attn_mask.has_value()) {
RETURN_IF_ERROR(InitializeAttentionMask(
decode_input_buffers_[signatures_.input_attn_mask.value()],
use_fp16_precision_));
RETURN_IF_ERROR(FillAttentionMask(
decode_input_buffers_[signatures_.input_attn_mask.value()], step,
/*steps=*/1));
}
if (gpu_optimized_single_buffer_cache_) {
LITERT_RETURN_IF_ERROR(signatures_.input_int32_param.has_value());
RETURN_IF_ERROR(FillSingleBufferCacheParamTensor(
decode_input_buffers_[signatures_.input_int32_param.value()], step, 1));
}
return BindTensorsAndRunDecode(&output_logits);
}
absl::Status LlmLiteRtCompiledModelExecutorBase::BindTensorsAndRunDecode(
TensorBuffer* output_logits) {
absl::flat_hash_map<absl::string_view, TensorBuffer> decode_input_buffers;
for (const auto& [input_name, input_buffer] : decode_input_buffers_) {
LITERT_ASSIGN_OR_RETURN(auto input_buffer_dup, input_buffer.Duplicate());
decode_input_buffers[input_name] = std::move(input_buffer_dup);
}
for (const auto& [input_name, input_buffer] : *input_kv_cache_buffers_) {
LITERT_ASSIGN_OR_RETURN(auto input_buffer_dup, input_buffer.Duplicate());
decode_input_buffers[input_name] = std::move(input_buffer_dup);
}
absl::flat_hash_map<absl::string_view, TensorBuffer> decode_output_buffers;
for (const auto& [output_name, output_buffer] : decode_output_buffers_) {
// LITERT_ASSIGN_OR_RETURN() causes a compilation error on windows.
auto output_buffer_dup =
output_logits && output_name == signatures_.output_logits
? output_logits->Duplicate()
: output_buffer.Duplicate();
RET_CHECK(output_buffer_dup) << "Failed to duplicate output buffer.";
output_buffer_dup->ClearEvent();
decode_output_buffers[output_name] = std::move(*output_buffer_dup);
}
for (const auto& [output_name, output_buffer] : *output_kv_cache_buffers_) {
LITERT_ASSIGN_OR_RETURN(auto output_buffer_dup, output_buffer.Duplicate());
output_buffer_dup.ClearEvent();
decode_output_buffers[output_name] = std::move(output_buffer_dup);
}
bool async = true;
LITERT_RETURN_IF_ERROR(
compiled_model_.RunAsync(kDecodeSignatureRunner, decode_input_buffers,
decode_output_buffers, async));
if (!gpu_optimized_single_buffer_cache_) {
std::swap(input_kv_cache_buffers_, output_kv_cache_buffers_);
}
return absl::OkStatus();
}
int LlmLiteRtCompiledModelExecutorBase::BindTensorsAndRunDecodeStatic(
void* arg) {
auto self = static_cast<LlmLiteRtCompiledModelExecutorBase*>(arg);
// Run decode with default output_logits.
auto status = self->BindTensorsAndRunDecode(/*output_logits=*/nullptr);
if (!status.ok()) {
ABSL_LOG(ERROR) << "Failed to bind tensors and run decode: " << status;
}
return status.raw_code();
}
absl::Status LlmLiteRtCompiledModelExecutorBase::PrepareFirstDecode() {
if (llm_context_->runtime_state().ran_decode && !force_prepare_needed_) {
return absl::OkStatus();
}
force_prepare_needed_ = false;
// Mark that we have run decode at least once.
llm_context_->runtime_state().ran_decode = true;
int output_heads = 1;
if (llm_context_->runtime_config().output_heads.has_value()) {
output_heads = llm_context_->runtime_config().output_heads.value();
}
if (output_heads <= 1) {
return absl::OkStatus();
}
LITERT_RETURN_IF_ERROR(llm_context_->processed_context()
.processed_tokens()
.BroadcastTokenCandidates(output_heads));
LITERT_RETURN_IF_ERROR(decode_kv_cache_buffers_1_.has_value());
LITERT_RETURN_IF_ERROR(decode_kv_cache_buffers_2_.has_value());
// Broadcast the prefill kv cache buffers to the decode kv cache buffers.
// This is only needed when decode batch size > 1.
LITERT_RETURN_IF_ERROR(CopyKvCacheBuffers(
output_heads, /*src_index_to_copy_on_prefill=*/-1,
*input_kv_cache_buffers_, *decode_kv_cache_buffers_1_));
input_kv_cache_buffers_ = &decode_kv_cache_buffers_1_.value();
output_kv_cache_buffers_ = &decode_kv_cache_buffers_2_.value();
return absl::OkStatus();
}
absl::StatusOr<std::vector<std::vector<int>>>
LlmLiteRtCompiledModelExecutorBase::Decode() {
return Decode(ExecutorDecodeParams());
}
absl::StatusOr<std::vector<std::vector<int>>>
LlmLiteRtCompiledModelExecutorBase::Decode(
const ExecutorDecodeParams& decode_params) {
std::vector<std::vector<int>> output_tokens_vector;
if (mtp_drafter_ == nullptr) {
ASSIGN_OR_RETURN(auto decoded_logits,
DecodeLogits(ExecutorInputs(), decode_params));
std::optional<TensorBuffer> output_tokens;
{
LITERT_ASSIGN_OR_RETURN(auto decoded_logits_type,
decoded_logits.TensorType());
auto dimensions = decoded_logits_type.Layout().Dimensions();
// Shape of decoded_logits is [batch_size, Token_length, vocab_size].
RET_CHECK_EQ(dimensions.size(), 3);
LITERT_ASSIGN_OR_RETURN(
output_tokens,
CreateTensorBuffer<int>({dimensions[0], dimensions[1]}));
}
RETURN_IF_ERROR(SampleLogits(decoded_logits, *output_tokens));
LITERT_ASSIGN_OR_RETURN(output_tokens_vector,
CopyFromTensorBuffer2D<int>(*output_tokens));
} else {
// MTP keeps an internal state of the last time it was called and will
// use those projected activations to kick off the next draft steps. As
// such, we need to do a single decode step on the first decode call after
// prefill and provide the projected activations to the MTP drafted only
// once.
bool last_run_is_decode = llm_context_->runtime_state().ran_decode;
if (last_run_is_decode) {
ASSIGN_OR_RETURN(auto step_and_token, GetTokenToDecode(ExecutorInputs()));
RETURN_IF_ERROR(ConsumePendingOrAddProcessedToken(step_and_token.token));
// Output: [Batch, drafted and verified tokens]
LITERT_ASSIGN_OR_RETURN(output_tokens_vector,
mtp_drafter_->Draft(step_and_token.step,
step_and_token.token[0]->id(),
/*activations=*/std::nullopt,
*input_kv_cache_buffers_,
*output_kv_cache_buffers_));
RET_CHECK_EQ(output_tokens_vector.size(), 1);
llm_context_->runtime_state().current_step +=
output_tokens_vector[0].size();
} else {
int token_id = -1;
{
ASSIGN_OR_RETURN(auto decoded_logits,
DecodeLogits(ExecutorInputs(), decode_params));
LITERT_ASSIGN_OR_RETURN(auto decoded_logits_type,
decoded_logits.TensorType());
auto dimensions = decoded_logits_type.Layout().Dimensions();
// Shape of decoded_logits is [batch_size, Token_length, vocab_size].
RET_CHECK_EQ(dimensions.size(), 3);
LITERT_ASSIGN_OR_RETURN(
auto output_tokens,
CreateTensorBuffer<int>({dimensions[0], dimensions[1]}));
RETURN_IF_ERROR(SampleLogits(decoded_logits, output_tokens));
LITERT_ASSIGN_OR_RETURN(output_tokens_vector,
CopyFromTensorBuffer2D<int>(output_tokens));
RET_CHECK_EQ(output_tokens_vector.size(), 1);
RET_CHECK_EQ(output_tokens_vector[0].size(), 1);
token_id = output_tokens_vector[0][0];
}
RET_CHECK(decode_output_buffers_.contains("activations"));
LITERT_ASSIGN_OR_RETURN(
auto activations, decode_output_buffers_["activations"].Duplicate());
// Note: Position remains the same as the prefill step. However,
// current_step is incremented in DecodeLogits and as such needs to be
// decremented.
LITERT_ASSIGN_OR_RETURN(
output_tokens_vector,
mtp_drafter_->Draft(llm_context_->runtime_state().current_step - 1,
token_id, std::move(activations),
*input_kv_cache_buffers_,
*output_kv_cache_buffers_));
llm_context_->runtime_state().current_step +=
output_tokens_vector[0].size();
output_tokens_vector[0].insert(output_tokens_vector[0].begin(), token_id);
}
}
// Check for any invalid token ids and set them to zero, if any.
bool has_invalid_output_token = false;
for (int batch = 0; batch < output_tokens_vector.size(); ++batch) {
for (int token_idx = 0; token_idx < output_tokens_vector[batch].size();
++token_idx) {
if (output_tokens_vector[batch][token_idx] < 0) {
has_invalid_output_token = true;
output_tokens_vector[batch][token_idx] = 0;
}
}
}
if (has_invalid_output_token) {
ABSL_LOG(WARNING) << "Invalid decode and sample result. The sampled token "
"is casted to 0 to avoid crash.";
}
// Update context with the assumption that there is one output per head.
// We must change this when doing drafter based decoding.
std::vector<int> processed_tokens;
std::vector<std::shared_ptr<TokenData>> pending_tokens;
for (auto& output_head_tokens : output_tokens_vector) {
for (int i = 0; i < output_head_tokens.size(); ++i) {
// Last token is reserved as pending input token.
if (i == output_head_tokens.size() - 1) {
pending_tokens.push_back(
std::make_shared<TokenData>(output_head_tokens[i]));
} else {
processed_tokens.push_back(output_head_tokens[i]);
}
}
}
if (!processed_tokens.empty()) {
llm_context_->processed_context().processed_tokens().AddProcessedTokens(
processed_tokens);
}
RETURN_IF_ERROR(
llm_context_->processed_context().processed_tokens().AddPendingInputToken(
pending_tokens));
return output_tokens_vector;
}
absl::Status LlmLiteRtCompiledModelExecutorBase::Decode(
const ExecutorInputs& inputs, TensorBuffer& output_logits) {
RETURN_IF_ERROR(PrepareFirstDecode());
ASSIGN_OR_RETURN(auto step_and_token, GetTokenToDecode(inputs));
RETURN_IF_ERROR(DecodeInternal(step_and_token.token, output_logits));
RETURN_IF_ERROR(ConsumePendingOrAddProcessedToken(step_and_token.token));
++llm_context_->runtime_state().current_step;
return absl::OkStatus();
}
absl::StatusOr<TensorBuffer> LlmLiteRtCompiledModelExecutorBase::DecodeLogits(
const ExecutorInputs& inputs) {
return DecodeLogits(inputs, ExecutorDecodeParams());
}
absl::StatusOr<TensorBuffer> LlmLiteRtCompiledModelExecutorBase::DecodeLogits(
const ExecutorInputs& inputs, const ExecutorDecodeParams& decode_params) {
LITERT_ASSIGN_OR_RETURN(
auto output_logits,
decode_output_buffers_[signatures_.output_logits].Duplicate());
bool last_run_is_decode = llm_context_->runtime_state().ran_decode;
RETURN_IF_ERROR(PrepareFirstDecode());
ASSIGN_OR_RETURN(auto step_and_token, GetTokenToDecode(inputs));
RETURN_IF_ERROR(DecodeInternal(step_and_token.token, output_logits));
RETURN_IF_ERROR(ConsumePendingOrAddProcessedToken(step_and_token.token));
if (decode_params.HasConstraintDecoder() && !step_and_token.token.empty()) {
int output_heads = 1;
if (llm_context_->runtime_config().output_heads.has_value()) {
output_heads = llm_context_->runtime_config().output_heads.value();
}
RET_CHECK_EQ(step_and_token.token.size(), output_heads);
std::vector<int> current_token_ids;
current_token_ids.reserve(output_heads);
for (const auto& token : step_and_token.token) {
current_token_ids.push_back(token->id());
}
// Update constraint state only with decode ids.
if (last_run_is_decode) {
RETURN_IF_ERROR(
decode_params.GetConstraintDecoder()->UpdateConstraintState(
absl::MakeSpan(current_token_ids)));
}
LITERT_ASSIGN_OR_RETURN(auto output_logits_buffer_type,
output_logits.BufferType());
// If the output logits are already on the host memory, use the buffer
// directly.
if (output_logits_buffer_type == TensorBufferType::kHostMemory) {
// Mask logits based on the current constraint state.
RETURN_IF_ERROR(
decode_params.GetConstraintDecoder()->MaskLogits(output_logits));
} else {
// For GPU, we always copy the logits to CPU and mask them, then write
// them back to GPU.
LITERT_ASSIGN_OR_RETURN(RankedTensorType logits_tensor_type,
output_logits.TensorType());
if (logits_tensor_type.ElementType() == ElementType::Float32) {
// Copy the logits from the tensor buffer to a vector.
LITERT_ASSIGN_OR_RETURN(auto logits_vector,
CopyFromTensorBuffer<float>(output_logits));
// Mask logits based on the current constraint state.
RETURN_IF_ERROR(decode_params.GetConstraintDecoder()->MaskLogits(
absl::MakeSpan(logits_vector.data(), logits_vector.size()),
logits_tensor_type.Layout().Dimensions()));
// Write the masked logits back to the tensor buffer.
output_logits.Write(
absl::MakeConstSpan(logits_vector.data(), logits_vector.size()));
} else if (logits_tensor_type.ElementType() ==
litert::ElementType::Float16) {
// Copy the logits from the tensor buffer to a vector.
LITERT_ASSIGN_OR_RETURN(
auto logits_vector,
CopyFromTensorBuffer<tflite::half>(output_logits));
// Mask logits based on the current constraint state.
RETURN_IF_ERROR(decode_params.GetConstraintDecoder()->MaskLogits(
absl::MakeSpan(logits_vector.data(), logits_vector.size()),
logits_tensor_type.Layout().Dimensions()));
// Write the masked logits back to the tensor buffer.
output_logits.Write(
absl::MakeConstSpan(logits_vector.data(), logits_vector.size()));
} else {
return absl::InvalidArgumentError(
"Output logits are not in float32 or float16 type.");
}
}
}
++llm_context_->runtime_state().current_step;
const auto& settings = executor_settings_.GetAdvancedSettings();
if (settings && settings->num_logits_to_print_after_decode > 0) {
LogTensor(output_logits, settings->num_logits_to_print_after_decode,
"Logits")
.IgnoreError();
}
return output_logits;
}
absl::Status LlmLiteRtCompiledModelExecutorBase::InitializeSampler(
std::optional<ActivationDataType> logits_data_type) {
if (sampler_ != nullptr) {
return absl::OkStatus();
}
// Use the provided activation data type if available, otherwise fallback to
// the member variable.
auto data_type = logits_data_type.value_or(logits_data_type_);
ASSIGN_OR_RETURN(auto vocab_size, GetVocabSize());
ASSIGN_OR_RETURN(auto sampler_backend, GetSamplerBackend(executor_settings_));
int output_heads = 1;
if (llm_context_->runtime_config().output_heads.has_value()) {
output_heads = llm_context_->runtime_config().output_heads.value();
}
proto::SamplerParameters sampler_params;
sampler_params.set_type(proto::SamplerParameters::TOP_P);
sampler_params.set_k(1);
sampler_params.set_p(0.0f);
sampler_params.set_temperature(1.0f);
sampler_params.set_seed(0);
ASSIGN_OR_RETURN(
sampler_,
CreateSampler(sampler_backend, output_heads, std::move(sampler_params),
env_.Get(), /*sequence_size=*/1, vocab_size, data_type));
// If the sampler can handle input, prepare the input tensors for it.
sampler_handles_input_ =
(!executor_settings_.GetAdvancedSettings().has_value() ||
executor_settings_.GetAdvancedSettings()->sampler_handles_input) &&
sampler_->CanHandleInput() && !signatures_.input_tokens.empty();
if (sampler_handles_input_) {
ABSL_LOG(INFO) << "Sampler will handle decode input tensors.";
if (!decode_prev_input_pos_) {
LITERT_ASSIGN_OR_RETURN(
decode_prev_input_pos_,
compiled_model_.CreateInputBuffer(kDecodeSignatureRunner,
signatures_.input_positions));
}
if (!decode_prev_mask_ && signatures_.input_attn_mask.has_value()) {
LITERT_ASSIGN_OR_RETURN(
decode_prev_mask_,
compiled_model_.CreateInputBuffer(kDecodeSignatureRunner,
*signatures_.input_attn_mask));
}
// Set, then reset the input handling to get the underlying model ready, but
// not to bind the input tensors.
RETURN_IF_ERROR(SetSamplerInputHandling(/*reset=*/false));
RETURN_IF_ERROR(SetSamplerInputHandling(/*reset=*/true));
}
return absl::OkStatus();
}
absl::Status LlmLiteRtCompiledModelExecutorBase::SwapSamplerInputTensors() {
bool has_input_attn_mask = signatures_.input_attn_mask.has_value();
// Move the input_pos and mask to previous ones.
std::swap(decode_prev_input_pos_,
decode_input_buffers_[signatures_.input_positions]);
if (has_input_attn_mask) {
std::swap(decode_prev_mask_,
decode_input_buffers_[*signatures_.input_attn_mask]);
}
return SetSamplerInputHandling(/*reset=*/false);
}
absl::Status LlmLiteRtCompiledModelExecutorBase::SetSamplerInputHandling(
bool reset) {
if (reset) {
return sampler_->SetInputTensorsAndInferenceFunc(
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr);
}
bool has_input_attn_mask = signatures_.input_attn_mask.has_value();
return sampler_->SetInputTensorsAndInferenceFunc(
&decode_input_buffers_[signatures_.input_tokens], &decode_prev_input_pos_,
&decode_input_buffers_[signatures_.input_positions],
has_input_attn_mask ? &decode_prev_mask_ : nullptr,
has_input_attn_mask ? &decode_input_buffers_[*signatures_.input_attn_mask]
: nullptr,
BindTensorsAndRunDecodeStatic, this);
}
absl::Status LlmLiteRtCompiledModelExecutorBase::SampleLogits(
const TensorBuffer& logits, TensorBuffer& ids_tensor) {
if (sampler_ == nullptr) {
LITERT_ASSIGN_OR_RETURN(auto logits_tensor_type, logits.TensorType());
ActivationDataType logits_data_type;
if (logits_tensor_type.ElementType() == ElementType::Float16) {
logits_data_type = ActivationDataType::FLOAT16;
} else if (logits_tensor_type.ElementType() == ElementType::Float32) {
logits_data_type = ActivationDataType::FLOAT32;
} else {
return absl::InvalidArgumentError(
absl::StrCat("Unsupported logits data type for sampler: ",
static_cast<int>(logits_tensor_type.ElementType())));
}
RETURN_IF_ERROR(InitializeSampler(logits_data_type));
}
if (sampler_handles_input_) {
RETURN_IF_ERROR(SwapSamplerInputTensors());
}
RETURN_IF_ERROR(sampler_->SampleToIdAndScoreBuffer(
logits, ids_tensor, /*scores_tensor=*/nullptr));
return absl::OkStatus();
}
absl::Status LlmLiteRtCompiledModelExecutorBase::UpdateExecutorSettings(
const LlmExecutorSettings& executor_settings) {
executor_settings_ = executor_settings;
return absl::OkStatus();
}
absl::Status LlmLiteRtCompiledModelExecutorBase::SetCurrentStep(int new_step) {
ASSIGN_OR_RETURN(auto old_step, GetCurrentStep());
if (old_step == new_step) {
return absl::OkStatus();
}
int max_step = old_step;
RET_CHECK_LE(new_step, max_step).SetCode(absl::StatusCode::kInvalidArgument)
<< "New step cannot be greater than the max step: " << max_step;
RET_CHECK_GE(new_step, 0).SetCode(absl::StatusCode::kInvalidArgument)
<< "New step cannot be negative.";
if (new_step == max_step) {
llm_context_->runtime_state().current_step = new_step;
return absl::OkStatus();
}
RET_CHECK_LE(new_step, max_step).SetCode(absl::StatusCode::kInvalidArgument)
<< "New step cannot be greater than the max step: " << max_step;
if (new_step < 0) {
// Current step is negative after rolling back. This can only happen when
// the user wants to set the step to 0 while there is a pending input token.
// Thus we can roll back executor state to step 0.
return Reset();
}
llm_context_->runtime_state().current_step = new_step;
return absl::OkStatus();
}
absl::Status LlmLiteRtCompiledModelExecutorBase::Reset() {
llm_context_->runtime_state().current_step = 0;
return absl::OkStatus();
}
absl::StatusOr<int> LlmLiteRtCompiledModelExecutorBase::GetVocabSize() {
if (!decode_output_buffers_.contains(signatures_.output_logits)) {
return absl::NotFoundError("Output logits info not found.");
}
LITERT_ASSIGN_OR_RETURN(
auto logits_tensor_type,
decode_output_buffers_[signatures_.output_logits].TensorType());
RET_CHECK_EQ(logits_tensor_type.Layout().Dimensions().size(), 3);
return logits_tensor_type.Layout().Dimensions()[2];
}
/* ===========================================================================*/
/* LlmLiteRtCompiledModelExecutorStatic */
/* ===========================================================================*/
absl::Status LlmLiteRtCompiledModelExecutorStatic::Prefill(
const ExecutorInputs& inputs, const ExecutorPrefillParams& params) {
int output_heads = 1;
if (llm_context_->runtime_config().output_heads.has_value()) {
output_heads = llm_context_->runtime_config().output_heads.value();
}
// For now, we reduce the input and processed tokens for prefill only with
// the first input and processed tokens. This should be updated if user select
// the decode output candidate.
constexpr int kTokenIndexToReduce = 0;
LITERT_RETURN_IF_ERROR(PrepareFirstPrefillAfterDecode(kTokenIndexToReduce));
LITERT_ASSIGN_OR_RETURN(auto token_ids_buffer, inputs.GetTextTokenIdsPtr());
LITERT_ASSIGN_OR_RETURN(auto tensor_type, token_ids_buffer->TensorType());
// Accept batch size 1 or output_heads though prefill handles only the
// first batch element.
int32_t input_batch_size = tensor_type.Layout().Dimensions()[0];
if (input_batch_size != 1) {
RET_CHECK_EQ(input_batch_size, output_heads);
}
RET_CHECK_GT(tensor_type.Layout().Dimensions()[1], 0)
<< "Prefill token ids must be non-empty.";
if (embedding_lookup_ != nullptr) {
RETURN_IF_ERROR(embedding_lookup_->UpdateMultiModalEmbeddings(inputs));
}
LITERT_ASSIGN_OR_RETURN(auto ids,
ReferTensorBufferAsSpan<int32_t>(*token_ids_buffer));
// Reduce the input ids only with one user selected.
auto input_length = ids.size() / input_batch_size;
ids = ids.subspan(kTokenIndexToReduce * input_length, input_length);
ASSIGN_OR_RETURN(auto work_groups, GetOptimizedPrefillWorkGroups(
prefill_signature_map_, ids.size()));
for (int i = 0; i < work_groups.size(); ++i) {
const auto& prefill_signature = work_groups[i].first;
int prefill_length = work_groups[i].second;
// Keep track of the signatures that have already had their buffers
// created only create them once.
if (!prefill_input_buffers_.contains(prefill_signature)) {
prefill_input_buffers_[prefill_signature] = {};
RETURN_IF_ERROR(CreatePrefillInputBuffers(
prefill_signature, prefill_length, prefill_length,
prefill_input_buffers_[prefill_signature]));
}
// TODO(b/494284915): Switch to use async prefill for Metal backend.
if (!do_prefill_sync_.has_value()) {
do_prefill_sync_ = std::any_of(
prefill_input_buffers_[prefill_signature].begin(),
prefill_input_buffers_[prefill_signature].end(),
[](const auto& pair) { return pair.second.IsMetalMemory(); });
}
bool async = !*do_prefill_sync_ &&
(i < work_groups.size() - 1 || !params.GetWaitForCompletion());
RETURN_IF_ERROR(PrefillInternal(
prefill_signature, prefill_input_buffers_[prefill_signature],
ids.subspan(/*pos=*/0, prefill_length), async));
ids = ids.subspan(/*pos=*/prefill_length);
}
RET_CHECK_EQ(ids.size(), 0).SetCode(absl::StatusCode::kInternal)
<< "Work groups not covering the entire prefill input.";
if (embedding_lookup_ != nullptr) {
RETURN_IF_ERROR(embedding_lookup_->CleanupMultiModalEmbeddings());
}
return absl::OkStatus();
}
// static
// Creates a LlmLiteRtCompiledModelExecutorStatic from a LiteRt model.
absl::StatusOr<std::unique_ptr<LlmLiteRtCompiledModelExecutorStatic>>
LlmLiteRtCompiledModelExecutorStatic::Create(
LlmExecutorSettings executor_settings, Environment& lrt_env,
ModelResources& resources) {
ASSIGN_OR_RETURN(auto litert_model,
resources.GetTFLiteModel(ModelType::kTfLitePrefillDecode));
std::string cache_path = executor_settings.GetCacheDir();
auto activation_data_type = ActivationDataType::FLOAT16;
// TODO(b/433590109): Some GPUs do not support FP16, so we need to check the
// capabilities of the GPU and set the activation data type accordingly.
if (executor_settings.GetActivationDataType().has_value()) {
activation_data_type = executor_settings.GetActivationDataType().value();
}
const Backend backend = executor_settings.GetBackend();
bool use_fp16_precision =
activation_data_type == ActivationDataType::FLOAT16 &&
backend == Backend::GPU;
if (!litert_model || !*litert_model) {
return absl::InternalError("Failed to build LiteRt model");
}
absl::string_view prefill_signature_key = "";
for (int i = 0; i < litert_model->GetNumSignatures(); ++i) {
LITERT_ASSIGN_OR_RETURN(auto sig, litert_model->GetSignature(i));
absl::string_view key = sig.Key();
if (absl::StartsWith(key, kPrefillSignatureRunner)) {
prefill_signature_key = key;
break;
}
}
LITERT_ASSIGN_OR_RETURN(auto prefill_signature,
litert_model->FindSignature(prefill_signature_key));
std::string kv_cache_k_root_name;
std::string kv_cache_v_root_name;
RETURN_IF_ERROR(GetKVCacheRootNames(
prefill_signature.InputNames(), prefill_signature.OutputNames(),
kv_cache_k_root_name, kv_cache_v_root_name));
LITERT_ASSIGN_OR_RETURN(auto decode_signature,
litert_model->FindSignature(kDecodeSignatureRunner));
ASSIGN_OR_RETURN(
ModelSignatures signatures,
GetModelSignaturesFromInputOutputNames(decode_signature.InputNames(),
decode_signature.OutputNames()));
bool gpu_optimized_single_buffer_cache =
backend == Backend::GPU && signatures.input_int32_param.has_value();
LITERT_ASSIGN_OR_RETURN(
auto compilation_options,
CreateCompilationOptions(executor_settings, activation_data_type,
&signatures));
auto section_offset =
resources.GetWeightsSectionOffset(ModelType::kTfLitePrefillDecode);
if (section_offset.ok()) {
if (backend != Backend::GPU) {
return absl::InvalidArgumentError(
"Weights section offset is only "
"supported for GPU backend.");
}
Options::ScopedWeightSectionMap section_map;
section_map["tflite_weights"] = {
section_offset.value().first,
section_offset.value().second - section_offset.value().first};
ABSL_LOG(INFO) << "section_map: " << section_map["tflite_weights"].offset
<< " " << section_map["tflite_weights"].length;
LITERT_ASSIGN_OR_RETURN(auto scoped_file, resources.GetScopedFile());
compilation_options.SetExternalWeightScopedFile(scoped_file.get(),
section_map);
};
LITERT_ASSIGN_OR_RETURN(
auto compiled_model,
CompiledModel::Create(lrt_env, litert_model->Get(), compilation_options));
absl::flat_hash_map<absl::string_view, TensorBuffer> decode_input_buffers;
absl::flat_hash_map<absl::string_view, TensorBuffer> decode_output_buffers;
absl::flat_hash_map<absl::string_view, TensorBuffer> input_kv_cache_buffers;
absl::flat_hash_map<absl::string_view, TensorBuffer> output_kv_cache_buffers;
bool clear_kv_cache_before_prefill =
!executor_settings.GetAdvancedSettings() ||
executor_settings.GetAdvancedSettings()->clear_kv_cache_before_prefill;
for (auto input_name : prefill_signature.InputNames()) {
// Skip creating buffers for the input tokens, positions and attn mask. Move
// into prefill function to create them based on the ids size.
if (!IsKVCacheTensor(input_name) || gpu_optimized_single_buffer_cache) {
continue;
}
LITERT_ASSIGN_OR_RETURN(
auto input_buffer,
compiled_model.CreateInputBuffer(prefill_signature_key, input_name));
if (clear_kv_cache_before_prefill) {
LITERT_RETURN_IF_ERROR(input_buffer.Clear());
}
if (backend == Backend::CPU) {
LITERT_ASSIGN_OR_RETURN(auto output_buffer, input_buffer.Duplicate());
output_kv_cache_buffers[input_name] = std::move(output_buffer);
}
input_kv_cache_buffers[input_name] = std::move(input_buffer);
}
for (auto output_name : prefill_signature.OutputNames()) {
if (IsKVCacheTensor(output_name)) {
if (backend == Backend::GPU) {
LITERT_ASSIGN_OR_RETURN(auto output_buffer,
compiled_model.CreateOutputBuffer(
prefill_signature_key, output_name));
if (clear_kv_cache_before_prefill &&
gpu_optimized_single_buffer_cache) {
LITERT_RETURN_IF_ERROR(output_buffer.Clear());
}
output_kv_cache_buffers[output_name] = std::move(output_buffer);
}
// For CPU, we will use single buffer for kv cache input and output to
// improve performance and memory usage.
} else {
// TODO b/444063139 - Support non-kv_cache tensors as prefill outputs.
// This should be done once we have a model that has non-kv_cache tensors
// as prefill outputs. It should be done in the same place as the prefill
// inputs are created.
return absl::UnimplementedError(absl::StrCat(
"Failed to create prefill output buffer for '", output_name,
"'. Only kv_cache tensors are supported as outputs to "
"prefill at the moment."));
}
}
for (auto input_name : decode_signature.InputNames()) {
if (IsLoRAInputName(input_name)) {
// We let LoraManager handle LoRA inputs.
continue;
}
if (IsKVCacheTensor(input_name)) {
continue;
}
LITERT_ASSIGN_OR_RETURN(
auto input_buffer,
compiled_model.CreateInputBuffer(kDecodeSignatureRunner, input_name));
decode_input_buffers[input_name] = std::move(input_buffer);
}
auto output_names = decode_signature.OutputNames();
for (int i = 0; i < output_names.size(); ++i) {
auto output_name = output_names[i];
if (IsKVCacheTensor(output_name)) {
continue;
}
// If we are using the GPU sampler and the model is compiled with FP16
// precision, we force the output logits to be FP16 as the
// GPU sampler supports FP16 inputs.
// If we use CPU sampler or the model is executed with FP32 / mixed
// precision, we will keep the logits in FP32
auto sampler_backend = GetSamplerBackend(executor_settings);
if (output_name == signatures.output_logits && use_fp16_precision &&
sampler_backend.ok() && *sampler_backend == Backend::GPU) {
LITERT_ASSIGN_OR_RETURN(
size_t signature_index,
compiled_model.GetSignatureIndex(kDecodeSignatureRunner));
LITERT_ASSIGN_OR_RETURN(
auto output_buffer,
CreateFP16OutputBuffer(lrt_env, compiled_model, signature_index,
output_name, i));
decode_output_buffers[output_name] = std::move(output_buffer);
} else {
LITERT_ASSIGN_OR_RETURN(auto output_buffer,
compiled_model.CreateOutputBuffer(
kDecodeSignatureRunner, output_name));
decode_output_buffers[output_name] = std::move(output_buffer);
}
}
LITERT_ASSIGN_OR_RETURN(
auto output_logits_buffer,
decode_output_buffers[signatures.output_logits].Duplicate());
LITERT_ASSIGN_OR_RETURN(auto output_logits_buffer_tensor_type,
output_logits_buffer.TensorType());
RET_CHECK(output_logits_buffer_tensor_type.Layout().Dimensions().size() == 3)
<< "Output logits must be (batch, seq, vocab)";
int batch_size = output_logits_buffer_tensor_type.Layout().Dimensions()[0];
std::optional<absl::flat_hash_map<absl::string_view, TensorBuffer>>
decode_input_kv_cache_buffers;
std::optional<absl::flat_hash_map<absl::string_view, TensorBuffer>>
decode_output_kv_cache_buffers;
if (batch_size > 1) {
ABSL_LOG(INFO) << "Decode batch size is larger than 1. Allocate decode "
<< "only KV cache buffers.";
decode_input_kv_cache_buffers =
absl::flat_hash_map<absl::string_view, TensorBuffer>();
decode_output_kv_cache_buffers =
absl::flat_hash_map<absl::string_view, TensorBuffer>();
for (auto input_name : decode_signature.InputNames()) {
if (absl::StartsWith(input_name, kv_cache_k_root_name) ||
absl::StartsWith(input_name, kv_cache_v_root_name)) {
LITERT_ASSIGN_OR_RETURN(auto input_buffer,
compiled_model.CreateInputBuffer(
kDecodeSignatureRunner, input_name));
if (clear_kv_cache_before_prefill) {
LITERT_RETURN_IF_ERROR(input_buffer.Clear());
}
(*decode_input_kv_cache_buffers)[input_name] = std::move(input_buffer);
}
}
for (auto output_name : decode_signature.OutputNames()) {
if (absl::StartsWith(output_name, kv_cache_k_root_name) ||
absl::StartsWith(output_name, kv_cache_v_root_name)) {
LITERT_ASSIGN_OR_RETURN(auto output_buffer,
compiled_model.CreateOutputBuffer(
kDecodeSignatureRunner, output_name));
(*decode_output_kv_cache_buffers)[output_name] =
std::move(output_buffer);
}
}
}
ASSIGN_OR_RETURN(auto prefill_runner_set,
GetPrefillRunnerSetFromModel(
*litert_model, kPrefillSignatureRunner,
/*input_positions_name=*/signatures.input_positions));
RET_CHECK(!prefill_runner_set.empty()) << "No prefill runner available.";
std::unique_ptr<EmbeddingLookupManager> embedding_lookup;
std::unique_ptr<EmbeddingLookupManager> per_layer_embedding_lookup;
RETURN_IF_ERROR(InitializeEmbeddingLookups(
lrt_env, resources, embedding_lookup, per_layer_embedding_lookup));
std::unique_ptr<LlmLiteRtMtpDrafter> mtp_drafter;
{
const auto& advanced_settings = executor_settings.GetAdvancedSettings();
if (advanced_settings.has_value() &&
advanced_settings->enable_speculative_decoding) {
RET_CHECK_NE(embedding_lookup, nullptr);
RET_CHECK_NE(per_layer_embedding_lookup, nullptr);
LITERT_ASSIGN_OR_RETURN(
auto base_compiled_model,
CompiledModel::Create(lrt_env, litert_model->Get(),
compilation_options));
ASSIGN_OR_RETURN(mtp_drafter,
LlmLiteRtMtpDrafter::Create(
lrt_env, resources, executor_settings,
std::move(base_compiled_model), *embedding_lookup,
*per_layer_embedding_lookup));
}
}
return absl::WrapUnique(new LlmLiteRtCompiledModelExecutorStatic(
std::move(executor_settings), lrt_env, litert_model,
std::move(compiled_model), std::move(decode_input_buffers),
std::move(decode_output_buffers), std::move(input_kv_cache_buffers),
std::move(output_kv_cache_buffers),
std::move(decode_input_kv_cache_buffers),
std::move(decode_output_kv_cache_buffers), std::move(prefill_runner_set),
signatures, batch_size, std::move(cache_path),
std::move(embedding_lookup), std::move(per_layer_embedding_lookup),
use_fp16_precision, activation_data_type, std::move(mtp_drafter)));
}
/* ===========================================================================*/
/* LlmLiteRtCompiledModelExecutorDynamic */
/* ===========================================================================*/
absl::Status LlmLiteRtCompiledModelExecutorDynamic::Prefill(
const ExecutorInputs& inputs, const ExecutorPrefillParams& params) {
// Only accept batch size 1 for now.
LITERT_RETURN_IF_ERROR(PrepareFirstPrefillAfterDecode(0));
LITERT_ASSIGN_OR_RETURN(auto token_ids_buffer, inputs.GetTextTokenIdsPtr());
LITERT_ASSIGN_OR_RETURN(auto tensor_type, token_ids_buffer->TensorType());
RET_CHECK_EQ(tensor_type.Layout().Dimensions()[0], 1);
RET_CHECK_GT(tensor_type.Layout().Dimensions()[1], 0)
<< "Prefill token ids must be non-empty.";
LITERT_ASSIGN_OR_RETURN(
absl::Span<int> ids, ReferTensorBufferAsSpan<int32_t>(*token_ids_buffer));
if (prefill_chunk_size_ <= 0) {
return PrefillInternal(ids, params);
}
while (!ids.empty()) {
int chunk_size =
std::min(static_cast<int>(ids.size()), prefill_chunk_size_);
absl::Span<int> chunk_ids = ids.first(chunk_size);
ids = ids.subspan(chunk_size);
RETURN_IF_ERROR(PrefillInternal(chunk_ids, params));
}
return absl::OkStatus();
}
absl::Status LlmLiteRtCompiledModelExecutorDynamic::PrefillInternal(
absl::Span<int> ids, const ExecutorPrefillParams& params) {
RETURN_IF_ERROR(RollBackProcessedTokens());
// Check if have a pending input token. Note that 'internal_start_step' is
// always equal to the number of processed tokens plus 1.
ProcessedTokens::StepAndToken step_and_token =
llm_context_->processed_context()
.processed_tokens()
.GetNextUnprocessedToken();
bool has_pending_input_token = !step_and_token.token.empty();
int prefill_length = has_pending_input_token ? ids.size() : ids.size() - 1;
// If there is no pending input token and no input token to prefill, we can
// return early by storing the token as a pending input token.
if (!has_pending_input_token && prefill_length == 0) {
RETURN_IF_ERROR(
llm_context_->processed_context()
.processed_tokens()
.AddPendingInputToken({std::make_shared<TokenData>(ids[0])}));
return absl::OkStatus();
}
int kv_length = 0;
if (kv_cache_buffers_1_.empty()) {
kv_length = prefill_length;
// First time prefilling, allocate KV cache buffers.
bool clear_kv_cache_before_prefill =
!executor_settings_.GetAdvancedSettings() ||
executor_settings_.GetAdvancedSettings()->clear_kv_cache_before_prefill;
for (const auto& k_cache_input_name : key_cache_input_names_) {
RETURN_IF_ERROR(ResolveDynamicShape(model_, compiled_model_, "prefill",
k_cache_input_name, prefill_length));
LITERT_ASSIGN_OR_RETURN(
auto input_buffer,
compiled_model_.CreateInputBuffer("prefill", k_cache_input_name));
if (clear_kv_cache_before_prefill) {
LITERT_RETURN_IF_ERROR(input_buffer.Clear());
}
kv_cache_buffers_1_[k_cache_input_name] = std::move(input_buffer);
}
for (const auto& v_cache_input_name : value_cache_input_names_) {
RETURN_IF_ERROR(ResolveDynamicShape(model_, compiled_model_, "prefill",
v_cache_input_name, prefill_length));
LITERT_ASSIGN_OR_RETURN(
auto input_buffer,
compiled_model_.CreateInputBuffer("prefill", v_cache_input_name));
if (clear_kv_cache_before_prefill) {
LITERT_RETURN_IF_ERROR(input_buffer.Clear());
}
kv_cache_buffers_1_[v_cache_input_name] = std::move(input_buffer);
}
} else {
{
RET_CHECK(!kv_cache_buffers_1_.empty());
const TensorBuffer& key_buffer =
kv_cache_buffers_1_[key_cache_input_names_[0]];
LITERT_ASSIGN_OR_RETURN(const RankedTensorType& key_buffer_tensor_type,
key_buffer.TensorType());
kv_length =
key_buffer_tensor_type.Layout().Dimensions()[key_dynamic_dim_index_];
}
int free_kv_entries = kv_length - step_and_token.step;
if (prefill_length > free_kv_entries) {
int new_kv_seq_len = kv_length + prefill_length;
int entries_to_add = new_kv_seq_len - kv_length;
for (const auto& k_cache_input_name : key_cache_input_names_) {
RETURN_IF_ERROR(ResolveDynamicShape(model_, compiled_model_, "prefill",
k_cache_input_name,
new_kv_seq_len));
ASSIGN_OR_RETURN(kv_cache_buffers_1_[k_cache_input_name],
ResizeKVCacheTensorBuffer(
env_, kv_cache_buffers_1_[k_cache_input_name],
key_dynamic_dim_index_, entries_to_add));
}
for (const auto& v_cache_input_name : value_cache_input_names_) {
RETURN_IF_ERROR(ResolveDynamicShape(model_, compiled_model_, "prefill",
v_cache_input_name,
new_kv_seq_len));
ASSIGN_OR_RETURN(kv_cache_buffers_1_[v_cache_input_name],
ResizeKVCacheTensorBuffer(
env_, kv_cache_buffers_1_[v_cache_input_name],
value_dynamic_dim_index_, entries_to_add));
}
kv_length = new_kv_seq_len;
}
}
absl::flat_hash_map<absl::string_view, TensorBuffer> prefill_input_buffers;
RETURN_IF_ERROR(CreatePrefillInputBuffers("prefill", prefill_length,
kv_length, prefill_input_buffers));
input_kv_cache_buffers_ = &kv_cache_buffers_1_;
output_kv_cache_buffers_ = &kv_cache_buffers_1_;
bool async = !params.GetWaitForCompletion();
return LlmLiteRtCompiledModelExecutorBase::PrefillInternal(
"prefill", prefill_input_buffers, ids, async);
}
absl::Status LlmLiteRtCompiledModelExecutorDynamic::DecodeInternal(
const std::vector<std::shared_ptr<TokenData>>& token,
TensorBuffer& output_logits) {
int current_kv_len = 0;
{
RET_CHECK(!kv_cache_buffers_1_.empty());
const TensorBuffer& key_buffer =
kv_cache_buffers_1_[key_cache_input_names_[0]];
LITERT_ASSIGN_OR_RETURN(const RankedTensorType& key_buffer_tensor_type,
key_buffer.TensorType());
current_kv_len =
key_buffer_tensor_type.Layout().Dimensions()[key_dynamic_dim_index_];
}
if (current_kv_len <= llm_context_->runtime_state().current_step - 1) {
int entries_to_add = kv_increament_size_;
int new_kv_len = current_kv_len + entries_to_add;
for (const auto& k_cache_input_name : key_cache_input_names_) {
RETURN_IF_ERROR(ResolveDynamicShape(model_, compiled_model_, "decode",
k_cache_input_name, new_kv_len));
ASSIGN_OR_RETURN(kv_cache_buffers_1_[k_cache_input_name],
ResizeKVCacheTensorBuffer(
env_, kv_cache_buffers_1_[k_cache_input_name],
key_dynamic_dim_index_, entries_to_add));
}
for (const auto& v_cache_input_name : value_cache_input_names_) {
RETURN_IF_ERROR(ResolveDynamicShape(model_, compiled_model_, "decode",
v_cache_input_name, new_kv_len));
ASSIGN_OR_RETURN(kv_cache_buffers_1_[v_cache_input_name],
ResizeKVCacheTensorBuffer(
env_, kv_cache_buffers_1_[v_cache_input_name],
value_dynamic_dim_index_, entries_to_add));
}
current_kv_len = new_kv_len;
}
RETURN_IF_ERROR(ResolveDynamicShape(model_, compiled_model_, "decode",
signatures_.input_attn_mask.value(),
current_kv_len));
LITERT_ASSIGN_OR_RETURN(
decode_input_buffers_[signatures_.input_attn_mask.value()],
compiled_model_.CreateInputBuffer("decode",
signatures_.input_attn_mask.value()));
return LlmLiteRtCompiledModelExecutorBase::DecodeInternal(token,
output_logits);
}
// static
// Creates a LlmLiteRtCompiledModelExecutorDynamic from a LiteRt model.
absl::StatusOr<std::unique_ptr<LlmLiteRtCompiledModelExecutorDynamic>>
LlmLiteRtCompiledModelExecutorDynamic::Create(
LlmExecutorSettings executor_settings, Environment& lrt_env,
ModelResources& resources) {
ASSIGN_OR_RETURN(auto litert_model,
resources.GetTFLiteModel(ModelType::kTfLitePrefillDecode));
ASSIGN_OR_RETURN(
auto compilation_options,
CreateCompilationOptions(executor_settings, ActivationDataType::FLOAT32,
/*signatures=*/std::nullopt));
std::string weight_cache_path = executor_settings.GetCacheDir();
const Backend backend = executor_settings.GetBackend();
RET_CHECK_EQ(backend, Backend::CPU)
<< "LlmLiteRtCompiledModelExecutorDynamic only supports CPU backend.";
uint32_t kv_increament_size = 0;
int prefill_chunk_size = -1;
{
LITERT_ASSIGN_OR_RETURN(auto& cpu_compilation_options,
compilation_options.GetCpuOptions());
ASSIGN_OR_RETURN(const auto& cpu_config,
executor_settings.GetBackendConfig<CpuConfig>());
kv_increament_size = cpu_config.kv_increment_size;
prefill_chunk_size = cpu_config.prefill_chunk_size;
cpu_compilation_options.SetNumThreads(cpu_config.number_of_threads);
auto weight_cache_file =
executor_settings.GetWeightCacheFile(".xnnpack_cache");
if (weight_cache_file.ok()) {
if (std::holds_alternative<std::string>(*weight_cache_file)) {
weight_cache_path = std::get<std::string>(*weight_cache_file);
cpu_compilation_options.SetXNNPackWeightCachePath(
weight_cache_path.c_str());
} else {
auto scoped_cache_file =
std::get<std::shared_ptr<ScopedFile>>(*weight_cache_file);
ASSIGN_OR_RETURN(auto duplicated, scoped_cache_file->Duplicate());
ASSIGN_OR_RETURN(int fd, duplicated.Release());
cpu_compilation_options.SetXNNPackWeightCacheFileDescriptor(fd);
}
}
RET_CHECK_GT(kv_increament_size, 0)
<< "KV increment size must be greater than 0.";
auto default_xnn_options = TfLiteXNNPackDelegateOptionsDefault();
cpu_compilation_options.SetXNNPackFlags(
default_xnn_options.flags |
TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS);
LITERT_ASSIGN_OR_RETURN(auto& runtime_options,
compilation_options.GetRuntimeOptions());
runtime_options.SetCompressQuantizationZeroPoints(true);
compilation_options.SetHardwareAccelerators(HwAccelerators::kCpu);
}
LITERT_ASSIGN_OR_RETURN(
auto compiled_model,
CompiledModel::Create(lrt_env, litert_model->Get(), compilation_options));
absl::flat_hash_map<absl::string_view, TensorBuffer> decode_input_buffers;
absl::flat_hash_map<absl::string_view, TensorBuffer> decode_output_buffers;
LITERT_ASSIGN_OR_RETURN(auto decode_signature,
litert_model->FindSignature(kDecodeSignatureRunner));
std::string kv_cache_k_root_name;
std::string kv_cache_v_root_name;
RETURN_IF_ERROR(GetKVCacheRootNames(
decode_signature.InputNames(), decode_signature.OutputNames(),
kv_cache_k_root_name, kv_cache_v_root_name));
ASSIGN_OR_RETURN(
ModelSignatures signatures,
GetModelSignaturesFromInputOutputNames(decode_signature.InputNames(),
decode_signature.OutputNames()));
std::vector<std::string> key_cache_input_names;
std::vector<std::string> value_cache_input_names;
for (auto input_name : decode_signature.InputNames()) {
bool is_key_cache_input =
absl::StartsWith(input_name, kv_cache_k_root_name);
if (is_key_cache_input) {
key_cache_input_names.push_back(std::string(input_name));
}
bool is_value_cache_input =
absl::StartsWith(input_name, kv_cache_v_root_name);
if (is_value_cache_input) {
value_cache_input_names.push_back(std::string(input_name));
}
bool is_kv_cache_input = is_key_cache_input || is_value_cache_input;
bool is_attn_mask_input =
signatures.input_attn_mask.has_value() &&
absl::StartsWith(input_name, signatures.input_attn_mask.value());
if (!is_kv_cache_input && !is_attn_mask_input) {
LITERT_ASSIGN_OR_RETURN(
auto input_buffer,
compiled_model.CreateInputBuffer(kDecodeSignatureRunner, input_name));
decode_input_buffers[input_name] = std::move(input_buffer);
}
}
for (auto output_name : decode_signature.OutputNames()) {
if (!absl::StartsWith(output_name, kv_cache_k_root_name) &&
!absl::StartsWith(output_name, kv_cache_v_root_name)) {
LITERT_ASSIGN_OR_RETURN(auto output_buffer,
compiled_model.CreateOutputBuffer(
kDecodeSignatureRunner, output_name));
decode_output_buffers[output_name] = std::move(output_buffer);
}
}
ASSIGN_OR_RETURN(
int k_dynamic_dim,
GetDynamicDimIndex(*litert_model, "prefill", key_cache_input_names[0]));
ASSIGN_OR_RETURN(
int v_dynamic_dim,
GetDynamicDimIndex(*litert_model, "prefill", value_cache_input_names[0]));
LITERT_ASSIGN_OR_RETURN(
auto output_logits_buffer,
decode_output_buffers[signatures.output_logits].Duplicate());
LITERT_ASSIGN_OR_RETURN(auto output_logits_buffer_tensor_type,
output_logits_buffer.TensorType());
RET_CHECK(output_logits_buffer_tensor_type.Layout().Dimensions().size() == 3)
<< "Output logits must be (batch, seq, vocab)";
int batch_size = output_logits_buffer_tensor_type.Layout().Dimensions()[0];
RET_CHECK_EQ(batch_size, 1) << "Only support batch size 1 for now.";
std::unique_ptr<EmbeddingLookupManager> embedding_lookup;
std::unique_ptr<EmbeddingLookupManager> per_layer_embedding_lookup;
RETURN_IF_ERROR(InitializeEmbeddingLookups(
lrt_env, resources, embedding_lookup, per_layer_embedding_lookup));
return absl::WrapUnique(new LlmLiteRtCompiledModelExecutorDynamic(
std::move(executor_settings), lrt_env, litert_model,
std::move(compiled_model), std::move(decode_input_buffers),
std::move(decode_output_buffers), prefill_chunk_size, k_dynamic_dim,
v_dynamic_dim, kv_increament_size, std::move(key_cache_input_names),
std::move(value_cache_input_names), signatures, batch_size,
std::move(weight_cache_path), std::move(embedding_lookup),
std::move(per_layer_embedding_lookup), /*use_fp16_precision=*/false,
/*logits_data_type=*/LogitsDataType::FLOAT32));
}
} // namespace litert::lm