Spaces:
Running
Running
| // Copyright 2026 The ODML Authors. | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| namespace litert::lm { | |
| namespace { | |
| constexpr absl::string_view kPrevPrefix = "prev_"; | |
| constexpr absl::string_view kFeatureStatesNamePattern = "feature_state"; | |
| constexpr absl::string_view kSegmentMaskName = "segment_mask"; | |
| constexpr absl::string_view kMaskName = "mask"; | |
| bool IsStreamingEncoder(const std::vector<absl::string_view>& input_names) { | |
| // A huristic to check if the model is a streaming model by checking if the | |
| // input names contain the prev_mask name. | |
| return std::any_of(input_names.begin(), input_names.end(), | |
| [](absl::string_view input_name) { | |
| return absl::StrContains(input_name, kPrevPrefix); | |
| }); | |
| } | |
| } // namespace | |
| absl::StatusOr<AudioExecutorProperties> | |
| GetAudioExecutorPropertiesFromModelResources(ModelResources& model_resources) { | |
| AudioExecutorProperties properties; | |
| ASSIGN_OR_RETURN( | |
| auto audio_encoder_model, | |
| model_resources.GetTFLiteModel(ModelType::kTfLiteAudioEncoderHw)); | |
| LITERT_ASSIGN_OR_RETURN(auto input_names, | |
| audio_encoder_model->GetSignatureInputNames()); | |
| properties.is_streaming_model = IsStreamingEncoder(input_names); | |
| LITERT_ASSIGN_OR_RETURN( | |
| auto mask_tensor_type, | |
| audio_encoder_model->GetInputTensorType( | |
| 0, properties.is_streaming_model ? kSegmentMaskName : kMaskName)); | |
| LITERT_ASSIGN_OR_RETURN(int input_sequence_length, | |
| mask_tensor_type.Layout().NumElements()); | |
| ASSIGN_OR_RETURN( | |
| auto audio_adapter_model, | |
| model_resources.GetTFLiteModel(ModelType::kTfLiteAudioAdapter)); | |
| LITERT_ASSIGN_OR_RETURN(auto adapter_output_tensor_type, | |
| audio_adapter_model->GetOutputTensorType(0, 0)); | |
| int output_sequence_length = | |
| adapter_output_tensor_type.Layout().Dimensions() | |
| [adapter_output_tensor_type.Layout().Dimensions().size() - 2]; | |
| if (properties.is_streaming_model) { | |
| // Get the feature states tensor type and use it to get the overlap size. | |
| std::string feature_states_name = | |
| absl::StrCat(kFeatureStatesNamePattern, "_0"); | |
| LITERT_ASSIGN_OR_RETURN( | |
| auto feature_states_tensor_type, | |
| audio_encoder_model->GetInputTensorType(0, feature_states_name), | |
| _ << "The Audio Streaming Encoder model must have a feature_states " | |
| "input " | |
| "buffer."); | |
| // The overlap size is the number of elements in the feature states tensor, | |
| // which is 3 for gemma3n. | |
| LITERT_ASSIGN_OR_RETURN(properties.streaming_chunk_overlap_size, | |
| feature_states_tensor_type.Layout().NumElements()); | |
| // Get the segment mask tensor type and use it to get the chunk size. | |
| LITERT_ASSIGN_OR_RETURN( | |
| auto segment_mask_tensor_type, | |
| audio_encoder_model->GetInputTensorType(0, kSegmentMaskName), | |
| _ << "The Audio Streaming Encoder model must have a segment_mask input " | |
| "buffer."); | |
| // The chunk size is the last dimension of the segment mask tensor, which is | |
| // the number of frames in each segment. | |
| properties.streaming_chunk_size = | |
| segment_mask_tensor_type.Layout().Dimensions() | |
| [segment_mask_tensor_type.Layout().Dimensions().size() - 1]; | |
| properties.audio_shrink_factor = | |
| (input_sequence_length - properties.streaming_chunk_overlap_size) / | |
| output_sequence_length; | |
| } else { | |
| properties.audio_shrink_factor = | |
| input_sequence_length / output_sequence_length; | |
| } | |
| return properties; | |
| } | |
| } // namespace litert::lm | |