Spaces:
Running
Running
# import torch | |
# import torch.nn as nn | |
# from itertools import combinations | |
# from pprint import pprint | |
# def parts_mixgen(lists): | |
# parts = [] | |
# for k in range(len(lists) + 1): | |
# for sublist in combinations(lists, k): | |
# rest = [] | |
# if sublist != () and len(sublist) != len(lists): | |
# for item in lists: | |
# if item not in sublist: | |
# rest.append(item) | |
# parts.append([sublist, rest]) | |
# return parts | |
# def parts_mixgen_bisection(lists, srcmix, mix): | |
# if mix == 0: | |
# yield [] | |
# else: | |
# for c in combinations(lists, srcmix): | |
# rest = [x for x in lists if x not in c] | |
# for r in parts_mixgen_bisection(rest, srcmix, mix - 1): | |
# yield [list(c), *r] | |
# class MixIT(nn.Module): | |
# def __init__(self, loss_func, bisection=True): | |
# super().__init__() | |
# self.loss_func = loss_func | |
# self.bisection = bisection | |
# def forward(self, ests, targets, return_ests=None, **kwargs): | |
# assert ests.shape[0] == targets.shape[0] | |
# assert ests.shape[-1] == targets.shape[-1] | |
# if self.bisection: | |
# loss, min_loss_idx, parts = self.mixit_bisection(self.loss_func, ests, targets, **kwargs) | |
# else: | |
# loss, min_loss_idx, parts = self.mixit_non_bisection(self.loss_func, ests, targets, **kwargs) | |
# mean_loss = torch.mean(loss) | |
# if not return_ests: | |
# return mean_loss | |
# reordered = self.reorder_source(ests, targets, min_loss_idx, parts) | |
# return mean_loss, reordered | |
# def mixit_bisection(self, loss_func, ests, targets, **kwargs): | |
# n_mix = targets.shape[1] | |
# n_src = ests.shape[1] | |
# srcmix = n_src // n_mix | |
# parts = parts_mixgen_bisection(range(n_src), srcmix, n_mix) | |
# loss_lists = [] | |
# for part in parts: | |
# ests_mix = torch.stack([ests[:, i, :].sum(1) for i in part], dim=1) | |
# loss = loss_func(ests_mix, targets, **kwargs) | |
# loss_lists.append(loss[:, None]) | |
# loss_lists = torch.cat(loss_lists, dim=1) | |
# min_loss, min_loss_indexes = torch.min(loss_lists, dim=1, keepdim=True) | |
# return min_loss, min_loss_indexes, parts | |
# def mixit_non_bisection(self, loss_func, ests, targets, **kwargs): | |
# n_mix = targets.shape[1] | |
# n_src = ests.shape[1] | |
# parts = parts_mixgen(range(n_src)) | |
# loss_lists = [] | |
# for part in parts: | |
# ests_mix = torch.stack([ests[:, i, :].sum(1) for i in part], dim=1) | |
# loss = loss_func(ests_mix, targets, **kwargs) | |
# loss_lists.append(loss[:, None]) | |
# loss_lists = torch.cat(loss_lists, dim=1) | |
# min_loss, min_loss_indexes = torch.min(loss_lists, dim=1, keepdim=True) | |
# return min_loss, min_loss_indexes, parts | |
# def reoder_source(self, ests, targets, min_loss_idx, parts, **kwargs): | |
# ordered = torch.zeros_like(targets) | |
# for b, idx in enumerate(min_loss_idx): | |
# right_partition = parts[idx] | |
# ordered[b, :, :] = torch.stack( | |
# [ests[b, idx, :][None, :, :].sum(1) for idx in right_partition], dim=1 | |
# ) | |
# return ordered | |
# if __name__ == "__main__": | |
# print(parts_mixgen(range(4))) | |
# print([item for item in parts_mixgen_bisection(range(4), 2, 2)]) | |
import warnings | |
from itertools import combinations | |
import torch | |
from torch import nn | |
class MixITLossWrapper(nn.Module): | |
r"""Mixture invariant loss wrapper. | |
Args: | |
loss_func: function with signature (est_targets, targets, **kwargs). | |
generalized (bool): Determines how MixIT is applied. If False , | |
apply MixIT for any number of mixtures as soon as they contain | |
the same number of sources (:meth:`~MixITLossWrapper.best_part_mixit`.) | |
If True (default), apply MixIT for two mixtures, but those mixtures do not | |
necessarly have to contain the same number of sources. | |
See :meth:`~MixITLossWrapper.best_part_mixit_generalized`. | |
For each of these modes, the best partition and reordering will be | |
automatically computed. | |
Examples: | |
>>> import torch | |
>>> from asteroid.losses import multisrc_mse | |
>>> mixtures = torch.randn(10, 2, 16000) | |
>>> est_sources = torch.randn(10, 4, 16000) | |
>>> # Compute MixIT loss based on pairwise losses | |
>>> loss_func = MixITLossWrapper(multisrc_mse) | |
>>> loss_val = loss_func(est_sources, mixtures) | |
References | |
[1] Scott Wisdom et al. "Unsupervised sound separation using | |
mixtures of mixtures." arXiv:2006.12701 (2020) | |
""" | |
def __init__(self, loss_func, generalized=True): | |
super().__init__() | |
self.loss_func = loss_func | |
self.generalized = generalized | |
def forward(self, est_targets, targets, return_est=False, **kwargs): | |
r"""Find the best partition and return the loss. | |
Args: | |
est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, *)`. | |
The batch of target estimates. | |
targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`. | |
The batch of training targets | |
return_est: Boolean. Whether to return the estimated mixtures | |
estimates (To compute metrics or to save example). | |
**kwargs: additional keyword argument that will be passed to the | |
loss function. | |
Returns: | |
- Best partition loss for each batch sample, average over | |
the batch. torch.Tensor(loss_value) | |
- The estimated mixtures (estimated sources summed according to the partition) | |
if return_est is True. torch.Tensor of shape :math:`(batch, nmix, ...)`. | |
""" | |
# Check input dimensions | |
assert est_targets.shape[0] == targets.shape[0] | |
assert est_targets.shape[2] == targets.shape[2] | |
if not self.generalized: | |
min_loss, min_loss_idx, parts = self.best_part_mixit( | |
self.loss_func, est_targets, targets, **kwargs | |
) | |
else: | |
min_loss, min_loss_idx, parts = self.best_part_mixit_generalized( | |
self.loss_func, est_targets, targets, **kwargs | |
) | |
# Take the mean over the batch | |
mean_loss = torch.mean(min_loss) | |
if not return_est: | |
return mean_loss | |
# Order and sum on the best partition to get the estimated mixtures | |
reordered = self.reorder_source(est_targets, targets, min_loss_idx, parts) | |
return mean_loss, reordered | |
def best_part_mixit(loss_func, est_targets, targets, **kwargs): | |
r"""Find best partition of the estimated sources that gives the minimum | |
loss for the MixIT training paradigm in [1]. Valid for any number of | |
mixtures as soon as they contain the same number of sources. | |
Args: | |
loss_func: function with signature ``(est_targets, targets, **kwargs)`` | |
The loss function to get batch losses from. | |
est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`. | |
The batch of target estimates. | |
targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`. | |
The batch of training targets (mixtures). | |
**kwargs: additional keyword argument that will be passed to the | |
loss function. | |
Returns: | |
- :class:`torch.Tensor`: | |
The loss corresponding to the best permutation of size (batch,). | |
- :class:`torch.LongTensor`: | |
The indices of the best partition. | |
- :class:`list`: | |
list of the possible partitions of the sources. | |
""" | |
nmix = targets.shape[1] | |
nsrc = est_targets.shape[1] | |
if nsrc % nmix != 0: | |
raise ValueError( | |
"The mixtures are assumed to contain the same number of sources" | |
) | |
nsrcmix = nsrc // nmix | |
# Generate all unique partitions of size k from a list lst of | |
# length n, where l = n // k is the number of parts. The total | |
# number of such partitions is: NPK(n,k) = n! / ((k!)^l * l!) | |
# Algorithm recursively distributes items over parts | |
def parts_mixit(lst, k, l): | |
if l == 0: | |
yield [] | |
else: | |
for c in combinations(lst, k): | |
rest = [x for x in lst if x not in c] | |
for r in parts_mixit(rest, k, l - 1): | |
yield [list(c), *r] | |
# Generate all the possible partitions | |
parts = list(parts_mixit(range(nsrc), nsrcmix, nmix)) | |
# Compute the loss corresponding to each partition | |
loss_set = MixITLossWrapper.loss_set_from_parts( | |
loss_func, est_targets=est_targets, targets=targets, parts=parts, **kwargs | |
) | |
# Indexes and values of min losses for each batch element | |
min_loss, min_loss_indexes = torch.min(loss_set, dim=1, keepdim=True) | |
return min_loss, min_loss_indexes, parts | |
def best_part_mixit_generalized(loss_func, est_targets, targets, **kwargs): | |
r"""Find best partition of the estimated sources that gives the minimum | |
loss for the MixIT training paradigm in [1]. Valid only for two mixtures, | |
but those mixtures do not necessarly have to contain the same number of | |
sources e.g the case where one mixture is silent is allowed.. | |
Args: | |
loss_func: function with signature ``(est_targets, targets, **kwargs)`` | |
The loss function to get batch losses from. | |
est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`. | |
The batch of target estimates. | |
targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`. | |
The batch of training targets (mixtures). | |
**kwargs: additional keyword argument that will be passed to the | |
loss function. | |
Returns: | |
- :class:`torch.Tensor`: | |
The loss corresponding to the best permutation of size (batch,). | |
- :class:`torch.LongTensor`: | |
The indexes of the best permutations. | |
- :class:`list`: | |
list of the possible partitions of the sources. | |
""" | |
nmix = targets.shape[1] # number of mixtures | |
nsrc = est_targets.shape[1] # number of estimated sources | |
if nmix != 2: | |
raise ValueError("Works only with two mixtures") | |
# Generate all unique partitions of any size from a list lst of | |
# length n. Algorithm recursively distributes items over parts | |
def parts_mixit_gen(lst): | |
partitions = [] | |
for k in range(len(lst) + 1): | |
for c in combinations(lst, k): | |
rest = [] | |
if c != () and len(c) != len(lst): | |
for item in lst: | |
if item not in c: | |
rest.append(item) | |
partitions.append([c, rest]) | |
return partitions | |
# Generate all the possible partitions | |
parts = parts_mixit_gen(range(nsrc)) | |
# Compute the loss corresponding to each partition | |
loss_set = MixITLossWrapper.loss_set_from_parts( | |
loss_func, est_targets=est_targets, targets=targets, parts=parts, **kwargs | |
) | |
# Indexes and values of min losses for each batch element | |
min_loss, min_loss_indexes = torch.min(loss_set, dim=1, keepdim=True) | |
return min_loss, min_loss_indexes, parts | |
def loss_set_from_parts(loss_func, est_targets, targets, parts, **kwargs): | |
"""Common loop between both best_part_mixit""" | |
loss_set = [] | |
for partition in parts: | |
# sum the sources according to the given partition | |
est_mixes = torch.stack( | |
[est_targets[:, idx, :].sum(1) for idx in partition], dim=1 | |
) | |
# get loss for the given partition | |
loss_set.append(loss_func(est_mixes, targets, **kwargs)[:, None]) | |
loss_set = torch.cat(loss_set, dim=1) | |
return loss_set | |
def reorder_source(est_targets, targets, min_loss_idx, parts): | |
"""Reorder sources according to the best partition. | |
Args: | |
est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`. | |
The batch of target estimates. | |
targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`. | |
The batch of training targets. | |
min_loss_idx: torch.LongTensor. The indexes of the best permutations. | |
parts: list of the possible partitions of the sources. | |
Returns: | |
:class:`torch.Tensor`: Reordered sources of shape :math:`(batch, nmix, time)`. | |
""" | |
# For each batch there is a different min_loss_idx | |
ordered = torch.zeros_like(targets) | |
for b, idx in enumerate(min_loss_idx): | |
right_partition = parts[idx] | |
# Sum the estimated sources to get the estimated mixtures | |
ordered[b, :, :] = torch.stack( | |
[est_targets[b, idx, :][None, :, :].sum(1) for idx in right_partition], | |
dim=1, | |
) | |
return ordered | |