Spaces:
Running
Running
import torch | |
from torch.nn.modules.loss import _Loss | |
class PairwiseNegSDR(_Loss): | |
def __init__(self, sdr_type, zero_mean=True, take_log=True, EPS=1e-8): | |
super().__init__() | |
assert sdr_type in ["snr", "sisdr", "sdsdr"] | |
self.sdr_type = sdr_type | |
self.zero_mean = zero_mean | |
self.take_log = take_log | |
self.EPS = EPS | |
def forward(self, ests, targets): | |
if targets.size() != ests.size() or targets.ndim != 3: | |
raise TypeError( | |
f"Inputs must be of shape [batch, n_src, time], got {targets.size()} and {ests.size()} instead" | |
) | |
assert targets.size() == ests.size() | |
# Step 1. Zero-mean norm | |
if self.zero_mean: | |
mean_source = torch.mean(targets, dim=2, keepdim=True) | |
mean_estimate = torch.mean(ests, dim=2, keepdim=True) | |
targets = targets - mean_source | |
ests = ests - mean_estimate | |
# Step 2. Pair-wise SI-SDR. (Reshape to use broadcast) | |
s_target = torch.unsqueeze(targets, dim=1) | |
s_estimate = torch.unsqueeze(ests, dim=2) | |
if self.sdr_type in ["sisdr", "sdsdr"]: | |
# [batch, n_src, n_src, 1] | |
pair_wise_dot = torch.sum(s_estimate * s_target, dim=3, keepdim=True) | |
# [batch, 1, n_src, 1] | |
s_target_energy = torch.sum(s_target ** 2, dim=3, keepdim=True) + self.EPS | |
# [batch, n_src, n_src, time] | |
pair_wise_proj = pair_wise_dot * s_target / s_target_energy | |
else: | |
# [batch, n_src, n_src, time] | |
pair_wise_proj = s_target.repeat(1, s_target.shape[2], 1, 1) | |
if self.sdr_type in ["sdsdr", "snr"]: | |
e_noise = s_estimate - s_target | |
else: | |
e_noise = s_estimate - pair_wise_proj | |
# [batch, n_src, n_src] | |
pair_wise_sdr = torch.sum(pair_wise_proj ** 2, dim=3) / ( | |
torch.sum(e_noise ** 2, dim=3) + self.EPS | |
) | |
if self.take_log: | |
pair_wise_sdr = 10 * torch.log10(pair_wise_sdr + self.EPS) | |
return -pair_wise_sdr | |
class SingleSrcNegSDR(_Loss): | |
def __init__( | |
self, sdr_type, zero_mean=True, take_log=True, reduction="none", EPS=1e-8 | |
): | |
assert reduction != "sum", NotImplementedError | |
super().__init__(reduction=reduction) | |
assert sdr_type in ["snr", "sisdr", "sdsdr"] | |
self.sdr_type = sdr_type | |
self.zero_mean = zero_mean | |
self.take_log = take_log | |
self.EPS = 1e-8 | |
def forward(self, ests, targets): | |
if targets.size() != ests.size() or targets.ndim != 2: | |
raise TypeError( | |
f"Inputs must be of shape [batch, time], got {targets.size()} and {ests.size()} instead" | |
) | |
# Step 1. Zero-mean norm | |
if self.zero_mean: | |
mean_source = torch.mean(targets, dim=1, keepdim=True) | |
mean_estimate = torch.mean(ests, dim=1, keepdim=True) | |
targets = targets - mean_source | |
ests = ests - mean_estimate | |
# Step 2. Pair-wise SI-SDR. | |
if self.sdr_type in ["sisdr", "sdsdr"]: | |
# [batch, 1] | |
dot = torch.sum(ests * targets, dim=1, keepdim=True) | |
# [batch, 1] | |
s_target_energy = torch.sum(targets ** 2, dim=1, keepdim=True) + self.EPS | |
# [batch, time] | |
scaled_target = dot * targets / s_target_energy | |
else: | |
# [batch, time] | |
scaled_target = targets | |
if self.sdr_type in ["sdsdr", "snr"]: | |
e_noise = ests - targets | |
else: | |
e_noise = ests - scaled_target | |
# [batch] | |
losses = torch.sum(scaled_target ** 2, dim=1) / ( | |
torch.sum(e_noise ** 2, dim=1) + self.EPS | |
) | |
if self.take_log: | |
losses = 10 * torch.log10(losses + self.EPS) | |
losses = losses.mean() if self.reduction == "mean" else losses | |
return -losses | |
class MultiSrcNegSDR(_Loss): | |
def __init__(self, sdr_type, zero_mean=True, take_log=True, EPS=1e-8): | |
super().__init__() | |
assert sdr_type in ["snr", "sisdr", "sdsdr"] | |
self.sdr_type = sdr_type | |
self.zero_mean = zero_mean | |
self.take_log = take_log | |
self.EPS = 1e-8 | |
def forward(self, ests, targets): | |
if targets.size() != ests.size() or targets.ndim != 3: | |
raise TypeError( | |
f"Inputs must be of shape [batch, n_src, time], got {targets.size()} and {ests.size()} instead" | |
) | |
# Step 1. Zero-mean norm | |
if self.zero_mean: | |
mean_source = torch.mean(targets, dim=2, keepdim=True) | |
mean_est = torch.mean(ests, dim=2, keepdim=True) | |
targets = targets - mean_source | |
ests = ests - mean_est | |
# Step 2. Pair-wise SI-SDR. | |
if self.sdr_type in ["sisdr", "sdsdr"]: | |
# [batch, n_src] | |
pair_wise_dot = torch.sum(ests * targets, dim=2, keepdim=True) | |
# [batch, n_src] | |
s_target_energy = torch.sum(targets ** 2, dim=2, keepdim=True) + self.EPS | |
# [batch, n_src, time] | |
scaled_targets = pair_wise_dot * targets / s_target_energy | |
else: | |
# [batch, n_src, time] | |
scaled_targets = targets | |
if self.sdr_type in ["sdsdr", "snr"]: | |
e_noise = ests - targets | |
else: | |
e_noise = ests - scaled_targets | |
# [batch, n_src] | |
pair_wise_sdr = torch.sum(scaled_targets ** 2, dim=2) / ( | |
torch.sum(e_noise ** 2, dim=2) + self.EPS | |
) | |
if self.take_log: | |
pair_wise_sdr = 10 * torch.log10(pair_wise_sdr + self.EPS) | |
return -torch.mean(pair_wise_sdr, dim=-1) | |
class freq_MAE_WavL1Loss(_Loss): | |
def __init__(self, win=2048, stride=512): | |
super().__init__() | |
self.EPS = 1e-8 | |
self.win = win | |
self.stride = stride | |
def forward(self, ests, targets): | |
B, nsrc, _ = ests.shape | |
est_spec = torch.stft(ests.view(-1, ests.shape[-1]), n_fft=self.win, hop_length=self.stride, | |
window=torch.hann_window(self.win).to(ests.device).float(), | |
return_complex=True) | |
est_target = torch.stft(targets.view(-1, targets.shape[-1]), n_fft=self.win, hop_length=self.stride, | |
window=torch.hann_window(self.win).to(ests.device).float(), | |
return_complex=True) | |
freq_L1 = (est_spec.real - est_target.real).abs().mean((1,2)) + (est_spec.imag - est_target.imag).abs().mean((1,2)) | |
freq_L1 = freq_L1.view(B, nsrc).mean(-1) | |
wave_l1 = (ests - targets).abs().mean(-1) | |
# print(freq_L1.shape, wave_l1.shape) | |
wave_l1 = wave_l1.view(B, nsrc).mean(-1) | |
return freq_L1 + wave_l1 | |
class freq_MSE_Loss(_Loss): | |
def __init__(self, win=640, stride=160): | |
super().__init__() | |
self.EPS = 1e-8 | |
self.win = win | |
self.stride = stride | |
def forward(self, ests, targets): | |
B, nsrc, _ = ests.shape | |
est_spec = torch.stft(ests.view(-1, ests.shape[-1]), n_fft=self.win, hop_length=self.stride, | |
window=torch.hann_window(self.win).to(ests.device).float(), | |
return_complex=True) | |
est_target = torch.stft(targets.view(-1, targets.shape[-1]), n_fft=self.win, hop_length=self.stride, | |
window=torch.hann_window(self.win).to(ests.device).float(), | |
return_complex=True) | |
freq_mse = (est_spec.real - est_target.real).square().mean((1,2)) + (est_spec.imag - est_target.imag).square().mean((1,2)) | |
freq_mse = freq_mse.view(B, nsrc).mean(-1) | |
return freq_mse | |
# aliases | |
pairwise_neg_sisdr = PairwiseNegSDR("sisdr") | |
pairwise_neg_sdsdr = PairwiseNegSDR("sdsdr") | |
pairwise_neg_snr = PairwiseNegSDR("snr") | |
singlesrc_neg_sisdr = SingleSrcNegSDR("sisdr") | |
singlesrc_neg_sdsdr = SingleSrcNegSDR("sdsdr") | |
singlesrc_neg_snr = SingleSrcNegSDR("snr") | |
multisrc_neg_sisdr = MultiSrcNegSDR("sisdr") | |
multisrc_neg_sdsdr = MultiSrcNegSDR("sdsdr") | |
multisrc_neg_snr = MultiSrcNegSDR("snr") | |
freq_mae_wavl1loss = freq_MAE_WavL1Loss() | |
pairwise_neg_sisdr_freq_mse = (PairwiseNegSDR("sisdr"), freq_MSE_Loss()) | |
pairwise_neg_snr_multidecoder = (PairwiseNegSDR("snr"), MultiSrcNegSDR("snr")) |