| import librosa
|
| import numpy as np
|
| import torch
|
|
|
| import torchcrepe
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def argmax(logits):
|
| """Sample observations by taking the argmax"""
|
| bins = logits.argmax(dim=1)
|
|
|
|
|
| return bins, torchcrepe.convert.bins_to_frequency(bins)
|
|
|
|
|
| def weighted_argmax(logits):
|
| """Sample observations using weighted sum near the argmax"""
|
|
|
| bins = logits.argmax(dim=1)
|
|
|
|
|
| start = torch.max(torch.tensor(0, device=logits.device), bins - 4)
|
| end = torch.min(torch.tensor(logits.size(1), device=logits.device), bins + 5)
|
|
|
|
|
| for batch in range(logits.size(0)):
|
| for time in range(logits.size(2)):
|
| logits[batch, :start[batch, time], time] = -float('inf')
|
| logits[batch, end[batch, time]:, time] = -float('inf')
|
|
|
|
|
| if not hasattr(weighted_argmax, 'weights'):
|
| weights = torchcrepe.convert.bins_to_cents(torch.arange(360))
|
| weighted_argmax.weights = weights[None, :, None]
|
|
|
|
|
| weighted_argmax.weights = weighted_argmax.weights.to(logits.device)
|
|
|
|
|
| with torch.no_grad():
|
| probs = torch.sigmoid(logits)
|
|
|
|
|
| cents = (weighted_argmax.weights * probs).sum(dim=1) / probs.sum(dim=1)
|
|
|
|
|
| return bins, torchcrepe.convert.cents_to_frequency(cents)
|
|
|
|
|
| def viterbi(logits):
|
| """Sample observations using viterbi decoding"""
|
|
|
| if not hasattr(viterbi, 'transition'):
|
| xx, yy = np.meshgrid(range(360), range(360))
|
| transition = np.maximum(12 - abs(xx - yy), 0)
|
| transition = transition / transition.sum(axis=1, keepdims=True)
|
| viterbi.transition = transition
|
|
|
|
|
| with torch.no_grad():
|
| probs = torch.nn.functional.softmax(logits, dim=1)
|
|
|
|
|
| sequences = probs.cpu().numpy()
|
|
|
|
|
| bins = np.array([
|
| librosa.sequence.viterbi(sequence, viterbi.transition).astype(np.int64)
|
| for sequence in sequences])
|
|
|
|
|
| bins = torch.tensor(bins, device=probs.device)
|
|
|
|
|
| return bins, torchcrepe.convert.bins_to_frequency(bins)
|
|
|