Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from einops import rearrange | |
from .mamba_block import TFMambaBlock | |
from .codec_module import DenseEncoder, MagDecoder, PhaseDecoder | |
class SEMamba(nn.Module): | |
""" | |
SEMamba model for speech enhancement using Mamba blocks. | |
This model uses a dense encoder, multiple Mamba blocks, and separate magnitude | |
and phase decoders to process noisy magnitude and phase inputs. | |
""" | |
def __init__(self, cfg): | |
""" | |
Initialize the SEMamba model. | |
Args: | |
- cfg: Configuration object containing model parameters. | |
""" | |
super(SEMamba, self).__init__() | |
self.cfg = cfg | |
self.num_tscblocks = cfg['model_cfg']['num_tfmamba'] if cfg['model_cfg']['num_tfmamba'] is not None else 4 # default tfmamba: 4 | |
# Initialize dense encoder | |
self.dense_encoder = DenseEncoder(cfg) | |
# Initialize Mamba blocks | |
self.TSMamba = nn.ModuleList([TFMambaBlock(cfg) for _ in range(self.num_tscblocks)]) | |
# Initialize decoders | |
self.mask_decoder = MagDecoder(cfg) | |
self.phase_decoder = PhaseDecoder(cfg) | |
def forward(self, noisy_mag, noisy_pha): | |
""" | |
Forward pass for the SEMamba model. | |
Args: | |
- noisy_mag (torch.Tensor): Noisy magnitude input tensor [B, F, T]. | |
- noisy_pha (torch.Tensor): Noisy phase input tensor [B, F, T]. | |
Returns: | |
- denoised_mag (torch.Tensor): Denoised magnitude tensor [B, F, T]. | |
- denoised_pha (torch.Tensor): Denoised phase tensor [B, F, T]. | |
- denoised_com (torch.Tensor): Denoised complex tensor [B, F, T, 2]. | |
""" | |
# Reshape inputs | |
noisy_mag = rearrange(noisy_mag, 'b f t -> b t f').unsqueeze(1) # [B, 1, T, F] | |
noisy_pha = rearrange(noisy_pha, 'b f t -> b t f').unsqueeze(1) # [B, 1, T, F] | |
# Concatenate magnitude and phase inputs | |
x = torch.cat((noisy_mag, noisy_pha), dim=1) # [B, 2, T, F] | |
# Encode input | |
x = self.dense_encoder(x) | |
# Apply Mamba blocks | |
for block in self.TSMamba: | |
x = block(x) | |
# Decode magnitude and phase | |
denoised_mag = rearrange(self.mask_decoder(x) * noisy_mag, 'b c t f -> b f t c').squeeze(-1) | |
denoised_pha = rearrange(self.phase_decoder(x), 'b c t f -> b f t c').squeeze(-1) | |
# Combine denoised magnitude and phase into a complex representation | |
denoised_com = torch.stack( | |
(denoised_mag * torch.cos(denoised_pha), denoised_mag * torch.sin(denoised_pha)), | |
dim=-1 | |
) | |
return denoised_mag, denoised_pha, denoised_com | |