| import torch.nn as nn |
| import torch |
| from transformers import AutoModel, AutoConfig |
|
|
| class RefactorSpanModel(nn.Module): |
| def __init__(self): |
| super().__init__() |
| base_model_path = 'microsoft/codebert-base' |
| self.base_config = AutoConfig.from_pretrained(base_model_path) |
| self.base_model = AutoModel.from_config(self.base_config) |
| self.dropout = nn.Dropout(0.5) |
| self.classifier = nn.Linear(768, 1) |
| self.start_span = nn.Linear(768, 1) |
|
|
| def forward(self, input_ids): |
| outputs = self.base_model(input_ids) |
| outputs_pool = self.dropout(outputs[1]) |
| outputs_hidden = self.dropout(outputs[0]) |
| refactor = self.classifier(outputs_pool) |
| span = self.start_span(outputs_hidden) |
| return refactor, span |
| |
| class RefactorModel(nn.Module): |
| def __init__(self): |
| super().__init__() |
| base_model_path = 'microsoft/codebert-base' |
| self.base_config = AutoConfig.from_pretrained(base_model_path) |
| self.base_model = AutoModel.from_config(self.base_config) |
| self.dropout = nn.Dropout(0.5) |
| self.classifier = nn.Linear(768, 1) |
|
|
| def forward(self, input_ids): |
| outputs = self.base_model(input_ids) |
| outputs_pool = self.dropout(outputs[1]) |
| refactor = self.classifier(outputs_pool) |
| return refactor |
| |
| if __name__ == "__main__": |
| checkpoint = 'pytorch_model_RSP.bin' |
| model = RefactorSpanModel() |
| model.load_state_dict(torch.load(checkpoint), strict=True) |
| |
| |
| checkpoint = 'pytorch_model_RP.bin' |
| model = RefactorModel() |
| model.load_state_dict(torch.load(checkpoint), strict=True) |
| |