File size: 3,019 Bytes
9442c34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
from typing import Union, Tuple
import numpy as np
from numpy.typing import NDArray
import torch
from torch import nn
from functools import partial
import matplotlib.pyplot as plt
from PIL import Image
import librosa
import miniaudio
from mae import MaskedAutoencoderViT
def load_audio(
path: str,
sr: int = 32000,
duration: int = 20,
) -> (np.ndarray, int):
g = miniaudio.stream_file(path, output_format=miniaudio.SampleFormat.FLOAT32, nchannels=1,
sample_rate=sr, frames_to_read=sr * duration)
signal = np.array(next(g))
return signal
def mel_spectrogram(
signal: np.ndarray,
sr: int = 32000,
n_fft: int = 800,
hop_length: int = 320,
n_mels: int = 128,
) -> np.ndarray:
mel_spec = librosa.feature.melspectrogram(
y=signal, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
window='hann', pad_mode='constant'
)
mel_spec = librosa.power_to_db(mel_spec) # (freq, time)
return mel_spec.T # (time, freq)
def display_image(
img: Union[NDArray, Image.Image],
figsize: Tuple[float, float] = (5, 5),
) -> None:
plt.figure(figsize=figsize)
plt.imshow(img, origin='lower', aspect='auto') # cmp = 'viridis', 'coolwarm'
plt.axis('off')
plt.colorbar()
plt.tight_layout()
plt.show()
def normalize(arr: np.ndarray, eps: float = 1e-8) -> np.ndarray:
return (arr - arr.mean()) / (arr.std() + eps)
if __name__ == '__main__':
mp3_file = "/Users/chenjing22/Downloads/songs/See You Again.mp3"
mel_spec = mel_spectrogram(load_audio(mp3_file, duration=21)) # (time, freq)
# padding or truncating
length = mel_spec.shape[0]
target_length = 2048
mel_spec = mel_spec[:target_length] if length > target_length else np.pad(
mel_spec, ((0, target_length - length), (0, 0)), mode='constant', constant_values=mel_spec.min()
)
# normalize
mel_spec = normalize(mel_spec) # (2048, 128)
display_image(mel_spec.T, figsize=(10, 4))
# Model
mae = MaskedAutoencoderViT(
img_size=(2048, 128),
patch_size=16,
in_chans=1,
embed_dim=768,
depth=12,
num_heads=12,
decoder_mode=1,
no_shift=False,
decoder_embed_dim=512,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
norm_pix_loss=False,
pos_trainable=False,
)
# Load pre-trained weights
ckpt_path = 'music-mae-32kHz.pth'
mae.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
device = 'cpu' # 'cuda'
mae.to(device)
x = torch.from_numpy(mel_spec).unsqueeze(0).unsqueeze(0).to(device) # (1, 1, 2048, 128)
mse_loss, y, mask = mae(x, mask_ratio=0.7) # y: (1, 1024, 256), mask: (1, 1024)
y[mask == 0.] = mae.patchify(x)[mask == 0.]
x_reconstructed = mae.unpatchify(y).squeeze(0).squeeze(0).detach().numpy()
print(f'mse_loss: {mse_loss.item()}')
display_image(x_reconstructed.T, figsize=(10, 4))
|