|
import numpy as np |
|
import torch |
|
|
|
import torchcrepe |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class At: |
|
"""Simple thresholding at a specified probability value""" |
|
|
|
def __init__(self, value): |
|
self.value = value |
|
|
|
def __call__(self, pitch, periodicity): |
|
|
|
pitch = torch.clone(pitch) |
|
|
|
|
|
pitch[periodicity < self.value] = torchcrepe.UNVOICED |
|
return pitch |
|
|
|
|
|
class Hysteresis: |
|
"""Hysteresis thresholding""" |
|
|
|
def __init__(self, |
|
lower_bound=.19, |
|
upper_bound=.31, |
|
width=.2, |
|
stds=1.7, |
|
return_threshold=False): |
|
self.lower_bound = lower_bound |
|
self.upper_bound = upper_bound |
|
self.width = width |
|
self.stds = stds |
|
self.return_threshold = return_threshold |
|
|
|
def __call__(self, pitch, periodicity): |
|
|
|
device = pitch.device |
|
|
|
|
|
pitch = torch.log2(pitch).detach().flatten().cpu().numpy() |
|
|
|
|
|
periodicity = periodicity.flatten().cpu().numpy() |
|
|
|
|
|
pitch[periodicity < self.lower_bound] = torchcrepe.UNVOICED |
|
|
|
|
|
mean, std = np.nanmean(pitch), np.nanstd(pitch) |
|
pitch = (pitch - mean) / std |
|
|
|
|
|
parabola = self.width * pitch ** 2 - self.width * self.stds ** 2 |
|
threshold = \ |
|
self.lower_bound + np.clip(parabola, 0, 1 - self.lower_bound) |
|
threshold[np.isnan(threshold)] = self.lower_bound |
|
|
|
|
|
i = 0 |
|
while i < len(periodicity) - 1: |
|
|
|
|
|
if periodicity[i] < threshold[i] and \ |
|
periodicity[i + 1] > threshold[i + 1]: |
|
|
|
|
|
start, end, keep = i + 1, i + 1, False |
|
while end < len(periodicity) and \ |
|
periodicity[end] > threshold[end]: |
|
if periodicity[end] > self.upper_bound: |
|
keep = True |
|
end += 1 |
|
|
|
|
|
|
|
if not keep: |
|
threshold[start:end] = 1 |
|
|
|
i = end |
|
|
|
else: |
|
i += 1 |
|
|
|
|
|
pitch[periodicity < threshold] = torchcrepe.UNVOICED |
|
|
|
|
|
pitch = pitch * std + mean |
|
|
|
|
|
pitch = torch.tensor(2 ** pitch, device=device)[None, :] |
|
|
|
|
|
if self.return_threshold: |
|
return pitch, torch.tensor(threshold, device=device) |
|
|
|
return pitch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Silence: |
|
"""Set periodicity to zero in silent regions""" |
|
|
|
def __init__(self, value=-60): |
|
self.value = value |
|
|
|
def __call__(self, |
|
periodicity, |
|
audio, |
|
sample_rate=torchcrepe.SAMPLE_RATE, |
|
hop_length=None, |
|
pad=True): |
|
|
|
periodicity = torch.clone(periodicity) |
|
|
|
|
|
loudness = torchcrepe.loudness.a_weighted( |
|
audio, sample_rate, hop_length, pad) |
|
|
|
|
|
periodicity[loudness < self.value] = 0. |
|
|
|
return periodicity |
|
|