import random import torch class GraphData(object): def __init__(self, data, feat_dim): super(GraphData, self).__init__() self.data = data self.feat_dim = feat_dim self.idx = list(range(len(data))) self.pos = 0 def __reset__(self): self.pos = 0 if self.shuffle: random.shuffle(self.idx) def __len__(self): return len(self.data) // self.batch + 1 def __getitem__(self, idx): g = self.data[idx] return g.A, g.feas.float(), g.label def __iter__(self): return self def __next__(self): if self.pos >= len(self.data): self.__reset__() raise StopIteration cur_idx = self.idx[self.pos: self.pos+self.batch] data = [self.__getitem__(idx) for idx in cur_idx] self.pos += len(cur_idx) gs, hs, labels = map(list, zip(*data)) return len(gs), gs, hs, torch.LongTensor(labels) def loader(self, batch, shuffle, *args): self.batch = batch self.shuffle = shuffle if shuffle: random.shuffle(self.idx) return self