| import torch |
| from tqdm import tqdm |
| import networkx as nx |
| import numpy as np |
| import torch.nn.functional as F |
| from sklearn.model_selection import StratifiedKFold |
| from functools import partial |
|
|
|
|
| class G_data(object): |
| def __init__(self, num_class, feat_dim, g_list): |
| self.num_class = num_class |
| self.feat_dim = feat_dim |
| self.g_list = g_list |
| self.sep_data() |
|
|
| def sep_data(self, seed=0): |
| skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed) |
| labels = [g.label for g in self.g_list] |
| self.idx_list = list(skf.split(np.zeros(len(labels)), labels)) |
|
|
| def use_fold_data(self, fold_idx): |
| self.fold_idx = fold_idx+1 |
| train_idx, test_idx = self.idx_list[fold_idx] |
| self.train_gs = [self.g_list[i] for i in train_idx] |
| self.test_gs = [self.g_list[i] for i in test_idx] |
|
|
|
|
| class FileLoader(object): |
| def __init__(self, args): |
| self.args = args |
|
|
| def line_genor(self, lines): |
| for line in lines: |
| yield line |
|
|
| def gen_graph(self, f, i, label_dict, feat_dict, deg_as_tag): |
| row = next(f).strip().split() |
| n, label = [int(w) for w in row] |
| if label not in label_dict: |
| label_dict[label] = len(label_dict) |
| g = nx.Graph() |
| g.add_nodes_from(list(range(n))) |
| node_tags = [] |
| for j in range(n): |
| row = next(f).strip().split() |
| tmp = int(row[1]) + 2 |
| row = [int(w) for w in row[:tmp]] |
| if row[0] not in feat_dict: |
| feat_dict[row[0]] = len(feat_dict) |
| for k in range(2, len(row)): |
| if j != row[k]: |
| g.add_edge(j, row[k]) |
| if len(row) > 2: |
| node_tags.append(feat_dict[row[0]]) |
| g.label = label |
| g.remove_nodes_from(list(nx.isolates(g))) |
| if deg_as_tag: |
| g.node_tags = list(dict(g.degree).values()) |
| else: |
| g.node_tags = node_tags |
| return g |
|
|
| def process_g(self, label_dict, tag2index, tagset, g): |
| g.label = label_dict[g.label] |
| g.feas = torch.tensor([tag2index[tag] for tag in g.node_tags]) |
| g.feas = F.one_hot(g.feas, len(tagset)) |
| A = torch.FloatTensor(nx.to_numpy_matrix(g)) |
| g.A = A + torch.eye(g.number_of_nodes()) |
| return g |
|
|
| def load_data(self): |
| args = self.args |
| print('loading data ...') |
| g_list = [] |
| label_dict = {} |
| feat_dict = {} |
|
|
| with open('data/%s/%s.txt' % (args.data, args.data), 'r') as f: |
| lines = f.readlines() |
| f = self.line_genor(lines) |
| n_g = int(next(f).strip()) |
| for i in tqdm(range(n_g), desc="Create graph", unit='graphs'): |
| g = self.gen_graph(f, i, label_dict, feat_dict, args.deg_as_tag) |
| g_list.append(g) |
|
|
| tagset = set([]) |
| for g in g_list: |
| tagset = tagset.union(set(g.node_tags)) |
| tagset = list(tagset) |
| tag2index = {tagset[i]: i for i in range(len(tagset))} |
|
|
| f_n = partial(self.process_g, label_dict, tag2index, tagset) |
| new_g_list = [] |
| for g in tqdm(g_list, desc="Process graph", unit='graphs'): |
| new_g_list.append(f_n(g)) |
| num_class = len(label_dict) |
| feat_dim = len(tagset) |
|
|
| print('# classes: %d' % num_class, '# maximum node tag: %d' % feat_dim) |
| return G_data(num_class, feat_dim, new_g_list) |
|
|