// Copyright 2025 The ODML Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_ODML_LITERT_LM_RUNTIME_ENGINE_ENGINE_H_ #define THIRD_PARTY_ODML_LITERT_LM_RUNTIME_ENGINE_ENGINE_H_ #include #include #include "absl/functional/any_invocable.h" // from @com_google_absl #include "absl/log/absl_log.h" // from @com_google_absl #include "absl/status/status.h" // from @com_google_absl #include "absl/status/statusor.h" // from @com_google_absl #include "absl/strings/string_view.h" // from @com_google_absl #include "absl/time/time.h" // from @com_google_absl #include "runtime/components/tokenizer.h" #include "runtime/engine/engine_settings.h" #include "runtime/engine/io_types.h" namespace litert::lm { // Engine is the interface for the LLM runtime. It is responsible for // - Initializing the LLM model and related resources, e.g. tokenizer, // embedder, etc. // - Providing the APIs to create the Session. // // The Session is responsible for hosting the internal state (e.g. conversation // history) of each separate interaction with LLM. It is created by the Engine // and is responsible for: // - Generating content from the input prompt/query. // - Running the prefill and decode processes. // // Example usage: // // Create the model assets. // auto model_assets = ModelAssets::Create(model_path); // CHECK_OK(model_assets); // // // Create the engine. // auto engine = Engine::CreateEngine(EngineSettings::CreateDefault( // model_assets, litert::lm::Backend::CPU)); // CHECK_OK(engine); // // // Create the session. // auto session = engine->CreateSession(SessionConfig::CreateDefault()); // CHECK_OK(session); // // // Run generate content. // auto responses = (*session)->GenerateContent({InputText("What's the tallest // building in the world?")}); // CHECK_OK(responses); // // // Print the response. // std::cout << *responses << std::endl; class Engine { public: virtual ~Engine() = default; // Session is responsible for hosting the internal state (e.g. conversation // history) of each separate interaction with LLM. class Session { public: // The TaskController is responsible for controlling the async task // execution. class TaskController { public: TaskController() = default; // The TaskController is not copyable. This is to avoid // the user from accidentally copying the TaskController and calling the // CancelProcess function multiple times. TaskController(const TaskController&) = delete; TaskController& operator=(const TaskController&) = delete; // The TaskController is movable. TaskController(TaskController&&) = default; TaskController& operator=(TaskController&&) = default; // The TaskController destructor. virtual ~TaskController() = default; // Waits until all the tasks are done or the timeout is reached. The // function will return error if the timeout is reached. virtual absl::Status WaitUntilDone(absl::Duration timeout) { return absl::UnimplementedError("Not implemented."); }; // Cancels the ongoing inference process. Note that if this function is // called after the inference process is done, the function will be a // no-op. virtual absl::Status Cancel() { return absl::UnimplementedError("Not implemented."); }; }; virtual ~Session() = default; // High-level API to generate content from the input prompt/query. This // function will handle the prefill and decode processes internally and // the usage is similar to the Gemini Text Generation API // (https://ai.google.dev/gemini-api/docs/text-generation). // - contents: The input data for generation. virtual absl::StatusOr GenerateContent( const std::vector& contents) = 0; // This is a not blocking call and the function will return right away. The // result will be streamed through the callback. // // - contents: The input data for generation. // - callback: Callback to receive streamed results. // Note: // - If the generation is done successfully, the callback will be // called with empty responses to signal the completion. // - If there is an error during the streaming process, the callback // will be called with the error status and no further results will be // sent. // - If the generation is cancelled, the callback will be called // with a Cancellation error. virtual absl::Status GenerateContentStream( const std::vector& contents, absl::AnyInvocable)> callback) = 0; // Same as above, but with a custom decode config. // - decode_config: configuration for the model decode process. virtual absl::Status GenerateContentStream( const std::vector& contents, absl::AnyInvocable)> callback, const DecodeConfig& decode_config) = 0; // Scores the target text after the prefill process is done. This function // will only run the decode process to fetch the decode output logits, which // is used to calculate the target text's score and update the model memory // using the target_text tokens. // This function should be called after the prefill process is done. // - target_text: The target text to score. // - store_token_lengths: Whether to store the token lengths of the target // texts in `Responses`. // - returns: This function returns the score associated with the target // text after the model has been prefilled. The returned score is the sum of // the negative log probability of seeing the target text during decode. virtual absl::StatusOr RunTextScoring( const std::vector& target_text, bool store_token_lengths) = 0; // Similar to the above RunTextScoring function, but this is a not blocking // call and the function will return right away. The processing status will // be signaled through the callback. // - target_text: The target text to score. // - callback: Callback to receive the scoring results. // - store_token_lengths: Whether to store the token lengths of the target // texts in `Responses`. virtual absl::StatusOr> RunTextScoringAsync( const std::vector& target_text, absl::AnyInvocable)> callback, bool store_token_lengths) { return absl::UnimplementedError("Not implemented."); } // Adds the input prompt/query to the model for starting the prefilling // process. Note that the user can break down their prompt/query into // multiple chunks and call this function multiple times. // // This is a blocking call and the function will return when the prefill // process is done. virtual absl::Status RunPrefill(const std::vector& contents) = 0; // This is a not blocking call and the function will return right away. The // processing status will be signaled through the callback. virtual absl::StatusOr> RunPrefillAsync( const std::vector& contents, absl::AnyInvocable)> callback) { return absl::UnimplementedError("Not implemented."); } // Starts the decoding process for the model to predict the response based // on the input prompt/query added after using RunPrefill* functions. // This is a blocking call and the function will return when the decoding // process is done. virtual absl::StatusOr RunDecode() = 0; // Same as above, but with a custom decode config. // - decode_config: configuration for the model decode process. virtual absl::StatusOr RunDecode( const DecodeConfig& decode_config) = 0; // Startes the decoding process for the model to predict the response based // on the input prompt/query added after using RunPrefill* functions. // This is a not blocking call and the function will return right away. The // result will be streamed through the callback. // - callback: Callback to receive streamed results. virtual absl::StatusOr> RunDecodeAsync( absl::AnyInvocable)> callback) { return absl::UnimplementedError("Not implemented."); } // Same as above, but with a custom decode config. // - decode_config: configuration for the model decode process. virtual absl::StatusOr> RunDecodeAsync( absl::AnyInvocable)> callback, const DecodeConfig& decode_config) { return absl::UnimplementedError("Not implemented."); } // Returns the benchmark info for the session. Returns error if the // benchmark is not enabled. virtual absl::StatusOr GetBenchmarkInfo() = 0; // Returns the mutable benchmark info for the session. Returns error if the // benchmark is not enabled. virtual absl::StatusOr GetMutableBenchmarkInfo() = 0; // Cancels the ongoing inference process. Note that if this function is // called, the inference process will return with a kCancelled error. The // session could still be used after afterwards. virtual void CancelProcess() { ABSL_LOG(FATAL) << "CancelProcess is not implemented."; } // Waits until all the tasks are done or the default timeout is reached. virtual absl::Status WaitUntilDone() = 0; // Clones the session. // The cloned session have all the settings and context // of the original session up to the point that the clone function is // called. // - callback: Callback to when the streamed results. // // Example usage: // Session session1 = engine->CreateSession(...); // session1->Prefill("What is the tallest building "); // Session session2 = session1->Clone(); // session1->Prefill("in the world?"); // session1->Decode(); // session2->Prefill("in France?"); // session2->Decode(); virtual absl::StatusOr> Clone() { return absl::UnimplementedError("Not implemented."); }; // Clones the session asynchronously. // The cloned session have all the settings and context // of the original session up to the point that the clone function is // called. // - callback: Callback to when the streamed results. // // Example usage: // Session session1 = engine->CreateSession(...); // session1->RunPrefillAsync("What is the tallest building ", ...); // Session session2 = session1->CloneAsync(...); // session1->RunPrefillAsync("in the world?", ...); // session1->RunDecodeAsync(...); // session2->RunPrefillAsync("in France?", ...); // session2->RunDecodeAsync(...); virtual absl::StatusOr> CloneAsync( absl::AnyInvocable)> callback) { return absl::UnimplementedError("Not implemented."); }; // Save the current step with the name `label`. You can later rewind to this // checkpoint using `RewindToCheckpoint(label)`. If the checkpoint name // already exists, the step number will be overwritten. virtual absl::Status SaveCheckpoint(absl::string_view label) { return absl::UnimplementedError("SaveCheckpoint not implemented."); } // Rewinds the session to the given checkpoint. Checkpoints after the // restored step will be removed. Returns an error if the checkpoint name // does not exist. virtual absl::Status RewindToCheckpoint(absl::string_view label) { return absl::UnimplementedError("RewindToCheckpoint not implemented."); } // Get the current step of the session. virtual absl::StatusOr GetCurrentStep() const { return absl::UnimplementedError("GetCurrentStep not implemented."); } // Get the reference to the session config for the session. virtual const SessionConfig& GetSessionConfig() const = 0; }; // Method to create the Session. virtual absl::StatusOr> CreateSession( const SessionConfig& session_config) = 0; // Waits until the engine is done with all the tasks. The function will // return error if the timeout is reached. virtual absl::Status WaitUntilDone(absl::Duration timeout) { return absl::UnimplementedError("Not implemented."); } // Returns the EngineSettings currently used by the engine. virtual const EngineSettings& GetEngineSettings() const = 0; // Get the reference to the tokenizer for the engine. virtual const Tokenizer& GetTokenizer() const = 0; // Get the audio model properties for the session. This is only available // if the engine is created with audio modality enabled. virtual absl::StatusOr GetAudioExecutorProperties() const = 0; // Get the vision model properties for the session. This is only available // if the engine is created with vision modality enabled. virtual absl::StatusOr GetVisionExecutorProperties() const = 0; // Default timeout duration for the engine/session processes. static constexpr absl::Duration kDefaultTimeout = absl::Minutes(10); }; } // namespace litert::lm #endif // THIRD_PARTY_ODML_LITERT_LM_RUNTIME_ENGINE_ENGINE_H_