| | from torch import nn |
| | from easydict import EasyDict as MyEasyDict |
| | from transformers import BertModel, PreTrainedModel, BertConfig, PretrainedConfig |
| |
|
| |
|
| | class BertConfig(PretrainedConfig): |
| | model_type = "bert" |
| |
|
| | def __init__( |
| | self, |
| | model_config=None, |
| | **kwargs): |
| | super().__init__(**kwargs) |
| | self.model_config = MyEasyDict(model_config) |
| |
|
| |
|
| | class BERTClassifier(PreTrainedModel): |
| | config_class = BertConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.bert = BertModel(config) |
| | self.dropout = nn.Dropout(0.1) |
| | self.fc = nn.Linear(self.bert.config.hidden_size, 16) |
| |
|
| | def forward(self, input_ids, attention_mask): |
| | outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
| | pooled_output = outputs.pooler_output |
| | x = self.dropout(pooled_output) |
| | logits = self.fc(x) |
| | return logits |
| | |
| | def print_test(self, x): |
| | return "lmao" |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | from transformers import BertConfig, BertModel, BertForMaskedLM, AutoConfig |
| |
|
| | |
| | config = AutoConfig.from_pretrained('google-bert/bert-base-uncased', trust_remote_code=True) |
| | model = BERTClassifier(config) |