LiteRT-LM / runtime /executor /audio_litert_compiled_model_executor.cc
SeaWolf-AI's picture
Upload full LiteRT-LM codebase
5f923cd verified
// Copyright 2025 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "runtime/executor/audio_litert_compiled_model_executor.h"
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <filesystem> //NOLINT
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <variant>
#include <vector>
#include "absl/base/nullability.h" // from @com_google_absl
#include "absl/container/flat_hash_map.h" // from @com_google_absl
#include "absl/log/absl_log.h" // from @com_google_absl
#include "absl/memory/memory.h" // from @com_google_absl
#include "absl/status/status.h" // from @com_google_absl
#include "absl/status/statusor.h" // from @com_google_absl
#include "absl/strings/match.h" // from @com_google_absl
#include "absl/strings/str_cat.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "absl/types/span.h" // from @com_google_absl
#include "litert/cc/litert_common.h" // from @litert
#include "litert/cc/litert_compiled_model.h" // from @litert
#include "litert/cc/litert_element_type.h" // from @litert
#include "litert/cc/litert_environment.h" // from @litert
#include "litert/cc/litert_layout.h" // from @litert
#include "litert/cc/litert_macros.h" // from @litert
#include "litert/cc/litert_model.h" // from @litert
#include "litert/cc/litert_options.h" // from @litert
#include "litert/cc/litert_ranked_tensor_type.h" // from @litert
#include "litert/cc/litert_tensor_buffer.h" // from @litert
#include "litert/cc/litert_tensor_buffer_types.h" // from @litert
#include "litert/cc/options/litert_cpu_options.h" // from @litert
#include "litert/cc/options/litert_gpu_options.h" // from @litert
#include "runtime/components/model_resources.h"
#include "runtime/engine/io_types.h"
#include "runtime/executor/audio_executor_settings.h"
#include "runtime/executor/audio_executor_utils.h"
#include "runtime/executor/common_utils.h"
#include "runtime/executor/executor_settings_base.h"
#include "runtime/executor/litert_compiled_model_executor_utils.h"
#include "runtime/executor/llm_executor_io_types.h"
#include "runtime/util/file_util.h"
#include "runtime/util/scoped_file.h"
#include "runtime/util/status_macros.h" //NOLINT
namespace litert::lm {
namespace {
absl::Status SetCpuCacheOptions(
const absl::StatusOr<std::string>& weight_cache_file,
std::shared_ptr<litert::lm::ScopedFile> scoped_cache_file,
litert::CpuOptions& cpu_options, absl::string_view logging_prefix) {
if (scoped_cache_file != nullptr) {
ASSIGN_OR_RETURN(auto duplicated, scoped_cache_file->Duplicate());
ASSIGN_OR_RETURN(int fd, duplicated.Release());
cpu_options.SetXNNPackWeightCacheFileDescriptor(fd);
ABSL_LOG(INFO) << logging_prefix
<< " use provided cache file descriptor: " << fd;
} else if (weight_cache_file.ok()) {
const std::string& weight_cache_path = *weight_cache_file;
cpu_options.SetXNNPackWeightCachePath(weight_cache_path.c_str());
ABSL_LOG(INFO) << logging_prefix
<< " use cache path: " << weight_cache_path;
} else {
ABSL_LOG(INFO) << logging_prefix << " does not use cache.";
}
return absl::OkStatus();
}
absl::Status SetGpuOptions(
const std::string& weight_cache_path,
std::shared_ptr<litert::lm::ScopedFile> scoped_cache_file,
const absl::StatusOr<
std::variant<std::string, std::shared_ptr<litert::lm::ScopedFile>>>&
program_cache_file,
const AudioExecutorSettings& executor_settings, absl::string_view cache_key,
absl::string_view logging_prefix, litert::GpuOptions& gpu_options) {
#if defined(LITERT_USE_WEBGPU_ACCELERATOR)
gpu_options.SetBackend(GpuOptions::Backend::kWebGpu);
#endif // defined(LITERT_USE_WEBGPU_ACCELERATOR)
gpu_options.EnableConstantTensorSharing(true);
// TODO(b/484646529): Re-enable precision setting once the GPU audio
// encoder precision is fixed. Similar to vision encoder, we force FP32 for
// now.
// if (executor_settings.GetActivationDataType().has_value()) {
// if (executor_settings.GetActivationDataType().value() ==
// ActivationDataType::FLOAT32) {
// gpu_options.SetPrecision(GpuOptions::Precision::kFp32);
// } else {
// gpu_options.SetPrecision(GpuOptions::Precision::kFp16);
// }
// } else {
// gpu_options.SetPrecision(GpuOptions::Precision::kFp32);
// }
gpu_options.SetPrecision(GpuOptions::Precision::kFp32);
#if defined(__APPLE__)
gpu_options.SetPreferTextureWeights(false);
gpu_options.SetUseMetalArgumentBuffers(true);
#else // !__APPLE__
gpu_options.SetPreferTextureWeights(true);
#endif // !__APPLE__
gpu_options.SetModelCacheKey(cache_key.data());
std::string cache_path = weight_cache_path;
bool serialization_dir_set = false;
if (cache_path != ":nocache") {
if (cache_path.empty()) {
ASSIGN_OR_RETURN(auto model_path,
executor_settings.GetModelAssets().GetPath());
cache_path =
std::filesystem::path(std::string(model_path)).parent_path().string();
if (cache_path.empty()) {
cache_path = std::filesystem::current_path().string();
}
}
gpu_options.SetSerializationDir(cache_path.c_str());
gpu_options.SetSerializeExternalTensors(true);
serialization_dir_set = true;
}
if (program_cache_file.ok()) {
if (std::holds_alternative<std::string>(*program_cache_file)) {
if (!serialization_dir_set) {
cache_path =
std::filesystem::path(std::get<std::string>(*program_cache_file))
.parent_path()
.string();
gpu_options.SetSerializationDir(cache_path.c_str());
}
} else {
auto scoped_cache_file =
std::get<std::shared_ptr<lm::ScopedFile>>(*program_cache_file);
ASSIGN_OR_RETURN(auto duplicated, scoped_cache_file->Duplicate());
ASSIGN_OR_RETURN(int fd, duplicated.Release());
gpu_options.SetProgramCacheFd(fd);
}
gpu_options.SetSerializeProgramCache(true);
} else {
gpu_options.SetSerializeProgramCache(false);
}
return absl::OkStatus();
}
constexpr absl::string_view kFeaturesName = "features";
constexpr absl::string_view kMaskName = "mask";
constexpr absl::string_view kMaskOutName = "mask_out";
constexpr absl::string_view kSrcInputsName = "src_inputs";
constexpr absl::string_view kSegmentValuesName = "segment_values";
constexpr absl::string_view kSegmentMaskName = "segment_mask";
constexpr absl::string_view kPrevMaskName = "prev_mask";
constexpr absl::string_view kPrevPrefix = "prev_";
constexpr absl::string_view kFeatureStatesNamePattern = "feature_state";
template <typename T>
absl::StatusOr<std::vector<T>> GetDataAsVector(TensorBuffer& tensor_buffer) {
LITERT_ASSIGN_OR_RETURN(auto tensor_type, tensor_buffer.TensorType());
LITERT_ASSIGN_OR_RETURN(auto elements, tensor_type.Layout().NumElements());
std::vector<T> data(elements);
LITERT_RETURN_IF_ERROR(tensor_buffer.Read<T>(absl::MakeSpan(data)));
return data;
}
// Returns the first valid token count from the mask tensor.
absl::StatusOr<int> GetValidCount(const TensorBuffer& mask_buffer) {
ASSIGN_OR_RETURN(auto mask, GetDataAsVector<uint8_t>(
const_cast<TensorBuffer&>(mask_buffer)));
for (int i = mask.size() - 1; i >= 0; --i) {
if (mask[i] != 0) {
return i + 1;
}
}
return 0;
}
absl::Status InitializeBuffer(TensorBuffer& buffer) {
LITERT_ASSIGN_OR_RETURN(
auto buffer_lock_and_addr,
TensorBufferScopedLock::Create(buffer, TensorBuffer::LockMode::kWrite));
LITERT_ASSIGN_OR_RETURN(auto packed_size, buffer.PackedSize());
memset(buffer_lock_and_addr.second, 0, packed_size);
return absl::OkStatus();
}
absl::Status InitializeBuffers(std::vector<TensorBuffer>& buffers) {
for (auto& buffer : buffers) {
RETURN_IF_ERROR(InitializeBuffer(buffer));
}
return absl::OkStatus();
}
inline int CeilIntDiv(int a, int b) { return (a + b - 1) / b; }
bool IsStreamingEncoder(const std::vector<absl::string_view>& input_names) {
// A huristic to check if the model is a streaming model by checking if the
// input names contain the prev_mask name.
return std::any_of(input_names.begin(), input_names.end(),
[](absl::string_view input_name) {
return absl::StrContains(input_name, kPrevPrefix);
});
}
} // namespace
absl::StatusOr<std::unique_ptr<AudioContext>> AudioStreamingContext::Clone()
const {
absl::flat_hash_map<absl::string_view, ::litert::TensorBuffer>
new_state_buffers;
for (auto& [name, buffer] : state_buffers_) {
LITERT_ASSIGN_OR_RETURN(auto new_buffer, buffer.Duplicate());
new_state_buffers[name] = std::move(new_buffer);
}
return std::make_unique<AudioStreamingContext>(std::move(new_state_buffers));
}
absl::StatusOr<
std::unique_ptr<AudioLiteRtCompiledModelExecutor::AudioStaticEncoder>>
AudioLiteRtCompiledModelExecutor::AudioStaticEncoder::Create(
const AudioExecutorSettings& executor_settings, Environment& env,
const Model* absl_nonnull model) {
auto handler = std::unique_ptr<AudioStaticEncoder>(
new AudioStaticEncoder(executor_settings, env, model));
RETURN_IF_ERROR(handler->Initialize());
return handler;
}
absl::Status
AudioLiteRtCompiledModelExecutor::AudioStaticEncoder::Initialize() {
LITERT_ASSIGN_OR_RETURN(auto options, Options::Create());
auto weight_cache_file = executor_settings_.GetWeightCacheFile(
".static_audio_encoder.xnnpack_cache");
std::string weight_cache_path = executor_settings_.GetCacheDir();
if (executor_settings_.GetBackend() == Backend::GPU) {
LITERT_ASSIGN_OR_RETURN(auto& gpu_options, options.GetGpuOptions());
ASSIGN_OR_RETURN(auto model_path,
executor_settings_.GetModelAssets().GetPath());
absl::string_view model_basename = Basename(model_path);
auto program_cache_file = executor_settings_.GetProgramCacheFile(
".mldrift_program_cache.static_audio_encoder.bin");
RETURN_IF_ERROR(SetGpuOptions(
weight_cache_path, executor_settings_.GetScopedEncoderCacheFile(),
program_cache_file, executor_settings_,
absl::StrCat(model_basename, ".static_audio_encoder"), "audio_encoder",
gpu_options));
options.SetHardwareAccelerators(litert::HwAccelerators::kGpu);
} else if (executor_settings_.GetBackend() == Backend::CPU) {
LITERT_ASSIGN_OR_RETURN(auto& cpu_options, options.GetCpuOptions());
cpu_options.SetNumThreads(executor_settings_.GetNumThreads());
std::shared_ptr<ScopedFile> scoped_encoder_cache_file =
executor_settings_.GetScopedEncoderCacheFile();
RETURN_IF_ERROR(SetCpuCacheOptions(weight_cache_file,
scoped_encoder_cache_file, cpu_options,
"audio_encoder"));
options.SetHardwareAccelerators(litert::HwAccelerators::kCpu);
} else {
return absl::InvalidArgumentError(
absl::StrCat("Unsupported backend for AudioStaticEncoder: ",
executor_settings_.GetBackend()));
}
LITERT_ASSIGN_OR_RETURN(compiled_model_,
CompiledModel::Create(env_, model_.Get(), options));
LITERT_ASSIGN_OR_RETURN(auto signatures, model_.GetSignatures());
if (signatures.size() != 1) {
return absl::InvalidArgumentError(
absl::StrCat("The Audio Static Encoder model must have exactly one "
"signature but got ",
signatures.size()));
}
LITERT_ASSIGN_OR_RETURN(auto signature, model_.GetSignature(0));
// Initialize the input buffers.
LITERT_ASSIGN_OR_RETURN(auto input_buffers,
compiled_model_.CreateInputBuffers(
/*signature_index=*/0));
LITERT_RETURN_IF_ERROR(InitializeBuffers(input_buffers));
input_names_.reserve(signature.InputNames().size());
for (int i = 0; i < signature.InputNames().size(); ++i) {
std::string input_name = std::string(signature.InputNames()[i]);
input_names_.push_back(input_name);
absl::string_view input_name_view = input_names_[i];
input_buffers_map_[input_name_view] = std::move(input_buffers[i]);
}
// Get pointers to specific buffers after the map is fully populated.
if (!input_buffers_map_.contains(kMaskName)) {
return absl::InvalidArgumentError(
"The Audio Static Encoder model must have a mask input buffer.");
}
if (!input_buffers_map_.contains(kSrcInputsName)) {
return absl::InvalidArgumentError(
"The Audio Static Encoder model must have a src_inputs input "
"buffer.");
}
input_mask_buffer_ = &input_buffers_map_[kMaskName];
spectrogram_buffer_ = &input_buffers_map_[kSrcInputsName];
// Initialize the output buffers.
LITERT_ASSIGN_OR_RETURN(auto output_buffers,
compiled_model_.CreateOutputBuffers(
/*signature_index=*/0));
if (output_buffers.size() != 2) {
return absl::InvalidArgumentError(absl::StrCat(
"The Audio Static Encoder model must have exactly two output "
"buffer but got ",
output_buffers.size()));
}
LITERT_RETURN_IF_ERROR(InitializeBuffers(output_buffers));
output_names_.reserve(signature.OutputNames().size());
for (int i = 0; i < signature.OutputNames().size(); ++i) {
std::string output_name = std::string(signature.OutputNames()[i]);
output_names_.push_back(output_name);
absl::string_view output_name_view = output_names_[i];
output_buffers_map_[output_name_view] = std::move(output_buffers[i]);
}
// Get pointers to specific buffers after the map is fully populated.
if (!output_buffers_map_.contains(kMaskName) &&
!output_buffers_map_.contains(kMaskOutName)) {
return absl::InvalidArgumentError(
"The Audio Static Encoder model must have a mask output buffer.");
}
if (!output_buffers_map_.contains(kFeaturesName)) {
return absl::InvalidArgumentError(
"The Audio Static Encoder model must have a features output buffer.");
}
output_mask_buffer_ = output_buffers_map_.contains(kMaskName)
? &output_buffers_map_[kMaskName]
: &output_buffers_map_[kMaskOutName];
output_features_buffer_ = &output_buffers_map_[kFeaturesName];
return absl::OkStatus();
}
absl::Status
AudioLiteRtCompiledModelExecutor::AudioStaticEncoder::ClearInputBuffers() {
for (auto& [input_name, input_buffer] : input_buffers_map_) {
LITERT_ASSIGN_OR_RETURN(auto buffer_lock_and_addr,
TensorBufferScopedLock::Create(
input_buffer, TensorBuffer::LockMode::kWrite));
LITERT_ASSIGN_OR_RETURN(auto packed_size, input_buffer.PackedSize());
memset(buffer_lock_and_addr.second, 0, packed_size);
}
return absl::OkStatus();
}
absl::StatusOr<
std::unique_ptr<AudioLiteRtCompiledModelExecutor::AudioStreamingEncoder>>
AudioLiteRtCompiledModelExecutor::AudioStreamingEncoder::Create(
const AudioExecutorSettings& executor_settings, Environment& env,
const Model* absl_nonnull model) {
auto handler = std::unique_ptr<AudioStreamingEncoder>(
new AudioStreamingEncoder(executor_settings, env, model));
RETURN_IF_ERROR(handler->Initialize());
return handler;
}
absl::Status
AudioLiteRtCompiledModelExecutor::AudioStreamingEncoder::Initialize() {
LITERT_ASSIGN_OR_RETURN(auto options, Options::Create());
auto weight_cache_file = executor_settings_.GetWeightCacheFile(
".streaming_audio_encoder.xnnpack_cache");
std::string weight_cache_path = executor_settings_.GetCacheDir();
if (executor_settings_.GetBackend() == Backend::GPU) {
LITERT_ASSIGN_OR_RETURN(auto& gpu_options, options.GetGpuOptions());
ASSIGN_OR_RETURN(auto model_path,
executor_settings_.GetModelAssets().GetPath());
absl::string_view model_basename = Basename(model_path);
auto program_cache_file = executor_settings_.GetProgramCacheFile(
".mldrift_program_cache.streaming_audio_encoder.bin");
RETURN_IF_ERROR(SetGpuOptions(
weight_cache_path, executor_settings_.GetScopedEncoderCacheFile(),
program_cache_file, executor_settings_,
absl::StrCat(model_basename, ".streaming_audio_encoder"),
"audio_encoder", gpu_options));
options.SetHardwareAccelerators(litert::HwAccelerators::kGpu);
} else if (executor_settings_.GetBackend() == Backend::CPU) {
LITERT_ASSIGN_OR_RETURN(auto& cpu_options, options.GetCpuOptions());
cpu_options.SetNumThreads(executor_settings_.GetNumThreads());
std::shared_ptr<ScopedFile> scoped_encoder_cache_file =
executor_settings_.GetScopedEncoderCacheFile();
RETURN_IF_ERROR(SetCpuCacheOptions(weight_cache_file,
scoped_encoder_cache_file, cpu_options,
"audio_encoder"));
options.SetHardwareAccelerators(litert::HwAccelerators::kCpu);
} else {
return absl::InvalidArgumentError(
absl::StrCat("Unsupported backend for AudioEncoder: ",
executor_settings_.GetBackend()));
}
LITERT_ASSIGN_OR_RETURN(compiled_model_,
CompiledModel::Create(env_, model_.Get(), options));
LITERT_ASSIGN_OR_RETURN(auto signatures, model_.GetSignatures());
if (signatures.size() != 1) {
return absl::InvalidArgumentError(absl::StrCat(
"The Audio Encoder model must have exactly one signature but got ",
signatures.size()));
}
LITERT_ASSIGN_OR_RETURN(auto signature, model_.GetSignature(0));
// Initialize the input buffers.
LITERT_ASSIGN_OR_RETURN(auto input_buffers,
compiled_model_.CreateInputBuffers(
/*signature_index=*/0));
LITERT_RETURN_IF_ERROR(InitializeBuffers(input_buffers));
input_names_.reserve(signature.InputNames().size());
for (int i = 0; i < signature.InputNames().size(); ++i) {
std::string input_name = std::string(signature.InputNames()[i]);
input_names_.push_back(input_name);
absl::string_view input_name_view = input_names_[i];
input_buffers_map_[input_name_view] = std::move(input_buffers[i]);
}
// Get pointers to specific buffers after the map is fully populated.
if (!input_buffers_map_.contains(kSegmentMaskName)) {
return absl::InvalidArgumentError(
"The Audio Streaming Encoder model must have a segment_mask input "
"buffer.");
}
if (!input_buffers_map_.contains(kSegmentValuesName)) {
return absl::InvalidArgumentError(
"The Audio Streaming Encoder model must have a segment_values input "
"buffer.");
}
input_mask_buffer_ = &input_buffers_map_[kSegmentMaskName];
spectrogram_buffer_ = &input_buffers_map_[kSegmentValuesName];
// Initialize the output buffers.
LITERT_ASSIGN_OR_RETURN(auto output_buffers,
compiled_model_.CreateOutputBuffers(
/*signature_index=*/0));
LITERT_RETURN_IF_ERROR(InitializeBuffers(output_buffers));
output_names_.reserve(signature.OutputNames().size());
for (int i = 0; i < signature.OutputNames().size(); ++i) {
std::string output_name = std::string(signature.OutputNames()[i]);
output_names_.push_back(output_name);
absl::string_view output_name_view = output_names_[i];
output_buffers_map_[output_name_view] = std::move(output_buffers[i]);
}
// Get pointers to specific buffers after the map is fully populated.
if (!output_buffers_map_.contains(kMaskName)) {
return absl::InvalidArgumentError(
"The Audio Streaming Encoder model must have a mask output buffer.");
}
if (!output_buffers_map_.contains(kFeaturesName)) {
return absl::InvalidArgumentError(
"The Audio Streaming Encoder model must have a features output "
"buffer.");
}
output_mask_buffer_ = &output_buffers_map_[kMaskName];
output_features_buffer_ = &output_buffers_map_[kFeaturesName];
// Get the feature states tensor type and use it to get the overlap size.
std::string feature_states_name =
absl::StrCat(kFeatureStatesNamePattern, "_0");
if (!input_buffers_map_.contains(feature_states_name)) {
return absl::InvalidArgumentError(
"The Audio Streaming Encoder model must have a feature_states input "
"buffer.");
}
LITERT_ASSIGN_OR_RETURN(auto feature_states_tensor_type,
input_buffers_map_[feature_states_name].TensorType());
// The overlap size is the number of elements in the feature states tensor,
// which is 3 for gemma3n.
LITERT_ASSIGN_OR_RETURN(overlap_size_,
feature_states_tensor_type.Layout().NumElements());
// Initialize the previous mask buffer to all ones.
if (input_buffers_map_.contains(kPrevMaskName)) {
LITERT_ASSIGN_OR_RETURN(auto prev_mask_type,
input_buffers_map_[kPrevMaskName].TensorType());
LITERT_ASSIGN_OR_RETURN(int prev_mask_size,
prev_mask_type.Layout().NumElements());
input_buffers_map_[kPrevMaskName].Write<uint8_t>(
std::vector<uint8_t>(prev_mask_size, 1));
}
return absl::OkStatus();
}
absl::Status AudioLiteRtCompiledModelExecutor::AudioStreamingEncoder::
SwapInternalStateBuffers() {
std::vector<absl::string_view> all_input_names(input_names_.begin(),
input_names_.end());
for (const auto& input_name : all_input_names) {
if (output_buffers_map_.contains(input_name)) {
std::swap(input_buffers_map_[input_name],
output_buffers_map_[input_name]);
}
}
return absl::OkStatus();
}
absl::Status
AudioLiteRtCompiledModelExecutor::AudioStreamingEncoder::ClearInputBuffers() {
{
LITERT_ASSIGN_OR_RETURN(
auto buffer_lock_and_addr,
TensorBufferScopedLock::Create(GetMutableInputSpectrogramBuffer(),
TensorBuffer::LockMode::kWrite));
LITERT_ASSIGN_OR_RETURN(auto packed_size,
GetInputSpectrogramBuffer().PackedSize());
memset(buffer_lock_and_addr.second, 0, packed_size);
}
{
LITERT_ASSIGN_OR_RETURN(
auto buffer_lock_and_addr,
TensorBufferScopedLock::Create(GetMutableInputMaskBuffer(),
TensorBuffer::LockMode::kWrite));
LITERT_ASSIGN_OR_RETURN(auto packed_size,
GetInputMaskBuffer().PackedSize());
memset(buffer_lock_and_addr.second, 0, packed_size);
}
return absl::OkStatus();
}
absl::Status AudioLiteRtCompiledModelExecutor::AudioStreamingEncoder::Reset() {
for (auto& [input_name, input_buffer] : input_buffers_map_) {
LITERT_ASSIGN_OR_RETURN(auto buffer_lock_and_addr,
TensorBufferScopedLock::Create(
input_buffer, TensorBuffer::LockMode::kWrite));
LITERT_ASSIGN_OR_RETURN(auto packed_size, input_buffer.PackedSize());
if (input_name == kPrevMaskName) {
for (int i = 0; i < packed_size; ++i) {
auto* mask_ptr = static_cast<bool*>(buffer_lock_and_addr.second);
mask_ptr[i] = true;
}
} else {
memset(buffer_lock_and_addr.second, 0, packed_size);
}
}
return absl::OkStatus();
}
absl::StatusOr<std::unique_ptr<AudioLiteRtCompiledModelExecutor::AudioAdapter>>
AudioLiteRtCompiledModelExecutor::AudioAdapter::Create(
const AudioExecutorSettings& executor_settings, Environment& env,
const Model* absl_nonnull model) {
auto handler = std::unique_ptr<AudioAdapter>(
new AudioAdapter(executor_settings, env, model));
RETURN_IF_ERROR(handler->Initialize());
return handler;
}
absl::Status AudioLiteRtCompiledModelExecutor::AudioAdapter::Initialize() {
LITERT_ASSIGN_OR_RETURN(auto options, Options::Create());
auto weight_cache_file =
executor_settings_.GetWeightCacheFile(".audio_adapter.xnnpack_cache");
if (executor_settings_.GetBackend() == Backend::GPU) {
LITERT_ASSIGN_OR_RETURN(auto& gpu_options, options.GetGpuOptions());
gpu_options.EnableConstantTensorSharing(true);
gpu_options.SetPrecision(GpuOptions::Precision::kFp32);
gpu_options.SetPreferTextureWeights(true);
#if defined(LITERT_USE_WEBGPU_ACCELERATOR)
gpu_options.SetBackend(GpuOptions::Backend::kWebGpu);
#endif // defined(LITERT_USE_WEBGPU_ACCELERATOR)
options.SetHardwareAccelerators(litert::HwAccelerators::kGpu);
} else if (executor_settings_.GetBackend() == Backend::CPU) {
LITERT_ASSIGN_OR_RETURN(auto& cpu_options, options.GetCpuOptions());
cpu_options.SetNumThreads(executor_settings_.GetNumThreads());
std::shared_ptr<ScopedFile> scoped_adapter_cache_file =
executor_settings_.GetScopedAdapterCacheFile();
RETURN_IF_ERROR(SetCpuCacheOptions(weight_cache_file,
scoped_adapter_cache_file, cpu_options,
"audio_adapter"));
options.SetHardwareAccelerators(litert::HwAccelerators::kCpu);
} else {
return absl::InvalidArgumentError(
absl::StrCat("Unsupported backend for AudioAdapter: ",
executor_settings_.GetBackend()));
}
LITERT_ASSIGN_OR_RETURN(compiled_model_,
CompiledModel::Create(env_, model_.Get(), options));
LITERT_ASSIGN_OR_RETURN(auto signatures, model_.GetSignatures());
if (signatures.size() != 1) {
return absl::InvalidArgumentError(absl::StrCat(
"The Audio Adapter model must have exactly one signature but got ",
signatures.size()));
}
LITERT_ASSIGN_OR_RETURN(input_buffers_, compiled_model_.CreateInputBuffers(
/*signature_index=*/0));
if (input_buffers_.size() != 2) {
return absl::InvalidArgumentError(absl::StrCat(
"The Audio Adapter model must have exactly two input buffer but got ",
input_buffers_.size()));
}
LITERT_ASSIGN_OR_RETURN(output_buffers_, compiled_model_.CreateOutputBuffers(
/*signature_index=*/0));
LITERT_RETURN_IF_ERROR(InitializeBuffers(input_buffers_));
LITERT_RETURN_IF_ERROR(InitializeBuffers(output_buffers_));
if (output_buffers_.size() != 1) {
return absl::InvalidArgumentError(
absl::StrCat("The Audio Adapter model must have exactly one output "
"buffer but got ",
output_buffers_.size()));
}
LITERT_ASSIGN_OR_RETURN(auto signature, model_.GetSignature(0));
for (int i = 0; i < signature.InputNames().size(); ++i) {
if (absl::StrContains(signature.InputNames()[i], kFeaturesName)) {
features_buffer_ = &input_buffers_[i];
} else if (absl::StrContains(signature.InputNames()[i], kMaskName)) {
mask_buffer_ = &input_buffers_[i];
}
}
if (features_buffer_ == nullptr) {
return absl::InvalidArgumentError(
"The Audio Adapter model must have a features input buffer.");
}
if (mask_buffer_ == nullptr) {
return absl::InvalidArgumentError(
"The Audio Adapter model must have a mask input buffer.");
}
return absl::OkStatus();
}
absl::StatusOr<std::unique_ptr<AudioLiteRtCompiledModelExecutor>>
AudioLiteRtCompiledModelExecutor::Create(
AudioExecutorSettings executor_settings, Environment& env) {
if (executor_settings.GetMaxSequenceLength() > 0) {
ABSL_LOG(INFO) << "Max sequence length is not used for "
"AudioLiteRtCompiledModelExecutor, "
"which can handle variable length input.";
}
LITERT_ASSIGN_OR_RETURN(
auto resources,
BuildLiteRtCompiledModelResources(executor_settings.GetModelAssets()));
ASSIGN_OR_RETURN(auto audio_encoder_model,
resources->GetTFLiteModel(ModelType::kTfLiteAudioEncoderHw));
ASSIGN_OR_RETURN(auto audio_adapter_model,
resources->GetTFLiteModel(ModelType::kTfLiteAudioAdapter));
std::unique_ptr<AudioEncoder> audio_encoder;
LITERT_ASSIGN_OR_RETURN(auto encoder_signature,
audio_encoder_model->GetSignature(0));
const bool is_streaming_encoder =
IsStreamingEncoder(encoder_signature.InputNames());
if (is_streaming_encoder) {
ASSIGN_OR_RETURN(audio_encoder,
AudioStreamingEncoder::Create(executor_settings, env,
audio_encoder_model));
} else {
ASSIGN_OR_RETURN(audio_encoder,
AudioStaticEncoder::Create(executor_settings, env,
audio_encoder_model));
}
LITERT_ASSIGN_OR_RETURN(
auto audio_adapter,
AudioAdapter::Create(executor_settings, env, audio_adapter_model));
const auto& tmp = audio_encoder->GetInputMaskBuffer();
LITERT_ASSIGN_OR_RETURN(auto mask_tensor_type, tmp.TensorType());
LITERT_ASSIGN_OR_RETURN(int sequence_length,
mask_tensor_type.Layout().NumElements());
LITERT_ASSIGN_OR_RETURN(
auto spectrogram_tensor_type,
audio_encoder->GetInputSpectrogramBuffer().TensorType());
const int spectrogram_feature_dimensions =
spectrogram_tensor_type.Layout().Dimensions().back();
LITERT_ASSIGN_OR_RETURN(auto adapter_output_tensor_type,
audio_adapter->GetOutputBuffers()[0].TensorType());
const auto dims = adapter_output_tensor_type.Layout().Dimensions();
const int audio_embedding_dimensions = dims.back();
LITERT_ASSIGN_OR_RETURN(
auto executor_properties,
GetAudioExecutorPropertiesFromModelResources(*resources));
const int encoder_shrinking_factor = executor_properties.audio_shrink_factor;
if (!is_streaming_encoder) {
if (audio_encoder->GetOutputBuffersMap().size() !=
audio_adapter->GetInputBuffers().size()) {
return absl::InvalidArgumentError(absl::StrCat(
"The number of output buffers of the audio encoder must be equal "
"to the number of input buffers of the audio adapter, but got ",
audio_encoder->GetOutputBuffersMap().size(), " and ",
audio_adapter->GetInputBuffers().size()));
}
}
// Make the audio adapter take the audio encoder's mask and features as
// input.
LITERT_ASSIGN_OR_RETURN(auto encoder_mask_tensor,
audio_encoder->GetOutputMaskBuffer().Duplicate());
audio_adapter->GetMutableInputBuffers()[0] = std::move(encoder_mask_tensor);
LITERT_ASSIGN_OR_RETURN(
auto encoder_features_tensor,
audio_encoder->GetMutableOutputFeaturesBuffer().Duplicate());
audio_adapter->GetMutableInputBuffers()[1] =
std::move(encoder_features_tensor);
ABSL_LOG(INFO) << "AudioLiteRtCompiledModelExecutor created with "
"encoder_shrinking_factor: "
<< encoder_shrinking_factor;
return absl::WrapUnique(new AudioLiteRtCompiledModelExecutor(
std::move(executor_settings), std::move(executor_properties), env,
std::move(resources), std::move(audio_encoder), std::move(audio_adapter),
sequence_length, spectrogram_feature_dimensions,
audio_embedding_dimensions, encoder_shrinking_factor));
}
absl::StatusOr<int> AudioLiteRtCompiledModelExecutor::EncodeInternal(
absl::Span<float> spectrogram_tensor, absl::Span<uint8_t> spectrogram_mask,
absl::Span<float> audio_embeddings) {
RETURN_IF_ERROR(audio_encoder_->ClearInputBuffers());
LITERT_RETURN_IF_ERROR(
audio_encoder_->GetMutableInputSpectrogramBuffer().Write<float>(
spectrogram_tensor));
LITERT_RETURN_IF_ERROR(
audio_encoder_->GetMutableInputMaskBuffer().Write<uint8_t>(
spectrogram_mask));
LITERT_RETURN_IF_ERROR(audio_encoder_->GetMutableCompiledModel().Run(
audio_encoder_->GetMutableInputBuffersMap(),
audio_encoder_->GetMutableOutputBuffersMap()));
ASSIGN_OR_RETURN(int chunk_valid_tokens,
GetValidCount(audio_encoder_->GetOutputMaskBuffer()));
LITERT_RETURN_IF_ERROR(audio_adapter_->GetMutableCompiledModel().Run(
audio_adapter_->GetMutableInputBuffers(),
audio_adapter_->GetMutableOutputBuffers()));
LITERT_RETURN_IF_ERROR(
audio_adapter_->GetMutableOutputBuffers()[0].Read<float>(
absl::MakeSpan(audio_embeddings.data(),
chunk_valid_tokens * audio_embedding_dimensions_)));
if (executor_properties_.is_streaming_model) {
RETURN_IF_ERROR(
reinterpret_cast<AudioStreamingEncoder*>(audio_encoder_.get())
->SwapInternalStateBuffers());
}
return chunk_valid_tokens;
}
absl::StatusOr<ExecutorAudioData> AudioLiteRtCompiledModelExecutor::Encode(
const TensorBuffer& spectrogram_tensor,
const TensorBuffer& spectrogram_mask) {
ASSIGN_OR_RETURN(int input_sequence_length, GetValidCount(spectrogram_mask));
LITERT_ASSIGN_OR_RETURN(
auto spectrogram_host_buffer,
GetDataAsVector<float>(const_cast<TensorBuffer&>(spectrogram_tensor)));
LITERT_ASSIGN_OR_RETURN(
auto spectrogram_mask_host_buffer,
GetDataAsVector<uint8_t>(const_cast<TensorBuffer&>(spectrogram_mask)));
std::vector<float> audio_embeddings(input_sequence_length *
audio_embedding_dimensions_);
// Chunk the spectrogram into smaller pieces and encode them one by one.
int total_valid_tokens = 0;
int pos = 0;
while (pos < input_sequence_length) {
int end = std::min(pos + sequence_length_, input_sequence_length);
auto spectrogram_host_buffer_slice =
absl::MakeSpan(spectrogram_host_buffer)
.subspan(pos * spectrogram_feature_dimensions_,
(end - pos) * spectrogram_feature_dimensions_);
auto spectrogram_mask_host_buffer_slice =
absl::MakeSpan(spectrogram_mask_host_buffer).subspan(pos, end - pos);
auto audio_embeddings_slice =
absl::MakeSpan(audio_embeddings)
.subspan(CeilIntDiv(pos, encoder_shrinking_factor_) *
audio_embedding_dimensions_,
CeilIntDiv(end - pos, encoder_shrinking_factor_) *
audio_embedding_dimensions_);
ASSIGN_OR_RETURN(int chunk_valid_tokens,
EncodeInternal(spectrogram_host_buffer_slice,
spectrogram_mask_host_buffer_slice,
audio_embeddings_slice));
total_valid_tokens += chunk_valid_tokens;
pos = end;
}
// Create the final audio embeddings tensor.
RankedTensorType audio_embeddings_tensor_type(
GetElementType<float>(),
Layout(Dimensions({1, total_valid_tokens, audio_embedding_dimensions_})));
LITERT_ASSIGN_OR_RETURN(
auto audio_embeddings_tensor,
TensorBuffer::CreateManaged(env_, TensorBufferType::kHostMemory,
audio_embeddings_tensor_type,
audio_embeddings.size() * sizeof(float)));
LITERT_RETURN_IF_ERROR(audio_embeddings_tensor.Write<float>(
absl::MakeSpan(audio_embeddings)
.subspan(0, total_valid_tokens * audio_embedding_dimensions_)));
ExecutorAudioData audio_data;
audio_data.SetEmbeddings(std::move(audio_embeddings_tensor));
audio_data.SetValidTokens(total_valid_tokens);
return audio_data;
}
absl::StatusOr<ExecutorAudioData> AudioLiteRtCompiledModelExecutor::Encode(
const TensorBuffer& spectrogram_tensor) {
LITERT_ASSIGN_OR_RETURN(auto tensor_type, spectrogram_tensor.TensorType());
auto dimensions = tensor_type.Layout().Dimensions();
if (dimensions.size() < 2) {
return absl::InvalidArgumentError(absl::StrCat(
"Spectrogram tensor must have at least 2 dimensions, but got ",
dimensions.size()));
}
int input_sequence_length = dimensions[dimensions.size() - 2];
LITERT_ASSIGN_OR_RETURN(
auto mask_tensor,
TensorBuffer::CreateManaged(
env_, TensorBufferType::kHostMemory,
RankedTensorType(GetElementType<uint8_t>(),
Layout(Dimensions({1, input_sequence_length}))),
input_sequence_length * sizeof(uint8_t)));
std::vector<uint8_t> all_ones(input_sequence_length, 1);
LITERT_RETURN_IF_ERROR(mask_tensor.Write<uint8_t>(absl::MakeSpan(all_ones)));
return Encode(spectrogram_tensor, mask_tensor);
}
absl::StatusOr<std::unique_ptr<AudioStreamingContext>>
AudioLiteRtCompiledModelExecutor::AudioStreamingEncoder::CreateNewContext() {
absl::flat_hash_map<absl::string_view, ::litert::TensorBuffer> state_buffers;
LITERT_ASSIGN_OR_RETURN(auto signature, compiled_model_.GetSignature(0));
for (auto& [name, buffer] : input_buffers_map_) {
if (name == kSegmentValuesName || name == kSegmentMaskName) {
// Skip the segment values and mask buffers as they are not part of the
// state.
continue;
}
LITERT_ASSIGN_OR_RETURN(auto new_buffer, compiled_model_.CreateInputBuffer(
signature.Key(), name));
if (name == kPrevMaskName) {
LITERT_ASSIGN_OR_RETURN(auto prev_mask_type, buffer.TensorType());
LITERT_ASSIGN_OR_RETURN(int prev_mask_size,
prev_mask_type.Layout().NumElements());
input_buffers_map_[kPrevMaskName].Write<uint8_t>(
std::vector<uint8_t>(prev_mask_size, 1));
} else {
LITERT_RETURN_IF_ERROR(InitializeBuffer(new_buffer));
}
state_buffers[name] = std::move(new_buffer);
}
auto audio_streaming_context =
std::make_unique<AudioStreamingContext>(std::move(state_buffers));
return audio_streaming_context;
}
absl::StatusOr<std::unique_ptr<AudioStreamingContext>>
AudioLiteRtCompiledModelExecutor::AudioStreamingEncoder::CloneContext() {
absl::flat_hash_map<absl::string_view, ::litert::TensorBuffer> state_buffers;
LITERT_ASSIGN_OR_RETURN(auto signature, compiled_model_.GetSignature(0));
for (auto& [name, buffer] : input_buffers_map_) {
if (name == kSegmentValuesName || name == kSegmentMaskName) {
// Skip the segment values and mask buffers as they are not part of the
// state.
continue;
}
LITERT_ASSIGN_OR_RETURN(auto new_buffer, compiled_model_.CreateInputBuffer(
signature.Key(), name));
RETURN_IF_ERROR(CopyBuffer(buffer, new_buffer));
state_buffers[name] = std::move(new_buffer);
}
auto audio_streaming_context =
std::make_unique<AudioStreamingContext>(std::move(state_buffers));
return audio_streaming_context;
}
absl::Status
AudioLiteRtCompiledModelExecutor::AudioStreamingEncoder::RestoreContext(
std::unique_ptr<AudioStreamingContext> audio_streaming_context) {
for (auto& [name, buffer] : audio_streaming_context->state_buffers()) {
if (!input_buffers_map_.contains(name)) {
return absl::InvalidArgumentError(
absl::StrCat("The Audio Streaming Encoder model must have a ", name,
" input buffer."));
}
if (name == kSegmentValuesName || name == kSegmentMaskName) {
// Skip the segment values and mask buffers as they are not part of the
// state.
continue;
}
LITERT_ASSIGN_OR_RETURN(auto buffer_copy, buffer.Duplicate());
input_buffers_map_[name] = std::move(buffer_copy);
}
return absl::OkStatus();
}
absl::StatusOr<std::unique_ptr<AudioContext>>
AudioLiteRtCompiledModelExecutor::CreateNewContext() {
if (!executor_properties_.is_streaming_model) {
return absl::UnimplementedError(
"CreateNewContext is only supported for streaming models.");
}
return reinterpret_cast<AudioStreamingEncoder*>(audio_encoder_.get())
->CreateNewContext();
}
absl::StatusOr<std::unique_ptr<AudioContext>>
AudioLiteRtCompiledModelExecutor::CloneContext() {
if (!executor_properties_.is_streaming_model) {
return absl::UnimplementedError(
"CloneContext is only supported for streaming models.");
}
ASSIGN_OR_RETURN(
auto audio_encoder_context,
reinterpret_cast<AudioStreamingEncoder*>(audio_encoder_.get())
->CloneContext());
return std::move(audio_encoder_context);
}
absl::StatusOr<std::unique_ptr<AudioContext>>
AudioLiteRtCompiledModelExecutor::CloneContext(
const AudioContext& audio_context) {
if (!executor_properties_.is_streaming_model) {
return absl::UnimplementedError(
"CloneContext is only supported for streaming models.");
}
const AudioStreamingContext& audio_streaming_context =
static_cast<const AudioStreamingContext&>(audio_context);
return audio_streaming_context.Clone();
}
absl::Status AudioLiteRtCompiledModelExecutor::RestoreContext(
std::unique_ptr<AudioContext> audio_context) {
if (!executor_properties_.is_streaming_model) {
return absl::UnimplementedError(
"RestoreContext is only supported for streaming models.");
}
return reinterpret_cast<AudioStreamingEncoder*>(audio_encoder_.get())
->RestoreContext(std::unique_ptr<AudioStreamingContext>(
static_cast<AudioStreamingContext*>(audio_context.release())));
}
} // namespace litert::lm