yhamidullah commited on
Commit
4365a31
·
verified ·
1 Parent(s): b62f504

update models.py

Browse files
Files changed (1) hide show
  1. models.py +11 -0
models.py CHANGED
@@ -2,6 +2,17 @@ import torch
2
  from torch import nn
3
  from transformers import PreTrainedModel
4
 
 
 
 
 
 
 
 
 
 
 
 
5
  class CustomClassifier(PreTrainedModel):
6
  config_class = CustomClassificationConfig
7
 
 
2
  from torch import nn
3
  from transformers import PreTrainedModel
4
 
5
+ from transformers import PretrainedConfig
6
+
7
+ class CustomClassificationConfig(PretrainedConfig):
8
+ model_type = "custom_classifier"
9
+
10
+ def __init__(self, input_dim=32, hidden_dim=64, num_classes=2, **kwargs):
11
+ super().__init__(**kwargs)
12
+ self.input_dim = input_dim
13
+ self.hidden_dim = hidden_dim
14
+ self.num_classes = num_classes
15
+
16
  class CustomClassifier(PreTrainedModel):
17
  config_class = CustomClassificationConfig
18