Spaces:
Running
Running
| // Copyright 2025 The ODML Authors. | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| namespace litert::lm { | |
| using ::litert::Expected; | |
| using ::litert::TensorBuffer; | |
| ::litert::Expected<bool> ShouldDeleteKVCacheTokens(int current_step, | |
| int start_position, | |
| size_t context_size) { | |
| if (current_step - start_position < 0) { | |
| return ::litert::Unexpected( | |
| kLiteRtStatusErrorInvalidArgument, | |
| "Deleted more tokens than the number of model steps processed."); | |
| } | |
| if (current_step < 0) { | |
| return ::litert::Unexpected(kLiteRtStatusErrorInvalidArgument, | |
| "current_step step is negative."); | |
| } | |
| if (start_position < 0) { | |
| return ::litert::Unexpected(kLiteRtStatusErrorInvalidArgument, | |
| "start_position is negative."); | |
| } | |
| if (current_step - start_position >= context_size - 1) { | |
| return true; | |
| } | |
| return false; | |
| } | |
| // Function to dump the ring buffer. | |
| ::litert::Expected<void> DeleteTokensFromKvCache( | |
| absl::flat_hash_map<absl::string_view, ::litert::TensorBuffer>* | |
| input_kv_cache_buffers, | |
| int num_tokens_to_drop, int init_tokens_to_retain) { | |
| // A k cache buffer is a 4D tensor with shape | |
| // [1, heads, context_size, embedding_size] | |
| // A v cache buffer is a 4D tensor with shape | |
| // [1, heads, embedding_size, context_size] | |
| for (auto& [input_name, input_buffer] : *input_kv_cache_buffers) { | |
| LITERT_ASSIGN_OR_RETURN(auto type, input_buffer.TensorType()); | |
| const int axis = absl::StrContains(input_name, "cache_k_") ? 2 | |
| : absl::StrContains(input_name, "cache_v_") ? 3 | |
| : -1; | |
| if (axis == -1) { | |
| return ::litert::Unexpected(kLiteRtStatusErrorInvalidArgument, | |
| "Unsupported input name."); | |
| } | |
| switch (type.ElementType()) { | |
| case ::litert::ElementType::Int8: | |
| LITERT_RETURN_IF_ERROR(DropTokensfromTensorBuffer<int8_t>( | |
| input_buffer, num_tokens_to_drop, axis, init_tokens_to_retain)); | |
| break; | |
| case ::litert::ElementType::Int16: | |
| LITERT_RETURN_IF_ERROR(DropTokensfromTensorBuffer<int16_t>( | |
| input_buffer, num_tokens_to_drop, axis, init_tokens_to_retain)); | |
| break; | |
| case ::litert::ElementType::Int32: | |
| LITERT_RETURN_IF_ERROR(DropTokensfromTensorBuffer<int32_t>( | |
| input_buffer, num_tokens_to_drop, axis, init_tokens_to_retain)); | |
| break; | |
| case ::litert::ElementType::Float32: | |
| LITERT_RETURN_IF_ERROR(DropTokensfromTensorBuffer<float>( | |
| input_buffer, num_tokens_to_drop, axis, init_tokens_to_retain)); | |
| break; | |
| default: | |
| return ::litert::Unexpected(kLiteRtStatusErrorInvalidArgument, | |
| "Unsupported element type."); | |
| } | |
| } | |
| return {}; | |
| } | |
| ::litert::Expected<bool> DeleteTokensIfNeeded( | |
| absl::flat_hash_map<absl::string_view, ::litert::TensorBuffer>* | |
| input_kv_cache_buffers, | |
| int num_tokens_to_drop, int init_tokens_to_retain, int current_step, | |
| int& start_position, size_t context_size) { | |
| LITERT_ASSIGN_OR_RETURN( | |
| bool should_delete_tokens, | |
| ShouldDeleteKVCacheTokens(current_step, start_position, context_size)); | |
| if (should_delete_tokens) { | |
| LITERT_RETURN_IF_ERROR(DeleteTokensFromKvCache( | |
| input_kv_cache_buffers, | |
| /*num_tokens_to_drop=*/num_tokens_to_drop, | |
| /*init_tokens_to_retain=*/init_tokens_to_retain)); | |
| start_position += num_tokens_to_drop; | |
| return true; | |
| } | |
| return false; | |
| } | |
| bool IsKVCacheTensor(absl::string_view tensor_name) { | |
| return absl::StartsWith(tensor_name, "kv_cache_") || | |
| absl::StartsWith(tensor_name, "k_cache_") || | |
| absl::StartsWith(tensor_name, "v_cache_"); | |
| } | |
| } // namespace litert::lm | |