sohamc10's picture
gradio app
9b0d6c2
"""
Most of this code comes from the timm library.
We tried to disentangle from the timm library version.
Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
import collections
import logging
import math
import os
import warnings
from collections import OrderedDict
from functools import partial
from itertools import repeat
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.frame_passt.vit_helpers import (DropPath, trunc_normal_,
build_model_with_cfg, adapt_input_conv)
_logger = logging.getLogger()
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
# patch models (weights from official Google JAX impl)
'vit_tiny_patch16_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
'vit_tiny_patch16_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_patch32_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
'vit_small_patch32_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_patch16_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
'vit_small_patch16_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch32_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
'vit_base_patch32_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch16_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
'vit_base_patch16_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_patch32_224': _cfg(
url='', # no official model weights for this combo, only for in21k
),
'vit_large_patch32_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_patch16_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
'vit_large_patch16_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
# patch models, imagenet21k (weights from official Google JAX impl)
'vit_tiny_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_small_patch32_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_small_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_base_patch32_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_base_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_large_patch32_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
num_classes=21843),
'vit_large_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
num_classes=21843),
'vit_huge_patch14_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
hf_hub='timm/vit_huge_patch14_224_in21k',
num_classes=21843),
# SAM trained models (https://arxiv.org/abs/2106.01548)
'vit_base_patch32_sam_224': _cfg(
url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'),
'vit_base_patch16_sam_224': _cfg(
url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'),
# deit models (FB weights)
'deit_tiny_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'deit_small_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'deit_base_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'deit_base_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0),
'deit_tiny_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
'deit_small_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
'deit_base_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
'deit_base_distilled_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0,
classifier=('head', 'head_dist')),
# ViT ImageNet-21K-P pretraining by MILL
'vit_base_patch16_224_miil_in21k': _cfg(
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth',
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
),
'vit_base_patch16_224_miil': _cfg(
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm'
'/vit_base_patch16_224_1k_miil_84_4.pth',
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
),
# PaSST
'passt_s_swa_p16_128_ap476': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.1-audioset/passt-s-f128-p16-s10-ap.476-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_kd_p16_128_ap486': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v.0.0.9/passt-s-kd-ap.486.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_l_kd_p16_128_ap47': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v.0.0.10/passt-l-kd-ap.47.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_swa_p16_128_ap4761': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s10-ap.4761-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_p16_128_ap472': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s10-ap.472.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_p16_s16_128_ap468': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s16-ap.468.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_swa_p16_s16_128_ap473': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s16-ap.473-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_swa_p16_s14_128_ap471': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s14-ap.471-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_p16_s14_128_ap469': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s14-ap.469.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_swa_p16_s12_128_ap473': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s12-ap.473-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_p16_s12_128_ap470': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s12-ap.470.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_swa_f128_stfthop100_p16_s10_ap473': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.3-audioset/passt-s-f128-stfthop100-p16-s10-ap.473-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 3200), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_swa_f128_stfthop160_p16_s10_ap473': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.3-audioset/passt-s-f128-stfthop160-p16-s10-ap.473-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 2000), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt-s-f128-20sec-p16-s10-ap474-swa': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.5/passt-s-f128-20sec-p16-s10-ap.474-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 2000), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt-s-f128-30sec-p16-s10-ap473-swa': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.5/passt-s-f128-30sec-p16-s10-ap.473-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 3000), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'openmic2008_passt_u_f128_p16_s10_ap85_swa': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.4-openmic/openmic2008.passt-u-f128-p16-s10-ap.85-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 3200), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=20),
'openmic2008_passt_u_f128_p16_s10_ap85 ': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.4-openmic/openmic2008.passt-u-f128-p16-s10-ap.85.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 2000), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=20),
}
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
first_RUN = True
PLUS1_TRICK = False
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size=224, in_chans=1, frame_nr=1, stride=1, overlap=1, embed_dim=768, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
frame_nr = frame_nr
stride = stride
self.img_size = img_size
self.frame_nr = frame_nr
self.stride = stride
self.seq_len = int(img_size[1]) // frame_nr
self.num_patches = self.seq_len // stride
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=(int(img_size[0]), stride + overlap),
stride=stride, padding=(0, 1)) # 128 x 2 kernel
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, F, T = x.shape
if not (F == self.img_size[0] and abs(T - self.img_size[1]) <= 1): # allows for a difference of 1
warnings.warn(f"Input image size ({F}*{T}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
x = self.proj(x)[:, :, :, 1:] # B embed_dim 1 T (F=1)
x = self.norm(x)
if first_RUN: print("self.norm(x)", x.size())
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = attn_drop
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_drop,
is_causal=False, scale=self.scale)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PaSST(nn.Module):
"""
Based on the implementation of Vision Transformer in timm library.
Take a look at the get_model function, adapting the weights of pretrained imagenet models.
"""
def __init__(self, img_size=(128, 998),
in_chans=1, num_classes=527, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
act_layer=None, weight_init='',
frame_patchout=300, frame_nr=1, pos_embed_length=1000):
"""
Args:
img_size (int, tuple): input image size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
distilled (bool): model includes a distillation token and head as in DeiT models
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): activation layer
weight_init: (str): weight init scheme
frame_patchout (int): number of frames to patch out
frame_nr (int): the second dimension of the proj-convolution kernel
pos_embed_length (int): length of the positional embedding
"""
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 2 if distilled else 1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.act_layer = act_layer()
self.in_chans = in_chans
self.frame_patchout = frame_patchout
self.pos_embed_len = pos_embed_length
# these three convolution are different compared to the vanilla passt
self.conv_in_1 = nn.Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.conv_in_2 = nn.Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.conv_in_3 = nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # 64 instead of 4
img_size = (img_size[0], pos_embed_length) # 128, 250
self.patch_embed = embed_layer(
img_size=img_size, in_chans=in_chans, frame_nr=frame_nr, stride=frame_nr, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
# PaSST
# refer to https://arxiv.org/abs/2110.05069 Section 2
self.new_pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) # for C and D tokens
self.freq_new_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 1, 1)) # | f
self.time_new_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 1, self.pos_embed_len)) # __ t
####
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
# Representation layer
if representation_size and not distilled:
self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(embed_dim, representation_size)),
('act', nn.Tanh())
]))
else:
self.pre_logits = nn.Identity()
self.init_weights(weight_init)
def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'nlhb', ''), f"mode: {mode}"
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
trunc_normal_(self.new_pos_embed, std=.02)
trunc_normal_(self.freq_new_pos_embed, std=.02)
trunc_normal_(self.time_new_pos_embed, std=.02)
if self.dist_token is not None:
trunc_normal_(self.dist_token, std=.02)
if mode.startswith('jax'):
# leave cls token as zeros to match jax impl
raise RuntimeError("Not supported yet")
else:
trunc_normal_(self.cls_token, std=.02)
self.apply(_init_vit_weights)
def _init_weights(self, m):
# this fn left here for compat with downstream users
_init_vit_weights(m)
@torch.jit.ignore
def no_weight_decay(self):
return {'new_pos_embed', 'freq_new_pos_embed', 'time_new_pos_embed', 'cls_token', 'dist_token'}
def forward_features(self, x):
global first_RUN # not jit friendly? use trace instead
# some 2D convolutions
f_dim = x.size(2) # 128
x = self.act_layer(self.conv_in_1(x))
x = self.act_layer(self.conv_in_2(x))
x = self.act_layer(self.conv_in_3(x))
if first_RUN: print("after convs", x.size())
x = x.reshape(x.shape[0], (x.shape[1] * x.shape[2]) // f_dim, f_dim, x.shape[3])
if first_RUN: print("after reshape", x.size())
x = self.patch_embed(x) # [b, e, f, t]
B_dim, E_dim, F_dim, T_dim = x.shape # slow
if first_RUN: print(" patch_embed : ", x.shape)
# Adding Time/Freq information
if first_RUN: print(" self.time_new_pos_embed.shape", self.time_new_pos_embed.shape)
time_new_pos_embed = self.time_new_pos_embed
if x.shape[-1] < time_new_pos_embed.shape[-1]:
if self.training:
toffset = torch.randint(1 + time_new_pos_embed.shape[-1] - x.shape[-1], (1,)).item()
if first_RUN: print(f" CUT with randomoffset={toffset} time_new_pos_embed.shape",
time_new_pos_embed.shape)
time_new_pos_embed = time_new_pos_embed[:, :, :, toffset:toffset + x.shape[-1]]
else:
time_new_pos_embed = time_new_pos_embed[:, :, :, :x.shape[-1]]
if first_RUN: print(" CUT time_new_pos_embed.shape", time_new_pos_embed.shape)
else:
# warnings.warn(
# f"the patches shape:{x.shape} are larger than the expected time encodings {time_new_pos_embed.shape}, x will be cut")
x = x[:, :, :, :time_new_pos_embed.shape[-1]]
x = x + time_new_pos_embed
if first_RUN: print(" self.freq_new_pos_embed.shape", self.freq_new_pos_embed.shape)
x = x + self.freq_new_pos_embed
# Structured Patchout https://arxiv.org/abs/2110.05069 Section 2.2
if self.training and self.frame_patchout:
if first_RUN: print(f"X Before frame Patchout of {self.frame_patchout} ", x.size())
# ([1, 768, 1, 82])
random_indices = torch.randperm(T_dim)[:T_dim - self.frame_patchout].sort().values
x = x[:, :, :, random_indices]
if first_RUN: print("X after frame Patchout", x.size())
x = x.flatten(2).transpose(1, 2)
# Add the C/D tokens
if first_RUN: print(" self.new_pos_embed.shape", self.new_pos_embed.shape)
cls_tokens = self.cls_token.expand(B_dim, -1, -1) + self.new_pos_embed[:, :1, :]
if first_RUN: print(" self.cls_tokens.shape", cls_tokens.shape)
if self.dist_token is None:
x = torch.cat((cls_tokens, x), dim=1)
else:
dist_token = self.dist_token.expand(B_dim, -1, -1) + self.new_pos_embed[:, 1:, :]
if first_RUN: print(" self.dist_token.shape", dist_token.shape)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
if first_RUN: print(" final sequence x", x.shape)
x = self.pos_drop(x)
x = self.blocks(x)
if first_RUN: print(f" after {len(self.blocks)} atten blocks x", x.shape)
x = self.norm(x)
return x
def forward(self, x):
global first_RUN
if first_RUN: print("x", x.size())
x = self.forward_features(x)
c, x = x[:, :2].mean(1), x[:, 2:]
if first_RUN: print("x after forward_features", x.size())
first_RUN = False
return x
def load_model(self, path, wandb_id):
ckpt_path = os.path.join(path, wandb_id + ".ckpt")
pretrained_weights = torch.load(ckpt_path, map_location="cpu")["state_dict"]
pretrained_weights = {k[10:]: v for k, v in pretrained_weights.items() if k[:10] == "net.model."}
self.load_state_dict(pretrained_weights)
print("Loaded model successfully. Wandb_id:", wandb_id)
def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
""" ViT weight initialization
* When called without n, head_bias, jax_impl args it will behave exactly the same
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
"""
if isinstance(module, nn.Linear):
if name.startswith('head'):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
elif name.startswith('pre_logits'):
lecun_normal_(module.weight)
nn.init.zeros_(module.bias)
else:
if jax_impl:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
if 'mlp' in name:
nn.init.normal_(module.bias, std=1e-6)
else:
nn.init.zeros_(module.bias)
else:
trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif jax_impl and isinstance(module, nn.Conv2d):
# NOTE conv was left to pytorch default in my original init
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
nn.init.zeros_(module.bias)
nn.init.ones_(module.weight)
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=(), mode='bicubic'):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
_logger.info('Resized position embedding: %s to %s with %s cls/dis tokens', posemb.shape, posemb_new.shape,
num_tokens)
ntok_new = posemb_new.shape[1]
if num_tokens:
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
ntok_new -= num_tokens
else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(math.sqrt(len(posemb_grid)))
if not len(gs_new): # backwards compatibility
gs_new = [int(math.sqrt(ntok_new))] * 2
assert len(gs_new) >= 2
_logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode=mode, align_corners=False)
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
return posemb
def adapt_image_pos_embed_to_passt(posemb, num_tokens=1, posemb_len=1000, mode='bicubic'):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
if num_tokens:
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(math.sqrt(len(posemb_grid)))
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=(1, posemb_len), mode=mode, align_corners=False)
freq_new_pos_embed = posemb_grid.mean(dim=3, keepdim=True)
time_new_pos_embed = posemb_grid.mean(dim=2, keepdim=True)
_logger.info('New Position cls/dstl embedding %s', posemb_tok.shape)
_logger.info('New FREQ Position embedding %s', freq_new_pos_embed.shape)
_logger.info('New TIME Position embedding %s', time_new_pos_embed.shape)
return posemb_tok, freq_new_pos_embed, time_new_pos_embed
def checkpoint_filter_fn(state_dict, model):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {}
if 'model' in state_dict:
# For deit models
state_dict = state_dict['model']
state_dict = {k: v for k, v in state_dict.items()}
if "time_new_pos_embed" not in state_dict:
# we are working with ImageNet model
_logger.info("Adapting pos embedding from ImageNet pretrained model to PaSST.")
v = state_dict.pop("pos_embed")
new_pos_embed, freq_new_pos_embed, time_new_pos_embed = adapt_image_pos_embed_to_passt(
v, getattr(model, 'num_tokens', 1), model.pos_embed_len)
state_dict["new_pos_embed"] = new_pos_embed
state_dict["freq_new_pos_embed"] = freq_new_pos_embed
state_dict["time_new_pos_embed"] = time_new_pos_embed
for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k:
embed_dim, C, H, W = v.shape
v = adapt_input_conv(model.in_chans, v, input_conv_name=k)
k1, k2 = model.patch_embed.proj.kernel_size # 128, 2
# clever reshape
assert H * W == k1 * k2, "Error in the kernel size of the patch embedding"
v = v.reshape(embed_dim, model.in_chans, k1, k2) # [embed_dim, 1, k1, k2]
out_dict[k] = v
return out_dict
def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
default_cfg = default_cfg or default_cfgs[variant]
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
# NOTE this extra code to support handling of repr size for in21k pretrained models
default_num_classes = default_cfg['num_classes']
num_classes = kwargs.get('num_classes', default_num_classes)
repr_size = kwargs.pop('representation_size', None)
if repr_size is not None and num_classes != default_num_classes:
# Remove representation layer if fine-tuning. This may not always be the desired action,
# but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
_logger.warning("Removing representation layer for fine-tuning.")
repr_size = None
model = build_model_with_cfg(
PaSST, variant, pretrained,
default_cfg=default_cfg,
representation_size=repr_size,
pretrained_filter_fn=checkpoint_filter_fn,
pretrained_custom_load='npz' in default_cfg['url'],
**kwargs)
return model
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
"""
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
return model
def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
""" DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer(
'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_swa_p16_128_ap476(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 10 structured patchout mAP=476 SWA \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (10, 10):
warnings.warn(
f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_s_swa_p16_128_ap476', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_kd_p16_128_ap486(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST pre-trained on AudioSet (with KD) Patch 16 stride 10 structured patchout mAP=486 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (10, 10):
warnings.warn(
f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_s_kd_p16_128_ap486', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_l_kd_p16_128_ap47(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print(
"\n\n Loading PaSST-L (light, reduced depth=7) pre-trained on AudioSet (with KD) Patch 16 stride 10 structured patchout mAP=4708 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768,
depth=7, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (10, 10):
warnings.warn(
f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_l_kd_p16_128_ap47', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_swa_p16_128_ap4761(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 10 structured patchout mAP=4763 SWA \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (10, 10):
warnings.warn(
f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_s_swa_p16_128_ap4761', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_p16_128_ap472(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 10 structured patchout mAP=472 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (10, 10):
warnings.warn(
f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_s_p16_128_ap472', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_p16_s12_128_ap470(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 12 structured patchout mAP=472 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (12, 12):
warnings.warn(
f"This model was pre-trained with strides {(12, 12)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_s_p16_s12_128_ap470', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_f128_20sec_p16_s10_ap474_swa(pretrained=False, **kwargs):
print("\n\n Loading PASST TRAINED ON AUDISET with 20 Second time encodings, with STFT hop of 160 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer(
'passt-s-f128-20sec-p16-s10-ap474-swa', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_f128_30sec_p16_s10_ap473_swa(pretrained=False, **kwargs):
print("\n\n Loading PASST TRAINED ON AUDISET with 30 Second time encodings, with STFT hop of 160 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer(
'passt-s-f128-30sec-p16-s10-ap473-swa', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_swa_p16_s12_128_ap473(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 12 structured patchout mAP=472 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (12, 12):
warnings.warn(
f"This model was pre-trained with strides {(12, 12)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_s_swa_p16_s12_128_ap473', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_p16_s14_128_ap469(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 14 structured patchout mAP=472 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (14, 14):
warnings.warn(
f"This model was pre-trained with strides {(14, 14)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_s_p16_s14_128_ap469', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_swa_p16_s14_128_ap471(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 14 structured patchout mAP=472 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (14, 14):
warnings.warn(
f"This model was pre-trained with strides {(14, 14)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_s_swa_p16_s14_128_ap471', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_swa_p16_s16_128_ap473(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 16 structured patchout mAP=472 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (16, 16):
warnings.warn(
f"This model was pre-trained with strides {(16, 16)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_s_swa_p16_s16_128_ap473', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_p16_s16_128_ap468(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 16 structured patchout mAP=472 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (16, 16):
warnings.warn(
f"This model was pre-trained with strides {(16, 16)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_s_p16_s16_128_ap468', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def fix_embedding_layer(model, embed="default"):
if embed == "default":
return model
if embed == "overlap":
model.patch_embed = PatchEmbedAdaptiveMean(replace=model.patch_embed)
if embed == "am_keepconv":
model.patch_embed = PatchEmbedAdaptiveMeanKeepConv(replace=model.patch_embed)
return model
def lighten_model(model, cut_depth=0):
if cut_depth == 0:
return model
if cut_depth:
if cut_depth < 0:
print(f"\n Reducing model depth by removing every {-cut_depth} layer \n\n")
else:
print(f"\n Reducing model depth by {cut_depth} \n\n")
if len(model.blocks) < cut_depth + 2:
raise ValueError(f"Cut depth a VIT with {len(model.blocks)} "
f"layers should be between 1 and {len(model.blocks) - 2}")
print(f"\n Before Cutting it was {len(model.blocks)} \n\n")
old_blocks = list(model.blocks.children())
if cut_depth < 0:
print(f"cut_depth={cut_depth}")
old_blocks = [old_blocks[0]] + old_blocks[1:-1:-cut_depth] + [old_blocks[-1]]
else:
old_blocks = [old_blocks[0]] + old_blocks[cut_depth + 1:]
model.blocks = nn.Sequential(*old_blocks)
print(f"\n Atfer Cutting it is {len(model.blocks)} \n\n")
return model
def get_model(arch="passt_s_kd_p16_128_ap486", pretrained=True, n_classes=527, in_channels=1,
input_fdim=128, input_tdim=998, frame_patchout=300, pos_embed_length=1000
):
"""
:param arch: Base ViT or Deit architecture
:param pretrained: use pretrained model on imagenet
:param n_classes: number of classes
:param in_channels: number of input channels: 1 for mono
:param input_fdim: the expected input frequency bins.
:param input_tdim: the expected input time bins.
:param frame_patchout: the number of frames to be removed from the input
@param wandb_id: tries to load model with corresponding wandb_id from 'pretrained_path'
:return:
"""
model_func = None
input_size = (input_fdim, input_tdim)
if arch == "passt_deit_bd_p16_384": # base deit
model_func = deit_base_distilled_patch16_384
elif arch == "passt_s_kd_p16_128_ap486": # pretrained
model_func = passt_s_kd_p16_128_ap486
elif arch == "passt_l_kd_p16_128_ap47": # pretrained passt-L
model_func = passt_l_kd_p16_128_ap47
elif arch == "passt_s_swa_p16_128_ap476": # pretrained
model_func = passt_s_swa_p16_128_ap476
elif arch == "passt_s_swa_p16_128_ap4761":
model_func = passt_s_swa_p16_128_ap4761
elif arch == "passt_s_p16_128_ap472":
model_func = passt_s_p16_128_ap472
elif arch == "passt_s_p16_s16_128_ap468":
model_func = passt_s_p16_s16_128_ap468
elif arch == "passt_s_swa_p16_s16_128_ap473":
model_func = passt_s_swa_p16_s16_128_ap473
elif arch == "passt_s_swa_p16_s14_128_ap471":
model_func = passt_s_swa_p16_s14_128_ap471
elif arch == "passt_s_p16_s14_128_ap469":
model_func = passt_s_p16_s14_128_ap469
elif arch == "passt_s_swa_p16_s12_128_ap473":
model_func = passt_s_swa_p16_s12_128_ap473
elif arch == "passt_s_p16_s12_128_ap470":
model_func = passt_s_p16_s12_128_ap470
elif arch == "passt_s_f128_20sec_p16_s10_ap474":
model_func = passt_s_f128_20sec_p16_s10_ap474_swa
elif arch == "passt_s_f128_30sec_p16_s10_ap473":
model_func = passt_s_f128_30sec_p16_s10_ap473_swa
if model_func is None:
raise RuntimeError(f"Unknown model {arch}")
model = model_func(pretrained=pretrained, num_classes=n_classes, in_chans=in_channels,
img_size=input_size, frame_patchout=frame_patchout, pos_embed_length=pos_embed_length)
model = fix_embedding_layer(model)
model = lighten_model(model)
return model
class EnsembelerModel(nn.Module):
def __init__(self, models):
super(EnsembelerModel, self).__init__()
self.models = nn.ModuleList(models)
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
all_out = None
for i, m in enumerate(self.models):
out, _ = m(x)
if all_out is None:
all_out = out
else:
all_out = out + all_out
all_out = all_out / len(self.models)
return all_out, all_out