// Copyright 2025 The ODML Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/components/sampling_cpu_util.h" #include #include #include #include #include #include "absl/types/span.h" // from @com_google_absl namespace litert::lm { namespace { using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; TEST(SamplingCpuUtilTest, TopKTokenIds_BatchSize1) { const std::vector logits = {0.1, 0.5, 0.4, 0.2}; auto topk_token_ids = TopKTokenIds(absl::MakeConstSpan(logits), /*k=*/2, /*batch_size=*/1, /*sequence_size=*/1); ASSERT_TRUE(topk_token_ids.ok()); EXPECT_EQ(topk_token_ids->size(), 1); EXPECT_THAT((*topk_token_ids)[0], UnorderedElementsAre(1, 2)); } TEST(SamplingCpuUtilTest, TopKTokenIds_BatchSize2) { const std::vector logits = {0.1, 0.5, 0.4, 0.2}; auto topk_token_ids = TopKTokenIds(absl::MakeConstSpan(logits), /*k=*/1, /*batch_size=*/2, /*sequence_size=*/1); ASSERT_TRUE(topk_token_ids.ok()); EXPECT_EQ(topk_token_ids->size(), 2); EXPECT_THAT((*topk_token_ids)[0], ElementsAre(1)); EXPECT_THAT((*topk_token_ids)[1], ElementsAre(0)); } TEST(SamplingCpuUtilTest, TopKTokenIds_SequenceLength2) { const std::vector logits = {0.1, 0.5, 0.4, 0.2, 0.6, 0.3}; auto topk_token_ids = TopKTokenIds(absl::MakeConstSpan(logits), /*k=*/1, /*batch_size=*/1, /*sequence_size=*/2); ASSERT_TRUE(topk_token_ids.ok()); EXPECT_EQ(topk_token_ids->size(), 1); EXPECT_THAT((*topk_token_ids)[0], ElementsAre(1, 1)); } TEST(SamplingCpuUtilTest, Softmax_BatchSize1) { const std::vector logits = {0.1f, 0.1f}; const std::vector topk_indices = {0, 1}; std::vector> max_logit_values; auto probabilities = Softmax(absl::MakeConstSpan(logits), absl::MakeConstSpan(topk_indices), /*temperature=*/1.0, /*batch_size=*/1, /*sequence_size=*/1, max_logit_values); ASSERT_TRUE(probabilities.ok()); EXPECT_EQ(probabilities->size(), 1); EXPECT_THAT((*probabilities)[0], ElementsAre(0.5, 0.5)); EXPECT_EQ(max_logit_values.size(), 1); EXPECT_THAT(max_logit_values[0], ElementsAre(0.1f)); } TEST(SamplingCpuUtilTest, Softmax_AllZeroLogits) { const std::vector logits = {0.0f, 0.0f}; const std::vector topk_indices = {0, 1}; std::vector> max_logit_values; auto probabilities = Softmax(absl::MakeConstSpan(logits), absl::MakeConstSpan(topk_indices), /*temperature=*/1.0, /*batch_size=*/1, /*sequence_size=*/1, max_logit_values); ASSERT_TRUE(probabilities.ok()); EXPECT_EQ(probabilities->size(), 1); EXPECT_THAT((*probabilities)[0], ElementsAre(0.5, 0.5)); EXPECT_EQ(max_logit_values.size(), 1); EXPECT_THAT(max_logit_values[0], ElementsAre(0.0f)); } TEST(SamplingCpuUtilTest, Softmax_TemperatureVerySmall) { const std::vector logits = {0.0f, 1.0f, 2.0f}; const std::vector topk_indices = {0, 1, 2}; std::vector> max_logit_values; auto probabilities = Softmax(absl::MakeConstSpan(logits), absl::MakeConstSpan(topk_indices), /*temperature=*/0.00000001f, /*batch_size=*/1, /*sequence_size=*/1, max_logit_values); ASSERT_TRUE(probabilities.ok()); // Very small temperature should mimic greedy sampling. EXPECT_EQ(probabilities->size(), 1); EXPECT_THAT((*probabilities)[0], ElementsAre(0.0f, 0.0f, 1.0f)); EXPECT_EQ(max_logit_values.size(), 1); EXPECT_THAT(max_logit_values[0], ElementsAre(2.0f)); } TEST(SamplingCpuUtilTest, Softmax_TemperatureExactlyZero) { const std::vector logits = {0.0f, 1.0f, 2.0f}; const std::vector topk_indices = {0, 1, 2}; std::vector> max_logit_values; auto probabilities = Softmax(absl::MakeConstSpan(logits), absl::MakeConstSpan(topk_indices), /*temperature=*/0.0f, /*batch_size=*/1, /*sequence_size=*/1, max_logit_values); ASSERT_TRUE(probabilities.ok()); // Exactly zero temperature should mimic greedy sampling. EXPECT_EQ(probabilities->size(), 1); EXPECT_THAT((*probabilities)[0], ElementsAre(0.0f, 0.0f, 1.0f)); EXPECT_EQ(max_logit_values.size(), 1); EXPECT_THAT(max_logit_values[0], ElementsAre(2.0f)); } TEST(SamplingCpuUtilTest, Softmax_TemperatureInf) { const std::vector logits = {0.0f, 1.0f, 2.0f, 3.0f}; const std::vector topk_indices = {0, 1, 2, 3}; std::vector> max_logit_values; auto probabilities = Softmax(absl::MakeConstSpan(logits), absl::MakeConstSpan(topk_indices), /*temperature=*/100000000000.0f, /*batch_size=*/1, /*sequence_size=*/1, max_logit_values); ASSERT_TRUE(probabilities.ok()); // Very large temperature should mimic uniform sampling. EXPECT_EQ(probabilities->size(), 1); EXPECT_THAT((*probabilities)[0], ElementsAre(0.25f, 0.25f, 0.25f, 0.25f)); EXPECT_EQ(max_logit_values.size(), 1); EXPECT_THAT(max_logit_values[0], ElementsAre(3.0f)); } TEST(SamplingCpuUtilTest, Softmax_BatchSize3) { // Batch size of 3, vocab size of 2. const std::vector logits = {0.1f, 0.1f, 0.0f, 5.0f, 1.0f, 0.0f}; absl::Span logits_span = absl::MakeConstSpan(logits); const std::vector topk_indices = {0, 1, 0, 1, 0, 1}; absl::Span topk_indices_span = absl::MakeConstSpan(topk_indices); std::vector> max_logit_values; auto probabilities = Softmax(logits_span, topk_indices_span, /*temperature=*/1.0f, /*batch_size=*/3, /*sequence_size=*/1, max_logit_values); ASSERT_TRUE(probabilities.ok()); EXPECT_EQ(probabilities->size(), 3); EXPECT_EQ((*probabilities)[0].size(), 2); EXPECT_THAT((*probabilities)[0], ElementsAre(0.5f, 0.5f)); EXPECT_EQ((*probabilities)[1].size(), 2); EXPECT_THAT((*probabilities)[1], ElementsAre(0.00669285096f, 0.993307173f)); EXPECT_EQ((*probabilities)[2].size(), 2); EXPECT_THAT((*probabilities)[2], ElementsAre(0.731058598f, 0.268941432f)); EXPECT_EQ(max_logit_values.size(), 3); EXPECT_EQ(max_logit_values[0].size(), 1); EXPECT_THAT(max_logit_values[0], ElementsAre(0.1f)); EXPECT_EQ(max_logit_values[1].size(), 1); EXPECT_THAT(max_logit_values[1], ElementsAre(5.0f)); EXPECT_EQ(max_logit_values[2].size(), 1); EXPECT_THAT(max_logit_values[2], ElementsAre(1.0f)); } TEST(SamplingCpuUtilTest, TopKTopPSampling_InvalidInputs) { const std::vector probabilities = {0.0, 0.0, 0.3}; auto rng = std::make_shared(0); // Negative k. std::vector> sampled_scores; auto sampled_ids = TopKTopPSampling(absl::MakeConstSpan(probabilities), /*k=*/-1, /*p=*/0.5, /*temperature=*/1.0, rng, /*batch_size=*/1, /*sequence_size=*/1, sampled_scores); EXPECT_FALSE(sampled_ids.ok()); // Negative p. sampled_ids = TopKTopPSampling(absl::MakeConstSpan(probabilities), /*k=*/1, /*p=*/-0.5, /*temperature=*/1.0f, rng, /*sequence_size=*/1, /*batch_size=*/1, sampled_scores); EXPECT_FALSE(sampled_ids.ok()); } TEST(SamplingCpuUtilTest, TopKTopPSampling_BatchSize1) { const std::vector probabilities = {0.0, 0.0, 0.3}; auto rng = std::make_shared(0); std::vector> sampled_scores; auto sampled_ids = TopKTopPSampling(absl::MakeConstSpan(probabilities), /*k=*/1, /*p=*/0.5, /*temperature=*/1.0f, rng, /*batch_size=*/1, /*sequence_size=*/1, sampled_scores); ASSERT_TRUE(sampled_ids.ok()); EXPECT_EQ(sampled_ids->size(), 1); EXPECT_THAT((*sampled_ids)[0], ElementsAre(2)); EXPECT_EQ(sampled_scores.size(), 1); EXPECT_THAT(sampled_scores[0], ElementsAre(1.0)); } TEST(SamplingCpuUtilTest, TopKTopPSampling_BatchSize1_TopK) { // Test that the sampler does return a sampled token from the top k // instead of always returning the first or the last token. const std::vector logits = {-1.0e7f, 1.0f, -1e3f}; auto rng = std::make_shared(0); std::vector> sampled_scores; auto sampled_ids = TopKTopPSampling(absl::MakeConstSpan(logits), /*k=*/3, /*p=*/1.0, /*temperature=*/1.0f, rng, /*batch_size=*/1, /*sequence_size=*/1, sampled_scores); ASSERT_TRUE(sampled_ids.ok()); EXPECT_EQ(sampled_ids->size(), 1); EXPECT_THAT((*sampled_ids)[0], ElementsAre(1)); EXPECT_EQ(sampled_scores.size(), 1); EXPECT_THAT(sampled_scores[0], ElementsAre(1.0)); } TEST(SamplingCpuUtilTest, TopKTopPSampling_BatchSize3) { // Batch of 3, vocab size of 3. The sampled ids are 2, 1, 0. const std::vector logits = {0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0}; auto rng = std::make_shared(0); std::vector> sampled_scores; auto sampled_ids = TopKTopPSampling(absl::MakeConstSpan(logits), /*k=*/2, /*p=*/0.5, /*temperature=*/0.00001f, rng, /*batch_size=*/3, /*sequence_size=*/1, sampled_scores); ASSERT_TRUE(sampled_ids.ok()); EXPECT_EQ(sampled_ids->size(), 3); EXPECT_EQ((*sampled_ids)[0][0], 2); EXPECT_EQ((*sampled_ids)[1][0], 1); EXPECT_EQ((*sampled_ids)[2][0], 0); EXPECT_EQ(sampled_scores.size(), 3); EXPECT_EQ(sampled_scores[0].size(), 1); EXPECT_EQ(sampled_scores[1].size(), 1); EXPECT_EQ(sampled_scores[2].size(), 1); } TEST(SamplingCpuUtilTest, TopKTopPSampling_LargeVocabIndices) { // Tests that sampling works correctly when top-k topk_token_ids are larger // than k. This exposes a bug where vocab topk_token_ids were incorrectly used // as offsets. std::vector logits = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 10.0}; auto rng = std::make_shared(0); std::vector> sampled_scores; // Top p = 0.0001f, should always return the token with the highest logit, // which is the 14th token in this case. auto sampled_ids = TopKTopPSampling(absl::MakeConstSpan(logits), /*k=*/15, /*p=*/0.0001f, /*temperature=*/1.0f, rng, /*batch_size=*/1, /*sequence_size=*/1, sampled_scores); ASSERT_TRUE(sampled_ids.ok()); // With very small temperature, it should pick the logit with the highest // value, which is at index 11. EXPECT_EQ(sampled_ids->size(), 1); EXPECT_THAT((*sampled_ids)[0], ElementsAre(14)); EXPECT_EQ(sampled_scores.size(), 1); EXPECT_THAT(sampled_scores[0], ElementsAre(0.99827528f)); } TEST(SamplingCpuUtilTest, Softmax_BatchSize2SequenceLength2) { // Batch size 2, sequence length 2, vocab size 3. const std::vector logits = { 0.1f, 0.1f, 0.2f, // Batch 0, Sequence 0 0.0f, 5.0f, 1.0f, // Batch 0, Sequence 1 1.0f, 0.0f, 2.0f, // Batch 1, Sequence 0 0.0f, 0.0f, 0.0f // Batch 1, Sequence 1 }; absl::Span logits_span = absl::MakeConstSpan(logits); const std::vector topk_indices = { 1, 2, // Batch 0, Sequence 0 1, 2, // Batch 0, Sequence 1 0, 2, // Batch 1, Sequence 0 0, 1 // Batch 1, Sequence 1 }; absl::Span topk_indices_span = absl::MakeConstSpan(topk_indices); std::vector> max_logit_values; auto probabilities = Softmax(logits_span, topk_indices_span, /*temperature=*/1.0f, /*batch_size=*/2, /*sequence_size=*/2, max_logit_values); ASSERT_TRUE(probabilities.ok()); EXPECT_EQ(probabilities->size(), 2); EXPECT_EQ((*probabilities)[0].size(), 4); EXPECT_THAT((*probabilities)[0], ElementsAre(testing::FloatNear(0.47502f, 1e-4f), testing::FloatNear(0.52498f, 1e-4f), testing::FloatNear(0.98201f, 1e-4f), testing::FloatNear(0.017986f, 1e-4f))); EXPECT_EQ((*probabilities)[1].size(), 4); EXPECT_THAT((*probabilities)[1], ElementsAre(testing::FloatNear(0.26894f, 1e-4f), testing::FloatNear(0.73106f, 1e-4f), testing::FloatNear(0.5f, 1e-4f), testing::FloatNear(0.5f, 1e-4f))); EXPECT_EQ(max_logit_values.size(), 2); EXPECT_EQ(max_logit_values[0].size(), 2); EXPECT_THAT(max_logit_values[0], ElementsAre(0.2f, 5.0f)); EXPECT_EQ(max_logit_values[1].size(), 2); EXPECT_THAT(max_logit_values[1], ElementsAre(2.0f, 0.0f)); } TEST(SamplingCpuUtilTest, TopKTopPSampling_BatchSize2SequenceLength2) { // Batch of 2, sequence length 2, vocab size of 3. const std::vector logits = { 0.0f, 0.0f, 1.0f, // b0, s0: top is 2 1.0f, 0.0f, 0.0f, // b0, s1: top is 0 0.0f, 1.0f, 0.0f, // b1, s0: top is 1 0.0f, 0.0f, 1.0f // b1, s1: top is 2 }; auto rng = std::make_shared(0); std::vector> sampled_scores; auto sampled_ids = TopKTopPSampling(absl::MakeConstSpan(logits), /*k=*/2, /*p=*/0.5f, /*temperature=*/0.00001f, rng, /*batch_size=*/2, /*sequence_size=*/2, sampled_scores); ASSERT_TRUE(sampled_ids.ok()); EXPECT_EQ(sampled_ids->size(), 2); EXPECT_THAT((*sampled_ids)[0], ElementsAre(2, 0)); EXPECT_THAT((*sampled_ids)[1], ElementsAre(1, 2)); EXPECT_EQ(sampled_scores.size(), 2); EXPECT_EQ(sampled_scores[0].size(), 2); EXPECT_EQ(sampled_scores[1].size(), 2); } } // namespace } // namespace litert::lm