fffiloni's picture
Migrated from GitHub
406f22d verified
import torch
from torch.nn.modules.loss import _Loss
class PairwiseNegSDR(_Loss):
def __init__(self, sdr_type, zero_mean=True, take_log=True, EPS=1e-8):
super().__init__()
assert sdr_type in ["snr", "sisdr", "sdsdr"]
self.sdr_type = sdr_type
self.zero_mean = zero_mean
self.take_log = take_log
self.EPS = EPS
def forward(self, ests, targets):
if targets.size() != ests.size() or targets.ndim != 3:
raise TypeError(
f"Inputs must be of shape [batch, n_src, time], got {targets.size()} and {ests.size()} instead"
)
assert targets.size() == ests.size()
# Step 1. Zero-mean norm
if self.zero_mean:
mean_source = torch.mean(targets, dim=2, keepdim=True)
mean_estimate = torch.mean(ests, dim=2, keepdim=True)
targets = targets - mean_source
ests = ests - mean_estimate
# Step 2. Pair-wise SI-SDR. (Reshape to use broadcast)
s_target = torch.unsqueeze(targets, dim=1)
s_estimate = torch.unsqueeze(ests, dim=2)
if self.sdr_type in ["sisdr", "sdsdr"]:
# [batch, n_src, n_src, 1]
pair_wise_dot = torch.sum(s_estimate * s_target, dim=3, keepdim=True)
# [batch, 1, n_src, 1]
s_target_energy = torch.sum(s_target ** 2, dim=3, keepdim=True) + self.EPS
# [batch, n_src, n_src, time]
pair_wise_proj = pair_wise_dot * s_target / s_target_energy
else:
# [batch, n_src, n_src, time]
pair_wise_proj = s_target.repeat(1, s_target.shape[2], 1, 1)
if self.sdr_type in ["sdsdr", "snr"]:
e_noise = s_estimate - s_target
else:
e_noise = s_estimate - pair_wise_proj
# [batch, n_src, n_src]
pair_wise_sdr = torch.sum(pair_wise_proj ** 2, dim=3) / (
torch.sum(e_noise ** 2, dim=3) + self.EPS
)
if self.take_log:
pair_wise_sdr = 10 * torch.log10(pair_wise_sdr + self.EPS)
return -pair_wise_sdr
class SingleSrcNegSDR(_Loss):
def __init__(
self, sdr_type, zero_mean=True, take_log=True, reduction="none", EPS=1e-8
):
assert reduction != "sum", NotImplementedError
super().__init__(reduction=reduction)
assert sdr_type in ["snr", "sisdr", "sdsdr"]
self.sdr_type = sdr_type
self.zero_mean = zero_mean
self.take_log = take_log
self.EPS = 1e-8
def forward(self, ests, targets):
if targets.size() != ests.size() or targets.ndim != 2:
raise TypeError(
f"Inputs must be of shape [batch, time], got {targets.size()} and {ests.size()} instead"
)
# Step 1. Zero-mean norm
if self.zero_mean:
mean_source = torch.mean(targets, dim=1, keepdim=True)
mean_estimate = torch.mean(ests, dim=1, keepdim=True)
targets = targets - mean_source
ests = ests - mean_estimate
# Step 2. Pair-wise SI-SDR.
if self.sdr_type in ["sisdr", "sdsdr"]:
# [batch, 1]
dot = torch.sum(ests * targets, dim=1, keepdim=True)
# [batch, 1]
s_target_energy = torch.sum(targets ** 2, dim=1, keepdim=True) + self.EPS
# [batch, time]
scaled_target = dot * targets / s_target_energy
else:
# [batch, time]
scaled_target = targets
if self.sdr_type in ["sdsdr", "snr"]:
e_noise = ests - targets
else:
e_noise = ests - scaled_target
# [batch]
losses = torch.sum(scaled_target ** 2, dim=1) / (
torch.sum(e_noise ** 2, dim=1) + self.EPS
)
if self.take_log:
losses = 10 * torch.log10(losses + self.EPS)
losses = losses.mean() if self.reduction == "mean" else losses
return -losses
class MultiSrcNegSDR(_Loss):
def __init__(self, sdr_type, zero_mean=True, take_log=True, EPS=1e-8):
super().__init__()
assert sdr_type in ["snr", "sisdr", "sdsdr"]
self.sdr_type = sdr_type
self.zero_mean = zero_mean
self.take_log = take_log
self.EPS = 1e-8
def forward(self, ests, targets):
if targets.size() != ests.size() or targets.ndim != 3:
raise TypeError(
f"Inputs must be of shape [batch, n_src, time], got {targets.size()} and {ests.size()} instead"
)
# Step 1. Zero-mean norm
if self.zero_mean:
mean_source = torch.mean(targets, dim=2, keepdim=True)
mean_est = torch.mean(ests, dim=2, keepdim=True)
targets = targets - mean_source
ests = ests - mean_est
# Step 2. Pair-wise SI-SDR.
if self.sdr_type in ["sisdr", "sdsdr"]:
# [batch, n_src]
pair_wise_dot = torch.sum(ests * targets, dim=2, keepdim=True)
# [batch, n_src]
s_target_energy = torch.sum(targets ** 2, dim=2, keepdim=True) + self.EPS
# [batch, n_src, time]
scaled_targets = pair_wise_dot * targets / s_target_energy
else:
# [batch, n_src, time]
scaled_targets = targets
if self.sdr_type in ["sdsdr", "snr"]:
e_noise = ests - targets
else:
e_noise = ests - scaled_targets
# [batch, n_src]
pair_wise_sdr = torch.sum(scaled_targets ** 2, dim=2) / (
torch.sum(e_noise ** 2, dim=2) + self.EPS
)
if self.take_log:
pair_wise_sdr = 10 * torch.log10(pair_wise_sdr + self.EPS)
return -torch.mean(pair_wise_sdr, dim=-1)
class freq_MAE_WavL1Loss(_Loss):
def __init__(self, win=2048, stride=512):
super().__init__()
self.EPS = 1e-8
self.win = win
self.stride = stride
def forward(self, ests, targets):
B, nsrc, _ = ests.shape
est_spec = torch.stft(ests.view(-1, ests.shape[-1]), n_fft=self.win, hop_length=self.stride,
window=torch.hann_window(self.win).to(ests.device).float(),
return_complex=True)
est_target = torch.stft(targets.view(-1, targets.shape[-1]), n_fft=self.win, hop_length=self.stride,
window=torch.hann_window(self.win).to(ests.device).float(),
return_complex=True)
freq_L1 = (est_spec.real - est_target.real).abs().mean((1,2)) + (est_spec.imag - est_target.imag).abs().mean((1,2))
freq_L1 = freq_L1.view(B, nsrc).mean(-1)
wave_l1 = (ests - targets).abs().mean(-1)
# print(freq_L1.shape, wave_l1.shape)
wave_l1 = wave_l1.view(B, nsrc).mean(-1)
return freq_L1 + wave_l1
class freq_MSE_Loss(_Loss):
def __init__(self, win=640, stride=160):
super().__init__()
self.EPS = 1e-8
self.win = win
self.stride = stride
def forward(self, ests, targets):
B, nsrc, _ = ests.shape
est_spec = torch.stft(ests.view(-1, ests.shape[-1]), n_fft=self.win, hop_length=self.stride,
window=torch.hann_window(self.win).to(ests.device).float(),
return_complex=True)
est_target = torch.stft(targets.view(-1, targets.shape[-1]), n_fft=self.win, hop_length=self.stride,
window=torch.hann_window(self.win).to(ests.device).float(),
return_complex=True)
freq_mse = (est_spec.real - est_target.real).square().mean((1,2)) + (est_spec.imag - est_target.imag).square().mean((1,2))
freq_mse = freq_mse.view(B, nsrc).mean(-1)
return freq_mse
# aliases
pairwise_neg_sisdr = PairwiseNegSDR("sisdr")
pairwise_neg_sdsdr = PairwiseNegSDR("sdsdr")
pairwise_neg_snr = PairwiseNegSDR("snr")
singlesrc_neg_sisdr = SingleSrcNegSDR("sisdr")
singlesrc_neg_sdsdr = SingleSrcNegSDR("sdsdr")
singlesrc_neg_snr = SingleSrcNegSDR("snr")
multisrc_neg_sisdr = MultiSrcNegSDR("sisdr")
multisrc_neg_sdsdr = MultiSrcNegSDR("sdsdr")
multisrc_neg_snr = MultiSrcNegSDR("snr")
freq_mae_wavl1loss = freq_MAE_WavL1Loss()
pairwise_neg_sisdr_freq_mse = (PairwiseNegSDR("sisdr"), freq_MSE_Loss())
pairwise_neg_snr_multidecoder = (PairwiseNegSDR("snr"), MultiSrcNegSDR("snr"))