|
import torch.nn as nn |
|
import torch |
|
from torchvision import models |
|
import numpy as np |
|
|
|
class EncodingBackbone(nn.Module): |
|
def __init__(self, encoding_size=256): |
|
super(EncodingBackbone, self).__init__() |
|
|
|
|
|
self.backbone = models.resnet50(pretrained=True) |
|
|
|
|
|
self.backbone = nn.Sequential(*list(self.backbone.children())[:-2]) |
|
|
|
|
|
self.global_avg_pooling = nn.AdaptiveAvgPool2d((1, 1)) |
|
|
|
|
|
self.encoding_layer = nn.Linear(2048, encoding_size) |
|
|
|
|
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, x): |
|
|
|
x = self.backbone(x) |
|
|
|
|
|
x = self.global_avg_pooling(x) |
|
|
|
|
|
x = x.view(x.size(0), -1) |
|
|
|
|
|
encoding = self.encoding_layer(x) |
|
|
|
return encoding |