backend / MTL.py
Irina
pls
21f70e6
raw
history blame contribute delete
910 Bytes
import torch.nn as nn
from torchvision import models
class MTL(nn.Module):
def __init__(self, num_classes_school, num_classes_type):
super(MTL, self).__init__()
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
self.resnet_feature_extractor = nn.Sequential(*list(resnet.children())[:-1])
self.num_features = resnet.fc.in_features
self.class_school_head = nn.Sequential(
nn.Linear(self.num_features, num_classes_school)
)
self.class_type_head = nn.Sequential(
nn.Linear(self.num_features, num_classes_type)
)
def forward(self, img):
visual_emb = self.resnet_feature_extractor(img)
visual_emb = visual_emb.view(visual_emb.size(0), -1)
out_school = self.class_school_head(visual_emb)
out_type = self.class_type_head(visual_emb)
return out_school, out_type