| from __future__ import annotations |
|
|
| import contextlib |
| import hashlib |
| import io |
| import json |
| import os |
| import re |
| from dataclasses import dataclass |
| from functools import lru_cache |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from huggingface_hub import snapshot_download |
| from huggingface_hub.utils import disable_progress_bars |
| from rdkit import Chem, DataStructs, RDLogger |
| from rdkit.Chem import AllChem, Crippen, Descriptors, Lipinski, MACCSkeys, rdMolDescriptors |
| from rdkit.Chem.MolStandardize import rdMolStandardize |
| from sentence_transformers import SentenceTransformer |
| from torch import nn |
| from transformers import AutoModel, AutoTokenizer |
| from transformers.utils import logging as transformers_logging |
|
|
| os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
| disable_progress_bars() |
| transformers_logging.set_verbosity_error() |
|
|
| RDLogger.DisableLog("rdApp.*") |
|
|
| DEFAULT_ASSAY_TASK = ( |
| "Given a bioassay description and metadata, represent the assay for ranking compatible small molecules." |
| ) |
| DEFAULT_DESCRIPTOR_NAMES = ( |
| "mol_wt", |
| "logp", |
| "tpsa", |
| "heavy_atoms", |
| "hbd", |
| "hba", |
| "rot_bonds", |
| "ring_count", |
| "aromatic_rings", |
| "aliphatic_rings", |
| "saturated_rings", |
| "fraction_csp3", |
| "heteroatoms", |
| "amide_bonds", |
| "fragments", |
| "formal_charge", |
| "max_atomic_num", |
| "metal_atom_count", |
| "halogen_count", |
| "nitrogen_count", |
| "oxygen_count", |
| "sulfur_count", |
| "phosphorus_count", |
| "fluorine_count", |
| "chlorine_count", |
| "bromine_count", |
| "iodine_count", |
| "aromatic_atom_count", |
| "spiro_atoms", |
| "bridgehead_atoms", |
| ) |
| ORGANIC_LIKE_ATOMIC_NUMBERS = {1, 5, 6, 7, 8, 9, 14, 15, 16, 17, 35, 53} |
| SECTION_ORDER = [ |
| "ASSAY_TITLE", |
| "DESCRIPTION", |
| "ORGANISM", |
| "READOUT", |
| "ASSAY_FORMAT", |
| "ASSAY_TYPE", |
| "TARGET_UNIPROT", |
| ] |
| ASSAY_SECTION_RE = re.compile(r"\[(ASSAY_TITLE|DESCRIPTION|ORGANISM|READOUT|ASSAY_FORMAT|ASSAY_TYPE|TARGET_UNIPROT)\]\n") |
| ORGANISM_ALIASES = { |
| "9606": "homo_sapiens", |
| "10090": "mus_musculus", |
| "10116": "rattus_norvegicus", |
| "4932": "saccharomyces_cerevisiae", |
| } |
|
|
|
|
| @dataclass |
| class AssayQuery: |
| title: str = "" |
| description: str = "" |
| organism: str = "" |
| readout: str = "" |
| assay_format: str = "" |
| assay_type: str = "" |
| target_uniprot: list[str] | None = None |
|
|
|
|
| def smiles_sha256(smiles: str) -> str: |
| return hashlib.sha256(smiles.encode("utf-8")).hexdigest() |
|
|
|
|
| @contextlib.contextmanager |
| def _silent_imports(): |
| buffer = io.StringIO() |
| with contextlib.redirect_stdout(buffer), contextlib.redirect_stderr(buffer): |
| yield |
|
|
|
|
| @lru_cache(maxsize=1_000_000) |
| def _standardize_smiles_v2_cached(smiles: str) -> str | None: |
| mol = Chem.MolFromSmiles(smiles) |
| if mol is None: |
| return None |
| try: |
| mol = rdMolStandardize.Cleanup(mol) |
| mol = rdMolStandardize.FragmentParent(mol) |
| mol = rdMolStandardize.Uncharger().uncharge(mol) |
| mol = rdMolStandardize.TautomerEnumerator().Canonicalize(mol) |
| Chem.SanitizeMol(mol) |
| except Exception: |
| return None |
| if mol.GetNumHeavyAtoms() < 2: |
| return None |
| standardized = Chem.MolToSmiles(mol, canonical=True, isomericSmiles=True) |
| if not standardized or "." in standardized: |
| return None |
| return standardized |
|
|
|
|
| def standardize_smiles_v2(smiles: str | None) -> str | None: |
| if not smiles: |
| return None |
| token = smiles.strip() |
| if not token: |
| return None |
| return _standardize_smiles_v2_cached(token) |
|
|
|
|
| def serialize_assay_query(query: AssayQuery) -> str: |
| targets = ", ".join(query.target_uniprot or []) |
| values = { |
| "ASSAY_TITLE": query.title.strip(), |
| "DESCRIPTION": query.description.strip(), |
| "ORGANISM": query.organism.strip(), |
| "READOUT": query.readout.strip(), |
| "ASSAY_FORMAT": query.assay_format.strip(), |
| "ASSAY_TYPE": query.assay_type.strip(), |
| "TARGET_UNIPROT": targets.strip(), |
| } |
| return "\n\n".join(f"[{key}]\n{values[key]}" for key in SECTION_ORDER) |
|
|
|
|
| def _parse_assay_sections(assay_text: str) -> dict[str, str]: |
| sections = {key: "" for key in SECTION_ORDER} |
| parts = ASSAY_SECTION_RE.split(assay_text) |
| for idx in range(1, len(parts), 2): |
| key = parts[idx] |
| value = parts[idx + 1] if idx + 1 < len(parts) else "" |
| if key in sections: |
| sections[key] = value.strip() |
| return sections |
|
|
|
|
| def _hash_bucket(value: str, dim: int) -> int: |
| return abs(hash(value)) % max(dim, 1) |
|
|
|
|
| def _normalize_metadata_token(value: str) -> str: |
| return re.sub(r"[^a-z0-9]+", "_", value.lower()).strip("_") |
|
|
|
|
| def _normalize_organism_token(value: str) -> str: |
| raw = value.strip() |
| if not raw: |
| return "" |
| aliased = ORGANISM_ALIASES.get(raw, raw) |
| return _normalize_metadata_token(aliased) |
|
|
|
|
| def _assay_metadata_vector(assay_text: str, *, dim: int) -> np.ndarray: |
| if dim <= 0: |
| return np.zeros((0,), dtype=np.float32) |
| sections = _parse_assay_sections(assay_text) |
| tokens: list[str] = [] |
| organism = _normalize_organism_token(sections.get("ORGANISM", "")) |
| if organism: |
| tokens.append(f"organism:{organism}") |
| for key in ("READOUT", "ASSAY_FORMAT", "ASSAY_TYPE"): |
| value = _normalize_metadata_token(sections.get(key, "")) |
| if value: |
| tokens.append(f"{key.lower()}:{value}") |
| for target in sections.get("TARGET_UNIPROT", "").split(","): |
| token = target.strip().upper() |
| if token: |
| tokens.append(f"target:{token}") |
| vec = np.zeros((dim,), dtype=np.float32) |
| for token in tokens: |
| vec[_hash_bucket(token, dim)] += 1.0 |
| norm = float(np.linalg.norm(vec)) |
| if norm > 0: |
| vec /= norm |
| return vec |
|
|
|
|
| def _morgan_bits_from_mol(mol, *, radius: int, n_bits: int, use_chirality: bool) -> np.ndarray: |
| fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits, useChirality=use_chirality) |
| arr = np.zeros((n_bits,), dtype=np.uint8) |
| DataStructs.ConvertToNumpyArray(fp, arr) |
| return arr |
|
|
|
|
| def _maccs_bits_from_mol(mol) -> np.ndarray: |
| fp = MACCSkeys.GenMACCSKeys(mol) |
| arr = np.zeros((fp.GetNumBits(),), dtype=np.uint8) |
| DataStructs.ConvertToNumpyArray(fp, arr) |
| return arr |
|
|
|
|
| def _count_atomic_nums(mol) -> dict[int, int]: |
| counts: dict[int, int] = {} |
| for atom in mol.GetAtoms(): |
| atomic_num = int(atom.GetAtomicNum()) |
| counts[atomic_num] = counts.get(atomic_num, 0) + 1 |
| return counts |
|
|
|
|
| def _molecule_descriptor_vector(mol, *, names: tuple[str, ...] = DEFAULT_DESCRIPTOR_NAMES) -> np.ndarray: |
| counts = _count_atomic_nums(mol) |
| fragments = Chem.GetMolFrags(mol) |
| formal_charge = sum(int(atom.GetFormalCharge()) for atom in mol.GetAtoms()) |
| max_atomic_num = max(counts) if counts else 0 |
| metal_atom_count = sum(count for atomic_num, count in counts.items() if atomic_num not in ORGANIC_LIKE_ATOMIC_NUMBERS) |
| halogen_count = sum(counts.get(item, 0) for item in (9, 17, 35, 53)) |
| aromatic_atom_count = sum(1 for atom in mol.GetAtoms() if atom.GetIsAromatic()) |
| values = { |
| "mol_wt": float(Descriptors.MolWt(mol)), |
| "logp": float(Crippen.MolLogP(mol)), |
| "tpsa": float(rdMolDescriptors.CalcTPSA(mol)), |
| "heavy_atoms": float(mol.GetNumHeavyAtoms()), |
| "hbd": float(Lipinski.NumHDonors(mol)), |
| "hba": float(Lipinski.NumHAcceptors(mol)), |
| "rot_bonds": float(Lipinski.NumRotatableBonds(mol)), |
| "ring_count": float(rdMolDescriptors.CalcNumRings(mol)), |
| "aromatic_rings": float(rdMolDescriptors.CalcNumAromaticRings(mol)), |
| "aliphatic_rings": float(rdMolDescriptors.CalcNumAliphaticRings(mol)), |
| "saturated_rings": float(rdMolDescriptors.CalcNumSaturatedRings(mol)), |
| "fraction_csp3": float(rdMolDescriptors.CalcFractionCSP3(mol)), |
| "heteroatoms": float(rdMolDescriptors.CalcNumHeteroatoms(mol)), |
| "amide_bonds": float(rdMolDescriptors.CalcNumAmideBonds(mol)), |
| "fragments": float(len(fragments)), |
| "formal_charge": float(formal_charge), |
| "max_atomic_num": float(max_atomic_num), |
| "metal_atom_count": float(metal_atom_count), |
| "halogen_count": float(halogen_count), |
| "nitrogen_count": float(counts.get(7, 0)), |
| "oxygen_count": float(counts.get(8, 0)), |
| "sulfur_count": float(counts.get(16, 0)), |
| "phosphorus_count": float(counts.get(15, 0)), |
| "fluorine_count": float(counts.get(9, 0)), |
| "chlorine_count": float(counts.get(17, 0)), |
| "bromine_count": float(counts.get(35, 0)), |
| "iodine_count": float(counts.get(53, 0)), |
| "aromatic_atom_count": float(aromatic_atom_count), |
| "spiro_atoms": float(rdMolDescriptors.CalcNumSpiroAtoms(mol)), |
| "bridgehead_atoms": float(rdMolDescriptors.CalcNumBridgeheadAtoms(mol)), |
| } |
| return np.array([values[name] for name in names], dtype=np.float32) |
|
|
|
|
| def molecule_ui_metrics(smiles: str) -> dict[str, float | int]: |
| canonical = standardize_smiles_v2(smiles) or smiles |
| mol = Chem.MolFromSmiles(canonical) |
| if mol is None: |
| return { |
| "mol_wt": 0.0, |
| "logp": 0.0, |
| "tpsa": 0.0, |
| "heavy_atoms": 0, |
| } |
| return { |
| "mol_wt": float(Descriptors.MolWt(mol)), |
| "logp": float(Crippen.MolLogP(mol)), |
| "tpsa": float(rdMolDescriptors.CalcTPSA(mol)), |
| "heavy_atoms": int(mol.GetNumHeavyAtoms()), |
| } |
|
|
|
|
| class CompatibilityHead(nn.Module): |
| def __init__(self, *, assay_dim: int, molecule_dim: int, projection_dim: int, hidden_dim: int, dropout: float) -> None: |
| super().__init__() |
| self.assay_norm = nn.LayerNorm(assay_dim) |
| self.assay_proj = nn.Linear(assay_dim, projection_dim) |
| self.mol_norm = nn.LayerNorm(molecule_dim) |
| self.mol_proj = nn.Linear(molecule_dim, projection_dim, bias=False) |
| self.score_mlp = nn.Sequential( |
| nn.Linear(projection_dim * 4, hidden_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, 1), |
| ) |
| self.dot_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32)) |
|
|
| def encode_assay(self, assay_features: torch.Tensor) -> torch.Tensor: |
| vec = self.assay_proj(self.assay_norm(assay_features)) |
| return F.normalize(vec, p=2, dim=-1) |
|
|
| def encode_molecule(self, molecule_features: torch.Tensor) -> torch.Tensor: |
| vec = self.mol_proj(self.mol_norm(molecule_features)) |
| return F.normalize(vec, p=2, dim=-1) |
|
|
| def score_candidates(self, assay_features: torch.Tensor, candidate_features: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| assay_vec = self.encode_assay(assay_features) |
| mol_vec = self.encode_molecule(candidate_features) |
| assay_expand = assay_vec.unsqueeze(1).expand(-1, mol_vec.shape[1], -1) |
| dot_scores = (assay_expand * mol_vec).sum(dim=-1) |
| mlp_input = torch.cat( |
| [assay_expand, mol_vec, assay_expand * mol_vec, torch.abs(assay_expand - mol_vec)], |
| dim=-1, |
| ) |
| mlp_scores = self.score_mlp(mlp_input).squeeze(-1) |
| logits = dot_scores * self.dot_scale + mlp_scores |
| return logits, assay_vec, mol_vec |
|
|
|
|
| class SpaceCompatibilityModel: |
| def __init__( |
| self, |
| *, |
| assay_encoder: SentenceTransformer, |
| compatibility_head: CompatibilityHead, |
| assay_task_description: str, |
| fingerprint_radii: tuple[int, ...], |
| fingerprint_bits: int, |
| use_chirality: bool, |
| use_maccs: bool, |
| use_rdkit_descriptors: bool, |
| descriptor_names: tuple[str, ...], |
| descriptor_mean: np.ndarray | None, |
| descriptor_std: np.ndarray | None, |
| molecule_transformer_model_name: str, |
| molecule_transformer_batch_size: int, |
| molecule_transformer_max_length: int, |
| use_assay_metadata_features: bool, |
| assay_metadata_dim: int, |
| ) -> None: |
| self.assay_encoder = assay_encoder |
| self.compatibility_head = compatibility_head.eval() |
| self.assay_task_description = assay_task_description |
| self.fingerprint_radii = fingerprint_radii |
| self.fingerprint_bits = fingerprint_bits |
| self.use_chirality = use_chirality |
| self.use_maccs = use_maccs |
| self.use_rdkit_descriptors = use_rdkit_descriptors |
| self.descriptor_names = descriptor_names |
| self.descriptor_mean = descriptor_mean |
| self.descriptor_std = descriptor_std |
| self.molecule_transformer_model_name = molecule_transformer_model_name |
| self.molecule_transformer_batch_size = molecule_transformer_batch_size |
| self.molecule_transformer_max_length = molecule_transformer_max_length |
| self.use_assay_metadata_features = use_assay_metadata_features |
| self.assay_metadata_dim = assay_metadata_dim |
| self._molecule_transformer_tokenizer = None |
| self._molecule_transformer_model = None |
| self._molecule_transformer_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| def _format_assay_query(self, assay_text: str) -> str: |
| return f"Instruct: {self.assay_task_description.strip()}\nQuery: {assay_text.strip()}" |
|
|
| def _build_assay_feature_array(self, assay_text: str) -> np.ndarray: |
| assay_features = self.assay_encoder.encode( |
| [self._format_assay_query(assay_text)], |
| batch_size=1, |
| normalize_embeddings=True, |
| show_progress_bar=False, |
| convert_to_numpy=True, |
| )[0].astype(np.float32) |
| if self.use_assay_metadata_features and self.assay_metadata_dim > 0: |
| metadata_vec = _assay_metadata_vector(assay_text, dim=self.assay_metadata_dim) |
| assay_features = np.concatenate([assay_features, metadata_vec.astype(np.float32)], axis=0) |
| return assay_features |
|
|
| def _ensure_molecule_transformer_loaded(self) -> None: |
| if not self.molecule_transformer_model_name or self._molecule_transformer_model is not None: |
| return |
| dtype = torch.float16 if self._molecule_transformer_device.type == "cuda" else torch.float32 |
| with _silent_imports(): |
| self._molecule_transformer_tokenizer = AutoTokenizer.from_pretrained( |
| self.molecule_transformer_model_name, |
| trust_remote_code=True, |
| ) |
| self._molecule_transformer_model = AutoModel.from_pretrained( |
| self.molecule_transformer_model_name, |
| trust_remote_code=True, |
| torch_dtype=dtype, |
| ).to(self._molecule_transformer_device) |
| self._molecule_transformer_model.eval() |
|
|
| def _encode_molecule_transformer_batch(self, smiles_values: list[str]) -> np.ndarray | None: |
| if not self.molecule_transformer_model_name: |
| return None |
| self._ensure_molecule_transformer_loaded() |
| assert self._molecule_transformer_model is not None |
| assert self._molecule_transformer_tokenizer is not None |
| outputs: list[np.ndarray] = [] |
| batch_size = max(self.molecule_transformer_batch_size, 1) |
| with torch.no_grad(): |
| for start in range(0, len(smiles_values), batch_size): |
| batch = smiles_values[start : start + batch_size] |
| encoded = self._molecule_transformer_tokenizer( |
| batch, |
| padding=True, |
| truncation=True, |
| max_length=self.molecule_transformer_max_length, |
| return_tensors="pt", |
| ) |
| encoded = {key: value.to(self._molecule_transformer_device) for key, value in encoded.items()} |
| hidden = self._molecule_transformer_model(**encoded).last_hidden_state |
| mask = encoded["attention_mask"].unsqueeze(-1) |
| pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) |
| outputs.append(pooled.detach().cpu().to(torch.float32).numpy()) |
| return np.concatenate(outputs, axis=0).astype(np.float32) |
|
|
| def build_molecule_feature_matrix(self, smiles_values: list[str]) -> np.ndarray: |
| transformer_matrix = self._encode_molecule_transformer_batch(smiles_values) |
| rows: list[np.ndarray] = [] |
| for idx, smiles in enumerate(smiles_values): |
| normalized = standardize_smiles_v2(smiles) or smiles |
| mol = Chem.MolFromSmiles(normalized) |
| if mol is None: |
| raise ValueError(f"Could not parse SMILES: {normalized}") |
| bit_blocks: list[np.ndarray] = [ |
| _morgan_bits_from_mol(mol, radius=int(radius), n_bits=self.fingerprint_bits, use_chirality=self.use_chirality) |
| for radius in self.fingerprint_radii |
| ] |
| if self.use_maccs: |
| bit_blocks.append(_maccs_bits_from_mol(mol)) |
| output_blocks: list[np.ndarray] = [np.concatenate(bit_blocks, axis=0).astype(np.float32)] |
| if self.use_rdkit_descriptors and self.descriptor_names: |
| dense = _molecule_descriptor_vector(mol, names=self.descriptor_names) |
| if self.descriptor_mean is not None and self.descriptor_std is not None: |
| dense = (dense - self.descriptor_mean) / self.descriptor_std |
| output_blocks.append(dense.astype(np.float32)) |
| if transformer_matrix is not None: |
| output_blocks.append(np.asarray(transformer_matrix[idx], dtype=np.float32)) |
| rows.append(np.concatenate(output_blocks, axis=0).astype(np.float32)) |
| return np.stack(rows, axis=0) |
|
|
|
|
| def _load_sentence_transformer(model_name: str) -> SentenceTransformer: |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
| with _silent_imports(): |
| encoder = SentenceTransformer( |
| model_name, |
| trust_remote_code=True, |
| model_kwargs={"torch_dtype": dtype}, |
| ) |
| if getattr(encoder, "tokenizer", None) is not None: |
| encoder.tokenizer.padding_side = "left" |
| return encoder |
|
|
|
|
| def _load_feature_spec(cfg: dict[str, Any], metadata: dict[str, Any], checkpoint: dict[str, Any]) -> dict[str, Any]: |
| spec = checkpoint.get("molecule_feature_spec") or metadata.get("molecule_feature_spec") |
| if spec: |
| return spec |
| radii = tuple(int(item) for item in (cfg.get("fingerprint_radii") or [cfg.get("fingerprint_radius", 2)])) |
| return { |
| "fingerprint_radii": list(radii), |
| "fingerprint_bits": int(cfg["fingerprint_bits"]), |
| "use_chirality": bool(cfg.get("use_chirality", False)), |
| "use_maccs": bool(cfg.get("use_maccs", False)), |
| "use_rdkit_descriptors": bool(cfg.get("use_rdkit_descriptors", False)), |
| "descriptor_names": [], |
| "descriptor_mean": None, |
| "descriptor_std": None, |
| "molecule_transformer_model_name": str(cfg.get("molecule_transformer_model_name") or ""), |
| "molecule_transformer_max_length": int(cfg.get("molecule_transformer_max_length", 128) or 128), |
| } |
|
|
|
|
| def load_compatibility_model(model_dir: str | Path) -> SpaceCompatibilityModel: |
| model_path = Path(model_dir) |
| checkpoint = torch.load(model_path / "best_model.pt", map_location="cpu", weights_only=False) |
| metadata = json.loads((model_path / "training_metadata.json").read_text()) |
| cfg = metadata["config"] |
| feature_spec = _load_feature_spec(cfg, metadata, checkpoint) |
|
|
| encoder = _load_sentence_transformer(checkpoint.get("assay_model_name") or cfg["assay_model_name"]) |
| assay_dim = int(checkpoint["model_state_dict"]["assay_proj.weight"].shape[1]) |
| molecule_dim = int(checkpoint["model_state_dict"]["mol_proj.weight"].shape[1]) |
| head = CompatibilityHead( |
| assay_dim=assay_dim, |
| molecule_dim=molecule_dim, |
| projection_dim=int(cfg["projection_dim"]), |
| hidden_dim=int(cfg["hidden_dim"]), |
| dropout=float(cfg["dropout"]), |
| ) |
| load_result = head.load_state_dict(checkpoint["model_state_dict"], strict=False) |
| allowed_missing = {"mol_norm.weight", "mol_norm.bias"} |
| unexpected = set(load_result.unexpected_keys) |
| missing = set(load_result.missing_keys) |
| if unexpected or (missing - allowed_missing): |
| raise RuntimeError( |
| f"Checkpoint mismatch: unexpected={sorted(unexpected)} missing={sorted(missing)}" |
| ) |
| return SpaceCompatibilityModel( |
| assay_encoder=encoder, |
| compatibility_head=head, |
| assay_task_description=checkpoint.get("assay_task_description") or cfg.get("assay_task_description", DEFAULT_ASSAY_TASK), |
| fingerprint_radii=tuple(int(item) for item in feature_spec.get("fingerprint_radii") or [2]), |
| fingerprint_bits=int(feature_spec.get("fingerprint_bits", cfg.get("fingerprint_bits", 2048))), |
| use_chirality=bool(feature_spec.get("use_chirality", cfg.get("use_chirality", False))), |
| use_maccs=bool(feature_spec.get("use_maccs", cfg.get("use_maccs", False))), |
| use_rdkit_descriptors=bool(feature_spec.get("use_rdkit_descriptors", cfg.get("use_rdkit_descriptors", False))), |
| descriptor_names=tuple(feature_spec.get("descriptor_names") or ()), |
| descriptor_mean=np.array(feature_spec["descriptor_mean"], dtype=np.float32) if feature_spec.get("descriptor_mean") is not None else None, |
| descriptor_std=np.array(feature_spec["descriptor_std"], dtype=np.float32) if feature_spec.get("descriptor_std") is not None else None, |
| molecule_transformer_model_name=str(feature_spec.get("molecule_transformer_model_name") or cfg.get("molecule_transformer_model_name") or ""), |
| molecule_transformer_batch_size=int(cfg.get("molecule_transformer_batch_size", 128) or 128), |
| molecule_transformer_max_length=int(feature_spec.get("molecule_transformer_max_length") or cfg.get("molecule_transformer_max_length", 128) or 128), |
| use_assay_metadata_features=bool(cfg.get("use_assay_metadata_features", False)), |
| assay_metadata_dim=int(cfg.get("assay_metadata_dim", 0) or 0), |
| ) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def load_compatibility_model_from_hub(model_repo_id: str) -> SpaceCompatibilityModel: |
| with _silent_imports(): |
| model_dir = snapshot_download( |
| repo_id=model_repo_id, |
| repo_type="model", |
| allow_patterns=["best_model.pt", "training_metadata.json", "README.md"], |
| ) |
| return load_compatibility_model(model_dir) |
|
|
|
|
| def rank_compounds( |
| model: SpaceCompatibilityModel, |
| *, |
| assay_text: str, |
| smiles_list: list[str], |
| top_k: int | None = None, |
| ) -> list[dict[str, Any]]: |
| if not smiles_list: |
| return [] |
| assay_features = model._build_assay_feature_array(assay_text) |
| assay_tensor = torch.from_numpy(assay_features.astype(np.float32)).unsqueeze(0) |
|
|
| valid_items: list[tuple[str, str]] = [] |
| invalid_items: list[dict[str, Any]] = [] |
| for raw_smiles in smiles_list: |
| standardized = standardize_smiles_v2(raw_smiles) |
| if standardized is None: |
| invalid_items.append( |
| { |
| "input_smiles": raw_smiles, |
| "canonical_smiles": None, |
| "smiles_hash": None, |
| "score": None, |
| "valid": False, |
| "error": "invalid_smiles", |
| } |
| ) |
| continue |
| valid_items.append((raw_smiles, standardized)) |
|
|
| ranked_items: list[dict[str, Any]] = [] |
| if valid_items: |
| feature_matrix = model.build_molecule_feature_matrix([item[1] for item in valid_items]) |
| candidate_tensor = torch.from_numpy(feature_matrix).unsqueeze(0) |
| with torch.no_grad(): |
| logits, _, _ = model.compatibility_head.score_candidates( |
| assay_tensor.to(dtype=torch.float32), |
| candidate_tensor.to(dtype=torch.float32), |
| ) |
| scores = logits.squeeze(0).cpu().numpy().tolist() |
| for (raw_smiles, canonical), score in zip(valid_items, scores, strict=True): |
| ranked_items.append( |
| { |
| "input_smiles": raw_smiles, |
| "canonical_smiles": canonical, |
| "smiles_hash": smiles_sha256(canonical), |
| "score": float(score), |
| "valid": True, |
| } |
| ) |
| ranked_items.sort(key=lambda item: item["score"], reverse=True) |
| if top_k is not None and top_k > 0: |
| ranked_items = ranked_items[:top_k] |
|
|
| return ranked_items + invalid_items |
|
|