| from transformers import PreTrainedModel |
| from transformers import PretrainedConfig |
| from typing import List |
| import torch.nn as nn |
| import torch |
|
|
|
|
| class MyModelConfig(PretrainedConfig): |
|
|
| def __init__( |
| self, |
| input_dim=100, |
| layers_num=5, |
| **kwargs, |
| ): |
| self.input_dim = input_dim |
| self.layers_num = layers_num |
| super().__init__(**kwargs) |
|
|
| class MyModel(PreTrainedModel): |
| config_class = MyModelConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| modules = [] |
| assert config.layers_num >= 1 |
| if config.layers_num == 1: |
| modules.append(nn.Linear(config.input_dim,1)) |
| else: |
| modules.append(nn.Linear(config.input_dim,30)) |
| for i in range(config.layers_num-2): |
| modules.append(nn.Linear(30,30)) |
| modules.append(nn.Linear(30,1)) |
| self.model = nn.ModuleList(modules) |
|
|
|
|
| def forward(self, tensor): |
| return self.model(tensor) |
|
|
| if __name__ == '__main__': |
| save_config = MyModelConfig(input_dim=10,layers_num=3) |
| save_config.save_pretrained("custom-mymodel") |
| mymodel = MyModel(save_config) |
| torch.save(mymodel.state_dict(),'pytorch_model.bin') |