|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import AutoModel, AutoConfig |
|
|
|
|
|
class CodeEmbedder(nn.Module): |
|
|
""" |
|
|
A wrapper around a Transformer model (default: CodeBERT) to produce |
|
|
dense vector embeddings for code snippets using Mean Pooling. |
|
|
""" |
|
|
def __init__(self, model_name_or_path="microsoft/codebert-base", trust_remote_code=False): |
|
|
super(CodeEmbedder, self).__init__() |
|
|
self.config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code) |
|
|
self.encoder = AutoModel.from_pretrained(model_name_or_path, config=self.config, trust_remote_code=trust_remote_code) |
|
|
|
|
|
def mean_pooling(self, token_embeddings, attention_mask): |
|
|
""" |
|
|
Average the token embeddings, ignoring padding tokens. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
|
|
|
|
|
|
|
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) |
|
|
|
|
|
|
|
|
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
|
|
return sum_embeddings / sum_mask |
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
|
|
|
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
|
|
|
|
|
|
|
|
|
|
|
last_hidden_state = outputs.last_hidden_state |
|
|
|
|
|
|
|
|
embeddings = self.mean_pooling(last_hidden_state, attention_mask) |
|
|
|
|
|
|
|
|
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) |
|
|
|
|
|
return embeddings |
|
|
|