├── README.md ├── args.py ├── data ├── citeseer.npz └── cora.npz ├── generator ├── __init__.py ├── cluster.py └── gpt │ ├── __init__.py │ ├── dataset.py │ ├── gpt.py │ ├── model.py │ ├── trainer.py │ └── utils.py ├── requirement.txt ├── run.sh ├── task ├── aggregation │ ├── __init__.py │ └── gcn.py └── utils │ ├── dataset.py │ └── utils.py └── test.py /README.md: -------------------------------------------------------------------------------- 1 | # Graph Generative Model for Benchmarking Graph Neural Networks 2 | 3 | We propose a novel, modern graph generation problem to enable generating privacy-controlled, synthetic substitutes of large-scale real-world graphs that can be effectively used to evaluate GNN models. 4 | Our proposed graph generative model, Computation Graph Transformer (CGT) 1) operates on minibatches rather than the whole graph, avoiding scalability issues, and 2) reduces the task of learning graph distributions to learning feature vector sequence distributions, which we approach with a novel Transformer architecture. 5 | 6 | You can see our [ICML 2023 paper](https://arxiv.org/abs/2207.04396) for more details. 7 | 8 | ## Setup 9 | Create a new conda environment, install [PyTorch](https://pytorch.org) and the remaining requirements: 10 | ``` 11 | conda create python==3.7 -n cgt 12 | conda activate cgt 13 | pip install -r requirement.txt 14 | pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117 15 | ``` 16 | The code is implemented on PyTorch DataParallel. 17 | 18 | ## Dataset 19 | You can download public graph datasets in the npz format from [GNN-Benchmark](https://github.com/shchur/gnn-benchmark). 20 | Place the dataset in `data/` directory. 21 | For your convenience, `cora.npz` and `citeseer.npz` are already saved in `data\`. 22 | We also support ogbn-arxiv and ogbn-products dataset from [OGBN benchmark](https://ogb.stanford.edu/docs/nodeprop/). 23 | 24 | ## Usage 25 | In `run.sh`, you write down a list of graph datasets that you want to learn distributions into `DATASETS`. 26 | First, we add different sizes of noisy neighbors to augment the original graphs using `NOISES`. 27 | By executing `run.sh`, we learn three different distributions with different noise sizes `NOISES=(0 2 4)` for each dataset. 28 | For each dataset, we train three different GCN models (GCN, GIN, SGC) on a pair of original and synthetic graphs, and then compare their performance. 29 | The details of other hyperparameters can be found in args.py. 30 | 31 | ## Differential Privacy module 32 | 33 | As described in the main paper, DP-SGD on transformer performs badly. 34 | Thus we provide only DP-k-means module in this repository. 35 | To run DP-k-means, you need to download an open-source library from: https://github.com/google/differential-privacy/tree/main/learning/clustering 36 | Then you can uncomment line 11-12 in `generator/cluster.py` and set `dp_feature` in `args.py` to True. 37 | 38 | 39 | ## File description 40 | 41 | We provide brief descriptions for each file as follows: 42 | 43 | | Directory/File | description | 44 | | ---- | ---- | 45 | | run.sh | script to run experiments | 46 | | args.py | set hyperparameters | 47 | | test.py | main file: prepare models, read datasets, graph generation, GNN evaluation | 48 | | data/ | download datasets | 49 | | generator/ | codes related to graph transformer | 50 | | generator/cluster.py | k-means or DP k-means clustering | 51 | | generator/gpt | CGT main directory | 52 | | generator/gpt/gpt.py | prepare models, prepare datasets, train/generation loops | 53 | | generator/gpt/dataset.py | dataset for flatten computation graphs | 54 | | generator/gpt/model.py | XLNet model | 55 | | generator/gpt/trainer.py | training loop | 56 | | generator/gpt/utils.py | generation loop | 57 | | task/ | GNN models | 58 | | task/aggregation | GNN models with different aggregation strategies (GCN, GAT, SGN, GIN) | 59 | | task/utils/dataset.py | Computation Graph Dataset for PyTorch DataParallel | 60 | | task/utils/utils.py | ogbn/npz format datasets loading, utility functions | 61 | 62 | 63 | ### Citation 64 | Please consider citing the following paper when using our code for your application. 65 | 66 | ``` 67 | @article{yoon2022scalable, 68 | title={Scalable Privacy-enhanced Benchmark Graph Generative Model for Graph Convolutional Networks}, 69 | author={Yoon, Minji and Wu, Yue and Palowitch, John and Perozzi, Bryan and Salakhutdinov, Ruslan}, 70 | journal={arXiv preprint arXiv:2207.04396}, 71 | year={2022} 72 | } 73 | ``` 74 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | def get_args(): 5 | """Argument parser from command line""" 6 | parser = argparse.ArgumentParser() 7 | # GNN training-related hyperparameters 8 | parser.add_argument('--epochs', type=int, default=200, 9 | help='Number of epochs to train.') 10 | parser.add_argument('--batch_size', type=int, default=64, 11 | help='Size of batch.') 12 | parser.add_argument('--num_workers', type=int, default=10, 13 | help='Number of workers for data loader.') 14 | parser.add_argument('--prefetch_factor', type=int, default=2, 15 | help='Number of precetched batchs.') 16 | parser.add_argument('--lr', type=float, default=0.001, 17 | help='Initial learning rate.') 18 | parser.add_argument('--early_stopping', type=int, default=10, 19 | help='Number of epochs to wait before early stop.') 20 | parser.add_argument('--dropout', type=float, default=0, 21 | help='dropout') 22 | parser.add_argument("--weight_decay", type=float, default=5e-4, 23 | help="Weight for L2 loss") 24 | 25 | # Dataset-related hyperparameters 26 | parser.add_argument('--data_dir', type=str, default="./data", 27 | help='Dataset location.') 28 | parser.add_argument('--save_dir', type=str, default="save", 29 | help='Save location.') 30 | parser.add_argument('--dataset', type=str, default="cora", 31 | help='Dataset to use.') 32 | 33 | # GNN structure-related hyperparameters 34 | parser.add_argument('--hidden_dim', type=int, default=64, 35 | help='Hidden dimension') 36 | parser.add_argument('--step_num', type=int, default=2, 37 | help='Number of propagating steps') 38 | parser.add_argument('--sample_num', type=int, default=5, 39 | help='Number of sampled neighbors') 40 | parser.add_argument('--subgraph_step_num', type=int, default=2, 41 | help='Number of propagating steps') 42 | parser.add_argument('--subgraph_sample_num', type=int, default=5, 43 | help='Number of propagating steps') 44 | 45 | # Privacy-related hyperparameters 46 | parser.add_argument('--dp_feature', dest='dp_feature', action='store_true') 47 | parser.set_defaults(dp_feature=False) 48 | parser.add_argument('--dp_edge', dest='dp_edge', action='store_true') 49 | parser.set_defaults(dp_edge=False) 50 | parser.add_argument( 51 | "--secure-rng", 52 | action="store_true", 53 | default=False, 54 | help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost", 55 | ) 56 | parser.add_argument( 57 | "--dp-sigma", 58 | type=float, 59 | default=0.01, 60 | metavar="S", 61 | help="Noise multiplier (default 1.0)", 62 | ) 63 | parser.add_argument( 64 | "-c", 65 | "--dp-max-per-sample-grad_norm", 66 | type=float, 67 | default=1.0, 68 | metavar="C", 69 | help="Clip per-sample gradients to this norm (default 1.0)", 70 | ) 71 | parser.add_argument( 72 | "--dp_delta", 73 | type=float, 74 | default=100, 75 | metavar="D", 76 | help="Target delta (default: 1e-5)", 77 | ) 78 | parser.add_argument( 79 | "--dp_epsilon", 80 | type=float, 81 | default=100., 82 | metavar="D", 83 | help="Target epsilon", 84 | ) 85 | 86 | # SR-GNN-related hyperparameters 87 | parser.add_argument("--arch", type=int, default=0, 88 | help="use which variant of the model") 89 | parser.add_argument("--alpha", type=float, default=0., 90 | help="restart coefficient in biased sampling") 91 | parser.add_argument('--iid_sample', dest='iid_sample', action='store_true') 92 | parser.add_argument('--bias_sample', dest='iid_sample', action='store_false') 93 | parser.set_defaults(iid_sample=True) 94 | 95 | # Computation Graph encoding-related hyperparameters 96 | parser.add_argument('--org_code', dest='org_code', action='store_true') 97 | parser.add_argument('--dup_code', dest='org_code', action='store_false') 98 | parser.set_defaults(org_code=True) 99 | parser.add_argument('--self_connection', dest='self_connection', action='store_true') 100 | parser.set_defaults(self_connection=True) 101 | 102 | # Task-related hyperparameters 103 | parser.add_argument('--task_name', type=str, default="aggregation", 104 | help='aggregation, depth, shift, width') 105 | parser.add_argument('-n', '--model_list', nargs='+', default=['gcn', 'sgc', 'gin', 'gat'], 106 | help='a list of GNN models to be tested') 107 | parser.add_argument('-p', '--predictor_list', nargs='+', default=['dot', 'mlp'], 108 | help='a list of link predictor models to be tested') 109 | parser.add_argument("--noise_num", type=int, default=0, 110 | help="Number of noise edges") 111 | 112 | # Cluster-related hyperparameters 113 | parser.add_argument('--cluster_num', type=int, default=512, 114 | help='Number of clusters used to discretize feature vectors') 115 | parser.add_argument('--cluster_size', type=int, default=1, 116 | help='Size of mininum cluster') 117 | parser.add_argument('--cluster_sample_num', type=int, default=5000, 118 | help='Number of nodes participated in kmeans') 119 | 120 | # GPT-related hyperparameters 121 | parser.add_argument('--gpt_train_name', type=str, default="default", 122 | help='wandb run name') 123 | parser.add_argument('--gpt_model', type=str, default="XLNet", 124 | help='GPT, XLNet, or Bayes') 125 | parser.add_argument('--gpt_softmax_temperature', type=float, default=1., 126 | help='Temperature used to sample') 127 | parser.add_argument('--gpt_epochs', type=int, default=50, 128 | help='Number of epochs to train.') 129 | parser.add_argument('--gpt_batch_size', type=int, default=128, 130 | help='Size of batch.') 131 | parser.add_argument('--gpt_lr', type=float, default=0.003/8, 132 | help='Initial learning rate.') 133 | parser.add_argument('--gpt_layers', type=int, default=3, 134 | help='Number of layers') 135 | parser.add_argument('--gpt_heads', type=int, default=12, 136 | help='Number of heads') 137 | parser.add_argument('--gpt_dropout', type=float, default=0.2, 138 | help='Dropout rate (1 - keep probability).') 139 | parser.add_argument('--gpt_weight_decay', type=float, default=5, 140 | help='Weight decay (L2 loss on parameters).') 141 | parser.add_argument('--gpt_hidden_dim', type=int, default=64, 142 | help='Hidden dimension') 143 | parser.add_argument('--gpt_early_stopping', type=int, default=10, 144 | help='Number of epochs to wait before early stop.') 145 | 146 | # CGT-related hyperparameters 147 | parser.add_argument('--give_start_id', dest='gpt_start_id', action='store_true') 148 | parser.set_defaults(gpt_start_id=False) 149 | parser.add_argument('--no_label_con', dest='gpt_label_con', action='store_false') 150 | parser.set_defaults(gpt_label_con=True) 151 | parser.add_argument('--long_seq', dest='gpt_long_seq', action='store_true') 152 | parser.set_defaults(gpt_long_seq=False) 153 | parser.add_argument('--inv_pos', dest='gpt_inv_pos', action='store_true') 154 | parser.set_defaults(gpt_inv_pos=False) 155 | 156 | parser.add_argument('--label_wise', dest='label_wise', action='store_true') 157 | parser.add_argument('--no_label_wise', dest='label_wise', action='store_false') 158 | parser.set_defaults(label_wise=True) 159 | 160 | # Save intermediate graph infomation 161 | parser.set_defaults(save_org_graph=False) 162 | parser.set_defaults(save_cluster_graph=False) 163 | parser.set_defaults(save_synthetic_graph=False) 164 | 165 | 166 | args, _ = parser.parse_known_args() 167 | return args 168 | -------------------------------------------------------------------------------- /data/citeseer.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minjiyoon/CGT/5daba9acfba5a84cd0c5ce1271ef7ed608312a93/data/citeseer.npz -------------------------------------------------------------------------------- /data/cora.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minjiyoon/CGT/5daba9acfba5a84cd0c5ce1271ef7ed608312a93/data/cora.npz -------------------------------------------------------------------------------- /generator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minjiyoon/CGT/5daba9acfba5a84cd0c5ce1271ef7ed608312a93/generator/__init__.py -------------------------------------------------------------------------------- /generator/cluster.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import torch 5 | 6 | from time import perf_counter 7 | from sklearn.decomposition import PCA 8 | from k_means_constrained import KMeansConstrained 9 | 10 | # To run differential private k-means, you need to download an open-source library from: https://github.com/google/differential-privacy/tree/main/learning/clustering 11 | #from .clustering import clustering_algorithm 12 | #from .clustering import clustering_params 13 | 14 | def kmeans(feats, cluster_num, cluster_size, cluster_sample_num): 15 | """ 16 | k-means clustering 17 | Args: 18 | feats: feature vectors 19 | cluster_num: number of clusters 20 | cluster_size: minimum size of clusters 21 | cluster_sample_num: number of samples for clustering 22 | Return: 23 | centers: cluster centers 24 | """ 25 | if cluster_sample_num < feats.shape[0]: 26 | px_ids = random.sample(list(range(feats.shape[0])), cluster_sample_num) 27 | x = feats[px_ids] 28 | else: 29 | x = feats 30 | 31 | pca = PCA(n_components=min(feats.shape[1], 128)) 32 | x_pca = pca.fit_transform(x) 33 | 34 | clf = KMeansConstrained(n_clusters=cluster_num, size_min=cluster_size, init='random', n_init=1, max_iter=8) 35 | clf.fit(x_pca) 36 | centers = pca.inverse_transform(clf.cluster_centers_) 37 | 38 | return centers 39 | 40 | 41 | def DP_kmeans(feats, cluster_num, cluster_sample_num, epsilon=10, delta=1e-6): 42 | """ 43 | Differential private k-means clustering 44 | Args: 45 | feats: feature vectors 46 | cluster_num: number of clusters 47 | cluster_sample_num: number of samples for clustering 48 | epsilon: privacy budget 49 | delta: privacy budget 50 | Return: 51 | centers: cluster centers 52 | cluster_num: number of clusters 53 | """ 54 | if cluster_sample_num < feats.shape[0]: 55 | px_ids = random.sample(list(range(feats.shape[0])), cluster_sample_num) 56 | x = feats[px_ids] 57 | else: 58 | x = feats 59 | 60 | pca = PCA(n_components=128) 61 | x_pca = pca.fit_transform(x) 62 | x_pca_total = pca.transform(feats) 63 | 64 | data = clustering_params.Data(x_pca, radius=1.0) 65 | privacy_param = clustering_params.DifferentialPrivacyParam(epsilon=epsilon, delta=delta) 66 | clustering_result: clustering_algorithm.ClusteringResult = (clustering_algorithm.private_lsh_clustering(cluster_num, data, privacy_param)) 67 | 68 | centers = pca.inverse_transform(clustering_result.centers) 69 | return centers, centers.shape[0] 70 | 71 | def cluster_feats(args, feats): 72 | """ 73 | Cluster feature vectors 74 | 75 | Input: 76 | org_feats: original feature matrices 77 | Return: 78 | cluster_ids: list of cluster ids where each feature belongs to 79 | cluster_centers: centers of clusters 80 | 81 | """ 82 | # Define cluster centers 83 | start_time = perf_counter() 84 | if args.dp_feature: 85 | cluster_centers, cluster_num = DP_kmeans(feats, args.cluster_num, args.cluster_sample_num) 86 | args.cluster_num = cluster_num 87 | else: 88 | cluster_centers = kmeans(feats, args.cluster_num, args.cluster_size, args.cluster_sample_num) 89 | 90 | # Cluster the original dataset 91 | batch_size = 1000 92 | cluster_ids = np.zeros(feats.shape[0]) 93 | for batch in range(feats.shape[0] // batch_size + 1): 94 | if batch < feats.shape[0] // batch_size: 95 | idx = list(range(batch * batch_size, (batch + 1) * batch_size)) 96 | else: 97 | idx = list(range(batch * batch_size, feats.shape[0])) 98 | cluster_ids[idx] = ((feats[idx, None, :] - cluster_centers[None, :, :]) ** 2).sum(-1).argmin(1) 99 | 100 | # Append empty_id 101 | cluster_ids = torch.LongTensor(np.append(cluster_ids, args.cluster_num)) 102 | cluster_centers = torch.FloatTensor(np.concatenate((cluster_centers, np.zeros((1, cluster_centers.shape[1]))), axis=0)) 103 | 104 | print("Clustering time: {:.3f}".format(perf_counter() - start_time)) 105 | 106 | return cluster_ids, cluster_centers 107 | 108 | -------------------------------------------------------------------------------- /generator/gpt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minjiyoon/CGT/5daba9acfba5a84cd0c5ce1271ef7ed608312a93/generator/gpt/__init__.py -------------------------------------------------------------------------------- /generator/gpt/dataset.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from itertools import permutations 3 | import numpy as np 4 | import random 5 | import torch 6 | 7 | class Dataset(torch.utils.data.Dataset): 8 | """ 9 | Flatten Computation Graph Dataset for training CGT 10 | Args: 11 | args: arguments 12 | adjs: a list of adjacency matrix of each computation graph 13 | cluster_ids: cluster ids for each nodes 14 | labels: labels for each nodes 15 | ids: ids 16 | """ 17 | def __init__(self, args, adjs, cluster_ids, labels, ids): 18 | self.adjs = adjs 19 | self.adjs_list = isinstance(adjs, list) 20 | self.cluster_ids = cluster_ids 21 | self.labels = labels 22 | self.ids = ids 23 | 24 | self.step_num = args.subgraph_step_num 25 | self.sample_num = args.subgraph_sample_num 26 | self.noise_num = args.noise_num 27 | self.total_sample_num = self.sample_num + self.noise_num 28 | self.short_seq_num = (self.sample_num + self.noise_num) ** self.step_num 29 | 30 | self.empty_id = args.cluster_num 31 | self.start_id = args.cluster_num + 1 32 | self.vocab_size = args.cluster_num + 2 # cluster centers + start_id + empty_id 33 | self.block_size = 1 + 1 + self.step_num # start_node + root_node + num_layers 34 | 35 | self.compute_short_seq() 36 | 37 | def __len__(self): 38 | return len(self.ids) 39 | 40 | def __getitem__(self, index): 41 | org_empty_id = self.cluster_ids.shape[0] - 1 42 | seed_id = self.ids[index] 43 | sampled_cluster_ids = [self.start_id, self.cluster_ids[seed_id]] 44 | curr_target_list = [seed_id] 45 | for _ in range(self.step_num): 46 | new_target_list = [] 47 | for target_id in curr_target_list: 48 | # Get neighbor list 49 | if target_id == org_empty_id: 50 | source_ids = [] 51 | else: 52 | if self.adjs_list: 53 | source_ids = self.adjs[target_id] 54 | else: 55 | source_ids = np.nonzero(self.adjs[target_id])[0].tolist() 56 | # Sample fixed number of neighbors 57 | if len(source_ids) == 0: 58 | sampled_ids = self.sample_num * [org_empty_id] 59 | elif len(source_ids) < self.sample_num: 60 | sampled_ids = source_ids + (self.sample_num - len(source_ids)) * [org_empty_id] 61 | else: 62 | sampled_ids = np.random.choice(source_ids, self.sample_num, replace = False).tolist() 63 | 64 | if self.noise_num > 0: 65 | perm = np.random.permutation(len(self.adjs))[:self.noise_num] 66 | sampled_ids = np.concatenate((sampled_ids, perm), axis=0) 67 | 68 | sampled_cluster_ids.extend(self.cluster_ids[sampled_ids]) 69 | new_target_list.extend(sampled_ids) 70 | 71 | curr_target_list = new_target_list 72 | 73 | return {"ids": torch.LongTensor(sampled_cluster_ids), 74 | "label": torch.LongTensor([self.labels[seed_id]]) 75 | } 76 | 77 | def collate(self, items): 78 | items = [(item["ids"], item["label"]) for item in items] 79 | (idses, labels) = zip(*items) 80 | 81 | idses = torch.stack(idses, dim=0) 82 | labels = torch.stack(labels, dim=0) 83 | idses = idses[:, self.seq_id_list].view(-1, self.block_size) 84 | labels = labels.repeat_interleave(self.short_seq_num) 85 | 86 | result= dict( 87 | query = idses[:, :-1], 88 | predict = idses[:, 1:], 89 | label = labels 90 | ) 91 | return result 92 | 93 | def compute_short_seq(self): 94 | self.seq_id_list = [] 95 | sample_num = self.sample_num + self.noise_num 96 | def recursion(layer, locat_list): 97 | for i in range(sample_num): 98 | new_locat_list = locat_list + [i] 99 | if layer == self.step_num: 100 | seq_id = [0, 1] 101 | abs = 2 102 | for j in range(1, self.step_num + 1): 103 | new_id = abs + sample_num * new_locat_list[j - 1] + new_locat_list[j] 104 | seq_id.append(new_id) 105 | abs += sample_num ** j 106 | self.seq_id_list.extend(seq_id) 107 | else: 108 | recursion(layer + 1, new_locat_list) 109 | recursion(1, [0]) 110 | 111 | 112 | class QuantizedDataset(torch.utils.data.Dataset): 113 | def __init__(self, args, sequences, labels, cluster_centers): 114 | self.sequences = sequences 115 | self.labels = labels 116 | self.cluster_centers = cluster_centers 117 | 118 | self.step_num = args.subgraph_step_num 119 | self.sample_num = args.subgraph_sample_num + args.noise_num 120 | self.self_connection = args.self_connection 121 | self.dup_adj = self.compute_dup_adj() 122 | 123 | def __len__(self): 124 | return len(self.sequences) 125 | 126 | def __getitem__(self, index): 127 | return {"feat": self.cluster_centers[self.sequences[index]], 128 | "adj": self.dup_adj, 129 | "label": torch.LongTensor([self.labels[index]]) 130 | } 131 | 132 | def compute_dup_adj(self): 133 | """ duplicate-encoded adjacency matrix (fixed shape)""" 134 | seed_id = 0 135 | sampled_nodes = [seed_id] 136 | curr_target_list = [seed_id] 137 | sampled_edges = defaultdict(list) 138 | for _ in range(self.step_num): 139 | new_target_list = [] 140 | for target_id in curr_target_list: 141 | # Get neighbor list 142 | source_ids = list(range(len(sampled_nodes), len(sampled_nodes) + self.sample_num)) 143 | 144 | sampled_nodes.extend(source_ids) 145 | new_target_list.extend(source_ids) 146 | sampled_edges[target_id].extend(source_ids) 147 | 148 | curr_target_list = new_target_list 149 | 150 | # Generate adjacency matrix 151 | rows = [] 152 | cols = [] 153 | for target_id in sampled_edges.keys(): 154 | for source_id in sampled_edges[target_id]: 155 | rows.append(target_id) 156 | cols.append(source_id) 157 | 158 | # Define adjacency matrix 159 | indices = torch.stack([torch.LongTensor(rows), torch.LongTensor(cols)], dim=0) 160 | attention = torch.ones(len(cols)) 161 | dense_shape = torch.Size([len(sampled_nodes), len(sampled_nodes)]) 162 | sampled_adj = torch.sparse.FloatTensor(indices, attention, dense_shape).to_dense() 163 | 164 | # Remove zero_in_degree 165 | if self.self_connection: 166 | sampled_adj = sampled_adj + torch.diag(torch.ones(len(sampled_nodes))) 167 | 168 | return sampled_adj 169 | 170 | 171 | -------------------------------------------------------------------------------- /generator/gpt/gpt.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import torch 4 | from tqdm import tqdm 5 | from time import perf_counter 6 | 7 | from generator.gpt.dataset import Dataset, QuantizedDataset 8 | from generator.gpt.model import GPTConfig, XLNet 9 | from generator.gpt.trainer import Trainer, TrainerConfig 10 | from generator.gpt.utils import sample 11 | from generator.cluster import cluster_feats 12 | 13 | 14 | def train(args, adjs, cluster_ids, labels, ids, split): 15 | """ 16 | Train CGT 17 | Args: 18 | args: arguments 19 | adjs: a list of adjacency matrix of each computation graphs 20 | cluster_ids: cluster ids of each nodes 21 | labels: label of each nodes 22 | ids: node id list 23 | split: split name ('train', 'val', 'test') 24 | """ 25 | # hyperparameters of computation graphs 26 | dataset = Dataset(args, adjs, cluster_ids, labels, ids) 27 | params = {'batch_size': args.gpt_batch_size, 28 | 'num_workers': args.num_workers, 29 | 'prefetch_factor': args.prefetch_factor, 30 | 'collate_fn': dataset.collate, 31 | 'shuffle': True, 32 | 'drop_last': True} 33 | data_loader = torch.utils.data.DataLoader(dataset, **params) 34 | 35 | # hyperparameters of XLNet architecture 36 | mconf = GPTConfig(dataset.vocab_size, dataset.block_size, step_num=dataset.step_num, sample_num=dataset.total_sample_num, 37 | embd_pdrop=args.gpt_dropout, resid_pdrop=args.gpt_dropout, attn_pdrop=args.gpt_dropout, 38 | n_layer=args.gpt_layers, n_head=args.gpt_heads, n_embd=args.gpt_hidden_dim*args.gpt_heads, n_class=args.label_size) 39 | model = XLNet(mconf) 40 | 41 | # hyperparameters of XLNet training 42 | tokens_per_epoch = dataset.block_size * dataset.short_seq_num * len(dataset) 43 | final_tokens = tokens_per_epoch * args.gpt_epochs 44 | tconf = TrainerConfig(batch_size=args.gpt_batch_size, block_size=dataset.block_size, short_seq_num=dataset.short_seq_num, 45 | max_epochs=args.gpt_epochs, learning_rate=args.gpt_lr, betas = (0.9, 0.95), weight_decay=args.gpt_weight_decay, 46 | lr_decay=True, warmup_tokens=tokens_per_epoch, final_tokens=final_tokens, 47 | ckpt_path='generator/gpt/save/{}_{}.pt'.format(args.gpt_train_name, split)) 48 | 49 | start_time = perf_counter() 50 | trainer = Trainer(args, tconf, model, data_loader) 51 | trainer.train() 52 | print("[GPT] name: {}, split: {}, train time: {:.3f}".format(args.gpt_train_name, split, perf_counter() - start_time)) 53 | 54 | return model 55 | 56 | 57 | def generate(args, model, labels, ids, split): 58 | """ 59 | Generate cluster ids for each node using CGT 60 | Args: 61 | args: arguments 62 | model: CGT model 63 | labels: label of each nodes 64 | ids: node id list 65 | split: split name ('train', 'val', 'test') 66 | """ 67 | start_time = perf_counter() 68 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 69 | 70 | checkpoint = torch.load('generator/gpt/save/{}_{}.pt'.format(args.gpt_train_name, split)) 71 | model.load_state_dict(checkpoint) 72 | model.eval() 73 | 74 | n_samples = len(labels) 75 | start_id = args.cluster_num + 1 76 | start_ids = [start_id for _ in range(n_samples)] 77 | start_ids = torch.LongTensor(start_ids).unsqueeze(1).to(device) 78 | labels = torch.LongTensor(labels).to(device) 79 | 80 | result = [] 81 | for b in range(int(math.ceil(n_samples / args.gpt_batch_size))): 82 | ind_start = b * args.gpt_batch_size 83 | ind_end = min(ind_start + args.gpt_batch_size, n_samples) 84 | generated_ids = sample(model, start_ids[ind_start:ind_end].contiguous(), labels[ind_start:ind_end].contiguous(), temperature=args.gpt_softmax_temperature) 85 | result.append(generated_ids[:, 1:]) 86 | result = torch.cat(result, dim = 0) 87 | print("[GPT] name: {}, split: {}, generation time: {:.3f}".format(args.gpt_train_name, split, perf_counter() - start_time)) 88 | 89 | return result.cpu() 90 | 91 | 92 | def run(args, graphs, feats, labels, ids): 93 | """ 94 | Learn graph distribution and generate new graphs using CGT 95 | Args: 96 | args: arguments 97 | graphs: a list of computation graphs 98 | feats: a list of feature matrices 99 | labels: a list of labels 100 | ids: a list of node ids 101 | """ 102 | # Create save directory 103 | save_dir = 'generator/gpt/save' 104 | if not os.path.isdir(save_dir): 105 | os.makedirs(save_dir) 106 | 107 | # STEP 1: quantize features using cluster ids they are belonging to 108 | cluster_ids, cluster_centers = cluster_feats(args, feats) 109 | 110 | # STEP 2-1: train GPT on the original (training + validation) set 111 | target_ids = ids["train"] + ids["val"] 112 | model = train(args, graphs, cluster_ids, labels, target_ids, split="train") 113 | # STEP 2-2: generate cluster_ids for each (training + validation) set 114 | gen_train_ids = generate(args, model, labels[ids["train"]], ids["train"], split="train") 115 | gen_val_ids = generate(args, model, labels[ids["val"]], ids["val"], split="train") 116 | # STEP 2-3: creat dataset that map cluster ids to feature vectors 117 | train_dataset = QuantizedDataset(args, gen_train_ids, labels[ids["train"]], cluster_centers) 118 | val_dataset = QuantizedDataset(args, gen_val_ids, labels[ids["val"]], cluster_centers) 119 | 120 | # STEP 3-1: train GPT on the original test set 121 | target_ids = ids["test"] 122 | model = train(args, graphs, cluster_ids, labels, target_ids, split="test") 123 | # STEP 3-2: generate cluster_ids for the test set 124 | gen_test_ids = generate(args, model, labels[ids["test"]], ids["test"], split="test") 125 | # STEP 3-3: creat dataset that map cluster ids to feature vectors 126 | test_dataset = QuantizedDataset(args, gen_test_ids, labels[ids["test"]], cluster_centers) 127 | 128 | return train_dataset, val_dataset, test_dataset 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /generator/gpt/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code block is adapted from: 3 | - Repository: minGPT 4 | - Link: https://github.com/karpathy/minGPT/tree/master 5 | """ 6 | 7 | import math 8 | import logging 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn import functional as F 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class GPTConfig: 18 | embd_pdrop = 0.1 19 | resid_pdrop = 0.1 20 | attn_pdrop = 0.1 21 | 22 | def __init__(self, vocab_size, block_size, **kwargs): 23 | self.vocab_size = vocab_size 24 | self.block_size = block_size 25 | for k,v in kwargs.items(): 26 | setattr(self, k, v) 27 | 28 | 29 | class CausalSelfAttention(nn.Module): 30 | 31 | def __init__(self, config): 32 | super().__init__() 33 | assert config.n_embd % config.n_head == 0 34 | 35 | # key, query, value projections for all heads 36 | self.key = nn.Linear(config.n_embd, config.n_embd) 37 | self.query = nn.Linear(config.n_embd, config.n_embd) 38 | self.value = nn.Linear(config.n_embd, config.n_embd) 39 | 40 | # regularization 41 | self.attn_drop = nn.Dropout(config.attn_pdrop) 42 | self.resid_drop = nn.Dropout(config.resid_pdrop) 43 | 44 | # output projection 45 | self.proj = nn.Linear(config.n_embd, config.n_embd) 46 | 47 | # causal mask to ensure that attention is only applied to the left in the input sequence 48 | self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)) 49 | .view(1, 1, config.block_size, config.block_size)) 50 | self.n_head = config.n_head 51 | 52 | def forward(self, x): 53 | B, T, C = x.size() 54 | 55 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 56 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 57 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 58 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 59 | 60 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 61 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 62 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) 63 | att = F.softmax(att, dim=-1) 64 | att = self.attn_drop(att) 65 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 66 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 67 | 68 | # output projection 69 | y = self.resid_drop(self.proj(y)) 70 | return y 71 | 72 | 73 | class XLNetAttention(nn.Module): 74 | 75 | def __init__(self, config): 76 | super().__init__() 77 | assert config.n_embd % config.n_head == 0 78 | 79 | # key, query, value projections for all heads 80 | self.key = nn.Linear(config.n_embd, config.n_embd) 81 | self.query = nn.Linear(config.n_embd, config.n_embd) 82 | self.value = nn.Linear(config.n_embd, config.n_embd) 83 | 84 | # regularization 85 | self.attn_drop = nn.Dropout(config.attn_pdrop) 86 | self.resid_drop = nn.Dropout(config.resid_pdrop) 87 | 88 | # output projection 89 | self.proj = nn.Linear(config.n_embd, config.n_embd) 90 | 91 | # causal mask to ensure that attention is only applied to the left in the input sequence 92 | self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)) 93 | .view(1, 1, config.block_size, config.block_size)) 94 | self.n_head = config.n_head 95 | 96 | def forward(self, x, q): 97 | B, T, C = x.size() 98 | 99 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 100 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 101 | q = self.query(q).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 102 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 103 | 104 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 105 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 106 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) 107 | att = F.softmax(att, dim=-1) 108 | att = self.attn_drop(att) 109 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 110 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 111 | 112 | # output projection 113 | y = self.resid_drop(self.proj(y)) 114 | return y 115 | 116 | class XLNetBlock(nn.Module): 117 | 118 | def __init__(self, config): 119 | super().__init__() 120 | self.ln1 = nn.LayerNorm(config.n_embd) 121 | self.ln2 = nn.LayerNorm(config.n_embd) 122 | self.attn = CausalSelfAttention(config) 123 | self.ln1q = nn.LayerNorm(config.n_embd) 124 | self.ln2q = nn.LayerNorm(config.n_embd) 125 | self.query_attn = XLNetAttention(config) 126 | self.mlp = nn.Sequential( 127 | nn.Linear(config.n_embd, 4 * config.n_embd), 128 | nn.GELU(), 129 | nn.Linear(4 * config.n_embd, config.n_embd), 130 | nn.Dropout(config.resid_pdrop), 131 | ) 132 | self.mlpq = nn.Sequential( 133 | nn.Linear(config.n_embd, 4 * config.n_embd), 134 | nn.GELU(), 135 | nn.Linear(4 * config.n_embd, config.n_embd), 136 | nn.Dropout(config.resid_pdrop), 137 | ) 138 | 139 | def forward(self, x): 140 | x, q = x 141 | x = x + self.attn(self.ln1(x)) 142 | x = x + self.mlp(self.ln2(x)) 143 | q = q + self.query_attn(self.ln1(x), self.ln1q(q)) 144 | q = q + self.mlpq(self.ln2q(q)) 145 | return (x, q) 146 | 147 | class XLNet(nn.Module): 148 | 149 | def __init__(self, config): 150 | super().__init__() 151 | self.config = config 152 | 153 | # input embedding stem 154 | self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) 155 | self.query_emb = nn.Parameter(torch.zeros(1, 1, config.n_embd)) 156 | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) 157 | self.class_emb = nn.Embedding(config.n_class, config.n_embd) 158 | 159 | self.step_num = config.step_num 160 | self.sample_num = config.sample_num 161 | self.block_size = config.block_size 162 | 163 | # inital embedding dropout 164 | self.drop = nn.Dropout(config.embd_pdrop) 165 | # transformer 166 | self.blocks = nn.Sequential(*[XLNetBlock(config) for _ in range(config.n_layer)]) 167 | # decoder head 168 | self.ln_f = nn.LayerNorm(config.n_embd) 169 | # remove start_id 170 | self.head = nn.Linear(config.n_embd, config.vocab_size - 1, bias=False) 171 | 172 | self.apply(self._init_weights) 173 | self.criterion = nn.CrossEntropyLoss() 174 | 175 | def _init_weights(self, module): 176 | if isinstance(module, (nn.Linear, nn.Embedding)): 177 | module.weight.data.normal_(mean=0.0, std=0.02) 178 | if isinstance(module, nn.Linear) and module.bias is not None: 179 | module.bias.data.zero_() 180 | elif isinstance(module, nn.LayerNorm): 181 | module.bias.data.zero_() 182 | module.weight.data.fill_(1.0) 183 | 184 | def configure_optimizers(self, train_config): 185 | """ 186 | This long function is unfortunately doing something very simple and is being very defensive: 187 | We are separating out all parameters of the model into two buckets: those that will experience 188 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 189 | We are then returning the PyTorch optimizer object. 190 | """ 191 | 192 | # separate out all parameters to those that will and won't experience regularizing weight decay 193 | decay = set() 194 | no_decay = set() 195 | whitelist_weight_modules = (torch.nn.Linear, ) 196 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 197 | for mn, m in self.named_modules(): 198 | for pn, p in m.named_parameters(): 199 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 200 | 201 | if pn.endswith('bias'): 202 | # all biases will not be decayed 203 | no_decay.add(fpn) 204 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 205 | # weights of whitelist modules will be weight decayed 206 | decay.add(fpn) 207 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 208 | # weights of blacklist modules will NOT be weight decayed 209 | no_decay.add(fpn) 210 | 211 | # special case the position embedding parameter in the root GPT module as not decayed 212 | no_decay.add('pos_emb') 213 | no_decay.add('query_emb') 214 | 215 | # validate that we considered every parameter 216 | param_dict = {pn: p for pn, p in self.named_parameters()} 217 | inter_params = decay & no_decay 218 | union_params = decay | no_decay 219 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 220 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 221 | % (str(param_dict.keys() - union_params), ) 222 | 223 | # create the pytorch optimizer object 224 | optim_groups = [ 225 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, 226 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 227 | ] 228 | optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) 229 | return optimizer 230 | 231 | def forward(self, idx, classes, targets=None): 232 | b, t = idx.size() 233 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 234 | 235 | token_embeddings = self.tok_emb(idx) 236 | class_embeddings = self.class_emb(classes).unsqueeze(1).expand(-1, t, -1) 237 | position_embeddings = self.pos_emb 238 | 239 | x = self.drop(token_embeddings + position_embeddings[:, :t]) 240 | q = self.drop((self.query_emb + position_embeddings[:, 1:(t+1)] + class_embeddings).expand_as(x)) 241 | _, q = self.blocks((x, q)) 242 | q = self.ln_f(q) 243 | logits = self.head(q) 244 | logits[:, 0, -1] = float('-inf') 245 | 246 | loss = None 247 | if targets is not None: 248 | logits = logits.view(-1, logits.size(-1)) 249 | targets = targets.view(-1) 250 | loss = self.criterion(logits, targets) 251 | 252 | return logits, loss 253 | -------------------------------------------------------------------------------- /generator/gpt/trainer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | 7 | import torch 8 | import torch.optim as optim 9 | from torch.optim.lr_scheduler import LambdaLR 10 | from torch.utils.data.dataloader import DataLoader 11 | import time 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | class TrainerConfig: 16 | batch_size = 64 17 | block_size = 4 18 | short_seq_num = 25 19 | # optimization parameters 20 | max_epochs = 100 21 | learning_rate = 3e-4 22 | betas = (0.9, 0.95) 23 | grad_norm_clip = 1.0 24 | weight_decay = 0.1 # only applied on matmul weights 25 | # learning rate decay params: linear warmup followed by cosine decay to 10% of original 26 | lr_decay = False 27 | warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere 28 | final_tokens = 260e9 # (at what point we reach 10% of original LR) 29 | # checkpoint settings 30 | ckpt_path = None 31 | 32 | def __init__(self, **kwargs): 33 | for k,v in kwargs.items(): 34 | setattr(self, k, v) 35 | 36 | class Trainer: 37 | """ 38 | Trainer for CGT model. 39 | Args: 40 | args: arguments 41 | config: TrainerConfig object containing model configuration 42 | model: CGT model 43 | data_loader: DataLoader object for training data 44 | """ 45 | def __init__(self, args, config, model, data_loader): 46 | self.config = config 47 | self.args = args 48 | self.model = model 49 | self.data_loader = data_loader 50 | 51 | self.device = 'cpu' 52 | if torch.cuda.is_available(): 53 | self.device = torch.cuda.current_device() 54 | self.model = torch.nn.DataParallel(self.model).to(self.device) 55 | 56 | def save_checkpoint(self): 57 | # DataParallel wrappers keep raw model object in .module attribute 58 | raw_model = self.model.module if hasattr(self.model, "module") else self.model 59 | torch.save(raw_model.state_dict(), self.config.ckpt_path) 60 | 61 | def train(self): 62 | model, config = self.model, self.config 63 | raw_model = model.module if hasattr(self.model, "module") else model 64 | optimizer = raw_model.configure_optimizers(config) 65 | 66 | start_time = time.time() 67 | self.tokens = 0 # counter used for learning rate decay 68 | best_loss = float('inf') 69 | for epoch in range(config.max_epochs): 70 | model.train() 71 | with tqdm(self.data_loader, unit="batch") as t_data_loader: 72 | for batch in t_data_loader: 73 | x, y, lbl = batch["query"].to(self.device), batch["predict"].to(self.device), batch["label"].to(self.device) 74 | 75 | with torch.set_grad_enabled(True): 76 | logits, loss = model(x, lbl, y) 77 | loss = loss.mean() 78 | 79 | model.zero_grad() 80 | loss.backward() 81 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) 82 | optimizer.step() 83 | 84 | if config.lr_decay: 85 | # number of tokens processed this step 86 | self.tokens += config.block_size * config.short_seq_num * config.batch_size 87 | if self.tokens < config.warmup_tokens: 88 | # linear warmup 89 | lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens)) 90 | else: 91 | # cosine learning rate decay 92 | progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens)) 93 | lr_mult = max(3e-5 / config.learning_rate, 0.5 * (1.0 + math.cos(math.pi * progress))) 94 | lr = config.learning_rate * lr_mult 95 | for param_group in optimizer.param_groups: 96 | param_group['lr'] = lr 97 | else: 98 | lr = config.learning_rate 99 | 100 | t_data_loader.set_description(f"Epoch {epoch}") 101 | t_data_loader.set_postfix(loss=loss.item()) 102 | time.sleep(0.1) 103 | 104 | # Save the trained model 105 | self.save_checkpoint() 106 | -------------------------------------------------------------------------------- /generator/gpt/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | 8 | @torch.no_grad() 9 | def sample(model, x, lbl, temperature=1.0, sample=True, top_k=None): 10 | """ 11 | Sample nodes in sequence which will be reconstructed into a computation graph. 12 | Args: 13 | model: CGT model 14 | x: seed nodes 15 | lbl: seed labels 16 | temperature: temperature of softmax 17 | sample: whether to sample or take the most likely 18 | top_k: k for top-k sampling 19 | """ 20 | model.eval() 21 | 22 | generated_ids = [] 23 | complete_ids = x 24 | batch_size = x.size(0) 25 | 26 | for step in range(model.step_num + 1): 27 | logits, _ = model(x, lbl) 28 | # pluck the logits at the final step and scale by temperature 29 | logits = logits[:, -1, :] / temperature 30 | # apply softmax to convert to probabilities 31 | probs = F.softmax(logits, dim=-1) 32 | # sample from the distribution or take the most likely 33 | ix = torch.multinomial(probs, num_samples=1) 34 | 35 | generated_ids.append(ix) 36 | x = torch.cat((x, ix), dim=-1) 37 | x = x.repeat_interleave(model.sample_num, dim=0) 38 | lbl = lbl.repeat_interleave(model.sample_num) 39 | 40 | for step in range(model.step_num + 1): 41 | complete_ids = torch.cat((complete_ids, generated_ids[step].view(batch_size, model.sample_num ** step)), dim=1) 42 | 43 | return complete_ids 44 | 45 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | networkx 3 | scipy 4 | scikit-learn 5 | torch=1.9.1+cu111 6 | tensorflow=1.14.0 7 | dgl-cuda=0.9.1post1 8 | ogb=1.3.2 9 | k-means-constrained=0.7.0 10 | ortools=9.3.10497 11 | jinja2 12 | tqdm 13 | ipdb 14 | wandb 15 | tensorboard -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | DATASETS=("cora" "citeseer") 2 | data_length=${#DATASETS[@]} 3 | 4 | # Experiment: effects of noise to aggregation strategies 5 | NOISES=(0 2 4) 6 | noise_length=${#NOISES[@]} 7 | 8 | for ((i=0;i<$data_length;i++)) 9 | do 10 | for ((j=0;j<$noise_length;j++)) 11 | do 12 | python test.py --dataset "${DATASETS[$i]}" --noise_num ${NOISES[$j]} \ 13 | --task_name "aggregation" -n "gcn" "sgc" "gin" 14 | done 15 | done 16 | 17 | -------------------------------------------------------------------------------- /task/aggregation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minjiyoon/CGT/5daba9acfba5a84cd0c5ce1271ef7ed608312a93/task/aggregation/__init__.py -------------------------------------------------------------------------------- /task/aggregation/gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import gc 6 | import numpy as np 7 | from tqdm import tqdm 8 | from time import perf_counter, sleep 9 | 10 | from ..utils.utils import calc_f1, calc_loss 11 | 12 | # GCN template 13 | class GCN(nn.Module): 14 | """ 15 | Graph Convolutional Networks (GCN) 16 | Args: 17 | model_name: model name ('gcn', 'gin', 'sgc', 'gat') 18 | input_dim: input dimension 19 | output_dim: output dimension 20 | hidden_dim: hidden dimension 21 | step_num: number of propagation steps 22 | output_layer: whether to use the output layer 23 | """ 24 | 25 | def __init__(self, model_name, input_dim, output_dim, hidden_dim, step_num, output_layer=True): 26 | super(GCN, self).__init__() 27 | 28 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 29 | self.model_name = model_name 30 | 31 | self.input_dim = input_dim 32 | self.output_dim = output_dim 33 | self.hidden_dim = hidden_dim 34 | self.step_num = step_num 35 | 36 | self.output_layer = output_layer 37 | 38 | self.W = nn.ModuleList([nn.Linear(input_dim, hidden_dim, bias=False)]) 39 | for _ in range(step_num-1): 40 | self.W.append(nn.Linear(hidden_dim, hidden_dim, bias=False)) 41 | for w in self.W: 42 | nn.init.xavier_uniform_(w.weight) 43 | self.outputW = nn.Linear(hidden_dim, output_dim, bias=False) 44 | nn.init.xavier_uniform_(self.outputW.weight) 45 | 46 | self.pooling = 'sum' if model_name in ('gin', 'gat') else 'avg' 47 | self.nonlinear = (model_name != 'sgc') 48 | self.attention = (model_name == 'gat') 49 | if self.attention: 50 | self.attentionW = nn.Parameter(torch.empty(size=(2*hidden_dim, 1))) 51 | nn.init.xavier_uniform_(self.attentionW.data) 52 | self.leakyReLU = nn.LeakyReLU(0.2) 53 | self.softmax = nn.Softmax(dim=1) 54 | 55 | def get_parameters(self): 56 | ml = list() 57 | for w in self.W: 58 | ml.append({'params': w.parameters()}) 59 | ml.append({'params': self.outputW.parameters()}) 60 | if self.attention: 61 | ml.append({'params': self.attentionW}) 62 | return ml 63 | 64 | def forward(self, feat, raw_adj): 65 | """ 66 | Args: 67 | feat: feature matrix (batch_num * node_num, input_dim) 68 | raw_adj: adjacency matrix (batch_num, node_num, node_num) 69 | Returns: 70 | output: feature matrix of target nodes (batch_num, output_dim) 71 | """ 72 | batch_num, batch_size = raw_adj.shape[0], raw_adj.shape[1] 73 | ids = torch.range(0, (batch_num - 1) * batch_size, batch_size, dtype=torch.long).to(raw_adj.device) 74 | adj = torch.block_diag(*raw_adj).to(raw_adj.device) 75 | if self.pooling == 'avg': 76 | adj = self.avg_pooling(feat, adj) 77 | X = feat 78 | for w in self.W: 79 | Z = w(X) 80 | if self.attention: 81 | adj = self.compute_attention(adj, Z) 82 | X = torch.spmm(adj, Z) 83 | if self.nonlinear: 84 | X = F.relu(X) 85 | if self.output_layer: 86 | X = self.outputW(X) 87 | return X[ids] 88 | 89 | def avg_pooling(self, feats, adj): 90 | """ 91 | Args: 92 | feats: feature matrix (batch_num * node_num, input_dim) 93 | adj: adjacency matrix (batch_num * node_num, batch_num * node_num) 94 | Returns: 95 | adj: adjacency matrix (batch_num * node_num, batch_num * node_num) 96 | """ 97 | nonzeros = torch.nonzero(torch.norm(feats, dim=1), as_tuple=True)[0] 98 | nonzero_adj = adj[:, nonzeros] 99 | row_sum = torch.sum(nonzero_adj, dim=1) 100 | row_sum = row_sum.masked_fill_(row_sum == 0, 1.) 101 | row_sum = torch.diag(1/row_sum).to(adj.device) 102 | adj = torch.spmm(row_sum, adj) 103 | return adj 104 | 105 | def compute_attention(self, adj, X): 106 | Wh1 = torch.matmul(X, self.attentionW[:self.hidden_dim, :]) 107 | Wh2 = torch.matmul(X, self.attentionW[self.hidden_dim:, :]) 108 | e = Wh1 + Wh2.T 109 | e= self.leakyReLU(e) 110 | 111 | zero_vec = -9e15 * torch.ones_like(e) 112 | attention = torch.where(adj > 0, e, zero_vec) #torch.tensor(-9e15).to(self.device)) 113 | attention = F.softmax(attention, dim=1) 114 | return attention 115 | 116 | 117 | def run(args, model_name, train_loader, val_loader, test_loader): 118 | """ 119 | Evaluate GNN performance 120 | Args: 121 | args: arguments 122 | model_name: model name ('gcn', 'gin', 'sgc', 'gat') 123 | train_loader: training data loader 124 | val_loader: validation data loader 125 | test_loader: test data loader 126 | Returns: 127 | acc_mic: micro-F1 score 128 | acc_mac: macro-F1 score 129 | """ 130 | 131 | device = args.device 132 | model = GCN(model_name, args.feat_size, args.label_size, args.hidden_dim, args.step_num) 133 | model = nn.DataParallel(model).to(device) 134 | 135 | # Test GCN models 136 | def test_model(args, model, data_loader, split='val'): 137 | start_time = perf_counter() 138 | stack_output = [] 139 | stack_label = [] 140 | model.eval() 141 | with tqdm(data_loader, unit="batch") as t_data_loader: 142 | for batch in t_data_loader: 143 | feats, adjs, labels = batch["feat"].to(device), batch["adj"].to(device), batch["label"].to(device) 144 | outputs = model(feats, adjs) 145 | loss = calc_loss(outputs, labels) 146 | stack_output.append(outputs.detach().cpu()) 147 | stack_label.append(labels.cpu()) 148 | t_data_loader.set_description(f"{split}") 149 | t_data_loader.set_postfix(loss=loss.item()) 150 | sleep(0.1) 151 | stack_output = torch.cat(stack_output, dim=0) 152 | stack_label = torch.cat(stack_label, dim=0) 153 | loss = calc_loss(stack_output, stack_label) 154 | acc_mic, acc_mac = calc_f1(stack_output, stack_label) 155 | return loss, acc_mic, acc_mac 156 | 157 | ml = list() 158 | ml.extend(model.module.get_parameters()) 159 | optimizer = optim.Adam(ml, lr=args.lr) 160 | 161 | patient = 0 162 | min_loss = np.inf 163 | for epoch in range(args.epochs): 164 | with tqdm(train_loader, unit="batch") as t_train_loader: 165 | for batch in t_train_loader: 166 | feats, adjs, labels = batch["feat"].to(device), batch["adj"].to(device), batch["label"].to(device) 167 | 168 | model.train() 169 | optimizer.zero_grad() 170 | outputs = model(feats, adjs) 171 | loss = calc_loss(outputs, labels) 172 | loss.backward() 173 | optimizer.step() 174 | 175 | t_train_loader.set_description(f"Epoch {epoch}") 176 | t_train_loader.set_postfix(loss=loss.item()) 177 | sleep(0.1) 178 | 179 | with torch.no_grad(): 180 | new_loss, acc_mic, acc_mac = test_model(args, model, val_loader, 'val') 181 | if new_loss >= min_loss: 182 | patient = patient + 1 183 | else: 184 | min_loss = new_loss 185 | patient = 0 186 | 187 | if patient == args.early_stopping: 188 | break 189 | 190 | _, acc_mic, acc_mac = test_model(args, model, test_loader, 'test') 191 | 192 | del model 193 | return acc_mic, acc_mac 194 | 195 | -------------------------------------------------------------------------------- /task/utils/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from collections import defaultdict 4 | 5 | class Dataset(torch.utils.data.Dataset): 6 | """ 7 | Computation Graph Dataset for PyTorch DataParallel 8 | Args: 9 | args: arguments 10 | split: split name ('train', 'val', 'test') 11 | adjs: adjacency matrix 12 | feats: feature matrix 13 | labels: label vector 14 | ids: node id list 15 | """ 16 | def __init__(self, args, split, adjs, feats, labels, ids): 17 | self.adjs = adjs 18 | self.adjs_list = isinstance(adjs, list) 19 | self.feats = feats 20 | self.labels = labels 21 | self.ids = ids[split] 22 | 23 | self.node_num = feats.shape[0] 24 | self.empty_id = feats.shape[0] 25 | self.feats = np.concatenate((self.feats, np.zeros((1, feats.shape[1]))), axis=0) 26 | 27 | self.step_num = args.subgraph_step_num 28 | self.sample_num = args.subgraph_sample_num 29 | self.noise_num = args.noise_num 30 | self.self_connection = args.self_connection 31 | self.dup_adj = self.compute_dup_adj() 32 | 33 | def __len__(self): 34 | return len(self.ids) 35 | 36 | def __getitem__(self, index): 37 | """ 38 | Provide a computation graph sampled around a seed node (index) 39 | Args: 40 | index: seed node index 41 | Returns: 42 | feat: feature matrix (sampled_node_num, input_dim) 43 | adj: adjacency matrix (sampled_node_num, sampled_node_num) 44 | label: label vector (1) 45 | """ 46 | seed_id = self.ids[index] 47 | sampled_nodes = [seed_id] 48 | curr_target_list = [seed_id] 49 | for _ in range(self.step_num): 50 | new_target_list = [] 51 | for target_id in curr_target_list: 52 | # Get neighbor list 53 | if target_id == self.empty_id: 54 | source_ids = [] 55 | else: 56 | if self.adjs_list: 57 | source_ids = self.adjs[target_id] 58 | else: 59 | source_ids = np.nonzero(self.adjs[target_id])[0].tolist() 60 | # Sample fixed number of neighbors 61 | if len(source_ids) == 0: 62 | sampled_ids = self.sample_num * [self.empty_id] 63 | elif len(source_ids) < self.sample_num: 64 | sampled_ids = source_ids + (self.sample_num - len(source_ids)) * [self.empty_id] 65 | else: 66 | sampled_ids = np.random.choice(source_ids, self.sample_num, replace = False).tolist() 67 | 68 | if self.noise_num > 0: 69 | perm = np.random.permutation(self.node_num)[:self.noise_num] 70 | sampled_ids = np.concatenate((sampled_ids, perm), axis=0) 71 | 72 | sampled_nodes.extend(sampled_ids) 73 | new_target_list.extend(sampled_ids) 74 | 75 | curr_target_list = new_target_list 76 | 77 | return {"feat": torch.FloatTensor(self.feats[sampled_nodes]), 78 | "adj": self.dup_adj, 79 | "label": torch.LongTensor([self.labels[seed_id]]) 80 | } 81 | 82 | def compute_dup_adj(self): 83 | """ duplicate-encoded adjacency matrix (fixed shape for all nodes)""" 84 | seed_id = 0 85 | sampled_nodes = [seed_id] 86 | curr_target_list = [seed_id] 87 | sampled_edges = defaultdict(list) 88 | for _ in range(self.step_num): 89 | new_target_list = [] 90 | for target_id in curr_target_list: 91 | # Get neighbor list 92 | source_ids = list(range(len(sampled_nodes), \ 93 | len(sampled_nodes) + self.sample_num + self.noise_num)) 94 | 95 | sampled_nodes.extend(source_ids) 96 | new_target_list.extend(source_ids) 97 | sampled_edges[target_id].extend(source_ids) 98 | 99 | curr_target_list = new_target_list 100 | 101 | # Generate adjacency matrix 102 | rows = [] 103 | cols = [] 104 | for target_id in sampled_edges.keys(): 105 | for source_id in sampled_edges[target_id]: 106 | rows.append(target_id) 107 | cols.append(source_id) 108 | 109 | # Define adjacency matrix 110 | indices = torch.stack([torch.LongTensor(rows), torch.LongTensor(cols)], dim=0) 111 | attention = torch.ones(len(cols)) 112 | dense_shape = torch.Size([len(sampled_nodes), len(sampled_nodes)]) 113 | sampled_adj = torch.sparse.FloatTensor(indices, attention, dense_shape).to_dense() 114 | 115 | # Remove zero_in_degree 116 | if self.self_connection: 117 | sampled_adj = sampled_adj + torch.diag(torch.ones(len(sampled_nodes))) 118 | 119 | return sampled_adj 120 | 121 | def collate(items): 122 | """Collate function for PyTorch DataLoader""" 123 | items = [(item["feat"], item["adj"], item["label"]) for item in items] 124 | (feats, adjs, labels) = zip(*items) 125 | 126 | result= dict( 127 | feat = torch.cat(feats), 128 | adj = torch.stack(adjs, dim=0), 129 | label = torch.cat(labels) 130 | ) 131 | return result 132 | -------------------------------------------------------------------------------- /task/utils/utils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os.path as osp 3 | import networkx as nx 4 | import numpy as np 5 | import random 6 | import scipy.sparse as sp 7 | import pandas as pd 8 | from collections import defaultdict 9 | from sklearn import metrics 10 | from sklearn.preprocessing import normalize 11 | from os.path import exists 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from ogb.io.read_graph_pyg import read_graph_pyg 17 | from torch_geometric.utils import to_undirected 18 | 19 | 20 | def set_seed(seed): 21 | random.seed(seed) 22 | np.random.seed(seed) 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | 26 | train_ratio = 0.4 27 | val_ratio = 0.2 28 | def split_ids(args, node_num): 29 | node_ids = list(range(node_num)) 30 | random.shuffle(node_ids) 31 | 32 | ids = {} 33 | ids['train'] = node_ids[:int(train_ratio * len(node_ids))] 34 | ids['val'] = node_ids[int(train_ratio * len(node_ids)):int((train_ratio + val_ratio) * len(node_ids))] 35 | ids['test'] = node_ids[int((train_ratio + val_ratio) * len(node_ids)):] 36 | 37 | return ids 38 | 39 | 40 | def convert_to_edge_list(edge_index, X): 41 | edge_list = [] 42 | sorted, indices = torch.sort(edge_index[1]) 43 | source_ids = edge_index[0][indices] 44 | target_ids = edge_index[1][indices] 45 | 46 | j = 0 47 | for i in range(X.shape[0]): 48 | neighbor_list = [] 49 | while j < target_ids.shape[0] and target_ids[j] == i: 50 | neighbor_list.append(source_ids[j].item()) 51 | j += 1 52 | edge_list.append(neighbor_list) 53 | 54 | return edge_list 55 | 56 | 57 | def normalize_features(features): 58 | features = features - features.min() 59 | features.div_(features.sum(dim=-1, keepdim=True).clamp_(min=1.)) 60 | return features 61 | 62 | 63 | def load_ogbn(args): 64 | master_file = args.data_dir + "/ogbn-master.csv" 65 | master = pd.read_csv(master_file, index_col = 0) 66 | meta_dict = master[args.dataset] 67 | 68 | add_inverse_edge = meta_dict['add_inverse_edge'] == 'True' 69 | binary = meta_dict['binary'] == 'True' 70 | additional_node_files = [] 71 | additional_edge_files = [] 72 | 73 | data_dir = args.data_dir + "/" + args.dataset + "/" 74 | data = read_graph_pyg(data_dir, add_inverse_edge = add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files, binary=binary)[0] 75 | #data.x = normalize_features(data.x) 76 | node_feat = data.x.numpy() 77 | 78 | data.edge_index = to_undirected(data.edge_index) 79 | graph = convert_to_edge_list(data.edge_index, node_feat) 80 | 81 | label = pd.read_csv(osp.join(data_dir, 'node-label.csv.gz'), compression='gzip', header = None).values 82 | label = label.squeeze() 83 | 84 | feat_size = node_feat.shape[1] 85 | label_size = label.max() - label.min() + 1 86 | 87 | return graph, node_feat, label, feat_size, label_size 88 | 89 | 90 | def load_graph(args): 91 | if args.dataset in ("ogbn-arxiv", "ogbn-products"): 92 | return load_ogbn(args) 93 | 94 | dataset = args.data_dir + "/" + args.dataset + ".npz" 95 | with np.load(dataset, allow_pickle = True) as loader: 96 | loader = dict(loader) 97 | 98 | # Adjacency matrix 99 | graph = sp.csr_matrix((loader['adj_data'], loader['adj_indices'], loader['adj_indptr']), shape=loader['adj_shape']) 100 | graph = graph + graph.transpose() 101 | if args.noise_num > 0: 102 | graph = graph + sp.identity(graph.shape[0]) 103 | graph = sp.csr_matrix.toarray(graph) 104 | 105 | # Feature matrix 106 | if 'attr_data' in loader: 107 | # Attributes are stored as a sparse CSR matrix 108 | features = sp.csr_matrix((loader['attr_data'], loader['attr_indices'], loader['attr_indptr']), shape=loader['attr_shape']) 109 | features = sp.csr_matrix.toarray(features) 110 | # Normalize 111 | features = normalize(features, axis=1, norm='l2') 112 | #features = features - np.mean(features, axis=0) 113 | elif 'attr_matrix' in loader: 114 | # Attributes are stored as a (dense) np.ndarray 115 | features = loader['attr_matrix'] 116 | else: 117 | features = None 118 | 119 | # Labels 120 | if 'labels_data' in loader: 121 | # Labels are stored as a CSR matrix 122 | labels = sp.csr_matrix((loader['labels_data'], loader['labels_indices'], loader['labels_indptr']), shape=loader['labels_shape']) 123 | labels = sp.csr_matrix.toarray(labels) 124 | elif 'labels' in loader: 125 | # Labels are stored as a numpy array 126 | labels = loader['labels'] 127 | else: 128 | labels = None 129 | 130 | feat_size = features.shape[1] 131 | labels = labels - labels.min() 132 | label_size = labels.max() - labels.min() + 1 133 | 134 | return graph, features, labels, feat_size, label_size 135 | 136 | 137 | def calc_loss(y_pred, y_true): 138 | if len(y_pred.shape) == 1: 139 | y_pred = torch.unsqueeze(y_pred, 0) 140 | if len(y_true.shape) == 2: 141 | y_true = torch.squeeze(y_true) 142 | loss_train = F.cross_entropy(y_pred, y_true) 143 | return loss_train 144 | 145 | 146 | def calc_f1(y_pred, y_true): 147 | y_pred = torch.argmax(y_pred, dim=1).cpu() 148 | y_true = y_true.cpu() 149 | return metrics.f1_score(y_true, y_pred, average="micro"), metrics.f1_score(y_true, y_pred, average="macro") 150 | 151 | 152 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | warnings.filterwarnings("ignore", message=r"Passing", category=FutureWarning) 4 | 5 | import gc 6 | import os 7 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 8 | 9 | import numpy as np 10 | import torch 11 | from operator import itemgetter 12 | from datetime import datetime 13 | from time import perf_counter 14 | 15 | from args import get_args 16 | from task.utils.dataset import Dataset, collate 17 | from task.utils.utils import load_graph, split_ids 18 | 19 | from task.aggregation.gcn import run as aggregation 20 | import generator.gpt.gpt as gpt 21 | 22 | 23 | def evaluate(args, train_set, val_set, test_set): 24 | """ 25 | Evaluate the performance of GNNs on the given dataset 26 | Args: 27 | args: arguments 28 | train_set: training set 29 | val_set: validation set 30 | test_set: test set 31 | Returns: 32 | acc_mic: micro-F1 score 33 | acc_mac: macro-F1 score 34 | """ 35 | acc_mic = np.zeros(len(args.model_list)) 36 | acc_mac = np.zeros(len(args.model_list)) 37 | 38 | params = {'batch_size': args.batch_size, 39 | 'num_workers': args.num_workers, 40 | 'prefetch_factor': args.prefetch_factor, 41 | 'collate_fn': collate, 42 | 'shuffle': True, 43 | 'drop_last': True} 44 | train_loader = torch.utils.data.DataLoader(train_set, **params) 45 | val_loader = torch.utils.data.DataLoader(val_set, **params) 46 | test_loader = torch.utils.data.DataLoader(test_set, **params) 47 | 48 | for i, model_name in enumerate(args.model_list): 49 | if args.task_name == "aggregation": 50 | acc_mic[i], acc_mac[i] = aggregation(args, model_name, train_loader, val_loader, test_loader) 51 | return acc_mic, acc_mac 52 | 53 | 54 | def main(): 55 | args = get_args() 56 | args.gpt_train_name = args.task_name + '_' + args.dataset + datetime.now().strftime("_%Y%m%d_%H%M%S") 57 | 58 | # Load the original graph datasets 59 | adj, feat, label, feat_size, label_size = load_graph(args) 60 | ids = split_ids(args, feat.shape[0]) 61 | args.feat_size = feat_size 62 | args.label_size = label_size 63 | args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 64 | 65 | trials = 1 66 | acc_mic_list = np.zeros((2, len(args.model_list), trials)) 67 | acc_mac_list = np.zeros((2, len(args.model_list), trials)) 68 | 69 | for t in range(trials): 70 | 71 | # Prepare duplicate-encoded computation graphs 72 | train_set = Dataset(args, "train", adj, feat, label, ids) 73 | val_set = Dataset(args, "val", adj, feat, label, ids) 74 | test_set = Dataset(args, "test", adj, feat, label, ids) 75 | # Check GNN performance on the original dataset 76 | start_time = perf_counter() 77 | acc_mic, acc_mac = evaluate(args, train_set, val_set, test_set) 78 | acc_mic_list[0, : , t] = acc_mic 79 | acc_mac_list[0, : , t] = acc_mac 80 | print('Original evaluation time: {:.3f}, acc: {}'.format(perf_counter() - start_time, acc_mic)) 81 | 82 | ## Train GPT on the original graph 83 | start_time = perf_counter() 84 | gen_train_set, gen_val_set, gen_test_set = gpt.run(args, adj, feat, label, ids) 85 | print('GPT training/generation total time: {:.3f}'.format(perf_counter() - start_time)) 86 | 87 | ## Check GNN performance on the generated dataset 88 | start_time = perf_counter() 89 | acc_mic, acc_mac = evaluate(args, gen_train_set, gen_val_set, gen_test_set) 90 | acc_mic_list[1, : , t] = acc_mic 91 | acc_mac_list[1, : , t] = acc_mac 92 | print('Synthetic evaluation time: {:.3f}, acc: {}'.format(perf_counter() - start_time, acc_mic)) 93 | 94 | test_acc_avg = np.average(acc_mic_list, axis=2) 95 | test_acc_std = np.std(acc_mic_list, axis=2) 96 | 97 | print('Task: ' + args.task_name + ', Dataset: ' + args.dataset) 98 | for model_name in args.model_list: 99 | print(model_name, end=', ') 100 | print() 101 | for model_id in range(len(args.model_list)): 102 | print("ORG: {:.2f} {:.3f}, GEN: {:.2f} {:.3f}".format(test_acc_avg[0][model_id], test_acc_std[0][model_id],\ 103 | test_acc_avg[1][model_id], test_acc_std[1][model_id])) 104 | 105 | 106 | if __name__ == "__main__": 107 | main() 108 | 109 | --------------------------------------------------------------------------------