fffiloni's picture
Migrated from GitHub
406f22d verified
# import torch
# import torch.nn as nn
# from itertools import combinations
# from pprint import pprint
# def parts_mixgen(lists):
# parts = []
# for k in range(len(lists) + 1):
# for sublist in combinations(lists, k):
# rest = []
# if sublist != () and len(sublist) != len(lists):
# for item in lists:
# if item not in sublist:
# rest.append(item)
# parts.append([sublist, rest])
# return parts
# def parts_mixgen_bisection(lists, srcmix, mix):
# if mix == 0:
# yield []
# else:
# for c in combinations(lists, srcmix):
# rest = [x for x in lists if x not in c]
# for r in parts_mixgen_bisection(rest, srcmix, mix - 1):
# yield [list(c), *r]
# class MixIT(nn.Module):
# def __init__(self, loss_func, bisection=True):
# super().__init__()
# self.loss_func = loss_func
# self.bisection = bisection
# def forward(self, ests, targets, return_ests=None, **kwargs):
# assert ests.shape[0] == targets.shape[0]
# assert ests.shape[-1] == targets.shape[-1]
# if self.bisection:
# loss, min_loss_idx, parts = self.mixit_bisection(self.loss_func, ests, targets, **kwargs)
# else:
# loss, min_loss_idx, parts = self.mixit_non_bisection(self.loss_func, ests, targets, **kwargs)
# mean_loss = torch.mean(loss)
# if not return_ests:
# return mean_loss
# reordered = self.reorder_source(ests, targets, min_loss_idx, parts)
# return mean_loss, reordered
# def mixit_bisection(self, loss_func, ests, targets, **kwargs):
# n_mix = targets.shape[1]
# n_src = ests.shape[1]
# srcmix = n_src // n_mix
# parts = parts_mixgen_bisection(range(n_src), srcmix, n_mix)
# loss_lists = []
# for part in parts:
# ests_mix = torch.stack([ests[:, i, :].sum(1) for i in part], dim=1)
# loss = loss_func(ests_mix, targets, **kwargs)
# loss_lists.append(loss[:, None])
# loss_lists = torch.cat(loss_lists, dim=1)
# min_loss, min_loss_indexes = torch.min(loss_lists, dim=1, keepdim=True)
# return min_loss, min_loss_indexes, parts
# def mixit_non_bisection(self, loss_func, ests, targets, **kwargs):
# n_mix = targets.shape[1]
# n_src = ests.shape[1]
# parts = parts_mixgen(range(n_src))
# loss_lists = []
# for part in parts:
# ests_mix = torch.stack([ests[:, i, :].sum(1) for i in part], dim=1)
# loss = loss_func(ests_mix, targets, **kwargs)
# loss_lists.append(loss[:, None])
# loss_lists = torch.cat(loss_lists, dim=1)
# min_loss, min_loss_indexes = torch.min(loss_lists, dim=1, keepdim=True)
# return min_loss, min_loss_indexes, parts
# def reoder_source(self, ests, targets, min_loss_idx, parts, **kwargs):
# ordered = torch.zeros_like(targets)
# for b, idx in enumerate(min_loss_idx):
# right_partition = parts[idx]
# ordered[b, :, :] = torch.stack(
# [ests[b, idx, :][None, :, :].sum(1) for idx in right_partition], dim=1
# )
# return ordered
# if __name__ == "__main__":
# print(parts_mixgen(range(4)))
# print([item for item in parts_mixgen_bisection(range(4), 2, 2)])
import warnings
from itertools import combinations
import torch
from torch import nn
class MixITLossWrapper(nn.Module):
r"""Mixture invariant loss wrapper.
Args:
loss_func: function with signature (est_targets, targets, **kwargs).
generalized (bool): Determines how MixIT is applied. If False ,
apply MixIT for any number of mixtures as soon as they contain
the same number of sources (:meth:`~MixITLossWrapper.best_part_mixit`.)
If True (default), apply MixIT for two mixtures, but those mixtures do not
necessarly have to contain the same number of sources.
See :meth:`~MixITLossWrapper.best_part_mixit_generalized`.
For each of these modes, the best partition and reordering will be
automatically computed.
Examples:
>>> import torch
>>> from asteroid.losses import multisrc_mse
>>> mixtures = torch.randn(10, 2, 16000)
>>> est_sources = torch.randn(10, 4, 16000)
>>> # Compute MixIT loss based on pairwise losses
>>> loss_func = MixITLossWrapper(multisrc_mse)
>>> loss_val = loss_func(est_sources, mixtures)
References
[1] Scott Wisdom et al. "Unsupervised sound separation using
mixtures of mixtures." arXiv:2006.12701 (2020)
"""
def __init__(self, loss_func, generalized=True):
super().__init__()
self.loss_func = loss_func
self.generalized = generalized
def forward(self, est_targets, targets, return_est=False, **kwargs):
r"""Find the best partition and return the loss.
Args:
est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, *)`.
The batch of target estimates.
targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`.
The batch of training targets
return_est: Boolean. Whether to return the estimated mixtures
estimates (To compute metrics or to save example).
**kwargs: additional keyword argument that will be passed to the
loss function.
Returns:
- Best partition loss for each batch sample, average over
the batch. torch.Tensor(loss_value)
- The estimated mixtures (estimated sources summed according to the partition)
if return_est is True. torch.Tensor of shape :math:`(batch, nmix, ...)`.
"""
# Check input dimensions
assert est_targets.shape[0] == targets.shape[0]
assert est_targets.shape[2] == targets.shape[2]
if not self.generalized:
min_loss, min_loss_idx, parts = self.best_part_mixit(
self.loss_func, est_targets, targets, **kwargs
)
else:
min_loss, min_loss_idx, parts = self.best_part_mixit_generalized(
self.loss_func, est_targets, targets, **kwargs
)
# Take the mean over the batch
mean_loss = torch.mean(min_loss)
if not return_est:
return mean_loss
# Order and sum on the best partition to get the estimated mixtures
reordered = self.reorder_source(est_targets, targets, min_loss_idx, parts)
return mean_loss, reordered
@staticmethod
def best_part_mixit(loss_func, est_targets, targets, **kwargs):
r"""Find best partition of the estimated sources that gives the minimum
loss for the MixIT training paradigm in [1]. Valid for any number of
mixtures as soon as they contain the same number of sources.
Args:
loss_func: function with signature ``(est_targets, targets, **kwargs)``
The loss function to get batch losses from.
est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`.
The batch of target estimates.
targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`.
The batch of training targets (mixtures).
**kwargs: additional keyword argument that will be passed to the
loss function.
Returns:
- :class:`torch.Tensor`:
The loss corresponding to the best permutation of size (batch,).
- :class:`torch.LongTensor`:
The indices of the best partition.
- :class:`list`:
list of the possible partitions of the sources.
"""
nmix = targets.shape[1]
nsrc = est_targets.shape[1]
if nsrc % nmix != 0:
raise ValueError(
"The mixtures are assumed to contain the same number of sources"
)
nsrcmix = nsrc // nmix
# Generate all unique partitions of size k from a list lst of
# length n, where l = n // k is the number of parts. The total
# number of such partitions is: NPK(n,k) = n! / ((k!)^l * l!)
# Algorithm recursively distributes items over parts
def parts_mixit(lst, k, l):
if l == 0:
yield []
else:
for c in combinations(lst, k):
rest = [x for x in lst if x not in c]
for r in parts_mixit(rest, k, l - 1):
yield [list(c), *r]
# Generate all the possible partitions
parts = list(parts_mixit(range(nsrc), nsrcmix, nmix))
# Compute the loss corresponding to each partition
loss_set = MixITLossWrapper.loss_set_from_parts(
loss_func, est_targets=est_targets, targets=targets, parts=parts, **kwargs
)
# Indexes and values of min losses for each batch element
min_loss, min_loss_indexes = torch.min(loss_set, dim=1, keepdim=True)
return min_loss, min_loss_indexes, parts
@staticmethod
def best_part_mixit_generalized(loss_func, est_targets, targets, **kwargs):
r"""Find best partition of the estimated sources that gives the minimum
loss for the MixIT training paradigm in [1]. Valid only for two mixtures,
but those mixtures do not necessarly have to contain the same number of
sources e.g the case where one mixture is silent is allowed..
Args:
loss_func: function with signature ``(est_targets, targets, **kwargs)``
The loss function to get batch losses from.
est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`.
The batch of target estimates.
targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`.
The batch of training targets (mixtures).
**kwargs: additional keyword argument that will be passed to the
loss function.
Returns:
- :class:`torch.Tensor`:
The loss corresponding to the best permutation of size (batch,).
- :class:`torch.LongTensor`:
The indexes of the best permutations.
- :class:`list`:
list of the possible partitions of the sources.
"""
nmix = targets.shape[1] # number of mixtures
nsrc = est_targets.shape[1] # number of estimated sources
if nmix != 2:
raise ValueError("Works only with two mixtures")
# Generate all unique partitions of any size from a list lst of
# length n. Algorithm recursively distributes items over parts
def parts_mixit_gen(lst):
partitions = []
for k in range(len(lst) + 1):
for c in combinations(lst, k):
rest = []
if c != () and len(c) != len(lst):
for item in lst:
if item not in c:
rest.append(item)
partitions.append([c, rest])
return partitions
# Generate all the possible partitions
parts = parts_mixit_gen(range(nsrc))
# Compute the loss corresponding to each partition
loss_set = MixITLossWrapper.loss_set_from_parts(
loss_func, est_targets=est_targets, targets=targets, parts=parts, **kwargs
)
# Indexes and values of min losses for each batch element
min_loss, min_loss_indexes = torch.min(loss_set, dim=1, keepdim=True)
return min_loss, min_loss_indexes, parts
@staticmethod
def loss_set_from_parts(loss_func, est_targets, targets, parts, **kwargs):
"""Common loop between both best_part_mixit"""
loss_set = []
for partition in parts:
# sum the sources according to the given partition
est_mixes = torch.stack(
[est_targets[:, idx, :].sum(1) for idx in partition], dim=1
)
# get loss for the given partition
loss_set.append(loss_func(est_mixes, targets, **kwargs)[:, None])
loss_set = torch.cat(loss_set, dim=1)
return loss_set
@staticmethod
def reorder_source(est_targets, targets, min_loss_idx, parts):
"""Reorder sources according to the best partition.
Args:
est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`.
The batch of target estimates.
targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`.
The batch of training targets.
min_loss_idx: torch.LongTensor. The indexes of the best permutations.
parts: list of the possible partitions of the sources.
Returns:
:class:`torch.Tensor`: Reordered sources of shape :math:`(batch, nmix, time)`.
"""
# For each batch there is a different min_loss_idx
ordered = torch.zeros_like(targets)
for b, idx in enumerate(min_loss_idx):
right_partition = parts[idx]
# Sum the estimated sources to get the estimated mixtures
ordered[b, :, :] = torch.stack(
[est_targets[b, idx, :][None, :, :].sum(1) for idx in right_partition],
dim=1,
)
return ordered