# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import torch import torch.nn as nn from einops import rearrange from huggingface_hub import PyTorchModelHubMixin from .mamba_block2_SEMamba import TFMambaBlock from .codec_module_time_d4 import DenseEncoder, MagDecoder, PhaseDecoder class SEMamba(nn.Module, PyTorchModelHubMixin): """ SEMamba model for speech enhancement using Mamba blocks. This model uses a dense encoder, multiple Mamba blocks, and separate magnitude and phase decoders to process noisy magnitude and phase inputs. """ def __init__(self, cfg): """ Initialize the SEMamba model. Args: - cfg: Configuration object containing model parameters. """ super(SEMamba, self).__init__() self.cfg = cfg self.num_tscblocks = cfg['model_cfg']['num_tfmamba'] if cfg['model_cfg']['num_tfmamba'] is not None else 4 # default tfmamba: 4 # Initialize dense encoder self.dense_encoder = DenseEncoder(cfg) # Initialize Mamba blocks self.TSMamba = nn.ModuleList([TFMambaBlock(cfg) for _ in range(self.num_tscblocks)]) # Initialize decoders self.mask_decoder = MagDecoder(cfg) self.phase_decoder = PhaseDecoder(cfg) def forward(self, noisy_mag, noisy_pha): """ Forward pass for the SEMamba model. Args: - noisy_mag (torch.Tensor): Noisy magnitude input tensor [B, F, T]. - noisy_pha (torch.Tensor): Noisy phase input tensor [B, F, T]. Returns: - denoised_mag (torch.Tensor): Denoised magnitude tensor [B, F, T]. - denoised_pha (torch.Tensor): Denoised phase tensor [B, F, T]. - denoised_com (torch.Tensor): Denoised complex tensor [B, F, T, 2]. """ # Reshape inputs noisy_mag = rearrange(noisy_mag, 'b f t -> b t f').unsqueeze(1) # [B, 1, T, F] noisy_pha = rearrange(noisy_pha, 'b f t -> b t f').unsqueeze(1) # [B, 1, T, F] # Concatenate magnitude and phase inputs x = torch.cat((noisy_mag, noisy_pha), dim=1) # [B, 2, T, F] # Prevent unpredictable errors B, C, T, F = x.shape zeros = torch.zeros(B, C, T, 2, device=x.device) x = torch.cat((x, zeros), dim=-1) zeros = torch.zeros(B, C, 2, F+2, device=x.device) x = torch.cat((x, zeros), dim=-2) # Encode input x = self.dense_encoder(x) # Apply Mamba blocks for block in self.TSMamba: x = block(x) # Decode output denoised_mag = rearrange(self.mask_decoder(x), 'b c t f -> b f t c').squeeze(-1) denoised_pha = rearrange(self.phase_decoder(x), 'b c t f -> b f t c').squeeze(-1) # Prevent unpredictable errors denoised_mag = denoised_mag[:, :F, :T] denoised_pha = denoised_pha[:, :F, :T] # Combine denoised magnitude and phase into a complex representation denoised_com = torch.stack( (denoised_mag * torch.cos(denoised_pha), denoised_mag * torch.sin(denoised_pha)), dim=-1 ) return denoised_mag, denoised_pha, denoised_com