charlesxsh commited on
Commit
ffe9bd8
·
1 Parent(s): 72e2faf

add config

Browse files
Files changed (1) hide show
  1. custom_model.py +4 -1
custom_model.py CHANGED
@@ -1,5 +1,5 @@
1
  # custom_model.py
2
- from transformers import PreTrainedModel, PretrainedConfig
3
  import torch
4
  import torch.nn as nn
5
 
@@ -19,3 +19,6 @@ class CustomModel(PreTrainedModel):
19
  def forward(self, input_ids):
20
  output = self.linear(input_ids)
21
  return output
 
 
 
 
1
  # custom_model.py
2
+ from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel
3
  import torch
4
  import torch.nn as nn
5
 
 
19
  def forward(self, input_ids):
20
  output = self.linear(input_ids)
21
  return output
22
+
23
+ AutoConfig.register("custom-model", CustomModelConfig)
24
+ AutoModel.register(CustomModelConfig, CustomModel)