| import torch |
| import json |
|
|
| from .model_arch import EGNNDiffusionModel, OlfactoryConditioner |
| from .utils import load_goodscents_subset, sample, validate_molecule |
|
|
|
|
| |
| smiles_list, label_map, label_names = load_goodscents_subset(index=500) |
| num_labels = len(label_names) |
|
|
| |
| model = EGNNDiffusionModel(node_dim=1, embed_dim=8) |
| model.load_state_dict(torch.load('/models/constrained/egnn_state_dict.pth')) |
| model.eval() |
|
|
| conditioner = OlfactoryConditioner(num_labels=num_labels, embed_dim=8) |
| conditioner.load_state_dict(torch.load('/models/constrained/olfactory_conditioner_state_dict.pth')) |
| conditioner.eval() |
|
|
|
|
| |
| descriptor_list: list = [] |
| aroma_vec_list: list = [] |
| with open('data/olfaction-vision-language-dataset.json', 'r') as file: |
| json_string = json.load(file) |
| data = json.loads(json_string) |
| for item in data: |
| item_dict: dict = dict(item) |
| if "descriptors" in item_dict.keys(): |
| descriptor_list.append(item_dict["descriptors"]) |
| if "aroma_vec" in item_dict.keys(): |
| aroma_vec_list.append(item_dict["aroma_vec"]) |
|
|
|
|
| |
| smiles_list, label_map, label_names = load_goodscents_subset(index=1000) |
| num_labels = len(label_names) |
| count: int = 0 |
| for i in range(0, len(descriptor_list)): |
| test_label_vec = torch.zeros(num_labels) |
| for descriptor in descriptor_list[i]: |
| if descriptor in label_names: |
| test_label_vec[label_names.index(descriptor)] = 1 |
|
|
| |
| new_smiles = sample(model, conditioner, label_vec=test_label_vec) |
| print(new_smiles) |
|
|
| |
| valid, props = validate_molecule(new_smiles) |
| print(f"Generated SMILES: {new_smiles}\nValid: {valid}, Properties: {props}") |
| if new_smiles != "": |
| count += 1 |
|
|
| |
| percent_correct: float = float(count) / float(len(aroma_vec_list)) * 100.0 |
| print(f"Percent correct: {percent_correct}") |
|
|