// Copyright 2025 The ODML Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/engine/engine_settings.h" #include #include #include #include #include #include #include #include "absl/base/nullability.h" // from @com_google_absl #include "absl/log/absl_log.h" // from @com_google_absl #include "absl/status/status.h" // from @com_google_absl #include "absl/status/statusor.h" // from @com_google_absl #include "absl/strings/match.h" // from @com_google_absl #include "absl/strings/str_cat.h" // from @com_google_absl #include "absl/strings/str_split.h" // from @com_google_absl #include "absl/strings/string_view.h" // from @com_google_absl #include "runtime/components/tokenizer.h" #include "runtime/executor/audio_executor_settings.h" #include "runtime/executor/executor_settings_base.h" #include "runtime/executor/llm_executor_settings.h" #include "runtime/executor/vision_executor_settings.h" #include "runtime/proto/engine.pb.h" #include "runtime/proto/llm_metadata.pb.h" #include "runtime/proto/llm_model_type.pb.h" #include "runtime/proto/sampler_params.pb.h" #include "runtime/proto/token.pb.h" #include "runtime/util/model_type_utils.h" #include "runtime/util/scoped_file.h" #include "runtime/util/status_macros.h" // IWYU pragma: keep namespace litert::lm { namespace { // Margin for the default prefill batch size assuming the tokens to indicate the // start and end of the input prompt. constexpr int kDefaultPrefillBatchSizeMargin = 2; std::ostream& operator<<(std::ostream& os, const std::vector& vec) { constexpr int newline_num = 10; os << "vector size: " << vec.size() << ": ["; for (int i = 0; i < vec.size(); ++i) { os << vec[i]; if (i < vec.size() - 1) { os << ", "; } if ((i + 1) % newline_num == 0) { os << "\n"; } } os << "]"; return os; } absl::Status ValidateBackendConstraint( ExecutorSettingsBase& executor_settings, // Polymorphic executor settings. const std::optional& backend_constraint, absl::string_view modality_name) { if (backend_constraint.has_value()) { // When both the executor settings and the backend constraint are set, we // check if the backend constraint contains the backend of the executor // settings. std::string backend_constraint_str = backend_constraint.value(); std::string backend = GetBackendString(executor_settings.GetBackend()); std::vector constraints = absl::StrSplit(backend_constraint_str, ','); bool found = std::any_of(constraints.begin(), constraints.end(), [&](absl::string_view constraint) { return absl::EqualsIgnoreCase(constraint, backend); }); if (!found) { return absl::InvalidArgumentError( absl::StrCat(modality_name, " backend constraint mismatch. Model requires one of [", backend_constraint_str, "] but ", modality_name, " backend is ", backend)); } ABSL_LOG(INFO) << "The " << modality_name << " backend constraint is matched: " << backend; } else { ABSL_LOG(INFO) << "The " << modality_name << " backend constraint is not set."; } return absl::OkStatus(); } } // namespace // static absl::StatusOr EngineSettings::CreateDefault( ModelAssets model_assets, Backend backend, std::optional vision_backend, std::optional audio_backend, std::optional sampler_backend) { ASSIGN_OR_RETURN( // NOLINT auto executor_settings, LlmExecutorSettings::CreateDefault( model_assets, backend, sampler_backend)); std::optional vision_executor_settings; if (vision_backend.has_value()) { ASSIGN_OR_RETURN( vision_executor_settings, VisionExecutorSettings::CreateDefault( model_assets, /*encoder_backend=*/vision_backend.value(), // Vision adapter can only run on CPU. /*adapter_backend=*/Backend::CPU)); } std::optional audio_executor_settings; if (audio_backend.has_value()) { ASSIGN_OR_RETURN(audio_executor_settings, AudioExecutorSettings::CreateDefault( model_assets, executor_settings.GetMaxNumTokens(), audio_backend.value())); } return EngineSettings(std::move(executor_settings), std::move(vision_executor_settings), std::move(audio_executor_settings)); } // TODO(b/488067258): Refactor the method to smaller methods. // For now, support 2 use cases: // 1. The tokenizer is available. // 2. The tokenizer is not available, when it is nullptr. absl::Status EngineSettings::MaybeUpdateAndValidate( Tokenizer* tokenizer, const proto::LlmMetadata* absl_nullable metadata_from_file, absl::string_view input_prompt_as_hint, const std::optional& text_backend_constraint, const std::optional& vision_backend_constraint, const std::optional& audio_backend_constraint) { proto::LlmMetadata& metadata = GetMutableLlmMetadata(); // Copy the metadata from the file if it is provided. if (metadata_from_file != nullptr) { metadata = *metadata_from_file; } // Convert the start/stop tokens from string to token ids. if (tokenizer != nullptr) { for (auto& stop_token : *metadata.mutable_stop_tokens()) { if (stop_token.has_token_str()) { auto stop_token_id = tokenizer->TokenToId(stop_token.token_str()); if (stop_token_id.ok()) { stop_token.mutable_token_ids()->mutable_ids()->Add(*stop_token_id); } else { auto stop_token_ids = tokenizer->TextToTokenIds(stop_token.token_str()); if (stop_token_ids.ok()) { stop_token.mutable_token_ids()->mutable_ids()->Add( stop_token_ids->begin(), stop_token_ids->end()); } } } } if (metadata.start_token().has_token_str()) { auto start_token_id = tokenizer->TokenToId(metadata.start_token().token_str()); if (start_token_id.ok()) { metadata.mutable_start_token()->mutable_token_ids()->mutable_ids()->Add( *start_token_id); } else { auto start_token_ids = tokenizer->TextToTokenIds(metadata.start_token().token_str()); if (start_token_ids.ok()) { metadata.mutable_start_token() ->mutable_token_ids() ->mutable_ids() ->Add(start_token_ids->begin(), start_token_ids->end()); } } } } int num_prompt_tokens = 0; if (!input_prompt_as_hint.empty()) { if (tokenizer == nullptr) { // If the tokenizer is not available, we estimate the number of tokens // in the input prompt by dividing the number of characters by 4. num_prompt_tokens = 1 + input_prompt_as_hint.size() / 4; } else { num_prompt_tokens = tokenizer->TextToTokenIds(input_prompt_as_hint) .value_or(std::vector()) .size(); } } // Load the max num tokens from the model file. // If not set, we set the default value to one based on the number of tokens // in the prompt. if (main_executor_settings_.GetMaxNumTokens() == 0) { // The default maximum number of tokens is set to the smallest multiple of // 4096 greater than the number of tokens in the prompt plus the default // decode length, 1024. int max_num_tokens = ((num_prompt_tokens + 1023) / 4096 + 1) * 4096; if (metadata.max_num_tokens() > 0) { max_num_tokens = metadata.max_num_tokens(); } main_executor_settings_.SetMaxNumTokens(max_num_tokens); } // By default, the audio executor is configured to use the same max num // tokens as the main executor. if (audio_executor_settings_.has_value() && audio_executor_settings_->GetMaxSequenceLength() == 0) { audio_executor_settings_->SetMaxSequenceLength( main_executor_settings_.GetMaxNumTokens()); } if (num_prompt_tokens > 0) { AdvancedSettings advanced_settings; if (main_executor_settings_.GetAdvancedSettings()) { advanced_settings = *main_executor_settings_.GetAdvancedSettings(); } if (advanced_settings.prefill_batch_sizes.empty()) { // If the prefill batch size is not set, set it to the number of tokens // in the input prompt with some margin. advanced_settings.prefill_batch_sizes.insert( num_prompt_tokens + kDefaultPrefillBatchSizeMargin); main_executor_settings_.SetAdvancedSettings(advanced_settings); } } // Set the default values for the sampler params. if (!metadata.has_sampler_params()) { proto::SamplerParameters& sampler_params = *metadata.mutable_sampler_params(); Backend backend = main_executor_settings_.GetBackend(); if (backend == Backend::NPU || backend == Backend::GPU_ARTISAN) { sampler_params.set_type(proto::SamplerParameters::TYPE_UNSPECIFIED); } else if (backend == Backend::CPU || backend == Backend::GPU ) { sampler_params.set_type(proto::SamplerParameters::TOP_P); sampler_params.set_k(1); sampler_params.set_p(0.95f); sampler_params.set_temperature(1.0f); sampler_params.set_seed(0); } else { return absl::InvalidArgumentError( absl::StrCat("Not recognized backend: ", backend)); } } if (!metadata.has_llm_model_type()) { const auto& model_assets = main_executor_settings_.GetModelAssets(); auto model_path = model_assets.GetPath(); if (tokenizer != nullptr) { ASSIGN_OR_RETURN(*metadata.mutable_llm_model_type(), InferLlmModelType(metadata, tokenizer)); } else { return absl::InvalidArgumentError( "Tokenizer is null and LLM model type is not set."); } } // Set allow_src_quantized_fc_conv_ops to default values depending on the // model type if it is not set. auto advanced_settings = AdvancedSettings(); if (main_executor_settings_.GetAdvancedSettings()) { advanced_settings = *main_executor_settings_.GetAdvancedSettings(); } if (!advanced_settings.allow_src_quantized_fc_conv_ops.has_value()) { // Disable src quantized fc conv ops for generic models. If it's well-known, // the quality is acceptable with int8 quantized fc/conv ops. advanced_settings.allow_src_quantized_fc_conv_ops = metadata.has_llm_model_type() && !metadata.llm_model_type().has_generic_model(); main_executor_settings_.SetAdvancedSettings(advanced_settings); } if (!advanced_settings.hint_waiting_for_completion.has_value()) { // Enable a hint for waiting for completion for generic models on GPU. advanced_settings.hint_waiting_for_completion = metadata.has_llm_model_type() && metadata.llm_model_type().has_generic_model(); main_executor_settings_.SetAdvancedSettings(advanced_settings); } // TODO: b/482450588 - Remove this once the bug is fixed. if (metadata.has_llm_model_type() && metadata.llm_model_type().has_function_gemma()) { advanced_settings.convert_weights_on_gpu = false; main_executor_settings_.SetAdvancedSettings(advanced_settings); } // Disable delegate clustering for Gemma 4 models. if (metadata.has_llm_model_type() && metadata.llm_model_type().has_gemma4()) { advanced_settings.disable_delegate_clustering = true; main_executor_settings_.SetAdvancedSettings(advanced_settings); } if (IsBenchmarkEnabled()) { advanced_settings.is_benchmark = true; main_executor_settings_.SetAdvancedSettings(advanced_settings); } else if (!advanced_settings.gpu_context_low_priority.has_value()) { // When we are not in benchmark mode, we set the OpenCL context low priority // for generic models, such that the UI thread can be smoother. advanced_settings.gpu_context_low_priority = metadata.has_llm_model_type() && metadata.llm_model_type().has_generic_model(); main_executor_settings_.SetAdvancedSettings(advanced_settings); } if (!metadata.has_jinja_prompt_template()) { ASSIGN_OR_RETURN(*metadata.mutable_jinja_prompt_template(), GetDefaultJinjaPromptTemplate(metadata.prompt_templates(), metadata.llm_model_type())); } // If the executor settings is set, then check if the input backend // constraint is compatible with the executor settings. RETURN_IF_ERROR(ValidateBackendConstraint(main_executor_settings_, text_backend_constraint, "Main")); if (vision_executor_settings_.has_value()) { RETURN_IF_ERROR(ValidateBackendConstraint(vision_executor_settings_.value(), vision_backend_constraint, "Vision")); } if (audio_executor_settings_.has_value()) { RETURN_IF_ERROR(ValidateBackendConstraint( audio_executor_settings_.value(), audio_backend_constraint, "Audio")); } ABSL_VLOG(5) << "The llm metadata: " << metadata.DebugString(); ABSL_LOG(INFO) << "The validated engine settings: " << *this; return absl::OkStatus(); } EngineSettings::EngineSettings( LlmExecutorSettings executor_settings, std::optional vision_executor_settings, std::optional audio_executor_settings, std::optional benchmark_params) : main_executor_settings_(std::move(executor_settings)), vision_executor_settings_(std::move(vision_executor_settings)), audio_executor_settings_(std::move(audio_executor_settings)), benchmark_params_(benchmark_params) {} const LlmExecutorSettings& EngineSettings::GetMainExecutorSettings() const { return main_executor_settings_; } LlmExecutorSettings& EngineSettings::GetMutableMainExecutorSettings() { return main_executor_settings_; } const std::optional& EngineSettings::GetVisionExecutorSettings() const { return vision_executor_settings_; } std::optional& EngineSettings::GetMutableVisionExecutorSettings() { return vision_executor_settings_; } const std::optional& EngineSettings::GetAudioExecutorSettings() const { return audio_executor_settings_; } std::optional& EngineSettings::GetMutableAudioExecutorSettings() { return audio_executor_settings_; } // Benchmark parameters: // Returns true if the benchmark is enabled. bool EngineSettings::IsBenchmarkEnabled() const { return benchmark_params_.has_value(); } // Returns the benchmark parameters. const std::optional& EngineSettings::GetBenchmarkParams() const { return benchmark_params_; } // Returns the mutable benchmark parameters. proto::BenchmarkParams& EngineSettings::GetMutableBenchmarkParams() { if (!benchmark_params_.has_value()) { benchmark_params_ = proto::BenchmarkParams(); } return benchmark_params_.value(); } const std::optional& EngineSettings::GetLlmMetadata() const { return metadata_; } std::ostream& operator<<(std::ostream& os, const EngineSettings& settings) { os << "EngineSettings: " << std::endl; os << " MainExecutorSettings: " << settings.GetMainExecutorSettings(); if (settings.GetLlmMetadata().has_value()) { os << " LlmMetadata: " << settings.GetLlmMetadata().value().DebugString(); } else { os << " LlmMetadata: Not set" << std::endl; } if (settings.GetBenchmarkParams().has_value()) { os << " BenchmarkParams: " << settings.GetBenchmarkParams().value().DebugString(); } else { os << " BenchmarkParams: Not set" << std::endl; } if (settings.GetVisionExecutorSettings().has_value()) { os << " VisionExecutorSettings: " << settings.GetVisionExecutorSettings().value(); } else { os << " VisionExecutorSettings: Not set" << std::endl; } if (settings.GetAudioExecutorSettings().has_value()) { os << " AudioExecutorSettings: " << settings.GetAudioExecutorSettings().value(); } else { os << " AudioExecutorSettings: Not set" << std::endl; } os << " ParallelFileSectionLoading: " << settings.GetParallelFileSectionLoading() << std::endl; return os; } proto::LlmMetadata& EngineSettings::GetMutableLlmMetadata() { if (!metadata_.has_value()) { metadata_ = proto::LlmMetadata(); } return metadata_.value(); } bool EngineSettings::GetParallelFileSectionLoading() const { return parallel_file_section_loading_; } void EngineSettings::SetParallelFileSectionLoading( bool parallel_file_section_loading) { parallel_file_section_loading_ = parallel_file_section_loading; } SessionConfig SessionConfig::CreateDefault() { proto::SamplerParameters sampler_params; sampler_params.set_type(proto::SamplerParameters::TYPE_UNSPECIFIED); auto config = SessionConfig(sampler_params); config.SetNumOutputCandidates(1); // Default to -1 to indicate the start token is not set. This is to be // overridden by the EngineSettings. config.SetStartTokenId(-1); return config; } absl::Status SessionConfig::MaybeUpdateAndValidate( const EngineSettings& engine_settings) { if ((stop_token_ids_.empty()) && !engine_settings.GetLlmMetadata().has_value()) { return absl::InvalidArgumentError( "Required: set stop tokens, or provide LlmMetadata."); } // Update the parameters from the engine settings when the LlmMetadata is // present. if (engine_settings.GetLlmMetadata().has_value()) { const auto llm_metadata = engine_settings.GetLlmMetadata().value(); proto::SamplerParameters& sampler_params = GetMutableSamplerParams(); // Update the sampler params if the session config does not have a sampler // params and the engine settings has a sampler params (probably read from // the model file). if ((sampler_params.type() == proto::SamplerParameters::TYPE_UNSPECIFIED)) { if (llm_metadata.has_sampler_params()) { sampler_params = engine_settings.GetLlmMetadata()->sampler_params(); } } // Set and validate the start token. if (start_token_id_ == -1) { if (llm_metadata.has_start_token()) { if (llm_metadata.start_token().token_ids().ids_size() > 1) { ABSL_LOG(WARNING) << "The start token has more than one token ids: "; } start_token_id_ = llm_metadata.start_token().token_ids().ids(0); } } // Set and validate the stop tokens. if (stop_token_ids_.empty()) { for (const auto& stop_token : llm_metadata.stop_tokens()) { if (stop_token.has_token_ids() && stop_token.token_ids().ids_size() > 0) { std::vector stop_token_ids(stop_token.token_ids().ids().begin(), stop_token.token_ids().ids().end()); stop_token_ids_.push_back(stop_token_ids); } } } // Set the prompt template from LlmMetadata, if not provided in // SessionConfig. // // Hack: use the user field to check if the prompt template is being set. // To use the empty prompt_template, set the user field with empty prefix. // // TODO(b/439648399): Remove this logic when LiteRT-LM no longer use // template in Session level. if (!prompt_templates_.has_user() && llm_metadata.has_prompt_templates()) { prompt_templates_ = llm_metadata.prompt_templates(); } if (llm_model_type_.model_type_case() == proto::LlmModelType::MODEL_TYPE_NOT_SET) { llm_model_type_ = llm_metadata.llm_model_type(); } } // Validating the required fields are set correctly. if (stop_token_ids_.empty()) { return absl::InvalidArgumentError( "Stop tokens are required. Either set the stop token ids or " "provide " "a valid stop token in the model file/engine settings."); } if (num_output_candidates_ < 1) { return absl::InvalidArgumentError(absl::StrCat( "Number of output candidates need to be at least 1, but got: ", num_output_candidates_)); } if (sampler_backend_ == Backend::UNSPECIFIED) { if (engine_settings.GetMainExecutorSettings().GetBackend() == Backend::GPU) { sampler_backend_ = Backend::GPU; } else { sampler_backend_ = Backend::CPU; } } ABSL_VLOG(5) << "The validated session config: " << *this; return absl::OkStatus(); } SessionConfig::SessionConfig(const proto::SamplerParameters& sampler_params) : sampler_params_(sampler_params) {} const proto::SamplerParameters& SessionConfig::GetSamplerParams() const { return sampler_params_; } proto::SamplerParameters& SessionConfig::GetMutableSamplerParams() { return sampler_params_; } const std::vector>& SessionConfig::GetStopTokenIds() const { return stop_token_ids_; } std::vector>& SessionConfig::GetMutableStopTokenIds() { return stop_token_ids_; } int SessionConfig::GetStartTokenId() const { return start_token_id_; } void SessionConfig::SetStartTokenId(int start_token_id) { start_token_id_ = start_token_id; } int SessionConfig::GetNumOutputCandidates() const { return num_output_candidates_; } void SessionConfig::SetNumOutputCandidates(int num_output_candidates) { num_output_candidates_ = num_output_candidates; } const proto::PromptTemplates& SessionConfig::GetPromptTemplates() const { return prompt_templates_; } proto::PromptTemplates& SessionConfig::GetMutablePromptTemplates() { return prompt_templates_; } const proto::LlmModelType& SessionConfig::GetLlmModelType() const { return llm_model_type_; } proto::LlmModelType& SessionConfig::GetMutableLlmModelType() { return llm_model_type_; } std::shared_ptr SessionConfig::GetScopedLoraFile() const { return scoped_lora_file_; } void SessionConfig::SetScopedLoraFile( std::shared_ptr scoped_lora_file) { scoped_lora_file_ = std::move(scoped_lora_file); } std::ostream& operator<<(std::ostream& os, const SessionConfig& config) { os << "SessionConfig: " << std::endl; os << " AudioModalityEnabled: " << config.AudioModalityEnabled() << std::endl; os << " VisionModalityEnabled: " << config.VisionModalityEnabled() << std::endl; os << " SamplerParams: " << config.GetSamplerParams().DebugString() << std::endl; os << " SamplerBackend: " << config.GetSamplerBackend() << std::endl; os << " StartTokenId: " << config.GetStartTokenId() << std::endl; os << " StopTokenIds: " << std::endl; for (const auto& stop_token_ids : config.GetStopTokenIds()) { os << " " << stop_token_ids << std::endl; } os << " NumOutputCandidates: " << config.GetNumOutputCandidates() << std::endl; os << " LlmModelType: " << config.GetLlmModelType().DebugString() << std::endl; os << " PromptTemplates: " << config.GetPromptTemplates().DebugString() << std::endl; os << " ApplyPromptTemplatesInSession: " << config.GetApplyPromptTemplateInSession() << std::endl; os << " ScopedLoraFile: " << (config.GetScopedLoraFile() != nullptr ? "Present" : "Not present") << std::endl; return os; } Backend SessionConfig::GetSamplerBackend() const { return sampler_backend_; } void SessionConfig::SetSamplerBackend(Backend sampler_backend) { sampler_backend_ = sampler_backend; } } // namespace litert::lm