Spaces:
Running
Running
| // Copyright 2025 The ODML Authors. | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| namespace litert::lm { | |
| namespace { | |
| using TaskController = Engine::Session::TaskController; | |
| } | |
| absl::flat_hash_set<LlmExecutor*>* SessionBasic::occupied_executors_ = | |
| new absl::flat_hash_set<LlmExecutor*>(); | |
| ABSL_CONST_INIT absl::Mutex SessionBasic::occupied_executors_mu_( | |
| absl::kConstInit); | |
| // static | |
| absl::StatusOr<std::unique_ptr<SessionBasic>> SessionBasic::Create( | |
| LlmExecutor* executor, Tokenizer* tokenizer, | |
| VisionExecutor* vision_executor, AudioExecutor* audio_executor, | |
| const SessionConfig& session_config, | |
| std::optional<BenchmarkInfo> benchmark_info, | |
| ThreadPool* worker_thread_pool) { | |
| // Check if the session already exists. | |
| absl::MutexLock lock(occupied_executors_mu_); // NOLINT | |
| if (occupied_executors_->contains(executor)) { | |
| return absl::FailedPreconditionError( | |
| "A session already exists. Only one session is supported at a time. " | |
| "Please delete the existing session before creating a new one."); | |
| } | |
| bool enable_speculative_decoding = false; | |
| { | |
| ASSIGN_OR_RETURN(auto executor_settings, executor->GetExecutorSettings()); | |
| auto advanced_settings = executor_settings.GetAdvancedSettings(); | |
| if (advanced_settings.has_value()) { | |
| enable_speculative_decoding = | |
| advanced_settings->enable_speculative_decoding; | |
| } | |
| } | |
| auto sampler_backend = session_config.GetSamplerBackend(); | |
| std::unique_ptr<Sampler> sampler; | |
| // If use CPU sampling, we create it here; For GPU sampling and when | |
| // speculative decoding is enabled, we let executor create it internally. | |
| if (sampler_backend == Backend::CPU) { | |
| if (!enable_speculative_decoding) { | |
| ASSIGN_OR_RETURN(sampler, | |
| CreateSampler(sampler_backend, | |
| session_config.GetNumOutputCandidates(), | |
| session_config.GetSamplerParams())); | |
| } | |
| } else if (sampler_backend != Backend::GPU && | |
| sampler_backend != Backend::NPU) { | |
| return absl::InvalidArgumentError( | |
| absl::StrCat("Unsupported sampler backend: ", sampler_backend)); | |
| } | |
| if (benchmark_info.has_value()) { | |
| ABSL_LOG(INFO) << "Benchmark is enabled."; | |
| } | |
| StopTokenDetector stop_token_detector( | |
| session_config.GetNumOutputCandidates()); | |
| for (const auto& stop_token_sequence : session_config.GetStopTokenIds()) { | |
| RETURN_IF_ERROR( | |
| stop_token_detector.AddStopTokenSequence(stop_token_sequence)); | |
| } | |
| occupied_executors_->insert(executor); | |
| return absl::WrapUnique(new SessionBasic( | |
| executor, tokenizer, vision_executor, audio_executor, std::move(sampler), | |
| session_config, benchmark_info, worker_thread_pool, stop_token_detector)); | |
| } | |
| SessionBasic::~SessionBasic() { | |
| WaitUntilDone().IgnoreError(); | |
| auto status = executor_.Reset(); | |
| if (!status.ok()) { | |
| ABSL_LOG(ERROR) << "Failed to reset executor: " << status; | |
| } | |
| if (audio_executor_ != nullptr) { | |
| status = audio_executor_->Reset(); | |
| if (!status.ok()) { | |
| ABSL_LOG(ERROR) << "Failed to reset audio executor: " << status; | |
| } | |
| } | |
| absl::MutexLock lock(occupied_executors_mu_); // NOLINT | |
| occupied_executors_->erase(&executor_); | |
| } | |
| absl::StatusOr<ExecutorInputs> SessionBasic::ProcessAndCombineContents( | |
| const std::vector<InputData>& preprocessed_contents) { | |
| std::vector<int> combined_token_ids; | |
| std::vector<ExecutorVisionData> all_image_data; | |
| std::vector<ExecutorAudioData> all_audio_data; | |
| for (const auto& preprocessed_content : preprocessed_contents) { | |
| if (const auto* input_text = | |
| std::get_if<InputText>(&preprocessed_content)) { | |
| ASSIGN_OR_RETURN(const auto* token_ids, | |
| input_text->GetPreprocessedTextTensor()); | |
| if (token_ids == nullptr) { | |
| return absl::InvalidArgumentError( | |
| "Token IDs is null in preprocessed_contents."); | |
| } | |
| LITERT_ASSIGN_OR_RETURN(auto ids_buffer_span, | |
| ReferTensorBufferAsSpan<int>(*token_ids)); | |
| combined_token_ids.insert(combined_token_ids.end(), | |
| ids_buffer_span.begin(), ids_buffer_span.end()); | |
| } else if (const auto* input_image = | |
| std::get_if<InputImage>(&preprocessed_content)) { | |
| if (benchmark_info_.has_value()) { | |
| RETURN_IF_ERROR(benchmark_info_->TimeMarkDelta("vision_executor")); | |
| } | |
| ExecutorVisionData single_image_data; | |
| if (input_image->IsTensorBuffer()) { | |
| ASSIGN_OR_RETURN(auto tensor_buffer, | |
| input_image->GetPreprocessedImageTensor()); | |
| ASSIGN_OR_RETURN(single_image_data, | |
| vision_executor_->Encode(*tensor_buffer)); | |
| } else if (input_image->IsTensorBufferMap()) { | |
| ASSIGN_OR_RETURN(auto tensor_buffer_map, | |
| input_image->GetPreprocessedImageTensorMap()); | |
| ASSIGN_OR_RETURN(single_image_data, | |
| vision_executor_->Encode(*tensor_buffer_map)); | |
| } else { | |
| return absl::FailedPreconditionError( | |
| "The image is not preprocessed and does not have a tensor."); | |
| } | |
| if (benchmark_info_.has_value()) { | |
| RETURN_IF_ERROR(benchmark_info_->TimeMarkDelta("vision_executor")); | |
| } | |
| ASSIGN_OR_RETURN(auto embeddings_ptr, | |
| single_image_data.GetEmbeddingsPtr()); | |
| const auto& dimensions = TensorBufferDims(*embeddings_ptr); | |
| // The last two dimensions are [..., image_token_num, model_dimension]. | |
| const int image_token_num = dimensions.at(dimensions.size() - 2); | |
| combined_token_ids.insert(combined_token_ids.end(), image_token_num, | |
| ExecutorVisionData::kSpecialToken); | |
| all_image_data.push_back(std::move(single_image_data)); | |
| } else if (const auto* input_image_end = | |
| std::get_if<InputImageEnd>(&preprocessed_content)) { | |
| combined_token_ids.push_back(ExecutorVisionData::kEndToken); | |
| } else if (const auto* input_audio = | |
| std::get_if<InputAudio>(&preprocessed_content)) { | |
| ASSIGN_OR_RETURN(const auto* spectrogram_tensor, | |
| input_audio->GetPreprocessedAudioTensor()); | |
| if (benchmark_info_.has_value()) { | |
| RETURN_IF_ERROR(benchmark_info_->TimeMarkDelta("audio_executor")); | |
| } | |
| ASSIGN_OR_RETURN(auto single_audio_data, | |
| audio_executor_->Encode(*spectrogram_tensor)); | |
| if (benchmark_info_.has_value()) { | |
| RETURN_IF_ERROR(benchmark_info_->TimeMarkDelta("audio_executor")); | |
| } | |
| const int num_audio_tokens = single_audio_data.GetValidTokens(); | |
| all_audio_data.push_back(std::move(single_audio_data)); | |
| combined_token_ids.insert(combined_token_ids.end(), num_audio_tokens, | |
| ExecutorAudioData::kSpecialToken); | |
| } else if (const auto* input_audio_end = | |
| std::get_if<InputAudioEnd>(&preprocessed_content)) { | |
| combined_token_ids.push_back(ExecutorAudioData::kEndToken); | |
| } else { | |
| return absl::InvalidArgumentError( | |
| "Unsupported input data type in preprocessed_contents."); | |
| } | |
| } | |
| if (combined_token_ids.empty()) { | |
| return absl::InvalidArgumentError( | |
| "No token IDs found in preprocessed_contents."); | |
| } | |
| std::optional<ExecutorVisionData> combined_image_data = std::nullopt; | |
| if (!all_image_data.empty()) { | |
| ASSIGN_OR_RETURN(combined_image_data, | |
| CombineExecutorVisionData(all_image_data)); | |
| } | |
| std::optional<ExecutorAudioData> combined_audio_data = std::nullopt; | |
| if (!all_audio_data.empty()) { | |
| ASSIGN_OR_RETURN(combined_audio_data, | |
| CombineExecutorAudioData(all_audio_data)); | |
| } | |
| ASSIGN_OR_RETURN(auto token_ids_buffer, | |
| tokenizer_.TokenIdsToTensorBuffer(combined_token_ids)); | |
| ExecutorInputs inputs(ExecutorTextData(std::move(token_ids_buffer)), | |
| std::move(combined_image_data), | |
| std::move(combined_audio_data)); | |
| return inputs; | |
| } | |
| absl::Status SessionBasic::SaveCheckpoint(absl::string_view label) { | |
| ASSIGN_OR_RETURN(int current_step, executor_.GetCurrentStep()); | |
| checkpoint_map_[label] = current_step; | |
| return absl::OkStatus(); | |
| } | |
| absl::Status SessionBasic::RewindToCheckpoint(absl::string_view label) { | |
| if (auto it = checkpoint_map_.find(label); it != checkpoint_map_.end()) { | |
| ASSIGN_OR_RETURN(int current_step, executor_.GetCurrentStep()); | |
| if (it->second > current_step) { | |
| // This shouldn't ever happen because we remove all checkpoints after the | |
| // current step when we rewind to a checkpoint. | |
| return absl::InvalidArgumentError( | |
| absl::StrCat("Cannot rewind to a future step: ", it->second)); | |
| } | |
| // Set the current step of the executor to the checkpoint step. | |
| RETURN_IF_ERROR(executor_.SetCurrentStep(it->second)); | |
| // Remove all checkpoints after the current step. | |
| absl::erase_if(checkpoint_map_, | |
| [current_step = it->second](const auto& pair) { | |
| return pair.second > current_step; | |
| }); | |
| return absl::OkStatus(); | |
| } | |
| return absl::NotFoundError(absl::StrCat("Checkpoint not found: ", label)); | |
| } | |
| absl::StatusOr<int> SessionBasic::GetCurrentStep() const { | |
| return executor_.GetCurrentStep(); | |
| } | |
| absl::Status SessionBasic::PrefillInternal( | |
| const std::vector<InputData>& preprocessed_contents, | |
| bool wait_for_completion) { | |
| ASSIGN_OR_RETURN(ExecutorInputs inputs, | |
| ProcessAndCombineContents(preprocessed_contents)); | |
| ASSIGN_OR_RETURN( | |
| last_prefill_token_id_, | |
| Prefill(executor_, inputs, wait_for_completion, benchmark_info_)); | |
| session_state_ = SessionState::kPrefilled; | |
| return absl::OkStatus(); | |
| } | |
| absl::Status SessionBasic::RunPrefill(const std::vector<InputData>& contents) { | |
| if (contents.empty()) { | |
| return absl::InvalidArgumentError("Input is empty."); | |
| } | |
| ABSL_LOG(INFO) << "RunPrefill: "; | |
| for (const auto& content : contents) { | |
| ABSL_LOG(INFO) << content; | |
| } | |
| if (cancelled_.load()) { | |
| // Reset the cancelled flag before processing the next turn. | |
| cancelled_ = false; | |
| } | |
| std::vector<InputData> preprocessed_contents; | |
| if (benchmark_info_.has_value() && | |
| benchmark_info_->GetBenchmarkParams().num_prefill_tokens() > 0) { | |
| ASSIGN_OR_RETURN(preprocessed_contents, | |
| PreprocessContents(contents, session_config_, tokenizer_, | |
| benchmark_info_)); | |
| } else { | |
| bool is_first_turn = session_state_ == SessionState::kFresh; | |
| ContentType content_type; | |
| if (session_config_.GetApplyPromptTemplateInSession()) { | |
| content_type = (is_first_turn || session_state_ == SessionState::kDecoded) | |
| ? ContentType::kFirst | |
| : ContentType::kMiddle; | |
| } else { | |
| content_type = ContentType::kNA; | |
| } | |
| ASSIGN_OR_RETURN( | |
| std::vector<InputData> templated_contents, | |
| ApplyPromptTemplates(contents, content_type, session_config_, | |
| tokenizer_, is_first_turn)); | |
| ASSIGN_OR_RETURN(preprocessed_contents, | |
| PreprocessContents(templated_contents, session_config_, | |
| tokenizer_, benchmark_info_)); | |
| } | |
| return PrefillInternal(preprocessed_contents, | |
| /*wait_for_completion=*/true); | |
| } | |
| absl::StatusOr<std::unique_ptr<TaskController>> SessionBasic::RunPrefillAsync( | |
| const std::vector<InputData>& contents, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback) { | |
| if (contents.empty()) { | |
| return absl::InvalidArgumentError("Input is empty."); | |
| } | |
| ABSL_LOG(INFO) << "RunPrefillAsync: "; | |
| for (const auto& content : contents) { | |
| ABSL_LOG(INFO) << content; | |
| } | |
| if (cancelled_.load()) { | |
| // Reset the cancelled flag before processing the next turn. | |
| cancelled_ = false; | |
| } | |
| std::vector<InputData> preprocessed_contents; | |
| if (benchmark_info_.has_value() && | |
| benchmark_info_->GetBenchmarkParams().num_prefill_tokens() > 0) { | |
| ASSIGN_OR_RETURN(preprocessed_contents, | |
| PreprocessContents(contents, session_config_, tokenizer_, | |
| benchmark_info_)); | |
| } else { | |
| bool is_first_turn = session_state_ == SessionState::kFresh; | |
| ContentType content_type; | |
| if (session_config_.GetApplyPromptTemplateInSession()) { | |
| content_type = (is_first_turn || session_state_ == SessionState::kDecoded) | |
| ? ContentType::kFirst | |
| : ContentType::kMiddle; | |
| } else { | |
| content_type = ContentType::kNA; | |
| } | |
| ASSIGN_OR_RETURN( | |
| std::vector<InputData> templated_contents, | |
| ApplyPromptTemplates(contents, content_type, session_config_, | |
| tokenizer_, is_first_turn)); | |
| ASSIGN_OR_RETURN(preprocessed_contents, | |
| PreprocessContents(templated_contents, session_config_, | |
| tokenizer_, benchmark_info_)); | |
| } | |
| RETURN_IF_ERROR(worker_thread_pool_.Schedule( | |
| [this, preprocessed_contents = std::move(preprocessed_contents), | |
| callback = std::move(callback)]() mutable { | |
| absl::Status status = this->PrefillInternal( | |
| preprocessed_contents, /*wait_for_completion=*/false); | |
| ABSL_LOG(INFO) << "RunPrefillAsync status: " << status; | |
| if (cancelled_.load()) { | |
| callback( | |
| absl::CancelledError("Session is cancelled during prefill.")); | |
| return; | |
| } | |
| if (!status.ok()) { | |
| callback(status); | |
| } else { | |
| callback(Responses(TaskState::kDone)); | |
| } | |
| })); | |
| return nullptr; | |
| } | |
| absl::StatusOr<Responses> SessionBasic::DecodeInternal( | |
| const DecodeConfig& decode_config) { | |
| if (session_state_ != SessionState::kPrefilled) { | |
| return absl::InternalError("Session is not prefilled yet."); | |
| } | |
| // We need to do a last prefill before initializing the decode, to make sure | |
| // the prompt is correctly set up for decode. | |
| if (session_config_.GetApplyPromptTemplateInSession()) { | |
| std::vector<InputData> contents; | |
| contents.emplace_back(InputText("")); | |
| ASSIGN_OR_RETURN( | |
| std::vector<InputData> templated_contents, | |
| ApplyPromptTemplates(contents, ContentType::kLast, session_config_, | |
| tokenizer_, /*is_first_turn=*/false)); | |
| if (!templated_contents.empty()) { | |
| ASSIGN_OR_RETURN(std::vector<InputData> preprocessed_contents, | |
| PreprocessContents(templated_contents, session_config_, | |
| tokenizer_, benchmark_info_)); | |
| RETURN_IF_ERROR(PrefillInternal(preprocessed_contents, | |
| /*wait_for_completion=*/true)); | |
| } | |
| } | |
| session_state_ = SessionState::kDecoded; | |
| if (sampler_ == nullptr) { | |
| ASSIGN_OR_RETURN( | |
| auto responses, | |
| Decode(executor_, tokenizer_, stop_token_detector_, | |
| session_config_.GetNumOutputCandidates(), | |
| decode_config.GetConstraint(), benchmark_info_, &cancelled_, | |
| decode_config.GetMaxOutputTokens().value_or( | |
| session_config_.GetMaxOutputTokens()))); | |
| return responses; | |
| } else { | |
| std::vector<int> decoded_ids(session_config_.GetNumOutputCandidates(), | |
| last_prefill_token_id_); | |
| LITERT_ASSIGN_OR_RETURN( | |
| auto decoded_ids_buffer, | |
| CopyToTensorBuffer<int>(decoded_ids, | |
| {session_config_.GetNumOutputCandidates(), 1})); | |
| ASSIGN_OR_RETURN( | |
| auto responses, | |
| DecodeCustomSampling(executor_, tokenizer_, stop_token_detector_, | |
| session_config_.GetNumOutputCandidates(), | |
| *sampler_, std::move(decoded_ids_buffer), | |
| decode_config.GetConstraint(), benchmark_info_, | |
| &cancelled_, | |
| decode_config.GetMaxOutputTokens().value_or( | |
| session_config_.GetMaxOutputTokens()))); | |
| return responses; | |
| } | |
| } | |
| absl::Status SessionBasic::DecodeInternalStreaming( | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback, | |
| const DecodeConfig& decode_config) { | |
| if (sampler_ == nullptr) { | |
| RETURN_IF_ERROR(DecodeStreaming( | |
| executor_, tokenizer_, stop_token_detector_, | |
| session_config_.GetNumOutputCandidates(), decode_config.GetConstraint(), | |
| benchmark_info_, std::move(callback), &cancelled_, | |
| decode_config.GetMaxOutputTokens().value_or( | |
| session_config_.GetMaxOutputTokens()))); | |
| } else { | |
| std::vector<int> decoded_ids(session_config_.GetNumOutputCandidates(), | |
| last_prefill_token_id_); | |
| LITERT_ASSIGN_OR_RETURN( | |
| auto decoded_ids_buffer, | |
| CopyToTensorBuffer<int>(decoded_ids, | |
| {session_config_.GetNumOutputCandidates(), 1})); | |
| RETURN_IF_ERROR(DecodeCustomSamplingStreaming( | |
| executor_, tokenizer_, stop_token_detector_, | |
| session_config_.GetNumOutputCandidates(), *sampler_, | |
| std::move(decoded_ids_buffer), decode_config.GetConstraint(), | |
| benchmark_info_, std::move(callback), &cancelled_, | |
| decode_config.GetMaxOutputTokens().value_or( | |
| session_config_.GetMaxOutputTokens()))); | |
| } | |
| return absl::OkStatus(); | |
| } | |
| absl::StatusOr<Responses> SessionBasic::RunDecode() { | |
| return RunDecode(DecodeConfig::CreateDefault()); | |
| } | |
| absl::StatusOr<Responses> SessionBasic::RunDecode( | |
| const DecodeConfig& decode_config) { | |
| ABSL_LOG(INFO) << "RunDecodeSync"; | |
| if (cancelled_.load()) { | |
| // Reset the cancelled flag before processing the next turn. | |
| cancelled_ = false; | |
| } | |
| return DecodeInternal(decode_config); | |
| } | |
| absl::StatusOr<std::unique_ptr<TaskController>> SessionBasic::RunDecodeAsync( | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback) { | |
| return RunDecodeAsync(std::move(callback), DecodeConfig::CreateDefault()); | |
| } | |
| absl::StatusOr<std::unique_ptr<TaskController>> SessionBasic::RunDecodeAsync( | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback, | |
| const DecodeConfig& decode_config) { | |
| ABSL_LOG(INFO) << "RunDecodeAsync"; | |
| if (cancelled_.load()) { | |
| // Reset the cancelled flag before processing the next turn. | |
| cancelled_ = false; | |
| } | |
| RETURN_IF_ERROR(worker_thread_pool_.Schedule( | |
| [this, callback = std::move(callback), decode_config]() mutable { | |
| this->DecodeInternalStreaming(std::move(callback), decode_config) | |
| .IgnoreError(); | |
| })); | |
| return nullptr; | |
| } | |
| absl::StatusOr<Responses> SessionBasic::GenerateContent( | |
| const std::vector<InputData>& contents) { | |
| if (cancelled_.load()) { | |
| // Reset the cancelled flag before processing the next turn. | |
| cancelled_ = false; | |
| } | |
| RETURN_IF_ERROR(RunPrefill(contents)); | |
| return RunDecode(DecodeConfig::CreateDefault()); | |
| } | |
| absl::StatusOr<Responses> SessionBasic::RunTextScoring( | |
| const std::vector<absl::string_view>& target_text, | |
| bool store_token_lengths) { | |
| absl::StatusOr<Responses> collected_responses; | |
| auto scoring_sync_callback = | |
| [&collected_responses](absl::StatusOr<Responses> responses) { | |
| collected_responses = std::move(responses); | |
| }; | |
| ASSIGN_OR_RETURN( | |
| auto task_controller, | |
| RunTextScoringAsync(target_text, std::move(scoring_sync_callback), | |
| store_token_lengths)); | |
| RETURN_IF_ERROR(worker_thread_pool_.WaitUntilDone(Engine::kDefaultTimeout)); | |
| return collected_responses; | |
| } | |
| absl::StatusOr<std::unique_ptr<Engine::Session::TaskController>> | |
| SessionBasic::RunTextScoringAsync( | |
| const std::vector<absl::string_view>& target_text, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback, | |
| bool store_token_lengths) { | |
| if (target_text.size() != 1) { | |
| return absl::InvalidArgumentError("Target text size should be 1."); | |
| } | |
| // TODO(b/435040163): Handle the temperature. Should it be calculated from | |
| // the sampler or the sampler parameters? For now, hardcode it to 1.0f for | |
| // testing. | |
| auto temperature = 1.0f; | |
| RETURN_IF_ERROR(worker_thread_pool_.Schedule( | |
| [this, callback = std::move(callback), target_text, store_token_lengths, | |
| temperature]() mutable { | |
| std::vector<int> decoded_ids(session_config_.GetNumOutputCandidates(), | |
| last_prefill_token_id_); | |
| auto decoded_ids_buffer = CopyToTensorBuffer<int>( | |
| decoded_ids, {session_config_.GetNumOutputCandidates(), 1}); | |
| if (!decoded_ids_buffer.HasValue()) { | |
| callback(absl::InternalError(decoded_ids_buffer.Error().Message())); | |
| return; | |
| } | |
| callback(ScoreCustomSampling( | |
| executor_, tokenizer_, target_text, temperature, | |
| std::move(decoded_ids_buffer.Value()), store_token_lengths)); | |
| })); | |
| return nullptr; | |
| } | |
| absl::Status SessionBasic::GenerateContentStream( | |
| const std::vector<InputData>& contents, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback) { | |
| return GenerateContentStream(contents, std::move(callback), | |
| DecodeConfig::CreateDefault()); | |
| } | |
| absl::Status SessionBasic::GenerateContentStream( | |
| const std::vector<InputData>& contents, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback, | |
| const DecodeConfig& decode_config) { | |
| if (cancelled_.load()) { | |
| // Reset the cancelled flag before processing the next turn. | |
| cancelled_ = false; | |
| } | |
| ASSIGN_OR_RETURN( | |
| auto task_controller, | |
| RunPrefillAsync( | |
| contents, | |
| [this, callback = std::move(callback), decode_config = decode_config]( | |
| absl::StatusOr<Responses> responses) mutable { | |
| if (!responses.ok()) { | |
| callback(responses.status()); | |
| } else { | |
| if (cancelled_.load()) { | |
| callback(absl::CancelledError( | |
| "Session is cancelled during prefill.")); | |
| return; | |
| } | |
| auto status = RunDecodeAsync(std::move(callback), decode_config); | |
| } | |
| })); | |
| return absl::OkStatus(); | |
| } | |
| absl::StatusOr<BenchmarkInfo> SessionBasic::GetBenchmarkInfo() { | |
| if (benchmark_info_.has_value()) { | |
| return benchmark_info_.value(); | |
| } | |
| return absl::InternalError( | |
| "Benchmark is not enabled. Please make sure the BenchmarkParams is set " | |
| "in the EngineSettings."); | |
| } | |
| absl::StatusOr<BenchmarkInfo*> SessionBasic::GetMutableBenchmarkInfo() { | |
| if (benchmark_info_.has_value()) { | |
| return &benchmark_info_.value(); | |
| } | |
| return absl::InternalError( | |
| "Benchmark is not enabled. Please make sure the BenchmarkParams is set " | |
| "in the EngineSettings."); | |
| } | |
| } // namespace litert::lm | |