| | import torch.nn as nn |
| | import torch |
| | from transformers import AutoModel, AutoConfig |
| |
|
| | class RefactorSpanModel(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | base_model_path = 'microsoft/codebert-base' |
| | self.base_config = AutoConfig.from_pretrained(base_model_path) |
| | self.base_model = AutoModel.from_config(self.base_config) |
| | self.dropout = nn.Dropout(0.5) |
| | self.classifier = nn.Linear(768, 1) |
| | self.start_span = nn.Linear(768, 1) |
| |
|
| | def forward(self, input_ids): |
| | outputs = self.base_model(input_ids) |
| | outputs_pool = self.dropout(outputs[1]) |
| | outputs_hidden = self.dropout(outputs[0]) |
| | refactor = self.classifier(outputs_pool) |
| | span = self.start_span(outputs_hidden) |
| | return refactor, span |
| | |
| | class RefactorModel(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | base_model_path = 'microsoft/codebert-base' |
| | self.base_config = AutoConfig.from_pretrained(base_model_path) |
| | self.base_model = AutoModel.from_config(self.base_config) |
| | self.dropout = nn.Dropout(0.5) |
| | self.classifier = nn.Linear(768, 1) |
| |
|
| | def forward(self, input_ids): |
| | outputs = self.base_model(input_ids) |
| | outputs_pool = self.dropout(outputs[1]) |
| | refactor = self.classifier(outputs_pool) |
| | return refactor |
| | |
| | if __name__ == "__main__": |
| | checkpoint = 'pytorch_model_RSP.bin' |
| | model = RefactorSpanModel() |
| | model.load_state_dict(torch.load(checkpoint), strict=True) |
| | |
| | |
| | checkpoint = 'pytorch_model_RP.bin' |
| | model = RefactorModel() |
| | model.load_state_dict(torch.load(checkpoint), strict=True) |
| | |