// Copyright 2025 The ODML Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_ODML_LITERT_LM_RUNTIME_ENGINE_ENGINE_SETTINGS_H_ #define THIRD_PARTY_ODML_LITERT_LM_RUNTIME_ENGINE_ENGINE_SETTINGS_H_ #include #include #include #include #include #include #include "absl/base/nullability.h" // from @com_google_absl #include "absl/status/status.h" // from @com_google_absl #include "absl/status/statusor.h" // from @com_google_absl #include "absl/strings/string_view.h" // from @com_google_absl #include "runtime/components/tokenizer.h" #include "runtime/executor/audio_executor_settings.h" #include "runtime/executor/executor_settings_base.h" #include "runtime/executor/llm_executor_settings.h" #include "runtime/executor/vision_executor_settings.h" #include "runtime/proto/engine.pb.h" #include "runtime/proto/llm_metadata.pb.h" #include "runtime/proto/llm_model_type.pb.h" #include "runtime/proto/sampler_params.pb.h" #include "runtime/util/scoped_file.h" namespace litert::lm { // Note for development conventions: // 1. Any optional field should use std::optional. // 2. All member variables should be private and have their corresponding // getters and setters. // 3. For basic types, e.g. int, float, bool, etc., the getters and setters // should be Get*() and Set*(). // 4. For complex types, e.g. proto::BenchmarkParams, the getters and setters // should be Get*() and GetMutable*(). // 5. For optional fields, the mutable getter should create a default instance // if the field is not set. But the non-mutable getter should return a // const reference to the std::optional field. // Settings used for initializing LiteRT LM Engine. // This class encapsulates the model-specific settings that are used for // initializing the LiteRT LM. These settings are typically fixed for a given // model and are not expected to change during the inference process. // // This class is used to initialize the LiteRT LM Engine. The user should // create an EngineSettings object and then call the MaybeUpdateAndValidate() // method to validate the settings. If the validation fails, the user should // not use the EngineSettings object. // // Example: // // ASSIGN_OR_RETURN(ModelAssets model_assets, // ModelAssets::Create(model_path)); // ASSIGN_OR_RETURN(EngineSettings engine_settings, // EngineSettings::CreateDefault(model_assets)); // ...initialize the Engine... // ASSIGN_OR_RETURN(std::unique_ptr engine, // Engine::CreateEngine(engine_settings)); // TODO(b/397975034) Add overloading << operator for debugging. class EngineSettings { public: // Creates a default EngineSettings with the given model assets and specified // backend. static absl::StatusOr CreateDefault( ModelAssets model_assets, Backend backend = Backend::CPU, std::optional vision_backend = std::nullopt, std::optional audio_backend = std::nullopt, std::optional sampler_backend = std::nullopt); // Updates the EngineSettings fields by loading the metadata from the model // assets. The function also validates to check if all of the required fields // are set correctly. Returns an error if the validation fails. absl::Status MaybeUpdateAndValidate( Tokenizer* tokenizer, const proto::LlmMetadata* absl_nullable metadata_from_file, absl::string_view input_prompt_as_hint = "", const std::optional& text_backend_constraint = std::nullopt, const std::optional& vision_backend_constraint = std::nullopt, const std::optional& audio_backend_constraint = std::nullopt); // Returns the LlmExecutorSettings. const LlmExecutorSettings& GetMainExecutorSettings() const; // Returns the mutable LlmExecutorSettings. LlmExecutorSettings& GetMutableMainExecutorSettings(); // Returns the VisionExecutorSettings for the vision model. const std::optional& GetVisionExecutorSettings() const; // Returns the mutable VisionExecutorSettings for the vision model. std::optional& GetMutableVisionExecutorSettings(); // Returns the AudioExecutorSettings for the audio model. const std::optional& GetAudioExecutorSettings() const; // Returns the mutable AudioExecutorSettings for the audio model. std::optional& GetMutableAudioExecutorSettings(); // Benchmark parameters: // Returns true if the benchmark is enabled. bool IsBenchmarkEnabled() const; // Returns the benchmark parameters. const std::optional& GetBenchmarkParams() const; // Returns the mutable benchmark parameters. proto::BenchmarkParams& GetMutableBenchmarkParams(); // Returns the LlmMetadata parameters. const std::optional& GetLlmMetadata() const; // Returns the mutable LlmMetadata parameters. Note that is the metadata_ is // not set (i.e. std::nullopt), then the default LlmMetadata will be // created and returned. proto::LlmMetadata& GetMutableLlmMetadata(); // Returns true if the engine may load different sections of the litertlm file // in parallel. bool GetParallelFileSectionLoading() const; // Sets whether the engine should load different sections of the litertlm file // in parallel. void SetParallelFileSectionLoading(bool parallel_file_section_loading); private: explicit EngineSettings( LlmExecutorSettings executor_settings, std::optional vision_executor_settings, std::optional audio_executor_settings, std::optional benchmark_params = std::nullopt); // Settings for the main executor. LlmExecutorSettings main_executor_settings_; // Settings for the vision executor. std::optional vision_executor_settings_; // Settings for the audio executor. std::optional audio_executor_settings_; // Parameters used to configure the benchmarking process. std::optional benchmark_params_; // Default metadata for the model. This is loaded from the model assets (if // present). std::optional metadata_; // Whether the engine should load different sections of the litertlm file in // parallel. bool parallel_file_section_loading_ = true; }; std::ostream& operator<<(std::ostream& os, const EngineSettings& settings); // Configurations used for the session. // This class encapsulates the session-specific configurations that are used for // creating a LiteRT LM session. class SessionConfig { public: // Creates a default SessionConfig. static SessionConfig CreateDefault(); // Updates the SessionConfig fields from the EngineSettings when not set. The // function also validates to check if all of the required fields are set // correctly. Returns an error if the validation fails. absl::Status MaybeUpdateAndValidate(const EngineSettings& engine_settings); // Configures the audio modality in the session. bool AudioModalityEnabled() const { return audio_modality_enabled_; } void SetAudioModalityEnabled(bool enable_audio_modality) { audio_modality_enabled_ = enable_audio_modality; } // Configures the vision modality in the session. bool VisionModalityEnabled() const { return vision_modality_enabled_; } void SetVisionModalityEnabled(bool enable_vision_modality) { vision_modality_enabled_ = enable_vision_modality; } // Sampler parameters: // Getters for the sampler parameters. const proto::SamplerParameters& GetSamplerParams() const; proto::SamplerParameters& GetMutableSamplerParams(); // Stop token ids: // Getters for the stop token ids. const std::vector>& GetStopTokenIds() const; std::vector>& GetMutableStopTokenIds(); // Set the start token ids. int GetStartTokenId() const; void SetStartTokenId(int start_token_id); // Number of output candidates: // Getters for the number of output candidates. int GetNumOutputCandidates() const; void SetNumOutputCandidates(int num_output_candidates); // Sampler backend: // Getters for the backend of the sampler. Backend GetSamplerBackend() const; void SetSamplerBackend(Backend sampler_backend); // Prompt templates: // Getters for the prompt templates. const proto::PromptTemplates& GetPromptTemplates() const; proto::PromptTemplates& GetMutablePromptTemplates(); // Llm model type: // Getters for the LLM model type. const proto::LlmModelType& GetLlmModelType() const; proto::LlmModelType& GetMutableLlmModelType(); // Whether to apply the basic prompt templates in the session. bool GetApplyPromptTemplateInSession() const { return apply_prompt_template_in_session_; } void SetApplyPromptTemplateInSession(bool apply_prompt_template_in_session) { apply_prompt_template_in_session_ = apply_prompt_template_in_session; } // Whether to use external sampler. bool UseExternalSampler() const { return use_external_sampler_; } void SetUseExternalSampler(bool use_external_sampler) { use_external_sampler_ = use_external_sampler; } // Scoped LoRA file: // Getters for the scoped LoRA file. std::shared_ptr GetScopedLoraFile() const; void SetScopedLoraFile(std::shared_ptr scoped_lora_file); // The maximum number of tokens to generate in a single request: // Getters for the max output tokens. int GetMaxOutputTokens() const { return max_output_tokens_; } void SetMaxOutputTokens(int max_output_tokens) { max_output_tokens_ = max_output_tokens; } private: // Private constructor for the SessionConfig. The user should use the // CreateDefault() method to create a SessionConfig. explicit SessionConfig(const proto::SamplerParameters& sampler_params); // Whether to enable audio modality in the session. bool audio_modality_enabled_ = false; // Whether to enable vision modality in the session. bool vision_modality_enabled_ = false; // Parameters used to configure the sampling process. proto::SamplerParameters sampler_params_; // Stop token ids for the session. Note that the stop token could be a // sequence of token ids (as opposed to a single token id). The first // dimension is the index of the stop token in the session, and the second // dimension is the sequence of token ids that constitutes the stop token. std::vector> stop_token_ids_; // Start token id for the session. int start_token_id_ = -1; // Prompt templates for the session. This is loaded from the model assets (if // present). proto::PromptTemplates prompt_templates_; // Llm model type for the session. This is loaded from the model assets (if // present). proto::LlmModelType llm_model_type_; // The number of output candidates to generate. Default value is 1 and setting // it to a value greater than 1 will require the model to support batching. int num_output_candidates_ = 1; // Backend to use for sampling. Backend sampler_backend_ = Backend::UNSPECIFIED; // Whether to apply the prompt templates in the session. bool apply_prompt_template_in_session_ = true; // Whether to use external sampler. // notice: this is only used in advanced engine. bool use_external_sampler_ = false; // Scoped file for the LoRA weights. std::shared_ptr scoped_lora_file_; // The maximum number of tokens to generate in a single request. This limits // the number of decoding steps for a request, as opposed to // LlmExecutorSettings::GetMaxNumTokens(), which limits the total number of // tokens (input + output) stored in the KV cache over the lifetime of a // session. int max_output_tokens_ = std::numeric_limits::max(); }; std::ostream& operator<<(std::ostream& os, const SessionConfig& config); } // namespace litert::lm #endif // THIRD_PARTY_ODML_LITERT_LM_RUNTIME_ENGINE_ENGINE_SETTINGS_H_