| import torch |
| import config |
| from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM |
|
|
| def load_esm2_model(model_name): |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| masked_model = AutoModelForMaskedLM.from_pretrained(model_name) |
| embedding_model = AutoModel.from_pretrained(model_name) |
| return tokenizer, masked_model, embedding_model |
|
|
| def get_latents(model, tokenizer, sequence, device): |
| inputs = tokenizer(sequence, return_tensors="pt").to(device) |
| with torch.no_grad(): |
| outputs = model(**inputs).last_hidden_state.squeeze(0) |
| return outputs |