Spaces:
Running
Running
File size: 13,568 Bytes
406f22d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 |
# import torch
# import torch.nn as nn
# from itertools import combinations
# from pprint import pprint
# def parts_mixgen(lists):
# parts = []
# for k in range(len(lists) + 1):
# for sublist in combinations(lists, k):
# rest = []
# if sublist != () and len(sublist) != len(lists):
# for item in lists:
# if item not in sublist:
# rest.append(item)
# parts.append([sublist, rest])
# return parts
# def parts_mixgen_bisection(lists, srcmix, mix):
# if mix == 0:
# yield []
# else:
# for c in combinations(lists, srcmix):
# rest = [x for x in lists if x not in c]
# for r in parts_mixgen_bisection(rest, srcmix, mix - 1):
# yield [list(c), *r]
# class MixIT(nn.Module):
# def __init__(self, loss_func, bisection=True):
# super().__init__()
# self.loss_func = loss_func
# self.bisection = bisection
# def forward(self, ests, targets, return_ests=None, **kwargs):
# assert ests.shape[0] == targets.shape[0]
# assert ests.shape[-1] == targets.shape[-1]
# if self.bisection:
# loss, min_loss_idx, parts = self.mixit_bisection(self.loss_func, ests, targets, **kwargs)
# else:
# loss, min_loss_idx, parts = self.mixit_non_bisection(self.loss_func, ests, targets, **kwargs)
# mean_loss = torch.mean(loss)
# if not return_ests:
# return mean_loss
# reordered = self.reorder_source(ests, targets, min_loss_idx, parts)
# return mean_loss, reordered
# def mixit_bisection(self, loss_func, ests, targets, **kwargs):
# n_mix = targets.shape[1]
# n_src = ests.shape[1]
# srcmix = n_src // n_mix
# parts = parts_mixgen_bisection(range(n_src), srcmix, n_mix)
# loss_lists = []
# for part in parts:
# ests_mix = torch.stack([ests[:, i, :].sum(1) for i in part], dim=1)
# loss = loss_func(ests_mix, targets, **kwargs)
# loss_lists.append(loss[:, None])
# loss_lists = torch.cat(loss_lists, dim=1)
# min_loss, min_loss_indexes = torch.min(loss_lists, dim=1, keepdim=True)
# return min_loss, min_loss_indexes, parts
# def mixit_non_bisection(self, loss_func, ests, targets, **kwargs):
# n_mix = targets.shape[1]
# n_src = ests.shape[1]
# parts = parts_mixgen(range(n_src))
# loss_lists = []
# for part in parts:
# ests_mix = torch.stack([ests[:, i, :].sum(1) for i in part], dim=1)
# loss = loss_func(ests_mix, targets, **kwargs)
# loss_lists.append(loss[:, None])
# loss_lists = torch.cat(loss_lists, dim=1)
# min_loss, min_loss_indexes = torch.min(loss_lists, dim=1, keepdim=True)
# return min_loss, min_loss_indexes, parts
# def reoder_source(self, ests, targets, min_loss_idx, parts, **kwargs):
# ordered = torch.zeros_like(targets)
# for b, idx in enumerate(min_loss_idx):
# right_partition = parts[idx]
# ordered[b, :, :] = torch.stack(
# [ests[b, idx, :][None, :, :].sum(1) for idx in right_partition], dim=1
# )
# return ordered
# if __name__ == "__main__":
# print(parts_mixgen(range(4)))
# print([item for item in parts_mixgen_bisection(range(4), 2, 2)])
import warnings
from itertools import combinations
import torch
from torch import nn
class MixITLossWrapper(nn.Module):
r"""Mixture invariant loss wrapper.
Args:
loss_func: function with signature (est_targets, targets, **kwargs).
generalized (bool): Determines how MixIT is applied. If False ,
apply MixIT for any number of mixtures as soon as they contain
the same number of sources (:meth:`~MixITLossWrapper.best_part_mixit`.)
If True (default), apply MixIT for two mixtures, but those mixtures do not
necessarly have to contain the same number of sources.
See :meth:`~MixITLossWrapper.best_part_mixit_generalized`.
For each of these modes, the best partition and reordering will be
automatically computed.
Examples:
>>> import torch
>>> from asteroid.losses import multisrc_mse
>>> mixtures = torch.randn(10, 2, 16000)
>>> est_sources = torch.randn(10, 4, 16000)
>>> # Compute MixIT loss based on pairwise losses
>>> loss_func = MixITLossWrapper(multisrc_mse)
>>> loss_val = loss_func(est_sources, mixtures)
References
[1] Scott Wisdom et al. "Unsupervised sound separation using
mixtures of mixtures." arXiv:2006.12701 (2020)
"""
def __init__(self, loss_func, generalized=True):
super().__init__()
self.loss_func = loss_func
self.generalized = generalized
def forward(self, est_targets, targets, return_est=False, **kwargs):
r"""Find the best partition and return the loss.
Args:
est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, *)`.
The batch of target estimates.
targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`.
The batch of training targets
return_est: Boolean. Whether to return the estimated mixtures
estimates (To compute metrics or to save example).
**kwargs: additional keyword argument that will be passed to the
loss function.
Returns:
- Best partition loss for each batch sample, average over
the batch. torch.Tensor(loss_value)
- The estimated mixtures (estimated sources summed according to the partition)
if return_est is True. torch.Tensor of shape :math:`(batch, nmix, ...)`.
"""
# Check input dimensions
assert est_targets.shape[0] == targets.shape[0]
assert est_targets.shape[2] == targets.shape[2]
if not self.generalized:
min_loss, min_loss_idx, parts = self.best_part_mixit(
self.loss_func, est_targets, targets, **kwargs
)
else:
min_loss, min_loss_idx, parts = self.best_part_mixit_generalized(
self.loss_func, est_targets, targets, **kwargs
)
# Take the mean over the batch
mean_loss = torch.mean(min_loss)
if not return_est:
return mean_loss
# Order and sum on the best partition to get the estimated mixtures
reordered = self.reorder_source(est_targets, targets, min_loss_idx, parts)
return mean_loss, reordered
@staticmethod
def best_part_mixit(loss_func, est_targets, targets, **kwargs):
r"""Find best partition of the estimated sources that gives the minimum
loss for the MixIT training paradigm in [1]. Valid for any number of
mixtures as soon as they contain the same number of sources.
Args:
loss_func: function with signature ``(est_targets, targets, **kwargs)``
The loss function to get batch losses from.
est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`.
The batch of target estimates.
targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`.
The batch of training targets (mixtures).
**kwargs: additional keyword argument that will be passed to the
loss function.
Returns:
- :class:`torch.Tensor`:
The loss corresponding to the best permutation of size (batch,).
- :class:`torch.LongTensor`:
The indices of the best partition.
- :class:`list`:
list of the possible partitions of the sources.
"""
nmix = targets.shape[1]
nsrc = est_targets.shape[1]
if nsrc % nmix != 0:
raise ValueError(
"The mixtures are assumed to contain the same number of sources"
)
nsrcmix = nsrc // nmix
# Generate all unique partitions of size k from a list lst of
# length n, where l = n // k is the number of parts. The total
# number of such partitions is: NPK(n,k) = n! / ((k!)^l * l!)
# Algorithm recursively distributes items over parts
def parts_mixit(lst, k, l):
if l == 0:
yield []
else:
for c in combinations(lst, k):
rest = [x for x in lst if x not in c]
for r in parts_mixit(rest, k, l - 1):
yield [list(c), *r]
# Generate all the possible partitions
parts = list(parts_mixit(range(nsrc), nsrcmix, nmix))
# Compute the loss corresponding to each partition
loss_set = MixITLossWrapper.loss_set_from_parts(
loss_func, est_targets=est_targets, targets=targets, parts=parts, **kwargs
)
# Indexes and values of min losses for each batch element
min_loss, min_loss_indexes = torch.min(loss_set, dim=1, keepdim=True)
return min_loss, min_loss_indexes, parts
@staticmethod
def best_part_mixit_generalized(loss_func, est_targets, targets, **kwargs):
r"""Find best partition of the estimated sources that gives the minimum
loss for the MixIT training paradigm in [1]. Valid only for two mixtures,
but those mixtures do not necessarly have to contain the same number of
sources e.g the case where one mixture is silent is allowed..
Args:
loss_func: function with signature ``(est_targets, targets, **kwargs)``
The loss function to get batch losses from.
est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`.
The batch of target estimates.
targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`.
The batch of training targets (mixtures).
**kwargs: additional keyword argument that will be passed to the
loss function.
Returns:
- :class:`torch.Tensor`:
The loss corresponding to the best permutation of size (batch,).
- :class:`torch.LongTensor`:
The indexes of the best permutations.
- :class:`list`:
list of the possible partitions of the sources.
"""
nmix = targets.shape[1] # number of mixtures
nsrc = est_targets.shape[1] # number of estimated sources
if nmix != 2:
raise ValueError("Works only with two mixtures")
# Generate all unique partitions of any size from a list lst of
# length n. Algorithm recursively distributes items over parts
def parts_mixit_gen(lst):
partitions = []
for k in range(len(lst) + 1):
for c in combinations(lst, k):
rest = []
if c != () and len(c) != len(lst):
for item in lst:
if item not in c:
rest.append(item)
partitions.append([c, rest])
return partitions
# Generate all the possible partitions
parts = parts_mixit_gen(range(nsrc))
# Compute the loss corresponding to each partition
loss_set = MixITLossWrapper.loss_set_from_parts(
loss_func, est_targets=est_targets, targets=targets, parts=parts, **kwargs
)
# Indexes and values of min losses for each batch element
min_loss, min_loss_indexes = torch.min(loss_set, dim=1, keepdim=True)
return min_loss, min_loss_indexes, parts
@staticmethod
def loss_set_from_parts(loss_func, est_targets, targets, parts, **kwargs):
"""Common loop between both best_part_mixit"""
loss_set = []
for partition in parts:
# sum the sources according to the given partition
est_mixes = torch.stack(
[est_targets[:, idx, :].sum(1) for idx in partition], dim=1
)
# get loss for the given partition
loss_set.append(loss_func(est_mixes, targets, **kwargs)[:, None])
loss_set = torch.cat(loss_set, dim=1)
return loss_set
@staticmethod
def reorder_source(est_targets, targets, min_loss_idx, parts):
"""Reorder sources according to the best partition.
Args:
est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`.
The batch of target estimates.
targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`.
The batch of training targets.
min_loss_idx: torch.LongTensor. The indexes of the best permutations.
parts: list of the possible partitions of the sources.
Returns:
:class:`torch.Tensor`: Reordered sources of shape :math:`(batch, nmix, time)`.
"""
# For each batch there is a different min_loss_idx
ordered = torch.zeros_like(targets)
for b, idx in enumerate(min_loss_idx):
right_partition = parts[idx]
# Sum the estimated sources to get the estimated mixtures
ordered[b, :, :] = torch.stack(
[est_targets[b, idx, :][None, :, :].sum(1) for idx in right_partition],
dim=1,
)
return ordered
|