Spaces:
Running
Running
| // Copyright 2025 The ODML Authors. | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| namespace litert::lm { | |
| // Note for development conventions: | |
| // 1. Any optional field should use std::optional. | |
| // 2. All member variables should be private and have their corresponding | |
| // getters and setters. | |
| // 3. For basic types, e.g. int, float, bool, etc., the getters and setters | |
| // should be Get*() and Set*(). | |
| // 4. For complex types, e.g. proto::BenchmarkParams, the getters and setters | |
| // should be Get*() and GetMutable*(). | |
| // 5. For optional fields, the mutable getter should create a default instance | |
| // if the field is not set. But the non-mutable getter should return a | |
| // const reference to the std::optional<T> field. | |
| // Settings used for initializing LiteRT LM Engine. | |
| // This class encapsulates the model-specific settings that are used for | |
| // initializing the LiteRT LM. These settings are typically fixed for a given | |
| // model and are not expected to change during the inference process. | |
| // | |
| // This class is used to initialize the LiteRT LM Engine. The user should | |
| // create an EngineSettings object and then call the MaybeUpdateAndValidate() | |
| // method to validate the settings. If the validation fails, the user should | |
| // not use the EngineSettings object. | |
| // | |
| // Example: | |
| // | |
| // ASSIGN_OR_RETURN(ModelAssets model_assets, | |
| // ModelAssets::Create(model_path)); | |
| // ASSIGN_OR_RETURN(EngineSettings engine_settings, | |
| // EngineSettings::CreateDefault(model_assets)); | |
| // ...initialize the Engine... | |
| // ASSIGN_OR_RETURN(std::unique_ptr<Engine> engine, | |
| // Engine::CreateEngine(engine_settings)); | |
| // TODO(b/397975034) Add overloading << operator for debugging. | |
| class EngineSettings { | |
| public: | |
| // Creates a default EngineSettings with the given model assets and specified | |
| // backend. | |
| static absl::StatusOr<EngineSettings> CreateDefault( | |
| ModelAssets model_assets, Backend backend = Backend::CPU, | |
| std::optional<Backend> vision_backend = std::nullopt, | |
| std::optional<Backend> audio_backend = std::nullopt, | |
| std::optional<Backend> sampler_backend = std::nullopt); | |
| // Updates the EngineSettings fields by loading the metadata from the model | |
| // assets. The function also validates to check if all of the required fields | |
| // are set correctly. Returns an error if the validation fails. | |
| absl::Status MaybeUpdateAndValidate( | |
| Tokenizer* tokenizer, | |
| const proto::LlmMetadata* absl_nullable metadata_from_file, | |
| absl::string_view input_prompt_as_hint = "", | |
| const std::optional<std::string>& text_backend_constraint = std::nullopt, | |
| const std::optional<std::string>& vision_backend_constraint = | |
| std::nullopt, | |
| const std::optional<std::string>& audio_backend_constraint = | |
| std::nullopt); | |
| // Returns the LlmExecutorSettings. | |
| const LlmExecutorSettings& GetMainExecutorSettings() const; | |
| // Returns the mutable LlmExecutorSettings. | |
| LlmExecutorSettings& GetMutableMainExecutorSettings(); | |
| // Returns the VisionExecutorSettings for the vision model. | |
| const std::optional<VisionExecutorSettings>& GetVisionExecutorSettings() | |
| const; | |
| // Returns the mutable VisionExecutorSettings for the vision model. | |
| std::optional<VisionExecutorSettings>& GetMutableVisionExecutorSettings(); | |
| // Returns the AudioExecutorSettings for the audio model. | |
| const std::optional<AudioExecutorSettings>& GetAudioExecutorSettings() const; | |
| // Returns the mutable AudioExecutorSettings for the audio model. | |
| std::optional<AudioExecutorSettings>& GetMutableAudioExecutorSettings(); | |
| // Benchmark parameters: | |
| // Returns true if the benchmark is enabled. | |
| bool IsBenchmarkEnabled() const; | |
| // Returns the benchmark parameters. | |
| const std::optional<proto::BenchmarkParams>& GetBenchmarkParams() const; | |
| // Returns the mutable benchmark parameters. | |
| proto::BenchmarkParams& GetMutableBenchmarkParams(); | |
| // Returns the LlmMetadata parameters. | |
| const std::optional<proto::LlmMetadata>& GetLlmMetadata() const; | |
| // Returns the mutable LlmMetadata parameters. Note that is the metadata_ is | |
| // not set (i.e. std::nullopt), then the default LlmMetadata will be | |
| // created and returned. | |
| proto::LlmMetadata& GetMutableLlmMetadata(); | |
| // Returns true if the engine may load different sections of the litertlm file | |
| // in parallel. | |
| bool GetParallelFileSectionLoading() const; | |
| // Sets whether the engine should load different sections of the litertlm file | |
| // in parallel. | |
| void SetParallelFileSectionLoading(bool parallel_file_section_loading); | |
| private: | |
| explicit EngineSettings( | |
| LlmExecutorSettings executor_settings, | |
| std::optional<VisionExecutorSettings> vision_executor_settings, | |
| std::optional<AudioExecutorSettings> audio_executor_settings, | |
| std::optional<proto::BenchmarkParams> benchmark_params = std::nullopt); | |
| // Settings for the main executor. | |
| LlmExecutorSettings main_executor_settings_; | |
| // Settings for the vision executor. | |
| std::optional<VisionExecutorSettings> vision_executor_settings_; | |
| // Settings for the audio executor. | |
| std::optional<AudioExecutorSettings> audio_executor_settings_; | |
| // Parameters used to configure the benchmarking process. | |
| std::optional<proto::BenchmarkParams> benchmark_params_; | |
| // Default metadata for the model. This is loaded from the model assets (if | |
| // present). | |
| std::optional<proto::LlmMetadata> metadata_; | |
| // Whether the engine should load different sections of the litertlm file in | |
| // parallel. | |
| bool parallel_file_section_loading_ = true; | |
| }; | |
| std::ostream& operator<<(std::ostream& os, const EngineSettings& settings); | |
| // Configurations used for the session. | |
| // This class encapsulates the session-specific configurations that are used for | |
| // creating a LiteRT LM session. | |
| class SessionConfig { | |
| public: | |
| // Creates a default SessionConfig. | |
| static SessionConfig CreateDefault(); | |
| // Updates the SessionConfig fields from the EngineSettings when not set. The | |
| // function also validates to check if all of the required fields are set | |
| // correctly. Returns an error if the validation fails. | |
| absl::Status MaybeUpdateAndValidate(const EngineSettings& engine_settings); | |
| // Configures the audio modality in the session. | |
| bool AudioModalityEnabled() const { return audio_modality_enabled_; } | |
| void SetAudioModalityEnabled(bool enable_audio_modality) { | |
| audio_modality_enabled_ = enable_audio_modality; | |
| } | |
| // Configures the vision modality in the session. | |
| bool VisionModalityEnabled() const { return vision_modality_enabled_; } | |
| void SetVisionModalityEnabled(bool enable_vision_modality) { | |
| vision_modality_enabled_ = enable_vision_modality; | |
| } | |
| // Sampler parameters: | |
| // Getters for the sampler parameters. | |
| const proto::SamplerParameters& GetSamplerParams() const; | |
| proto::SamplerParameters& GetMutableSamplerParams(); | |
| // Stop token ids: | |
| // Getters for the stop token ids. | |
| const std::vector<std::vector<int>>& GetStopTokenIds() const; | |
| std::vector<std::vector<int>>& GetMutableStopTokenIds(); | |
| // Set the start token ids. | |
| int GetStartTokenId() const; | |
| void SetStartTokenId(int start_token_id); | |
| // Number of output candidates: | |
| // Getters for the number of output candidates. | |
| int GetNumOutputCandidates() const; | |
| void SetNumOutputCandidates(int num_output_candidates); | |
| // Sampler backend: | |
| // Getters for the backend of the sampler. | |
| Backend GetSamplerBackend() const; | |
| void SetSamplerBackend(Backend sampler_backend); | |
| // Prompt templates: | |
| // Getters for the prompt templates. | |
| const proto::PromptTemplates& GetPromptTemplates() const; | |
| proto::PromptTemplates& GetMutablePromptTemplates(); | |
| // Llm model type: | |
| // Getters for the LLM model type. | |
| const proto::LlmModelType& GetLlmModelType() const; | |
| proto::LlmModelType& GetMutableLlmModelType(); | |
| // Whether to apply the basic prompt templates in the session. | |
| bool GetApplyPromptTemplateInSession() const { | |
| return apply_prompt_template_in_session_; | |
| } | |
| void SetApplyPromptTemplateInSession(bool apply_prompt_template_in_session) { | |
| apply_prompt_template_in_session_ = apply_prompt_template_in_session; | |
| } | |
| // Whether to use external sampler. | |
| bool UseExternalSampler() const { return use_external_sampler_; } | |
| void SetUseExternalSampler(bool use_external_sampler) { | |
| use_external_sampler_ = use_external_sampler; | |
| } | |
| // Scoped LoRA file: | |
| // Getters for the scoped LoRA file. | |
| std::shared_ptr<ScopedFile> GetScopedLoraFile() const; | |
| void SetScopedLoraFile(std::shared_ptr<ScopedFile> scoped_lora_file); | |
| // The maximum number of tokens to generate in a single request: | |
| // Getters for the max output tokens. | |
| int GetMaxOutputTokens() const { return max_output_tokens_; } | |
| void SetMaxOutputTokens(int max_output_tokens) { | |
| max_output_tokens_ = max_output_tokens; | |
| } | |
| private: | |
| // Private constructor for the SessionConfig. The user should use the | |
| // CreateDefault() method to create a SessionConfig. | |
| explicit SessionConfig(const proto::SamplerParameters& sampler_params); | |
| // Whether to enable audio modality in the session. | |
| bool audio_modality_enabled_ = false; | |
| // Whether to enable vision modality in the session. | |
| bool vision_modality_enabled_ = false; | |
| // Parameters used to configure the sampling process. | |
| proto::SamplerParameters sampler_params_; | |
| // Stop token ids for the session. Note that the stop token could be a | |
| // sequence of token ids (as opposed to a single token id). The first | |
| // dimension is the index of the stop token in the session, and the second | |
| // dimension is the sequence of token ids that constitutes the stop token. | |
| std::vector<std::vector<int>> stop_token_ids_; | |
| // Start token id for the session. | |
| int start_token_id_ = -1; | |
| // Prompt templates for the session. This is loaded from the model assets (if | |
| // present). | |
| proto::PromptTemplates prompt_templates_; | |
| // Llm model type for the session. This is loaded from the model assets (if | |
| // present). | |
| proto::LlmModelType llm_model_type_; | |
| // The number of output candidates to generate. Default value is 1 and setting | |
| // it to a value greater than 1 will require the model to support batching. | |
| int num_output_candidates_ = 1; | |
| // Backend to use for sampling. | |
| Backend sampler_backend_ = Backend::UNSPECIFIED; | |
| // Whether to apply the prompt templates in the session. | |
| bool apply_prompt_template_in_session_ = true; | |
| // Whether to use external sampler. | |
| // notice: this is only used in advanced engine. | |
| bool use_external_sampler_ = false; | |
| // Scoped file for the LoRA weights. | |
| std::shared_ptr<ScopedFile> scoped_lora_file_; | |
| // The maximum number of tokens to generate in a single request. This limits | |
| // the number of decoding steps for a request, as opposed to | |
| // LlmExecutorSettings::GetMaxNumTokens(), which limits the total number of | |
| // tokens (input + output) stored in the KV cache over the lifetime of a | |
| // session. | |
| int max_output_tokens_ = std::numeric_limits<int>::max(); | |
| }; | |
| std::ostream& operator<<(std::ostream& os, const SessionConfig& config); | |
| } // namespace litert::lm | |