Spaces:
Running
Running
| // Copyright 2026 The ODML Authors. | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| // TODO(b/417209286): Remove this once the model assets are stored in the | |
| // litertlm file format. | |
| namespace litert::lm { | |
| namespace { | |
| // Gets the singleton Environment, initializing it on the first call | |
| // with the provided settings. This ensure we maintain the same LiteRT | |
| // environment during the whole application lifetime. This is required for GPU | |
| // LiteRT environment. See b/454383477 for more details. | |
| absl::StatusOr<Environment&> GetEnvironment(EngineSettings& engine_settings, | |
| ModelResources& model_resources) { | |
| // Helper must be available until LlmLiteRtCompiledModelExecutor::Create() is | |
| // called. Since env is used multiple times, it should also be static. | |
| static absl::NoDestructor<MagicNumberConfigsHelper> helper; | |
| static absl::NoDestructor<absl::StatusOr<Environment>> kEnvironment( | |
| [&]() -> absl::StatusOr<Environment> { | |
| std::vector<Environment::Option> env_options; | |
| const auto& main_executor_settings = | |
| engine_settings.GetMainExecutorSettings(); | |
| if ((main_executor_settings.GetBackend() == Backend::CPU) || | |
| (main_executor_settings.GetBackend() == Backend::GPU)) { | |
| if (!main_executor_settings | |
| .GetAdvancedSettings() || // Default is true. | |
| main_executor_settings.GetAdvancedSettings() | |
| ->configure_magic_numbers) { | |
| env_options = helper->GetLiteRtEnvOptions(model_resources, | |
| main_executor_settings); | |
| } | |
| } else { | |
| return absl::InvalidArgumentError( | |
| "Only CPU and GPU backends are supported."); | |
| if (!main_executor_settings.GetLitertDispatchLibDir().empty()) { | |
| // If the dispatch library directory is provided, use it. | |
| env_options.push_back(::litert::Environment::Option{ | |
| ::litert::Environment::OptionTag::DispatchLibraryDir, | |
| main_executor_settings.GetLitertDispatchLibDir()}); | |
| ABSL_LOG(INFO) << "Setting dispatch library path from " | |
| "main_executor_settings: " | |
| << main_executor_settings.GetLitertDispatchLibDir(); | |
| } else { | |
| // Otherwise, use the directory of the model file. | |
| std::string model_path( | |
| main_executor_settings.GetModelAssets().GetPath().value_or("")); | |
| std::filesystem::path path(model_path); | |
| // Note: Existence check for path was here, but it's better to check | |
| // before calling this function if needed. | |
| static const absl::NoDestructor<std::string> kDispatchLibraryPath( | |
| path.parent_path().string()); | |
| if (!kDispatchLibraryPath->empty()) { | |
| ABSL_LOG(INFO) | |
| << "Setting dispatch library path: " << *kDispatchLibraryPath; | |
| env_options.push_back(::litert::Environment::Option{ | |
| ::litert::Environment::OptionTag::DispatchLibraryDir, | |
| absl::string_view(*kDispatchLibraryPath)}); | |
| } else { | |
| ABSL_LOG(INFO) << "No dispatch library path provided."; | |
| } | |
| } | |
| } | |
| LITERT_ASSIGN_OR_RETURN(auto env, Environment::Create(env_options)); | |
| return std::move(env); | |
| }()); | |
| if (!kEnvironment->ok()) { | |
| return kEnvironment->status(); | |
| } | |
| return **kEnvironment; | |
| } | |
| } // namespace | |
| class EngineAdvancedImpl : public Engine { | |
| public: | |
| ~EngineAdvancedImpl() override { | |
| ABSL_QCHECK_OK(WaitUntilDone(Engine::kDefaultTimeout)); | |
| } | |
| static absl::StatusOr<std::unique_ptr<Engine>> Create( | |
| EngineSettings engine_settings, absl::string_view input_prompt_as_hint); | |
| EngineAdvancedImpl(EngineSettings engine_settings, | |
| std::unique_ptr<ModelResources> litert_model_resources, | |
| std::unique_ptr<Tokenizer> tokenizer, | |
| std::unique_ptr<ExecutionManager> execution_manager, | |
| std::optional<BenchmarkInfo> benchmark_info) | |
| : engine_settings_(std::move(engine_settings)), | |
| litert_model_resources_(std::move(litert_model_resources)), | |
| tokenizer_(std::move(tokenizer)), | |
| execution_manager_(std::move(execution_manager)), | |
| benchmark_info_(std::move(benchmark_info)) {} | |
| // Method to create the Session. | |
| absl::StatusOr<std::unique_ptr<Session>> CreateSession( | |
| const SessionConfig& session_config) override { | |
| std::optional<BenchmarkInfo> session_benchmark_info; | |
| if (benchmark_info_.has_value()) { | |
| // Each session will have its own benchmark info, which will be populated | |
| // with the session-specific information. | |
| session_benchmark_info = benchmark_info_; | |
| RETURN_IF_ERROR(session_benchmark_info->TimeInitPhaseStart( | |
| BenchmarkInfo::InitPhase::kSession)); | |
| } | |
| SessionConfig config = session_config; | |
| // TODO(b/418794726): Move this logics to be part of the SessionConfig | |
| // class. | |
| RETURN_IF_ERROR(config.MaybeUpdateAndValidate(engine_settings_)); | |
| ABSL_CHECK(litert_model_resources_ != nullptr); | |
| ASSIGN_OR_RETURN( | |
| auto session, | |
| InitializeSessionAdvanced(execution_manager_, tokenizer_.get(), config, | |
| std::move(session_benchmark_info))); | |
| if (benchmark_info_.has_value()) { | |
| auto session_benchmark_info_or = session->GetMutableBenchmarkInfo(); | |
| if (session_benchmark_info_or.ok()) { | |
| RETURN_IF_ERROR(session_benchmark_info_or.value()->TimeInitPhaseEnd( | |
| BenchmarkInfo::InitPhase::kSession)); | |
| } | |
| } | |
| return session; | |
| } | |
| absl::Status WaitUntilDone(absl::Duration timeout) override { | |
| return execution_manager_->WaitUntilAllDone(timeout); | |
| } | |
| const EngineSettings& GetEngineSettings() const override { | |
| return engine_settings_; | |
| } | |
| const Tokenizer& GetTokenizer() const override { return *tokenizer_; } | |
| absl::StatusOr<AudioExecutorProperties> GetAudioExecutorProperties() | |
| const override { | |
| return GetAudioExecutorPropertiesFromModelResources( | |
| *litert_model_resources_); | |
| } | |
| absl::StatusOr<VisionExecutorProperties> GetVisionExecutorProperties() | |
| const override { | |
| return GetVisionExecutorPropertiesFromModelResources( | |
| *litert_model_resources_); | |
| } | |
| private: | |
| // Stored engine settings. | |
| EngineSettings engine_settings_; | |
| // Model resources, which must outlive `executor_`. | |
| std::unique_ptr<ModelResources> litert_model_resources_; | |
| // Tokenizer shared by all sessions. | |
| std::unique_ptr<Tokenizer> tokenizer_; | |
| // Execution manager for the engine. | |
| std::shared_ptr<ExecutionManager> execution_manager_; | |
| // Benchmark info for the engine. | |
| std::optional<BenchmarkInfo> benchmark_info_; | |
| }; | |
| // Method to create Engine. | |
| absl::StatusOr<std::unique_ptr<Engine>> EngineAdvancedImpl::Create( | |
| EngineSettings engine_settings, absl::string_view input_prompt_as_hint) { | |
| std::optional<BenchmarkInfo> benchmark_info = | |
| engine_settings.IsBenchmarkEnabled() | |
| ? std::make_optional<BenchmarkInfo>( | |
| engine_settings.GetBenchmarkParams().value()) | |
| : std::nullopt; | |
| if (benchmark_info.has_value()) { | |
| RETURN_IF_ERROR( | |
| benchmark_info->TimeInitPhaseStart(BenchmarkInfo::InitPhase::kTotal)); | |
| RETURN_IF_ERROR(benchmark_info->TimeInitPhaseStart( | |
| BenchmarkInfo::InitPhase::kModelAssets)); | |
| } | |
| const auto& model_assets = | |
| engine_settings.GetMutableMainExecutorSettings().GetModelAssets(); | |
| ASSIGN_OR_RETURN(auto model_resources, | |
| BuildLiteRtCompiledModelResources(model_assets)); | |
| if (benchmark_info.has_value()) { | |
| RETURN_IF_ERROR(benchmark_info->TimeInitPhaseEnd( | |
| BenchmarkInfo::InitPhase::kModelAssets)); | |
| } | |
| if (benchmark_info.has_value()) { | |
| RETURN_IF_ERROR(benchmark_info->TimeInitPhaseStart( | |
| BenchmarkInfo::InitPhase::kLlmMetadata)); | |
| } | |
| ASSIGN_OR_RETURN(auto* llm_metadata, model_resources->GetLlmMetadata()); | |
| if (benchmark_info.has_value()) { | |
| RETURN_IF_ERROR(benchmark_info->TimeInitPhaseEnd( | |
| BenchmarkInfo::InitPhase::kLlmMetadata)); | |
| } | |
| bool hasLlmModelType = llm_metadata->has_llm_model_type(); | |
| absl::Duration tokenizer_duration = absl::ZeroDuration(); | |
| // This lambda is used to create the tokenizer asynchronously if the model | |
| // type is available, such that the tokenizer can be created in parallel with | |
| // the executor. | |
| auto create_tokenizer = | |
| [&tokenizer_duration, | |
| &model_resources]() -> absl::StatusOr<std::unique_ptr<Tokenizer>> { | |
| absl::Time start_time = absl::Now(); | |
| ASSIGN_OR_RETURN(std::unique_ptr<Tokenizer> tokenizer, | |
| model_resources->GetTokenizer()); | |
| tokenizer_duration = absl::Now() - start_time; | |
| return tokenizer; | |
| }; | |
| const auto& main_executor_settings = | |
| engine_settings.GetMainExecutorSettings(); | |
| std::future<absl::StatusOr<std::unique_ptr<Tokenizer>>> tokenizer_future; | |
| std::unique_ptr<Tokenizer> tokenizer; | |
| if (!hasLlmModelType) { | |
| ABSL_LOG(INFO) | |
| << "Legacy model files don't have LlmModelType, loading tokenizer now"; | |
| ASSIGN_OR_RETURN(tokenizer, create_tokenizer()); | |
| // Update and load the parameters from the model file and convert the | |
| // tokens to ids. | |
| RETURN_IF_ERROR(engine_settings.MaybeUpdateAndValidate( | |
| tokenizer.get(), llm_metadata, input_prompt_as_hint, | |
| model_resources->GetTFLiteModelBackendConstraint( | |
| ModelType::kTfLitePrefillDecode), | |
| model_resources->GetTFLiteModelBackendConstraint( | |
| ModelType::kTfLiteVisionEncoder), | |
| model_resources->GetTFLiteModelBackendConstraint( | |
| ModelType::kTfLiteAudioEncoderHw))); | |
| } else { | |
| // If the model type is available, wait for the tokenizer to be created | |
| // after the model is loaded. | |
| ABSL_LOG(INFO) << "New model files have LlmModelType, loading tokenizer " | |
| "asynchronously"; | |
| if (engine_settings.GetParallelFileSectionLoading()) { | |
| // Launch the tokenizer creation in a separate thread in parallel with the | |
| // model loading. | |
| tokenizer_future = std::async(std::launch::async, create_tokenizer); | |
| } else { | |
| // Launch the tokenizer creation in the same thread. | |
| tokenizer_future = std::async(std::launch::deferred, create_tokenizer); | |
| } | |
| RETURN_IF_ERROR(engine_settings.MaybeUpdateAndValidate( | |
| nullptr, llm_metadata, input_prompt_as_hint, | |
| model_resources->GetTFLiteModelBackendConstraint( | |
| ModelType::kTfLitePrefillDecode), | |
| model_resources->GetTFLiteModelBackendConstraint( | |
| ModelType::kTfLiteVisionEncoder), | |
| model_resources->GetTFLiteModelBackendConstraint( | |
| ModelType::kTfLiteAudioEncoderHw))); | |
| } | |
| if (benchmark_info.has_value()) { | |
| RETURN_IF_ERROR(benchmark_info->TimeInitPhaseStart( | |
| BenchmarkInfo::InitPhase::kExecutor)); | |
| } | |
| ASSIGN_OR_RETURN(auto& litert_env, | |
| GetEnvironment(engine_settings, *model_resources)); | |
| std::unique_ptr<LlmExecutor> executor; | |
| switch (main_executor_settings.GetBackend()) { | |
| default: { | |
| ASSIGN_OR_RETURN( | |
| executor, CreateLlmLiteRtCompiledModelExecutor( | |
| main_executor_settings, litert_env, *model_resources)); | |
| } | |
| }; | |
| std::unique_ptr<VisionExecutorSettings> vision_executor_settings_ptr; | |
| if (engine_settings.GetVisionExecutorSettings().has_value()) { | |
| vision_executor_settings_ptr = std::make_unique<VisionExecutorSettings>( | |
| std::move(engine_settings.GetVisionExecutorSettings().value())); | |
| if (vision_executor_settings_ptr->GetAdapterBackend() != Backend::CPU) { | |
| ABSL_LOG(WARNING) << "Vision adapter backend is not CPU, which may cause " | |
| "precision loss."; | |
| } | |
| } | |
| std::unique_ptr<AudioExecutorSettings> audio_executor_settings_ptr; | |
| if (engine_settings.GetAudioExecutorSettings().has_value()) { | |
| audio_executor_settings_ptr = std::make_unique<AudioExecutorSettings>( | |
| std::move(engine_settings.GetAudioExecutorSettings().value())); | |
| } | |
| if (benchmark_info.has_value()) { | |
| RETURN_IF_ERROR( | |
| benchmark_info->TimeInitPhaseEnd(BenchmarkInfo::InitPhase::kExecutor)); | |
| } | |
| if (hasLlmModelType) { | |
| // Now load the tokenizer and update the engine settings. | |
| ASSIGN_OR_RETURN(tokenizer, tokenizer_future.get()); | |
| RETURN_IF_ERROR(engine_settings.MaybeUpdateAndValidate( | |
| tokenizer.get(), llm_metadata, input_prompt_as_hint, | |
| model_resources->GetTFLiteModelBackendConstraint( | |
| ModelType::kTfLitePrefillDecode), | |
| model_resources->GetTFLiteModelBackendConstraint( | |
| ModelType::kTfLiteVisionEncoder), | |
| model_resources->GetTFLiteModelBackendConstraint( | |
| ModelType::kTfLiteAudioEncoderHw))); | |
| // As we load the tokenizer asynchronously, we need to update the executor | |
| // settings after the tokenizer is loaded. | |
| RETURN_IF_ERROR(executor->UpdateExecutorSettings( | |
| engine_settings.GetMainExecutorSettings())); | |
| } | |
| if (benchmark_info.has_value()) { | |
| RETURN_IF_ERROR(benchmark_info->InitPhaseRecord( | |
| BenchmarkInfo::InitPhase::kTokenizer, tokenizer_duration)); | |
| } | |
| ASSIGN_OR_RETURN( | |
| auto execution_manager, | |
| ExecutionManager::Create( | |
| tokenizer.get(), model_resources.get(), std::move(executor), | |
| std::move(vision_executor_settings_ptr), | |
| std::move(audio_executor_settings_ptr), &litert_env)); | |
| if (benchmark_info.has_value()) { | |
| RETURN_IF_ERROR( | |
| benchmark_info->TimeInitPhaseEnd(BenchmarkInfo::InitPhase::kTotal)); | |
| } | |
| auto llm_impl = std::make_unique<EngineAdvancedImpl>( | |
| std::move(engine_settings), std::move(model_resources), | |
| std::move(tokenizer), std::move(execution_manager), | |
| std::move(benchmark_info)); | |
| return llm_impl; | |
| }; | |
| LITERT_LM_REGISTER_ENGINE( | |
| EngineFactory::EngineType::kAdvancedLiteRTCompiledModel, | |
| [](EngineSettings settings, absl::string_view input_prompt_as_hint) { | |
| return EngineAdvancedImpl::Create(std::move(settings), | |
| input_prompt_as_hint); | |
| }); | |
| } // namespace litert::lm | |