yhamidullah commited on
Commit
3505b99
·
verified ·
1 Parent(s): 90d6057

create models.py

Browse files
Files changed (1) hide show
  1. models.py +28 -0
models.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import PreTrainedModel
4
+
5
+ class CustomClassifier(PreTrainedModel):
6
+ config_class = CustomClassificationConfig
7
+
8
+ def __init__(self, config):
9
+ super().__init__(config)
10
+ self.encoder = nn.Sequential(
11
+ nn.Linear(config.input_dim, config.hidden_dim),
12
+ nn.ReLU(),
13
+ nn.Linear(config.hidden_dim, config.hidden_dim),
14
+ nn.ReLU(),
15
+ )
16
+ self.classifier = nn.Linear(config.hidden_dim, config.num_classes)
17
+
18
+ def forward(self, input_ids=None, labels=None, **kwargs):
19
+ # input_ids: shape (batch_size, input_dim)
20
+ hidden = self.encoder(input_ids)
21
+ logits = self.classifier(hidden)
22
+
23
+ loss = None
24
+ if labels is not None:
25
+ loss_fn = nn.CrossEntropyLoss()
26
+ loss = loss_fn(logits, labels)
27
+
28
+ return {"loss": loss, "logits": logits}