// Copyright 2026 The ODML Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/executor/audio_executor_utils.h" #include #include #include #include "absl/status/statusor.h" // from @com_google_absl #include "absl/strings/match.h" // from @com_google_absl #include "absl/strings/str_cat.h" // from @com_google_absl #include "absl/strings/string_view.h" // from @com_google_absl #include "litert/cc/litert_macros.h" // from @litert #include "litert/cc/litert_model.h" // from @litert #include "runtime/components/model_resources.h" #include "runtime/engine/io_types.h" #include "runtime/util/status_macros.h" namespace litert::lm { namespace { constexpr absl::string_view kPrevPrefix = "prev_"; constexpr absl::string_view kFeatureStatesNamePattern = "feature_state"; constexpr absl::string_view kSegmentMaskName = "segment_mask"; constexpr absl::string_view kMaskName = "mask"; bool IsStreamingEncoder(const std::vector& input_names) { // A huristic to check if the model is a streaming model by checking if the // input names contain the prev_mask name. return std::any_of(input_names.begin(), input_names.end(), [](absl::string_view input_name) { return absl::StrContains(input_name, kPrevPrefix); }); } } // namespace absl::StatusOr GetAudioExecutorPropertiesFromModelResources(ModelResources& model_resources) { AudioExecutorProperties properties; ASSIGN_OR_RETURN( auto audio_encoder_model, model_resources.GetTFLiteModel(ModelType::kTfLiteAudioEncoderHw)); LITERT_ASSIGN_OR_RETURN(auto input_names, audio_encoder_model->GetSignatureInputNames()); properties.is_streaming_model = IsStreamingEncoder(input_names); LITERT_ASSIGN_OR_RETURN( auto mask_tensor_type, audio_encoder_model->GetInputTensorType( 0, properties.is_streaming_model ? kSegmentMaskName : kMaskName)); LITERT_ASSIGN_OR_RETURN(int input_sequence_length, mask_tensor_type.Layout().NumElements()); ASSIGN_OR_RETURN( auto audio_adapter_model, model_resources.GetTFLiteModel(ModelType::kTfLiteAudioAdapter)); LITERT_ASSIGN_OR_RETURN(auto adapter_output_tensor_type, audio_adapter_model->GetOutputTensorType(0, 0)); int output_sequence_length = adapter_output_tensor_type.Layout().Dimensions() [adapter_output_tensor_type.Layout().Dimensions().size() - 2]; if (properties.is_streaming_model) { // Get the feature states tensor type and use it to get the overlap size. std::string feature_states_name = absl::StrCat(kFeatureStatesNamePattern, "_0"); LITERT_ASSIGN_OR_RETURN( auto feature_states_tensor_type, audio_encoder_model->GetInputTensorType(0, feature_states_name), _ << "The Audio Streaming Encoder model must have a feature_states " "input " "buffer."); // The overlap size is the number of elements in the feature states tensor, // which is 3 for gemma3n. LITERT_ASSIGN_OR_RETURN(properties.streaming_chunk_overlap_size, feature_states_tensor_type.Layout().NumElements()); // Get the segment mask tensor type and use it to get the chunk size. LITERT_ASSIGN_OR_RETURN( auto segment_mask_tensor_type, audio_encoder_model->GetInputTensorType(0, kSegmentMaskName), _ << "The Audio Streaming Encoder model must have a segment_mask input " "buffer."); // The chunk size is the last dimension of the segment mask tensor, which is // the number of frames in each segment. properties.streaming_chunk_size = segment_mask_tensor_type.Layout().Dimensions() [segment_mask_tensor_type.Layout().Dimensions().size() - 1]; properties.audio_shrink_factor = (input_sequence_length - properties.streaming_chunk_overlap_size) / output_sequence_length; } else { properties.audio_shrink_factor = input_sequence_length / output_sequence_length; } return properties; } } // namespace litert::lm