Spaces:
Running
on
Zero
Running
on
Zero
# Reference: https://github.com/yxlu-0102/MP-SENet/blob/main/models/generator.py | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from pesq import pesq | |
from joblib import Parallel, delayed | |
def phase_losses(phase_r, phase_g, cfg): | |
""" | |
Calculate phase losses including in-phase loss, gradient delay loss, | |
and integrated absolute frequency loss between reference and generated phases. | |
Args: | |
phase_r (torch.Tensor): Reference phase tensor of shape (batch, freq, time). | |
phase_g (torch.Tensor): Generated phase tensor of shape (batch, freq, time). | |
h (object): Configuration object containing parameters like n_fft. | |
Returns: | |
tuple: Tuple containing in-phase loss, gradient delay loss, and integrated absolute frequency loss. | |
""" | |
dim_freq = cfg['stft_cfg']['n_fft'] // 2 + 1 # Calculate frequency dimension | |
dim_time = phase_r.size(-1) # Calculate time dimension | |
# Construct gradient delay matrix | |
gd_matrix = (torch.triu(torch.ones(dim_freq, dim_freq), diagonal=1) - | |
torch.triu(torch.ones(dim_freq, dim_freq), diagonal=2) - | |
torch.eye(dim_freq)).to(phase_g.device) | |
# Apply gradient delay matrix to reference and generated phases | |
gd_r = torch.matmul(phase_r.permute(0, 2, 1), gd_matrix) | |
gd_g = torch.matmul(phase_g.permute(0, 2, 1), gd_matrix) | |
# Construct integrated absolute frequency matrix | |
iaf_matrix = (torch.triu(torch.ones(dim_time, dim_time), diagonal=1) - | |
torch.triu(torch.ones(dim_time, dim_time), diagonal=2) - | |
torch.eye(dim_time)).to(phase_g.device) | |
# Apply integrated absolute frequency matrix to reference and generated phases | |
iaf_r = torch.matmul(phase_r, iaf_matrix) | |
iaf_g = torch.matmul(phase_g, iaf_matrix) | |
# Calculate losses | |
ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g)) | |
gd_loss = torch.mean(anti_wrapping_function(gd_r - gd_g)) | |
iaf_loss = torch.mean(anti_wrapping_function(iaf_r - iaf_g)) | |
return ip_loss, gd_loss, iaf_loss | |
def anti_wrapping_function(x): | |
""" | |
Anti-wrapping function to adjust phase values within the range of -pi to pi. | |
Args: | |
x (torch.Tensor): Input tensor representing phase differences. | |
Returns: | |
torch.Tensor: Adjusted tensor with phase values wrapped within -pi to pi. | |
""" | |
return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi) | |
def compute_stft(y: torch.Tensor, n_fft: int, hop_size: int, win_size: int, center: bool, compress_factor: float = 1.0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
""" | |
Compute the Short-Time Fourier Transform (STFT) and return magnitude, phase, and complex components. | |
Args: | |
y (torch.Tensor): Input signal tensor. | |
n_fft (int): Number of FFT points. | |
hop_size (int): Hop size for STFT. | |
win_size (int): Window size for STFT. | |
center (bool): Whether to pad the input on both sides. | |
compress_factor (float, optional): Compression factor for magnitude. Defaults to 1.0. | |
Returns: | |
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Magnitude, phase, and complex components. | |
""" | |
eps = torch.finfo(y.dtype).eps | |
hann_window = torch.hann_window(win_size).to(y.device) | |
stft_spec = torch.stft( | |
y, | |
n_fft=n_fft, | |
hop_length=hop_size, | |
win_length=win_size, | |
window=hann_window, | |
center=center, | |
pad_mode='reflect', | |
normalized=False, | |
return_complex=True | |
) | |
real_part = stft_spec.real | |
imag_part = stft_spec.imag | |
mag = torch.sqrt( real_part.pow(2) * imag_part.pow(2) + eps ) | |
pha = torch.atan2( real_part + eps, imag_part + eps ) | |
mag = torch.pow(mag, compress_factor) | |
com = torch.stack((mag * torch.cos(pha), mag * torch.sin(pha)), dim=-1) | |
return mag, pha, com | |
def pesq_score(utts_r, utts_g, cfg): | |
""" | |
Calculate PESQ (Perceptual Evaluation of Speech Quality) score for pairs of reference and generated utterances. | |
Args: | |
utts_r (list of torch.Tensor): List of reference utterances. | |
utts_g (list of torch.Tensor): List of generated utterances. | |
h (object): Configuration object containing parameters like sampling_rate. | |
Returns: | |
float: Mean PESQ score across all pairs of utterances. | |
""" | |
def eval_pesq(clean_utt, esti_utt, sr): | |
""" | |
Evaluate PESQ score for a single pair of clean and estimated utterances. | |
Args: | |
clean_utt (np.ndarray): Clean reference utterance. | |
esti_utt (np.ndarray): Estimated generated utterance. | |
sr (int): Sampling rate. | |
Returns: | |
float: PESQ score or -1 in case of an error. | |
""" | |
try: | |
pesq_score = pesq(sr, clean_utt, esti_utt) | |
except Exception as e: | |
# Error can happen due to silent period or other issues | |
print(f"Error computing PESQ score: {e}") | |
pesq_score = -1 | |
return pesq_score | |
# Parallel processing of PESQ score computation | |
pesq_scores = Parallel(n_jobs=30)(delayed(eval_pesq)( | |
utts_r[i].squeeze().cpu().numpy(), | |
utts_g[i].squeeze().cpu().numpy(), | |
cfg['stft_cfg']['sampling_rate'] | |
) for i in range(len(utts_r))) | |
# Calculate mean PESQ score | |
pesq_score = np.mean(pesq_scores) | |
return pesq_score | |