clique / GraphUNets /src /utils /dataset.py
qingy2024's picture
Upload folder using huggingface_hub
bf620c6 verified
import random
import torch
class GraphData(object):
def __init__(self, data, feat_dim):
super(GraphData, self).__init__()
self.data = data
self.feat_dim = feat_dim
self.idx = list(range(len(data)))
self.pos = 0
def __reset__(self):
self.pos = 0
if self.shuffle:
random.shuffle(self.idx)
def __len__(self):
return len(self.data) // self.batch + 1
def __getitem__(self, idx):
g = self.data[idx]
return g.A, g.feas.float(), g.label
def __iter__(self):
return self
def __next__(self):
if self.pos >= len(self.data):
self.__reset__()
raise StopIteration
cur_idx = self.idx[self.pos: self.pos+self.batch]
data = [self.__getitem__(idx) for idx in cur_idx]
self.pos += len(cur_idx)
gs, hs, labels = map(list, zip(*data))
return len(gs), gs, hs, torch.LongTensor(labels)
def loader(self, batch, shuffle, *args):
self.batch = batch
self.shuffle = shuffle
if shuffle:
random.shuffle(self.idx)
return self