Spaces:
Running
Running
| // Copyright 2025 The ODML Authors. | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| namespace litert::lm { | |
| namespace { | |
| // Must be longer than # of prefill tokens which is 4. | |
| constexpr int kMaxNumTokens = 8; | |
| constexpr int kMaxNumTokens = 16; | |
| TEST(EngineTest, CreateEngine_WithoutCache) { | |
| auto task_path = | |
| std::filesystem::path(::testing::SrcDir()) / | |
| "litert_lm/runtime/testdata/test_lm_new_metadata.task"; | |
| auto model_assets = ModelAssets::Create(task_path.string()); | |
| ASSERT_OK(model_assets); | |
| auto engine_settings = | |
| EngineSettings::CreateDefault(*model_assets, Backend::CPU); | |
| ASSERT_OK(engine_settings); | |
| engine_settings->GetMutableMainExecutorSettings().SetMaxNumTokens( | |
| kMaxNumTokens); | |
| engine_settings->GetMutableMainExecutorSettings().SetCacheDir(":nocache"); | |
| absl::StatusOr<std::unique_ptr<Engine>> llm = | |
| EngineFactory::CreateAny(*engine_settings); | |
| ABSL_CHECK_OK(llm); | |
| absl::StatusOr<std::unique_ptr<Engine::Session>> session = | |
| (*llm)->CreateSession(SessionConfig::CreateDefault()); | |
| ABSL_CHECK_OK(session); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello world!")); | |
| ABSL_CHECK_OK((*session)->RunPrefill(inputs)); | |
| auto responses = (*session)->RunDecode(); | |
| EXPECT_OK(responses); | |
| EXPECT_EQ(responses->GetTexts().size(), 1); | |
| EXPECT_FALSE(responses->GetTexts()[0].empty()); | |
| // 2nd run with the same engine. | |
| session->reset(); // Destroy the previous first. | |
| session = (*llm)->CreateSession(SessionConfig::CreateDefault()); | |
| ABSL_CHECK_OK(session); | |
| ABSL_CHECK_OK((*session)->RunPrefill(inputs)); | |
| responses = (*session)->RunDecode(); | |
| EXPECT_OK(responses); | |
| EXPECT_EQ(responses->GetTexts().size(), 1); | |
| EXPECT_FALSE(responses->GetTexts()[0].empty()); | |
| } | |
| TEST(EngineTestWithoutParallelLoading, CreateEngineAndRunInference) { | |
| auto task_path = | |
| std::filesystem::path(::testing::SrcDir()) / | |
| "litert_lm/runtime/testdata/test_lm_new_metadata.task"; | |
| auto model_assets = ModelAssets::Create(task_path.string()); | |
| ASSERT_OK(model_assets); | |
| auto engine_settings = | |
| EngineSettings::CreateDefault(*model_assets, Backend::CPU); | |
| ASSERT_OK(engine_settings); | |
| engine_settings->GetMutableMainExecutorSettings().SetMaxNumTokens( | |
| kMaxNumTokens); | |
| engine_settings->GetMutableMainExecutorSettings().SetCacheDir(":nocache"); | |
| engine_settings->SetParallelFileSectionLoading(false); | |
| absl::StatusOr<std::unique_ptr<Engine>> llm = | |
| EngineFactory::CreateAny(*engine_settings); | |
| ABSL_CHECK_OK(llm); | |
| absl::StatusOr<std::unique_ptr<Engine::Session>> session = | |
| (*llm)->CreateSession(SessionConfig::CreateDefault()); | |
| ABSL_CHECK_OK(session); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello world!")); | |
| ABSL_CHECK_OK((*session)->RunPrefill(inputs)); | |
| auto responses = (*session)->RunDecode(); | |
| EXPECT_OK(responses); | |
| EXPECT_EQ(responses->GetTexts().size(), 1); | |
| EXPECT_FALSE(responses->GetTexts()[0].empty()); | |
| } | |
| TEST(EngineTest, CreateEngine_WithCache) { | |
| auto cache_path = std::filesystem::path(::testing::TempDir()) / | |
| absl::StrCat("cache-", std::rand()); | |
| std::filesystem::remove_all(cache_path); | |
| absl::Cleanup remove_cache = [cache_path] { | |
| std::filesystem::remove_all(cache_path); | |
| }; | |
| auto task_path = | |
| std::filesystem::path(::testing::SrcDir()) / | |
| "litert_lm/runtime/testdata/test_lm_new_metadata.task"; | |
| auto model_assets = ModelAssets::Create(task_path.string()); | |
| ASSERT_OK(model_assets); | |
| auto engine_settings = | |
| EngineSettings::CreateDefault(*model_assets, Backend::CPU); | |
| ASSERT_OK(engine_settings); | |
| engine_settings->GetMutableMainExecutorSettings().SetMaxNumTokens( | |
| kMaxNumTokens); | |
| engine_settings->GetMutableMainExecutorSettings().SetCacheDir( | |
| cache_path.string()); | |
| // 1st run to populate the cache. | |
| absl::StatusOr<std::unique_ptr<Engine>> llm = | |
| EngineFactory::CreateAny(*engine_settings); | |
| ABSL_CHECK_OK(llm); | |
| absl::StatusOr<std::unique_ptr<Engine::Session>> session = | |
| (*llm)->CreateSession(SessionConfig::CreateDefault()); | |
| ABSL_CHECK_OK(session); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello world!")); | |
| ABSL_CHECK_OK((*session)->RunPrefill(inputs)); | |
| auto responses = (*session)->RunDecode(); | |
| EXPECT_OK(responses); | |
| EXPECT_EQ(responses->GetTexts().size(), 1); | |
| EXPECT_FALSE(responses->GetTexts()[0].empty()); | |
| // 2nd run with the same engine and the same cache. | |
| session->reset(); // Destroy the previous first. | |
| session = (*llm)->CreateSession(SessionConfig::CreateDefault()); | |
| ABSL_CHECK_OK(session); | |
| ABSL_CHECK_OK((*session)->RunPrefill(inputs)); | |
| responses = (*session)->RunDecode(); | |
| EXPECT_OK(responses); | |
| EXPECT_EQ(responses->GetTexts().size(), 1); | |
| EXPECT_FALSE(responses->GetTexts()[0].empty()); | |
| // 3rd run with a new engine and the same cache. | |
| session->reset(); // Destroy the previous first. | |
| llm->reset(); | |
| llm = EngineFactory::CreateAny(*engine_settings); | |
| ABSL_CHECK_OK(llm); | |
| session = (*llm)->CreateSession(SessionConfig::CreateDefault()); | |
| ABSL_CHECK_OK(session); | |
| ABSL_CHECK_OK((*session)->RunPrefill(inputs)); | |
| responses = (*session)->RunDecode(); | |
| EXPECT_OK(responses); | |
| EXPECT_EQ(responses->GetTexts().size(), 1); | |
| EXPECT_FALSE(responses->GetTexts()[0].empty()); | |
| } | |
| TEST(EngineTest, CreateEngine_WithModelAndCacheFromFileDescriptor) { | |
| auto cache_path = std::filesystem::path(::testing::TempDir()) / | |
| absl::StrCat("cache-", std::rand(), ".cache"); | |
| std::filesystem::remove_all(cache_path); | |
| { | |
| // Create an empty file - ScopedFile expects the file to exist. | |
| std::ofstream cache_file(cache_path.string()); | |
| } | |
| absl::Cleanup remove_cache = [cache_path] { | |
| std::filesystem::remove_all(cache_path); | |
| }; | |
| ASSERT_OK_AND_ASSIGN(auto scoped_cache_file, | |
| ScopedFile::OpenWritable(cache_path.string())); | |
| auto shared_scoped_cache_file = | |
| std::make_shared<ScopedFile>(std::move(scoped_cache_file)); | |
| auto task_path = | |
| std::filesystem::path(::testing::SrcDir()) / | |
| "litert_lm/runtime/testdata/test_lm_new_metadata.task"; | |
| ASSERT_OK_AND_ASSIGN(auto task_descriptor, | |
| ScopedFile::Open(task_path.string())); | |
| auto shared_task_descriptor = | |
| std::make_shared<ScopedFile>(std::move(task_descriptor)); | |
| auto model_assets = ModelAssets::Create(shared_task_descriptor); | |
| ASSERT_OK(model_assets); | |
| auto engine_settings = | |
| EngineSettings::CreateDefault(*model_assets, Backend::CPU); | |
| ASSERT_OK(engine_settings); | |
| engine_settings->GetMutableMainExecutorSettings().SetMaxNumTokens( | |
| kMaxNumTokens); | |
| engine_settings->GetMutableMainExecutorSettings().SetScopedCacheFile( | |
| shared_scoped_cache_file); | |
| absl::StatusOr<std::unique_ptr<Engine>> llm = | |
| EngineFactory::CreateAny(*engine_settings); | |
| ABSL_CHECK_OK(llm); | |
| absl::StatusOr<std::unique_ptr<Engine::Session>> session = | |
| (*llm)->CreateSession(SessionConfig::CreateDefault()); | |
| ABSL_CHECK_OK(session); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello world!")); | |
| ABSL_CHECK_OK((*session)->RunPrefill(inputs)); | |
| auto responses = (*session)->RunDecode(); | |
| EXPECT_OK(responses); | |
| EXPECT_EQ(responses->GetTexts().size(), 1); | |
| EXPECT_FALSE(responses->GetTexts()[0].empty()); | |
| } | |
| TEST(EngineTest, CreateEngine_WithBenchmark) { | |
| auto task_path = | |
| std::filesystem::path(::testing::SrcDir()) / | |
| "litert_lm/runtime/testdata/test_lm_new_metadata.task"; | |
| auto model_assets = ModelAssets::Create(task_path.string()); | |
| ASSERT_OK(model_assets); | |
| auto engine_settings = | |
| EngineSettings::CreateDefault(*model_assets, Backend::CPU); | |
| ASSERT_OK(engine_settings); | |
| // Enable Benchmark | |
| engine_settings->GetMutableBenchmarkParams(); | |
| absl::StatusOr<std::unique_ptr<Engine>> llm = | |
| EngineFactory::CreateAny(*engine_settings); | |
| ABSL_CHECK_OK(llm); | |
| absl::StatusOr<std::unique_ptr<Engine::Session>> session = | |
| (*llm)->CreateSession(SessionConfig::CreateDefault()); | |
| ABSL_CHECK_OK(session); | |
| auto benchmark_info = (*session)->GetMutableBenchmarkInfo(); | |
| ASSERT_OK(benchmark_info); | |
| const auto& init_phases = (*benchmark_info)->GetInitPhases(); | |
| EXPECT_TRUE(init_phases.contains(std::string( | |
| BenchmarkInfo::InitPhaseToString(BenchmarkInfo::InitPhase::kTokenizer)))); | |
| EXPECT_TRUE(init_phases.contains(std::string( | |
| BenchmarkInfo::InitPhaseToString(BenchmarkInfo::InitPhase::kExecutor)))); | |
| EXPECT_TRUE(init_phases.contains(std::string( | |
| BenchmarkInfo::InitPhaseToString(BenchmarkInfo::InitPhase::kTotal)))); | |
| } | |
| TEST(EngineTest, CreateEngine_AsyncTokenizer_ValidatesConcurrency) { | |
| auto task_path = | |
| std::filesystem::path(::testing::SrcDir()) / | |
| "litert_lm/runtime/testdata/test_lm_new_metadata.task"; | |
| auto model_assets = ModelAssets::Create(task_path.string()); | |
| ASSERT_OK(model_assets); | |
| auto engine_settings = | |
| EngineSettings::CreateDefault(*model_assets, Backend::CPU); | |
| ASSERT_OK(engine_settings); | |
| engine_settings->GetMutableMainExecutorSettings().SetMaxNumTokens( | |
| kMaxNumTokens); | |
| engine_settings->GetMutableMainExecutorSettings().SetCacheDir(":nocache"); | |
| // Enable Benchmark to measure the phases and prove concurrency | |
| engine_settings->GetMutableBenchmarkParams(); | |
| absl::StatusOr<std::unique_ptr<Engine>> llm = | |
| EngineFactory::CreateAny(*engine_settings); | |
| ABSL_CHECK_OK(llm); | |
| absl::StatusOr<std::unique_ptr<Engine::Session>> session = | |
| (*llm)->CreateSession(SessionConfig::CreateDefault()); | |
| ABSL_CHECK_OK(session); | |
| auto benchmark_info = (*session)->GetMutableBenchmarkInfo(); | |
| ASSERT_OK(benchmark_info); | |
| const auto& init_phases = benchmark_info.value()->GetInitPhases(); | |
| auto total_time = init_phases.at(std::string( | |
| BenchmarkInfo::InitPhaseToString(BenchmarkInfo::InitPhase::kTotal))); | |
| auto executor_time = init_phases.at(std::string( | |
| BenchmarkInfo::InitPhaseToString(BenchmarkInfo::InitPhase::kExecutor))); | |
| auto tokenizer_time = init_phases.at(std::string( | |
| BenchmarkInfo::InitPhaseToString(BenchmarkInfo::InitPhase::kTokenizer))); | |
| ABSL_LOG(INFO) << "total_time: " << total_time; | |
| ABSL_LOG(INFO) << "executor_time: " << executor_time; | |
| ABSL_LOG(INFO) << "tokenizer_time: " << tokenizer_time; | |
| // The total duration should be greater than or equal to the longest | |
| // concurrent branch. | |
| EXPECT_GE(total_time, executor_time); | |
| EXPECT_GE(total_time, tokenizer_time); | |
| // The total duration (minus the sequential part) should be less than the sum | |
| // of the two parallel branches. This is to prove that the tokenizer and | |
| // executor are loaded concurrently. | |
| auto rest_time = total_time - std::max(executor_time, tokenizer_time); | |
| EXPECT_LT(total_time - rest_time, executor_time + tokenizer_time); | |
| // Verifying tokenizer resolves tokens successfully without data bounds errors | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello concurrent world!")); | |
| ABSL_CHECK_OK(session.value()->RunPrefill(inputs)); | |
| auto responses = session.value()->RunDecode(); | |
| ASSERT_OK(responses); | |
| EXPECT_EQ(responses.value().GetTexts().size(), 1); | |
| EXPECT_FALSE(responses.value().GetTexts()[0].empty()); | |
| } | |
| TEST(EngineTest, CreateEngine_FailsNoVisionModel) { | |
| auto task_path = | |
| std::filesystem::path(::testing::SrcDir()) / | |
| "litert_lm/runtime/testdata/test_lm_new_metadata.task"; | |
| auto model_assets = ModelAssets::Create(task_path.string()); | |
| ASSERT_OK(model_assets); | |
| auto engine_settings = EngineSettings::CreateDefault( | |
| *model_assets, /*backend=*/Backend::CPU, /*vision_backend=*/Backend::CPU, | |
| /*audio_backend=*/std::nullopt); | |
| engine_settings->GetMutableMainExecutorSettings().SetMaxNumTokens( | |
| kMaxNumTokens); | |
| engine_settings->GetMutableMainExecutorSettings().SetCacheDir(":nocache"); | |
| EXPECT_THAT(EngineFactory::CreateAny(*engine_settings), | |
| testing::status::StatusIs( | |
| absl::StatusCode::kNotFound, | |
| "TF_LITE_VISION_ENCODER not found in the model.")); | |
| } | |
| TEST(EngineTest, CreateEngine_FailsNoAudioModel) { | |
| auto task_path = | |
| std::filesystem::path(::testing::SrcDir()) / | |
| "litert_lm/runtime/testdata/test_lm_new_metadata.task"; | |
| auto model_assets = ModelAssets::Create(task_path.string()); | |
| ASSERT_OK(model_assets); | |
| auto engine_settings = EngineSettings::CreateDefault( | |
| *model_assets, /*backend=*/Backend::CPU, /*vision_backend=*/std::nullopt, | |
| /*audio_backend=*/Backend::CPU); | |
| engine_settings->GetMutableMainExecutorSettings().SetMaxNumTokens( | |
| kMaxNumTokens); | |
| engine_settings->GetMutableMainExecutorSettings().SetCacheDir(":nocache"); | |
| EXPECT_THAT(EngineFactory::CreateAny(*engine_settings), | |
| testing::status::StatusIs( | |
| absl::StatusCode::kNotFound, | |
| "TF_LITE_AUDIO_ENCODER_HW not found in the model.")); | |
| } | |
| // TODO (b/397975034): Add more tests for Engine. | |
| } // namespace | |
| } // namespace litert::lm | |