|
import torch.nn as nn |
|
from torchvision import models |
|
|
|
class MTL(nn.Module): |
|
def __init__(self, num_classes_school, num_classes_type): |
|
super(MTL, self).__init__() |
|
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) |
|
self.resnet_feature_extractor = nn.Sequential(*list(resnet.children())[:-1]) |
|
self.num_features = resnet.fc.in_features |
|
|
|
self.class_school_head = nn.Sequential( |
|
nn.Linear(self.num_features, num_classes_school) |
|
) |
|
self.class_type_head = nn.Sequential( |
|
nn.Linear(self.num_features, num_classes_type) |
|
) |
|
|
|
def forward(self, img): |
|
visual_emb = self.resnet_feature_extractor(img) |
|
visual_emb = visual_emb.view(visual_emb.size(0), -1) |
|
|
|
out_school = self.class_school_head(visual_emb) |
|
out_type = self.class_type_head(visual_emb) |
|
|
|
return out_school, out_type |
|
|