import torch from torch.nn.modules.loss import _Loss class PairwiseNegSDR(_Loss): def __init__(self, sdr_type, zero_mean=True, take_log=True, EPS=1e-8): super().__init__() assert sdr_type in ["snr", "sisdr", "sdsdr"] self.sdr_type = sdr_type self.zero_mean = zero_mean self.take_log = take_log self.EPS = EPS def forward(self, ests, targets): if targets.size() != ests.size() or targets.ndim != 3: raise TypeError( f"Inputs must be of shape [batch, n_src, time], got {targets.size()} and {ests.size()} instead" ) assert targets.size() == ests.size() # Step 1. Zero-mean norm if self.zero_mean: mean_source = torch.mean(targets, dim=2, keepdim=True) mean_estimate = torch.mean(ests, dim=2, keepdim=True) targets = targets - mean_source ests = ests - mean_estimate # Step 2. Pair-wise SI-SDR. (Reshape to use broadcast) s_target = torch.unsqueeze(targets, dim=1) s_estimate = torch.unsqueeze(ests, dim=2) if self.sdr_type in ["sisdr", "sdsdr"]: # [batch, n_src, n_src, 1] pair_wise_dot = torch.sum(s_estimate * s_target, dim=3, keepdim=True) # [batch, 1, n_src, 1] s_target_energy = torch.sum(s_target ** 2, dim=3, keepdim=True) + self.EPS # [batch, n_src, n_src, time] pair_wise_proj = pair_wise_dot * s_target / s_target_energy else: # [batch, n_src, n_src, time] pair_wise_proj = s_target.repeat(1, s_target.shape[2], 1, 1) if self.sdr_type in ["sdsdr", "snr"]: e_noise = s_estimate - s_target else: e_noise = s_estimate - pair_wise_proj # [batch, n_src, n_src] pair_wise_sdr = torch.sum(pair_wise_proj ** 2, dim=3) / ( torch.sum(e_noise ** 2, dim=3) + self.EPS ) if self.take_log: pair_wise_sdr = 10 * torch.log10(pair_wise_sdr + self.EPS) return -pair_wise_sdr class SingleSrcNegSDR(_Loss): def __init__( self, sdr_type, zero_mean=True, take_log=True, reduction="none", EPS=1e-8 ): assert reduction != "sum", NotImplementedError super().__init__(reduction=reduction) assert sdr_type in ["snr", "sisdr", "sdsdr"] self.sdr_type = sdr_type self.zero_mean = zero_mean self.take_log = take_log self.EPS = 1e-8 def forward(self, ests, targets): if targets.size() != ests.size() or targets.ndim != 2: raise TypeError( f"Inputs must be of shape [batch, time], got {targets.size()} and {ests.size()} instead" ) # Step 1. Zero-mean norm if self.zero_mean: mean_source = torch.mean(targets, dim=1, keepdim=True) mean_estimate = torch.mean(ests, dim=1, keepdim=True) targets = targets - mean_source ests = ests - mean_estimate # Step 2. Pair-wise SI-SDR. if self.sdr_type in ["sisdr", "sdsdr"]: # [batch, 1] dot = torch.sum(ests * targets, dim=1, keepdim=True) # [batch, 1] s_target_energy = torch.sum(targets ** 2, dim=1, keepdim=True) + self.EPS # [batch, time] scaled_target = dot * targets / s_target_energy else: # [batch, time] scaled_target = targets if self.sdr_type in ["sdsdr", "snr"]: e_noise = ests - targets else: e_noise = ests - scaled_target # [batch] losses = torch.sum(scaled_target ** 2, dim=1) / ( torch.sum(e_noise ** 2, dim=1) + self.EPS ) if self.take_log: losses = 10 * torch.log10(losses + self.EPS) losses = losses.mean() if self.reduction == "mean" else losses return -losses class MultiSrcNegSDR(_Loss): def __init__(self, sdr_type, zero_mean=True, take_log=True, EPS=1e-8): super().__init__() assert sdr_type in ["snr", "sisdr", "sdsdr"] self.sdr_type = sdr_type self.zero_mean = zero_mean self.take_log = take_log self.EPS = 1e-8 def forward(self, ests, targets): if targets.size() != ests.size() or targets.ndim != 3: raise TypeError( f"Inputs must be of shape [batch, n_src, time], got {targets.size()} and {ests.size()} instead" ) # Step 1. Zero-mean norm if self.zero_mean: mean_source = torch.mean(targets, dim=2, keepdim=True) mean_est = torch.mean(ests, dim=2, keepdim=True) targets = targets - mean_source ests = ests - mean_est # Step 2. Pair-wise SI-SDR. if self.sdr_type in ["sisdr", "sdsdr"]: # [batch, n_src] pair_wise_dot = torch.sum(ests * targets, dim=2, keepdim=True) # [batch, n_src] s_target_energy = torch.sum(targets ** 2, dim=2, keepdim=True) + self.EPS # [batch, n_src, time] scaled_targets = pair_wise_dot * targets / s_target_energy else: # [batch, n_src, time] scaled_targets = targets if self.sdr_type in ["sdsdr", "snr"]: e_noise = ests - targets else: e_noise = ests - scaled_targets # [batch, n_src] pair_wise_sdr = torch.sum(scaled_targets ** 2, dim=2) / ( torch.sum(e_noise ** 2, dim=2) + self.EPS ) if self.take_log: pair_wise_sdr = 10 * torch.log10(pair_wise_sdr + self.EPS) return -torch.mean(pair_wise_sdr, dim=-1) class freq_MAE_WavL1Loss(_Loss): def __init__(self, win=2048, stride=512): super().__init__() self.EPS = 1e-8 self.win = win self.stride = stride def forward(self, ests, targets): B, nsrc, _ = ests.shape est_spec = torch.stft(ests.view(-1, ests.shape[-1]), n_fft=self.win, hop_length=self.stride, window=torch.hann_window(self.win).to(ests.device).float(), return_complex=True) est_target = torch.stft(targets.view(-1, targets.shape[-1]), n_fft=self.win, hop_length=self.stride, window=torch.hann_window(self.win).to(ests.device).float(), return_complex=True) freq_L1 = (est_spec.real - est_target.real).abs().mean((1,2)) + (est_spec.imag - est_target.imag).abs().mean((1,2)) freq_L1 = freq_L1.view(B, nsrc).mean(-1) wave_l1 = (ests - targets).abs().mean(-1) # print(freq_L1.shape, wave_l1.shape) wave_l1 = wave_l1.view(B, nsrc).mean(-1) return freq_L1 + wave_l1 class freq_MSE_Loss(_Loss): def __init__(self, win=640, stride=160): super().__init__() self.EPS = 1e-8 self.win = win self.stride = stride def forward(self, ests, targets): B, nsrc, _ = ests.shape est_spec = torch.stft(ests.view(-1, ests.shape[-1]), n_fft=self.win, hop_length=self.stride, window=torch.hann_window(self.win).to(ests.device).float(), return_complex=True) est_target = torch.stft(targets.view(-1, targets.shape[-1]), n_fft=self.win, hop_length=self.stride, window=torch.hann_window(self.win).to(ests.device).float(), return_complex=True) freq_mse = (est_spec.real - est_target.real).square().mean((1,2)) + (est_spec.imag - est_target.imag).square().mean((1,2)) freq_mse = freq_mse.view(B, nsrc).mean(-1) return freq_mse # aliases pairwise_neg_sisdr = PairwiseNegSDR("sisdr") pairwise_neg_sdsdr = PairwiseNegSDR("sdsdr") pairwise_neg_snr = PairwiseNegSDR("snr") singlesrc_neg_sisdr = SingleSrcNegSDR("sisdr") singlesrc_neg_sdsdr = SingleSrcNegSDR("sdsdr") singlesrc_neg_snr = SingleSrcNegSDR("snr") multisrc_neg_sisdr = MultiSrcNegSDR("sisdr") multisrc_neg_sdsdr = MultiSrcNegSDR("sdsdr") multisrc_neg_snr = MultiSrcNegSDR("snr") freq_mae_wavl1loss = freq_MAE_WavL1Loss() pairwise_neg_sisdr_freq_mse = (PairwiseNegSDR("sisdr"), freq_MSE_Loss()) pairwise_neg_snr_multidecoder = (PairwiseNegSDR("snr"), MultiSrcNegSDR("snr"))