| import torch |
| import pandas as pd |
| import typing |
| import math |
| import utils |
| import numpy as np |
| import os |
|
|
| base_path = "" |
| LOGGER = utils.get_logger(__name__) |
| DNA_ALPHABET = {'A': 0, 'C': 1, 'G': 2, 'T': 3} |
| INDEX_TO_DNA = {v: k for k, v in DNA_ALPHABET.items()} |
| lookup_array = np.array([INDEX_TO_DNA[i] for i in range(len(INDEX_TO_DNA))]) |
|
|
|
|
| def dna_detokenize(seq): |
| return ''.join([list(DNA_ALPHABET.keys())[int(i)] for i in seq]) |
|
|
| def batch_dna_detokenize(batch_seq): |
| """ |
| batch_seq: numpy array of shape [batch_size, seq_len] |
| return: list of strings |
| """ |
| detokenized_batch = lookup_array[batch_seq] |
| detokenized_batch = [''.join(seq) for seq in detokenized_batch] |
| return detokenized_batch |
|
|
| def dna_tokenize(seq): |
| return [DNA_ALPHABET[c] for c in seq] |
|
|
| def batch_dna_tokenize(batch_seq): |
| """ |
| batch_seq: list of strings |
| return: numpy array of shape [batch_size, seq_len] |
| """ |
| tokenized_batch = np.array([[DNA_ALPHABET[c] for c in seq] for seq in batch_seq]) |
| return tokenized_batch |
|
|
| class GosaiDataset(torch.utils.data.Dataset): |
| def __init__(self): |
| data_df = pd.read_csv(os.path.join(base_path, f'mdlm/gosai_data/processed_data/gosai_all.csv')) |
| self.seqs = torch.tensor(data_df['seq'].apply(lambda x: [DNA_ALPHABET[c] for c in x]).tolist()) |
| self.clss = torch.tensor(data_df[['hepg2', 'k562', 'sknsh']].to_numpy()) |
| LOGGER.info(f'Loaded data: seqs shape: {self.seqs.shape}, clss shape: {self.clss.shape}') |
|
|
| def __len__(self): |
| return len(self.seqs) |
|
|
| def __getitem__(self, idx): |
| return {'seqs': self.seqs[idx], 'clss': self.clss[idx], 'attention_mask': torch.ones(len(self.seqs[idx]))} |
|
|
|
|
| def get_datasets_gosai(): |
| return GosaiDataset() |
|
|
|
|
| def get_dataloaders_gosai(config, skip_valid=False, valid_seed=None): |
| num_gpus = torch.cuda.device_count() |
| if config.loader.global_batch_size % ( |
| num_gpus * config.trainer.accumulate_grad_batches) != 0: |
| raise ValueError( |
| f'Train Batch Size {config.training.batch_size}' |
| f'not divisible by {num_gpus} gpus with accumulation ' |
| f'{config.trainer.accumulate_grad_batches}.') |
| if config.loader.eval_global_batch_size % num_gpus != 0: |
| raise ValueError( |
| f'Eval Batch Size for {config.eval.batch_size} ' |
| f'not divisible by {num_gpus}.') |
| train_set = GosaiDataset() |
| |
| valid_set = torch.utils.data.Subset(train_set, np.random.choice(len(train_set), 40000, replace=False)) |
| test_set = torch.utils.data.Subset(train_set, np.random.choice(len(train_set), 40000, replace=False)) |
|
|
| train_loader = torch.utils.data.DataLoader( |
| train_set, |
| batch_size=config.loader.batch_size, |
| num_workers=config.loader.num_workers, |
| pin_memory=config.loader.pin_memory, |
| shuffle=not config.data.streaming, |
| persistent_workers=True) |
| if skip_valid: |
| valid_loader = None |
| test_loader = None |
| else: |
| if valid_seed is None: |
| shuffle_valid = False |
| generator = None |
| else: |
| shuffle_valid = True |
| generator = torch.Generator().manual_seed(valid_seed) |
| valid_loader = torch.utils.data.DataLoader( |
| valid_set, |
| batch_size=config.loader.eval_batch_size, |
| num_workers=config.loader.num_workers, |
| pin_memory=config.loader.pin_memory, |
| shuffle=shuffle_valid, |
| generator=generator) |
| test_loader = torch.utils.data.DataLoader( |
| test_set, |
| batch_size=config.loader.eval_batch_size, |
| num_workers=config.loader.num_workers, |
| pin_memory=config.loader.pin_memory, |
| shuffle=shuffle_valid, |
| generator=generator) |
|
|
| return train_loader, valid_loader, test_loader |
|
|
|
|
| |
| class RandomFaultTolerantSampler(torch.utils.data.RandomSampler): |
|
|
| def __init__(self, *args, generator=None, **kwargs): |
| |
| |
| |
| |
| if generator is None: |
| seed = int(torch.empty((), dtype=torch.int64).random_().item()) |
| generator = torch.Generator().manual_seed(seed) |
| kwargs.pop('shuffle', None) |
| super().__init__(*args, generator=generator, **kwargs) |
| self.counter = 0 |
| self.restarting = False |
|
|
| def state_dict(self): |
| return {'random_state': self.generator.get_state(), |
| 'counter': self.counter} |
|
|
| def load_state_dict(self, state_dict): |
| self.generator.set_state(state_dict.get('random_state')) |
| self.counter = state_dict['counter'] |
| |
| self.restarting = True |
|
|
| |
| |
|
|
| def __iter__(self) -> typing.Iterator[int]: |
| n = len(self.data_source) |
|
|
| self.state = self.generator.get_state() |
| indices = torch.randperm(n, generator=self.generator).tolist() |
|
|
| if not self.restarting: |
| self.counter = 0 |
| else: |
| indices = indices[self.counter:] |
| self.restarting = False |
|
|
| for index in indices: |
| self.counter += 1 |
| yield index |
|
|
| self.counter = 0 |
|
|
|
|
| class FaultTolerantDistributedSampler(torch.utils.data.DistributedSampler): |
|
|
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.counter = 0 |
| self.restarting = False |
|
|
| def state_dict(self): |
| return {'epoch': self.epoch, 'counter': self.counter} |
|
|
| def load_state_dict(self, state_dict): |
| self.epoch = state_dict['epoch'] |
| self.counter = state_dict['counter'] |
| self.restarting = True |
|
|
| |
| |
| def __iter__(self): |
| if self.shuffle: |
| |
| g = torch.Generator() |
| g.manual_seed(self.seed + self.epoch) |
| indices = torch.randperm(len(self.dataset), generator=g).tolist() |
| else: |
| indices = list(range(len(self.dataset))) |
|
|
| if not self.drop_last: |
| |
| padding_size = self.total_size - len(indices) |
| if padding_size <= len(indices): |
| indices += indices[:padding_size] |
| else: |
| indices += (indices * math.ceil( |
| padding_size / len(indices)))[:padding_size] |
| else: |
| |
| indices = indices[:self.total_size] |
| assert len(indices) == self.total_size |
|
|
| |
| indices = indices[self.rank:self.total_size:self.num_replicas] |
| assert len(indices) == self.num_samples |
|
|
| if not self.restarting: |
| self.counter = 0 |
| else: |
| indices = indices[self.counter:] |
| self.restarting = False |
|
|
| for index in indices: |
| self.counter += 1 |
| yield index |
|
|
| self.counter = 0 |
|
|
|
|