Spaces:
Sleeping
Sleeping
"""Masked Modeling Duo (M2D) Portable Runtime. | |
All you need is: | |
pip install timm, einops, nnAudio | |
""" | |
import logging | |
from functools import partial | |
from pathlib import Path | |
import nnAudio.features | |
import numpy as np | |
import timm | |
import torch | |
from einops import rearrange | |
from timm.models.layers import trunc_normal_ | |
class Config: | |
weight_file = '' | |
feature_d = 768 * 5 | |
norm_type = all | |
pooling_type = 'mean' | |
model = '' | |
input_size = [80, 208] | |
patch_size = [16, 16] | |
sr = '16k' | |
flat_features = False | |
def expand_size(sz): | |
if isinstance(sz, int): | |
return [sz, sz] | |
return sz | |
class PatchEmbed(torch.nn.Module): | |
""" 2D Image to Patch Embedding -- borrowed from https://pypi.org/project/timm/0.4.12/""" | |
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): | |
super().__init__() | |
img_size = expand_size(img_size) | |
patch_size = expand_size(patch_size) | |
self.img_size = img_size | |
self.patch_size = patch_size | |
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) | |
self.num_patches = self.grid_size[0] * self.grid_size[1] | |
self.flatten = flatten | |
self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | |
self.norm = norm_layer(embed_dim) if norm_layer else torch.nn.Identity() | |
def forward(self, x): | |
x = self.proj(x) | |
if self.flatten: | |
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC | |
x = self.norm(x) | |
return x | |
class LocalViT(timm.models.vision_transformer.VisionTransformer): | |
""" Vision Transformer for M2D Audio""" | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
# Workaround for PatchEmbed to avoid unintended assertion failure. ex) AssertionError: Input image width (102) doesn't match model (608). | |
self.patch_embed = PatchEmbed(self.patch_embed.img_size, self.patch_embed.patch_size, | |
self.patch_embed.proj.in_channels, self.patch_embed.proj.out_channels) | |
self.norm_stats = torch.nn.Parameter(torch.tensor([-7.1, 4.2]), requires_grad=False) | |
# We do not use the default head | |
del self.head | |
def patch_size(self): | |
return np.array(self.patch_embed.patch_size) | |
def grid_size(self): | |
# Workaround for compatibility issue (timm 0.4.5 fails with: return self.patch_embed.grid_size) | |
img_size = np.array(self.patch_embed.img_size) | |
patch_size = self.patch_size() | |
grid_size = img_size // patch_size | |
return grid_size | |
def forward_encoder(self, x): | |
x = self.patch_embed(x) | |
# add pos embed w/o cls token | |
pos_embed = self.pos_embed[:, 1:, :] | |
if x.shape[1] < pos_embed.shape[1]: # shorten pos_embed for a short input | |
dims = pos_embed.shape[-1] | |
fbins = self.grid_size()[0] | |
frames = x.shape[1] // fbins | |
pos_embed = pos_embed.reshape(1, fbins, -1, dims)[:, :, :frames, :].reshape(1, fbins * frames, dims) | |
x = x + pos_embed | |
# append cls token | |
cls_token = self.cls_token + self.pos_embed[:, :1, :] | |
cls_tokens = cls_token.expand(x.shape[0], -1, -1) | |
x = torch.cat((cls_tokens, x), dim=1) | |
# apply Transformer blocks | |
for blk in self.blocks: | |
x = blk(x) | |
x = self.norm(x) | |
return x | |
def parse_sizes_by_name(name): | |
# Parse parameters. "m2d_vit_base-80x1001p16x16p16k" -> input size: 80x1001, patch size: 16x16, sr: 16k | |
model_cls = name.split('-')[0] | |
params = name.split('-')[1] | |
params = params.split('p')[:3] | |
input_str, patch_str, sr = params[0], params[1], params[2] if len(params) > 2 else '16k' | |
input_size = [int(a) for a in input_str.split('x')] | |
patch_size = [int(a) for a in patch_str.split('x')] | |
return input_size, patch_size, sr, model_cls | |
def drop_non_model_weights(model, checkpoint, filename): | |
model_keys = [n for n, p in model.named_parameters()] | |
new_ckpt, dropped = {}, [] | |
for k in checkpoint: | |
if k not in model_keys: | |
dropped.append(k) | |
continue | |
new_ckpt[k] = checkpoint[k] | |
n_org = len(checkpoint.keys()) | |
n_cur = len(new_ckpt.keys()) | |
print( | |
f' using {n_cur} parameters, while dropped {n_org - n_cur} out of {n_org} parameters from {Path(filename).parent / Path(filename).name}' | |
if n_org > n_cur else f' using {n_cur} parameters from {Path(filename).parent / Path(filename).name}') | |
print(' (dropped:', dropped[:5], ')' if len(dropped) < 5 else '...)') | |
return new_ckpt | |
def load_evar_head_parameters(checkpoint, head_norm, head): | |
# Load the weights of the task head trained in the EVAR fine-tuning. | |
if 'module.head.norm.running_mean' in checkpoint: | |
head_norm.load_state_dict({to_k: checkpoint[k] for to_k, k in { | |
'running_mean': 'module.head.norm.running_mean', 'running_var': 'module.head.norm.running_var'}.items()}) | |
head.load_state_dict({to_k: checkpoint[k] for to_k, k in { | |
'weight': 'module.head.mlp.mlp.0.weight', 'bias': 'module.head.mlp.mlp.0.bias'}.items()}) | |
else: | |
print(' Not an EVAR checkpoint for loading head weights.') | |
def reformat_ckpt_keys(checkpoint): | |
# In case: checkpoint['model'] | |
checkpoint = checkpoint['model'] if 'model' in checkpoint else checkpoint | |
# The checkpoints saved in a EVAR fine-tuning has a prefix of "module.ar.runtime.backbone", the following removes it. | |
new_ckpt = {} | |
for k in checkpoint: | |
new_k = k.replace('module.ar.runtime.backbone.', '') # replace | |
new_ckpt[new_k] = checkpoint[k] | |
return new_ckpt | |
def make_it_CLAP(model, checkpoint): | |
# Add projectors if needed | |
if 'audio_proj.0.weight' in checkpoint.keys(): | |
proj_hidden_dim = embed_dim = checkpoint['audio_proj.0.weight'].shape[1] | |
model.audio_proj = torch.nn.Sequential( | |
torch.nn.Linear(embed_dim, proj_hidden_dim), | |
torch.nn.ReLU(), | |
torch.nn.Linear(proj_hidden_dim, embed_dim), | |
) | |
if 'text_proj.weight' in checkpoint.keys(): | |
dim = checkpoint['text_proj.weight'].shape | |
model.text_proj = torch.nn.Linear(dim[1], dim[0]) | |
else: | |
model.text_proj = torch.nn.Identity() | |
def get_backbone(args, weight_file): | |
name = Path(weight_file).parent.name if weight_file is not None \ | |
else "m2d_clap_vit_base-80x1001p16x16-240128_AS-FT_enconly" | |
args.input_size, args.patch_size, args.sr, args.beats = parse_sizes_by_name(name) | |
# Create a ViT. | |
model = LocalViT( | |
in_chans=1, img_size=args.input_size, patch_size=args.patch_size, embed_dim=768, depth=12, num_heads=12, | |
mlp_ratio=4, norm_layer=partial(torch.nn.LayerNorm, eps=1e-6)) | |
if weight_file is None: | |
args.mean, args.std = -7.1, 4.2 | |
model.eval() | |
return model, None | |
# Load checkpoint. | |
checkpoint = torch.load(weight_file, map_location='cpu') | |
checkpoint = reformat_ckpt_keys(checkpoint) | |
# Set normalization statistics for backward compatibility. The [-7.1, 4.2] is for 2022 models. | |
if 'norm_stats' not in checkpoint: | |
checkpoint['norm_stats'] = torch.tensor([-7.1, 4.2]) | |
print(' using default norm_stats:', checkpoint['norm_stats']) | |
# Modify the model if it should be a M2D-CLAP. | |
make_it_CLAP(model, checkpoint) | |
# Load weights. | |
dropped = drop_non_model_weights(model, checkpoint, weight_file) | |
msg = model.load_state_dict(dropped) | |
print(msg); | |
logging.info(msg) | |
# Make normalization statistics for the model easy to use in the downstream task. | |
args.mean, args.std = model.state_dict()['norm_stats'].to('cpu').numpy() | |
model.eval() | |
return model, checkpoint | |
def get_to_melspec(cfg): | |
if cfg.sr == '16k': | |
cfg.sample_rate, cfg.n_fft, cfg.window_size, cfg.hop_size = 16000, 400, 400, 160 | |
cfg.n_mels, cfg.f_min, cfg.f_max = 80, 50, 8000 | |
elif cfg.sr == '32k': | |
cfg.sample_rate, cfg.n_fft, cfg.window_size, cfg.hop_size = 32000, 800, 800, 320 | |
cfg.n_mels, cfg.f_min, cfg.f_max = 80, 50, 16000 | |
else: | |
assert False, f'Unknown input size: {cfg.input_size}' | |
to_spec = nnAudio.features.MelSpectrogram( | |
sr=cfg.sample_rate, | |
n_fft=cfg.n_fft, | |
win_length=cfg.window_size, | |
hop_length=cfg.hop_size, | |
n_mels=cfg.n_mels, | |
fmin=cfg.f_min, | |
fmax=cfg.f_max, | |
center=True, | |
power=2, | |
verbose=False, | |
) | |
logging.info(f'Runtime MelSpectrogram({cfg.sample_rate}, {cfg.n_fft}, {cfg.window_size}, {cfg.hop_size}, ' | |
+ f'{cfg.n_mels}, {cfg.f_min}, {cfg.f_max}):') | |
logging.info(to_spec) | |
return to_spec | |
def get_timestamps(cfg, batch_audio, x): # Returns timestamps in milliseconds. | |
audio_len = len(batch_audio[0]) | |
sec = audio_len / cfg.sample_rate | |
x_len = len(x[0]) | |
step = sec / x_len * 1000 # sec -> ms | |
ts = torch.tensor([step * i for i in range(x_len)]).unsqueeze(0) | |
ts = ts.repeat(len(batch_audio), 1) | |
return ts | |
class PortableM2D(torch.nn.Module): | |
def __init__(self, weight_file=None, num_classes=None, freeze_embed=False, flat_features=None): | |
super().__init__() | |
self.cfg = Config() | |
self.cfg.weight_file = weight_file | |
self.cfg.freeze_embed = freeze_embed | |
self.cfg.flat_features = self.cfg.flat_features if flat_features is None else flat_features | |
# Create backbone model. | |
self.backbone, checkpoint = get_backbone(self.cfg, self.cfg.weight_file) | |
# Finalize feature dimension. | |
d = self.backbone.pos_embed.shape[-1] | |
if num_classes is not None and 'module.head.mlp.mlp.0.weight' in checkpoint and \ | |
checkpoint['module.head.mlp.mlp.0.weight'].shape[-1] == d: | |
self.cfg.flat_features = True | |
n_stack_feature = 1 if self.cfg.flat_features else (self.cfg.input_size[0] // self.cfg.patch_size[0]) | |
self.cfg.feature_d = d * n_stack_feature # 768 if flat_features else 768*5=3840 | |
# Create head. | |
if num_classes is not None: | |
self.head_norm = torch.nn.BatchNorm1d(self.cfg.feature_d, affine=False) | |
self.head = torch.nn.Linear(self.cfg.feature_d, num_classes) | |
trunc_normal_(self.head.weight, std=2e-5) | |
load_evar_head_parameters(checkpoint, self.head_norm, self.head) | |
# Option: freeze patch embedding ([2211.09359] How to Fine-Tune Vision Models with SGD) | |
if self.cfg.freeze_embed: | |
models_mae.set_requires_grad(self.backbone.patch_embed, False) | |
logging.info(' ** Freeze patch_embed **') | |
logging.info(self.backbone.patch_embed) | |
logging.info(f'Model input size: {self.cfg.input_size}') | |
logging.info(f'Using weights: {self.cfg.weight_file}') | |
logging.info(f'Feature dimension: {self.cfg.feature_d}') | |
logging.info(f'Norm stats: {self.cfg.mean}, {self.cfg.std}') | |
self.to_spec = get_to_melspec(self.cfg) | |
self.eval() | |
def to_log_mel_spec(self, batch_audio): | |
x = self.to_spec(batch_audio) | |
x = (x + torch.finfo().eps).log() | |
x = x.unsqueeze(1) | |
return x | |
def normalize_batch(self, x): | |
x = (x - self.cfg.mean) / self.cfg.std | |
return x | |
def to_normalized_feature(self, batch_audio): | |
x = self.to_log_mel_spec(batch_audio) | |
x = self.normalize_batch(x) | |
return x | |
def encode_lms(self, x, average_per_time_frame=False): | |
patch_fbins = self.backbone.grid_size()[0] | |
unit_frames = self.cfg.input_size[1] | |
patch_frames = self.backbone.patch_size()[1] | |
embed_d = self.backbone.patch_embed.proj.out_channels | |
n_chunk = (x.shape[-1] + unit_frames - 1) // unit_frames | |
pad_frames = (patch_frames - (x.shape[-1] % unit_frames % patch_frames)) % patch_frames | |
if pad_frames > 0: | |
x = torch.nn.functional.pad(x, (0, pad_frames)) | |
embeddings = [] | |
if self.cfg.flat_features: | |
# flatten all patch embeddings | |
for i in range(n_chunk): | |
emb = self.backbone.forward_encoder(x[..., i * unit_frames:(i + 1) * unit_frames]) | |
emb = emb[..., 1:, :] | |
if average_per_time_frame: | |
emb = rearrange(emb, 'b (f t) d -> b t d f', f=patch_fbins, d=embed_d).mean(-1) | |
embeddings.append(emb) | |
else: | |
# stack embeddings along time frame | |
for i in range(n_chunk): | |
emb = self.backbone.forward_encoder(x[..., i * unit_frames:(i + 1) * unit_frames]) | |
emb = emb[..., 1:, :] | |
emb = rearrange(emb, 'b (f t) d -> b t (f d)', f=patch_fbins, d=embed_d) | |
embeddings.append(emb) | |
# concatenate embedding chunks in the time axis | |
x = torch.cat(embeddings, axis=-2) | |
return x | |
def encode(self, batch_audio, average_per_time_frame=False): | |
x = self.to_normalized_feature(batch_audio) | |
return self.encode_lms(x, average_per_time_frame=average_per_time_frame) | |
def forward(self, batch_audio, average_per_time_frame=False): | |
x = self.encode(batch_audio, average_per_time_frame=average_per_time_frame) | |
if hasattr(self, 'head'): | |
x = x.mean(1) # B, D | |
x = self.head_norm(x.unsqueeze(-1)).squeeze(-1) | |
x = self.head(x) | |
return x | |
def forward_mel(self, batch_mel, average_per_time_frame=False): | |
x = self.encode_lms(batch_mel, average_per_time_frame=average_per_time_frame) | |
if hasattr(self, 'head'): | |
x = x.mean(1) # B, D | |
x = self.head_norm(x.unsqueeze(-1)).squeeze(-1) | |
x = self.head(x) | |
return x | |
def get_scene_embeddings(self, batch_audio): | |
x = self.encode(batch_audio) | |
x = torch.mean(x, dim=1) | |
return x | |
def get_timestamp_embeddings(self, batch_audio): | |
x = self.encode(batch_audio, average_per_time_frame=True) | |
ts = get_timestamps(self.cfg, batch_audio, x) | |
return x, ts | |
def forward_frames(self, batch_audio): | |
x, ts = self.get_timestamp_embeddings(batch_audio) | |
if hasattr(self, 'head'): | |
x = self.head_norm(x.transpose(-1, -2)).transpose(-2, -1) | |
x = self.head(x) | |
return x, ts | |
def encode_clap_audio(self, batch_audio): | |
audio_embeddings = self.forward(batch_audio) | |
audio_embeddings = audio_embeddings.mean(dim=-2) | |
audio_embeddings = self.backbone.audio_proj(audio_embeddings) | |
return audio_embeddings | |
def encode_clap_text(self, batch_text, truncate=False): | |
if not hasattr(self, 'text_encoder'): | |
self.text_encoder = GTETextEncoder() | |
text_embeddings = self.text_encoder(batch_text, truncate=truncate) | |
text_embeddings = self.backbone.text_proj(text_embeddings) | |
text_embeddings = text_embeddings.detach().cpu().to(torch.float) | |
return text_embeddings | |
# For the CLAP models | |
class GTETextEncoder: | |
def __init__(self, clip_weight="thenlper/gte-base"): | |
from transformers import AutoTokenizer, AutoModel | |
import os | |
os.environ["TOKENIZERS_PARALLELISM"] = "true" # To suppress warnings. | |
self.tokenizer = AutoTokenizer.from_pretrained(clip_weight) | |
self.model = AutoModel.from_pretrained(clip_weight) | |
def __call__(self, texts, truncate=True, max_length=512): | |
def average_pool(last_hidden_states, attention_mask): | |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
with torch.no_grad(): | |
device = next(self.model.parameters()).device | |
batch_dict = self.tokenizer(texts, max_length=max_length, padding=True, truncation=truncate, | |
return_tensors='pt') | |
batch_dict['input_ids'] = batch_dict['input_ids'].to(device) | |
batch_dict['token_type_ids'] = batch_dict['token_type_ids'].to(device) | |
batch_dict['attention_mask'] = batch_dict['attention_mask'].to(device) | |
outputs = self.model.to(device)(**batch_dict) | |
embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask']) | |
return embeddings | |