// Copyright 2026 The ODML Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_ODML_LITERT_LM_RUNTIME_EXECUTOR_LITERT_KV_CACHE_H_ #define THIRD_PARTY_ODML_LITERT_LM_RUNTIME_EXECUTOR_LITERT_KV_CACHE_H_ #include #include #include #include #include "absl/container/flat_hash_map.h" // from @com_google_absl #include "absl/status/status.h" // from @com_google_absl #include "absl/status/statusor.h" // from @com_google_absl #include "absl/strings/string_view.h" // from @com_google_absl #include "litert/cc/litert_compiled_model.h" // from @litert #include "litert/cc/litert_environment.h" // from @litert #include "litert/cc/litert_model.h" // from @litert #include "litert/cc/litert_tensor_buffer.h" // from @litert #include "runtime/executor/kv_cache_interface.h" namespace litert::lm { class LitertKVCache : public KVCacheInterface { public: static absl::StatusOr> Create( Environment& env, const Model& model, absl::string_view signature_name, CompiledModel& compiled_model, bool inplace_update); int GetNumEntries() const override { return num_entries_; }; int GetBatchSize() const override { return batch_size_; }; absl::StatusOr Serialize() const override { return absl::UnimplementedError("Not implemented"); } absl::Status Load(absl::string_view serialized_kv_cache) override { return absl::UnimplementedError("Not implemented"); } absl::Status SelectAndCopyFrom(KVCacheInterface& other, int batch_index) override; absl::Status BroadcastAndCopyFrom(KVCacheInterface& other) override; absl::StatusOr> DeepCopy() const override; // Resizes the KV cache to the given number of entries (sequence length). // Note: Resize is a no-op if the requested size is smaller than the current // size. absl::Status Resize(int num_entries); struct KVCacheBuffers { absl::flat_hash_map input_buffers; absl::flat_hash_map output_buffers; }; // For backends that support inplace update, this returns a single set of KV // cache buffers that can be used for both input and output (i.e, // input_buffers and output_buffers point to the same data). // For backends that don't support inplace update, this returns two distinct // sets of KV cache buffers, one for input and one for output. On each call, // the input/output buffers will be swapped. absl::StatusOr GetKVCacheBuffers(); private: LitertKVCache( int batch_size, int num_entries, std::optional k_dynamic_dim, std::optional v_dynamic_dim, Environment& env, absl::flat_hash_map bank_1_key_cache_buffers, absl::flat_hash_map bank_1_value_cache_buffers, std::optional> bank_2_key_cache_buffers, std::optional> bank_2_value_cache_buffers) : batch_size_(batch_size), num_entries_(num_entries), k_dynamic_dim_(std::move(k_dynamic_dim)), v_dynamic_dim_(std::move(v_dynamic_dim)), env_(env), bank_1_key_cache_buffers_(std::move(bank_1_key_cache_buffers)), bank_1_value_cache_buffers_(std::move(bank_1_value_cache_buffers)), bank_2_key_cache_buffers_(std::move(bank_2_key_cache_buffers)), bank_2_value_cache_buffers_(std::move(bank_2_value_cache_buffers)) {} // Batch size of the KV cache buffers. int batch_size_; // Number of entries in the KV cache. int num_entries_; // Dynamic dimension index of the KV cache buffers (i.e., sequence dimension). std::optional k_dynamic_dim_; std::optional v_dynamic_dim_; // Environment to create new TensorBuffers (required for resizing). Environment& env_; // Primary KV cache buffers. absl::flat_hash_map bank_1_key_cache_buffers_; absl::flat_hash_map bank_1_value_cache_buffers_; // Secondary KV cache buffers - only used when inplace update is not // supported. std::optional> bank_2_key_cache_buffers_; std::optional> bank_2_value_cache_buffers_; bool bank_1_is_input_ = true; }; } // namespace litert::lm #endif // THIRD_PARTY_ODML_LITERT_LM_RUNTIME_EXECUTOR_LITERT_KV_CACHE_H_