| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedTokenizer |
| from transformers.tokenization_utils_base import BatchEncoding |
| from transformers import AutoTokenizer, AutoModel |
| from rdkit import Chem |
| from rdkit.Chem import Descriptors, AllChem, MACCSkeys |
| from rdkit.ML.Descriptors import MoleculeDescriptors |
| from rdkit import RDLogger |
| from rdkit.Chem import Draw |
| import joblib |
| import numpy as np |
| import os |
| from huggingface_hub import snapshot_download |
| import warnings |
| from sklearn.exceptions import InconsistentVersionWarning |
| from torchvision import models, transforms |
| from PIL import Image |
| warnings.filterwarnings("ignore", category=InconsistentVersionWarning) |
| RDLogger.DisableLog('rdApp.*') |
|
|
| class BBBTokenizer(PreTrainedTokenizer): |
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
| |
| self.calc = MoleculeDescriptors.MolecularDescriptorCalculator([i[0] for i in Descriptors.descList]) |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-100M-MLM') |
| self.chemberta = AutoModel.from_pretrained('DeepChem/ChemBERTa-100M-MLM').eval() |
|
|
| self.resnet50_backbone = models.resnet50(weights="IMAGENET1K_V1") |
| self.resnet = nn.Sequential(*list(self.resnet50_backbone.children())[:-1]).eval() |
| self.img_preprocess = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225], |
| ) |
| ]) |
|
|
| self.feature_transformer_tab = None |
| self.feature_transformer_img = None |
| self.feature_transformer_txt = None |
| self.task = None |
|
|
| def generate_tab_features(self, smiles): |
| mol = Chem.MolFromSmiles(smiles) |
| |
| if mol is None: |
| return torch.tensor(self.feature_transformer_tab.n_features_in_, dtype=torch.float32) |
| |
| rdkit_2d = np.array(self.calc.CalcDescriptors(mol)) |
| rdkit_2d[np.isinf(rdkit_2d)] = np.nan |
| rdkit_2d = np.nan_to_num(rdkit_2d, nan=0.0, posinf=0.0, neginf=0.0) |
| maccs = np.array(list(MACCSkeys.GenMACCSKeys(mol).ToBitString()), dtype=int) |
| tab_input = np.concatenate([rdkit_2d, maccs]) |
| tab_input = self.feature_transformer_tab.transform(tab_input.reshape(1, -1))[0] |
| tab_input = np.clip(tab_input, -1e5, 1e5) |
| return torch.tensor(tab_input, dtype=torch.float32) |
|
|
| def generate_img_features(self, smiles): |
| mol = Chem.MolFromSmiles(smiles) |
| if mol is None: |
| img = Image.new("RGB", (300,300), color=(0,0,0)) |
| else: |
| img = Draw.MolToImage(mol, size=(300, 300)) |
| img = self.img_preprocess(img) |
| with torch.no_grad(): |
| img_input = self.resnet(img.unsqueeze(0)).squeeze(-1).squeeze(-1) |
| img_input = self.feature_transformer_img.transform(img_input.reshape(1, -1))[0] |
| return torch.tensor(img_input, dtype=torch.float32) |
|
|
| def generate_txt_features(self, smiles): |
| encoded = self.tokenizer(smiles, return_tensors="pt") |
| with torch.no_grad(): |
| outputs = self.chemberta(**encoded) |
| hidden_states = outputs.last_hidden_state[0].mean(axis=0).numpy() |
| txt_input = self.feature_transformer_txt.transform(hidden_states.reshape(1, -1))[0] |
| return torch.tensor(txt_input, dtype=torch.float32) |
|
|
| def _batch_encode_plus( |
| self, |
| batch_smiles: list[str], |
| task: str = 'classification', |
| return_tensors: str = "pt", |
| **kwargs |
| ): |
| if self.task is None or self.task != task: |
| if task == 'classification': |
| model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_cls_tabular.joblib"]) |
| transformer_tab_path = os.path.join(model_dir, "normalize_cls_tabular.joblib") |
| model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_cls_image.joblib"]) |
| transformer_img_path = os.path.join(model_dir, "normalize_cls_image.joblib") |
| model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_cls_text.joblib"]) |
| transformer_txt_path = os.path.join(model_dir, "normalize_cls_text.joblib") |
| self.task = task |
|
|
| elif task == 'regression': |
| model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_reg_tabular.joblib"]) |
| transformer_tab_path = os.path.join(model_dir, "normalize_reg_tabular.joblib") |
| model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_reg_image.joblib"]) |
| transformer_img_path = os.path.join(model_dir, "normalize_reg_image.joblib") |
| model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_reg_text.joblib"]) |
| transformer_txt_path = os.path.join(model_dir, "normalize_reg_text.joblib") |
| self.task = task |
|
|
| else: |
| raise ValueError('task not defined') |
| return |
|
|
| self.feature_transformer_tab = joblib.load(transformer_tab_path) |
| self.feature_transformer_img = joblib.load(transformer_img_path) |
| self.feature_transformer_txt = joblib.load(transformer_txt_path) |
| |
| data_list = [] |
| tab, img, txt = [], [], [] |
|
|
| for smiles in batch_smiles: |
| tab.append(self.generate_tab_features(smiles)) |
| img.append(self.generate_img_features(smiles)) |
| txt.append(self.generate_txt_features(smiles)) |
|
|
| tab = torch.stack(tab) |
| img = torch.stack(img) |
| txt = torch.stack(txt) |
|
|
| output = {} |
| output["tab"] = tab |
| output["img"] = img |
| output["txt"] = txt |
| |
| return BatchEncoding(output, tensor_type=return_tensors) |
|
|
| def encode(self, |
| batch_smiles: list[str], |
| task: str = 'classification', |
| return_tensors: str = "pt", |
| **kwargs): |
| return self._batch_encode_plus(batch_smiles, task, return_tensors, **kwargs) |
|
|
| def __call__(self, |
| batch_smiles: list[str], |
| task: str = 'classification', |
| return_tensors: str = "pt", |
| **kwargs): |
| return self._batch_encode_plus(batch_smiles, task, return_tensors, **kwargs) |
| |
| def _tokenize(self, text, **kwargs): |
| return [] |
|
|
| def save_vocabulary(self, save_directory, filename_prefix=None): |
| return () |
| |
| def get_vocab(self): |
| return {"<pad>":0, "<bos>":1, "<eos>":2, "<unk>":3, "<mask>":4} |
|
|
| @property |
| def vocab_size(self): |
| return len(self.get_vocab()) |