| | from transformers import PreTrainedModel |
| | from transformers import PretrainedConfig |
| | from typing import List |
| | import torch.nn as nn |
| | import torch |
| |
|
| |
|
| | class MyModelConfig(PretrainedConfig): |
| |
|
| | def __init__( |
| | self, |
| | input_dim=100, |
| | layers_num=5, |
| | **kwargs, |
| | ): |
| | self.input_dim = input_dim |
| | self.layers_num = layers_num |
| | super().__init__(**kwargs) |
| |
|
| | class MyModel(PreTrainedModel): |
| | config_class = MyModelConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | modules = [] |
| | assert config.layers_num >= 1 |
| | if config.layers_num == 1: |
| | modules.append(nn.Linear(config.input_dim,1)) |
| | else: |
| | modules.append(nn.Linear(config.input_dim,30)) |
| | for i in range(config.layers_num-2): |
| | modules.append(nn.Linear(30,30)) |
| | modules.append(nn.Linear(30,1)) |
| | self.model = nn.ModuleList(modules) |
| |
|
| |
|
| | def forward(self, tensor): |
| | return self.model(tensor) |
| |
|
| | if __name__ == '__main__': |
| | save_config = MyModelConfig(input_dim=10,layers_num=3) |
| | save_config.save_pretrained("custom-mymodel") |
| | mymodel = MyModel(save_config) |
| | torch.save(mymodel.state_dict(),'pytorch_model.bin') |