├── .gitignore ├── data ├── cora │ └── README └── emails │ └── email_labels.txt ├── README.md ├── utils.py ├── main.py ├── dataset.py ├── community_detection.py ├── model.py ├── execution.py ├── preparation.py ├── Execution-demo-NC.ipynb └── Execution-demo-LP.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | /runs 2 | /checkpoints 3 | /__pycache__ 4 | /.ipynb_checkpoints 5 | .idea 6 | -------------------------------------------------------------------------------- /data/cora/README: -------------------------------------------------------------------------------- 1 | This directory contains the a selection of the Cora dataset (www.research.whizbang.com/data). 2 | 3 | The Cora dataset consists of Machine Learning papers. These papers are classified into one of the following seven classes: 4 | Case_Based 5 | Genetic_Algorithms 6 | Neural_Networks 7 | Probabilistic_Methods 8 | Reinforcement_Learning 9 | Rule_Learning 10 | Theory 11 | 12 | The papers were selected in a way such that in the final corpus every paper cites or is cited by atleast one other paper. There are 2708 papers in the whole corpus. 13 | 14 | After stemming and removing stopwords we were left with a vocabulary of size 1433 unique words. All words with document frequency less than 10 were removed. 15 | 16 | 17 | THE DIRECTORY CONTAINS TWO FILES: 18 | 19 | The .content file contains descriptions of the papers in the following format: 20 | 21 | + 22 | 23 | The first entry in each line contains the unique string ID of the paper followed by binary values indicating whether each word in the vocabulary is present (indicated by 1) or absent (indicated by 0) in the paper. Finally, the last entry in the line contains the class label of the paper. 24 | 25 | The .cites file contains the citation graph of the corpus. Each line describes a link in the following format: 26 | 27 | 28 | 29 | Each line contains two paper IDs. The first entry is the ID of the paper being cited and the second ID stands for the paper which contains the citation. The direction of the link is from right to left. If a line is represented by "paper1 paper2" then the link is "paper2->paper1". -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hierarchical Message-Passing Graph Neural Networks 2 | 3 | Code of the first practical model (HC-GNN). 4 | 5 | ### Required packages 6 | The code has been tested running under Python 3.7.3. with the following packages installed (along with their dependencies): 7 | 8 | - numpy == 1.16.5 9 | - pandas == 0.25.1 10 | - scikit-learn == 0.21.2 11 | - networkx == 2.3 12 | - community (python-louvain) == 0.13 13 | - pytorch == 1.1.0 14 | - torch_geometric == 1.3.2 15 | 16 | ### Data requirement 17 | All eight datasets we used in the paper are all public datasets which can be downloaded from the internet. 18 | 19 | ### Code execution 20 | Link prediction: 21 | ``` 22 | python main.py --task LP --dataset grid --mode basemodel --model HCGNN --layer_num 3 --epoch_num 2001 --lr 0.0001 --relu True --dropout True --drop_ratio 0.5 --same_level_gnn GCN --down2up_gnn MEAN --up2down_gnn GAT --fshot False --SEED 123 --gpu True 23 | ``` 24 | 25 | Node classification and community detection: 26 | ``` 27 | python main.py --task NC --dataset cora --mode basemodel --model HCGNN --layer_num 2 --epoch_num 201 --lr 0.01 --relu True --dropout False --drop_ratio 0.5 --same_level_gnn GCN --down2up_gnn MEAN --up2down_gnn GCN --fshot True --SEED 1234 --gpu True 28 | ``` 29 | 30 | Model hyper-parameters: 31 | ``` 32 | --task: the target downstream task, "LP; NC; Inductive", type=str, default=LP 33 | --dataset: dataset name, type=str, default=grid 34 | --mode: the experiment type, type=str, default=basemodel 35 | --model: the model name, type=str, default=HCGNN 36 | --layer_num: the number of layers of primary GNN encoder for within level propagation, type=int, default=3 37 | --epoch_num: epoch number, type=int, default=2001 38 | --lr: learning rate, type=float, default=0.0001 39 | --relu: whether use relu as activation function in the model, type=bool, default=True 40 | --dropout: whether use dropout component in the model, type=bool, default=True 41 | --drop_ratio: dropout ratio if use dropout component, type=float, default=0.5 42 | --same_level_gnn: the GNN encoder for within level propagation, type=str, default=GCN 43 | --down2up_gnn: define the down2up propagation, type=str, default=MEAN 44 | --up2down_gnn: define the top2down propagation, type=str, default=GAT 45 | --fshot: if adopt few-shot learning settings, type=bool, default=False 46 | --SEED: random seed, type=int, default=123 47 | --gpu: if use GPU device, type=bool, default=True 48 | ``` 49 | 50 | Two demo file is given to show the execution of link prediction (LP) and node classification (NC) tasks. 51 | 52 | ## Cite 53 | 54 | Please cite our paper if it is helpful in your own work: 55 | 56 | ```bibtex 57 | @article{ZLP23, 58 | author = {Zhiqiang Zhong and Cheng{-}Te Li and Jun Pang}, 59 | title = {Hierarchical Message-Passing Graph Neural Networks}, 60 | journal = {Data Mining and Knowledge Discovery (DMKD)}, 61 | volume = {37}, 62 | number = {1}, 63 | pages = {381--408}, 64 | publisher = {Springer}, 65 | year = {2023}, 66 | } 67 | ``` 68 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import networkx as nx 3 | from multiprocessing import cpu_count, Pool 4 | import torch 5 | from sklearn.metrics import roc_auc_score, f1_score, normalized_mutual_info_score 6 | 7 | 8 | def evaluate_results(pred, y, idx=None, method=None): 9 | if method == 'roc-auc': 10 | return roc_auc_score( 11 | y_score=pred, y_true=y 12 | ) 13 | elif method == 'mic-f1': 14 | return f1_score( 15 | y_pred=np.argmax(pred.data.cpu().numpy(), axis=1), 16 | y_true=np.argmax(y, axis=1)[idx], average='micro' 17 | ) 18 | elif method == 'mac-f1': 19 | return f1_score( 20 | y_pred=np.argmax(pred.data.cpu().numpy(), axis=1), 21 | y_true=np.argmax(y, axis=1)[idx], average='macro' 22 | ) 23 | elif method == 'nmi': 24 | return normalized_mutual_info_score( 25 | labels_pred=np.argmax(pred.data.cpu().numpy(), axis=1), 26 | labels_true=np.argmax(y, axis=1)[idx] 27 | ) 28 | 29 | 30 | def generate_bigram(ls): 31 | res = [] 32 | for i in range(len(ls)-1): 33 | res += [(ls[i], item) for item in ls[i+1:]] 34 | return res 35 | 36 | 37 | def max_lists(lists): 38 | return max([item for items in lists for item in items]) 39 | 40 | 41 | def min_lists(lists): 42 | return min([item for items in lists for item in items]) 43 | 44 | 45 | def max_node(G): 46 | return max(G) 47 | 48 | 49 | def min_node(G): 50 | return min(G) 51 | 52 | 53 | def single_set_edge_between_community(graph, community, bigrams, threshold): 54 | res = [] 55 | for bigram in bigrams: 56 | if sum(1 for _ in nx.algorithms.edge_boundary( 57 | graph, community['edges_to_lowest'][bigram[0]], community['edges_to_lowest'][bigram[1]] 58 | )) >= threshold: 59 | res.append(True) 60 | else: 61 | res.append(False) 62 | return res 63 | 64 | 65 | def parallel_set_edge_between_community(graph, community, df, threshold): 66 | # TD: parallel 67 | # n_core = cpu_count() 68 | # # n_core = 4 69 | # 70 | # bigrams = df['all_bigrams'].values.tolist() 71 | # pool = Pool(processes=n_core) 72 | # results = [pool.apply_async(single_set_edge_between_community, args=( 73 | # graph, community, bigrams[int(len(bigrams) / cpu_count() * i):int(len(bigrams) / cpu_count() * (i + 1))], threshold) 74 | # ) for i in range(n_core)] 75 | # output = [p for res in [result.get() for result in results] for p in res] 76 | # df['result'] = output 77 | 78 | bigrams = df['all_bigrams'].values.tolist() 79 | df['result'] = single_set_edge_between_community( 80 | graph=graph, community=community, bigrams=bigrams, threshold=threshold 81 | ) 82 | 83 | return df 84 | 85 | 86 | def weights_init(m): 87 | if isinstance(m, torch.nn.Linear): 88 | m.weight.data = torch.nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu') 89 | 90 | 91 | def seed_everything(seed: int): 92 | import random 93 | import os 94 | import numpy as np 95 | import torch 96 | 97 | random.seed(seed) 98 | os.environ['PYTHONHASHSEED'] = str(seed) 99 | np.random.seed(seed) 100 | torch.manual_seed(seed) 101 | torch.cuda.manual_seed(seed) 102 | torch.backends.cudnn.deterministic = True 103 | torch.backends.cudnn.benchmark = True 104 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | from utils import seed_everything 5 | from dataset import get_dataset 6 | from community_detection import hierarchical_structure_generation 7 | from execution import execute_NC, execute_LP 8 | from preparation import LP_preparation, NC_preparation 9 | 10 | import warnings 11 | warnings.filterwarnings('ignore') 12 | ########################################################################### 13 | 14 | # sys.argv = [''] # execution on jupyter notebook 15 | parser = argparse.ArgumentParser() 16 | # general 17 | parser.add_argument('--comment', dest='comment', default='0', type=str, 18 | help='comment') 19 | parser.add_argument('--task', dest='task', default='LP', type=str, 20 | help='LP; NC; Inductive') 21 | parser.add_argument('--mode', dest='mode', default='baseline', type=str, 22 | help='experiment mode. E.g., baseline or basemodel') 23 | parser.add_argument('--model', dest='model', default='GCN', type=str, 24 | help='model class name. E.g., GCN, PGNN, HCGNN...') 25 | parser.add_argument('--dataset', dest='dataset', default='grid', type=str, 26 | help='cora; grid; communities; ppi') 27 | parser.add_argument('--gpu', dest='gpu', default=True, type=bool, 28 | help='whether use gpu') 29 | parser.add_argument('--SEED', dest='SEED', default=123, type=int) 30 | 31 | # dataset 32 | parser.add_argument('--ratio_sample_pos', dest='ratio_sample_pos', default=20, type=float) 33 | parser.add_argument('--use_features', dest='use_features', default=True, type=bool, 34 | help='whether use node features') 35 | parser.add_argument('--community_detection_method', dest='community_detection_method', default='Louvain', type=str, 36 | help='community detection method, default Louvain') 37 | parser.add_argument('--threshold', dest='threshold', default=1, type=int, 38 | help='the threshold for graph generation, default 1') 39 | 40 | # model 41 | parser.add_argument('--lr', dest='lr', default=1e-2, type=float) 42 | parser.add_argument('--epoch_num', dest='epoch_num', default=201, type=int) 43 | parser.add_argument('--epoch_log', dest='epoch_log', default=10, type=int) 44 | parser.add_argument('--layer_num', dest='layer_num', default=2, type=int) 45 | parser.add_argument('--relu', dest='relu', default=True, type=bool) 46 | parser.add_argument('--dropout', dest='dropout', default=False, type=bool) 47 | parser.add_argument('--drop_ratio', dest='drop_ratio', default=0.5, type=float) 48 | parser.add_argument('--feature_pre', dest='feature_pre', default=True, type=bool) 49 | parser.add_argument('--same_level_gnn', dest='same_level_gnn', default='GCN', type=str, 50 | help='agg within level. E.g., MEAN GCN, SAGE, GAT, GIN, ...') 51 | parser.add_argument('--down2up_gnn', dest='down2up_gnn', default='MEAN', type=str, 52 | help='aggregation bottom-up. E.g., MEAN, GCN, SAGE, GAT, GIN, ...') 53 | parser.add_argument('--up2down_gnn', dest='up2down_gnn', default='GAT', type=str, 54 | help='aggregation top-down. E.g., GCN, SAGE, GAT, GIN, ...') 55 | parser.add_argument('--fshot', dest='fshot', default=False, type=bool) 56 | 57 | parser.set_defaults(gpu=False, task='LP', model='GCN', dataset='grid', feature_pre=True) 58 | args = parser.parse_args() 59 | 60 | args.device = torch.device('cuda:'+str(0) if args.gpu and torch.cuda.is_available() else 'cpu') 61 | seed_everything(args.SEED) 62 | print(args, '\n') 63 | 64 | 65 | if args.task == 'LP': 66 | ls_df_friends, graphs_complete, graphs, ls_valid_edges, ls_test_edges, features, df_labels = get_dataset( 67 | dataset_name=args.dataset, 68 | use_features=args.use_features, 69 | task=args.task, 70 | ratio_sample=args.ratio_sample_pos 71 | ) 72 | ls_hierarchical_community, ls_up2down_edges, ls_down2up_edges = hierarchical_structure_generation( 73 | dataset_name=args.dataset, 74 | graphs=graphs, 75 | method=args.community_detection_method, 76 | threshold=args.threshold 77 | ) 78 | ls_adj_same_level, ls_df_train, ls_df_valid, ls_df_test = LP_preparation( 79 | graphs=graphs, 80 | ls_df_friends=ls_df_friends, 81 | ls_test_edges=ls_test_edges, 82 | ls_valid_edges=ls_valid_edges, 83 | ls_hierarchical_community=ls_hierarchical_community 84 | ) 85 | execute_LP( 86 | args, graphs, features, ls_hierarchical_community, 87 | ls_adj_same_level, ls_up2down_edges, ls_down2up_edges, 88 | ls_df_train, ls_df_valid, ls_df_test 89 | ) 90 | else: 91 | graphs, features, df_labels = get_dataset( 92 | dataset_name=args.dataset, 93 | use_features=args.use_features, 94 | task=args.task, 95 | ) 96 | ls_hierarchical_community, ls_up2down_edges, ls_down2up_edges = hierarchical_structure_generation( 97 | dataset_name=args.dataset, 98 | graphs=graphs, 99 | method=args.community_detection_method, 100 | threshold=args.threshold 101 | ) 102 | ls_adj_same_level = NC_preparation( 103 | graphs=graphs, 104 | ls_hierarchical_community=ls_hierarchical_community 105 | ) 106 | execute_NC( 107 | args, graphs, df_labels, features, ls_hierarchical_community, 108 | ls_adj_same_level, ls_up2down_edges, ls_down2up_edges, 109 | ) 110 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import random 4 | import networkx as nx 5 | from copy import deepcopy 6 | import time 7 | from sklearn import preprocessing 8 | import sys 9 | 10 | 11 | def LP_preprocessing(graphs, ratio_sample_pos_link): 12 | # graph with train and test edges 13 | graphs_complete = deepcopy(graphs) 14 | 15 | # collect negative test edges 16 | ls_test_edges_neg = [] 17 | for idx, graph in enumerate(graphs): 18 | start = time.time() 19 | print('For graph {}, we need to collect {} negative edges.'.format( 20 | idx, int(graph.number_of_edges() * (ratio_sample_pos_link / 100)) 21 | )) 22 | df_train_edges_pos = nx.to_pandas_edgelist(graph) 23 | df_test_edges_neg = pd.DataFrame( 24 | np.random.choice(list(graph.nodes()), 10 * graph.number_of_edges()), columns=['source'] 25 | ) 26 | df_test_edges_neg['target'] = np.random.choice(list(graph.nodes()), 10*graph.number_of_edges()) 27 | df_test_edges_neg = df_test_edges_neg[df_test_edges_neg['source']0 and node_count[node2] > 0: # if degree>1 58 | index_val.append(i) 59 | node_count[node1] -= 1 60 | node_count[node2] -= 1 61 | if len(index_val) == int(e * ratio_sample_pos_link / 100): 62 | break 63 | else: 64 | index_train.append(i) 65 | index_train = index_train + list(range(i + 1, e)) 66 | edges_train = edges[:, index_train] 67 | edges_test = edges[:, index_val] 68 | test_edges_pos = [[edges_test[0, i], edges_test[1, i]] for i in range(edges_test.shape[1])] 69 | G_train.remove_edges_from(test_edges_pos) 70 | if len(test_edges_pos) < int(graph.number_of_edges() * (ratio_sample_pos_link / 100)): 71 | print('For graph {}, there are only {} positive instances.'.format(idx, len(test_edges_pos))) 72 | sys.exit("Can not remove more edges.") 73 | print('Generating {} positive instances uses {:.2f} seconds.'.format(idx, time.time()-start)) 74 | graphs[idx] = G_train 75 | ls_test_edges_pos.append(test_edges_pos) 76 | 77 | # friends collections 78 | ls_df_friends = [] 79 | for idx in range(len(graphs)): 80 | df_friends = nx.to_pandas_edgelist(graphs[idx]) 81 | 82 | _x = deepcopy(df_friends) 83 | _x.columns = ['target', 'source'] 84 | df_friends = pd.concat([df_friends, _x]).reset_index(drop=True) 85 | 86 | ls_df_friends.append(df_friends) 87 | 88 | # test and valid edges collections 89 | ls_valid_edges = [] 90 | ls_test_edges = [] 91 | for idx in range(len(graphs)): 92 | valid_edges_pos = random.sample( 93 | ls_test_edges_pos[idx], int(graphs_complete[idx].number_of_edges() * (ratio_sample_pos_link / 100) / 2) 94 | ) 95 | valid_edges_neg = random.sample( 96 | ls_test_edges_neg[idx], int(graphs_complete[idx].number_of_edges() * (ratio_sample_pos_link / 100) / 2) 97 | ) 98 | test_edges_pos = [item for item in ls_test_edges_pos[idx] if item not in valid_edges_pos] 99 | test_edges_neg = [item for item in ls_test_edges_neg[idx] if item not in valid_edges_neg] 100 | test_edges = { 101 | 'positive': test_edges_pos, 102 | 'negative': test_edges_neg 103 | } 104 | valid_edges = { 105 | 'positive': valid_edges_pos, 106 | 'negative': valid_edges_neg 107 | } 108 | ls_valid_edges.append(valid_edges) 109 | ls_test_edges.append(test_edges) 110 | 111 | return ls_df_friends, graphs_complete, graphs, ls_valid_edges, ls_test_edges 112 | 113 | 114 | def get_dataset(dataset_name, use_features, task, ratio_sample: int = 0): 115 | if dataset_name == 'grid': 116 | print('is reading {} dataset...'.format(dataset_name)) 117 | graph = nx.grid_2d_graph(20, 20) 118 | graph = nx.convert_node_labels_to_integers(graph) 119 | keys = list(graph.nodes) 120 | vals = range(graph.number_of_nodes()) 121 | mapping = dict(zip(keys, vals)) 122 | graph = nx.relabel_nodes(graph, mapping, copy=True) 123 | identify_oh_feature = np.identity(graph.number_of_nodes()) 124 | graphs = [graph] 125 | features = [identify_oh_feature] 126 | print('datatset reading is done.') 127 | 128 | elif dataset_name == 'emails': 129 | print('is reading {} dataset...'.format(dataset_name)) 130 | df = pd.read_csv('./data/emails/email.txt', header=None, sep=' ', names=['source', 'target']) 131 | graph = nx.from_pandas_edgelist(df=df, source='source', target='target', edge_attr=None) 132 | 133 | df_label = pd.read_csv('./data/emails/email_labels.txt', header=None, sep=' ', names=['node_id', 'label']) 134 | df_label = df_label[df_label['label'].isin(df_label['label'].value_counts()[df_label['label'].value_counts()>20].index)] 135 | available_nodes = df_label['node_id'].unique() 136 | 137 | graph = graph.subgraph(available_nodes) 138 | keys = list(graph.nodes) 139 | vals = range(graph.number_of_nodes()) 140 | mapping = dict(zip(keys, vals)) 141 | graph = nx.relabel_nodes(graph, mapping, copy=True) 142 | 143 | df_label['node_id'] = df_label['node_id'].replace(mapping) 144 | df_label = df_label.sort_values('node_id', ascending=True).reset_index(drop=True) 145 | # ecode label into numeric 146 | le = preprocessing.LabelEncoder() 147 | df_label['label'] = le.fit_transform(df_label['label']) 148 | 149 | identify_oh_feature = np.identity(graph.number_of_nodes()) 150 | 151 | graphs = [graph] 152 | features = [identify_oh_feature] 153 | df_labels = [df_label] 154 | 155 | elif dataset_name == 'cora': 156 | print('is reading {} dataset...'.format(dataset_name)) 157 | df = pd.read_csv('./data/cora/cora.cites', header=None, sep='\t', names=['source', 'target']) 158 | graph = nx.from_pandas_edgelist(df=df, source='source', target='target', edge_attr=None) 159 | keys = list(graph.nodes) 160 | vals = range(graph.number_of_nodes()) 161 | mapping = dict(zip(keys, vals)) 162 | graph = nx.relabel_nodes(graph, mapping, copy=True) 163 | 164 | # cora feature 165 | content = pd.read_csv('./data/cora/cora.content', header=None, sep='\t') 166 | df_feat = content[range(1434)].rename(columns={0: 'node_id'}) 167 | df_label = content[[0, 1434]].rename( 168 | columns={ 169 | 0: 'node_id', 170 | 1434: 'label' 171 | } 172 | ) 173 | 174 | df_feat['node_id'] = df_feat['node_id'].replace(mapping) 175 | df_feat = df_feat.sort_values('node_id', ascending=True).reset_index(drop=True) 176 | 177 | df_label['node_id'] = df_label['node_id'].replace(mapping) 178 | df_label = df_label.sort_values('node_id', ascending=True).reset_index(drop=True) 179 | # ecode label into numeric 180 | le = preprocessing.LabelEncoder() 181 | df_label['label'] = le.fit_transform(df_label['label']) 182 | 183 | graphs = [graph] 184 | if use_features: 185 | features = [df_feat[range(1, 1434)].values] 186 | else: 187 | identify_oh_feature = np.identity(graph.number_of_nodes()) 188 | features = [identify_oh_feature] 189 | df_labels = [df_label] 190 | 191 | elif dataset_name == 'citeseer': 192 | print('is reading {} dataset...'.format(dataset_name)) 193 | df = pd.read_csv('./data/citeseer/citeseer.cites', header=None, sep='\t', names=['source', 'target']) 194 | graph = nx.from_pandas_edgelist(df=df, source='source', target='target', edge_attr=None) 195 | 196 | # citeseer feature 197 | content = pd.read_csv('./data/citeseer/citeseer.content', header=None, sep='\t') 198 | content[0] = content[0].apply(str) 199 | available_nodes = content[0].unique() 200 | df_feat = content[range(3704)].rename(columns={0: 'node_id'}) 201 | df_label = content[[0, 3704]].rename( 202 | columns={ 203 | 0: 'node_id', 204 | 3704: 'label' 205 | } 206 | ) 207 | 208 | graph = graph.subgraph(available_nodes) 209 | keys = list(graph.nodes) 210 | vals = range(graph.number_of_nodes()) 211 | mapping = dict(zip(keys, vals)) 212 | graph = nx.relabel_nodes(graph, mapping, copy=True) 213 | 214 | df_feat['node_id'] = df_feat['node_id'].replace(mapping) 215 | df_feat = df_feat.sort_values('node_id', ascending=True).reset_index(drop=True) 216 | 217 | df_label['node_id'] = df_label['node_id'].replace(mapping) 218 | df_label = df_label.sort_values('node_id', ascending=True).reset_index(drop=True) 219 | # ecode label into numeric 220 | le = preprocessing.LabelEncoder() 221 | df_label['label'] = le.fit_transform(df_label['label']) 222 | 223 | graphs = [graph] 224 | if use_features: 225 | features = [df_feat[range(1, 1434)].values] 226 | else: 227 | identify_oh_feature = np.identity(graph.number_of_nodes()) 228 | features = [identify_oh_feature] 229 | df_labels = [df_label] 230 | 231 | print(nx.info(graphs[0])) 232 | print('is processing dataset...') 233 | if task == 'LP': 234 | ls_df_friends, graphs_complete, graphs, ls_valid_edges, ls_test_edges = LP_preprocessing( 235 | graphs=graphs, ratio_sample_pos_link=ratio_sample 236 | ) 237 | df_labels = 0 238 | print('data processing is done') 239 | return ls_df_friends, graphs_complete, graphs, ls_valid_edges, ls_test_edges, features, df_labels 240 | else: 241 | return graphs, features, df_labels 242 | -------------------------------------------------------------------------------- /community_detection.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from collections import defaultdict 3 | from copy import deepcopy 4 | import community 5 | from networkx.algorithms.community.centrality import girvan_newman 6 | import itertools 7 | 8 | from utils import max_lists, min_lists, generate_bigram, parallel_set_edge_between_community 9 | 10 | 11 | def Louvain_community_detection(graphs): 12 | print('Is doing community detection....') 13 | # community_detection 14 | ls_hierarchical_community = [] 15 | 16 | for idx, G in enumerate(graphs): 17 | print('start {} subgraph...'.format(idx)) 18 | dendrogram_old = community.generate_dendrogram(G) 19 | dendrogram_new = deepcopy(dendrogram_old) 20 | hierarchical_community = [] 21 | 22 | for level in range(len(dendrogram_old)): 23 | community_tmp = {} 24 | print('iteration {}: from {} items to {} items, modularity is {}.'.format( 25 | level, len(dendrogram_old[level]), len(set(dendrogram_old[level].values())), 26 | community.modularity( 27 | community.partition_at_level(dendrogram_old, level), graph=G 28 | ) 29 | )) 30 | community_tmp['before_num_partitions'] = len(dendrogram_old[level]) 31 | community_tmp['after_num_partitions'] = len(set(dendrogram_old[level].values())) 32 | community_tmp['modularity'] = community.modularity( 33 | community.partition_at_level(dendrogram_old, level), graph=G 34 | ) 35 | 36 | if level == 0: 37 | new_update = defaultdict(list) 38 | for key, value in dendrogram_new[level].items(): 39 | new_update[value+max(dendrogram_new[level])+1].append(key) 40 | dendrogram_new[level] = new_update 41 | community_tmp['edges_to_lower'] = new_update 42 | community_tmp['edges_to_lowest'] = new_update 43 | else: 44 | # edges to lower 45 | new_update = defaultdict(list) 46 | for key, value in dendrogram_new[level].items(): 47 | new_update[value+max(dendrogram_new[level-1])+1].append( 48 | key + max_lists(dendrogram_new[level - 1].values()) + 1 49 | ) 50 | dendrogram_new[level] = new_update 51 | community_tmp['edges_to_lower'] = new_update 52 | # edges to lowest 53 | new_update_lowest = deepcopy(new_update) 54 | for key, values in new_update.items(): 55 | new_values = [] 56 | for value in values: 57 | new_values += hierarchical_community[level - 1]['edges_to_lowest'][value] 58 | new_update_lowest[key] = new_values 59 | community_tmp['edges_to_lowest'] = new_update_lowest 60 | community_tmp['partitions'] = list(dendrogram_new[level].keys()) 61 | 62 | hierarchical_community.append(community_tmp) 63 | ls_hierarchical_community.append(hierarchical_community) 64 | print('Community detection is done') 65 | return ls_hierarchical_community 66 | 67 | 68 | def GN_community_detection(graphs, dataset_name): 69 | k = {'emails': 2, 70 | 'grid': 4, 71 | 'cora': 4, 72 | 'power': 5, 73 | 'citeseer': 4, 74 | 'pubmed': 4}[dataset_name] 75 | 76 | ls_hierarchical_community = [] 77 | 78 | for idx, G in enumerate(graphs): 79 | print('start {} subgraph...'.format(idx)) 80 | G = graphs[0] 81 | comp = girvan_newman(G) 82 | ls_community = [] 83 | for idx, communities in enumerate(itertools.islice(comp, k)): 84 | ls_community.append(communities) 85 | ls_community = ls_community[::-1] # we need a pyramidal structure 86 | 87 | hierarchical_community = [] 88 | for level, community in enumerate(ls_community): 89 | community_tmp = {} 90 | if level == 0: 91 | community_tmp['before_num_partitions'] = G.number_of_nodes() 92 | community_tmp['after_num_partitions'] = len(community) 93 | new_update = defaultdict(list) 94 | for idx, com in enumerate(community): 95 | new_update[idx+max(G.node)+1] = list(com) 96 | community_tmp['edges_to_lower'] = new_update 97 | community_tmp['edges_to_lowest'] = new_update 98 | else: 99 | community_tmp['before_num_partitions'] = hierarchical_community[-1]['after_num_partitions'] 100 | community_tmp['after_num_partitions'] = len(community) 101 | # edges to lowest 102 | new_update_lowest = defaultdict(list) 103 | for idx, com in enumerate(community): 104 | new_update_lowest[idx+max(hierarchical_community[-1]['partitions'])+1] = list(com) 105 | # edges to lower 106 | new_update_lower = defaultdict(list) 107 | for idx, com in enumerate(community): 108 | for key, values in hierarchical_community[-1]['edges_to_lowest'].items(): 109 | if set(values).issubset(com): 110 | new_update_lower[idx+max(hierarchical_community[-1]['partitions'])+1].append(key) 111 | community_tmp['edges_to_lower'] = new_update_lower 112 | community_tmp['edges_to_lowest'] = new_update_lowest 113 | 114 | community_tmp['partitions'] = list(community_tmp['edges_to_lowest'].keys()) 115 | hierarchical_community.append(community_tmp) 116 | ls_hierarchical_community.append(hierarchical_community) 117 | return ls_hierarchical_community 118 | 119 | 120 | def present_community_detection_results(ls_hierarchical_community): 121 | print('is presenting hierarchical structure....') 122 | for idx, hierarchical_community in enumerate(ls_hierarchical_community): 123 | print('start graph community {}'.format(idx)) 124 | for i in range(len(hierarchical_community)): 125 | print('layer {}: keys: [{}, {}], values: [{}, {}]'.format( 126 | i, min(hierarchical_community[i]['edges_to_lower']), 127 | max(hierarchical_community[i]['edges_to_lower']), 128 | min_lists(hierarchical_community[i]['edges_to_lower'].values()), 129 | max_lists(hierarchical_community[i]['edges_to_lower'].values()))) 130 | 131 | 132 | def up2down_pipeline(graphs, ls_hierarchical_community, threshold): 133 | print('Is setting up hierarchical pipelines....') 134 | # set up up-down pipelline 135 | for idx, hierarchical_community in enumerate(ls_hierarchical_community): 136 | graph = graphs[idx] 137 | print('start graph {}'.format(idx)) 138 | for id_community, community in enumerate(hierarchical_community): 139 | df_all_bigrams = pd.DataFrame({'all_bigrams': generate_bigram(community['partitions'])}) 140 | df_all_bigrams = parallel_set_edge_between_community( 141 | graph=graph, community=community, df=df_all_bigrams, threshold=threshold 142 | ) 143 | edges_tmp = df_all_bigrams[df_all_bigrams['result']]['all_bigrams'].values.tolist() 144 | hierarchical_community[id_community]['edges'] = edges_tmp 145 | ls_hierarchical_community[idx] = hierarchical_community 146 | # # verify the correctness of up-down pipelline 147 | # for idx, hierarchical_community in enumerate(ls_hierarchical_community): 148 | # if len(hierarchical_community)==4: 149 | # for key, values in hierarchical_community[3]['edges_to_lower'].items(): 150 | # res = [] 151 | # third_keys = [] 152 | # for sec_key in values: 153 | # third_keys += hierarchical_community[2]['edges_to_lower'][sec_key] 154 | # fourth_keys = [] 155 | # for third_key in third_keys: 156 | # fourth_keys += hierarchical_community[1]['edges_to_lower'][third_key] 157 | # for fourth_key in fourth_keys: 158 | # res += hierarchical_community[0]['edges_to_lower'][fourth_key] 159 | # if res != hierarchical_community[3]['edges_to_lowest'][key]: 160 | # print(key, res, hierarchical_community[3]['edges_to_lowest'][key]) 161 | print('Hierarchical pipelines are ready') 162 | return ls_hierarchical_community 163 | 164 | 165 | def up2down_edges(ls_hierarchical_community): 166 | print('is recording up2down edges....') 167 | # all above layers messages pass to first layer 168 | ls_up2down_edges = [] 169 | ls_up2down_dicts = [] 170 | for idx, hierarchical_community in enumerate(ls_hierarchical_community): 171 | up2down_edges = defaultdict(list) 172 | up2down_dicts = defaultdict(list) 173 | 174 | for com in hierarchical_community: 175 | up2down_edges.update(com['edges_to_lowest']) 176 | ls_up2down_edges.append(dict(up2down_edges)) 177 | 178 | for com in hierarchical_community: 179 | for (key, values) in com['edges_to_lowest'].items(): 180 | for value in values: 181 | up2down_dicts[value].append(key) 182 | # set high level order 183 | for (key, values) in up2down_dicts.items(): 184 | up2down_dicts[key] = sorted(values) 185 | ls_up2down_dicts.append(dict(up2down_dicts)) 186 | 187 | for idx, up2down_edges in enumerate(ls_up2down_edges): 188 | print('keys in [{}, {}], values in [{}, {}]'.format( 189 | min(up2down_edges.keys()), 190 | max(up2down_edges.keys()), 191 | min([value for values in up2down_edges.values() for value in values]), 192 | max([value for values in up2down_edges.values() for value in values]) 193 | )) 194 | return ls_up2down_edges, ls_up2down_dicts 195 | 196 | 197 | def down2up_edges(ls_hierarchical_community): 198 | print('is recording down2up edges....') 199 | # Down to the above one 200 | ls_down2up_edges = [] 201 | for idx, hierarchical_community in enumerate(ls_hierarchical_community): 202 | down2up_edges = defaultdict(list) 203 | 204 | for com in hierarchical_community: 205 | for key, values in com['edges_to_lower'].items(): 206 | for value in values: 207 | down2up_edges[value].append(key) 208 | 209 | ls_down2up_edges.append(dict(down2up_edges)) 210 | # verify down2up 211 | for idx, down2up_edges in enumerate(ls_down2up_edges): 212 | print('keys in [{}, {}], values in [{}, {}]'.format( 213 | min(down2up_edges.keys()), 214 | max(down2up_edges.keys()), 215 | min([value for values in down2up_edges.values() for value in values]), 216 | max([value for values in down2up_edges.values() for value in values]) 217 | )) 218 | return ls_down2up_edges 219 | 220 | 221 | def hierarchical_structure_generation(dataset_name, graphs, method, threshold): 222 | print('is generating hierarchical structure....') 223 | if method == 'Louvain': 224 | ls_hierarchical_community = Louvain_community_detection(graphs=graphs) 225 | elif method == 'GN': 226 | ls_hierarchical_community = GN_community_detection(graphs=graphs, dataset_name=dataset_name) 227 | else: 228 | ls_hierarchical_community = None 229 | present_community_detection_results(ls_hierarchical_community) 230 | 231 | ls_hierarchical_community = up2down_pipeline( 232 | graphs=graphs, ls_hierarchical_community=ls_hierarchical_community, threshold=threshold 233 | ) 234 | ls_up2down_edges, ls_up2down_dicts = up2down_edges(ls_hierarchical_community) 235 | ls_down2up_edges = down2up_edges(ls_hierarchical_community) 236 | 237 | print('hierarchical structure generation is done') 238 | return ls_hierarchical_community, ls_up2down_edges, ls_down2up_edges 239 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import GCNConv, SAGEConv, GATConv, GINConv, GraphUNet 4 | from torch_geometric.utils import dropout_adj 5 | 6 | 7 | # baseline models 8 | class Baseline_GNN(torch.nn.Module): 9 | def __init__(self, config, graphs, features, ls_labels): 10 | super(Baseline_GNN, self).__init__() 11 | self.config = config 12 | self.gnn_layers = [512, 256, 128, 64, 32][-(self.config.layer_num+1):] 13 | 14 | self.same_level_gnn_layers = torch.nn.ModuleList() 15 | if self.config.model == 'GCN': 16 | if self.config.feature_pre: 17 | self.linear_pre = torch.nn.Linear(features[0].shape[1], self.gnn_layers[0]) 18 | else: 19 | self.config.gnn_layers[0] = features[0].shape[1] 20 | for idx, (in_size, out_size) in enumerate(zip(self.gnn_layers[:-1], self.gnn_layers[1:])): 21 | self.same_level_gnn_layers.append(GCNConv(in_size, out_size)) 22 | elif self.config.model == 'SAGE': 23 | if self.config.feature_pre: 24 | self.linear_pre = torch.nn.Linear(features[0].shape[1], self.gnn_layers[0]) 25 | else: 26 | self.config.gnn_layers[0] = features[0].shape[1] 27 | for idx, (in_size, out_size) in enumerate(zip(self.gnn_layers[:-1], self.gnn_layers[1:])): 28 | self.same_level_gnn_layers.append(SAGEConv(in_size, out_size)) 29 | elif self.config.model == 'GAT': 30 | if self.config.feature_pre: 31 | self.linear_pre = torch.nn.Linear(features[0].shape[1], self.gnn_layers[0]) 32 | else: 33 | self.config.gnn_layers[0] = features[0].shape[1] 34 | for idx, (in_size, out_size) in enumerate(zip(self.gnn_layers[:-1], self.gnn_layers[1:])): 35 | self.same_level_gnn_layers.append(GATConv(in_size, out_size)) 36 | elif self.config.model == 'GIN': 37 | self.same_level_lnn_layers = torch.nn.ModuleList() 38 | if self.config.feature_pre: 39 | self.linear_pre = torch.nn.Linear(features[0].shape[1], self.gnn_layers[0]) 40 | else: 41 | self.config.gnn_layers[0] = features[0].shape[1] 42 | for idx, (in_size, out_size) in enumerate(zip(self.gnn_layers[:-1], self.gnn_layers[1:])): 43 | self.same_level_lnn_layers.append(torch.nn.Linear(in_size, out_size)) 44 | self.same_level_gnn_layers.append(GINConv(self.same_level_lnn_layers[idx])) 45 | elif self.config.model == 'GUNET': 46 | if graphs[0].number_of_nodes() < 2000: 47 | pool_ratios = [200 / graphs[0].number_of_nodes(), 0.5] 48 | else: 49 | pool_ratios = [2000 / graphs[0].number_of_nodes(), 0.5] 50 | self.unet = GraphUNet( 51 | features[0].shape[1], 32, ls_labels[0].shape[1], depth=self.config.layer_num, pool_ratios=pool_ratios 52 | ) 53 | 54 | def forward(self, data): 55 | if self.config.model == 'GUNET': 56 | edge_index, _ = dropout_adj(data.edge_index, p=self.config.drop_ratio, 57 | force_undirected=True, 58 | num_nodes=data.x.shape[0], training=self.training) 59 | embed = torch.nn.functional.dropout(data.x, p=self.config.drop_ratio, training=self.training) 60 | 61 | embed = self.unet(embed, edge_index) 62 | 63 | else: 64 | x, same_level_edge_index = data.x, data.edge_index 65 | 66 | if self.config.feature_pre: 67 | embed = self.linear_pre(x) 68 | else: 69 | embed = x 70 | 71 | for idx, _ in enumerate(range(len(self.same_level_gnn_layers))): 72 | if idx != len(self.same_level_gnn_layers)-1: 73 | # same level 74 | embed = self.same_level_gnn_layers[idx](embed, same_level_edge_index) 75 | if self.config.relu: 76 | embed = torch.nn.functional.relu(embed) # Note: optional! 77 | if self.config.dropout: 78 | embed = torch.nn.functional.dropout(embed, p=self.config.drop_ratio, training=self.training) 79 | else: 80 | # same level 81 | embed = self.same_level_gnn_layers[idx](embed, same_level_edge_index) 82 | 83 | if self.config.task=='NC': 84 | embed = torch.nn.functional.log_softmax(embed, dim=1) 85 | else: 86 | embed = torch.nn.functional.normalize(embed, p=2, dim=-1) 87 | 88 | return embed 89 | 90 | 91 | # components of HC-GNN layer 92 | class Down2Up_layer(torch.nn.Module): 93 | def __init__(self, config, in_size, out_size): 94 | super(Down2Up_layer, self).__init__() 95 | self.config = config 96 | 97 | if self.config.down2up_gnn == 'GAT': 98 | self.nn = GATConv(in_size, out_size) 99 | elif self.config.down2up_gnn == 'GCN': 100 | self.nn = GCNConv(in_size, out_size) 101 | elif self.config.down2up_gnn == 'SAGE': 102 | self.nn = SAGEConv(in_size, out_size) 103 | elif self.config.down2up_gnn == 'MEAN': 104 | self.nn = False 105 | 106 | def forward(self, embedding, down2up_paths): 107 | if type(self.nn) == bool: 108 | for down2up_array in down2up_paths: 109 | update_message = torch.mm(down2up_array, embedding) 110 | embedding = embedding + update_message 111 | embedding = torch.mul(embedding, 1.0/(down2up_array.sum(-1)+1).unsqueeze(1)) 112 | else: 113 | embedding = self.nn(embedding, down2up_paths) 114 | 115 | return embedding 116 | 117 | 118 | class Up2Down_layer(torch.nn.Module): 119 | def __init__(self, config, in_size, out_size): 120 | super(Up2Down_layer, self).__init__() 121 | self.config = config 122 | if self.config.up2down_gnn == 'GAT': 123 | self.nn = GATConv(in_size, out_size) 124 | elif self.config.up2down_gnn == 'GCN': 125 | self.nn = GCNConv(in_size, out_size) 126 | elif self.config.up2down_gnn == 'SAGE': 127 | self.nn = SAGEConv(in_size, out_size) 128 | 129 | def forward(self, embedding, up2down_edge_index): 130 | embedding = self.nn(embedding, up2down_edge_index) 131 | 132 | return embedding 133 | 134 | 135 | class HCGNN_layer(torch.nn.Module): 136 | def __init__(self, config, in_size, out_size, gnn_type): 137 | super(HCGNN_layer, self).__init__() 138 | self.in_size = in_size 139 | self.out_size = out_size 140 | self.config = config 141 | 142 | self.down2up_layer = Down2Up_layer(self.config, self.in_size, self.in_size) 143 | if gnn_type == 'GCN': 144 | self.samle_level_layer = GCNConv(in_size, out_size) 145 | elif gnn_type == 'SAGE': 146 | self.samle_level_layer = SAGEConv(in_size, out_size) 147 | elif gnn_type == 'GAT': 148 | self.samle_level_layer = GATConv(in_size, out_size) 149 | elif gnn_type == 'GIN': 150 | self.same_level_lnn_layer = torch.nn.Linear(in_size, out_size) 151 | self.samle_level_layer = GINConv(self.same_level_lnn_layer) 152 | self.up2down_layer = Up2Down_layer(self.config, self.out_size, self.out_size) 153 | 154 | def forward(self, embedding, down2up_path, same_level_edge_index, up2down_edge_index): 155 | # down2up 156 | embed = self.down2up_layer(embedding=embedding, down2up_paths=down2up_path) 157 | # same level 158 | embed = self.samle_level_layer(embed, same_level_edge_index) 159 | # up2down 160 | embed = self.up2down_layer(embed, up2down_edge_index) 161 | 162 | return embed 163 | 164 | 165 | # class HCGNN_layer(torch.nn.Module): 166 | # def __init__(self, in_size, out_size, gnn_type, idx_start, idx_end): 167 | # super(HCGNN_layer, self).__init__() 168 | # self.in_size = in_size 169 | # self.out_size = out_size 170 | # self.idx_start = idx_start 171 | # self.idx_end = idx_end 172 | 173 | # if self.idx_start: 174 | # self.down2up_layer = Down2Up_layer(self.config, self.in_size, self.in_size) 175 | 176 | # if gnn_type == 'GCN': 177 | # self.samle_level_layer = GCNConv(self.in_size, self.out_size) 178 | # elif gnn_type == 'SAGE': 179 | # self.samle_level_layer = SAGEConv(self.in_size, self.out_size) 180 | # elif gnn_type == 'GAT': 181 | # self.samle_level_layer = GATConv(self.in_size, self.out_size) 182 | # elif gnn_type == 'GIN': 183 | # self.same_level_lnn_layer = torch.nn.Linear(self.in_size, self.out_size) 184 | # self.samle_level_layer = GINConv(self.same_level_lnn_layer) 185 | 186 | # if self.idx_end: 187 | # self.up2down_layer = Up2Down_layer(self.config, self.out_size, self.out_size) 188 | 189 | # def forward(self, embedding, down2up_path, same_level_edge_index, up2down_edge_index): 190 | # if self.idx_start: 191 | # # down2up 192 | # embedding = self.down2up_layer(embedding=embedding, down2up_paths=down2up_path) 193 | # # same level 194 | # embedding = self.samle_level_layer(embedding, same_level_edge_index) 195 | # if self.idx_end: 196 | # # up2down 197 | # embedding = self.up2down_layer(embedding, up2down_edge_index) 198 | 199 | # return embedding 200 | 201 | 202 | # basemodel 203 | class HCGNN(torch.nn.Module): 204 | def __init__(self, config, features): 205 | super(HCGNN, self).__init__() 206 | self.config = config 207 | self.gnn_layers = [512, 256, 128, 64, 32][-(self.config.layer_num+1):] 208 | self.cgnn_layers = torch.nn.ModuleList() 209 | 210 | if self.config.feature_pre: 211 | self.linear_pre = torch.nn.Linear(features[0].shape[1], self.gnn_layers[0]) 212 | else: 213 | self.gnn_layers[0] = features[0].shape[1] 214 | 215 | for idx, (in_size, out_size) in enumerate(zip(self.gnn_layers[:-1], self.gnn_layers[1:])): 216 | self.cgnn_layers.append(HCGNN_layer(self.config, in_size, out_size, gnn_type=self.config.same_level_gnn)) 217 | 218 | def forward(self, data, data_up2down, data_down2up, down2up_torch_arrays): 219 | x, same_level_edge_index = data.x, data.edge_index 220 | _, up2down_edge_index = data_up2down.x, data_up2down.edge_index 221 | if self.config.down2up_gnn != 'MEAN': 222 | _, down2up_edge_index = data_up2down.x, data_down2up.edge_index 223 | 224 | if self.config.feature_pre: 225 | embed = self.linear_pre(x) 226 | else: 227 | embed = x 228 | 229 | if len(self.cgnn_layers)==1: 230 | if self.config.down2up_gnn=='MEAN': 231 | embed = self.cgnn_layers[0](embed, down2up_torch_arrays, same_level_edge_index, up2down_edge_index) 232 | else: 233 | embed = self.cgnn_layers[0](embed, down2up_edge_index, same_level_edge_index, up2down_edge_index) 234 | else: 235 | for idx in range(len(self.cgnn_layers)): 236 | if idx != len(self.cgnn_layers)-1: 237 | if self.config.down2up_gnn=='MEAN': 238 | embed = self.cgnn_layers[idx]( 239 | embed, down2up_torch_arrays, same_level_edge_index, up2down_edge_index 240 | ) 241 | else: 242 | embed = self.cgnn_layers[idx]( 243 | embed, down2up_edge_index, same_level_edge_index, up2down_edge_index 244 | ) 245 | if self.config.relu: 246 | embed = torch.nn.functional.relu(embed) # Note: optional! 247 | if self.config.dropout: 248 | embed = torch.nn.functional.dropout(embed, p=self.config.drop_ratio, training=self.training) 249 | else: 250 | if self.config.down2up_gnn=='MEAN': 251 | embed = self.cgnn_layers[idx]( 252 | embed, down2up_torch_arrays, same_level_edge_index, up2down_edge_index 253 | ) 254 | else: 255 | embed = self.cgnn_layers[idx]( 256 | embed, down2up_edge_index, same_level_edge_index, up2down_edge_index 257 | ) 258 | 259 | embed = torch.nn.functional.normalize(embed, p=2, dim=-1) 260 | if self.config.task == 'NC': 261 | embed = torch.nn.functional.log_softmax(embed, dim=1) 262 | return embed 263 | -------------------------------------------------------------------------------- /data/emails/email_labels.txt: -------------------------------------------------------------------------------- 1 | 0 1 2 | 1 1 3 | 2 21 4 | 3 21 5 | 4 21 6 | 5 25 7 | 6 25 8 | 7 14 9 | 8 14 10 | 9 14 11 | 10 9 12 | 11 14 13 | 12 14 14 | 13 26 15 | 14 4 16 | 15 17 17 | 16 34 18 | 17 1 19 | 18 1 20 | 19 14 21 | 20 9 22 | 21 9 23 | 22 9 24 | 23 11 25 | 24 11 26 | 25 11 27 | 26 11 28 | 27 11 29 | 28 11 30 | 29 11 31 | 30 11 32 | 31 11 33 | 32 11 34 | 33 11 35 | 34 11 36 | 35 11 37 | 36 11 38 | 37 11 39 | 38 11 40 | 39 11 41 | 40 11 42 | 41 5 43 | 42 34 44 | 43 14 45 | 44 14 46 | 45 17 47 | 46 17 48 | 47 10 49 | 48 10 50 | 49 36 51 | 50 37 52 | 51 5 53 | 52 7 54 | 53 4 55 | 54 22 56 | 55 22 57 | 56 21 58 | 57 21 59 | 58 21 60 | 59 21 61 | 60 7 62 | 61 7 63 | 62 36 64 | 63 21 65 | 64 25 66 | 65 4 67 | 66 8 68 | 67 15 69 | 68 15 70 | 69 15 71 | 70 37 72 | 71 37 73 | 72 9 74 | 73 1 75 | 74 1 76 | 75 10 77 | 76 10 78 | 77 3 79 | 78 3 80 | 79 3 81 | 80 29 82 | 81 15 83 | 82 36 84 | 83 36 85 | 84 37 86 | 85 1 87 | 86 36 88 | 87 34 89 | 88 20 90 | 89 20 91 | 90 8 92 | 91 15 93 | 92 9 94 | 93 4 95 | 94 5 96 | 95 4 97 | 96 20 98 | 97 16 99 | 98 16 100 | 99 16 101 | 100 16 102 | 101 16 103 | 102 38 104 | 103 7 105 | 104 7 106 | 105 34 107 | 106 38 108 | 107 36 109 | 108 8 110 | 109 27 111 | 110 8 112 | 111 8 113 | 112 8 114 | 113 10 115 | 114 10 116 | 115 13 117 | 116 13 118 | 117 6 119 | 118 26 120 | 119 10 121 | 120 1 122 | 121 36 123 | 122 0 124 | 123 13 125 | 124 16 126 | 125 16 127 | 126 22 128 | 127 6 129 | 128 5 130 | 129 4 131 | 130 0 132 | 131 28 133 | 132 28 134 | 133 4 135 | 134 2 136 | 135 13 137 | 136 13 138 | 137 21 139 | 138 21 140 | 139 17 141 | 140 17 142 | 141 14 143 | 142 36 144 | 143 8 145 | 144 40 146 | 145 35 147 | 146 15 148 | 147 23 149 | 148 0 150 | 149 0 151 | 150 7 152 | 151 10 153 | 152 37 154 | 153 27 155 | 154 35 156 | 155 35 157 | 156 0 158 | 157 0 159 | 158 19 160 | 159 19 161 | 160 36 162 | 161 14 163 | 162 37 164 | 163 24 165 | 164 17 166 | 165 13 167 | 166 36 168 | 167 4 169 | 168 4 170 | 169 13 171 | 170 13 172 | 171 10 173 | 172 4 174 | 173 38 175 | 174 32 176 | 175 32 177 | 176 4 178 | 177 1 179 | 178 0 180 | 179 0 181 | 180 0 182 | 181 7 183 | 182 7 184 | 183 4 185 | 184 15 186 | 185 16 187 | 186 40 188 | 187 15 189 | 188 15 190 | 189 15 191 | 190 15 192 | 191 0 193 | 192 21 194 | 193 21 195 | 194 21 196 | 195 21 197 | 196 5 198 | 197 4 199 | 198 4 200 | 199 4 201 | 200 4 202 | 201 4 203 | 202 4 204 | 203 4 205 | 204 5 206 | 205 5 207 | 206 4 208 | 207 4 209 | 208 22 210 | 209 19 211 | 210 19 212 | 211 22 213 | 212 34 214 | 213 14 215 | 214 0 216 | 215 1 217 | 216 17 218 | 217 37 219 | 218 1 220 | 219 1 221 | 220 1 222 | 221 1 223 | 222 1 224 | 223 1 225 | 224 1 226 | 225 1 227 | 226 1 228 | 227 1 229 | 228 1 230 | 229 10 231 | 230 23 232 | 231 0 233 | 232 4 234 | 233 19 235 | 234 19 236 | 235 19 237 | 236 19 238 | 237 19 239 | 238 19 240 | 239 19 241 | 240 19 242 | 241 19 243 | 242 19 244 | 243 19 245 | 244 19 246 | 245 10 247 | 246 14 248 | 247 14 249 | 248 1 250 | 249 14 251 | 250 7 252 | 251 13 253 | 252 20 254 | 253 31 255 | 254 40 256 | 255 6 257 | 256 4 258 | 257 0 259 | 258 8 260 | 259 9 261 | 260 9 262 | 261 10 263 | 262 0 264 | 263 10 265 | 264 14 266 | 265 14 267 | 266 14 268 | 267 14 269 | 268 39 270 | 269 17 271 | 270 4 272 | 271 28 273 | 272 17 274 | 273 17 275 | 274 17 276 | 275 4 277 | 276 4 278 | 277 0 279 | 278 0 280 | 279 23 281 | 280 4 282 | 281 21 283 | 282 36 284 | 283 36 285 | 284 0 286 | 285 22 287 | 286 21 288 | 287 15 289 | 288 37 290 | 289 0 291 | 290 4 292 | 291 4 293 | 292 4 294 | 293 14 295 | 294 4 296 | 295 7 297 | 296 7 298 | 297 1 299 | 298 15 300 | 299 15 301 | 300 38 302 | 301 26 303 | 302 20 304 | 303 20 305 | 304 20 306 | 305 21 307 | 306 9 308 | 307 1 309 | 308 1 310 | 309 1 311 | 310 1 312 | 311 1 313 | 312 1 314 | 313 1 315 | 314 1 316 | 315 1 317 | 316 1 318 | 317 1 319 | 318 10 320 | 319 19 321 | 320 7 322 | 321 7 323 | 322 17 324 | 323 16 325 | 324 14 326 | 325 9 327 | 326 9 328 | 327 9 329 | 328 8 330 | 329 8 331 | 330 13 332 | 331 39 333 | 332 14 334 | 333 10 335 | 334 17 336 | 335 17 337 | 336 13 338 | 337 13 339 | 338 13 340 | 339 13 341 | 340 2 342 | 341 1 343 | 342 0 344 | 343 0 345 | 344 0 346 | 345 0 347 | 346 0 348 | 347 0 349 | 348 0 350 | 349 0 351 | 350 0 352 | 351 0 353 | 352 0 354 | 353 16 355 | 354 16 356 | 355 27 357 | 356 8 358 | 357 8 359 | 358 14 360 | 359 14 361 | 360 14 362 | 361 10 363 | 362 14 364 | 363 35 365 | 364 37 366 | 365 14 367 | 366 36 368 | 367 10 369 | 368 7 370 | 369 20 371 | 370 10 372 | 371 16 373 | 372 36 374 | 373 36 375 | 374 14 376 | 375 8 377 | 376 7 378 | 377 7 379 | 378 7 380 | 379 7 381 | 380 7 382 | 381 7 383 | 382 7 384 | 383 7 385 | 384 7 386 | 385 7 387 | 386 7 388 | 387 7 389 | 388 7 390 | 389 7 391 | 390 7 392 | 391 7 393 | 392 7 394 | 393 7 395 | 394 7 396 | 395 7 397 | 396 7 398 | 397 7 399 | 398 7 400 | 399 4 401 | 400 9 402 | 401 4 403 | 402 0 404 | 403 4 405 | 404 16 406 | 405 38 407 | 406 14 408 | 407 14 409 | 408 21 410 | 409 26 411 | 410 27 412 | 411 28 413 | 412 21 414 | 413 4 415 | 414 1 416 | 415 1 417 | 416 9 418 | 417 10 419 | 418 15 420 | 419 4 421 | 420 26 422 | 421 14 423 | 422 35 424 | 423 10 425 | 424 34 426 | 425 4 427 | 426 4 428 | 427 12 429 | 428 17 430 | 429 17 431 | 430 14 432 | 431 37 433 | 432 37 434 | 433 37 435 | 434 34 436 | 435 6 437 | 436 13 438 | 437 13 439 | 438 13 440 | 439 13 441 | 440 4 442 | 441 14 443 | 442 10 444 | 443 10 445 | 444 10 446 | 445 3 447 | 446 17 448 | 447 17 449 | 448 17 450 | 449 1 451 | 450 4 452 | 451 14 453 | 452 14 454 | 453 6 455 | 454 27 456 | 455 22 457 | 456 21 458 | 457 4 459 | 458 4 460 | 459 1 461 | 460 34 462 | 461 17 463 | 462 30 464 | 463 30 465 | 464 4 466 | 465 23 467 | 466 14 468 | 467 15 469 | 468 1 470 | 469 22 471 | 470 12 472 | 471 31 473 | 472 6 474 | 473 15 475 | 474 15 476 | 475 8 477 | 476 15 478 | 477 8 479 | 478 8 480 | 479 1 481 | 480 15 482 | 481 22 483 | 482 2 484 | 483 3 485 | 484 4 486 | 485 10 487 | 486 4 488 | 487 14 489 | 488 14 490 | 489 25 491 | 490 6 492 | 491 6 493 | 492 40 494 | 493 4 495 | 494 36 496 | 495 23 497 | 496 14 498 | 497 3 499 | 498 14 500 | 499 14 501 | 500 14 502 | 501 14 503 | 502 14 504 | 503 14 505 | 504 14 506 | 505 14 507 | 506 14 508 | 507 31 509 | 508 15 510 | 509 15 511 | 510 14 512 | 511 0 513 | 512 23 514 | 513 35 515 | 514 8 516 | 515 4 517 | 516 1 518 | 517 1 519 | 518 35 520 | 519 23 521 | 520 21 522 | 521 2 523 | 522 4 524 | 523 4 525 | 524 9 526 | 525 14 527 | 526 4 528 | 527 10 529 | 528 25 530 | 529 14 531 | 530 14 532 | 531 3 533 | 532 21 534 | 533 35 535 | 534 4 536 | 535 9 537 | 536 15 538 | 537 6 539 | 538 9 540 | 539 3 541 | 540 15 542 | 541 23 543 | 542 4 544 | 543 4 545 | 544 4 546 | 545 11 547 | 546 35 548 | 547 10 549 | 548 6 550 | 549 15 551 | 550 15 552 | 551 15 553 | 552 22 554 | 553 2 555 | 554 2 556 | 555 14 557 | 556 4 558 | 557 3 559 | 558 14 560 | 559 27 561 | 560 31 562 | 561 34 563 | 562 4 564 | 563 4 565 | 564 19 566 | 565 14 567 | 566 14 568 | 567 4 569 | 568 4 570 | 569 14 571 | 570 14 572 | 571 21 573 | 572 4 574 | 573 14 575 | 574 4 576 | 575 0 577 | 576 4 578 | 577 27 579 | 578 27 580 | 579 17 581 | 580 16 582 | 581 3 583 | 582 15 584 | 583 2 585 | 584 4 586 | 585 4 587 | 586 21 588 | 587 21 589 | 588 11 590 | 589 23 591 | 590 11 592 | 591 23 593 | 592 17 594 | 593 5 595 | 594 36 596 | 595 15 597 | 596 23 598 | 597 23 599 | 598 2 600 | 599 19 601 | 600 4 602 | 601 36 603 | 602 14 604 | 603 1 605 | 604 22 606 | 605 1 607 | 606 21 608 | 607 34 609 | 608 14 610 | 609 13 611 | 610 6 612 | 611 4 613 | 612 37 614 | 613 6 615 | 614 24 616 | 615 35 617 | 616 6 618 | 617 17 619 | 618 16 620 | 619 6 621 | 620 4 622 | 621 0 623 | 622 21 624 | 623 4 625 | 624 26 626 | 625 21 627 | 626 4 628 | 627 15 629 | 628 7 630 | 629 1 631 | 630 20 632 | 631 19 633 | 632 7 634 | 633 21 635 | 634 21 636 | 635 21 637 | 636 21 638 | 637 19 639 | 638 38 640 | 639 19 641 | 640 16 642 | 641 23 643 | 642 6 644 | 643 37 645 | 644 25 646 | 645 1 647 | 646 22 648 | 647 6 649 | 648 21 650 | 649 14 651 | 650 1 652 | 651 26 653 | 652 8 654 | 653 7 655 | 654 37 656 | 655 4 657 | 656 0 658 | 657 17 659 | 658 14 660 | 659 6 661 | 660 17 662 | 661 14 663 | 662 16 664 | 663 15 665 | 664 4 666 | 665 32 667 | 666 14 668 | 667 15 669 | 668 0 670 | 669 23 671 | 670 21 672 | 671 29 673 | 672 14 674 | 673 23 675 | 674 14 676 | 675 1 677 | 676 17 678 | 677 26 679 | 678 15 680 | 679 29 681 | 680 0 682 | 681 0 683 | 682 0 684 | 683 22 685 | 684 34 686 | 685 21 687 | 686 6 688 | 687 16 689 | 688 4 690 | 689 4 691 | 690 15 692 | 691 21 693 | 692 0 694 | 693 36 695 | 694 4 696 | 695 23 697 | 696 1 698 | 697 1 699 | 698 22 700 | 699 14 701 | 700 14 702 | 701 30 703 | 702 4 704 | 703 9 705 | 704 10 706 | 705 4 707 | 706 4 708 | 707 14 709 | 708 16 710 | 709 16 711 | 710 15 712 | 711 21 713 | 712 0 714 | 713 15 715 | 714 4 716 | 715 15 717 | 716 29 718 | 717 24 719 | 718 21 720 | 719 7 721 | 720 14 722 | 721 11 723 | 722 11 724 | 723 9 725 | 724 13 726 | 725 10 727 | 726 31 728 | 727 4 729 | 728 22 730 | 729 14 731 | 730 23 732 | 731 1 733 | 732 4 734 | 733 9 735 | 734 1 736 | 735 17 737 | 736 27 738 | 737 28 739 | 738 22 740 | 739 14 741 | 740 20 742 | 741 7 743 | 742 23 744 | 743 1 745 | 744 4 746 | 745 6 747 | 746 15 748 | 747 15 749 | 748 23 750 | 749 4 751 | 750 20 752 | 751 5 753 | 752 36 754 | 753 10 755 | 754 14 756 | 755 21 757 | 756 39 758 | 757 10 759 | 758 41 760 | 759 31 761 | 760 17 762 | 761 7 763 | 762 21 764 | 763 34 765 | 764 1 766 | 765 14 767 | 766 2 768 | 767 18 769 | 768 16 770 | 769 27 771 | 770 16 772 | 771 38 773 | 772 7 774 | 773 38 775 | 774 21 776 | 775 1 777 | 776 5 778 | 777 9 779 | 778 15 780 | 779 15 781 | 780 15 782 | 781 0 783 | 782 6 784 | 783 23 785 | 784 28 786 | 785 11 787 | 786 23 788 | 787 34 789 | 788 24 790 | 789 4 791 | 790 4 792 | 791 4 793 | 792 24 794 | 793 23 795 | 794 17 796 | 795 10 797 | 796 17 798 | 797 1 799 | 798 1 800 | 799 15 801 | 800 15 802 | 801 4 803 | 802 4 804 | 803 21 805 | 804 14 806 | 805 14 807 | 806 20 808 | 807 28 809 | 808 20 810 | 809 22 811 | 810 26 812 | 811 3 813 | 812 32 814 | 813 4 815 | 814 0 816 | 815 21 817 | 816 13 818 | 817 4 819 | 818 15 820 | 819 17 821 | 820 5 822 | 821 24 823 | 822 4 824 | 823 14 825 | 824 0 826 | 825 9 827 | 826 21 828 | 827 14 829 | 828 38 830 | 829 4 831 | 830 14 832 | 831 31 833 | 832 21 834 | 833 14 835 | 834 6 836 | 835 4 837 | 836 4 838 | 837 6 839 | 838 17 840 | 839 0 841 | 840 4 842 | 841 7 843 | 842 16 844 | 843 4 845 | 844 4 846 | 845 21 847 | 846 1 848 | 847 10 849 | 848 3 850 | 849 21 851 | 850 4 852 | 851 0 853 | 852 1 854 | 853 7 855 | 854 17 856 | 855 15 857 | 856 14 858 | 857 0 859 | 858 9 860 | 859 32 861 | 860 13 862 | 861 5 863 | 862 2 864 | 863 21 865 | 864 28 866 | 865 21 867 | 866 22 868 | 867 22 869 | 868 7 870 | 869 7 871 | 870 33 872 | 871 0 873 | 872 1 874 | 873 15 875 | 874 4 876 | 875 31 877 | 876 30 878 | 877 15 879 | 878 11 880 | 879 19 881 | 880 21 882 | 881 9 883 | 882 21 884 | 883 13 885 | 884 21 886 | 885 9 887 | 886 32 888 | 887 9 889 | 888 32 890 | 889 38 891 | 890 9 892 | 891 38 893 | 892 38 894 | 893 14 895 | 894 9 896 | 895 10 897 | 896 38 898 | 897 10 899 | 898 22 900 | 899 21 901 | 900 13 902 | 901 21 903 | 902 4 904 | 903 0 905 | 904 1 906 | 905 1 907 | 906 23 908 | 907 0 909 | 908 5 910 | 909 4 911 | 910 4 912 | 911 15 913 | 912 14 914 | 913 14 915 | 914 13 916 | 915 11 917 | 916 1 918 | 917 5 919 | 918 5 920 | 919 10 921 | 920 23 922 | 921 21 923 | 922 14 924 | 923 9 925 | 924 20 926 | 925 10 927 | 926 19 928 | 927 19 929 | 928 21 930 | 929 17 931 | 930 19 932 | 931 19 933 | 932 36 934 | 933 17 935 | 934 35 936 | 935 16 937 | 936 4 938 | 937 16 939 | 938 4 940 | 939 6 941 | 940 4 942 | 941 41 943 | 942 6 944 | 943 7 945 | 944 23 946 | 945 9 947 | 946 23 948 | 947 7 949 | 948 6 950 | 949 22 951 | 950 36 952 | 951 14 953 | 952 15 954 | 953 11 955 | 954 35 956 | 955 5 957 | 956 14 958 | 957 14 959 | 958 15 960 | 959 4 961 | 960 6 962 | 961 4 963 | 962 9 964 | 963 19 965 | 964 11 966 | 965 4 967 | 966 29 968 | 967 14 969 | 968 15 970 | 969 15 971 | 970 5 972 | 971 32 973 | 972 15 974 | 973 14 975 | 974 5 976 | 975 9 977 | 976 10 978 | 977 19 979 | 978 13 980 | 979 23 981 | 980 12 982 | 981 10 983 | 982 21 984 | 983 10 985 | 984 35 986 | 985 7 987 | 986 22 988 | 987 22 989 | 988 22 990 | 989 8 991 | 990 21 992 | 991 32 993 | 992 4 994 | 993 21 995 | 994 21 996 | 995 6 997 | 996 14 998 | 997 11 999 | 998 14 1000 | 999 15 1001 | 1000 4 1002 | 1001 21 1003 | 1002 1 1004 | 1003 6 1005 | 1004 22 1006 | -------------------------------------------------------------------------------- /execution.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import random 4 | import torch 5 | 6 | from utils import weights_init, evaluate_results 7 | from preparation import LP_set_up, NC_set_up 8 | from model import Baseline_GNN, HCGNN 9 | 10 | 11 | def execute_LP( 12 | args, graphs, features, ls_hierarchical_community, 13 | ls_adj_same_level, ls_up2down_edges, ls_down2up_edges, 14 | ls_df_train, ls_df_valid, ls_df_test 15 | ): 16 | ls_data, ls_data_up2down, ls_data_down2up, ls_down2up_torch_arrays, ls_train_user_left, ls_train_user_right, ls_valid_user_left, ls_valid_user_right, ls_test_user_left, ls_test_user_right, ls_train_labels, ls_train_labels_tensor, ls_valid_labels, ls_test_labels = LP_set_up( 17 | config=args, graphs=graphs, 18 | features=features, 19 | ls_hierarchical_community=ls_hierarchical_community, 20 | ls_adj_same_level=ls_adj_same_level, 21 | ls_up2down_edges=ls_up2down_edges, 22 | ls_down2up_edges=ls_down2up_edges, 23 | ls_df_train=ls_df_train, 24 | ls_df_valid=ls_df_valid, 25 | ls_df_test=ls_df_test, device=args.device 26 | ) 27 | 28 | if args.mode == 'baseline': 29 | model = Baseline_GNN(config=args, graphs=graphs, features=features, ls_labels=None) 30 | else: 31 | model = HCGNN(config=args, features=features) 32 | model = model.to(args.device) 33 | model.apply(weights_init) 34 | print(model) 35 | 36 | optimizer = torch.optim.Adam( 37 | model.parameters(), 38 | lr=args.lr, 39 | weight_decay=5e-4 40 | ) 41 | 42 | loss_func = torch.nn.BCEWithLogitsLoss() 43 | out_act = torch.nn.Sigmoid() 44 | 45 | train_results = [] 46 | test_results = [] 47 | valid_results = [] 48 | 49 | for epoch_id, epoch in enumerate(range(args.epoch_num)): 50 | start_epoch = time.time() 51 | if epoch_id % args.epoch_log == 0: 52 | print('Epoch {} starts !'.format(epoch_id)) 53 | print('-' * 80) 54 | total_loss = 0 55 | 56 | for idx, data in enumerate(ls_data): 57 | data_up2down = ls_data_up2down[idx] 58 | if args.down2up_gnn == 'MEAN': 59 | data_down2up = [0] 60 | down2up_torch_arrays = ls_down2up_torch_arrays[idx] 61 | else: 62 | data_down2up = ls_data_down2up[idx] 63 | down2up_torch_arrays = [0] 64 | train_user_left = ls_train_user_left[idx] 65 | train_user_right = ls_train_user_right[idx] 66 | train_labels_tensor = ls_train_labels_tensor[idx] 67 | 68 | model.train() 69 | optimizer.zero_grad() 70 | 71 | if args.mode == 'baseline': 72 | out = model.forward(data) 73 | else: 74 | out = model.forward( 75 | data=data, data_up2down=data_up2down, data_down2up=data_down2up, 76 | down2up_torch_arrays=down2up_torch_arrays 77 | ) 78 | 79 | nodes_left = torch.index_select(out, 0, train_user_left) 80 | nodes_right = torch.index_select(out, 0, train_user_right) 81 | preds = torch.sum(nodes_left * nodes_right, dim=-1) 82 | loss = loss_func(preds, train_labels_tensor).to(args.device) 83 | 84 | # update 85 | loss.backward() 86 | optimizer.step() 87 | optimizer.zero_grad() 88 | total_loss += loss.cpu().item() 89 | 90 | # evaluate epoch 91 | if epoch_id % args.epoch_log == 0: 92 | model.eval() 93 | epoch_train_results = [] 94 | epoch_test_results = [] 95 | epoch_valid_results = [] 96 | for idx, data in enumerate(ls_data): 97 | data_up2down = ls_data_up2down[idx] 98 | down2up_torch_arrays = ls_down2up_torch_arrays[idx] 99 | train_user_left = ls_train_user_left[idx] 100 | train_user_right = ls_train_user_right[idx] 101 | train_labels = ls_train_labels[idx] 102 | valid_user_left = ls_valid_user_left[idx] 103 | valid_user_right = ls_valid_user_right[idx] 104 | valid_labels = ls_valid_labels[idx] 105 | test_user_left = ls_test_user_left[idx] 106 | test_user_right = ls_test_user_right[idx] 107 | test_labels = ls_test_labels[idx] 108 | 109 | if args.mode == 'baseline': 110 | out = model.forward(data) 111 | else: 112 | out = model.forward( 113 | data=data, data_up2down=data_up2down, data_down2up=data_down2up, 114 | down2up_torch_arrays=down2up_torch_arrays 115 | ) 116 | nodes_left_train = torch.index_select(out, 0, train_user_left) 117 | nodes_right_train = torch.index_select(out, 0, train_user_right) 118 | pred_train = out_act(torch.sum(nodes_left_train * nodes_right_train, dim=-1)) 119 | pred_train = np.array(out_act(pred_train).view(-1).tolist()) 120 | y_train = np.array(train_labels) 121 | 122 | nodes_left_valid = torch.index_select(out, 0, valid_user_left) 123 | nodes_right_valid = torch.index_select(out, 0, valid_user_right) 124 | pred_valid = out_act(torch.sum(nodes_left_valid * nodes_right_valid, dim=-1)) 125 | pred_valid = np.array(out_act(pred_valid).view(-1).tolist()) 126 | y_valid = np.array(valid_labels) 127 | 128 | nodes_left_test = torch.index_select(out, 0, test_user_left) 129 | nodes_right_test = torch.index_select(out, 0, test_user_right) 130 | pred_test = out_act(torch.sum(nodes_left_test * nodes_right_test, dim=-1)) 131 | pred_test = np.array(out_act(pred_test).view(-1).tolist()) 132 | y_test = np.array(test_labels) 133 | 134 | epoch_train_results.append(evaluate_results( 135 | pred=pred_train, y=y_train, method='roc-auc' 136 | )) 137 | epoch_valid_results.append(evaluate_results( 138 | pred=pred_valid, y=y_valid, method='roc-auc' 139 | )) 140 | epoch_test_results.append(evaluate_results( 141 | pred=pred_test, y=y_test, method='roc-auc' 142 | )) 143 | print('Evaluating Epoch {}, time {:.3f}, ROC-AUC: Train = {:.4f}, Valid = {:.4f}, Test = {:.4f}'.format( 144 | epoch_id, time.time() - start_epoch, 145 | np.mean(epoch_train_results), np.mean(epoch_valid_results), np.mean(epoch_test_results) 146 | )) 147 | train_results.append(np.mean(epoch_train_results)) 148 | valid_results.append(np.mean(epoch_valid_results)) 149 | test_results.append(np.mean(epoch_test_results)) 150 | print('Best valid performance is {:.4f}, best test performance is {:.4f} and epoch_id is {}'.format( 151 | max(valid_results), 152 | test_results[valid_results.index(max(valid_results))], 153 | args.epoch_log * valid_results.index(max(valid_results)) 154 | )) 155 | 156 | 157 | def execute_NC( 158 | args, graphs, df_labels, features, ls_hierarchical_community, 159 | ls_adj_same_level, ls_up2down_edges, ls_down2up_edges 160 | ): 161 | ls_data, features, ls_data_up2down, ls_data_down2up, ls_down2up_torch_arrays, ls_train_nodes, ls_valid_nodes, ls_test_nodes, ls_labels, ls_labels_tensor = NC_set_up( 162 | config=args, graphs=graphs, df_labels=df_labels, features=features, 163 | ls_hierarchical_community=ls_hierarchical_community, 164 | ls_adj_same_level=ls_adj_same_level, ls_up2down_edges=ls_up2down_edges, ls_down2up_edges=ls_down2up_edges, 165 | device=args.device 166 | ) 167 | 168 | if args.mode == 'baseline': 169 | model = Baseline_GNN(config=args, graphs=graphs, features=features, ls_labels=None) 170 | else: 171 | model = HCGNN(config=args, features=features) 172 | model = model.to(args.device) 173 | model.apply(weights_init) 174 | print(model) 175 | 176 | optimizer = torch.optim.Adam( 177 | model.parameters(), lr=args.lr, weight_decay=5e-4 178 | ) 179 | loss_func = torch.nn.NLLLoss() 180 | 181 | train_f1_micro = [] 182 | valid_f1_micro = [] 183 | test_f1_micro = [] 184 | train_f1_macro = [] 185 | valid_f1_macro = [] 186 | test_f1_macro = [] 187 | train_nmi = [] 188 | valid_nmi = [] 189 | test_nmi = [] 190 | 191 | for epoch_id, epoch in enumerate(range(args.epoch_num)): 192 | start_epoch = time.time() 193 | if epoch_id % args.epoch_log == 0: 194 | print('Epoch {} starts !'.format(epoch_id)) 195 | print('-' * 80) 196 | total_loss = 0 197 | 198 | for idx, _ in enumerate(graphs): 199 | data = ls_data[idx] 200 | data_up2down = ls_data_up2down[idx] 201 | if args.down2up_gnn == 'MEAN': 202 | data_down2up = [0] 203 | down2up_torch_arrays = ls_down2up_torch_arrays[idx] 204 | else: 205 | data_down2up = ls_data_down2up[idx] 206 | down2up_torch_arrays = [0] 207 | train_nodes = ls_train_nodes[idx] 208 | labels_tensor = ls_labels_tensor[idx] 209 | 210 | model.train() 211 | optimizer.zero_grad() 212 | 213 | if args.mode == 'baseline': 214 | out = model.forward(data) 215 | else: 216 | out = model.forward( 217 | data=data, data_up2down=data_up2down, data_down2up=data_down2up, 218 | down2up_torch_arrays=down2up_torch_arrays 219 | ) 220 | 221 | pred_train = torch.index_select(out, 0, torch.from_numpy(train_nodes).long().to(args.device)) 222 | loss = loss_func(pred_train, labels_tensor[train_nodes]).to(args.device) 223 | 224 | # update 225 | loss.backward() 226 | optimizer.step() 227 | optimizer.zero_grad() 228 | total_loss += loss.cpu().item() 229 | 230 | # evaluate epoch 231 | model.eval() 232 | epoch_train_f1_micro = [] 233 | epoch_valid_f1_micro = [] 234 | epoch_test_f1_micro = [] 235 | epoch_train_f1_macro = [] 236 | epoch_valid_f1_macro = [] 237 | epoch_test_f1_macro = [] 238 | epoch_train_nmi = [] 239 | epoch_valid_nmi = [] 240 | epoch_test_nmi = [] 241 | for idx, data in enumerate(ls_data): 242 | data_up2down = ls_data_up2down[idx] 243 | if args.down2up_gnn == 'MEAN': 244 | data_down2up = [0] 245 | down2up_torch_arrays = ls_down2up_torch_arrays[idx] 246 | else: 247 | data_down2up = ls_data_down2up[idx] 248 | down2up_torch_arrays = [0] 249 | train_nodes = ls_train_nodes[idx] 250 | valid_nodes = ls_valid_nodes[idx] 251 | test_nodes = ls_test_nodes[idx] 252 | labels = ls_labels[idx] 253 | 254 | if args.mode == 'baseline': 255 | out = model.forward(data) 256 | else: 257 | out = model.forward( 258 | data=data, data_up2down=data_up2down, data_down2up=data_down2up, 259 | down2up_torch_arrays=down2up_torch_arrays 260 | ) 261 | 262 | pred_train = torch.index_select(out, 0, torch.from_numpy(train_nodes).long().to(args.device)) 263 | pred_valid = torch.index_select(out, 0, torch.from_numpy(valid_nodes).long().to(args.device)) 264 | pred_test = torch.index_select(out, 0, torch.from_numpy(test_nodes).long().to(args.device)) 265 | 266 | if epoch_id % args.epoch_log == 0: 267 | if args.dataset not in ['emails']: 268 | epoch_train_f1_micro.append((evaluate_results( 269 | pred=pred_train, y=labels, idx=train_nodes, method='mic-f1' 270 | ))) 271 | epoch_valid_f1_micro.append((evaluate_results( 272 | pred=pred_valid, y=labels, idx=valid_nodes, method='mic-f1' 273 | ))) 274 | epoch_test_f1_micro.append((evaluate_results( 275 | pred=pred_test, y=labels, idx=test_nodes, method='mic-f1' 276 | ))) 277 | epoch_train_f1_macro.append((evaluate_results( 278 | pred=pred_train, y=labels, idx=train_nodes, method='mac-f1' 279 | ))) 280 | epoch_valid_f1_macro.append((evaluate_results( 281 | pred=pred_valid, y=labels, idx=valid_nodes, method='mac-f1' 282 | ))) 283 | epoch_test_f1_macro.append((evaluate_results( 284 | pred=pred_test, y=labels, idx=test_nodes, method='mac-f1' 285 | ))) 286 | print('Evaluating Epoch {}, time {:.3f}'.format(epoch_id, time.time() - start_epoch)) 287 | print('Micro-f1: Train = {:.4f}, Valid = {:.4f}, Test Micro-f1 = {:.4f}'.format( 288 | np.mean(epoch_train_f1_micro), np.mean(epoch_valid_f1_micro), np.mean(epoch_test_f1_micro) 289 | )) 290 | print('Macro-f1: Train = {:.4f}, valid = {:.4f}, Macro-f1 = {:.4f}'.format( 291 | np.mean(epoch_train_f1_macro), np.mean(epoch_valid_f1_macro), np.mean(epoch_test_f1_macro) 292 | )) 293 | train_f1_micro.append(np.mean(epoch_train_f1_micro)) 294 | valid_f1_micro.append(np.mean(epoch_valid_f1_micro)) 295 | test_f1_micro.append(np.mean(epoch_test_f1_micro)) 296 | train_f1_macro.append(np.mean(epoch_train_f1_macro)) 297 | valid_f1_macro.append(np.mean(epoch_valid_f1_macro)) 298 | test_f1_macro.append(np.mean(epoch_test_f1_macro)) 299 | print('Best Valid Mic-f1 is {:.4f}, best Test Mic-f1 is {:.4f} and epoch_id is {}'.format( 300 | max(valid_f1_micro), 301 | test_f1_micro[valid_f1_micro.index(max(valid_f1_micro))], 302 | args.epoch_log * valid_f1_micro.index(max(valid_f1_micro)) 303 | )) 304 | print('Best Valid Mac-f1 is {:.4f}, best Test Mac-f1 is {:.4f} and epoch_id is {}'.format( 305 | max(valid_f1_macro), 306 | test_f1_macro[valid_f1_macro.index(max(valid_f1_macro))], 307 | args.epoch_log * valid_f1_macro.index(max(valid_f1_macro)) 308 | )) 309 | else: 310 | epoch_train_nmi.append(evaluate_results( 311 | pred=pred_train, y=labels, idx=train_nodes, method='nmi' 312 | )) 313 | epoch_valid_nmi.append(evaluate_results( 314 | pred=pred_valid, y=labels, idx=valid_nodes, method='nmi' 315 | )) 316 | epoch_test_nmi.append(evaluate_results( 317 | pred=pred_test, y=labels, idx=test_nodes, method='nmi' 318 | )) 319 | print('NMI: Train = {:.4f}, Valid = {:.4f}, Test NMI = {:.4f}'.format( 320 | np.mean(epoch_train_nmi), np.mean(epoch_valid_nmi), np.mean(epoch_test_nmi) 321 | )) 322 | train_nmi.append(np.mean(epoch_train_nmi)) 323 | valid_nmi.append(np.mean(epoch_valid_nmi)) 324 | test_nmi.append(np.mean(epoch_test_nmi)) 325 | print('Best Valid NMI is {:.4f}, best Test NMI is {:.4f} and epoch_id is {}'.format( 326 | max(valid_nmi), 327 | test_nmi[valid_nmi.index(max(valid_nmi))], args.epoch_log * valid_nmi.index(max(valid_nmi)) 328 | )) 329 | -------------------------------------------------------------------------------- /preparation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import networkx as nx 3 | import pandas as pd 4 | import random 5 | from copy import deepcopy 6 | import torch 7 | import torch_geometric as tg 8 | 9 | 10 | def graph_to_adj(graphs, ls_hierarchical_community): 11 | # add generated graphs into edges 12 | ls_adj_same_level = [] 13 | for idx, graph in enumerate(graphs): 14 | G_same_level = deepcopy(graph) 15 | hierarchical_community = ls_hierarchical_community[idx] 16 | 17 | add_nodes = [] 18 | add_edges = [] 19 | for com in hierarchical_community: 20 | add_nodes += com['partitions'] 21 | add_edges += com['edges'] 22 | G_same_level.add_nodes_from(add_nodes) 23 | G_same_level.add_edges_from(add_edges) 24 | adj_same_level = nx.to_scipy_sparse_matrix(G_same_level) 25 | 26 | ls_adj_same_level.append(adj_same_level) 27 | return ls_adj_same_level 28 | 29 | 30 | def set_up_train_test_valid(graphs, ls_df_friends, ls_valid_edges, ls_test_edges, seed=123): 31 | # valid data 32 | ls_df_valid = [] 33 | for idx, valid_edges in enumerate(ls_valid_edges): 34 | df_valid_pos_samples = pd.DataFrame(valid_edges['positive'], columns=['source', 'target']) 35 | df_valid_pos_samples['label'] = 1 36 | df_valid_neg_samples = pd.DataFrame(valid_edges['negative'], columns=['source', 'target']) 37 | df_valid_neg_samples['label'] = 0 38 | 39 | df_valid = pd.concat([df_valid_pos_samples, df_valid_neg_samples], axis=0) 40 | 41 | ls_df_valid.append(df_valid) 42 | # test data 43 | ls_df_test = [] 44 | for idx, test_edges in enumerate(ls_test_edges): 45 | df_test_pos_samples = pd.DataFrame(test_edges['positive'], columns=['source', 'target']) 46 | df_test_pos_samples['label'] = 1 47 | df_test_neg_samples = pd.DataFrame(test_edges['negative'], columns=['source', 'target']) 48 | df_test_neg_samples['label'] = 0 49 | 50 | df_test = pd.concat([df_test_pos_samples, df_test_neg_samples], axis=0) 51 | 52 | ls_df_test.append(df_test) 53 | # train data 54 | ls_df_train = [] 55 | for idx, friends in enumerate(ls_df_friends): 56 | graph = graphs[idx] 57 | df_train_neg = pd.DataFrame( 58 | np.random.choice(list(graph.nodes()), 10 * graph.number_of_edges()), columns=['source'] 59 | ) 60 | df_train_neg['target'] = np.random.choice(list(graph.nodes()), 10*graph.number_of_edges()) 61 | df_train_neg = df_train_neg[df_train_neg['source']