import os import urllib.parse from functools import partial from typing import Any, Callable, List, Optional, Sequence, Tuple import torch from torch import nn, Tensor from torch.hub import load_state_dict_from_url from torchvision.ops.misc import ConvNormActivation from models.frame_mn.block_types import InvertedResidualConfig, InvertedResidual from models.frame_mn.utils import cnn_out_size # Adapted version of MobileNetV3 pytorch implementation # https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv3.py # points to github releases model_url = "https://github.com/fschmid56/EfficientAT/releases/download/v0.0.1/" # folder to store downloaded models to model_dir = "resources" pretrained_models = { # pytorch ImageNet pre-trained model # own ImageNet pre-trained models will follow # NOTE: for easy loading we provide the adapted state dict ready for AudioSet training (1 input channel, # 527 output classes) # NOTE: the classifier is just a random initialization, feature extractor (conv layers) is pre-trained "mn10_im_pytorch": urllib.parse.urljoin(model_url, "mn10_im_pytorch.pt"), # self-trained models on ImageNet "mn01_im": urllib.parse.urljoin(model_url, "mn01_im.pt"), "mn02_im": urllib.parse.urljoin(model_url, "mn02_im.pt"), "mn04_im": urllib.parse.urljoin(model_url, "mn04_im.pt"), "mn05_im": urllib.parse.urljoin(model_url, "mn05_im.pt"), "mn06_im": urllib.parse.urljoin(model_url, "mn06_im.pt"), "mn10_im": urllib.parse.urljoin(model_url, "mn10_im.pt"), "mn20_im": urllib.parse.urljoin(model_url, "mn20_im.pt"), "mn30_im": urllib.parse.urljoin(model_url, "mn30_im.pt"), "mn40_im": urllib.parse.urljoin(model_url, "mn40_im.pt"), # Models trained on AudioSet "mn01_as": urllib.parse.urljoin(model_url, "mn01_as_mAP_298.pt"), "mn02_as": urllib.parse.urljoin(model_url, "mn02_as_mAP_378.pt"), "mn04_as": urllib.parse.urljoin(model_url, "mn04_as_mAP_432.pt"), "mn05_as": urllib.parse.urljoin(model_url, "mn05_as_mAP_443.pt"), "mn10_as": urllib.parse.urljoin(model_url, "mn10_as_mAP_471.pt"), "mn20_as": urllib.parse.urljoin(model_url, "mn20_as_mAP_478.pt"), "mn30_as": urllib.parse.urljoin(model_url, "mn30_as_mAP_482.pt"), "mn40_as": urllib.parse.urljoin(model_url, "mn40_as_mAP_484.pt"), "mn40_as(2)": urllib.parse.urljoin(model_url, "mn40_as_mAP_483.pt"), "mn40_as(3)": urllib.parse.urljoin(model_url, "mn40_as_mAP_483(2).pt"), "mn40_as_no_im_pre": urllib.parse.urljoin(model_url, "mn40_as_no_im_pre_mAP_483.pt"), "mn40_as_no_im_pre(2)": urllib.parse.urljoin(model_url, "mn40_as_no_im_pre_mAP_483(2).pt"), "mn40_as_no_im_pre(3)": urllib.parse.urljoin(model_url, "mn40_as_no_im_pre_mAP_482.pt"), "mn40_as_ext": urllib.parse.urljoin(model_url, "mn40_as_ext_mAP_487.pt"), "mn40_as_ext(2)": urllib.parse.urljoin(model_url, "mn40_as_ext_mAP_486.pt"), "mn40_as_ext(3)": urllib.parse.urljoin(model_url, "mn40_as_ext_mAP_485.pt"), # varying hop size (time resolution) "mn10_as_hop_5": urllib.parse.urljoin(model_url, "mn10_as_hop_5_mAP_475.pt"), "mn10_as_hop_15": urllib.parse.urljoin(model_url, "mn10_as_hop_15_mAP_463.pt"), "mn10_as_hop_20": urllib.parse.urljoin(model_url, "mn10_as_hop_20_mAP_456.pt"), "mn10_as_hop_25": urllib.parse.urljoin(model_url, "mn10_as_hop_25_mAP_447.pt"), # varying n_mels (frequency resolution) "mn10_as_mels_40": urllib.parse.urljoin(model_url, "mn10_as_mels_40_mAP_453.pt"), "mn10_as_mels_64": urllib.parse.urljoin(model_url, "mn10_as_mels_64_mAP_461.pt"), "mn10_as_mels_256": urllib.parse.urljoin(model_url, "mn10_as_mels_256_mAP_474.pt"), # fully-convolutional head "mn10_as_fc": urllib.parse.urljoin(model_url, "mn10_as_fc_mAP_465.pt"), "mn10_as_fc_s2221": urllib.parse.urljoin(model_url, "mn10_as_fc_s2221_mAP_466.pt"), "mn10_as_fc_s2211": urllib.parse.urljoin(model_url, "mn10_as_fc_s2211_mAP_466.pt"), } class MN(nn.Module): def __init__( self, inverted_residual_setting: List[InvertedResidualConfig], block: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, in_conv_kernel: int = 3, in_conv_stride: int = 2, in_channels: int = 1, **kwargs: Any, ) -> None: """ MobileNet V3 main class Args: inverted_residual_setting (List[InvertedResidualConfig]): Network structure block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for models norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use in_conv_kernel (int): Size of kernel for first convolution in_conv_stride (int): Size of stride for first convolution in_channels (int): Number of input channels """ super(MN, self).__init__() if not inverted_residual_setting: raise ValueError("The inverted_residual_setting should not be empty") elif not ( isinstance(inverted_residual_setting, Sequence) and all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting]) ): raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]") if block is None: block = InvertedResidual depthwise_norm_layer = norm_layer = \ norm_layer if norm_layer is not None else partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) layers: List[nn.Module] = [] kernel_sizes = [in_conv_kernel] strides = [in_conv_stride] # building first layer firstconv_output_channels = inverted_residual_setting[0].input_channels layers.append( ConvNormActivation( in_channels, firstconv_output_channels, kernel_size=in_conv_kernel, stride=in_conv_stride, norm_layer=norm_layer, activation_layer=nn.Hardswish, ) ) # get squeeze excitation config se_cnf = kwargs.get('se_conf', None) # building inverted residual blocks # - keep track of size of frequency and time dimensions for possible application of Squeeze-and-Excitation # on the frequency/time dimension # - applying Squeeze-and-Excitation on the time dimension is not recommended as this constrains the network to # a particular length of the audio clip, whereas Squeeze-and-Excitation on the frequency bands is fine, # as the number of frequency bands is usually not changing f_dim, t_dim = kwargs.get('input_dims', (128, 1000)) # take into account first conv layer f_dim = cnn_out_size(f_dim, 1, 1, 3, 2) t_dim = cnn_out_size(t_dim, 1, 1, 3, 2) for cnf in inverted_residual_setting: f_dim = cnf.out_size(f_dim, idx=0) t_dim = cnf.out_size(t_dim, idx=1) cnf.f_dim, cnf.t_dim = f_dim, t_dim # update dimensions in block config layers.append(block(cnf, se_cnf, norm_layer, depthwise_norm_layer)) kernel_sizes.append(cnf.kernel) strides.append(cnf.stride) # building last several layers lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_output_channels = 6 * lastconv_input_channels self.lastconv_output_channels = lastconv_output_channels layers.append( ConvNormActivation( lastconv_input_channels, lastconv_output_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.Hardswish, ) ) self.features = nn.Sequential(*layers) # no prediction head needed - we want to use Frame-MobileNet to extract a 3D sequence # i.e.: batch size x sequence length x channel dimension for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) if m.bias is not None: nn.init.zeros_(m.bias) def _forward_impl(self, x: Tensor, return_fmaps: bool = False) -> Tensor: fmaps = [] for i, layer in enumerate(self.features): x = layer(x) if return_fmaps: fmaps.append(x) # reshape: batch size x channels x frequency bands x time -> batch size x time x channels # works, because frequency dimension is exactly 1 x = x.squeeze(2).permute(0, 2, 1) return x def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) def load_model(self, path, wandb_id): ckpt_path = os.path.join(path, wandb_id + ".ckpt") pretrained_weights = torch.load(ckpt_path, map_location="cpu")["state_dict"] pretrained_weights = {k[10:]: v for k, v in pretrained_weights.items() if k[:10] == "net.model."} self.load_state_dict(pretrained_weights) print("Loaded model successfully. Wandb_id:", wandb_id) def _mobilenet_v3_conf( width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False, strides: Tuple[int] = None, dilation_list_t_dim: Optional[List[int]] = None, **kwargs ): reduce_divider = 2 if reduced_tail else 1 if dilation_list_t_dim is None: dilation_list_t_dim = [1] * 15 if dilated: dilation_list_t_dim[-3:] = [2] * 3 print("dilation_list_t_dim: ") print(dilation_list_t_dim) bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) if strides is None: # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 f_strides = (1, 2, 2, 1, 1, 2, 1, 2, 1, 1, 2, 1, 1, 1, 2) t_strides = (1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) strides = tuple(zip(f_strides, t_strides)) # InvertedResidualConfig: # input_channels, kernel, expanded_channels, out_channels, use_se, activation, stride, dilation inverted_residual_setting = [ bneck_conf(16, 3, 16, 16, False, "RE", strides[0], (1, dilation_list_t_dim[0])), # 0 bneck_conf(16, 3, 64, 24, False, "RE", strides[1], (1, dilation_list_t_dim[1])), # 1 - C1 bneck_conf(24, 3, 72, 24, False, "RE", strides[2], (1, dilation_list_t_dim[2])), # 2 bneck_conf(24, 5, 72, 40, True, "RE", strides[3], (1, dilation_list_t_dim[3])), # 3 - C2 bneck_conf(40, 5, 120, 40, True, "RE", strides[4], (1, dilation_list_t_dim[4])), # 4 bneck_conf(40, 5, 120, 40, True, "RE", strides[5], (1, dilation_list_t_dim[5])), # 5 bneck_conf(40, 3, 240, 80, False, "HS", strides[6], (1, dilation_list_t_dim[6])), # 6 - C3 bneck_conf(80, 3, 200, 80, False, "HS", strides[7], (1, dilation_list_t_dim[7])), # 7 bneck_conf(80, 3, 184, 80, False, "HS", strides[8], (1, dilation_list_t_dim[8])), # 8 bneck_conf(80, 3, 184, 80, False, "HS", strides[9], (1, dilation_list_t_dim[9])), # 9 bneck_conf(80, 3, 480, 112, True, "HS", strides[10], (1, dilation_list_t_dim[10])), # 10 bneck_conf(112, 3, 672, 112, True, "HS", strides[11], (1, dilation_list_t_dim[11])), # 11 bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", strides[12], (1, dilation_list_t_dim[12])), # 12 - C4 # dilation bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", strides[13], (1, dilation_list_t_dim[13])), # 13 # dilation bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", strides[14], (1, dilation_list_t_dim[14])), # 14 # dilation ] last_channel = adjust_channels(1280 // reduce_divider) return inverted_residual_setting, last_channel def _mobilenet_v3( inverted_residual_setting: List[InvertedResidualConfig], pretrained_name: str, **kwargs: Any, ): model = MN(inverted_residual_setting, **kwargs) if pretrained_name in pretrained_models: model_url = pretrained_models.get(pretrained_name) state_dict = load_state_dict_from_url(model_url, model_dir=model_dir, map_location="cpu") if kwargs['head_type'] == "mlp": num_classes = state_dict['classifier.5.bias'].size(0) elif kwargs['head_type'] == "fully_convolutional": num_classes = state_dict['classifier.1.bias'].size(0) else: print("Loading weights for classifier only implemented for head types 'mlp' and 'fully_convolutional'") num_classes = -1 if kwargs['num_classes'] != num_classes: # if the number of logits is not matching the state dict, # drop the corresponding pre-trained part pretrain_logits = state_dict['classifier.5.bias'].size(0) if kwargs['head_type'] == "mlp" \ else state_dict['classifier.1.bias'].size(0) print(f"Number of classes defined: {kwargs['num_classes']}, " f"but try to load pre-trained layer with logits: {pretrain_logits}\n" "Dropping last layer.") if kwargs['head_type'] == "mlp": del state_dict['classifier.5.weight'] del state_dict['classifier.5.bias'] else: state_dict = {k: v for k, v in state_dict.items() if not k.startswith('classifier')} try: model.load_state_dict(state_dict) except RuntimeError as e: print(str(e)) print("Loading weights pre-trained weights in a non-strict manner.") model.load_state_dict(state_dict, strict=False) elif pretrained_name: raise NotImplementedError(f"Model name '{pretrained_name}' unknown.") return model def mobilenet_v3(pretrained_name: str = None, **kwargs: Any) \ -> MN: """ Constructs a MobileNetV3 architecture from "Searching for MobileNetV3" ". """ inverted_residual_setting, last_channel = _mobilenet_v3_conf(**kwargs) return _mobilenet_v3(inverted_residual_setting, pretrained_name, **kwargs) def get_model(pretrained_name: str = None, width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False, dilation_list_t_dim=None, strides: Tuple[int, int, int, int] = None, head_type: str = "mlp", multihead_attention_heads: int = 4, input_dim_f: int = 128, input_dim_t: int = 1000, se_dims: str = 'c', se_agg: str = "max", se_r: int = 4): """ Arguments to modify the instantiation of a MobileNetv3 Args: pretrained_name (str): Specifies name of pre-trained model to load width_mult (float): Scales width of network reduced_tail (bool): Scales down network tail dilated (bool): Applies dilated convolution to network tail dilation_list_t_dim (List): List of dilation factors to apply to network tail strides (Tuple): Strides that are set to '2' in original implementation; might be changed to modify the size of receptive field and the downsampling factor in time and frequency dimension head_type (str): decides which classification head to use multihead_attention_heads (int): number of heads in case 'multihead_attention_heads' is used input_dim_f (int): number of frequency bands input_dim_t (int): number of time frames se_dims (Tuple): choose dimension to apply squeeze-excitation on, if multiple dimensions are chosen, then squeeze-excitation is applied concurrently and se layer outputs are fused by se_agg operation se_agg (str): operation to fuse output of concurrent se layers se_r (int): squeeze excitation bottleneck size se_dims (str): contains letters corresponding to dimensions 'c' - channel, 'f' - frequency, 't' - time """ dim_map = {'c': 1, 'f': 2, 't': 3} assert len(se_dims) <= 3 and all([s in dim_map.keys() for s in se_dims]) or se_dims == 'none' input_dims = (input_dim_f, input_dim_t) if se_dims == 'none': se_dims = None else: se_dims = [dim_map[s] for s in se_dims] se_conf = dict(se_dims=se_dims, se_agg=se_agg, se_r=se_r) m = mobilenet_v3(pretrained_name=pretrained_name, width_mult=width_mult, reduced_tail=reduced_tail, dilated=dilated, dilation_list_t_dim=dilation_list_t_dim, strides=strides, head_type=head_type, multihead_attention_heads=multihead_attention_heads, input_dims=input_dims, se_conf=se_conf ) print(m) return m