LiteRT-LM / runtime /executor /llm_litert_compiled_model_cache_utils.cc
SeaWolf-AI's picture
Upload full LiteRT-LM codebase
5f923cd verified
// Copyright 2025 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "runtime/executor/llm_litert_compiled_model_cache_utils.h"
#include <cstdint>
#include <cstring>
#include <utility>
#include "absl/container/flat_hash_map.h" // from @com_google_absl
#include "absl/strings/match.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "litert/cc/litert_element_type.h" // from @litert
#include "litert/cc/litert_expected.h" // from @litert
#include "litert/cc/litert_macros.h" // from @litert
#include "litert/cc/litert_tensor_buffer.h" // from @litert
#include "runtime/util/convert_tensor_buffer.h"
#include "runtime/util/status_macros.h" // IWYU pragma: keep
namespace litert::lm {
using ::litert::Expected;
using ::litert::TensorBuffer;
::litert::Expected<bool> ShouldDeleteKVCacheTokens(int current_step,
int start_position,
size_t context_size) {
if (current_step - start_position < 0) {
return ::litert::Unexpected(
kLiteRtStatusErrorInvalidArgument,
"Deleted more tokens than the number of model steps processed.");
}
if (current_step < 0) {
return ::litert::Unexpected(kLiteRtStatusErrorInvalidArgument,
"current_step step is negative.");
}
if (start_position < 0) {
return ::litert::Unexpected(kLiteRtStatusErrorInvalidArgument,
"start_position is negative.");
}
if (current_step - start_position >= context_size - 1) {
return true;
}
return false;
}
// Function to dump the ring buffer.
::litert::Expected<void> DeleteTokensFromKvCache(
absl::flat_hash_map<absl::string_view, ::litert::TensorBuffer>*
input_kv_cache_buffers,
int num_tokens_to_drop, int init_tokens_to_retain) {
// A k cache buffer is a 4D tensor with shape
// [1, heads, context_size, embedding_size]
// A v cache buffer is a 4D tensor with shape
// [1, heads, embedding_size, context_size]
for (auto& [input_name, input_buffer] : *input_kv_cache_buffers) {
LITERT_ASSIGN_OR_RETURN(auto type, input_buffer.TensorType());
const int axis = absl::StrContains(input_name, "cache_k_") ? 2
: absl::StrContains(input_name, "cache_v_") ? 3
: -1;
if (axis == -1) {
return ::litert::Unexpected(kLiteRtStatusErrorInvalidArgument,
"Unsupported input name.");
}
switch (type.ElementType()) {
case ::litert::ElementType::Int8:
LITERT_RETURN_IF_ERROR(DropTokensfromTensorBuffer<int8_t>(
input_buffer, num_tokens_to_drop, axis, init_tokens_to_retain));
break;
case ::litert::ElementType::Int16:
LITERT_RETURN_IF_ERROR(DropTokensfromTensorBuffer<int16_t>(
input_buffer, num_tokens_to_drop, axis, init_tokens_to_retain));
break;
case ::litert::ElementType::Int32:
LITERT_RETURN_IF_ERROR(DropTokensfromTensorBuffer<int32_t>(
input_buffer, num_tokens_to_drop, axis, init_tokens_to_retain));
break;
case ::litert::ElementType::Float32:
LITERT_RETURN_IF_ERROR(DropTokensfromTensorBuffer<float>(
input_buffer, num_tokens_to_drop, axis, init_tokens_to_retain));
break;
default:
return ::litert::Unexpected(kLiteRtStatusErrorInvalidArgument,
"Unsupported element type.");
}
}
return {};
}
::litert::Expected<bool> DeleteTokensIfNeeded(
absl::flat_hash_map<absl::string_view, ::litert::TensorBuffer>*
input_kv_cache_buffers,
int num_tokens_to_drop, int init_tokens_to_retain, int current_step,
int& start_position, size_t context_size) {
LITERT_ASSIGN_OR_RETURN(
bool should_delete_tokens,
ShouldDeleteKVCacheTokens(current_step, start_position, context_size));
if (should_delete_tokens) {
LITERT_RETURN_IF_ERROR(DeleteTokensFromKvCache(
input_kv_cache_buffers,
/*num_tokens_to_drop=*/num_tokens_to_drop,
/*init_tokens_to_retain=*/init_tokens_to_retain));
start_position += num_tokens_to_drop;
return true;
}
return false;
}
bool IsKVCacheTensor(absl::string_view tensor_name) {
return absl::StartsWith(tensor_name, "kv_cache_") ||
absl::StartsWith(tensor_name, "k_cache_") ||
absl::StartsWith(tensor_name, "v_cache_");
}
} // namespace litert::lm