Spaces:
Running
Running
| // Copyright 2024 The ODML Authors. | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| namespace litert::lm { | |
| // The LLM Executor serves as a lightweight and portable wrapper around various | |
| // converted LLM model formats, i.e. LiteRT. It aims to provide a general, | |
| // minimal-dependency interface for the users, abstracting the complexities of | |
| // executing models across diverse hardware accelerators like CPUs, GPUs, and | |
| // other ML accelerators. | |
| class LlmExecutorBase { | |
| public: | |
| virtual ~LlmExecutorBase() = default; | |
| // ------------Input APIs------------: | |
| // Basic API to trigger the "prefill" or "prefix" process. | |
| // Input is token ids with shape `[batch, sequence_length]` | |
| virtual absl::Status Prefill(const ExecutorInputs& inputs) = 0; | |
| // Advanced API to allow customized query parameters. | |
| // Input is token ids with shape `[batch, sequence_length]` | |
| virtual absl::Status Prefill(const ExecutorInputs& inputs, | |
| const ExecutorPrefillParams& prefill_params) { | |
| return absl::UnimplementedError(absl::StrCat( | |
| "Prefill with prefill params not implemented for backend: ", | |
| ExecutorBackendName())); | |
| }; | |
| // ------------Output APIs------------: | |
| // Basic API to trigger the "decode" process. On success, will return a vector | |
| // of token ids of generated output tokens, one per candidate. | |
| virtual absl::StatusOr<std::vector<std::vector<int>>> Decode() = 0; | |
| // Advanced API to trigger the decode and sampling process with custom | |
| // parameters. On success, will return a vector of token ids of generated | |
| // output tokens, one per candidate. | |
| virtual absl::StatusOr<std::vector<std::vector<int>>> Decode( | |
| const ExecutorDecodeParams& decode_params) { | |
| return absl::UnimplementedError( | |
| absl::StrCat("Decode with decode params not implemented for backend: ", | |
| ExecutorBackendName())); | |
| } | |
| // [Deprecated]Basic API to trigger the "decode" process but without sampling. | |
| // Input is token ids with shape `[batch, sequence_length]` | |
| // Output is logits with shape `[batch, sequence_length, vocab_size]` of | |
| // float32_t. | |
| // TODO(b/416293708): remove this API once the new API is ready and tested. | |
| virtual absl::Status Decode(const ExecutorInputs& inputs, | |
| ::litert::TensorBuffer& output_logits) { | |
| return absl::UnimplementedError( | |
| absl::StrCat("Decode for logits output not implemented for backend: ", | |
| ExecutorBackendName())); | |
| }; | |
| // Basic API to trigger the "decode" process but without sampling. | |
| // Input is token ids with shape `[batch, sequence_length]` | |
| // Output is logits with shape `[batch, sequence_length, vocab_size]` of | |
| // float32_t on the host memory. | |
| virtual absl::StatusOr<::litert::TensorBuffer> DecodeLogits( | |
| const ExecutorInputs& inputs) { | |
| return absl::UnimplementedError( | |
| absl::StrCat("Decode for logits output not implemented for backend: ", | |
| ExecutorBackendName())); | |
| }; | |
| virtual absl::string_view ExecutorBackendName() const = 0; | |
| // Get vocabulary size used to build tensor buffers for decode functions. | |
| virtual absl::StatusOr<int> GetVocabSize() { | |
| return absl::UnimplementedError(absl::StrCat( | |
| "GetVocabSize not implemented for backend: ", ExecutorBackendName())); | |
| }; | |
| // Gets the current step of the executor. | |
| virtual absl::StatusOr<int> GetCurrentStep() const { | |
| return absl::UnimplementedError(absl::StrCat( | |
| "GetCurrentStep not implemented for backend: ", ExecutorBackendName())); | |
| }; | |
| virtual absl::Status SetCurrentStep(int current_step) { | |
| return absl::UnimplementedError(absl::StrCat( | |
| "SetCurrentStep not implemented for backend: ", ExecutorBackendName())); | |
| }; | |
| // Gets the executor settings of the executor. | |
| virtual absl::StatusOr<LlmExecutorSettings> GetExecutorSettings() const { | |
| return absl::UnimplementedError( | |
| absl::StrCat("GetExecutorSettings not implemented for backend: ", | |
| ExecutorBackendName())); | |
| }; | |
| // Gets the litert environment used by the executor. | |
| // This is used by AICore's EmbeddingModelMldrift to convert tokens to | |
| // TensorBuffers. | |
| virtual ::litert::Environment* GetEnvironment() const { return nullptr; }; | |
| // Updates the executor settings. | |
| virtual absl::Status UpdateExecutorSettings( | |
| const LlmExecutorSettings& executor_settings) { | |
| return absl::OkStatus(); | |
| } | |
| // ------------Vision APIs------------: | |
| // This function will populate the GPU tensors with the vision embeddings and | |
| // vision per layer embeddings. This should only be used before the | |
| // prefill/prefix stage. | |
| // vision_input: The vision embeddings to populate the GPU tensors with. | |
| // image_index: The index of the image in the batch. It must be non-negative | |
| // and less than `max_num_images`. This will overwrite the vision embeddings | |
| // for the given image index if it is already populated. | |
| virtual absl::Status FillVisionEmbeddings( | |
| const ExecutorVisionData& vision_input, int image_index) { | |
| return absl::UnimplementedError( | |
| absl::StrCat("FillVisionEmbeddings not implemented for backend: ", | |
| ExecutorBackendName())); | |
| }; | |
| // Resets all of the internal states (e.g. KVCache). Loaded and used LoRA | |
| // models are not affected (remain loaded and in use). | |
| virtual absl::Status Reset() { | |
| return absl::UnimplementedError(absl::StrCat( | |
| "Reset not implemented for backend: ", ExecutorBackendName())); | |
| }; | |
| }; | |
| } // namespace litert::lm | |