Spaces:
Running
Running
| // Copyright 2025 The ODML Authors. | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| namespace litert::lm { | |
| namespace { | |
| // The name of the prefill decode model in the task bundle. | |
| constexpr char kPrefilDecodeModelNameInTaskBundle[] = "TF_LITE_PREFILL_DECODE"; | |
| // Possible input tokens names: | |
| constexpr std::array<absl::string_view, 2> kInputTokensNames = {"token_ids", | |
| "tokens"}; | |
| // Possible input positions names: | |
| constexpr std::array<absl::string_view, 2> kInputPositionsNames = {"positions", | |
| "input_pos"}; | |
| // Possible input attention mask names: | |
| constexpr std::array<absl::string_view, 2> kInputAttnMaskNames = {"attn_mask", | |
| "mask"}; | |
| // Possible embedding names: | |
| constexpr std::array<absl::string_view, 1> kEmbeddingNames = {"embeddings"}; | |
| // Possible per layer embedding names: | |
| constexpr std::array<absl::string_view, 1> kPerLayerEmbeddingNames = { | |
| "per_layer_embeddings"}; | |
| // Possible input int32 param names: | |
| constexpr std::array<absl::string_view, 1> kInputInt32ParamNames = { | |
| "param_tensor"}; | |
| // Possible output logits names: | |
| constexpr std::array<absl::string_view, 1> kOutputLogitsNames = {"logits"}; | |
| absl::StatusOr<std::unique_ptr<ModelResources>> | |
| BuildModelResourcesFromTaskFormat(const ModelAssets& model_assets) { | |
| std::unique_ptr<ModelAssetBundleResources> resources; | |
| if (model_assets.HasMemoryMappedFile()) { | |
| ASSIGN_OR_RETURN(auto memory_mapped_file, | |
| model_assets.GetMemoryMappedFile()); | |
| ASSIGN_OR_RETURN(resources, ModelAssetBundleResources::Create( | |
| /*tag=*/"", memory_mapped_file)); | |
| } else { | |
| ASSIGN_OR_RETURN(auto scoped_file, model_assets.GetOrCreateScopedFile()); | |
| ASSIGN_OR_RETURN(resources, ModelAssetBundleResources::Create( | |
| /*tag=*/"", scoped_file)); | |
| } | |
| auto files_list = resources->ListFiles(); | |
| RET_CHECK(std::find(files_list.begin(), files_list.end(), | |
| kPrefilDecodeModelNameInTaskBundle) != files_list.end()) | |
| << kPrefilDecodeModelNameInTaskBundle | |
| << " model file not found in task bundle."; | |
| return ModelResourcesTask::Create(std::move(resources)); | |
| } | |
| absl::StatusOr<std::unique_ptr<ModelResources>> | |
| BuildModelResourcesFromLitertLmFormat(const ModelAssets& model_assets) { | |
| std::unique_ptr<LitertLmLoader> loader; | |
| if (model_assets.HasMemoryMappedFile()) { | |
| ASSIGN_OR_RETURN(auto memory_mapped_file, | |
| model_assets.GetMemoryMappedFile()); | |
| loader = std::make_unique<LitertLmLoader>(memory_mapped_file); | |
| } else { | |
| // `BuildModelResourcesFromLitertLmFormat` expects a ScopedFile that it | |
| // takes ownership of, so we need to duplicate the ScopedFile to keep | |
| // the original alive. | |
| ASSIGN_OR_RETURN(auto scoped_file, model_assets.GetOrCreateScopedFile()); | |
| ASSIGN_OR_RETURN(auto duplicate_file, scoped_file->Duplicate()); | |
| loader = std::make_unique<LitertLmLoader>(std::move(duplicate_file)); | |
| } | |
| return ModelResourcesLitertLm::Create(std::move(loader)); | |
| } | |
| } // namespace | |
| absl::StatusOr<ModelSignatures> GetModelSignaturesFromInputOutputNames( | |
| const std::vector<absl::string_view>& input_names, | |
| const std::vector<absl::string_view>& output_names, bool strict) { | |
| ModelSignatures model_signatures; | |
| for (auto input_name : input_names) { | |
| if (absl::c_linear_search(kInputTokensNames, input_name)) { | |
| model_signatures.input_tokens = std::string(input_name); | |
| continue; | |
| } | |
| if (absl::c_linear_search(kInputPositionsNames, input_name)) { | |
| model_signatures.input_positions = std::string(input_name); | |
| continue; | |
| } | |
| if (absl::c_linear_search(kInputAttnMaskNames, input_name)) { | |
| model_signatures.input_attn_mask = std::string(input_name); | |
| continue; | |
| } | |
| if (absl::c_linear_search(kEmbeddingNames, input_name)) { | |
| model_signatures.input_embeddings = std::string(input_name); | |
| continue; | |
| } | |
| if (absl::c_linear_search(kPerLayerEmbeddingNames, input_name)) { | |
| model_signatures.input_per_layer_embeddings = std::string(input_name); | |
| continue; | |
| } | |
| if (absl::c_linear_search(kInputInt32ParamNames, input_name)) { | |
| model_signatures.input_int32_param = std::string(input_name); | |
| continue; | |
| } | |
| } | |
| for (auto output_name : output_names) { | |
| if (absl::c_linear_search(kOutputLogitsNames, output_name)) { | |
| model_signatures.output_logits = std::string(output_name); | |
| continue; | |
| } | |
| } | |
| if (strict) { | |
| RET_CHECK(!model_signatures.input_tokens.empty() || | |
| model_signatures.input_embeddings.has_value()) | |
| .SetCode(absl::StatusCode::kFailedPrecondition) | |
| << "Input tokens or embeddings not found."; | |
| RET_CHECK(!model_signatures.input_positions.empty()) | |
| .SetCode(absl::StatusCode::kFailedPrecondition) | |
| << "Input positions not found."; | |
| RET_CHECK(!model_signatures.output_logits.empty()) | |
| .SetCode(absl::StatusCode::kFailedPrecondition) | |
| << "Output logits not found."; | |
| } | |
| return model_signatures; | |
| } | |
| absl::Status GetKVCacheRootNames(std::vector<absl::string_view> input_names, | |
| std::vector<absl::string_view> output_names, | |
| std::string& k_root_name, | |
| std::string& v_root_name) { | |
| for (auto input_name : input_names) { | |
| if (input_name == "kv_cache_k_0") { | |
| k_root_name = "kv_cache_k_"; | |
| v_root_name = "kv_cache_v_"; | |
| return absl::OkStatus(); | |
| } else if (input_name == "k_cache_0") { | |
| k_root_name = "k_cache_"; | |
| v_root_name = "v_cache_"; | |
| return absl::OkStatus(); | |
| } else if (input_name == "kv_cache_c_0") { | |
| k_root_name = "kv_cache_c_"; | |
| v_root_name = "kv_cache_c_"; | |
| return absl::OkStatus(); | |
| } | |
| } | |
| for (auto output_name : output_names) { | |
| if (output_name == "kv_cache_k_0") { | |
| k_root_name = "kv_cache_k_"; | |
| v_root_name = "kv_cache_v_"; | |
| return absl::OkStatus(); | |
| } else if (output_name == "k_cache_0") { | |
| k_root_name = "k_cache_"; | |
| v_root_name = "v_cache_"; | |
| return absl::OkStatus(); | |
| } else if (output_name == "kv_cache_c_0") { | |
| k_root_name = "kv_cache_c_"; | |
| v_root_name = "kv_cache_c_"; | |
| return absl::OkStatus(); | |
| } | |
| } | |
| return absl::FailedPreconditionError("No KV cache inputs found."); | |
| } | |
| absl::StatusOr<SortedPrefillSignatureMap> GetPrefillRunnerSetFromModel( | |
| const ::litert::Model& model, absl::string_view signature_name_base, | |
| absl::string_view input_positions_name) { | |
| SortedPrefillSignatureMap prefill_runner_set; | |
| auto signatures = model.GetSignatures(); | |
| for (auto& signature : *signatures) { | |
| if (auto signature_key = signature.Key(); | |
| absl::StartsWith(signature_key, signature_name_base)) { | |
| LITERT_ASSIGN_OR_RETURN(auto input_positions_tensor, | |
| signature.InputTensor(input_positions_name)); | |
| LITERT_ASSIGN_OR_RETURN(auto ranked_tensor_type, | |
| input_positions_tensor.RankedTensorType()); | |
| if (ranked_tensor_type.Layout().Rank() == 2) { | |
| // [batch_size, max_seq_len] | |
| prefill_runner_set[ranked_tensor_type.Layout().Dimensions()[1]] = | |
| std::string(signature_key); | |
| } else if (ranked_tensor_type.Layout().Rank() == 1) { | |
| // [max_seq_len] | |
| prefill_runner_set[ranked_tensor_type.Layout().Dimensions()[0]] = | |
| std::string(signature_key); | |
| } else { | |
| return absl::FailedPreconditionError( | |
| "Unsupported input tokens tensor dimension."); | |
| } | |
| } | |
| } | |
| return prefill_runner_set; | |
| } | |
| absl::StatusOr<std::vector<std::pair<std::string, int>>> | |
| GetOptimizedPrefillWorkGroups( | |
| const SortedPrefillSignatureMap& prefill_runner_set, int input_length) { | |
| std::vector<std::pair<std::string, int>> work_groups; | |
| // Current strategy: | |
| // 1. Use the prefill runner with the largest sequence length, until the | |
| // remaining length is less than its sequence length. | |
| // 2. Finish the remaining length with one prefill call, using the runner with | |
| // the sequence length as small as possible. | |
| // TODO: b/378772479 - Improve this strategy once we have benchmarked costs. | |
| int max_seq_len = prefill_runner_set.begin()->first; | |
| while (input_length >= max_seq_len) { | |
| work_groups.push_back( | |
| std::make_pair(prefill_runner_set.begin()->second, max_seq_len)); | |
| input_length -= max_seq_len; | |
| } | |
| if (input_length > 0) { | |
| for (auto it = prefill_runner_set.begin(); it != prefill_runner_set.end(); | |
| ++it) { | |
| // If the next smaller runner can handle the remaining length, skip the | |
| // current runner. | |
| if (std::next(it) != prefill_runner_set.end() && | |
| std::next(it)->first >= input_length) { | |
| continue; | |
| } | |
| work_groups.push_back(std::make_pair(it->second, input_length)); | |
| break; | |
| } | |
| } | |
| return work_groups; | |
| } | |
| absl::Status InitializeAttentionMask(litert::TensorBuffer& mask, bool is_f16) { | |
| LITERT_ASSIGN_OR_RETURN(auto mask_size, mask.PackedSize()); | |
| LITERT_ASSIGN_OR_RETURN(auto mask_tensor_type, mask.TensorType()); | |
| LITERT_ASSIGN_OR_RETURN(auto mask_lock_and_addr, | |
| litert::TensorBufferScopedLock::Create( | |
| mask, litert::TensorBuffer::LockMode::kWrite)); | |
| switch (mask_tensor_type.ElementType()) { | |
| case litert::ElementType::Bool: | |
| // Boolean mask: Default value = false. | |
| memset(mask_lock_and_addr.second, 0, mask_size); | |
| break; | |
| case litert::ElementType::Float32: { | |
| // Float mask: Default value is based on precision. | |
| // Default value reference: | |
| // third_party/odml/infra/genai/inference/ml_drift/llm/tasks/apply_attention_mask_test_util.cc | |
| float* mask_ptr = static_cast<float*>(mask_lock_and_addr.second); | |
| std::fill(mask_ptr, mask_ptr + mask_size / sizeof(float), | |
| is_f16 ? -45824 : -0.7f * std::numeric_limits<float>::max()); | |
| break; | |
| } | |
| case litert::ElementType::Float16: { | |
| // Float16 mask: Default value is -45824. | |
| // This value is approximately -0.7 * MaxFloat16 (65504). | |
| // 0.7 * 65504 = 45852.8. Truncated to 45824. | |
| // It provides a margin of ~19680 before overflowing to -inf. | |
| tflite::half* mask_ptr = | |
| static_cast<tflite::half*>(mask_lock_and_addr.second); | |
| std::fill(mask_ptr, mask_ptr + mask_size / sizeof(tflite::half), | |
| tflite::half(-45824.0f)); | |
| break; | |
| } | |
| default: | |
| return absl::InvalidArgumentError( | |
| "Unsupported attention mask data type."); | |
| } | |
| return absl::OkStatus(); | |
| } | |
| absl::Status FillSingleBufferCacheParamTensor( | |
| litert::TensorBuffer& param_tensor, int start_index, int update_length) { | |
| // TODO(sulemanshahid): Local attention optimization is not supported in the | |
| // OpenCL implementation, enable for WebGPU. | |
| LITERT_ASSIGN_OR_RETURN(auto packed_size, param_tensor.PackedSize()); | |
| LITERT_ASSIGN_OR_RETURN( | |
| auto param_tensor_lock_and_addr, | |
| TensorBufferScopedLock::Create(param_tensor, | |
| TensorBuffer::LockMode::kWrite)); | |
| std::memset(param_tensor_lock_and_addr.second, 0, packed_size); | |
| // See parameter definition in ml_drift::LlmRuntimeParams. | |
| // First 2 parameters are used by add_values_to_cache kernel. | |
| // 3rd parameter is used by runtime_batched_matmul kernel to check the end | |
| // channel index, which doesn't have to be aligned as the kernel does that. | |
| int end_index = start_index + update_length; | |
| int32_t params[] = {start_index, end_index, end_index}; | |
| LITERT_RETURN_IF_ERROR(sizeof(params) <= packed_size); | |
| std::memcpy(param_tensor_lock_and_addr.second, params, sizeof(params)); | |
| return absl::OkStatus(); | |
| } | |
| absl::Status FillAttentionMask(litert::TensorBuffer& mask, int start_timestep, | |
| int steps) { | |
| LITERT_ASSIGN_OR_RETURN(auto mask_tensor_type, mask.TensorType()); | |
| RET_CHECK_EQ(mask_tensor_type.Layout().Rank(), 4) | |
| .SetCode(absl::StatusCode::kInvalidArgument) | |
| << "Attention mask must be 4D."; | |
| int batch_size = mask_tensor_type.Layout().Dimensions()[0]; | |
| int channel_size = mask_tensor_type.Layout().Dimensions()[3]; | |
| LITERT_ASSIGN_OR_RETURN(auto mask_size, mask.PackedSize()); | |
| LITERT_ASSIGN_OR_RETURN(auto mask_lock_and_addr, | |
| litert::TensorBufferScopedLock::Create( | |
| mask, litert::TensorBuffer::LockMode::kWrite)); | |
| int batch_offset = mask_size / batch_size; | |
| if (mask_tensor_type.ElementType() == litert::ElementType::Bool) { | |
| batch_offset /= sizeof(bool); | |
| } else if (mask_tensor_type.ElementType() == litert::ElementType::Float32) { | |
| batch_offset /= sizeof(float); | |
| } else if (mask_tensor_type.ElementType() == litert::ElementType::Float16) { | |
| batch_offset /= sizeof(tflite::half); | |
| } else { | |
| return absl::InvalidArgumentError("Unsupported attention mask data type."); | |
| } | |
| for (int b = 0; b < batch_size; ++b) { | |
| for (int i = 0; i < steps; ++i) { | |
| int current_step = start_timestep + i; | |
| int offset = b * batch_offset + i * channel_size; | |
| // For current step = n, we fill (n+1) positions for the mask sequence. | |
| if (mask_tensor_type.ElementType() == litert::ElementType::Bool) { | |
| // Boolean mask: Fill value = true. | |
| bool* bool_ptr = static_cast<bool*>(mask_lock_and_addr.second); | |
| std::fill(bool_ptr + offset, bool_ptr + offset + current_step + 1, | |
| true); | |
| } else if (mask_tensor_type.ElementType() == | |
| litert::ElementType::Float16) { | |
| // Float16 mask: Fill value = 0.0f. | |
| tflite::half* half_ptr = | |
| static_cast<tflite::half*>(mask_lock_and_addr.second); | |
| std::fill(half_ptr + offset, half_ptr + offset + current_step + 1, | |
| tflite::half(0.0f)); | |
| } else { // litert::ElementType::Float32, checked above. | |
| // Float mask: Fill value = 0.0f. | |
| float* float_ptr = static_cast<float*>(mask_lock_and_addr.second); | |
| std::fill(float_ptr + offset, float_ptr + offset + current_step + 1, | |
| 0.0f); | |
| } | |
| } | |
| } | |
| return absl::OkStatus(); | |
| } | |
| absl::StatusOr<std::unique_ptr<ModelResources>> | |
| BuildLiteRtCompiledModelResources(const ModelAssets& model_assets) { | |
| ASSIGN_OR_RETURN(auto format, GetFileFormat(model_assets)); | |
| switch (format) { | |
| case FileFormat::TASK: | |
| return BuildModelResourcesFromTaskFormat(model_assets); | |
| case FileFormat::LITERT_LM: | |
| return BuildModelResourcesFromLitertLmFormat(model_assets); | |
| default: | |
| return absl::InvalidArgumentError("Unsupported file format."); | |
| } | |
| } | |
| absl::Status GenericComputeTokenEmbeddings( | |
| const TensorBuffer& input_tokens, absl::Span<float> output_embeddings, | |
| absl::Span<float> output_ple_embeddings, | |
| EmbeddingLookupManager* embedding_lookup_manager, | |
| EmbeddingLookupManager* per_layer_embedding_lookup_manager) { | |
| LITERT_ASSIGN_OR_RETURN(auto input_tokens_span, | |
| ReferTensorBufferAsSpan<int32_t>(input_tokens)); | |
| const int num_tokens = input_tokens_span.size(); | |
| if (embedding_lookup_manager == nullptr) { | |
| return absl::InvalidArgumentError("Embedding lookup manager is missing."); | |
| } | |
| const int embedding_dim = | |
| embedding_lookup_manager->GetTextEmbeddingLookup()->GetFloatsPerToken(); | |
| auto output_buffer_type = | |
| embedding_lookup_manager->GetTextEmbeddingLookup()->GetOutputBufferType(); | |
| std::vector<int32_t> dims = {num_tokens, embedding_dim}; | |
| if (output_buffer_type.has_value()) { | |
| auto span_dims = output_buffer_type->Layout().Dimensions(); | |
| dims.assign(span_dims.begin(), span_dims.end()); | |
| dims[0] = 1; | |
| dims[1] = num_tokens; | |
| } | |
| auto tensor_type = MakeRankedTensorType<float>(dims); | |
| LITERT_ASSIGN_OR_RETURN( | |
| auto wrapped_embeddings, | |
| WrapOrCreateTensorBufferFromHostMemory(tensor_type, output_embeddings)); | |
| RETURN_IF_ERROR(embedding_lookup_manager->LookupPrefill( | |
| input_tokens_span, &wrapped_embeddings.buffer, 0 /*token_offset=*/)); | |
| if (per_layer_embedding_lookup_manager != nullptr && | |
| !output_ple_embeddings.empty()) { | |
| auto ple_output_buffer_type = | |
| per_layer_embedding_lookup_manager->GetTextEmbeddingLookup() | |
| ->GetOutputBufferType(); | |
| const int ple_embedding_dim = | |
| per_layer_embedding_lookup_manager->GetTextEmbeddingLookup() | |
| ->GetFloatsPerToken(); | |
| std::vector<int32_t> ple_dims = {num_tokens, ple_embedding_dim}; | |
| if (ple_output_buffer_type.has_value()) { | |
| auto ple_span_dims = ple_output_buffer_type->Layout().Dimensions(); | |
| ple_dims.assign(ple_span_dims.begin(), ple_span_dims.end()); | |
| ple_dims[0] = 1; | |
| ple_dims[1] = num_tokens; | |
| } | |
| auto ple_tensor_type = MakeRankedTensorType<float>(ple_dims); | |
| LITERT_ASSIGN_OR_RETURN(auto wrapped_ple_embeddings, | |
| WrapOrCreateTensorBufferFromHostMemory( | |
| ple_tensor_type, output_ple_embeddings)); | |
| RETURN_IF_ERROR(per_layer_embedding_lookup_manager->LookupPrefill( | |
| input_tokens_span, &wrapped_ple_embeddings.buffer, | |
| 0 /*token_offset=*/)); | |
| } | |
| return absl::OkStatus(); | |
| } | |
| } // namespace litert::lm | |