File size: 2,237 Bytes
d568351 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 | import torch.nn as nn
class CalcModeClassifier(nn.Module):
def __init__(self, input_dim, num_calc_mode_labels, dropout_rate=0.):
super(CalcModeClassifier, self).__init__()
self.dropout = nn.Dropout(dropout_rate)
self.linear = nn.Linear(input_dim, num_calc_mode_labels)
def forward(self, x):
x = self.dropout(x)
return self.linear(x)
class ActivityClassifier(nn.Module):
def __init__(self, input_dim, num_activity_labels, dropout_rate=0.):
super(ActivityClassifier, self).__init__()
self.dropout = nn.Dropout(dropout_rate)
self.linear = nn.Linear(input_dim, num_activity_labels)
def forward(self, x):
x = self.dropout(x)
return self.linear(x)
class RegionClassifier(nn.Module):
def __init__(self, input_dim, num_region_labels, dropout_rate=0.):
super(RegionClassifier, self).__init__()
self.dropout = nn.Dropout(dropout_rate)
self.linear = nn.Linear(input_dim, num_region_labels)
def forward(self, x):
x = self.dropout(x)
return self.linear(x)
class InvestmentClassifier(nn.Module):
def __init__(self, input_dim, num_investment_labels, dropout_rate=0.):
super(InvestmentClassifier, self).__init__()
self.dropout = nn.Dropout(dropout_rate)
self.linear = nn.Linear(input_dim, num_investment_labels)
def forward(self, x):
x = self.dropout(x)
return self.linear(x)
class ReqFormClassifier(nn.Module):
def __init__(self, input_dim, num_req_form_labels, dropout_rate=0.):
super(ReqFormClassifier, self).__init__()
self.dropout = nn.Dropout(dropout_rate)
self.linear = nn.Linear(input_dim, num_req_form_labels)
def forward(self, x):
x = self.dropout(x)
return self.linear(x)
class SlotClassifier(nn.Module):
def __init__(self, input_dim, num_slot_labels, dropout_rate=0.):
super(SlotClassifier, self).__init__()
self.dropout = nn.Dropout(dropout_rate)
self.linear = nn.Linear(input_dim, num_slot_labels)
def forward(self, x):
x = self.dropout(x)
return self.linear(x)
|