| | from itertools import accumulate |
| | import numpy as np |
| | import torch |
| | from torch.utils.data import Dataset |
| |
|
| |
|
| | class ProcessedLigandPocketDataset(Dataset): |
| | def __init__(self, npz_path, center=True, transform=None): |
| |
|
| | self.transform = transform |
| |
|
| | with np.load(npz_path, allow_pickle=True) as f: |
| | data = {key: val for key, val in f.items()} |
| |
|
| | |
| | self.data = {} |
| | for (k, v) in data.items(): |
| | if k == 'names' or k == 'receptors': |
| | self.data[k] = v |
| | continue |
| |
|
| | sections = np.where(np.diff(data['lig_mask']))[0] + 1 \ |
| | if 'lig' in k \ |
| | else np.where(np.diff(data['pocket_mask']))[0] + 1 |
| | self.data[k] = [torch.from_numpy(x) for x in np.split(v, sections)] |
| |
|
| | |
| | if k == 'lig_mask': |
| | self.data['num_lig_atoms'] = \ |
| | torch.tensor([len(x) for x in self.data['lig_mask']]) |
| | elif k == 'pocket_mask': |
| | self.data['num_pocket_nodes'] = \ |
| | torch.tensor([len(x) for x in self.data['pocket_mask']]) |
| |
|
| | if center: |
| | for i in range(len(self.data['lig_coords'])): |
| | mean = (self.data['lig_coords'][i].sum(0) + |
| | self.data['pocket_coords'][i].sum(0)) / \ |
| | (len(self.data['lig_coords'][i]) + len(self.data['pocket_coords'][i])) |
| | self.data['lig_coords'][i] = self.data['lig_coords'][i] - mean |
| | self.data['pocket_coords'][i] = self.data['pocket_coords'][i] - mean |
| |
|
| | def __len__(self): |
| | return len(self.data['names']) |
| |
|
| | def __getitem__(self, idx): |
| | data = {key: val[idx] for key, val in self.data.items()} |
| | if self.transform is not None: |
| | data = self.transform(data) |
| | return data |
| |
|
| | @staticmethod |
| | def collate_fn(batch): |
| | out = {} |
| | for prop in batch[0].keys(): |
| |
|
| | if prop == 'names' or prop == 'receptors': |
| | out[prop] = [x[prop] for x in batch] |
| | elif prop == 'num_lig_atoms' or prop == 'num_pocket_nodes' \ |
| | or prop == 'num_virtual_atoms': |
| | out[prop] = torch.tensor([x[prop] for x in batch]) |
| | elif 'mask' in prop: |
| | |
| | |
| | out[prop] = torch.cat([i * torch.ones(len(x[prop])) |
| | for i, x in enumerate(batch)], dim=0) |
| | else: |
| | out[prop] = torch.cat([x[prop] for x in batch], dim=0) |
| |
|
| | return out |
| |
|