|
import librosa
|
|
import numpy as np
|
|
import torch
|
|
|
|
import torchcrepe
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def argmax(logits):
|
|
"""Sample observations by taking the argmax"""
|
|
bins = logits.argmax(dim=1)
|
|
|
|
|
|
return bins, torchcrepe.convert.bins_to_frequency(bins)
|
|
|
|
|
|
def weighted_argmax(logits):
|
|
"""Sample observations using weighted sum near the argmax"""
|
|
|
|
bins = logits.argmax(dim=1)
|
|
|
|
|
|
start = torch.max(torch.tensor(0, device=logits.device), bins - 4)
|
|
end = torch.min(torch.tensor(logits.size(1), device=logits.device), bins + 5)
|
|
|
|
|
|
for batch in range(logits.size(0)):
|
|
for time in range(logits.size(2)):
|
|
logits[batch, :start[batch, time], time] = -float('inf')
|
|
logits[batch, end[batch, time]:, time] = -float('inf')
|
|
|
|
|
|
if not hasattr(weighted_argmax, 'weights'):
|
|
weights = torchcrepe.convert.bins_to_cents(torch.arange(360))
|
|
weighted_argmax.weights = weights[None, :, None]
|
|
|
|
|
|
weighted_argmax.weights = weighted_argmax.weights.to(logits.device)
|
|
|
|
|
|
with torch.no_grad():
|
|
probs = torch.sigmoid(logits)
|
|
|
|
|
|
cents = (weighted_argmax.weights * probs).sum(dim=1) / probs.sum(dim=1)
|
|
|
|
|
|
return bins, torchcrepe.convert.cents_to_frequency(cents)
|
|
|
|
|
|
def viterbi(logits):
|
|
"""Sample observations using viterbi decoding"""
|
|
|
|
if not hasattr(viterbi, 'transition'):
|
|
xx, yy = np.meshgrid(range(360), range(360))
|
|
transition = np.maximum(12 - abs(xx - yy), 0)
|
|
transition = transition / transition.sum(axis=1, keepdims=True)
|
|
viterbi.transition = transition
|
|
|
|
|
|
with torch.no_grad():
|
|
probs = torch.nn.functional.softmax(logits, dim=1)
|
|
|
|
|
|
sequences = probs.cpu().numpy()
|
|
|
|
|
|
bins = np.array([
|
|
librosa.sequence.viterbi(sequence, viterbi.transition).astype(np.int64)
|
|
for sequence in sequences])
|
|
|
|
|
|
bins = torch.tensor(bins, device=probs.device)
|
|
|
|
|
|
return bins, torchcrepe.convert.bins_to_frequency(bins)
|
|
|