|
|
''' |
|
|
Modified from the vit_pytorch library: https://github.com/lucidrains/vit-pytorch |
|
|
''' |
|
|
|
|
|
from einops import rearrange |
|
|
from einops.layers.torch import Rearrange |
|
|
import json |
|
|
import math |
|
|
from nnAudio.features.mel import MelSpectrogram |
|
|
import os |
|
|
import torch |
|
|
from torch import nn |
|
|
import torchaudio |
|
|
import torchaudio.transforms as T |
|
|
|
|
|
|
|
|
from huggingface_hub import HfApi, PyTorchModelHubMixin |
|
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
import shutil |
|
|
|
|
|
|
|
|
def pair(t): |
|
|
return t if isinstance(t, (tuple, list)) else (t, t) |
|
|
|
|
|
|
|
|
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32): |
|
|
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") |
|
|
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" |
|
|
omega = torch.arange(dim // 4) / (dim // 4 - 1) |
|
|
omega = 1.0 / (temperature ** omega) |
|
|
|
|
|
y = y.flatten()[:, None] * omega[None, :] |
|
|
x = x.flatten()[:, None] * omega[None, :] |
|
|
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) |
|
|
return pe.type(dtype) |
|
|
|
|
|
|
|
|
def load_model(model: nn.Module, checkpoint_path: str, device: str = 'cpu', ignore_layers: list = ['linear_head'], verbose: bool = False): |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
|
|
|
filtered_state_dict = { |
|
|
k: v for k, v in checkpoint.items() |
|
|
if not any(k.startswith(layer) for layer in ignore_layers) |
|
|
} |
|
|
|
|
|
model.load_state_dict(filtered_state_dict, strict=False) |
|
|
|
|
|
if ignore_layers and verbose: |
|
|
print(f'==> Loaded model from {checkpoint_path}, ignoring layers: {", ".join(ignore_layers)}') |
|
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
def __init__(self, dim, hidden_dim): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential( |
|
|
nn.LayerNorm(dim), |
|
|
nn.Linear(dim, hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_dim, dim), |
|
|
) |
|
|
def forward(self, x): |
|
|
return self.net(x) |
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
|
def __init__(self, dim, heads = 8, dim_head = 64): |
|
|
super().__init__() |
|
|
inner_dim = dim_head * heads |
|
|
self.heads = heads |
|
|
self.scale = dim_head ** -0.5 |
|
|
self.norm = nn.LayerNorm(dim) |
|
|
|
|
|
self.attend = nn.Softmax(dim = -1) |
|
|
|
|
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) |
|
|
self.to_out = nn.Linear(inner_dim, dim, bias = False) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.norm(x) |
|
|
|
|
|
qkv = self.to_qkv(x).chunk(3, dim = -1) |
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) |
|
|
|
|
|
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale |
|
|
|
|
|
attn = self.attend(dots) |
|
|
|
|
|
out = torch.matmul(attn, v) |
|
|
out = rearrange(out, 'b h n d -> b n (h d)') |
|
|
return self.to_out(out) |
|
|
|
|
|
|
|
|
class Transformer(nn.Module): |
|
|
def __init__(self, dim, depth, heads, dim_head, mlp_dim): |
|
|
super().__init__() |
|
|
self.norm = nn.LayerNorm(dim) |
|
|
self.layers = nn.ModuleList([]) |
|
|
for _ in range(depth): |
|
|
self.layers.append(nn.ModuleList([ |
|
|
Attention(dim, heads = heads, dim_head = dim_head), |
|
|
FeedForward(dim, mlp_dim) |
|
|
])) |
|
|
def forward(self, x): |
|
|
for attn, ff in self.layers: |
|
|
x = attn(x) + x |
|
|
x = ff(x) + x |
|
|
return self.norm(x) |
|
|
|
|
|
|
|
|
class MynaPreprocessor: |
|
|
def __init__(self, target_sr: int = 16000, n_mels: int = 128): |
|
|
self.target_sr = target_sr |
|
|
self.n_mels = n_mels |
|
|
self.mel_spec = MelSpectrogram(sr=target_sr, n_mels=n_mels, verbose=False) |
|
|
|
|
|
def __call__(self, filename: str, n_frames: int = None): |
|
|
|
|
|
signal, sr = torchaudio.load(filename) |
|
|
if signal.shape[0] > 1: |
|
|
signal = signal.mean(dim=0, keepdim=True) |
|
|
if sr != self.target_sr: |
|
|
resampler = T.Resample(orig_freq=sr, new_freq=self.target_sr) |
|
|
signal = resampler(signal) |
|
|
ms = self.mel_spec(signal) |
|
|
|
|
|
if n_frames: |
|
|
ms = self._batch_spectrogram(ms, n_frames) |
|
|
|
|
|
return ms |
|
|
|
|
|
def _batch_spectrogram(self, ms: torch.Tensor, n_frames: int): |
|
|
|
|
|
assert ms.dim() == 3 and ms.shape[0] == 1 |
|
|
|
|
|
|
|
|
num_chunks = ms.shape[-1] // n_frames |
|
|
ms = ms[:, :, :num_chunks * n_frames] |
|
|
|
|
|
|
|
|
chunks = torch.chunk(ms, num_chunks, dim=2) |
|
|
batch = torch.stack(chunks) |
|
|
|
|
|
return batch |
|
|
|
|
|
|
|
|
class MynaConfig(PretrainedConfig): |
|
|
model_type = 'myna' |
|
|
def __init__( |
|
|
self, spec_size=(128, 4096), patch_size=16, dim=384, depth=12, |
|
|
heads=6, mlp_dim=1536, dim_head = 64, arch=None, additional_patch_size = None, |
|
|
hybrid_mode: bool = False, n_samples = 50000, sr = 16000, **kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.spec_size = spec_size |
|
|
self.patch_size = patch_size |
|
|
self.dim = dim |
|
|
self.depth = depth |
|
|
self.heads = heads |
|
|
self.mlp_dim = mlp_dim |
|
|
self.dim_head = dim_head |
|
|
self.arch = arch |
|
|
self.additional_patch_size = additional_patch_size |
|
|
self.hybrid_mode = hybrid_mode |
|
|
|
|
|
self.n_samples = n_samples |
|
|
self.sr = sr |
|
|
self.n_frames = self._get_n_frames(n_samples) |
|
|
|
|
|
|
|
|
if arch: |
|
|
arch = self._get_arch(arch) |
|
|
self.dim = arch['dim'] |
|
|
self.depth = arch['depth'] |
|
|
self.heads = arch['heads'] |
|
|
self.mlp_dim = arch['mlp_dim'] |
|
|
|
|
|
def _get_arch(self, arch: str): |
|
|
if arch.lower() in ['vit-s-16', 'vit-s-32']: |
|
|
|
|
|
return {'dim': 384, 'depth': 12, 'mlp_dim': 1536, 'heads': 6} |
|
|
if arch.lower() == 'vit-b-16': |
|
|
|
|
|
return {'dim': 768, 'depth': 12, 'mlp_dim': 3072, 'heads': 12} |
|
|
if arch.lower() == 'vit-l-16': |
|
|
|
|
|
return {'dim': 1024, 'depth': 24, 'mlp_dim': 4096, 'heads': 16} |
|
|
|
|
|
raise ValueError(f'Architecture {arch} not implemented') |
|
|
|
|
|
def _get_n_frames(self, n_samples: int): |
|
|
''' How many frames is n_samples samples? ''' |
|
|
mel_spectrogram = MelSpectrogram(sr=self.sr, n_mels=self.spec_size[0], verbose=False) |
|
|
patch_size_time = self.patch_size if isinstance(self.patch_size, int) else self.patch_size[1] |
|
|
mel_frames = mel_spectrogram(torch.randn(1, 1, n_samples)).shape[-1] |
|
|
mel_frames = math.floor(mel_frames / patch_size_time) * patch_size_time |
|
|
return mel_frames |
|
|
|
|
|
|
|
|
class Myna(PreTrainedModel, PyTorchModelHubMixin): |
|
|
config_class = MynaConfig |
|
|
def __init__(self, config: MynaConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
self.preprocessor = MynaPreprocessor() |
|
|
self.hybrid_mode = config.hybrid_mode |
|
|
spec_height, spec_width = pair(config.spec_size) |
|
|
patch_height, patch_width = pair(config.patch_size) |
|
|
|
|
|
assert spec_height % patch_height == 0 and spec_width % patch_width == 0, 'Spectrogram dimensions must be divisible by the patch size.' |
|
|
|
|
|
self.additional_patch_size = config.additional_patch_size |
|
|
if config.additional_patch_size: |
|
|
patch_height_b, patch_width_b = pair(config.additional_patch_size) |
|
|
patch_dim_b = patch_height_b * patch_width_b |
|
|
|
|
|
self.to_patch_embedding_b, self.pos_embedding_b = self._make_embeddings( |
|
|
patch_height_b, patch_width_b, patch_dim_b, config.dim, spec_height, spec_width |
|
|
) |
|
|
|
|
|
patch_dim = patch_height * patch_width |
|
|
|
|
|
self.to_patch_embedding, self.pos_embedding = self._make_embeddings( |
|
|
patch_height, patch_width, patch_dim, config.dim, spec_height, spec_width |
|
|
) |
|
|
|
|
|
self.transformer = Transformer(config.dim, config.depth, config.heads, config.dim_head, config.mlp_dim) |
|
|
|
|
|
self.pool = 'mean' |
|
|
self.to_latent = nn.Identity() |
|
|
|
|
|
self.linear_head = nn.Identity() |
|
|
|
|
|
def forward(self, spec, recurse=True): |
|
|
if self.hybrid_mode and recurse: |
|
|
a = self(spec, recurse=False) |
|
|
self.toggle_embeddings() |
|
|
b = self(spec, recurse=False) |
|
|
self.toggle_embeddings() |
|
|
return torch.cat((a, b), dim=-1) |
|
|
|
|
|
|
|
|
if spec.dim() == 2: |
|
|
|
|
|
spec = spec.unsqueeze(0).unsqueeze(0) |
|
|
elif spec.dim() == 3: |
|
|
|
|
|
spec = spec.unsqueeze(1) |
|
|
assert spec.dim() == 4 |
|
|
|
|
|
device = spec.device |
|
|
|
|
|
x = self.to_patch_embedding(spec) |
|
|
n_patches = x.shape[1] |
|
|
x += self.pos_embedding[:n_patches].to(device, dtype=x.dtype) |
|
|
|
|
|
x = self.transformer(x) |
|
|
x = x.mean(dim = 1) |
|
|
|
|
|
x = self.to_latent(x) |
|
|
return self.linear_head(x) |
|
|
|
|
|
def toggle_embeddings(self): |
|
|
if not self.additional_patch_size: |
|
|
print('toggle_embeddings() called but no additional patch size provided! Ignoring call.') |
|
|
return |
|
|
self.to_patch_embedding, self.to_patch_embedding_b = self.to_patch_embedding_b, self.to_patch_embedding |
|
|
self.pos_embedding, self.pos_embedding_b = self.pos_embedding_b, self.pos_embedding |
|
|
|
|
|
def _make_embeddings(self, patch_height, patch_width, patch_dim, dim, image_height, image_width): |
|
|
to_patch_embedding = nn.Sequential( |
|
|
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), |
|
|
nn.LayerNorm(patch_dim), |
|
|
nn.Linear(patch_dim, dim), |
|
|
nn.LayerNorm(dim), |
|
|
) |
|
|
|
|
|
pos_embedding = posemb_sincos_2d( |
|
|
h = image_height // patch_height, |
|
|
w = image_width // patch_width, |
|
|
dim = dim, |
|
|
) |
|
|
|
|
|
return to_patch_embedding, pos_embedding |
|
|
|
|
|
def from_file(self, filename: str, n_samples: int = None): |
|
|
n_frames = self.config.n_frames |
|
|
if n_samples and n_samples != self.config.n_samples: |
|
|
n_frames = self.config._get_n_frames(n_samples) |
|
|
spec = self.preprocessor(filename, n_frames).to(self.device) |
|
|
return self(spec) |
|
|
|
|
|
@property |
|
|
def n_params(self): |
|
|
return sum(p.numel() for p in self.parameters()) |
|
|
|
|
|
|
|
|
def save_model_and_push(model, repo_name, save_dir='myna-temp', to_hub=False): |
|
|
model.save_pretrained(save_dir) |
|
|
shutil.copy('myna.py', save_dir) |
|
|
|
|
|
config = model.config.to_dict() |
|
|
config.update({ |
|
|
'_name_or_path': repo_name, |
|
|
'architectures': ['Myna'], |
|
|
'auto_map': { |
|
|
'AutoConfig': 'myna.MynaConfig', |
|
|
'AutoModel': 'myna.Myna' |
|
|
}, |
|
|
'model_type': 'myna' |
|
|
}) |
|
|
|
|
|
with open(os.path.join(save_dir, 'config.json'), 'w') as f: |
|
|
json.dump(config, f, indent=4) |
|
|
|
|
|
print(f'Model saved locally to {save_dir}') |
|
|
|
|
|
if to_hub: |
|
|
api = HfApi() |
|
|
api.create_repo(repo_name, exist_ok=True) |
|
|
api.upload_folder(folder_path=save_dir, repo_id=repo_name) |
|
|
print(f"Model pushed to: https://huggingface.co/{repo_name}") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
config = MynaConfig( |
|
|
arch='vit-b-16', |
|
|
patch_size=16, |
|
|
additional_patch_size=(128, 2), |
|
|
hybrid_mode=True |
|
|
) |
|
|
model = Myna(config) |
|
|
load_model(model, 'checkpoints/myna-85m.pth', verbose=True) |
|
|
print(f'Model contains {model.n_params:,} parameters') |
|
|
|
|
|
save_model_and_push( |
|
|
model, |
|
|
repo_name='oriyonay/myna-85m', |
|
|
save_dir='myna-85m-hybrid', |
|
|
to_hub=True |
|
|
) |