// Copyright 2025 The ODML Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/core/session_basic.h" #include #include #include #include #include #include #include #include #include "absl/base/attributes.h" // from @com_google_absl #include "absl/base/const_init.h" // from @com_google_absl #include "absl/container/flat_hash_map.h" // from @com_google_absl #include "absl/container/flat_hash_set.h" // from @com_google_absl #include "absl/functional/any_invocable.h" // from @com_google_absl #include "absl/log/absl_log.h" // from @com_google_absl #include "absl/memory/memory.h" // from @com_google_absl #include "absl/status/status.h" // from @com_google_absl #include "absl/status/statusor.h" // from @com_google_absl #include "absl/strings/str_cat.h" // from @com_google_absl #include "absl/strings/string_view.h" // from @com_google_absl #include "absl/synchronization/mutex.h" // from @com_google_absl #include "litert/cc/litert_layout.h" // from @litert #include "litert/cc/litert_macros.h" // from @litert #include "litert/cc/litert_tensor_buffer.h" // from @litert #include "runtime/components/sampler.h" #include "runtime/components/sampler_factory.h" #include "runtime/components/stop_token_detector.h" #include "runtime/components/tokenizer.h" #include "runtime/core/pipeline.h" #include "runtime/core/session_utils.h" #include "runtime/engine/engine.h" #include "runtime/engine/engine_settings.h" #include "runtime/engine/io_types.h" #include "runtime/executor/audio_executor.h" #include "runtime/executor/executor_settings_base.h" #include "runtime/executor/llm_executor.h" #include "runtime/executor/llm_executor_io_types.h" #include "runtime/executor/vision_executor.h" #include "runtime/framework/threadpool.h" #include "runtime/proto/sampler_params.pb.h" #include "runtime/util/convert_tensor_buffer.h" #include "runtime/util/executor_data_util.h" #include "runtime/util/status_macros.h" // IWYU pragma: keep #include "runtime/util/tensor_buffer_util.h" namespace litert::lm { namespace { using TaskController = Engine::Session::TaskController; } absl::flat_hash_set* SessionBasic::occupied_executors_ = new absl::flat_hash_set(); ABSL_CONST_INIT absl::Mutex SessionBasic::occupied_executors_mu_( absl::kConstInit); // static absl::StatusOr> SessionBasic::Create( LlmExecutor* executor, Tokenizer* tokenizer, VisionExecutor* vision_executor, AudioExecutor* audio_executor, const SessionConfig& session_config, std::optional benchmark_info, ThreadPool* worker_thread_pool) { // Check if the session already exists. absl::MutexLock lock(occupied_executors_mu_); // NOLINT if (occupied_executors_->contains(executor)) { return absl::FailedPreconditionError( "A session already exists. Only one session is supported at a time. " "Please delete the existing session before creating a new one."); } bool enable_speculative_decoding = false; { ASSIGN_OR_RETURN(auto executor_settings, executor->GetExecutorSettings()); auto advanced_settings = executor_settings.GetAdvancedSettings(); if (advanced_settings.has_value()) { enable_speculative_decoding = advanced_settings->enable_speculative_decoding; } } auto sampler_backend = session_config.GetSamplerBackend(); std::unique_ptr sampler; // If use CPU sampling, we create it here; For GPU sampling and when // speculative decoding is enabled, we let executor create it internally. if (sampler_backend == Backend::CPU) { if (!enable_speculative_decoding) { ASSIGN_OR_RETURN(sampler, CreateSampler(sampler_backend, session_config.GetNumOutputCandidates(), session_config.GetSamplerParams())); } } else if (sampler_backend != Backend::GPU && sampler_backend != Backend::NPU) { return absl::InvalidArgumentError( absl::StrCat("Unsupported sampler backend: ", sampler_backend)); } if (benchmark_info.has_value()) { ABSL_LOG(INFO) << "Benchmark is enabled."; } StopTokenDetector stop_token_detector( session_config.GetNumOutputCandidates()); for (const auto& stop_token_sequence : session_config.GetStopTokenIds()) { RETURN_IF_ERROR( stop_token_detector.AddStopTokenSequence(stop_token_sequence)); } occupied_executors_->insert(executor); return absl::WrapUnique(new SessionBasic( executor, tokenizer, vision_executor, audio_executor, std::move(sampler), session_config, benchmark_info, worker_thread_pool, stop_token_detector)); } SessionBasic::~SessionBasic() { WaitUntilDone().IgnoreError(); auto status = executor_.Reset(); if (!status.ok()) { ABSL_LOG(ERROR) << "Failed to reset executor: " << status; } if (audio_executor_ != nullptr) { status = audio_executor_->Reset(); if (!status.ok()) { ABSL_LOG(ERROR) << "Failed to reset audio executor: " << status; } } absl::MutexLock lock(occupied_executors_mu_); // NOLINT occupied_executors_->erase(&executor_); } absl::StatusOr SessionBasic::ProcessAndCombineContents( const std::vector& preprocessed_contents) { std::vector combined_token_ids; std::vector all_image_data; std::vector all_audio_data; for (const auto& preprocessed_content : preprocessed_contents) { if (const auto* input_text = std::get_if(&preprocessed_content)) { ASSIGN_OR_RETURN(const auto* token_ids, input_text->GetPreprocessedTextTensor()); if (token_ids == nullptr) { return absl::InvalidArgumentError( "Token IDs is null in preprocessed_contents."); } LITERT_ASSIGN_OR_RETURN(auto ids_buffer_span, ReferTensorBufferAsSpan(*token_ids)); combined_token_ids.insert(combined_token_ids.end(), ids_buffer_span.begin(), ids_buffer_span.end()); } else if (const auto* input_image = std::get_if(&preprocessed_content)) { if (benchmark_info_.has_value()) { RETURN_IF_ERROR(benchmark_info_->TimeMarkDelta("vision_executor")); } ExecutorVisionData single_image_data; if (input_image->IsTensorBuffer()) { ASSIGN_OR_RETURN(auto tensor_buffer, input_image->GetPreprocessedImageTensor()); ASSIGN_OR_RETURN(single_image_data, vision_executor_->Encode(*tensor_buffer)); } else if (input_image->IsTensorBufferMap()) { ASSIGN_OR_RETURN(auto tensor_buffer_map, input_image->GetPreprocessedImageTensorMap()); ASSIGN_OR_RETURN(single_image_data, vision_executor_->Encode(*tensor_buffer_map)); } else { return absl::FailedPreconditionError( "The image is not preprocessed and does not have a tensor."); } if (benchmark_info_.has_value()) { RETURN_IF_ERROR(benchmark_info_->TimeMarkDelta("vision_executor")); } ASSIGN_OR_RETURN(auto embeddings_ptr, single_image_data.GetEmbeddingsPtr()); const auto& dimensions = TensorBufferDims(*embeddings_ptr); // The last two dimensions are [..., image_token_num, model_dimension]. const int image_token_num = dimensions.at(dimensions.size() - 2); combined_token_ids.insert(combined_token_ids.end(), image_token_num, ExecutorVisionData::kSpecialToken); all_image_data.push_back(std::move(single_image_data)); } else if (const auto* input_image_end = std::get_if(&preprocessed_content)) { combined_token_ids.push_back(ExecutorVisionData::kEndToken); } else if (const auto* input_audio = std::get_if(&preprocessed_content)) { ASSIGN_OR_RETURN(const auto* spectrogram_tensor, input_audio->GetPreprocessedAudioTensor()); if (benchmark_info_.has_value()) { RETURN_IF_ERROR(benchmark_info_->TimeMarkDelta("audio_executor")); } ASSIGN_OR_RETURN(auto single_audio_data, audio_executor_->Encode(*spectrogram_tensor)); if (benchmark_info_.has_value()) { RETURN_IF_ERROR(benchmark_info_->TimeMarkDelta("audio_executor")); } const int num_audio_tokens = single_audio_data.GetValidTokens(); all_audio_data.push_back(std::move(single_audio_data)); combined_token_ids.insert(combined_token_ids.end(), num_audio_tokens, ExecutorAudioData::kSpecialToken); } else if (const auto* input_audio_end = std::get_if(&preprocessed_content)) { combined_token_ids.push_back(ExecutorAudioData::kEndToken); } else { return absl::InvalidArgumentError( "Unsupported input data type in preprocessed_contents."); } } if (combined_token_ids.empty()) { return absl::InvalidArgumentError( "No token IDs found in preprocessed_contents."); } std::optional combined_image_data = std::nullopt; if (!all_image_data.empty()) { ASSIGN_OR_RETURN(combined_image_data, CombineExecutorVisionData(all_image_data)); } std::optional combined_audio_data = std::nullopt; if (!all_audio_data.empty()) { ASSIGN_OR_RETURN(combined_audio_data, CombineExecutorAudioData(all_audio_data)); } ASSIGN_OR_RETURN(auto token_ids_buffer, tokenizer_.TokenIdsToTensorBuffer(combined_token_ids)); ExecutorInputs inputs(ExecutorTextData(std::move(token_ids_buffer)), std::move(combined_image_data), std::move(combined_audio_data)); return inputs; } absl::Status SessionBasic::SaveCheckpoint(absl::string_view label) { ASSIGN_OR_RETURN(int current_step, executor_.GetCurrentStep()); checkpoint_map_[label] = current_step; return absl::OkStatus(); } absl::Status SessionBasic::RewindToCheckpoint(absl::string_view label) { if (auto it = checkpoint_map_.find(label); it != checkpoint_map_.end()) { ASSIGN_OR_RETURN(int current_step, executor_.GetCurrentStep()); if (it->second > current_step) { // This shouldn't ever happen because we remove all checkpoints after the // current step when we rewind to a checkpoint. return absl::InvalidArgumentError( absl::StrCat("Cannot rewind to a future step: ", it->second)); } // Set the current step of the executor to the checkpoint step. RETURN_IF_ERROR(executor_.SetCurrentStep(it->second)); // Remove all checkpoints after the current step. absl::erase_if(checkpoint_map_, [current_step = it->second](const auto& pair) { return pair.second > current_step; }); return absl::OkStatus(); } return absl::NotFoundError(absl::StrCat("Checkpoint not found: ", label)); } absl::StatusOr SessionBasic::GetCurrentStep() const { return executor_.GetCurrentStep(); } absl::Status SessionBasic::PrefillInternal( const std::vector& preprocessed_contents, bool wait_for_completion) { ASSIGN_OR_RETURN(ExecutorInputs inputs, ProcessAndCombineContents(preprocessed_contents)); ASSIGN_OR_RETURN( last_prefill_token_id_, Prefill(executor_, inputs, wait_for_completion, benchmark_info_)); session_state_ = SessionState::kPrefilled; return absl::OkStatus(); } absl::Status SessionBasic::RunPrefill(const std::vector& contents) { if (contents.empty()) { return absl::InvalidArgumentError("Input is empty."); } ABSL_LOG(INFO) << "RunPrefill: "; for (const auto& content : contents) { ABSL_LOG(INFO) << content; } if (cancelled_.load()) { // Reset the cancelled flag before processing the next turn. cancelled_ = false; } std::vector preprocessed_contents; if (benchmark_info_.has_value() && benchmark_info_->GetBenchmarkParams().num_prefill_tokens() > 0) { ASSIGN_OR_RETURN(preprocessed_contents, PreprocessContents(contents, session_config_, tokenizer_, benchmark_info_)); } else { bool is_first_turn = session_state_ == SessionState::kFresh; ContentType content_type; if (session_config_.GetApplyPromptTemplateInSession()) { content_type = (is_first_turn || session_state_ == SessionState::kDecoded) ? ContentType::kFirst : ContentType::kMiddle; } else { content_type = ContentType::kNA; } ASSIGN_OR_RETURN( std::vector templated_contents, ApplyPromptTemplates(contents, content_type, session_config_, tokenizer_, is_first_turn)); ASSIGN_OR_RETURN(preprocessed_contents, PreprocessContents(templated_contents, session_config_, tokenizer_, benchmark_info_)); } return PrefillInternal(preprocessed_contents, /*wait_for_completion=*/true); } absl::StatusOr> SessionBasic::RunPrefillAsync( const std::vector& contents, absl::AnyInvocable)> callback) { if (contents.empty()) { return absl::InvalidArgumentError("Input is empty."); } ABSL_LOG(INFO) << "RunPrefillAsync: "; for (const auto& content : contents) { ABSL_LOG(INFO) << content; } if (cancelled_.load()) { // Reset the cancelled flag before processing the next turn. cancelled_ = false; } std::vector preprocessed_contents; if (benchmark_info_.has_value() && benchmark_info_->GetBenchmarkParams().num_prefill_tokens() > 0) { ASSIGN_OR_RETURN(preprocessed_contents, PreprocessContents(contents, session_config_, tokenizer_, benchmark_info_)); } else { bool is_first_turn = session_state_ == SessionState::kFresh; ContentType content_type; if (session_config_.GetApplyPromptTemplateInSession()) { content_type = (is_first_turn || session_state_ == SessionState::kDecoded) ? ContentType::kFirst : ContentType::kMiddle; } else { content_type = ContentType::kNA; } ASSIGN_OR_RETURN( std::vector templated_contents, ApplyPromptTemplates(contents, content_type, session_config_, tokenizer_, is_first_turn)); ASSIGN_OR_RETURN(preprocessed_contents, PreprocessContents(templated_contents, session_config_, tokenizer_, benchmark_info_)); } RETURN_IF_ERROR(worker_thread_pool_.Schedule( [this, preprocessed_contents = std::move(preprocessed_contents), callback = std::move(callback)]() mutable { absl::Status status = this->PrefillInternal( preprocessed_contents, /*wait_for_completion=*/false); ABSL_LOG(INFO) << "RunPrefillAsync status: " << status; if (cancelled_.load()) { callback( absl::CancelledError("Session is cancelled during prefill.")); return; } if (!status.ok()) { callback(status); } else { callback(Responses(TaskState::kDone)); } })); return nullptr; } absl::StatusOr SessionBasic::DecodeInternal( const DecodeConfig& decode_config) { if (session_state_ != SessionState::kPrefilled) { return absl::InternalError("Session is not prefilled yet."); } // We need to do a last prefill before initializing the decode, to make sure // the prompt is correctly set up for decode. if (session_config_.GetApplyPromptTemplateInSession()) { std::vector contents; contents.emplace_back(InputText("")); ASSIGN_OR_RETURN( std::vector templated_contents, ApplyPromptTemplates(contents, ContentType::kLast, session_config_, tokenizer_, /*is_first_turn=*/false)); if (!templated_contents.empty()) { ASSIGN_OR_RETURN(std::vector preprocessed_contents, PreprocessContents(templated_contents, session_config_, tokenizer_, benchmark_info_)); RETURN_IF_ERROR(PrefillInternal(preprocessed_contents, /*wait_for_completion=*/true)); } } session_state_ = SessionState::kDecoded; if (sampler_ == nullptr) { ASSIGN_OR_RETURN( auto responses, Decode(executor_, tokenizer_, stop_token_detector_, session_config_.GetNumOutputCandidates(), decode_config.GetConstraint(), benchmark_info_, &cancelled_, decode_config.GetMaxOutputTokens().value_or( session_config_.GetMaxOutputTokens()))); return responses; } else { std::vector decoded_ids(session_config_.GetNumOutputCandidates(), last_prefill_token_id_); LITERT_ASSIGN_OR_RETURN( auto decoded_ids_buffer, CopyToTensorBuffer(decoded_ids, {session_config_.GetNumOutputCandidates(), 1})); ASSIGN_OR_RETURN( auto responses, DecodeCustomSampling(executor_, tokenizer_, stop_token_detector_, session_config_.GetNumOutputCandidates(), *sampler_, std::move(decoded_ids_buffer), decode_config.GetConstraint(), benchmark_info_, &cancelled_, decode_config.GetMaxOutputTokens().value_or( session_config_.GetMaxOutputTokens()))); return responses; } } absl::Status SessionBasic::DecodeInternalStreaming( absl::AnyInvocable)> callback, const DecodeConfig& decode_config) { if (sampler_ == nullptr) { RETURN_IF_ERROR(DecodeStreaming( executor_, tokenizer_, stop_token_detector_, session_config_.GetNumOutputCandidates(), decode_config.GetConstraint(), benchmark_info_, std::move(callback), &cancelled_, decode_config.GetMaxOutputTokens().value_or( session_config_.GetMaxOutputTokens()))); } else { std::vector decoded_ids(session_config_.GetNumOutputCandidates(), last_prefill_token_id_); LITERT_ASSIGN_OR_RETURN( auto decoded_ids_buffer, CopyToTensorBuffer(decoded_ids, {session_config_.GetNumOutputCandidates(), 1})); RETURN_IF_ERROR(DecodeCustomSamplingStreaming( executor_, tokenizer_, stop_token_detector_, session_config_.GetNumOutputCandidates(), *sampler_, std::move(decoded_ids_buffer), decode_config.GetConstraint(), benchmark_info_, std::move(callback), &cancelled_, decode_config.GetMaxOutputTokens().value_or( session_config_.GetMaxOutputTokens()))); } return absl::OkStatus(); } absl::StatusOr SessionBasic::RunDecode() { return RunDecode(DecodeConfig::CreateDefault()); } absl::StatusOr SessionBasic::RunDecode( const DecodeConfig& decode_config) { ABSL_LOG(INFO) << "RunDecodeSync"; if (cancelled_.load()) { // Reset the cancelled flag before processing the next turn. cancelled_ = false; } return DecodeInternal(decode_config); } absl::StatusOr> SessionBasic::RunDecodeAsync( absl::AnyInvocable)> callback) { return RunDecodeAsync(std::move(callback), DecodeConfig::CreateDefault()); } absl::StatusOr> SessionBasic::RunDecodeAsync( absl::AnyInvocable)> callback, const DecodeConfig& decode_config) { ABSL_LOG(INFO) << "RunDecodeAsync"; if (cancelled_.load()) { // Reset the cancelled flag before processing the next turn. cancelled_ = false; } RETURN_IF_ERROR(worker_thread_pool_.Schedule( [this, callback = std::move(callback), decode_config]() mutable { this->DecodeInternalStreaming(std::move(callback), decode_config) .IgnoreError(); })); return nullptr; } absl::StatusOr SessionBasic::GenerateContent( const std::vector& contents) { if (cancelled_.load()) { // Reset the cancelled flag before processing the next turn. cancelled_ = false; } RETURN_IF_ERROR(RunPrefill(contents)); return RunDecode(DecodeConfig::CreateDefault()); } absl::StatusOr SessionBasic::RunTextScoring( const std::vector& target_text, bool store_token_lengths) { absl::StatusOr collected_responses; auto scoring_sync_callback = [&collected_responses](absl::StatusOr responses) { collected_responses = std::move(responses); }; ASSIGN_OR_RETURN( auto task_controller, RunTextScoringAsync(target_text, std::move(scoring_sync_callback), store_token_lengths)); RETURN_IF_ERROR(worker_thread_pool_.WaitUntilDone(Engine::kDefaultTimeout)); return collected_responses; } absl::StatusOr> SessionBasic::RunTextScoringAsync( const std::vector& target_text, absl::AnyInvocable)> callback, bool store_token_lengths) { if (target_text.size() != 1) { return absl::InvalidArgumentError("Target text size should be 1."); } // TODO(b/435040163): Handle the temperature. Should it be calculated from // the sampler or the sampler parameters? For now, hardcode it to 1.0f for // testing. auto temperature = 1.0f; RETURN_IF_ERROR(worker_thread_pool_.Schedule( [this, callback = std::move(callback), target_text, store_token_lengths, temperature]() mutable { std::vector decoded_ids(session_config_.GetNumOutputCandidates(), last_prefill_token_id_); auto decoded_ids_buffer = CopyToTensorBuffer( decoded_ids, {session_config_.GetNumOutputCandidates(), 1}); if (!decoded_ids_buffer.HasValue()) { callback(absl::InternalError(decoded_ids_buffer.Error().Message())); return; } callback(ScoreCustomSampling( executor_, tokenizer_, target_text, temperature, std::move(decoded_ids_buffer.Value()), store_token_lengths)); })); return nullptr; } absl::Status SessionBasic::GenerateContentStream( const std::vector& contents, absl::AnyInvocable)> callback) { return GenerateContentStream(contents, std::move(callback), DecodeConfig::CreateDefault()); } absl::Status SessionBasic::GenerateContentStream( const std::vector& contents, absl::AnyInvocable)> callback, const DecodeConfig& decode_config) { if (cancelled_.load()) { // Reset the cancelled flag before processing the next turn. cancelled_ = false; } ASSIGN_OR_RETURN( auto task_controller, RunPrefillAsync( contents, [this, callback = std::move(callback), decode_config = decode_config]( absl::StatusOr responses) mutable { if (!responses.ok()) { callback(responses.status()); } else { if (cancelled_.load()) { callback(absl::CancelledError( "Session is cancelled during prefill.")); return; } auto status = RunDecodeAsync(std::move(callback), decode_config); } })); return absl::OkStatus(); } absl::StatusOr SessionBasic::GetBenchmarkInfo() { if (benchmark_info_.has_value()) { return benchmark_info_.value(); } return absl::InternalError( "Benchmark is not enabled. Please make sure the BenchmarkParams is set " "in the EngineSettings."); } absl::StatusOr SessionBasic::GetMutableBenchmarkInfo() { if (benchmark_info_.has_value()) { return &benchmark_info_.value(); } return absl::InternalError( "Benchmark is not enabled. Please make sure the BenchmarkParams is set " "in the EngineSettings."); } } // namespace litert::lm