Spaces:
Running
Running
| // Copyright 2025 The ODML Authors. | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| namespace litert::lm { | |
| namespace { | |
| using ::testing::status::StatusIs; | |
| constexpr absl::string_view kTestdataDir = | |
| "litert_lm/runtime/components/testdata/"; | |
| constexpr absl::string_view kTestAudioModelPath = | |
| "litert_lm/runtime/testdata/dummy_audio_only.litertlm"; | |
| constexpr int kSpectrogramFrequencySlots = 8; | |
| constexpr int kSpectrogramSequenceLength = 10; | |
| constexpr int kEmbeddingSequenceLength = 5; | |
| constexpr int kEmbeddingDimensions = 6; | |
| // Audio embedding tensor will have shape [1, kEmbeddingSequenceLength, | |
| // kEmbeddingDimensions]. | |
| constexpr std::array<float, kEmbeddingSequenceLength * kEmbeddingDimensions> | |
| kExpectedAudioEmbedding = {0., 0., 0., 0., 0., 0., 0., 1., 2., 3., | |
| 3., 3., 0., 1., 2., 4., 4., 4., 1., 2., | |
| 3., 5., 5., 5., 0., 1., 2., 4., 4., 4.}; | |
| // Mel spectrogram tensor will have shape [1, kSpectrogramSequenceLength, | |
| // kSpectrogramFrequencySlots]. | |
| constexpr std::array<float, | |
| kSpectrogramSequenceLength * kSpectrogramFrequencySlots> | |
| mel_spectrogram_data = { | |
| 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 1., 0., 0., 0., 0., | |
| 0., 1., 0., 0., 1., 1., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0., | |
| 0., 1., 0., 1., 0., 0., 1., 1., 1., 1., 1., 0., 0., 1., 1., 0., | |
| 1., 0., 0., 1., 0., 1., 0., 1., 1., 0., 0., 1., 0., 1., 0., 0., | |
| 0., 1., 0., 1., 1., 0., 1., 0., 0., 0., 1., 0., 1., 1., 1., 1.}; | |
| absl::StatusOr<std::unique_ptr<FakeLlmExecutor>> CreateFakeLlmExecutor( | |
| std::vector<std::vector<int>> prefill_tokens, | |
| std::vector<std::vector<int>> decode_tokens, | |
| std::optional<std::vector<float>> audio_embedding = std::nullopt) { | |
| auto batch_size = decode_tokens.empty() ? 1 : decode_tokens[0].size(); | |
| auto fake_executor = std::make_unique<FakeLlmExecutor>( | |
| 2560, prefill_tokens, decode_tokens, batch_size, audio_embedding); | |
| return std::move(fake_executor); | |
| } | |
| absl::StatusOr<Responses> RunTextScoring( | |
| const std::vector<std::vector<int>>& prefill_tokens, | |
| const std::vector<std::vector<int>>& decode_tokens, | |
| absl::string_view input_prompt, absl::string_view target_text, | |
| const proto::SamplerParameters& sampler_params, bool store_token_lengths, | |
| Tokenizer* tokenizer, ThreadPool* worker_thread_pool) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSIGN_OR_RETURN(auto executor, | |
| CreateFakeLlmExecutor(prefill_tokens, decode_tokens)); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer, /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, std::nullopt, | |
| worker_thread_pool); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText(std::string(input_prompt))); | |
| EXPECT_OK((*session)->RunPrefill(inputs)); | |
| std::vector<absl::string_view> target_texts; | |
| target_texts.push_back(target_text); | |
| return (*session)->RunTextScoring(target_texts, store_token_lengths); | |
| } | |
| class ExtendedTokenizer : public Tokenizer { | |
| public: | |
| static absl::StatusOr<std::unique_ptr<ExtendedTokenizer>> CreateFromFile( | |
| absl::string_view model_path) { | |
| ASSIGN_OR_RETURN(auto tokenizer, | |
| SentencePieceTokenizer::CreateFromFile(model_path)); | |
| return absl::WrapUnique(new ExtendedTokenizer(std::move(tokenizer))); | |
| } | |
| void SetExtendedToken(int token_id, absl::string_view token_str) { | |
| extended_tokens_to_id_[token_str] = token_id; | |
| id_to_extended_tokens_[token_id] = token_str; | |
| } | |
| absl::StatusOr<std::vector<int>> TextToTokenIds( | |
| absl::string_view text) override { | |
| std::vector<int> token_ids; | |
| bool is_extended_token_found = false; | |
| do { | |
| is_extended_token_found = false; | |
| for (const auto& [extended_token_str, extended_token_id] : | |
| extended_tokens_to_id_) { | |
| auto extended_token_pos = text.find(extended_token_str); | |
| if (extended_token_pos != std::string::npos) { | |
| // The text before the extended token. | |
| ASSIGN_OR_RETURN( | |
| auto text_ids, | |
| tokenizer_->TextToTokenIds(text.substr(0, extended_token_pos))); | |
| token_ids.insert(token_ids.end(), text_ids.begin(), text_ids.end()); | |
| token_ids.push_back(extended_token_id); | |
| text = text.substr(extended_token_pos + extended_token_str.size()); | |
| is_extended_token_found = true; | |
| } | |
| } | |
| } while (is_extended_token_found); | |
| if (!text.empty()) { | |
| ASSIGN_OR_RETURN(auto text_ids, tokenizer_->TextToTokenIds(text)); | |
| token_ids.insert(token_ids.end(), text_ids.begin(), text_ids.end()); | |
| } | |
| return token_ids; | |
| } | |
| absl::StatusOr<std::string> TokenIdsToText( | |
| const std::vector<int>& token_ids) override { | |
| std::vector<std::string> token_strs; | |
| for (int token_id : token_ids) { | |
| if (id_to_extended_tokens_.contains(token_id)) { | |
| token_strs.push_back(id_to_extended_tokens_[token_id]); | |
| } else { | |
| token_strs.push_back(tokenizer_->TokenIdsToText({token_id}).value()); | |
| } | |
| } | |
| return absl::StrJoin(token_strs, ""); | |
| } | |
| absl::StatusOr<int> TokenToId(absl::string_view token) override { | |
| if (extended_tokens_to_id_.contains(token)) { | |
| return extended_tokens_to_id_[token]; | |
| } | |
| return tokenizer_->TokenToId(token); | |
| } | |
| TokenizerType GetTokenizerType() const override { | |
| return tokenizer_->GetTokenizerType(); | |
| } | |
| std::vector<std::string> GetTokens() const override { | |
| return tokenizer_->GetTokens(); | |
| } | |
| private: | |
| explicit ExtendedTokenizer(std::unique_ptr<SentencePieceTokenizer> tokenizer) | |
| : tokenizer_(std::move(tokenizer)) {}; | |
| absl::flat_hash_map<int, std::string> id_to_extended_tokens_; | |
| absl::flat_hash_map<std::string, int> extended_tokens_to_id_; | |
| std::unique_ptr<SentencePieceTokenizer> tokenizer_; | |
| }; | |
| class SessionBasicTest : public testing::Test { | |
| protected: | |
| void SetUp() override { | |
| auto tokenizer = ExtendedTokenizer::CreateFromFile( | |
| (std::filesystem::path(::testing::SrcDir()) / | |
| std::string(kTestdataDir) / "sentencepiece.model") | |
| .string()); | |
| ASSERT_OK(tokenizer); | |
| tokenizer.value()->SetExtendedToken(256000, "<start_of_audio>"); | |
| tokenizer_ = std::move(*tokenizer); | |
| sampler_params_.set_type(proto::SamplerParameters::TYPE_UNSPECIFIED); | |
| // Creating the thread pool of a single thread to execute the works. | |
| worker_thread_pool_ = std::make_unique<ThreadPool>(/*name_prefix=*/"engine", | |
| /*max_num_threads=*/1); | |
| } | |
| std::unique_ptr<Tokenizer> tokenizer_; | |
| proto::SamplerParameters sampler_params_; | |
| std::unique_ptr<ThreadPool> worker_thread_pool_; | |
| }; | |
| absl::StatusOr<std::unique_ptr<AudioLiteRtCompiledModelExecutor>> | |
| CreateAudioExecutor(Environment& env, const std::string& model_path, | |
| int max_sequence_length, Backend backend) { | |
| ASSIGN_OR_RETURN(auto model_file, ScopedFile::Open(model_path)); | |
| auto model_file_ptr = std::make_shared<ScopedFile>(std::move(model_file)); | |
| ASSIGN_OR_RETURN(auto model_assets, ModelAssets::Create(model_file_ptr)); | |
| // Create the audio executor settings. | |
| ASSIGN_OR_RETURN(auto audio_executor_settings, | |
| AudioExecutorSettings::CreateDefault( | |
| model_assets, max_sequence_length, backend)); | |
| // Create the audio executor. | |
| return litert::lm::AudioLiteRtCompiledModelExecutor::Create( | |
| audio_executor_settings, env); | |
| } | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> CreateStreamingTestCallback( | |
| absl::Status& status_ref, std::vector<std::string>& texts_ref, | |
| absl::Notification& done_ref, bool delay_on_next = false) { | |
| return [&status_ref, &texts_ref, &done_ref, | |
| delay_on_next](absl::StatusOr<Responses> responses) mutable { | |
| if (!responses.ok()) { | |
| status_ref = std::move(responses.status()); | |
| done_ref.Notify(); | |
| return; | |
| } | |
| if (responses->GetTaskState() == TaskState::kDone) { | |
| done_ref.Notify(); | |
| return; | |
| } | |
| if (delay_on_next) { | |
| absl::SleepFor(absl::Milliseconds(50)); | |
| } | |
| EXPECT_EQ(responses->GetTexts().size(), 1); | |
| texts_ref.push_back(responses->GetTexts()[0]); | |
| }; | |
| } | |
| TEST_F(SessionBasicTest, SessionCreation) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // The prefill tokens are the expected tokens that will be passed in | |
| // at each time the Prefill function is called. The values are the | |
| // token ids of the input prompt "Hello World!". | |
| // The decode tokens are the expected tokens that will be returned | |
| // by the Decode function. The values are the token ids of the | |
| // output response "How's it going?" followed by the stop token id | |
| // (2294). | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, std::nullopt, | |
| worker_thread_pool_.get()); | |
| EXPECT_OK(session); | |
| // Second session creation should fail because there is already a session. | |
| EXPECT_THAT(SessionBasic::Create(executor.get(), tokenizer_.get(), | |
| /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, | |
| std::nullopt, worker_thread_pool_.get()), | |
| StatusIs(absl::StatusCode::kFailedPrecondition)); | |
| // After the first session is destroyed, the second session creation should | |
| // succeed. | |
| session->reset(); | |
| EXPECT_OK(SessionBasic::Create(executor.get(), tokenizer_.get(), | |
| /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, | |
| std::nullopt, worker_thread_pool_.get())); | |
| } | |
| TEST_F(SessionBasicTest, RunPrefill) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // The prefill tokens are the expected tokens that will be passed in | |
| // at each time the Prefill function is called. The values are the | |
| // token ids of the input prompt "Hello World!". | |
| // The decode tokens are the expected tokens that will be returned | |
| // by the Decode function. The values are the token ids of the | |
| // output response "How's it going?" followed by the stop token id | |
| // (2294). | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, std::nullopt, | |
| worker_thread_pool_.get()); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK((*session)->RunPrefill(inputs)); | |
| } | |
| TEST_F(SessionBasicTest, EmptyInputTextReturnsError) { | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN(auto executor, CreateFakeLlmExecutor( | |
| /*prefill_tokens=*/{{}}, | |
| /*decode_tokens=*/{{}})); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, std::nullopt, | |
| worker_thread_pool_.get()); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("")); | |
| EXPECT_THAT((*session)->RunPrefill(inputs), | |
| StatusIs(absl::StatusCode::kInvalidArgument, | |
| "No token IDs found in preprocessed_contents.")); | |
| } | |
| TEST_F(SessionBasicTest, RunDecode) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, std::nullopt, | |
| worker_thread_pool_.get()); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK((*session)->RunPrefill(inputs)); | |
| auto responses = (*session)->RunDecode(); | |
| EXPECT_OK(responses); | |
| // Expect a single output candidate. | |
| EXPECT_EQ(responses->GetTexts().size(), 1); | |
| // The response is " How's it going?" since "!" is the stop token which is | |
| // not included in the response. | |
| EXPECT_EQ(responses->GetTexts()[0], " How's it going?"); | |
| } | |
| TEST_F(SessionBasicTest, RunDecodeWithMaxOutputTokens) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, std::nullopt, | |
| worker_thread_pool_.get()); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK((*session)->RunPrefill(inputs)); | |
| auto decode_config = DecodeConfig::CreateDefault(); | |
| decode_config.SetMaxOutputTokens(2); | |
| auto responses = (*session)->RunDecode(decode_config); | |
| EXPECT_OK(responses); | |
| // Expect a single output candidate. | |
| EXPECT_EQ(responses->GetTexts().size(), 1); | |
| EXPECT_EQ(responses->GetTexts()[0], " How'"); | |
| } | |
| TEST_F(SessionBasicTest, RunDecodeWithMultipleOutputCandidates) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetNumOutputCandidates(3); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?", "Hello World", "How's it going?" | |
| /*decode_tokens=*/{{224, 90, 224}, | |
| {24, 547, 24}, | |
| {8, 58, 8}, | |
| {66, 735, 66}, | |
| {246, 210, 246}, | |
| {18, 466, 18}, | |
| {2295, 2294, 2295}, | |
| {2294, 0, 2294}})); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, std::nullopt, | |
| worker_thread_pool_.get()); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK((*session)->RunPrefill(inputs)); | |
| auto responses = (*session)->RunDecode(); | |
| EXPECT_OK(responses); | |
| EXPECT_EQ(responses->GetTexts().size(), 3); | |
| // The response is " How's it going?" since "!" is the stop token which is | |
| // not included in the response. | |
| EXPECT_EQ(responses->GetTexts()[0], " How's it going?"); | |
| EXPECT_EQ(responses->GetTexts()[1], " Hello World"); | |
| EXPECT_EQ(responses->GetTexts()[2], " How's it going?"); | |
| } | |
| TEST_F(SessionBasicTest, RunDecodeWithSamplerAndConstrainedDecoding) { | |
| // Fake constraint that expects "'s it". | |
| std::vector<int> expected_token_ids = {24, 8, 66, 0}; | |
| auto constraint = | |
| FakeConstraint(expected_token_ids, /*vocabulary_size=*/2560); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}, {0}}; | |
| // Top P sampler. | |
| proto::SamplerParameters sampler_params; | |
| sampler_params.set_type(proto::SamplerParameters::TOP_P); | |
| sampler_params.set_k(1); | |
| sampler_params.set_temperature(1.0); | |
| sampler_params.set_p(0.5); | |
| sampler_params.set_seed(1); | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| /*prefill_tokens=*/{{2, 224}, // The first prefill. | |
| {0}}, // The expected prefill tokens that after | |
| // stop tokens are found in decoding with | |
| // sampler. That is, the last | |
| // sampled tokens at stop condition. | |
| // "How's it going?" | |
| /*decode_tokens=*/{{24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, | |
| /*benchmark_info=*/std::nullopt, worker_thread_pool_.get()); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("How")); | |
| auto decode_config = DecodeConfig::CreateDefault(); | |
| decode_config.SetConstraint(&constraint); | |
| EXPECT_OK((*session)->RunPrefill(inputs)); | |
| ASSERT_OK_AND_ASSIGN(auto responses, (*session)->RunDecode(decode_config)); | |
| // Expect a single output candidate. | |
| EXPECT_EQ(responses.GetTexts().size(), 1); | |
| EXPECT_EQ(responses.GetTexts()[0], "'s it"); | |
| } | |
| TEST_F(SessionBasicTest, RunDecodeWithConstrainedDecodingNoSampler) { | |
| // Fake constraint that expects "'s it". | |
| std::vector<int> expected_token_ids = {24, 8, 66, 0}; | |
| auto constraint = | |
| FakeConstraint(expected_token_ids, /*vocabulary_size=*/2560); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}, {0}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::GPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| /*prefill_tokens=*/{{2, 224}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{{24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, | |
| /*benchmark_info=*/std::nullopt, worker_thread_pool_.get()); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("How")); | |
| auto decode_config = DecodeConfig::CreateDefault(); | |
| decode_config.SetConstraint(&constraint); | |
| EXPECT_OK((*session)->RunPrefill(inputs)); | |
| ASSERT_OK_AND_ASSIGN(auto responses, (*session)->RunDecode(decode_config)); | |
| // Expect a single output candidate. | |
| EXPECT_EQ(responses.GetTexts().size(), 1); | |
| EXPECT_EQ(responses.GetTexts()[0], "'s it"); | |
| } | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> CreateTestCallback( | |
| bool& done_ref) { | |
| return [&done_ref](absl::StatusOr<Responses> responses) mutable { | |
| if (responses.ok() && responses->GetTexts().empty()) { | |
| done_ref = true; | |
| } | |
| }; | |
| } | |
| TEST_F(SessionBasicTest, RunPrefillAsync) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.SetStartTokenId(2); | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, std::nullopt, | |
| worker_thread_pool_.get()); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| bool done = false; | |
| auto callback = CreateTestCallback(done); | |
| EXPECT_OK((*session)->RunPrefillAsync(inputs, std::move(callback))); | |
| // Wait for the async call to finish. | |
| EXPECT_OK(worker_thread_pool_->WaitUntilDone(absl::Seconds(100))); | |
| EXPECT_TRUE(done); | |
| } | |
| TEST_F(SessionBasicTest, RunDecodeAsync) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.SetStartTokenId(2); | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, std::nullopt, | |
| worker_thread_pool_.get()); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| bool done_prefill = false; | |
| EXPECT_OK( | |
| (*session)->RunPrefillAsync(inputs, CreateTestCallback(done_prefill))); | |
| bool done_decode = false; | |
| EXPECT_OK((*session)->RunDecodeAsync(CreateTestCallback(done_decode))); | |
| EXPECT_OK(worker_thread_pool_->WaitUntilDone(absl::Seconds(100))); | |
| EXPECT_TRUE(done_prefill); | |
| EXPECT_TRUE(done_decode); | |
| } | |
| TEST_F(SessionBasicTest, RunDecodeAsyncWithSamplerAndConstrainedDecoding) { | |
| // Fake constraint that expects "'s it". | |
| std::vector<int> expected_token_ids = {24, 8, 66, 0}; | |
| auto constraint = | |
| FakeConstraint(expected_token_ids, /*vocabulary_size=*/2560); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}, {0}}; | |
| // Top P sampler. | |
| proto::SamplerParameters sampler_params; | |
| sampler_params.set_type(proto::SamplerParameters::TOP_P); | |
| sampler_params.set_k(1); | |
| sampler_params.set_temperature(1.0); | |
| sampler_params.set_p(0.5); | |
| sampler_params.set_seed(1); | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| /*prefill_tokens=*/{{2, 224}, // The first prefill. | |
| {0}}, // The expected prefill tokens that after | |
| // stop tokens are found in decoding with | |
| // sampler. That is, the last | |
| // sampled tokens at stop condition. | |
| // "How's it going?" | |
| /*decode_tokens=*/{{24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, | |
| /*benchmark_info=*/std::nullopt, worker_thread_pool_.get()); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("How")); | |
| bool done_prefill = false; | |
| EXPECT_OK( | |
| (*session)->RunPrefillAsync(inputs, CreateTestCallback(done_prefill))); | |
| absl::Status status; | |
| std::vector<std::string> texts; | |
| absl::Notification done_decode = absl::Notification(); | |
| auto decode_config = DecodeConfig::CreateDefault(); | |
| decode_config.SetConstraint(&constraint); | |
| EXPECT_OK((*session)->RunDecodeAsync( | |
| CreateStreamingTestCallback(status, texts, done_decode), decode_config)); | |
| done_decode.WaitForNotification(); | |
| EXPECT_OK(status); | |
| EXPECT_EQ(texts.size(), 3); | |
| EXPECT_THAT(texts, testing::ElementsAre("'", "s", " it")); | |
| } | |
| TEST_F(SessionBasicTest, RunDecodeAsyncWithConstrainedDecodingNoSampler) { | |
| // Fake constraint that expects "'s it". | |
| std::vector<int> expected_token_ids = {24, 8, 66, 0}; | |
| auto constraint = | |
| FakeConstraint(expected_token_ids, /*vocabulary_size=*/2560); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}, {0}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| /*prefill_tokens=*/{{2, 224}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{{24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, | |
| /*benchmark_info=*/std::nullopt, worker_thread_pool_.get()); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("How")); | |
| bool done_prefill = false; | |
| EXPECT_OK( | |
| (*session)->RunPrefillAsync(inputs, CreateTestCallback(done_prefill))); | |
| absl::Status status; | |
| std::vector<std::string> texts; | |
| absl::Notification done_decode = absl::Notification(); | |
| auto decode_config = DecodeConfig::CreateDefault(); | |
| decode_config.SetConstraint(&constraint); | |
| EXPECT_OK((*session)->RunDecodeAsync( | |
| CreateStreamingTestCallback(status, texts, done_decode), decode_config)); | |
| done_decode.WaitForNotification(); | |
| EXPECT_OK(status); | |
| EXPECT_EQ(texts.size(), 3); | |
| EXPECT_THAT(texts, testing::ElementsAre("'", "s", " it")); | |
| } | |
| TEST_F(SessionBasicTest, RunTextScoringEmptyTargetTextFailure) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, std::nullopt, | |
| worker_thread_pool_.get()); | |
| std::vector<absl::string_view> target_text; | |
| EXPECT_THAT((*session)->RunTextScoring(target_text, | |
| /*store_token_lengths=*/false), | |
| StatusIs(absl::StatusCode::kInvalidArgument, | |
| "Target text size should be 1.")); | |
| } | |
| TEST_F(SessionBasicTest, RunTextScoringMultipleTargetTextFailure) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, std::nullopt, | |
| worker_thread_pool_.get()); | |
| std::vector<absl::string_view> target_text; | |
| target_text.push_back("How's it going?"); | |
| target_text.push_back("How are you?"); | |
| EXPECT_THAT( | |
| (*session)->RunTextScoring(target_text, /*store_token_lengths=*/false), | |
| StatusIs(absl::StatusCode::kInvalidArgument, | |
| "Target text size should be 1.")); | |
| } | |
| TEST_F(SessionBasicTest, RunTextScoringWithoutTokenLengthsSuccess) { | |
| const auto responses_without_token_lengths = RunTextScoring( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{{224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}}, | |
| /*input_prompt=*/"Hello World!", | |
| /*target_text=*/"How's it going?", sampler_params_, | |
| /*store_token_lengths=*/false, tokenizer_.get(), | |
| worker_thread_pool_.get()); | |
| EXPECT_OK(responses_without_token_lengths); | |
| // Expect a single output candidate with score 0.0f. | |
| EXPECT_EQ(responses_without_token_lengths->GetScores().size(), 1); | |
| EXPECT_EQ(responses_without_token_lengths->GetScores()[0], 0.0f); | |
| EXPECT_FALSE(responses_without_token_lengths->GetTokenLengths().has_value()); | |
| } | |
| TEST_F(SessionBasicTest, RunTextScoringWithTokenLengthsSuccess) { | |
| const auto responses_with_token_lengths = RunTextScoring( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{{224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}}, | |
| /*input_prompt=*/"Hello World!", | |
| /*target_text=*/"How's it going?", sampler_params_, | |
| /*store_token_lengths=*/true, tokenizer_.get(), | |
| worker_thread_pool_.get()); | |
| EXPECT_OK(responses_with_token_lengths); | |
| // Expect a single output candidate with score 0.0f and token length 7. | |
| EXPECT_EQ(responses_with_token_lengths->GetScores().size(), 1); | |
| EXPECT_EQ(responses_with_token_lengths->GetScores()[0], 0.0f); | |
| EXPECT_TRUE(responses_with_token_lengths->GetTokenLengths().has_value()); | |
| EXPECT_EQ(responses_with_token_lengths->GetTokenLengths()->size(), 1); | |
| EXPECT_EQ((*responses_with_token_lengths->GetTokenLengths())[0], 7); | |
| } | |
| TEST_F(SessionBasicTest, RunTextScoringAsyncWithTokenLengthsSuccess) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, std::nullopt, | |
| worker_thread_pool_.get()); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK((*session)->RunPrefill(inputs)); | |
| absl::Notification done; | |
| absl::StatusOr<Responses> responses; | |
| auto callback = [&done, &responses](absl::StatusOr<Responses> r) { | |
| responses = std::move(r); | |
| done.Notify(); | |
| }; | |
| std::vector<absl::string_view> target_texts; | |
| target_texts.push_back("How's it going?"); | |
| ASSERT_OK((*session)->RunTextScoringAsync(target_texts, std::move(callback), | |
| /*store_token_lengths=*/true)); | |
| done.WaitForNotification(); | |
| EXPECT_OK(responses); | |
| // Expect a single output candidate with score 0.0f and token length 7. | |
| EXPECT_EQ(responses->GetScores().size(), 1); | |
| EXPECT_EQ(responses->GetScores()[0], 0.0f); | |
| EXPECT_TRUE(responses->GetTokenLengths().has_value()); | |
| EXPECT_EQ(responses->GetTokenLengths()->size(), 1); | |
| EXPECT_EQ((*responses->GetTokenLengths())[0], 7); | |
| } | |
| TEST_F(SessionBasicTest, GenerateContentStream) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, | |
| /*benchmark_info=*/std::nullopt, worker_thread_pool_.get()); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| absl::Status status; | |
| std::vector<std::string> texts; | |
| absl::Notification done = absl::Notification(); | |
| EXPECT_OK((*session)->GenerateContentStream( | |
| inputs, CreateStreamingTestCallback(status, texts, done))); | |
| done.WaitForNotification(); | |
| EXPECT_OK(status); | |
| EXPECT_EQ(texts.size(), 7); | |
| EXPECT_THAT(texts, | |
| testing::ElementsAre(" How", "'", "s", " it", " go", "ing", "?")); | |
| } | |
| TEST_F(SessionBasicTest, GenerateContentStreamEmptyInput) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, std::nullopt, | |
| worker_thread_pool_.get()); | |
| std::vector<InputData> inputs; | |
| absl::Status status; | |
| std::vector<std::string> texts; | |
| absl::Notification done; | |
| EXPECT_THAT((*session)->GenerateContentStream( | |
| inputs, CreateStreamingTestCallback(status, texts, done)), | |
| StatusIs(absl::StatusCode::kInvalidArgument, "Input is empty.")); | |
| } | |
| TEST_F(SessionBasicTest, GenerateContentStreamPrefillError) { | |
| // Configure the executor to fail at prefill. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| auto* fake_executor = static_cast<FakeLlmExecutor*>(executor.get()); | |
| fake_executor->SetPrefillStatus(absl::InternalError("Prefill failed")); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, std::nullopt, | |
| worker_thread_pool_.get()); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| absl::Status status; | |
| std::vector<std::string> texts; | |
| absl::Notification done; | |
| EXPECT_OK((*session)->GenerateContentStream( | |
| inputs, CreateStreamingTestCallback(status, texts, done))); | |
| done.WaitForNotification(); | |
| EXPECT_FALSE(status.ok()); | |
| EXPECT_THAT(status, StatusIs(absl::StatusCode::kInternal, "Prefill failed")); | |
| } | |
| TEST_F(SessionBasicTest, GenerateContentStreamDecodeError) { | |
| // Configure the executor to fail at decode. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| auto* fake_executor = static_cast<FakeLlmExecutor*>(executor.get()); | |
| fake_executor->SetDecodeStatus(absl::InternalError("Decode failed")); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, std::nullopt, | |
| worker_thread_pool_.get()); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| absl::Status status; | |
| std::vector<std::string> texts; | |
| absl::Notification done; | |
| EXPECT_OK((*session)->GenerateContentStream( | |
| inputs, CreateStreamingTestCallback(status, texts, done))); | |
| done.WaitForNotification(); | |
| EXPECT_FALSE(status.ok()); | |
| EXPECT_THAT(status, StatusIs(absl::StatusCode::kInternal, "Decode failed")); | |
| } | |
| TEST_F(SessionBasicTest, ProcessAndCombineContentsSingleText) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( | |
| "<test>User\n"); | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( | |
| "<end>\n"); | |
| session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( | |
| "<test>Model\n"); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionBasic::Create(executor.get(), tokenizer_.get(), | |
| /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, | |
| std::nullopt, worker_thread_pool_.get())); | |
| std::vector<InputData> preprocessed_contents; | |
| ASSERT_OK_AND_ASSIGN(auto ids_buffer, tokenizer_->TokenIdsToTensorBuffer( | |
| {90, 547, 58, 735, 210, 466})); | |
| preprocessed_contents.emplace_back(InputText(std::move(ids_buffer))); | |
| auto executor_input_or = | |
| session->ProcessAndCombineContents(preprocessed_contents); | |
| ASSERT_OK(executor_input_or); | |
| ASSERT_OK_AND_ASSIGN(auto text_data, executor_input_or->GetTextDataPtr()); | |
| ASSERT_NE(text_data, nullptr); | |
| LITERT_ASSERT_OK_AND_ASSIGN( | |
| auto token_ids_span, | |
| ReferTensorBufferAsSpan<int>(text_data->GetTokenIds())); | |
| EXPECT_THAT(std::vector<int>(token_ids_span.begin(), token_ids_span.end()), | |
| testing::ElementsAre(90, 547, 58, 735, 210, 466)); | |
| } | |
| TEST_F(SessionBasicTest, ProcessAndCombineContentsMultiText) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( | |
| "<test>User"); | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( | |
| "<end>\n"); | |
| session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( | |
| "<test>Model\n"); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionBasic::Create(executor.get(), tokenizer_.get(), | |
| /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, | |
| std::nullopt, worker_thread_pool_.get())); | |
| std::vector<InputData> preprocessed_contents; | |
| LITERT_ASSERT_OK_AND_ASSIGN(auto ids_buffer1, | |
| tokenizer_->TokenIdsToTensorBuffer({90, 547})); | |
| preprocessed_contents.emplace_back(InputText(std::move(ids_buffer1))); | |
| LITERT_ASSERT_OK_AND_ASSIGN( | |
| auto ids_buffer2, | |
| tokenizer_->TokenIdsToTensorBuffer({58, 735, 210, 466})); | |
| preprocessed_contents.emplace_back(InputText(std::move(ids_buffer2))); | |
| auto executor_input_or = | |
| session->ProcessAndCombineContents(preprocessed_contents); | |
| ASSERT_OK(executor_input_or); | |
| ASSERT_OK_AND_ASSIGN(auto text_data, executor_input_or->GetTextDataPtr()); | |
| ASSERT_NE(text_data, nullptr); | |
| LITERT_ASSERT_OK_AND_ASSIGN( | |
| auto token_ids_span, | |
| ReferTensorBufferAsSpan<int>(text_data->GetTokenIds())); | |
| EXPECT_THAT(std::vector<int>(token_ids_span.begin(), token_ids_span.end()), | |
| testing::ElementsAre(90, 547, 58, 735, 210, 466)); | |
| } | |
| TEST_F(SessionBasicTest, ProcessAndCombineContentsEmptyFails) { | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionBasic::Create(executor.get(), tokenizer_.get(), | |
| /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, | |
| std::nullopt, worker_thread_pool_.get())); | |
| std::vector<InputData> preprocessed_contents; | |
| auto result = session->ProcessAndCombineContents(preprocessed_contents); | |
| EXPECT_THAT(result, StatusIs(absl::StatusCode::kInvalidArgument, | |
| "No token IDs found in preprocessed_contents.")); | |
| } | |
| TEST_F(SessionBasicTest, RunIncrementalPrefillWithDecode) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( | |
| "User:"); | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( | |
| "[END]"); | |
| session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( | |
| "Model:"); | |
| session_config.GetMutableLlmModelType().mutable_gemma3n(); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| /*prefill_tokens=*/ | |
| { | |
| {2, 423, 8, 179, 29, 207, 19, 547, 58}, // prefill chunk 1.1 | |
| {735, 210, 466, 2294}, // prefill chunk 1.2 | |
| {433, 2172, 1920, 432, 197, 979, 3076, | |
| 29}, // prefill ran before decode with turn change template | |
| {423, 8, 179, 29, 207, 19, 547, 58, 735, 210, 466, | |
| 2294}, // prefill chunk 2.1 | |
| {433, 2172, 1920, 432, 197, 979, 3076, | |
| 29}, // prefill ran before decode with turn change template | |
| }, | |
| /*decode_tokens=*/ | |
| {{1}, {2}, {3}, {2294}, {1}, {2}, {3}, {2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionBasic::Create(executor.get(), tokenizer_.get(), | |
| /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, | |
| std::nullopt, worker_thread_pool_.get())); | |
| { | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello ")); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| } | |
| { | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("World!")); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| } | |
| { | |
| EXPECT_OK(session->RunDecode()); | |
| } | |
| { | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| } | |
| { | |
| EXPECT_OK(session->RunDecode()); | |
| } | |
| } | |
| // TODO: b/441514829 - Enable the tests on Windows once the bug is fixed. | |
| TEST_F(SessionBasicTest, ProcessAndCombineContentsAudioSuccess) { | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| session_config.GetMutableLlmModelType().mutable_gemma3n(); | |
| LITERT_ASSERT_OK_AND_ASSIGN( | |
| auto env, Environment::Create(std::vector<Environment::Option>())); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto audio_executor, | |
| CreateAudioExecutor(env, | |
| (std::filesystem::path(::testing::SrcDir()) / | |
| std::string(kTestAudioModelPath)) | |
| .string(), | |
| /*max_sequence_length=*/0, Backend::CPU)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!<start_of_audio>" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294, 256000}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/ | |
| {{224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}}, | |
| /*audio_embedding=*/ | |
| std::vector<float>(kExpectedAudioEmbedding.begin(), | |
| kExpectedAudioEmbedding.end()))); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), | |
| /*vision_executor=*/nullptr, | |
| /*audio_executor=*/audio_executor.get(), session_config, | |
| std::nullopt, worker_thread_pool_.get())); | |
| std::vector<InputData> preprocessed_contents; | |
| ASSERT_OK_AND_ASSIGN( | |
| auto ids_buffer, | |
| tokenizer_->TokenIdsToTensorBuffer({90, 547, 58, 735, 210, 466, 256000})); | |
| preprocessed_contents.emplace_back(InputText(std::move(ids_buffer))); | |
| LITERT_ASSERT_OK_AND_ASSIGN( | |
| TensorBuffer mel_spectrogram_data, | |
| CopyToTensorBuffer<float>( | |
| mel_spectrogram_data, | |
| {1, kSpectrogramSequenceLength, kSpectrogramFrequencySlots})); | |
| InputAudio input_audio(std::move(mel_spectrogram_data)); | |
| preprocessed_contents.emplace_back(std::move(input_audio)); | |
| preprocessed_contents.emplace_back(InputAudioEnd()); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto result, session->ProcessAndCombineContents(preprocessed_contents)); | |
| ASSERT_OK_AND_ASSIGN(auto text_data, result.GetTextDataPtr()); | |
| ASSERT_NE(text_data, nullptr); | |
| LITERT_ASSERT_OK_AND_ASSIGN( | |
| auto token_ids_span, | |
| ReferTensorBufferAsSpan<int>(text_data->GetTokenIds())); | |
| // The input to audio executor has length 10. | |
| // The processed audio embedding has length 5. | |
| EXPECT_THAT(std::vector<int>(token_ids_span.begin(), token_ids_span.end()), | |
| testing::ElementsAre(90, 547, 58, 735, 210, 466, 256000, -2, -2, | |
| -2, -2, -2, -4)); | |
| } | |
| TEST_F(SessionBasicTest, ProcessAndCombineContentsTextAndAudioSuccess) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( | |
| "User:"); | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( | |
| "[END]"); | |
| session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( | |
| "Model:"); | |
| session_config.GetMutableLlmModelType().mutable_gemma3n(); | |
| LITERT_ASSERT_OK_AND_ASSIGN( | |
| auto env, Environment::Create(std::vector<Environment::Option>())); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto audio_executor, | |
| CreateAudioExecutor(env, | |
| (std::filesystem::path(::testing::SrcDir()) / | |
| std::string(kTestAudioModelPath)) | |
| .string(), | |
| /*max_sequence_length=*/0, Backend::CPU)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "User:Hello World!<start_of_audio>[END]Model:" | |
| /*prefill_tokens=*/{{2, 423, 8, 179, 29, 207, 19, | |
| 547, 58, 735, 210, 466, 2294, 256000, | |
| -2, -2, -2, -2, -2, -4}, | |
| {433, 2172, 1920, 432, 197, 979, 3076, 29}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/ | |
| {{224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}}, | |
| /*audio_embedding=*/ | |
| std::vector<float>(kExpectedAudioEmbedding.begin(), | |
| kExpectedAudioEmbedding.end()))); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), | |
| /*vision_executor=*/nullptr, | |
| /*audio_executor=*/audio_executor.get(), session_config, | |
| std::nullopt, worker_thread_pool_.get())); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!<start_of_audio>")); | |
| LITERT_ASSERT_OK_AND_ASSIGN( | |
| TensorBuffer mel_spectrogram_data, | |
| CopyToTensorBuffer<float>( | |
| mel_spectrogram_data, | |
| {1, kSpectrogramSequenceLength, kSpectrogramFrequencySlots})); | |
| InputAudio input_audio(std::move(mel_spectrogram_data)); | |
| inputs.emplace_back(std::move(input_audio)); | |
| inputs.emplace_back(InputAudioEnd()); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| } | |
| TEST_F(SessionBasicTest, ProcessAndCombineContentsTextAudioTextSuccess) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( | |
| "User:"); | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( | |
| "[END]"); | |
| session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( | |
| "Model:"); | |
| session_config.GetMutableLlmModelType().mutable_gemma3n(); | |
| LITERT_ASSERT_OK_AND_ASSIGN( | |
| auto env, Environment::Create(std::vector<Environment::Option>())); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto audio_executor, | |
| CreateAudioExecutor(env, | |
| (std::filesystem::path(::testing::SrcDir()) / | |
| std::string(kTestAudioModelPath)) | |
| .string(), | |
| /*max_sequence_length=*/0, Backend::CPU)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "User:Hello World!<start_of_audio>What does the audio say?" | |
| // "[END]Model:" | |
| /*prefill_tokens=*/ | |
| {{2, 423, 8, 179, 29, 207, 19, 547, 58, 735, 210, | |
| 466, 2294, 256000, -2, -2, -2, -2, -2, -4, 583, 378, | |
| 844, 166, 3, 14, 1252, 54, 58, 626, 2295}, | |
| {3995, 2172, 1920, 432, 197, 979, 3076, 29}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/ | |
| {{224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}}, | |
| /*audio_embedding=*/ | |
| std::vector<float>(kExpectedAudioEmbedding.begin(), | |
| kExpectedAudioEmbedding.end()))); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), | |
| /*vision_executor=*/nullptr, | |
| /*audio_executor=*/audio_executor.get(), session_config, | |
| std::nullopt, worker_thread_pool_.get())); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!<start_of_audio>")); | |
| LITERT_ASSERT_OK_AND_ASSIGN( | |
| TensorBuffer mel_spectrogram_data, | |
| CopyToTensorBuffer<float>( | |
| mel_spectrogram_data, | |
| {1, kSpectrogramSequenceLength, kSpectrogramFrequencySlots})); | |
| InputAudio input_audio(std::move(mel_spectrogram_data)); | |
| inputs.emplace_back(std::move(input_audio)); | |
| inputs.emplace_back(InputAudioEnd()); | |
| inputs.emplace_back(InputText("What does the audio say?")); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| } | |
| TEST_F(SessionBasicTest, GenerateContentStreamWithCancellation) { | |
| // Configure the executor to have a delay to simulate a long-running task. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto fake_executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| fake_executor->SetDecodeDelay(absl::Milliseconds(200)); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| auto session = | |
| SessionBasic::Create(fake_executor.get(), tokenizer_.get(), | |
| /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, | |
| std::nullopt, worker_thread_pool_.get()); | |
| ASSERT_OK(session); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| absl::Status status; | |
| std::vector<std::string> responses; | |
| absl::Notification done; | |
| (*session) | |
| ->GenerateContentStream( | |
| inputs, CreateStreamingTestCallback(status, responses, done, | |
| /*delay_on_next=*/true)) | |
| .IgnoreError(); | |
| // Wait for a short time to ensure the decoding has started. | |
| absl::SleepFor(absl::Milliseconds(100)); | |
| // Cancel the process. | |
| (*session)->CancelProcess(); | |
| // Wait for the callback to be done. | |
| done.WaitForNotification(); | |
| EXPECT_THAT(status, StatusIs(absl::StatusCode::kCancelled)); | |
| } | |
| TEST_F(SessionBasicTest, GenerateContentStreamAndDelete) { | |
| // Configure the executor to have a delay to simulate a long-running task. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto fake_executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| fake_executor->SetDecodeDelay(absl::Milliseconds(200)); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| auto session = | |
| SessionBasic::Create(fake_executor.get(), tokenizer_.get(), | |
| /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, | |
| std::nullopt, worker_thread_pool_.get()); | |
| ASSERT_OK(session); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| absl::Status status; | |
| std::vector<std::string> responses; | |
| absl::Notification done; | |
| (*session) | |
| ->GenerateContentStream( | |
| inputs, CreateStreamingTestCallback(status, responses, done, | |
| /*delay_on_next=*/true)) | |
| .IgnoreError(); | |
| // Reset the session to trigger destructor. | |
| session->reset(); | |
| // Wait for thread pool to finish. If destructor does not wait, this might | |
| // cause use-after-free. | |
| worker_thread_pool_->WaitUntilDone(absl::Seconds(10)).IgnoreError(); | |
| if (!done.HasBeenNotified()) { | |
| done.Notify(); | |
| } | |
| } | |
| class SessionBasicCancellationTest : public testing::TestWithParam<bool> { | |
| protected: | |
| void SetUp() override { | |
| auto tokenizer = ExtendedTokenizer::CreateFromFile( | |
| (std::filesystem::path(::testing::SrcDir()) / | |
| std::string(kTestdataDir) / "sentencepiece.model") | |
| .string()); | |
| ASSERT_OK(tokenizer); | |
| tokenizer.value()->SetExtendedToken(256000, "<start_of_audio>"); | |
| tokenizer_ = std::move(*tokenizer); | |
| sampler_params_.set_type(proto::SamplerParameters::TYPE_UNSPECIFIED); | |
| // Creating the thread pool of a single thread to execute the works. | |
| worker_thread_pool_ = std::make_unique<ThreadPool>(/*name_prefix=*/"engine", | |
| /*max_num_threads=*/1); | |
| } | |
| bool use_benchmark_info_ = GetParam(); | |
| std::unique_ptr<Tokenizer> tokenizer_; | |
| proto::SamplerParameters sampler_params_; | |
| std::unique_ptr<ThreadPool> worker_thread_pool_; | |
| }; | |
| TEST_P(SessionBasicCancellationTest, | |
| GenerateContentStreamCancelThenGenerateWithBenchmark) { | |
| // Configure the executor to have a delay to simulate a long-running task. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto fake_executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}, | |
| // The second prefill doesn't have bos token. | |
| {90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| fake_executor->SetDecodeDelay(absl::Milliseconds(200)); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| std::optional<BenchmarkInfo> benchmark_info; | |
| if (use_benchmark_info_) { | |
| proto::BenchmarkParams benchmark_params; | |
| benchmark_info.emplace(benchmark_params); | |
| } | |
| auto session = | |
| SessionBasic::Create(fake_executor.get(), tokenizer_.get(), | |
| /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, | |
| benchmark_info, worker_thread_pool_.get()); | |
| ASSERT_OK(session); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| absl::Status status; | |
| std::vector<std::string> responses; | |
| absl::Notification done1; | |
| (*session) | |
| ->GenerateContentStream( | |
| inputs, CreateStreamingTestCallback(status, responses, done1, | |
| /*delay_on_next=*/true)) | |
| .IgnoreError(); | |
| // Cancel the process. | |
| (*session)->CancelProcess(); | |
| // Wait for the callback to be done. | |
| done1.WaitForNotification(); | |
| EXPECT_THAT(status, StatusIs(absl::StatusCode::kCancelled)); | |
| // Generate again after cancellation. | |
| // The second generation should succeed. | |
| status = absl::OkStatus(); | |
| responses.clear(); | |
| absl::Notification done2; | |
| (*session) | |
| ->GenerateContentStream( | |
| inputs, CreateStreamingTestCallback(status, responses, done2, | |
| /*delay_on_next=*/true)) | |
| .IgnoreError(); | |
| done2.WaitForNotification(); | |
| EXPECT_OK(status); | |
| // Reset worker thread pool to stop accessing session and fake executor. | |
| // We need to reset session first because it holds a reference to the | |
| // thread pool. | |
| (*session).reset(); | |
| worker_thread_pool_.reset(); | |
| } | |
| INSTANTIATE_TEST_SUITE_P(SessionBasicCancellationTest, | |
| SessionBasicCancellationTest, testing::Bool(), | |
| testing::PrintToStringParamName()); | |
| TEST_F(SessionBasicTest, GenerateContentStreamOnCancelledSession) { | |
| ASSERT_OK_AND_ASSIGN( | |
| auto fake_executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| auto session = | |
| SessionBasic::Create(fake_executor.get(), tokenizer_.get(), | |
| /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, | |
| std::nullopt, worker_thread_pool_.get()); | |
| ASSERT_OK(session); | |
| (*session)->CancelProcess(); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| absl::Status status; | |
| std::vector<std::string> responses; | |
| absl::Notification done; | |
| // The session is cancelled, so the call should return with a kCancelled | |
| // error. | |
| EXPECT_OK((*session)->GenerateContentStream( | |
| inputs, CreateStreamingTestCallback(status, responses, done))); | |
| // Wait for the callback to be done. | |
| done.WaitForNotification(); | |
| EXPECT_OK(status); | |
| } | |
| TEST_F(SessionBasicTest, | |
| TestBenchmarkModeWithoutNumPrefillTokensRespectPromptTemplate) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( | |
| "<test>User\n"); | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( | |
| "<end>\n"); | |
| session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( | |
| "<test>Model\n"); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // Expected tokens: "</s><test>User\nHello World!" + | |
| // "<end>\n<test>Model\n" | |
| /*prefill_tokens=*/{{2, 4, 0, 39, 637, 0, 3328, 8, 179, 90, 547, 58, | |
| 735, 210, 466, 2294}, | |
| {0, 40, 23, 0, 4, 0, 39, 637, 0, 197, 979, 3076}}, | |
| /*decode_tokens=*/{{224}})); | |
| proto::BenchmarkParams benchmark_params; | |
| BenchmarkInfo benchmark_info(benchmark_params); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, benchmark_info, | |
| worker_thread_pool_.get()); | |
| ASSERT_OK(session); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK((*session)->RunPrefill(inputs)); | |
| EXPECT_EQ((*session)->GetBenchmarkInfo()->GetTotalPrefillTurns(), 1); | |
| } | |
| TEST_F(SessionBasicTest, | |
| TestBenchmarkModeWithNumPrefillTokensIgnorePromptTemplate) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( | |
| "<test>User\n"); | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( | |
| "<end>\n"); | |
| session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( | |
| "<test>Model\n"); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // Expected tokens: "Hello World!" (No templates) | |
| /*prefill_tokens=*/{{90, 547, 58, 735, 210, 466, 2294}}, | |
| /*decode_tokens=*/{{224}})); | |
| proto::BenchmarkParams benchmark_params; | |
| benchmark_params.set_num_prefill_tokens(7); | |
| BenchmarkInfo benchmark_info(benchmark_params); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, benchmark_info, | |
| worker_thread_pool_.get()); | |
| ASSERT_OK(session); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK((*session)->RunPrefill(inputs)); | |
| EXPECT_EQ((*session)->GetBenchmarkInfo()->GetTotalPrefillTurns(), 1); | |
| } | |
| TEST_F(SessionBasicTest, | |
| GenerateContentStreamWithSamplerAndConstrainedDecoding) { | |
| // Fake constraint that expects "'s it". | |
| std::vector<int> expected_token_ids = {24, 8, 66, 0}; | |
| auto constraint = | |
| FakeConstraint(expected_token_ids, /*vocabulary_size=*/2560); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}, {0}}; | |
| // Top P sampler. | |
| proto::SamplerParameters sampler_params; | |
| sampler_params.set_type(proto::SamplerParameters::TOP_P); | |
| sampler_params.set_k(1); | |
| sampler_params.set_temperature(1.0); | |
| sampler_params.set_p(0.5); | |
| sampler_params.set_seed(1); | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| /*prefill_tokens=*/{{2, 224}, // The first prefill. | |
| {0}}, // The expected prefill tokens that after | |
| // stop tokens are found in decoding with | |
| // sampler. That is, the last | |
| // sampled tokens at stop condition. | |
| // "How's it going?" | |
| /*decode_tokens=*/{{24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), | |
| /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, | |
| /*benchmark_info=*/std::nullopt, worker_thread_pool_.get()); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("How")); | |
| absl::Status status; | |
| std::vector<std::string> texts; | |
| absl::Notification done_decode = absl::Notification(); | |
| auto decode_config = DecodeConfig::CreateDefault(); | |
| decode_config.SetConstraint(&constraint); | |
| EXPECT_OK((*session)->GenerateContentStream( | |
| inputs, CreateStreamingTestCallback(status, texts, done_decode), | |
| decode_config)); | |
| done_decode.WaitForNotification(); | |
| EXPECT_OK(status); | |
| EXPECT_EQ(texts.size(), 3); | |
| EXPECT_THAT(texts, testing::ElementsAre("'", "s", " it")); | |
| } | |
| TEST_F(SessionBasicTest, SaveAndRewindCheckpoint) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| // Set up a fake executor. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}, | |
| {2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| auto session = SessionBasic::Create( | |
| executor.get(), tokenizer_.get(), /*vision_executor=*/nullptr, | |
| /*audio_executor=*/nullptr, session_config, std::nullopt, | |
| worker_thread_pool_.get()); | |
| ASSERT_OK(session); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| // Save checkpoint before prefill. | |
| EXPECT_OK((*session)->SaveCheckpoint("start")); | |
| EXPECT_THAT((*session)->GetCurrentStep(), testing::status::IsOkAndHolds(0)); | |
| EXPECT_OK((*session)->RunPrefill(inputs)); | |
| // Save checkpoint after prefill. | |
| EXPECT_OK((*session)->SaveCheckpoint("prefill")); | |
| EXPECT_THAT((*session)->GetCurrentStep(), testing::status::IsOkAndHolds(8)); | |
| // Rewind should succeed and restore the state. | |
| EXPECT_OK((*session)->RewindToCheckpoint("start")); | |
| EXPECT_THAT((*session)->GetCurrentStep(), testing::status::IsOkAndHolds(0)); | |
| // The "prefill" checkpoint should no longer exist. | |
| EXPECT_THAT((*session)->RewindToCheckpoint("prefill"), | |
| testing::status::StatusIs(absl::StatusCode::kNotFound)); | |
| // Checkpoint that doesn't exist should fail. | |
| EXPECT_THAT((*session)->RewindToCheckpoint("nonexistent"), | |
| testing::status::StatusIs(absl::StatusCode::kNotFound)); | |
| } | |
| } // namespace | |
| } // namespace litert::lm | |