|
|
|
|
|
""" |
|
https://github.com/Rikorose/DeepFilterNet/blob/main/DeepFilterNet/df/modules.py#L816 |
|
""" |
|
from typing import Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
import torchaudio |
|
|
|
|
|
def local_energy(spec: torch.Tensor, n_frame: int, device: torch.device) -> torch.Tensor: |
|
if n_frame % 2 == 0: |
|
n_frame += 1 |
|
n_frame_half = n_frame // 2 |
|
|
|
|
|
spec = spec.pow(2).sum(-1).sum(-1) |
|
|
|
spec = F.pad(spec, (n_frame_half, n_frame_half, 0, 0)) |
|
|
|
|
|
weight = torch.hann_window(n_frame, device=device, dtype=spec.dtype) |
|
|
|
|
|
spec = spec.unfold(-1, size=n_frame, step=1) * weight |
|
|
|
|
|
result = torch.sum(spec, dim=-1).div(n_frame) |
|
|
|
return result |
|
|
|
|
|
def local_snr(spec_clean: torch.Tensor, |
|
spec_noise: torch.Tensor, |
|
n_frame: int = 5, |
|
db: bool = False, |
|
eps: float = 1e-12, |
|
): |
|
|
|
spec_clean = torch.view_as_real(spec_clean) |
|
spec_noise = torch.view_as_real(spec_noise) |
|
|
|
|
|
energy_clean = local_energy(spec_clean, n_frame=n_frame, device=spec_clean.device) |
|
energy_noise = local_energy(spec_noise, n_frame=n_frame, device=spec_noise.device) |
|
|
|
|
|
snr = energy_clean / energy_noise.clamp_min(eps) |
|
|
|
|
|
if db: |
|
snr = snr.clamp_min(eps).log10().mul(10) |
|
return snr, energy_clean, energy_noise |
|
|
|
|
|
class LocalSnrTarget(nn.Module): |
|
def __init__(self, |
|
sample_rate: int = 8000, |
|
nfft: int = 512, |
|
win_size: int = 512, |
|
hop_size: int = 256, |
|
|
|
n_frame: int = 3, |
|
|
|
min_local_snr: int = -15, |
|
max_local_snr: int = 30, |
|
|
|
db: bool = True, |
|
): |
|
super().__init__() |
|
self.sample_rate = sample_rate |
|
self.nfft = nfft |
|
self.win_size = win_size |
|
self.hop_size = hop_size |
|
|
|
self.n_frame = n_frame |
|
|
|
self.min_local_snr = min_local_snr |
|
self.max_local_snr = max_local_snr |
|
|
|
self.db = db |
|
|
|
def forward(self, |
|
spec_clean: torch.Tensor, |
|
spec_noise: torch.Tensor, |
|
) -> torch.Tensor: |
|
""" |
|
|
|
:param spec_clean: torch.complex, shape: [b, c, t, f] |
|
:param spec_noise: torch.complex, shape: [b, c, t, f] |
|
:return: lsnr, shape: [b, t] |
|
""" |
|
|
|
lsnr, _, _ = local_snr( |
|
spec_clean=spec_clean, |
|
spec_noise=spec_noise, |
|
n_frame=self.n_frame, |
|
db=self.db, |
|
) |
|
|
|
lsnr = lsnr.clamp(self.min_local_snr, self.max_local_snr).squeeze(1) |
|
|
|
return lsnr |
|
|
|
|
|
def main(): |
|
sample_rate = 8000 |
|
nfft = 512 |
|
win_size = 512 |
|
hop_size = 256 |
|
window_fn = "hamming" |
|
|
|
transform = torchaudio.transforms.Spectrogram( |
|
n_fft=nfft, |
|
win_length=win_size, |
|
hop_length=hop_size, |
|
power=None, |
|
window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, |
|
) |
|
|
|
noisy = torch.randn(size=(1, 16000), dtype=torch.float32) |
|
|
|
spec = transform.forward(noisy) |
|
spec = spec.permute(0, 2, 1) |
|
spec = torch.unsqueeze(spec, dim=1) |
|
print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}") |
|
|
|
|
|
|
|
|
|
|
|
local = LocalSnrTarget( |
|
sample_rate=sample_rate, |
|
nfft=nfft, |
|
win_size=win_size, |
|
hop_size=hop_size, |
|
n_frame=5, |
|
min_local_snr=-15, |
|
max_local_snr=30, |
|
db=True, |
|
) |
|
lsnr_target = local.forward(spec, spec) |
|
print(f"lsnr_target.shape: {lsnr_target.shape}") |
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|