myna-85m / myna.py
oriyonay's picture
Upload folder using huggingface_hub
feb8c25 verified
'''
Modified from the vit_pytorch library: https://github.com/lucidrains/vit-pytorch
'''
from einops import rearrange
from einops.layers.torch import Rearrange
import json
import math
from nnAudio.features.mel import MelSpectrogram
import os
import torch
from torch import nn
import torchaudio
import torchaudio.transforms as T
# for uploading to huggingface hub
from huggingface_hub import HfApi, PyTorchModelHubMixin
from transformers import PretrainedConfig, PreTrainedModel
import shutil
def pair(t):
return t if isinstance(t, (tuple, list)) else (t, t)
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)
def load_model(model: nn.Module, checkpoint_path: str, device: str = 'cpu', ignore_layers: list = ['linear_head'], verbose: bool = False):
checkpoint = torch.load(checkpoint_path, map_location=device)
filtered_state_dict = {
k: v for k, v in checkpoint.items()
if not any(k.startswith(layer) for layer in ignore_layers)
}
model.load_state_dict(filtered_state_dict, strict=False)
if ignore_layers and verbose:
print(f'==> Loaded model from {checkpoint_path}, ignoring layers: {", ".join(ignore_layers)}')
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
class MynaPreprocessor:
def __init__(self, target_sr: int = 16000, n_mels: int = 128):
self.target_sr = target_sr
self.n_mels = n_mels
self.mel_spec = MelSpectrogram(sr=target_sr, n_mels=n_mels, verbose=False)
def __call__(self, filename: str, n_frames: int = None):
# loads audio from file and returns a 3D tensor (B, n_mels, n_frames)
signal, sr = torchaudio.load(filename)
if signal.shape[0] > 1:
signal = signal.mean(dim=0, keepdim=True)
if sr != self.target_sr:
resampler = T.Resample(orig_freq=sr, new_freq=self.target_sr)
signal = resampler(signal)
ms = self.mel_spec(signal)
if n_frames:
ms = self._batch_spectrogram(ms, n_frames)
return ms
def _batch_spectrogram(self, ms: torch.Tensor, n_frames: int):
# sanity check
assert ms.dim() == 3 and ms.shape[0] == 1
# discard excess frames
num_chunks = ms.shape[-1] // n_frames
ms = ms[:, :, :num_chunks * n_frames]
# split the tensor into chunks and stack them
chunks = torch.chunk(ms, num_chunks, dim=2)
batch = torch.stack(chunks)
return batch
class MynaConfig(PretrainedConfig):
model_type = 'myna'
def __init__(
self, spec_size=(128, 4096), patch_size=16, dim=384, depth=12,
heads=6, mlp_dim=1536, dim_head = 64, arch=None, additional_patch_size = None,
hybrid_mode: bool = False, n_samples = 50000, sr = 16000, **kwargs
):
super().__init__(**kwargs)
self.spec_size = spec_size
self.patch_size = patch_size
self.dim = dim
self.depth = depth
self.heads = heads
self.mlp_dim = mlp_dim
self.dim_head = dim_head
self.arch = arch
self.additional_patch_size = additional_patch_size
self.hybrid_mode = hybrid_mode
self.n_samples = n_samples # number of samples for inference
self.sr = sr # for preprocessing
self.n_frames = self._get_n_frames(n_samples)
# load architecture if provided
if arch:
arch = self._get_arch(arch)
self.dim = arch['dim']
self.depth = arch['depth']
self.heads = arch['heads']
self.mlp_dim = arch['mlp_dim']
def _get_arch(self, arch: str):
if arch.lower() in ['vit-s-16', 'vit-s-32']:
# dim 384, depth 12, MLP 1536, 6 heads, 22M parameters
return {'dim': 384, 'depth': 12, 'mlp_dim': 1536, 'heads': 6}
if arch.lower() == 'vit-b-16':
# dim 768, depth 12, MLP 3072, 12 heads, 87M parameters
return {'dim': 768, 'depth': 12, 'mlp_dim': 3072, 'heads': 12}
if arch.lower() == 'vit-l-16':
# dim 1024, depth 24, MLP 4096, 16 heads, 303M parameters
return {'dim': 1024, 'depth': 24, 'mlp_dim': 4096, 'heads': 16}
raise ValueError(f'Architecture {arch} not implemented')
def _get_n_frames(self, n_samples: int):
''' How many frames is n_samples samples? '''
mel_spectrogram = MelSpectrogram(sr=self.sr, n_mels=self.spec_size[0], verbose=False)
patch_size_time = self.patch_size if isinstance(self.patch_size, int) else self.patch_size[1]
mel_frames = mel_spectrogram(torch.randn(1, 1, n_samples)).shape[-1]
mel_frames = math.floor(mel_frames / patch_size_time) * patch_size_time
return mel_frames
class Myna(PreTrainedModel, PyTorchModelHubMixin):
config_class = MynaConfig
def __init__(self, config: MynaConfig):
super().__init__(config)
self.preprocessor = MynaPreprocessor()
self.hybrid_mode = config.hybrid_mode
spec_height, spec_width = pair(config.spec_size)
patch_height, patch_width = pair(config.patch_size)
assert spec_height % patch_height == 0 and spec_width % patch_width == 0, 'Spectrogram dimensions must be divisible by the patch size.'
self.additional_patch_size = config.additional_patch_size
if config.additional_patch_size:
patch_height_b, patch_width_b = pair(config.additional_patch_size)
patch_dim_b = patch_height_b * patch_width_b
self.to_patch_embedding_b, self.pos_embedding_b = self._make_embeddings(
patch_height_b, patch_width_b, patch_dim_b, config.dim, spec_height, spec_width
)
patch_dim = patch_height * patch_width
self.to_patch_embedding, self.pos_embedding = self._make_embeddings(
patch_height, patch_width, patch_dim, config.dim, spec_height, spec_width
)
self.transformer = Transformer(config.dim, config.depth, config.heads, config.dim_head, config.mlp_dim)
self.pool = 'mean'
self.to_latent = nn.Identity()
self.linear_head = nn.Identity()
def forward(self, spec, recurse=True):
if self.hybrid_mode and recurse:
a = self(spec, recurse=False)
self.toggle_embeddings()
b = self(spec, recurse=False)
self.toggle_embeddings()
return torch.cat((a, b), dim=-1)
# if input shape is not 4d, make it 4d:
if spec.dim() == 2:
# unbatched: n_mels, n_frames
spec = spec.unsqueeze(0).unsqueeze(0)
elif spec.dim() == 3:
# batched but without channels: B, n_mels, n_frames
spec = spec.unsqueeze(1)
assert spec.dim() == 4
device = spec.device
x = self.to_patch_embedding(spec)
n_patches = x.shape[1] # x is of shape (B, n_patches, dim)
x += self.pos_embedding[:n_patches].to(device, dtype=x.dtype)
x = self.transformer(x)
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)
def toggle_embeddings(self):
if not self.additional_patch_size:
print('toggle_embeddings() called but no additional patch size provided! Ignoring call.')
return
self.to_patch_embedding, self.to_patch_embedding_b = self.to_patch_embedding_b, self.to_patch_embedding
self.pos_embedding, self.pos_embedding_b = self.pos_embedding_b, self.pos_embedding
def _make_embeddings(self, patch_height, patch_width, patch_dim, dim, image_height, image_width):
to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
return to_patch_embedding, pos_embedding
def from_file(self, filename: str, n_samples: int = None):
n_frames = self.config.n_frames
if n_samples and n_samples != self.config.n_samples:
n_frames = self.config._get_n_frames(n_samples)
spec = self.preprocessor(filename, n_frames).to(self.device)
return self(spec)
@property
def n_params(self):
return sum(p.numel() for p in self.parameters())
def save_model_and_push(model, repo_name, save_dir='myna-temp', to_hub=False):
model.save_pretrained(save_dir)
shutil.copy('myna.py', save_dir)
config = model.config.to_dict()
config.update({
'_name_or_path': repo_name,
'architectures': ['Myna'],
'auto_map': {
'AutoConfig': 'myna.MynaConfig',
'AutoModel': 'myna.Myna'
},
'model_type': 'myna'
})
with open(os.path.join(save_dir, 'config.json'), 'w') as f:
json.dump(config, f, indent=4)
print(f'Model saved locally to {save_dir}')
if to_hub:
api = HfApi()
api.create_repo(repo_name, exist_ok=True)
api.upload_folder(folder_path=save_dir, repo_id=repo_name)
print(f"Model pushed to: https://huggingface.co/{repo_name}")
if __name__ == '__main__':
config = MynaConfig(
arch='vit-b-16', # arch='vit-s-16',
patch_size=16,
additional_patch_size=(128, 2),
hybrid_mode=True
)
model = Myna(config)
load_model(model, 'checkpoints/myna-85m.pth', verbose=True)
print(f'Model contains {model.n_params:,} parameters')
save_model_and_push(
model,
repo_name='oriyonay/myna-85m',
save_dir='myna-85m-hybrid',
to_hub=True
)