File size: 910 Bytes
fd7d432
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch.nn as nn
from torchvision import models

class MTL(nn.Module):
    def __init__(self, num_classes_school, num_classes_type):
        super(MTL, self).__init__()
        resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.resnet_feature_extractor = nn.Sequential(*list(resnet.children())[:-1])
        self.num_features = resnet.fc.in_features

        self.class_school_head = nn.Sequential(
            nn.Linear(self.num_features, num_classes_school)
        )
        self.class_type_head = nn.Sequential(
            nn.Linear(self.num_features, num_classes_type)
        )

    def forward(self, img):
        visual_emb = self.resnet_feature_extractor(img)
        visual_emb = visual_emb.view(visual_emb.size(0), -1)

        out_school = self.class_school_head(visual_emb)
        out_type = self.class_type_head(visual_emb)

        return out_school, out_type