// Copyright 2025 The ODML Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/core/session_basic.h" #include #include // NOLINT: Required for path manipulation. #include #include #include #include #include #include #include #include #include "absl/container/flat_hash_map.h" // from @com_google_absl #include "absl/functional/any_invocable.h" // from @com_google_absl #include "absl/memory/memory.h" // from @com_google_absl #include "absl/status/status.h" // from @com_google_absl #include "absl/status/statusor.h" // from @com_google_absl #include "absl/strings/str_join.h" // from @com_google_absl #include "absl/strings/string_view.h" // from @com_google_absl #include "absl/synchronization/notification.h" // from @com_google_absl #include "absl/time/clock.h" // from @com_google_absl #include "absl/time/time.h" // from @com_google_absl #include "litert/cc/litert_environment.h" // from @litert #include "litert/cc/litert_tensor_buffer.h" // from @litert #include "litert/test/matchers.h" // from @litert #include "runtime/components/constrained_decoding/fake_constraint.h" #include "runtime/components/sentencepiece_tokenizer.h" #include "runtime/components/tokenizer.h" #include "runtime/engine/engine_settings.h" #include "runtime/engine/io_types.h" #include "runtime/executor/audio_executor_settings.h" #include "runtime/executor/audio_litert_compiled_model_executor.h" #include "runtime/executor/executor_settings_base.h" #include "runtime/executor/fake_llm_executor.h" #include "runtime/executor/llm_executor_io_types.h" #include "runtime/framework/threadpool.h" #include "runtime/util/convert_tensor_buffer.h" #include "runtime/util/scoped_file.h" #include "runtime/util/status_macros.h" #include "runtime/util/test_utils.h" // IWYU pragma: keep namespace litert::lm { namespace { using ::testing::status::StatusIs; constexpr absl::string_view kTestdataDir = "litert_lm/runtime/components/testdata/"; constexpr absl::string_view kTestAudioModelPath = "litert_lm/runtime/testdata/dummy_audio_only.litertlm"; constexpr int kSpectrogramFrequencySlots = 8; constexpr int kSpectrogramSequenceLength = 10; constexpr int kEmbeddingSequenceLength = 5; constexpr int kEmbeddingDimensions = 6; // Audio embedding tensor will have shape [1, kEmbeddingSequenceLength, // kEmbeddingDimensions]. constexpr std::array kExpectedAudioEmbedding = {0., 0., 0., 0., 0., 0., 0., 1., 2., 3., 3., 3., 0., 1., 2., 4., 4., 4., 1., 2., 3., 5., 5., 5., 0., 1., 2., 4., 4., 4.}; // Mel spectrogram tensor will have shape [1, kSpectrogramSequenceLength, // kSpectrogramFrequencySlots]. constexpr std::array mel_spectrogram_data = { 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 1., 1., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 1., 1., 1., 1., 1., 0., 0., 1., 1., 0., 1., 0., 0., 1., 0., 1., 0., 1., 1., 0., 0., 1., 0., 1., 0., 0., 0., 1., 0., 1., 1., 0., 1., 0., 0., 0., 1., 0., 1., 1., 1., 1.}; absl::StatusOr> CreateFakeLlmExecutor( std::vector> prefill_tokens, std::vector> decode_tokens, std::optional> audio_embedding = std::nullopt) { auto batch_size = decode_tokens.empty() ? 1 : decode_tokens[0].size(); auto fake_executor = std::make_unique( 2560, prefill_tokens, decode_tokens, batch_size, audio_embedding); return std::move(fake_executor); } absl::StatusOr RunTextScoring( const std::vector>& prefill_tokens, const std::vector>& decode_tokens, absl::string_view input_prompt, absl::string_view target_text, const proto::SamplerParameters& sampler_params, bool store_token_lengths, Tokenizer* tokenizer, ThreadPool* worker_thread_pool) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); ASSIGN_OR_RETURN(auto executor, CreateFakeLlmExecutor(prefill_tokens, decode_tokens)); auto session = SessionBasic::Create( executor.get(), tokenizer, /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool); std::vector inputs; inputs.emplace_back(InputText(std::string(input_prompt))); EXPECT_OK((*session)->RunPrefill(inputs)); std::vector target_texts; target_texts.push_back(target_text); return (*session)->RunTextScoring(target_texts, store_token_lengths); } class ExtendedTokenizer : public Tokenizer { public: static absl::StatusOr> CreateFromFile( absl::string_view model_path) { ASSIGN_OR_RETURN(auto tokenizer, SentencePieceTokenizer::CreateFromFile(model_path)); return absl::WrapUnique(new ExtendedTokenizer(std::move(tokenizer))); } void SetExtendedToken(int token_id, absl::string_view token_str) { extended_tokens_to_id_[token_str] = token_id; id_to_extended_tokens_[token_id] = token_str; } absl::StatusOr> TextToTokenIds( absl::string_view text) override { std::vector token_ids; bool is_extended_token_found = false; do { is_extended_token_found = false; for (const auto& [extended_token_str, extended_token_id] : extended_tokens_to_id_) { auto extended_token_pos = text.find(extended_token_str); if (extended_token_pos != std::string::npos) { // The text before the extended token. ASSIGN_OR_RETURN( auto text_ids, tokenizer_->TextToTokenIds(text.substr(0, extended_token_pos))); token_ids.insert(token_ids.end(), text_ids.begin(), text_ids.end()); token_ids.push_back(extended_token_id); text = text.substr(extended_token_pos + extended_token_str.size()); is_extended_token_found = true; } } } while (is_extended_token_found); if (!text.empty()) { ASSIGN_OR_RETURN(auto text_ids, tokenizer_->TextToTokenIds(text)); token_ids.insert(token_ids.end(), text_ids.begin(), text_ids.end()); } return token_ids; } absl::StatusOr TokenIdsToText( const std::vector& token_ids) override { std::vector token_strs; for (int token_id : token_ids) { if (id_to_extended_tokens_.contains(token_id)) { token_strs.push_back(id_to_extended_tokens_[token_id]); } else { token_strs.push_back(tokenizer_->TokenIdsToText({token_id}).value()); } } return absl::StrJoin(token_strs, ""); } absl::StatusOr TokenToId(absl::string_view token) override { if (extended_tokens_to_id_.contains(token)) { return extended_tokens_to_id_[token]; } return tokenizer_->TokenToId(token); } TokenizerType GetTokenizerType() const override { return tokenizer_->GetTokenizerType(); } std::vector GetTokens() const override { return tokenizer_->GetTokens(); } private: explicit ExtendedTokenizer(std::unique_ptr tokenizer) : tokenizer_(std::move(tokenizer)) {}; absl::flat_hash_map id_to_extended_tokens_; absl::flat_hash_map extended_tokens_to_id_; std::unique_ptr tokenizer_; }; class SessionBasicTest : public testing::Test { protected: void SetUp() override { auto tokenizer = ExtendedTokenizer::CreateFromFile( (std::filesystem::path(::testing::SrcDir()) / std::string(kTestdataDir) / "sentencepiece.model") .string()); ASSERT_OK(tokenizer); tokenizer.value()->SetExtendedToken(256000, ""); tokenizer_ = std::move(*tokenizer); sampler_params_.set_type(proto::SamplerParameters::TYPE_UNSPECIFIED); // Creating the thread pool of a single thread to execute the works. worker_thread_pool_ = std::make_unique(/*name_prefix=*/"engine", /*max_num_threads=*/1); } std::unique_ptr tokenizer_; proto::SamplerParameters sampler_params_; std::unique_ptr worker_thread_pool_; }; absl::StatusOr> CreateAudioExecutor(Environment& env, const std::string& model_path, int max_sequence_length, Backend backend) { ASSIGN_OR_RETURN(auto model_file, ScopedFile::Open(model_path)); auto model_file_ptr = std::make_shared(std::move(model_file)); ASSIGN_OR_RETURN(auto model_assets, ModelAssets::Create(model_file_ptr)); // Create the audio executor settings. ASSIGN_OR_RETURN(auto audio_executor_settings, AudioExecutorSettings::CreateDefault( model_assets, max_sequence_length, backend)); // Create the audio executor. return litert::lm::AudioLiteRtCompiledModelExecutor::Create( audio_executor_settings, env); } absl::AnyInvocable)> CreateStreamingTestCallback( absl::Status& status_ref, std::vector& texts_ref, absl::Notification& done_ref, bool delay_on_next = false) { return [&status_ref, &texts_ref, &done_ref, delay_on_next](absl::StatusOr responses) mutable { if (!responses.ok()) { status_ref = std::move(responses.status()); done_ref.Notify(); return; } if (responses->GetTaskState() == TaskState::kDone) { done_ref.Notify(); return; } if (delay_on_next) { absl::SleepFor(absl::Milliseconds(50)); } EXPECT_EQ(responses->GetTexts().size(), 1); texts_ref.push_back(responses->GetTexts()[0]); }; } TEST_F(SessionBasicTest, SessionCreation) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // The prefill tokens are the expected tokens that will be passed in // at each time the Prefill function is called. The values are the // token ids of the input prompt "Hello World!". // The decode tokens are the expected tokens that will be returned // by the Decode function. The values are the token ids of the // output response "How's it going?" followed by the stop token id // (2294). /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get()); EXPECT_OK(session); // Second session creation should fail because there is already a session. EXPECT_THAT(SessionBasic::Create(executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get()), StatusIs(absl::StatusCode::kFailedPrecondition)); // After the first session is destroyed, the second session creation should // succeed. session->reset(); EXPECT_OK(SessionBasic::Create(executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get())); } TEST_F(SessionBasicTest, RunPrefill) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // The prefill tokens are the expected tokens that will be passed in // at each time the Prefill function is called. The values are the // token ids of the input prompt "Hello World!". // The decode tokens are the expected tokens that will be returned // by the Decode function. The values are the token ids of the // output response "How's it going?" followed by the stop token id // (2294). /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get()); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK((*session)->RunPrefill(inputs)); } TEST_F(SessionBasicTest, EmptyInputTextReturnsError) { SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN(auto executor, CreateFakeLlmExecutor( /*prefill_tokens=*/{{}}, /*decode_tokens=*/{{}})); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get()); std::vector inputs; inputs.emplace_back(InputText("")); EXPECT_THAT((*session)->RunPrefill(inputs), StatusIs(absl::StatusCode::kInvalidArgument, "No token IDs found in preprocessed_contents.")); } TEST_F(SessionBasicTest, RunDecode) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get()); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK((*session)->RunPrefill(inputs)); auto responses = (*session)->RunDecode(); EXPECT_OK(responses); // Expect a single output candidate. EXPECT_EQ(responses->GetTexts().size(), 1); // The response is " How's it going?" since "!" is the stop token which is // not included in the response. EXPECT_EQ(responses->GetTexts()[0], " How's it going?"); } TEST_F(SessionBasicTest, RunDecodeWithMaxOutputTokens) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get()); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK((*session)->RunPrefill(inputs)); auto decode_config = DecodeConfig::CreateDefault(); decode_config.SetMaxOutputTokens(2); auto responses = (*session)->RunDecode(decode_config); EXPECT_OK(responses); // Expect a single output candidate. EXPECT_EQ(responses->GetTexts().size(), 1); EXPECT_EQ(responses->GetTexts()[0], " How'"); } TEST_F(SessionBasicTest, RunDecodeWithMultipleOutputCandidates) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetNumOutputCandidates(3); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?", "Hello World", "How's it going?" /*decode_tokens=*/{{224, 90, 224}, {24, 547, 24}, {8, 58, 8}, {66, 735, 66}, {246, 210, 246}, {18, 466, 18}, {2295, 2294, 2295}, {2294, 0, 2294}})); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get()); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK((*session)->RunPrefill(inputs)); auto responses = (*session)->RunDecode(); EXPECT_OK(responses); EXPECT_EQ(responses->GetTexts().size(), 3); // The response is " How's it going?" since "!" is the stop token which is // not included in the response. EXPECT_EQ(responses->GetTexts()[0], " How's it going?"); EXPECT_EQ(responses->GetTexts()[1], " Hello World"); EXPECT_EQ(responses->GetTexts()[2], " How's it going?"); } TEST_F(SessionBasicTest, RunDecodeWithSamplerAndConstrainedDecoding) { // Fake constraint that expects "'s it". std::vector expected_token_ids = {24, 8, 66, 0}; auto constraint = FakeConstraint(expected_token_ids, /*vocabulary_size=*/2560); const std::vector> stop_token_ids = {{2294}, {0}}; // Top P sampler. proto::SamplerParameters sampler_params; sampler_params.set_type(proto::SamplerParameters::TOP_P); sampler_params.set_k(1); sampler_params.set_temperature(1.0); sampler_params.set_p(0.5); sampler_params.set_seed(1); SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( /*prefill_tokens=*/{{2, 224}, // The first prefill. {0}}, // The expected prefill tokens that after // stop tokens are found in decoding with // sampler. That is, the last // sampled tokens at stop condition. // "How's it going?" /*decode_tokens=*/{{24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, /*benchmark_info=*/std::nullopt, worker_thread_pool_.get()); std::vector inputs; inputs.emplace_back(InputText("How")); auto decode_config = DecodeConfig::CreateDefault(); decode_config.SetConstraint(&constraint); EXPECT_OK((*session)->RunPrefill(inputs)); ASSERT_OK_AND_ASSIGN(auto responses, (*session)->RunDecode(decode_config)); // Expect a single output candidate. EXPECT_EQ(responses.GetTexts().size(), 1); EXPECT_EQ(responses.GetTexts()[0], "'s it"); } TEST_F(SessionBasicTest, RunDecodeWithConstrainedDecodingNoSampler) { // Fake constraint that expects "'s it". std::vector expected_token_ids = {24, 8, 66, 0}; auto constraint = FakeConstraint(expected_token_ids, /*vocabulary_size=*/2560); const std::vector> stop_token_ids = {{2294}, {0}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::GPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( /*prefill_tokens=*/{{2, 224}}, // "How's it going?" /*decode_tokens=*/{{24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, /*benchmark_info=*/std::nullopt, worker_thread_pool_.get()); std::vector inputs; inputs.emplace_back(InputText("How")); auto decode_config = DecodeConfig::CreateDefault(); decode_config.SetConstraint(&constraint); EXPECT_OK((*session)->RunPrefill(inputs)); ASSERT_OK_AND_ASSIGN(auto responses, (*session)->RunDecode(decode_config)); // Expect a single output candidate. EXPECT_EQ(responses.GetTexts().size(), 1); EXPECT_EQ(responses.GetTexts()[0], "'s it"); } absl::AnyInvocable)> CreateTestCallback( bool& done_ref) { return [&done_ref](absl::StatusOr responses) mutable { if (responses.ok() && responses->GetTexts().empty()) { done_ref = true; } }; } TEST_F(SessionBasicTest, RunPrefillAsync) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.SetStartTokenId(2); session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get()); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); bool done = false; auto callback = CreateTestCallback(done); EXPECT_OK((*session)->RunPrefillAsync(inputs, std::move(callback))); // Wait for the async call to finish. EXPECT_OK(worker_thread_pool_->WaitUntilDone(absl::Seconds(100))); EXPECT_TRUE(done); } TEST_F(SessionBasicTest, RunDecodeAsync) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.SetStartTokenId(2); session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get()); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); bool done_prefill = false; EXPECT_OK( (*session)->RunPrefillAsync(inputs, CreateTestCallback(done_prefill))); bool done_decode = false; EXPECT_OK((*session)->RunDecodeAsync(CreateTestCallback(done_decode))); EXPECT_OK(worker_thread_pool_->WaitUntilDone(absl::Seconds(100))); EXPECT_TRUE(done_prefill); EXPECT_TRUE(done_decode); } TEST_F(SessionBasicTest, RunDecodeAsyncWithSamplerAndConstrainedDecoding) { // Fake constraint that expects "'s it". std::vector expected_token_ids = {24, 8, 66, 0}; auto constraint = FakeConstraint(expected_token_ids, /*vocabulary_size=*/2560); const std::vector> stop_token_ids = {{2294}, {0}}; // Top P sampler. proto::SamplerParameters sampler_params; sampler_params.set_type(proto::SamplerParameters::TOP_P); sampler_params.set_k(1); sampler_params.set_temperature(1.0); sampler_params.set_p(0.5); sampler_params.set_seed(1); SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( /*prefill_tokens=*/{{2, 224}, // The first prefill. {0}}, // The expected prefill tokens that after // stop tokens are found in decoding with // sampler. That is, the last // sampled tokens at stop condition. // "How's it going?" /*decode_tokens=*/{{24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, /*benchmark_info=*/std::nullopt, worker_thread_pool_.get()); std::vector inputs; inputs.emplace_back(InputText("How")); bool done_prefill = false; EXPECT_OK( (*session)->RunPrefillAsync(inputs, CreateTestCallback(done_prefill))); absl::Status status; std::vector texts; absl::Notification done_decode = absl::Notification(); auto decode_config = DecodeConfig::CreateDefault(); decode_config.SetConstraint(&constraint); EXPECT_OK((*session)->RunDecodeAsync( CreateStreamingTestCallback(status, texts, done_decode), decode_config)); done_decode.WaitForNotification(); EXPECT_OK(status); EXPECT_EQ(texts.size(), 3); EXPECT_THAT(texts, testing::ElementsAre("'", "s", " it")); } TEST_F(SessionBasicTest, RunDecodeAsyncWithConstrainedDecodingNoSampler) { // Fake constraint that expects "'s it". std::vector expected_token_ids = {24, 8, 66, 0}; auto constraint = FakeConstraint(expected_token_ids, /*vocabulary_size=*/2560); const std::vector> stop_token_ids = {{2294}, {0}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( /*prefill_tokens=*/{{2, 224}}, // "How's it going?" /*decode_tokens=*/{{24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, /*benchmark_info=*/std::nullopt, worker_thread_pool_.get()); std::vector inputs; inputs.emplace_back(InputText("How")); bool done_prefill = false; EXPECT_OK( (*session)->RunPrefillAsync(inputs, CreateTestCallback(done_prefill))); absl::Status status; std::vector texts; absl::Notification done_decode = absl::Notification(); auto decode_config = DecodeConfig::CreateDefault(); decode_config.SetConstraint(&constraint); EXPECT_OK((*session)->RunDecodeAsync( CreateStreamingTestCallback(status, texts, done_decode), decode_config)); done_decode.WaitForNotification(); EXPECT_OK(status); EXPECT_EQ(texts.size(), 3); EXPECT_THAT(texts, testing::ElementsAre("'", "s", " it")); } TEST_F(SessionBasicTest, RunTextScoringEmptyTargetTextFailure) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get()); std::vector target_text; EXPECT_THAT((*session)->RunTextScoring(target_text, /*store_token_lengths=*/false), StatusIs(absl::StatusCode::kInvalidArgument, "Target text size should be 1.")); } TEST_F(SessionBasicTest, RunTextScoringMultipleTargetTextFailure) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get()); std::vector target_text; target_text.push_back("How's it going?"); target_text.push_back("How are you?"); EXPECT_THAT( (*session)->RunTextScoring(target_text, /*store_token_lengths=*/false), StatusIs(absl::StatusCode::kInvalidArgument, "Target text size should be 1.")); } TEST_F(SessionBasicTest, RunTextScoringWithoutTokenLengthsSuccess) { const auto responses_without_token_lengths = RunTextScoring( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{{224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}}, /*input_prompt=*/"Hello World!", /*target_text=*/"How's it going?", sampler_params_, /*store_token_lengths=*/false, tokenizer_.get(), worker_thread_pool_.get()); EXPECT_OK(responses_without_token_lengths); // Expect a single output candidate with score 0.0f. EXPECT_EQ(responses_without_token_lengths->GetScores().size(), 1); EXPECT_EQ(responses_without_token_lengths->GetScores()[0], 0.0f); EXPECT_FALSE(responses_without_token_lengths->GetTokenLengths().has_value()); } TEST_F(SessionBasicTest, RunTextScoringWithTokenLengthsSuccess) { const auto responses_with_token_lengths = RunTextScoring( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{{224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}}, /*input_prompt=*/"Hello World!", /*target_text=*/"How's it going?", sampler_params_, /*store_token_lengths=*/true, tokenizer_.get(), worker_thread_pool_.get()); EXPECT_OK(responses_with_token_lengths); // Expect a single output candidate with score 0.0f and token length 7. EXPECT_EQ(responses_with_token_lengths->GetScores().size(), 1); EXPECT_EQ(responses_with_token_lengths->GetScores()[0], 0.0f); EXPECT_TRUE(responses_with_token_lengths->GetTokenLengths().has_value()); EXPECT_EQ(responses_with_token_lengths->GetTokenLengths()->size(), 1); EXPECT_EQ((*responses_with_token_lengths->GetTokenLengths())[0], 7); } TEST_F(SessionBasicTest, RunTextScoringAsyncWithTokenLengthsSuccess) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get()); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK((*session)->RunPrefill(inputs)); absl::Notification done; absl::StatusOr responses; auto callback = [&done, &responses](absl::StatusOr r) { responses = std::move(r); done.Notify(); }; std::vector target_texts; target_texts.push_back("How's it going?"); ASSERT_OK((*session)->RunTextScoringAsync(target_texts, std::move(callback), /*store_token_lengths=*/true)); done.WaitForNotification(); EXPECT_OK(responses); // Expect a single output candidate with score 0.0f and token length 7. EXPECT_EQ(responses->GetScores().size(), 1); EXPECT_EQ(responses->GetScores()[0], 0.0f); EXPECT_TRUE(responses->GetTokenLengths().has_value()); EXPECT_EQ(responses->GetTokenLengths()->size(), 1); EXPECT_EQ((*responses->GetTokenLengths())[0], 7); } TEST_F(SessionBasicTest, GenerateContentStream) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, /*benchmark_info=*/std::nullopt, worker_thread_pool_.get()); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); absl::Status status; std::vector texts; absl::Notification done = absl::Notification(); EXPECT_OK((*session)->GenerateContentStream( inputs, CreateStreamingTestCallback(status, texts, done))); done.WaitForNotification(); EXPECT_OK(status); EXPECT_EQ(texts.size(), 7); EXPECT_THAT(texts, testing::ElementsAre(" How", "'", "s", " it", " go", "ing", "?")); } TEST_F(SessionBasicTest, GenerateContentStreamEmptyInput) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get()); std::vector inputs; absl::Status status; std::vector texts; absl::Notification done; EXPECT_THAT((*session)->GenerateContentStream( inputs, CreateStreamingTestCallback(status, texts, done)), StatusIs(absl::StatusCode::kInvalidArgument, "Input is empty.")); } TEST_F(SessionBasicTest, GenerateContentStreamPrefillError) { // Configure the executor to fail at prefill. ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); auto* fake_executor = static_cast(executor.get()); fake_executor->SetPrefillStatus(absl::InternalError("Prefill failed")); const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get()); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); absl::Status status; std::vector texts; absl::Notification done; EXPECT_OK((*session)->GenerateContentStream( inputs, CreateStreamingTestCallback(status, texts, done))); done.WaitForNotification(); EXPECT_FALSE(status.ok()); EXPECT_THAT(status, StatusIs(absl::StatusCode::kInternal, "Prefill failed")); } TEST_F(SessionBasicTest, GenerateContentStreamDecodeError) { // Configure the executor to fail at decode. ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); auto* fake_executor = static_cast(executor.get()); fake_executor->SetDecodeStatus(absl::InternalError("Decode failed")); const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get()); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); absl::Status status; std::vector texts; absl::Notification done; EXPECT_OK((*session)->GenerateContentStream( inputs, CreateStreamingTestCallback(status, texts, done))); done.WaitForNotification(); EXPECT_FALSE(status.ok()); EXPECT_THAT(status, StatusIs(absl::StatusCode::kInternal, "Decode failed")); } TEST_F(SessionBasicTest, ProcessAndCombineContentsSingleText) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( "User\n"); session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( "\n"); session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( "Model\n"); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); ASSERT_OK_AND_ASSIGN( auto session, SessionBasic::Create(executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get())); std::vector preprocessed_contents; ASSERT_OK_AND_ASSIGN(auto ids_buffer, tokenizer_->TokenIdsToTensorBuffer( {90, 547, 58, 735, 210, 466})); preprocessed_contents.emplace_back(InputText(std::move(ids_buffer))); auto executor_input_or = session->ProcessAndCombineContents(preprocessed_contents); ASSERT_OK(executor_input_or); ASSERT_OK_AND_ASSIGN(auto text_data, executor_input_or->GetTextDataPtr()); ASSERT_NE(text_data, nullptr); LITERT_ASSERT_OK_AND_ASSIGN( auto token_ids_span, ReferTensorBufferAsSpan(text_data->GetTokenIds())); EXPECT_THAT(std::vector(token_ids_span.begin(), token_ids_span.end()), testing::ElementsAre(90, 547, 58, 735, 210, 466)); } TEST_F(SessionBasicTest, ProcessAndCombineContentsMultiText) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( "User"); session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( "\n"); session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( "Model\n"); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); ASSERT_OK_AND_ASSIGN( auto session, SessionBasic::Create(executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get())); std::vector preprocessed_contents; LITERT_ASSERT_OK_AND_ASSIGN(auto ids_buffer1, tokenizer_->TokenIdsToTensorBuffer({90, 547})); preprocessed_contents.emplace_back(InputText(std::move(ids_buffer1))); LITERT_ASSERT_OK_AND_ASSIGN( auto ids_buffer2, tokenizer_->TokenIdsToTensorBuffer({58, 735, 210, 466})); preprocessed_contents.emplace_back(InputText(std::move(ids_buffer2))); auto executor_input_or = session->ProcessAndCombineContents(preprocessed_contents); ASSERT_OK(executor_input_or); ASSERT_OK_AND_ASSIGN(auto text_data, executor_input_or->GetTextDataPtr()); ASSERT_NE(text_data, nullptr); LITERT_ASSERT_OK_AND_ASSIGN( auto token_ids_span, ReferTensorBufferAsSpan(text_data->GetTokenIds())); EXPECT_THAT(std::vector(token_ids_span.begin(), token_ids_span.end()), testing::ElementsAre(90, 547, 58, 735, 210, 466)); } TEST_F(SessionBasicTest, ProcessAndCombineContentsEmptyFails) { SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); ASSERT_OK_AND_ASSIGN( auto session, SessionBasic::Create(executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get())); std::vector preprocessed_contents; auto result = session->ProcessAndCombineContents(preprocessed_contents); EXPECT_THAT(result, StatusIs(absl::StatusCode::kInvalidArgument, "No token IDs found in preprocessed_contents.")); } TEST_F(SessionBasicTest, RunIncrementalPrefillWithDecode) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( "User:"); session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( "[END]"); session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( "Model:"); session_config.GetMutableLlmModelType().mutable_gemma3n(); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( /*prefill_tokens=*/ { {2, 423, 8, 179, 29, 207, 19, 547, 58}, // prefill chunk 1.1 {735, 210, 466, 2294}, // prefill chunk 1.2 {433, 2172, 1920, 432, 197, 979, 3076, 29}, // prefill ran before decode with turn change template {423, 8, 179, 29, 207, 19, 547, 58, 735, 210, 466, 2294}, // prefill chunk 2.1 {433, 2172, 1920, 432, 197, 979, 3076, 29}, // prefill ran before decode with turn change template }, /*decode_tokens=*/ {{1}, {2}, {3}, {2294}, {1}, {2}, {3}, {2294}})); ASSERT_OK_AND_ASSIGN( auto session, SessionBasic::Create(executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get())); { std::vector inputs; inputs.emplace_back(InputText("Hello ")); EXPECT_OK(session->RunPrefill(inputs)); } { std::vector inputs; inputs.emplace_back(InputText("World!")); EXPECT_OK(session->RunPrefill(inputs)); } { EXPECT_OK(session->RunDecode()); } { std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK(session->RunPrefill(inputs)); } { EXPECT_OK(session->RunDecode()); } } // TODO: b/441514829 - Enable the tests on Windows once the bug is fixed. #if !defined(WIN32) && !defined(_WIN32) && !defined(__WIN32__) && \ !defined(__NT__) && !defined(_WIN64) TEST_F(SessionBasicTest, ProcessAndCombineContentsAudioSuccess) { SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); session_config.GetMutableLlmModelType().mutable_gemma3n(); LITERT_ASSERT_OK_AND_ASSIGN( auto env, Environment::Create(std::vector())); ASSERT_OK_AND_ASSIGN( auto audio_executor, CreateAudioExecutor(env, (std::filesystem::path(::testing::SrcDir()) / std::string(kTestAudioModelPath)) .string(), /*max_sequence_length=*/0, Backend::CPU)); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294, 256000}}, // "How's it going?" /*decode_tokens=*/ {{224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}}, /*audio_embedding=*/ std::vector(kExpectedAudioEmbedding.begin(), kExpectedAudioEmbedding.end()))); ASSERT_OK_AND_ASSIGN( auto session, SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/audio_executor.get(), session_config, std::nullopt, worker_thread_pool_.get())); std::vector preprocessed_contents; ASSERT_OK_AND_ASSIGN( auto ids_buffer, tokenizer_->TokenIdsToTensorBuffer({90, 547, 58, 735, 210, 466, 256000})); preprocessed_contents.emplace_back(InputText(std::move(ids_buffer))); LITERT_ASSERT_OK_AND_ASSIGN( TensorBuffer mel_spectrogram_data, CopyToTensorBuffer( mel_spectrogram_data, {1, kSpectrogramSequenceLength, kSpectrogramFrequencySlots})); InputAudio input_audio(std::move(mel_spectrogram_data)); preprocessed_contents.emplace_back(std::move(input_audio)); preprocessed_contents.emplace_back(InputAudioEnd()); ASSERT_OK_AND_ASSIGN( auto result, session->ProcessAndCombineContents(preprocessed_contents)); ASSERT_OK_AND_ASSIGN(auto text_data, result.GetTextDataPtr()); ASSERT_NE(text_data, nullptr); LITERT_ASSERT_OK_AND_ASSIGN( auto token_ids_span, ReferTensorBufferAsSpan(text_data->GetTokenIds())); // The input to audio executor has length 10. // The processed audio embedding has length 5. EXPECT_THAT(std::vector(token_ids_span.begin(), token_ids_span.end()), testing::ElementsAre(90, 547, 58, 735, 210, 466, 256000, -2, -2, -2, -2, -2, -4)); } TEST_F(SessionBasicTest, ProcessAndCombineContentsTextAndAudioSuccess) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( "User:"); session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( "[END]"); session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( "Model:"); session_config.GetMutableLlmModelType().mutable_gemma3n(); LITERT_ASSERT_OK_AND_ASSIGN( auto env, Environment::Create(std::vector())); ASSERT_OK_AND_ASSIGN( auto audio_executor, CreateAudioExecutor(env, (std::filesystem::path(::testing::SrcDir()) / std::string(kTestAudioModelPath)) .string(), /*max_sequence_length=*/0, Backend::CPU)); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "User:Hello World![END]Model:" /*prefill_tokens=*/{{2, 423, 8, 179, 29, 207, 19, 547, 58, 735, 210, 466, 2294, 256000, -2, -2, -2, -2, -2, -4}, {433, 2172, 1920, 432, 197, 979, 3076, 29}}, // "How's it going?" /*decode_tokens=*/ {{224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}}, /*audio_embedding=*/ std::vector(kExpectedAudioEmbedding.begin(), kExpectedAudioEmbedding.end()))); ASSERT_OK_AND_ASSIGN( auto session, SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/audio_executor.get(), session_config, std::nullopt, worker_thread_pool_.get())); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); LITERT_ASSERT_OK_AND_ASSIGN( TensorBuffer mel_spectrogram_data, CopyToTensorBuffer( mel_spectrogram_data, {1, kSpectrogramSequenceLength, kSpectrogramFrequencySlots})); InputAudio input_audio(std::move(mel_spectrogram_data)); inputs.emplace_back(std::move(input_audio)); inputs.emplace_back(InputAudioEnd()); EXPECT_OK(session->RunPrefill(inputs)); } TEST_F(SessionBasicTest, ProcessAndCombineContentsTextAudioTextSuccess) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( "User:"); session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( "[END]"); session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( "Model:"); session_config.GetMutableLlmModelType().mutable_gemma3n(); LITERT_ASSERT_OK_AND_ASSIGN( auto env, Environment::Create(std::vector())); ASSERT_OK_AND_ASSIGN( auto audio_executor, CreateAudioExecutor(env, (std::filesystem::path(::testing::SrcDir()) / std::string(kTestAudioModelPath)) .string(), /*max_sequence_length=*/0, Backend::CPU)); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "User:Hello World!What does the audio say?" // "[END]Model:" /*prefill_tokens=*/ {{2, 423, 8, 179, 29, 207, 19, 547, 58, 735, 210, 466, 2294, 256000, -2, -2, -2, -2, -2, -4, 583, 378, 844, 166, 3, 14, 1252, 54, 58, 626, 2295}, {3995, 2172, 1920, 432, 197, 979, 3076, 29}}, // "How's it going?" /*decode_tokens=*/ {{224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}}, /*audio_embedding=*/ std::vector(kExpectedAudioEmbedding.begin(), kExpectedAudioEmbedding.end()))); ASSERT_OK_AND_ASSIGN( auto session, SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/audio_executor.get(), session_config, std::nullopt, worker_thread_pool_.get())); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); LITERT_ASSERT_OK_AND_ASSIGN( TensorBuffer mel_spectrogram_data, CopyToTensorBuffer( mel_spectrogram_data, {1, kSpectrogramSequenceLength, kSpectrogramFrequencySlots})); InputAudio input_audio(std::move(mel_spectrogram_data)); inputs.emplace_back(std::move(input_audio)); inputs.emplace_back(InputAudioEnd()); inputs.emplace_back(InputText("What does the audio say?")); EXPECT_OK(session->RunPrefill(inputs)); } #endif // !defined(WIN32) && !defined(_WIN32) && !defined(__WIN32__) && \ // !defined(__NT__) && !defined(_WIN64) TEST_F(SessionBasicTest, GenerateContentStreamWithCancellation) { // Configure the executor to have a delay to simulate a long-running task. ASSERT_OK_AND_ASSIGN( auto fake_executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); fake_executor->SetDecodeDelay(absl::Milliseconds(200)); const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); auto session = SessionBasic::Create(fake_executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get()); ASSERT_OK(session); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); absl::Status status; std::vector responses; absl::Notification done; (*session) ->GenerateContentStream( inputs, CreateStreamingTestCallback(status, responses, done, /*delay_on_next=*/true)) .IgnoreError(); // Wait for a short time to ensure the decoding has started. absl::SleepFor(absl::Milliseconds(100)); // Cancel the process. (*session)->CancelProcess(); // Wait for the callback to be done. done.WaitForNotification(); EXPECT_THAT(status, StatusIs(absl::StatusCode::kCancelled)); } TEST_F(SessionBasicTest, GenerateContentStreamAndDelete) { // Configure the executor to have a delay to simulate a long-running task. ASSERT_OK_AND_ASSIGN( auto fake_executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); fake_executor->SetDecodeDelay(absl::Milliseconds(200)); const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); auto session = SessionBasic::Create(fake_executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get()); ASSERT_OK(session); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); absl::Status status; std::vector responses; absl::Notification done; (*session) ->GenerateContentStream( inputs, CreateStreamingTestCallback(status, responses, done, /*delay_on_next=*/true)) .IgnoreError(); // Reset the session to trigger destructor. session->reset(); // Wait for thread pool to finish. If destructor does not wait, this might // cause use-after-free. worker_thread_pool_->WaitUntilDone(absl::Seconds(10)).IgnoreError(); if (!done.HasBeenNotified()) { done.Notify(); } } class SessionBasicCancellationTest : public testing::TestWithParam { protected: void SetUp() override { auto tokenizer = ExtendedTokenizer::CreateFromFile( (std::filesystem::path(::testing::SrcDir()) / std::string(kTestdataDir) / "sentencepiece.model") .string()); ASSERT_OK(tokenizer); tokenizer.value()->SetExtendedToken(256000, ""); tokenizer_ = std::move(*tokenizer); sampler_params_.set_type(proto::SamplerParameters::TYPE_UNSPECIFIED); // Creating the thread pool of a single thread to execute the works. worker_thread_pool_ = std::make_unique(/*name_prefix=*/"engine", /*max_num_threads=*/1); } bool use_benchmark_info_ = GetParam(); std::unique_ptr tokenizer_; proto::SamplerParameters sampler_params_; std::unique_ptr worker_thread_pool_; }; TEST_P(SessionBasicCancellationTest, GenerateContentStreamCancelThenGenerateWithBenchmark) { // Configure the executor to have a delay to simulate a long-running task. ASSERT_OK_AND_ASSIGN( auto fake_executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}, // The second prefill doesn't have bos token. {90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); fake_executor->SetDecodeDelay(absl::Milliseconds(200)); const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); std::optional benchmark_info; if (use_benchmark_info_) { proto::BenchmarkParams benchmark_params; benchmark_info.emplace(benchmark_params); } auto session = SessionBasic::Create(fake_executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, benchmark_info, worker_thread_pool_.get()); ASSERT_OK(session); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); absl::Status status; std::vector responses; absl::Notification done1; (*session) ->GenerateContentStream( inputs, CreateStreamingTestCallback(status, responses, done1, /*delay_on_next=*/true)) .IgnoreError(); // Cancel the process. (*session)->CancelProcess(); // Wait for the callback to be done. done1.WaitForNotification(); EXPECT_THAT(status, StatusIs(absl::StatusCode::kCancelled)); // Generate again after cancellation. // The second generation should succeed. status = absl::OkStatus(); responses.clear(); absl::Notification done2; (*session) ->GenerateContentStream( inputs, CreateStreamingTestCallback(status, responses, done2, /*delay_on_next=*/true)) .IgnoreError(); done2.WaitForNotification(); EXPECT_OK(status); // Reset worker thread pool to stop accessing session and fake executor. // We need to reset session first because it holds a reference to the // thread pool. (*session).reset(); worker_thread_pool_.reset(); } INSTANTIATE_TEST_SUITE_P(SessionBasicCancellationTest, SessionBasicCancellationTest, testing::Bool(), testing::PrintToStringParamName()); TEST_F(SessionBasicTest, GenerateContentStreamOnCancelledSession) { ASSERT_OK_AND_ASSIGN( auto fake_executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); auto session = SessionBasic::Create(fake_executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get()); ASSERT_OK(session); (*session)->CancelProcess(); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); absl::Status status; std::vector responses; absl::Notification done; // The session is cancelled, so the call should return with a kCancelled // error. EXPECT_OK((*session)->GenerateContentStream( inputs, CreateStreamingTestCallback(status, responses, done))); // Wait for the callback to be done. done.WaitForNotification(); EXPECT_OK(status); } TEST_F(SessionBasicTest, TestBenchmarkModeWithoutNumPrefillTokensRespectPromptTemplate) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( "User\n"); session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( "\n"); session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( "Model\n"); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // Expected tokens: "User\nHello World!" + // "\nModel\n" /*prefill_tokens=*/{{2, 4, 0, 39, 637, 0, 3328, 8, 179, 90, 547, 58, 735, 210, 466, 2294}, {0, 40, 23, 0, 4, 0, 39, 637, 0, 197, 979, 3076}}, /*decode_tokens=*/{{224}})); proto::BenchmarkParams benchmark_params; BenchmarkInfo benchmark_info(benchmark_params); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, benchmark_info, worker_thread_pool_.get()); ASSERT_OK(session); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK((*session)->RunPrefill(inputs)); EXPECT_EQ((*session)->GetBenchmarkInfo()->GetTotalPrefillTurns(), 1); } TEST_F(SessionBasicTest, TestBenchmarkModeWithNumPrefillTokensIgnorePromptTemplate) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( "User\n"); session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( "\n"); session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( "Model\n"); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // Expected tokens: "Hello World!" (No templates) /*prefill_tokens=*/{{90, 547, 58, 735, 210, 466, 2294}}, /*decode_tokens=*/{{224}})); proto::BenchmarkParams benchmark_params; benchmark_params.set_num_prefill_tokens(7); BenchmarkInfo benchmark_info(benchmark_params); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, benchmark_info, worker_thread_pool_.get()); ASSERT_OK(session); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK((*session)->RunPrefill(inputs)); EXPECT_EQ((*session)->GetBenchmarkInfo()->GetTotalPrefillTurns(), 1); } TEST_F(SessionBasicTest, GenerateContentStreamWithSamplerAndConstrainedDecoding) { // Fake constraint that expects "'s it". std::vector expected_token_ids = {24, 8, 66, 0}; auto constraint = FakeConstraint(expected_token_ids, /*vocabulary_size=*/2560); const std::vector> stop_token_ids = {{2294}, {0}}; // Top P sampler. proto::SamplerParameters sampler_params; sampler_params.set_type(proto::SamplerParameters::TOP_P); sampler_params.set_k(1); sampler_params.set_temperature(1.0); sampler_params.set_p(0.5); sampler_params.set_seed(1); SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( /*prefill_tokens=*/{{2, 224}, // The first prefill. {0}}, // The expected prefill tokens that after // stop tokens are found in decoding with // sampler. That is, the last // sampled tokens at stop condition. // "How's it going?" /*decode_tokens=*/{{24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, /*benchmark_info=*/std::nullopt, worker_thread_pool_.get()); std::vector inputs; inputs.emplace_back(InputText("How")); absl::Status status; std::vector texts; absl::Notification done_decode = absl::Notification(); auto decode_config = DecodeConfig::CreateDefault(); decode_config.SetConstraint(&constraint); EXPECT_OK((*session)->GenerateContentStream( inputs, CreateStreamingTestCallback(status, texts, done_decode), decode_config)); done_decode.WaitForNotification(); EXPECT_OK(status); EXPECT_EQ(texts.size(), 3); EXPECT_THAT(texts, testing::ElementsAre("'", "s", " it")); } TEST_F(SessionBasicTest, SaveAndRewindCheckpoint) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); // Set up a fake executor. ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}, {2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); auto session = SessionBasic::Create( executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, /*audio_executor=*/nullptr, session_config, std::nullopt, worker_thread_pool_.get()); ASSERT_OK(session); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); // Save checkpoint before prefill. EXPECT_OK((*session)->SaveCheckpoint("start")); EXPECT_THAT((*session)->GetCurrentStep(), testing::status::IsOkAndHolds(0)); EXPECT_OK((*session)->RunPrefill(inputs)); // Save checkpoint after prefill. EXPECT_OK((*session)->SaveCheckpoint("prefill")); EXPECT_THAT((*session)->GetCurrentStep(), testing::status::IsOkAndHolds(8)); // Rewind should succeed and restore the state. EXPECT_OK((*session)->RewindToCheckpoint("start")); EXPECT_THAT((*session)->GetCurrentStep(), testing::status::IsOkAndHolds(0)); // The "prefill" checkpoint should no longer exist. EXPECT_THAT((*session)->RewindToCheckpoint("prefill"), testing::status::StatusIs(absl::StatusCode::kNotFound)); // Checkpoint that doesn't exist should fail. EXPECT_THAT((*session)->RewindToCheckpoint("nonexistent"), testing::status::StatusIs(absl::StatusCode::kNotFound)); } } // namespace } // namespace litert::lm