// Copyright 2025 The ODML Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/components/sampling_cpu_util.h" #include #include #include #include #include #include #include #include #include #include "absl/status/status.h" // from @com_google_absl #include "absl/status/statusor.h" // from @com_google_absl #include "absl/strings/str_cat.h" // from @com_google_absl #include "absl/strings/str_format.h" // from @com_google_absl #include "absl/types/span.h" // from @com_google_absl namespace litert::lm { absl::StatusOr>> TopKTokenIds( absl::Span logits, int k, int batch_size, int sequence_size) { if (logits.size() % (batch_size * sequence_size) != 0) { return absl::InvalidArgumentError( absl::StrFormat("Logits vector size must be a multiple of batch " "size * sequence_size. But got %d and " "%d * %d.", logits.size(), batch_size, sequence_size)); } const int vocab_size = logits.size() / (batch_size * sequence_size); std::vector> output_indices( batch_size, std::vector(sequence_size * k)); if (k == 1) { // Greedy sampling. Use std::max_element to be more efficient. for (int b = 0; b < batch_size; ++b) { for (int s = 0; s < sequence_size; ++s) { auto max_iterator = std::max_element( logits.begin() + (b * sequence_size + s) * vocab_size, logits.begin() + (b * sequence_size + s + 1) * vocab_size); // Gets the index of the max element in the logits of each batch. const int max_logit_idx = std::distance(logits.begin() + (b * sequence_size + s) * vocab_size, max_iterator); output_indices[b][s] = max_logit_idx; } } } else { std::vector indices(vocab_size); for (int b = 0; b < batch_size; ++b) { for (int s = 0; s < sequence_size; ++s) { // Fill with 0, 1, 2,... std::iota(indices.begin(), indices.end(), 0); // Define the comparator for descending probability auto desc_prob_comp = [&logits, vocab_size, b, s, sequence_size]( int i1, int i2) { int offset = (b * sequence_size + s) * vocab_size; return logits[offset + i1] > logits[offset + i2]; }; // Partition Top-K. // O(N) average time complexity. // Rearranges 'indices' such that the k elements with the highest // probabilities are in the range [indices.begin(), indices.begin() + // k). The element at indices[k] is not necessarily the (k+1)th largest. std::nth_element(indices.begin(), indices.begin() + k, indices.end(), desc_prob_comp); std::copy(indices.begin(), indices.begin() + k, output_indices[b].begin() + s * k); } } } return output_indices; } absl::StatusOr>> Softmax( absl::Span logits, absl::Span topk_token_ids, float temperature, int batch_size, int sequence_size, std::vector>& max_logit_values) { if (logits.empty()) { return absl::InvalidArgumentError("Logits vector cannot be empty."); } if (logits.size() % (batch_size * sequence_size) != 0) { return absl::InvalidArgumentError(absl::StrFormat( "Logits vector size must be a multiple of batch size * sequence " "length. But got %d and %d * %d.", logits.size(), batch_size, sequence_size)); } if (temperature < 0.0) { // A very small positive temperature can mimic greedy sampling. // If temp = 0, we will clamp it to epsilon later. return absl::InvalidArgumentError( absl::StrCat("Temperature must be >= 0, but got ", temperature)); } const int vocab_size = logits.size() / (batch_size * sequence_size); const int k = topk_token_ids.size() / (batch_size * sequence_size); std::vector> probabilities( batch_size, std::vector(sequence_size * k)); max_logit_values.assign(batch_size, std::vector(sequence_size)); for (size_t b = 0; b < batch_size; ++b) { for (size_t s = 0; s < sequence_size; ++s) { // Define the comparator for ascending logit to be used with // std::max_element. The comparator takes two indices from topk_token_ids // and compares the corresponding logit values. auto logit_comp = [&logits, vocab_size, b, s, sequence_size]( int topk_idx1, int topk_idx2) { int offset = (b * sequence_size + s) * vocab_size; return logits[offset + topk_idx1] < logits[offset + topk_idx2]; }; int topk_offset = (b * sequence_size + s) * k; // Use std::max_element to find the index in topk_token_ids that // corresponds to the maximum logit value. auto max_iterator = std::max_element( topk_token_ids.begin() + topk_offset, topk_token_ids.begin() + topk_offset + k, logit_comp); const int max_logit_idx = std::distance(topk_token_ids.begin() + topk_offset, max_iterator); int offset = (b * sequence_size + s) * vocab_size; max_logit_values[b][s] = logits[offset + topk_token_ids[topk_offset + max_logit_idx]]; float sum_of_exps = 0.0; float current_temp = std::max(temperature, std::numeric_limits::epsilon()); for (size_t i = 0; i < k; ++i) { probabilities[b][s * k + i] = std::exp((logits[offset + topk_token_ids[topk_offset + i]] - max_logit_values[b][s]) / current_temp); sum_of_exps += probabilities[b][s * k + i]; } if (sum_of_exps <= std::numeric_limits::epsilon()) { // Handle potential zero sum (uniform distribution fallback) float uniform_prob = 1.0 / static_cast(k); std::fill(probabilities[b].begin() + s * k, probabilities[b].begin() + (s + 1) * k, uniform_prob); } else if (std::isinf(sum_of_exps)) { // Handle inf sum which is caused by very small temperature. // This is to avoid inf sum in the softmax calculation. std::fill(probabilities[b].begin() + s * k, probabilities[b].begin() + (s + 1) * k, 0.0f); probabilities[b][s * k + max_logit_idx] = 1.0f; } else { // Normalize probabilities float inv_sum = 1.0 / sum_of_exps; // Calculate inverse once for slight speedup for (size_t i = 0; i < k; ++i) { probabilities[b][s * k + i] *= inv_sum; } } } } return probabilities; }; absl::StatusOr>> TopKTopPSampling( absl::Span logits, int k, float p, float temperature, std::shared_ptr rng, int batch_size, int sequence_size, std::vector>& sampled_scores) { if (logits.empty()) { return absl::InvalidArgumentError("Logits vector cannot be empty."); } if (logits.size() % (batch_size * sequence_size) != 0) { return absl::InvalidArgumentError( absl::StrFormat("Logits vector size must be a multiple of batch " "size * sequence length. But got %d and %d * %d.", logits.size(), batch_size, sequence_size)); } if (k <= 0) { return absl::InvalidArgumentError("k must be greater than 0."); } if (p < 0.0 || p > 1.0) { return absl::InvalidArgumentError("p must be in the range [0.0, 1.0]."); } if (rng == nullptr) { return absl::InvalidArgumentError("rng cannot be nullptr."); } const int vocab_size = logits.size() / (batch_size * sequence_size); // Ensure k is not larger than the number of probabilities k = std::min(k, vocab_size); auto topk_token_ids = TopKTokenIds(logits, k, batch_size, sequence_size); if (!topk_token_ids.ok()) return topk_token_ids.status(); std::vector flat_topk_token_ids(batch_size * sequence_size * k); for (int b = 0; b < batch_size; ++b) { std::copy((*topk_token_ids)[b].begin(), (*topk_token_ids)[b].end(), flat_topk_token_ids.begin() + b * sequence_size * k); } std::vector> max_logit_values; auto probabilities = Softmax(logits, flat_topk_token_ids, temperature, batch_size, sequence_size, max_logit_values); if (!probabilities.ok()) return probabilities.status(); std::vector> sampled_ids(batch_size, std::vector(sequence_size)); sampled_scores.assign(batch_size, std::vector(sequence_size)); if (k == 1) { // Greedy sampling. Return the topk_token_ids directly. for (int b = 0; b < batch_size; ++b) { for (int s = 0; s < sequence_size; ++s) { sampled_ids[b][s] = (*topk_token_ids)[b][s]; sampled_scores[b][s] = 1.0f; } } return sampled_ids; } float current_temp = std::max(temperature, std::numeric_limits::epsilon()); for (int b = 0; b < batch_size; ++b) { for (int s = 0; s < sequence_size; ++s) { // Define the comparator for descending probability auto desc_prob_comp = [&probabilities, b, s, k](int i1, int i2) { return (*probabilities)[b][s * k + i1] > (*probabilities)[b][s * k + i2]; }; // Sort Only the Top-K. // O(k log k) time complexity. // Sorts the indices of the top k elements based on their probabilities. // - index_of_topk[i] stores the offset within the current batch's top-k // range. std::vector index_of_topk(k); std::iota(index_of_topk.begin(), index_of_topk.end(), 0); std::sort(index_of_topk.begin(), index_of_topk.end(), desc_prob_comp); // Determine Top-P Cutoff Index within Top-K. // O(k) time complexity. double cumulative_prob = 0.0; int final_sample_size = 0; // Actual number of elements to sample from for (int i = 0; i < k; ++i) { // Check if adding this probability would exceed the threshold p. It // stops when cumulative_prob >= p. cumulative_prob += (*probabilities)[b][s * k + index_of_topk[i]]; final_sample_size = i + 1; // Include this element if (cumulative_prob >= p) { break; // Found the smallest set within Top-K satisfying Top-P } } // final_sample_size now holds min(p_cutoff_within_top_k, k) int topk_offset = b * sequence_size * k + s * k; int offset = (b * sequence_size + s) * vocab_size; // Handle Edge Case: Cumulative Probability is Zero. if (cumulative_prob <= std::numeric_limits::epsilon()) { // Fallback: Return the index with the absolute highest probability // (indices[0] after sorting top-k). sampled_ids[b][s] = flat_topk_token_ids[topk_offset + index_of_topk[0]]; sampled_scores[b][s] = std::exp( (logits[offset + sampled_ids[b][s]] - max_logit_values[b][s]) / current_temp); continue; } // O(final_sample_size) which is O(k) time complexity. std::uniform_real_distribution dist(0.0, cumulative_prob); double random_sample = dist(*rng); double current_cumulative = 0.0; for (int i = 0; i < final_sample_size; ++i) { current_cumulative += (*probabilities)[b][s * k + index_of_topk[i]]; if (random_sample <= current_cumulative) { sampled_ids[b][s] = flat_topk_token_ids[topk_offset + index_of_topk[i]]; sampled_scores[b][s] = (*probabilities)[b][s * k + index_of_topk[i]]; break; } } } } return sampled_ids; } } // namespace litert::lm