// Copyright 2025 The ODML Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_ODML_LITERT_LM_RUNTIME_EXECUTOR_AUDIO_LITERT_COMPILED_MODEL_EXECUTOR_H_ #define THIRD_PARTY_ODML_LITERT_LM_RUNTIME_EXECUTOR_AUDIO_LITERT_COMPILED_MODEL_EXECUTOR_H_ #include #include #include #include #include #include "absl/base/nullability.h" // from @com_google_absl #include "absl/container/flat_hash_map.h" // from @com_google_absl #include "absl/status/status.h" // from @com_google_absl #include "absl/status/statusor.h" // from @com_google_absl #include "absl/strings/string_view.h" // from @com_google_absl #include "absl/types/span.h" // from @com_google_absl #include "litert/cc/litert_compiled_model.h" // from @litert #include "litert/cc/litert_environment.h" // from @litert #include "litert/cc/litert_model.h" // from @litert #include "litert/cc/litert_tensor_buffer.h" // from @litert #include "runtime/components/model_resources.h" #include "runtime/engine/io_types.h" #include "runtime/executor/audio_executor.h" #include "runtime/executor/audio_executor_settings.h" #include "runtime/executor/llm_executor_io_types.h" namespace litert::lm { // The context for streaming audio encoder model, which contains // the state buffers of the audio encoder model. class AudioStreamingContext : public AudioContext { public: explicit AudioStreamingContext( absl::flat_hash_map state_buffers) : state_buffers_(std::move(state_buffers)) {}; absl::StatusOr> Clone() const override; absl::flat_hash_map& state_buffers() { return state_buffers_; } private: // The state buffers of the audio encoder model. It includes the kv caches and // the convolution features and masks of the last timestamp. absl::flat_hash_map state_buffers_; }; // The Audio Executor that uses the LiteRT CompiledModel to run the audio // encoder and audio adapter models to encode the spectrogram tensor into audio // soft token embeddings. class AudioLiteRtCompiledModelExecutor : public AudioExecutor { public: // Create an AudioLiteRtCompiledModelExecutor to encode the spectrogram // LiteRT TensorBuffer into audio embeddings LiteRT TensorBuffer. // Args: // - executor_settings: The audio executor settings. // - env: The LiteRT environment. // Returns: // A unique pointer to the AudioLiteRtCompiledModelExecutor if successful, // or an error status if failed. static absl::StatusOr> Create(AudioExecutorSettings executor_settings, Environment& env); // Run the audio encoder and audio adapter models to encode the spectrogram // tensor into audio embeddings. It is caller's responsibility to ensure the // spectrogram tensor is valid and has the correct shape. It is assumed that // all the timestamps in the spectrogram tensor are valid. // Args: // - spectrogram_tensor: The spectrogram tensor to encode, in shape of // [..., timestamp, frequency_bins]. // Returns: // A ExecutorAudioData object containing the audio embeddings and the // number of valid tokens. absl::StatusOr Encode( const TensorBuffer& spectrogram_tensor) override; // Run the audio encoder and audio adapter models to encode the spectrogram // tensor into audio embeddings. It is caller's responsibility to ensure the // spectrogram tensor is valid and has the correct shape. // The spectrogram mask is used to indicate the valid timestamps in the // spectrogram tensor. // Args: // - spectrogram_tensor: The spectrogram tensor to encode, in shape of // [..., timestamp, frequency_bins]. // - spectrogram_mask: The spectrogram mask to indicate the valid timestamps // in the spectrogram tensor, in shape of [..., timestamp]. // Returns: // A ExecutorAudioData object containing the audio embeddings and the // number of valid tokens. absl::StatusOr Encode( const TensorBuffer& spectrogram_tensor, const TensorBuffer& spectrogram_mask); // Reset the audio encoder, which will be a stateful object when streaming // model is used. absl::Status Reset() override { return audio_encoder_->Reset(); } // Get the audio executor properties. absl::StatusOr GetAudioExecutorProperties() const override { return executor_properties_; } // Create a new audio context for the audio executor. absl::StatusOr> CreateNewContext() override; // Clone the audio context for the audio executor. absl::StatusOr> CloneContext() override; // Clone the audio context from the given audio context. absl::StatusOr> CloneContext( const AudioContext& audio_context) override; // Restore the audio context for the audio executor. absl::Status RestoreContext( std::unique_ptr audio_context) override; private: // The Audio Encoder LiteRT CompiledModel wrapper manage the input and // output buffers of the audio encoder model. It is not expected to be used // directly by the user. It is used by the AudioLiteRtCompiledModelExecutor // to encode the spectrogram tensor into audio embeddings. The user should // use the AudioLiteRtCompiledModelExecutor instead. class AudioEncoder { public: virtual ~AudioEncoder() = default; virtual absl::Status Initialize() = 0; virtual absl::Status ClearInputBuffers() = 0; virtual absl::Status Reset() = 0; const CompiledModel& GetCompiledModel() const { return compiled_model_; } CompiledModel& GetMutableCompiledModel() { return compiled_model_; } const absl::flat_hash_map& GetInputBuffersMap() const { return input_buffers_map_; } absl::flat_hash_map& GetMutableInputBuffersMap() { return input_buffers_map_; } const absl::flat_hash_map& GetOutputBuffersMap() const { return output_buffers_map_; } absl::flat_hash_map& GetMutableOutputBuffersMap() { return output_buffers_map_; } const TensorBuffer& GetInputMaskBuffer() const { return *input_mask_buffer_; } TensorBuffer& GetMutableInputMaskBuffer() { return *input_mask_buffer_; } const TensorBuffer& GetInputSpectrogramBuffer() const { return *spectrogram_buffer_; } TensorBuffer& GetMutableInputSpectrogramBuffer() { return *spectrogram_buffer_; } const TensorBuffer& GetOutputMaskBuffer() const { return *output_mask_buffer_; } TensorBuffer& GetMutableOutputMaskBuffer() { return *output_mask_buffer_; } const TensorBuffer& GetOutputFeaturesBuffer() const { return *output_features_buffer_; } TensorBuffer& GetMutableOutputFeaturesBuffer() { return *output_features_buffer_; } protected: CompiledModel compiled_model_; // The input buffer for the spectrogram mask. TensorBuffer* input_mask_buffer_; // The input buffer for the spectrogram tensor. TensorBuffer* spectrogram_buffer_; // The output buffer for the valid tokens mask. TensorBuffer* output_mask_buffer_; // The output buffer for the features. TensorBuffer* output_features_buffer_; // The input names for the audio encoder model. std::vector input_names_; // The output names for the audio encoder model. std::vector output_names_; // The input buffers map for the audio encoder model. absl::flat_hash_map input_buffers_map_; // The output buffers map for the audio encoder model. absl::flat_hash_map output_buffers_map_; }; // Audio Encoder for static LiteRT model, where the whole audio is provided at // once. class AudioStaticEncoder : public AudioEncoder { public: // Create an AudioStaticEncoder to run audio static encoder LiteRT // CompiledModel. // Args: // - env: The LiteRT environment. // - model: The audio encoder model. // Returns: // A unique pointer to the AudioStaticEncoder if successful, or an error // status if failed. static absl::StatusOr> Create( const AudioExecutorSettings& executor_settings, Environment& env, const Model* absl_nonnull model); // Initialize the AudioStaticEncoder, which will create the input and output // buffers for the audio encoder model. absl::Status Initialize() override; absl::Status ClearInputBuffers() override; absl::Status Reset() override { return ClearInputBuffers(); } private: AudioStaticEncoder(const AudioExecutorSettings& executor_settings, Environment& env, const Model* absl_nonnull model) : executor_settings_(executor_settings), env_(env), model_(*model) {} const AudioExecutorSettings& executor_settings_; Environment& env_; const Model& model_; }; // Audio Encoder for streaming LiteRT model, where the audio is provided in // streaming fashion. // // For streaming audio encoder model, the input buffers map contains two // parts: // 1. The inputs from the new audio segment. It includes // - segment_values: The spectrogram segment. // - segment_mask: The spectrogram mask. // 2. The inputs from the internal state. It includes // - prev_features: The previous features. // - prev_mask: The previous mask. // - prev_conv_out_mask: The previous conv out mask. // and for each transformer layer (12 layers for gemma3n): // - prev_q_{layer_idx}: The previous q tensor. // - prev_k_{layer_idx}: The previous k tensor. // - prev_v_{layer_idx}: The previous v tensor. // - conv_padding_{layer_idx}: The conv padding. // and for each subsample layer (2 layers for gemma3n): // - feature_states_{layer_idx}: The feature states. // // For streaming audio encoder model, the output buffers map contains two // parts: // 1. The outputs from the new audio segment. It includes // - features: The features. // - mask: The valid tokens mask. // 2. The outputs from the internal state, and are used for next round of // input. It includes // - prev_features: The previous features. // - prev_mask: The previous mask. // - prev_conv_out_mask: The previous conv out mask. // and for each transformer layer (12 layers for gemma3n): // - prev_q_{layer_idx}: The previous q tensor. // - prev_k_{layer_idx}: The previous k tensor. // - prev_v_{layer_idx}: The previous v tensor. // - conv_padding_{layer_idx}: The conv padding. // and for each subsample layer (2 layers for gemma3n): // - class AudioStreamingEncoder : public AudioEncoder { public: // Create an AudioStreamingEncoder to run audio streaming encoder LiteRT // CompiledModel. // Args: // - env: The LiteRT environment. // - model: The audio encoder model. // Returns: // A unique pointer to the AudioStreamingEncoder if successful, or an // error status if failed. static absl::StatusOr> Create( const AudioExecutorSettings& executor_settings, Environment& env, const Model* absl_nonnull model); // Initialize the AudioStreamingEncoder, which will create the input and // output buffers for the audio encoder model. absl::Status Initialize() override; int GetOverlapSize() const { return overlap_size_; } // Swap the internal state buffers between input and output buffers map, so // the previous state will be used for the current state. absl::Status SwapInternalStateBuffers(); absl::Status ClearInputBuffers() override; absl::Status Reset() override; absl::StatusOr> CreateNewContext(); absl::StatusOr> CloneContext(); absl::Status RestoreContext( std::unique_ptr audio_streaming_context); private: AudioStreamingEncoder(const AudioExecutorSettings& executor_settings, Environment& env, const Model* absl_nonnull model) : executor_settings_(executor_settings), env_(env), model_(*model) {} const AudioExecutorSettings& executor_settings_; Environment& env_; const Model& model_; int overlap_size_; }; // The Audio Adapter LiteRT CompiledModel wrapper manage the input and // output buffers of the audio adapter model. It is not expected to be used // directly by the user. It is used by the AudioLiteRtCompiledModelExecutor to // encode the audio embeddings into audio soft tokens. The user should use the // AudioLiteRtCompiledModelExecutor instead. class AudioAdapter { public: // Create an AudioAdapter to run audio adapter LiteRT CompiledModel. // Args: // - env: The LiteRT environment. // - model: The audio adapter model. // Returns: // A unique pointer to the AudioAdapter if successful, or an error status // if failed. static absl::StatusOr> Create( const AudioExecutorSettings& executor_settings, Environment& env, const Model* absl_nonnull model); // Initialize the AudioAdapter, which will create the input and output // buffers for the audio adapter model. absl::Status Initialize(); const CompiledModel& GetCompiledModel() const { return compiled_model_; } CompiledModel& GetMutableCompiledModel() { return compiled_model_; } const std::vector& GetInputBuffers() const { return input_buffers_; } std::vector& GetMutableInputBuffers() { return input_buffers_; } const TensorBuffer& GetFeaturesBuffer() const { return *features_buffer_; } TensorBuffer& GetMutableFeaturesBuffer() { return *features_buffer_; } const TensorBuffer& GetMaskBuffer() const { return *mask_buffer_; } TensorBuffer& GetMutableMaskBuffer() { return *mask_buffer_; } const std::vector& GetOutputBuffers() const { return output_buffers_; } std::vector& GetMutableOutputBuffers() { return output_buffers_; } private: AudioAdapter(const AudioExecutorSettings& executor_settings, Environment& env, const Model* absl_nonnull model) : executor_settings_(executor_settings), env_(env), model_(*model) {} const AudioExecutorSettings& executor_settings_; Environment& env_; const Model& model_; CompiledModel compiled_model_; // The input buffers for the audio adapter model. std::vector input_buffers_; // The input buffers for the input features. TensorBuffer* features_buffer_; // The input buffer for the input mask. TensorBuffer* mask_buffer_; // The output buffers for the audio adapter model. std::vector output_buffers_; }; explicit AudioLiteRtCompiledModelExecutor( AudioExecutorSettings executor_settings, AudioExecutorProperties executor_properties, Environment& env, std::unique_ptr resources, std::unique_ptr audio_encoder, std::unique_ptr audio_adapter, int sequence_length, int spectrogram_feature_dimensions, int audio_embedding_dimensions, int encoder_shrinking_factor) : sequence_length_(sequence_length), spectrogram_feature_dimensions_(spectrogram_feature_dimensions), audio_embedding_dimensions_(audio_embedding_dimensions), encoder_shrinking_factor_(encoder_shrinking_factor), executor_settings_(std::move(executor_settings)), executor_properties_(std::move(executor_properties)), env_(env), resources_(std::move(resources)), audio_encoder_(std::move(audio_encoder)), audio_adapter_(std::move(audio_adapter)) {} // Run the audio encoder and audio adapter models to encode the spectrogram // tensor into audio embeddings. // Args: // - spectrogram_tensor: The spectrogram tensor buffer to encode. // - spectrogram_mask: The spectrogram mask buffer to indicate the valid // timestamps. // - audio_embeddings: The output buffer for the audio embeddings to write // into. // Returns: // The number of valid tokens in the audio embeddings. absl::StatusOr EncodeInternal(absl::Span spectrogram_tensor, absl::Span spectrogram_mask, absl::Span audio_embeddings); int sequence_length_; int spectrogram_feature_dimensions_; int audio_embedding_dimensions_; int encoder_shrinking_factor_; AudioExecutorSettings executor_settings_; AudioExecutorProperties executor_properties_; /// The LiteRT environment. Environment& env_; std::unique_ptr resources_; std::unique_ptr audio_encoder_; std::unique_ptr audio_adapter_; }; } // namespace litert::lm #endif // THIRD_PARTY_ODML_LITERT_LM_RUNTIME_EXECUTOR_AUDIO_LITERT_COMPILED_MODEL_EXECUTOR_H_