Spaces:
Running
Running
File size: 2,201 Bytes
406f22d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import os
import warnings
import torch
import numpy as np
import soundfile as sf
def get_device(tensor_or_module, default=None):
if hasattr(tensor_or_module, "device"):
return tensor_or_module.device
elif hasattr(tensor_or_module, "parameters"):
return next(tensor_or_module.parameters()).device
elif default is None:
raise TypeError(
f"Don't know how to get device of {type(tensor_or_module)} object"
)
else:
return torch.device(default)
class Separator:
def forward_wav(self, wav, **kwargs):
raise NotImplementedError
def sample_rate(self):
raise NotImplementedError
def separate(model, wav, **kwargs):
if isinstance(wav, np.ndarray):
return numpy_separate(model, wav, **kwargs)
elif isinstance(wav, torch.Tensor):
return torch_separate(model, wav, **kwargs)
else:
raise ValueError(
f"Only support filenames, numpy arrays and torch tensors, received {type(wav)}"
)
@torch.no_grad()
def torch_separate(model: Separator, wav: torch.Tensor, **kwargs) -> torch.Tensor:
"""Core logic of `separate`."""
if model.in_channels is not None and wav.shape[-2] != model.in_channels:
raise RuntimeError(
f"Model supports {model.in_channels}-channel inputs but found audio with {wav.shape[-2]} channels."
f"Please match the number of channels."
)
# Handle device placement
input_device = get_device(wav, default="cpu")
model_device = get_device(model, default="cpu")
wav = wav.to(model_device)
# Forward
separate_func = getattr(model, "forward_wav", model)
out_wavs = separate_func(wav, **kwargs)
# FIXME: for now this is the best we can do.
out_wavs *= wav.abs().sum() / (out_wavs.abs().sum())
# Back to input device (and numpy if necessary)
out_wavs = out_wavs.to(input_device)
return out_wavs
def numpy_separate(model: Separator, wav: np.ndarray, **kwargs) -> np.ndarray:
"""Numpy interface to `separate`."""
wav = torch.from_numpy(wav)
out_wavs = torch_separate(model, wav, **kwargs)
out_wavs = out_wavs.data.numpy()
return out_wavs
|