Spaces:
Runtime error
Runtime error
| import os | |
| from rdkit import Chem | |
| from rdkit.Chem import Draw, AllChem | |
| from rdkit.Geometry import Point3D | |
| from rdkit import RDLogger | |
| import imageio | |
| import networkx as nx | |
| import numpy as np | |
| import rdkit.Chem | |
| import wandb | |
| import matplotlib.pyplot as plt | |
| class MolecularVisualization: | |
| def __init__(self, remove_h, dataset_infos): | |
| self.remove_h = remove_h | |
| self.dataset_infos = dataset_infos | |
| def mol_from_graphs(self, node_list, adjacency_matrix): | |
| """ | |
| Convert graphs to rdkit molecules | |
| node_list: the nodes of a batch of nodes (bs x n) | |
| adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n) | |
| """ | |
| # dictionary to map integer value to the char of atom | |
| atom_decoder = self.dataset_infos.atom_decoder | |
| # create empty editable mol object | |
| mol = Chem.RWMol() | |
| # add atoms to mol and keep track of index | |
| node_to_idx = {} | |
| for i in range(len(node_list)): | |
| if node_list[i] == -1: | |
| continue | |
| a = Chem.Atom(atom_decoder[int(node_list[i])]) | |
| molIdx = mol.AddAtom(a) | |
| node_to_idx[i] = molIdx | |
| for ix, row in enumerate(adjacency_matrix): | |
| for iy, bond in enumerate(row): | |
| # only traverse half the symmetric matrix | |
| if iy <= ix: | |
| continue | |
| if bond == 1: | |
| bond_type = Chem.rdchem.BondType.SINGLE | |
| elif bond == 2: | |
| bond_type = Chem.rdchem.BondType.DOUBLE | |
| elif bond == 3: | |
| bond_type = Chem.rdchem.BondType.TRIPLE | |
| elif bond == 4: | |
| bond_type = Chem.rdchem.BondType.AROMATIC | |
| else: | |
| continue | |
| mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type) | |
| try: | |
| mol = mol.GetMol() | |
| except rdkit.Chem.KekulizeException: | |
| print("Can't kekulize molecule") | |
| mol = None | |
| return mol | |
| def visualize(self, path: str, molecules: list, num_molecules_to_visualize: int, log='graph'): | |
| # define path to save figures | |
| if not os.path.exists(path): | |
| os.makedirs(path) | |
| # visualize the final molecules | |
| print(f"Visualizing {num_molecules_to_visualize} of {len(molecules)}") | |
| if num_molecules_to_visualize > len(molecules): | |
| print(f"Shortening to {len(molecules)}") | |
| num_molecules_to_visualize = len(molecules) | |
| for i in range(num_molecules_to_visualize): | |
| file_path = os.path.join(path, 'molecule_{}.png'.format(i)) | |
| mol = self.mol_from_graphs(molecules[i][0].numpy(), molecules[i][1].numpy()) | |
| try: | |
| Draw.MolToFile(mol, file_path) | |
| if wandb.run and log is not None: | |
| print(f"Saving {file_path} to wandb") | |
| wandb.log({log: wandb.Image(file_path)}, commit=True) | |
| except rdkit.Chem.KekulizeException: | |
| print("Can't kekulize molecule") | |
| def visualize_chain(self, path, nodes_list, adjacency_matrix, trainer=None): | |
| RDLogger.DisableLog('rdApp.*') | |
| # convert graphs to the rdkit molecules | |
| mols = [self.mol_from_graphs(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])] | |
| # find the coordinates of atoms in the final molecule | |
| final_molecule = mols[-1] | |
| AllChem.Compute2DCoords(final_molecule) | |
| coords = [] | |
| for i, atom in enumerate(final_molecule.GetAtoms()): | |
| positions = final_molecule.GetConformer().GetAtomPosition(i) | |
| coords.append((positions.x, positions.y, positions.z)) | |
| # align all the molecules | |
| for i, mol in enumerate(mols): | |
| AllChem.Compute2DCoords(mol) | |
| conf = mol.GetConformer() | |
| for j, atom in enumerate(mol.GetAtoms()): | |
| x, y, z = coords[j] | |
| conf.SetAtomPosition(j, Point3D(x, y, z)) | |
| # draw gif | |
| save_paths = [] | |
| num_frams = nodes_list.shape[0] | |
| for frame in range(num_frams): | |
| file_name = os.path.join(path, 'fram_{}.png'.format(frame)) | |
| Draw.MolToFile(mols[frame], file_name, size=(300, 300), legend=f"Frame {frame}") | |
| save_paths.append(file_name) | |
| imgs = [imageio.imread(fn) for fn in save_paths] | |
| gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1])) | |
| imgs.extend([imgs[-1]] * 10) | |
| imageio.mimsave(gif_path, imgs, subrectangles=True, duration=20) | |
| if wandb.run: | |
| print(f"Saving {gif_path} to wandb") | |
| wandb.log({"chain": wandb.Video(gif_path, fps=5, format="gif")}, commit=True) | |
| # draw grid image | |
| try: | |
| img = Draw.MolsToGridImage(mols, molsPerRow=10, subImgSize=(200, 200)) | |
| img.save(os.path.join(path, '{}_grid_image.png'.format(path.split('/')[-1]))) | |
| except Chem.rdchem.KekulizeException: | |
| print("Can't kekulize molecule") | |
| return mols | |
| class NonMolecularVisualization: | |
| def to_networkx(self, node_list, adjacency_matrix): | |
| """ | |
| Convert graphs to networkx graphs | |
| node_list: the nodes of a batch of nodes (bs x n) | |
| adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n) | |
| """ | |
| graph = nx.Graph() | |
| for i in range(len(node_list)): | |
| if node_list[i] == -1: | |
| continue | |
| graph.add_node(i, number=i, symbol=node_list[i], color_val=node_list[i]) | |
| rows, cols = np.where(adjacency_matrix >= 1) | |
| edges = zip(rows.tolist(), cols.tolist()) | |
| for edge in edges: | |
| edge_type = adjacency_matrix[edge[0]][edge[1]] | |
| graph.add_edge(edge[0], edge[1], color=float(edge_type), weight=3 * edge_type) | |
| return graph | |
| def visualize_non_molecule(self, graph, pos, path, iterations=100, node_size=100, largest_component=False): | |
| if largest_component: | |
| CGs = [graph.subgraph(c) for c in nx.connected_components(graph)] | |
| CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True) | |
| graph = CGs[0] | |
| # Plot the graph structure with colors | |
| if pos is None: | |
| pos = nx.spring_layout(graph, iterations=iterations) | |
| # Set node colors based on the eigenvectors | |
| w, U = np.linalg.eigh(nx.normalized_laplacian_matrix(graph).toarray()) | |
| vmin, vmax = np.min(U[:, 1]), np.max(U[:, 1]) | |
| m = max(np.abs(vmin), vmax) | |
| vmin, vmax = -m, m | |
| plt.figure() | |
| nx.draw(graph, pos, font_size=5, node_size=node_size, with_labels=False, node_color=U[:, 1], | |
| cmap=plt.cm.coolwarm, vmin=vmin, vmax=vmax, edge_color='grey') | |
| plt.tight_layout() | |
| plt.savefig(path) | |
| plt.close("all") | |
| def visualize(self, path: str, graphs: list, num_graphs_to_visualize: int, log='graph'): | |
| # define path to save figures | |
| if not os.path.exists(path): | |
| os.makedirs(path) | |
| # visualize the final molecules | |
| for i in range(num_graphs_to_visualize): | |
| file_path = os.path.join(path, 'graph_{}.png'.format(i)) | |
| graph = self.to_networkx(graphs[i][0].numpy(), graphs[i][1].numpy()) | |
| self.visualize_non_molecule(graph=graph, pos=None, path=file_path) | |
| im = plt.imread(file_path) | |
| if wandb.run and log is not None: | |
| wandb.log({log: [wandb.Image(im, caption=file_path)]}) | |
| def visualize_chain(self, path, nodes_list, adjacency_matrix): | |
| # convert graphs to networkx | |
| graphs = [self.to_networkx(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])] | |
| # find the coordinates of atoms in the final molecule | |
| final_graph = graphs[-1] | |
| final_pos = nx.spring_layout(final_graph, seed=0) | |
| # draw gif | |
| save_paths = [] | |
| num_frams = nodes_list.shape[0] | |
| for frame in range(num_frams): | |
| file_name = os.path.join(path, 'fram_{}.png'.format(frame)) | |
| self.visualize_non_molecule(graph=graphs[frame], pos=final_pos, path=file_name) | |
| save_paths.append(file_name) | |
| imgs = [imageio.imread(fn) for fn in save_paths] | |
| gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1])) | |
| imgs.extend([imgs[-1]] * 10) | |
| imageio.mimsave(gif_path, imgs, subrectangles=True, duration=20) | |
| if wandb.run: | |
| wandb.log({'chain': [wandb.Video(gif_path, caption=gif_path, format="gif")]}) | |