|
|
|
|
|
""" |
|
https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/conv_stft.py |
|
""" |
|
from collections import defaultdict |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from scipy.signal import get_window |
|
|
|
|
|
def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False): |
|
if win_type == "None" or win_type is None: |
|
window = np.ones(win_size) |
|
else: |
|
window = get_window(win_type, win_size, fftbins=True)**0.5 |
|
|
|
fourier_basis = np.fft.rfft(np.eye(nfft))[:win_size] |
|
real_kernel = np.real(fourier_basis) |
|
image_kernel = np.imag(fourier_basis) |
|
kernel = np.concatenate([real_kernel, image_kernel], 1).T |
|
|
|
if inverse: |
|
kernel = np.linalg.pinv(kernel).T |
|
|
|
kernel = kernel * window |
|
kernel = kernel[:, None, :] |
|
result = ( |
|
torch.from_numpy(kernel.astype(np.float32)), |
|
torch.from_numpy(window[None, :, None].astype(np.float32)) |
|
) |
|
return result |
|
|
|
|
|
class ConvSTFT(nn.Module): |
|
|
|
def __init__(self, |
|
nfft: int, |
|
win_size: int, |
|
hop_size: int, |
|
win_type: str = "hamming", |
|
power: int = None, |
|
requires_grad: bool = False): |
|
super(ConvSTFT, self).__init__() |
|
|
|
if nfft is None: |
|
self.nfft = int(2**np.ceil(np.log2(win_size))) |
|
else: |
|
self.nfft = nfft |
|
|
|
kernel, _ = init_kernels(self.nfft, win_size, hop_size, win_type) |
|
self.weight = nn.Parameter(kernel, requires_grad=requires_grad) |
|
|
|
self.win_size = win_size |
|
self.hop_size = hop_size |
|
|
|
self.stride = hop_size |
|
self.dim = self.nfft |
|
self.power = power |
|
|
|
def forward(self, waveform: torch.Tensor): |
|
if waveform.dim() == 2: |
|
waveform = torch.unsqueeze(waveform, 1) |
|
|
|
matrix = F.conv1d(waveform, self.weight, stride=self.stride) |
|
dim = self.dim // 2 + 1 |
|
real = matrix[:, :dim, :] |
|
imag = matrix[:, dim:, :] |
|
spec = torch.complex(real, imag) |
|
|
|
|
|
if self.power is None: |
|
return spec |
|
elif self.power == 1: |
|
mags = torch.sqrt(real**2 + imag**2) |
|
|
|
return mags |
|
elif self.power == 2: |
|
power = real**2 + imag**2 |
|
return power |
|
else: |
|
raise AssertionError |
|
|
|
|
|
class ConviSTFT(nn.Module): |
|
|
|
def __init__(self, |
|
win_size: int, |
|
hop_size: int, |
|
nfft: int = None, |
|
win_type: str = "hamming", |
|
requires_grad: bool = False): |
|
super(ConviSTFT, self).__init__() |
|
if nfft is None: |
|
self.nfft = int(2**np.ceil(np.log2(win_size))) |
|
else: |
|
self.nfft = nfft |
|
|
|
kernel, window = init_kernels(self.nfft, win_size, hop_size, win_type, inverse=True) |
|
self.weight = nn.Parameter(kernel, requires_grad=requires_grad) |
|
|
|
|
|
|
|
self.win_size = win_size |
|
self.hop_size = hop_size |
|
self.win_type = win_type |
|
|
|
self.stride = hop_size |
|
self.dim = self.nfft |
|
|
|
self.register_buffer("window", window) |
|
self.register_buffer("enframe", torch.eye(win_size)[:, None, :]) |
|
|
|
|
|
|
|
def forward(self, |
|
spec: torch.Tensor): |
|
""" |
|
self.weight shape: [f*2, 1, win_size] |
|
self.window shape: [1, win_size, 1] |
|
self.enframe shape: [win_size, 1, win_size] |
|
|
|
:param spec: torch.Tensor, shape: [b, f, t, 2] |
|
:return: |
|
""" |
|
spec = torch.view_as_real(spec) |
|
|
|
matrix = torch.concat(tensors=[spec[..., 0], spec[..., 1]], dim=1) |
|
|
|
|
|
waveform = F.conv_transpose1d(matrix, self.weight, stride=self.stride) |
|
|
|
|
|
|
|
t = self.window.repeat(1, 1, matrix.size(-1))**2 |
|
|
|
coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) |
|
|
|
waveform = waveform / (coff + 1e-8) |
|
|
|
return waveform |
|
|
|
@torch.no_grad() |
|
def forward_chunk(self, |
|
spec: torch.Tensor, |
|
cache_dict: dict = None |
|
): |
|
""" |
|
:param spec: shape: [b, f, t] |
|
:param cache_dict: dict, |
|
waveform_cache shape: [b, 1, win_size - hop_size] |
|
coff_cache shape: [b, 1, win_size - hop_size] |
|
:return: |
|
""" |
|
if cache_dict is None: |
|
cache_dict = defaultdict(lambda: None) |
|
waveform_cache = cache_dict["waveform_cache"] |
|
coff_cache = cache_dict["coff_cache"] |
|
|
|
spec = torch.view_as_real(spec) |
|
matrix = torch.concat(tensors=[spec[..., 0], spec[..., 1]], dim=1) |
|
|
|
waveform_current = F.conv_transpose1d(matrix, self.weight, stride=self.stride) |
|
|
|
t = self.window.repeat(1, 1, matrix.size(-1))**2 |
|
coff_current = F.conv_transpose1d(t, self.enframe, stride=self.stride) |
|
|
|
overlap_size = self.win_size - self.hop_size |
|
|
|
if waveform_cache is not None: |
|
waveform_current[:, :, :overlap_size] += waveform_cache |
|
waveform_output = waveform_current[:, :, :self.hop_size] |
|
new_waveform_cache = waveform_current[:, :, self.hop_size:] |
|
|
|
if coff_cache is not None: |
|
coff_current[:, :, :overlap_size] += coff_cache |
|
coff_output = coff_current[:, :, :self.hop_size] |
|
new_coff_cache = coff_current[:, :, self.hop_size:] |
|
|
|
waveform_output = waveform_output / (coff_output + 1e-8) |
|
|
|
new_cache_dict = { |
|
"waveform_cache": new_waveform_cache, |
|
"coff_cache": new_coff_cache, |
|
} |
|
return waveform_output, new_cache_dict |
|
|
|
|
|
def main(): |
|
nfft = 512 |
|
win_size = 512 |
|
hop_size = 256 |
|
|
|
stft = ConvSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size, power=None) |
|
istft = ConviSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size) |
|
|
|
mixture = torch.rand(size=(1, 16000), dtype=torch.float32) |
|
b, num_samples = mixture.shape |
|
t = (num_samples - win_size) / hop_size + 1 |
|
|
|
spec = stft.forward(mixture) |
|
b, f, t = spec.shape |
|
|
|
|
|
|
|
print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}") |
|
|
|
waveform = istft.forward(spec) |
|
|
|
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") |
|
print(waveform[:, :, 300: 302]) |
|
|
|
waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32) |
|
for i in range(int(t)): |
|
begin = i * hop_size |
|
end = begin + win_size |
|
sub_spec = spec[:, :, i:i+1] |
|
sub_waveform = istft.forward(sub_spec) |
|
|
|
waveform[:, :, begin:end] = sub_waveform |
|
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") |
|
print(waveform[:, :, 300: 302]) |
|
|
|
return |
|
|
|
|
|
def main2(): |
|
nfft = 512 |
|
win_size = 512 |
|
hop_size = 256 |
|
|
|
stft = ConvSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size, power=None) |
|
istft = ConviSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size) |
|
|
|
mixture = torch.rand(size=(1, 16128), dtype=torch.float32) |
|
b, num_samples = mixture.shape |
|
|
|
spec = stft.forward(mixture) |
|
b, f, t = spec.shape |
|
|
|
|
|
spec = spec + 0.01 * torch.randn(size=(1, nfft//2+1, t), dtype=torch.float32) |
|
print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}") |
|
|
|
waveform = istft.forward(spec) |
|
|
|
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") |
|
print(waveform[:, :, 300: 302]) |
|
|
|
cache_dict = None |
|
waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32) |
|
for i in range(int(t)): |
|
sub_spec = spec[:, :, i:i+1] |
|
begin = i * hop_size |
|
|
|
end = begin + win_size - hop_size |
|
sub_waveform, cache_dict = istft.forward_chunk(sub_spec, cache_dict=cache_dict) |
|
|
|
|
|
|
|
waveform[:, :, begin:end] = sub_waveform |
|
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") |
|
print(waveform[:, :, 300: 302]) |
|
|
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
main2() |
|
|