# Author: David Harwath import torch import torch.nn as nn import torch.nn.functional import torch.nn.functional import torch.nn.functional as F import torch.utils.model_zoo as model_zoo import torchvision.models as imagemodels class Davenet(nn.Module): def __init__(self, embedding_dim=1024): super(Davenet, self).__init__() self.embedding_dim = embedding_dim self.batchnorm1 = nn.BatchNorm2d(1) self.conv1 = nn.Conv2d(1, 128, kernel_size=(40, 1), stride=(1, 1), padding=(0, 0)) self.conv2 = nn.Conv2d(128, 256, kernel_size=(1, 11), stride=(1, 1), padding=(0, 5)) self.conv3 = nn.Conv2d(256, 512, kernel_size=(1, 17), stride=(1, 1), padding=(0, 8)) self.conv4 = nn.Conv2d(512, 512, kernel_size=(1, 17), stride=(1, 1), padding=(0, 8)) self.conv5 = nn.Conv2d(512, embedding_dim, kernel_size=(1, 17), stride=(1, 1), padding=(0, 8)) self.pool = nn.MaxPool2d(kernel_size=(1, 3), stride=(1, 2), padding=(0, 1)) def forward(self, x): if x.dim() == 3: x = x.unsqueeze(1) x = self.batchnorm1(x) x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = self.pool(x) x = F.relu(self.conv3(x)) x = self.pool(x) x = F.relu(self.conv4(x)) x = self.pool(x) x = F.relu(self.conv5(x)) x = self.pool(x) x = x.squeeze(2) return x class Resnet18(imagemodels.ResNet): def __init__(self, embedding_dim=1024, pretrained=False): super(Resnet18, self).__init__(imagemodels.resnet.BasicBlock, [2, 2, 2, 2]) if pretrained: self.load_state_dict(model_zoo.load_url(imagemodels.resnet.model_urls['resnet18'])) self.avgpool = None self.fc = None self.embedder = nn.Conv2d(512, embedding_dim, kernel_size=1, stride=1, padding=0) self.embedding_dim = embedding_dim self.pretrained = pretrained def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.embedder(x) return x class Resnet34(imagemodels.ResNet): def __init__(self, embedding_dim=1024, pretrained=False): super(Resnet34, self).__init__(imagemodels.resnet.BasicBlock, [3, 4, 6, 3]) if pretrained: self.load_state_dict(model_zoo.load_url(imagemodels.resnet.model_urls['resnet34'])) self.avgpool = None self.fc = None self.embedder = nn.Conv2d(512, embedding_dim, kernel_size=1, stride=1, padding=0) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.embedder(x) return x class Resnet50(imagemodels.ResNet): def __init__(self, embedding_dim=1024, pretrained=False): super(Resnet50, self).__init__(imagemodels.resnet.Bottleneck, [3, 4, 6, 3]) if pretrained: self.load_state_dict(model_zoo.load_url(imagemodels.resnet.model_urls['resnet50'])) self.avgpool = None self.fc = None self.embedder = nn.Conv2d(2048, embedding_dim, kernel_size=1, stride=1, padding=0) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.embedder(x) return x class VGG16(nn.Module): def __init__(self, embedding_dim=1024, pretrained=False): super(VGG16, self).__init__() seed_model = imagemodels.__dict__['vgg16'](pretrained=pretrained).features seed_model = nn.Sequential(*list(seed_model.children())[:-1]) # remove final maxpool last_layer_index = len(list(seed_model.children())) seed_model.add_module(str(last_layer_index), nn.Conv2d(512, embedding_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))) self.image_model = seed_model def forward(self, x): x = self.image_model(x) return x def prep(dict): return {k.replace("module.", ""): v for k, v in dict.items()} class DavenetAudioFeaturizer(nn.Module): def __init__(self): super().__init__() self.audio_model = Davenet() self.audio_model.load_state_dict(prep(torch.load("../models/davenet_pt_audio.pth"))) def forward(self, audio, include_cls): patch_tokens = self.audio_model(audio).unsqueeze(2) if include_cls: return patch_tokens, None else: return patch_tokens def get_last_params(self): return [] class DavenetImageFeaturizer(nn.Module): def __init__(self): super().__init__() self.image_model = VGG16() self.image_model.load_state_dict(prep(torch.load("../models/davenet_pt_image.pth"))) def forward(self, image, include_cls): patch_tokens = self.image_model(image) if include_cls: return patch_tokens, None else: return patch_tokens def get_last_params(self): return []