| import config |
| import torch |
| import torch.nn as nn |
| from pretrained_models import load_esm2_model |
| from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel |
|
|
| class MembraneTokenizer: |
| def __init__(self, esm_model_path=config.ESM_MODEL_PATH): |
| self.tokenizer = AutoTokenizer.from_pretrained(esm_model_path) |
| |
| def __getattr__(self, name): |
| return getattr(self.tokenizer, name) |
|
|
| def __call__(self, *args, **kwargs): |
| return self.tokenizer(*args, **kwargs) |
| |
| def save_tokenizer(self, save_dir): |
| self.tokenizer.save_pretrained(save_dir) |
| |
| def load_tokenizer(self, load_dir): |
| self.tokenizer.save_pretrained(load_dir) |
|
|
| class MembraneMLM: |
| def __init__(self, esm_model_path=config.ESM_MODEL_PATH): |
| self.model = AutoModelForMaskedLM.from_pretrained(esm_model_path) |
| self.tokenizer = AutoTokenizer.from_pretrained(esm_model_path) |
| |
| def __getattr__(self, name): |
| return getattr(self.model, name) |
|
|
| def __call__(self, *args, **kwargs): |
| return self.model(*args, **kwargs) |
|
|
| def freeze_model(self): |
| |
| for param in self.model.parameters(): |
| param.requires_grad = False |
|
|
| def unfreeze_n_layers(self): |
| |
| model_layers = len(self.model.esm.encoder.layer) |
|
|
| |
| for i, layer in enumerate(self.model.esm.encoder.layer): |
| if i >= model_layers-config.ESM_LAYERS: |
| for module in layer.attention.self.key.modules(): |
| for param in module.parameters(): |
| param.requires_grad = True |
| for module in layer.attention.self.query.modules(): |
| for param in module.parameters(): |
| param.requires_grad = True |
| for module in layer.attention.self.value.modules(): |
| for param in module.parameters(): |
| param.requires_grad = True |
| |
| def forward(self, **inputs): |
| return self.model(**inputs) |
|
|
| def save_model(self, save_dir): |
| self.model.save_pretrained(save_dir) |
| self.tokenizer.save_pretrained(save_dir) |
|
|
| def load_model(self, load_dir): |
| self.model = AutoModel.from_pretrained(load_dir) |
| self.tokenizer = AutoTokenizer.from_pretrained(load_dir) |