LiteRT-LM / runtime /executor /llm_executor_base.h
SeaWolf-AI's picture
Upload full LiteRT-LM codebase
5f923cd verified
// Copyright 2024 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_ODML_LITERT_LM_RUNTIME_EXECUTOR_LLM_EXECUTOR_BASE_H_
#define THIRD_PARTY_ODML_LITERT_LM_RUNTIME_EXECUTOR_LLM_EXECUTOR_BASE_H_
#include <vector>
#include "absl/status/status.h" // from @com_google_absl
#include "absl/status/statusor.h" // from @com_google_absl
#include "absl/strings/str_cat.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "litert/cc/litert_environment.h" // from @litert
#include "litert/cc/litert_tensor_buffer.h" // from @litert
#include "runtime/executor/llm_executor_io_types.h"
#include "runtime/executor/llm_executor_settings.h"
namespace litert::lm {
// The LLM Executor serves as a lightweight and portable wrapper around various
// converted LLM model formats, i.e. LiteRT. It aims to provide a general,
// minimal-dependency interface for the users, abstracting the complexities of
// executing models across diverse hardware accelerators like CPUs, GPUs, and
// other ML accelerators.
class LlmExecutorBase {
public:
virtual ~LlmExecutorBase() = default;
// ------------Input APIs------------:
// Basic API to trigger the "prefill" or "prefix" process.
// Input is token ids with shape `[batch, sequence_length]`
virtual absl::Status Prefill(const ExecutorInputs& inputs) = 0;
// Advanced API to allow customized query parameters.
// Input is token ids with shape `[batch, sequence_length]`
virtual absl::Status Prefill(const ExecutorInputs& inputs,
const ExecutorPrefillParams& prefill_params) {
return absl::UnimplementedError(absl::StrCat(
"Prefill with prefill params not implemented for backend: ",
ExecutorBackendName()));
};
// ------------Output APIs------------:
// Basic API to trigger the "decode" process. On success, will return a vector
// of token ids of generated output tokens, one per candidate.
virtual absl::StatusOr<std::vector<std::vector<int>>> Decode() = 0;
// Advanced API to trigger the decode and sampling process with custom
// parameters. On success, will return a vector of token ids of generated
// output tokens, one per candidate.
virtual absl::StatusOr<std::vector<std::vector<int>>> Decode(
const ExecutorDecodeParams& decode_params) {
return absl::UnimplementedError(
absl::StrCat("Decode with decode params not implemented for backend: ",
ExecutorBackendName()));
}
// [Deprecated]Basic API to trigger the "decode" process but without sampling.
// Input is token ids with shape `[batch, sequence_length]`
// Output is logits with shape `[batch, sequence_length, vocab_size]` of
// float32_t.
// TODO(b/416293708): remove this API once the new API is ready and tested.
virtual absl::Status Decode(const ExecutorInputs& inputs,
::litert::TensorBuffer& output_logits) {
return absl::UnimplementedError(
absl::StrCat("Decode for logits output not implemented for backend: ",
ExecutorBackendName()));
};
// Basic API to trigger the "decode" process but without sampling.
// Input is token ids with shape `[batch, sequence_length]`
// Output is logits with shape `[batch, sequence_length, vocab_size]` of
// float32_t on the host memory.
virtual absl::StatusOr<::litert::TensorBuffer> DecodeLogits(
const ExecutorInputs& inputs) {
return absl::UnimplementedError(
absl::StrCat("Decode for logits output not implemented for backend: ",
ExecutorBackendName()));
};
virtual absl::string_view ExecutorBackendName() const = 0;
// Get vocabulary size used to build tensor buffers for decode functions.
virtual absl::StatusOr<int> GetVocabSize() {
return absl::UnimplementedError(absl::StrCat(
"GetVocabSize not implemented for backend: ", ExecutorBackendName()));
};
// Gets the current step of the executor.
virtual absl::StatusOr<int> GetCurrentStep() const {
return absl::UnimplementedError(absl::StrCat(
"GetCurrentStep not implemented for backend: ", ExecutorBackendName()));
};
virtual absl::Status SetCurrentStep(int current_step) {
return absl::UnimplementedError(absl::StrCat(
"SetCurrentStep not implemented for backend: ", ExecutorBackendName()));
};
// Gets the executor settings of the executor.
virtual absl::StatusOr<LlmExecutorSettings> GetExecutorSettings() const {
return absl::UnimplementedError(
absl::StrCat("GetExecutorSettings not implemented for backend: ",
ExecutorBackendName()));
};
// Gets the litert environment used by the executor.
// This is used by AICore's EmbeddingModelMldrift to convert tokens to
// TensorBuffers.
virtual ::litert::Environment* GetEnvironment() const { return nullptr; };
// Updates the executor settings.
virtual absl::Status UpdateExecutorSettings(
const LlmExecutorSettings& executor_settings) {
return absl::OkStatus();
}
// ------------Vision APIs------------:
// This function will populate the GPU tensors with the vision embeddings and
// vision per layer embeddings. This should only be used before the
// prefill/prefix stage.
// vision_input: The vision embeddings to populate the GPU tensors with.
// image_index: The index of the image in the batch. It must be non-negative
// and less than `max_num_images`. This will overwrite the vision embeddings
// for the given image index if it is already populated.
virtual absl::Status FillVisionEmbeddings(
const ExecutorVisionData& vision_input, int image_index) {
return absl::UnimplementedError(
absl::StrCat("FillVisionEmbeddings not implemented for backend: ",
ExecutorBackendName()));
};
// Resets all of the internal states (e.g. KVCache). Loaded and used LoRA
// models are not affected (remain loaded and in use).
virtual absl::Status Reset() {
return absl::UnimplementedError(absl::StrCat(
"Reset not implemented for backend: ", ExecutorBackendName()));
};
};
} // namespace litert::lm
#endif // THIRD_PARTY_ODML_LITERT_LM_RUNTIME_EXECUTOR_LLM_EXECUTOR_BASE_H_