File size: 3,019 Bytes
9442c34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from typing import Union, Tuple
import numpy as np
from numpy.typing import NDArray
import torch
from torch import nn
from functools import partial
import matplotlib.pyplot as plt
from PIL import Image
import librosa
import miniaudio

from mae import MaskedAutoencoderViT


def load_audio(
        path: str,
        sr: int = 32000,
        duration: int = 20,
) -> (np.ndarray, int):
    g = miniaudio.stream_file(path, output_format=miniaudio.SampleFormat.FLOAT32, nchannels=1,
                              sample_rate=sr, frames_to_read=sr * duration)
    signal = np.array(next(g))
    return signal


def mel_spectrogram(
        signal: np.ndarray,
        sr: int = 32000,
        n_fft: int = 800,
        hop_length: int = 320,
        n_mels: int = 128,
) -> np.ndarray:
    mel_spec = librosa.feature.melspectrogram(
        y=signal, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
        window='hann', pad_mode='constant'
    )
    mel_spec = librosa.power_to_db(mel_spec)  # (freq, time)
    return mel_spec.T  # (time, freq)


def display_image(
        img: Union[NDArray, Image.Image],
        figsize: Tuple[float, float] = (5, 5),
) -> None:
    plt.figure(figsize=figsize)
    plt.imshow(img, origin='lower', aspect='auto')  # cmp = 'viridis', 'coolwarm'
    plt.axis('off')
    plt.colorbar()
    plt.tight_layout()
    plt.show()


def normalize(arr: np.ndarray, eps: float = 1e-8) -> np.ndarray:
    return (arr - arr.mean()) / (arr.std() + eps)


if __name__ == '__main__':
    mp3_file = "/Users/chenjing22/Downloads/songs/See You Again.mp3"
    mel_spec = mel_spectrogram(load_audio(mp3_file, duration=21))  # (time, freq)

    # padding or truncating
    length = mel_spec.shape[0]
    target_length = 2048
    mel_spec = mel_spec[:target_length] if length > target_length else np.pad(
        mel_spec, ((0, target_length - length), (0, 0)), mode='constant', constant_values=mel_spec.min()
    )

    # normalize
    mel_spec = normalize(mel_spec)  # (2048, 128)

    display_image(mel_spec.T, figsize=(10, 4))

    # Model
    mae = MaskedAutoencoderViT(
        img_size=(2048, 128),
        patch_size=16,
        in_chans=1,
        embed_dim=768,
        depth=12,
        num_heads=12,
        decoder_mode=1,
        no_shift=False,
        decoder_embed_dim=512,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        norm_pix_loss=False,
        pos_trainable=False,
    )

    # Load pre-trained weights
    ckpt_path = 'music-mae-32kHz.pth'
    mae.load_state_dict(torch.load(ckpt_path, map_location='cpu'))

    device = 'cpu'  # 'cuda'
    mae.to(device)

    x = torch.from_numpy(mel_spec).unsqueeze(0).unsqueeze(0).to(device)  # (1, 1, 2048, 128)
    mse_loss, y, mask = mae(x, mask_ratio=0.7)  # y: (1, 1024, 256), mask: (1, 1024)

    y[mask == 0.] = mae.patchify(x)[mask == 0.]
    x_reconstructed = mae.unpatchify(y).squeeze(0).squeeze(0).detach().numpy()

    print(f'mse_loss: {mse_loss.item()}')
    display_image(x_reconstructed.T, figsize=(10, 4))