├── .gitignore ├── LICENSE ├── README.md ├── data_preprocess.py ├── dataloader.py ├── experiments ├── ablation_feature_noise.sh ├── ablation_gnn.sh ├── ablation_ind_split_rate.sh ├── ga_glnn_arxiv.sh ├── glnn_arxiv.sh ├── glnn_cpf.sh ├── glnn_products.sh ├── sage_arxiv.sh ├── sage_cpf.sh └── sage_products.sh ├── imgs ├── glnn.png └── trade_off.png ├── models.py ├── prepare_env.sh ├── requirements.txt ├── train.conf.yaml ├── train_and_eval.py ├── train_student.py ├── train_teacher.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/data 2 | **/outputs 3 | **/*.npz 4 | **/*.ipynb 5 | **/__pycache__ 6 | **/*.ipynb_checkpoints 7 | **/log 8 | **/plots 9 | **/temp 10 | ./.idea -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Snap Inc. 2021. 2 | 3 | All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Graph-less Neural Networks (GLNN) 3 | 4 | Code for [Graph-less Neural Networks: Teaching Old MLPs New Tricks via Distillation](https://arxiv.org/pdf/2110.08727.pdf) by [Shichang Zhang](https://shichangzh.github.io/), [Yozen Liu](https://research.snap.com/team/yozen-liu/), [Yizhou Sun](http://web.cs.ucla.edu/~yzsun/), and [Neil Shah](http://nshah.net/). 5 | 6 | 7 | ## Overview 8 | ### Distillation framework 9 |

10 |
11 | 12 |
13 |

14 | 15 | 16 | ### Accuracy vs. inference time on the `ogbn-products` dataset 17 | 18 |

19 |
20 | 21 |
22 |

23 | 24 | 25 | ## Getting Started 26 | 27 | ### Setup Environment 28 | 29 | We use conda for environment setup. You can use 30 | 31 | `bash ./prepare_env.sh` 32 | 33 | which will create a conda environment named `glnn` and install relevant requirements (from `requirements.txt`). For simplicity, we use CPU-based `torch` and `dgl` versions in this guide, as specified in requirements. To run experiments with CUDA, please install `torch` and `dgl` with proper CUDA support, remove them from `requirements.txt`, and properly set the `--device` argument in the scripts. See https://pytorch.org/ and https://www.dgl.ai/pages/start.html for more installation details. 34 | 35 | Be sure to activate the environment with 36 | 37 | `conda activate glnn` 38 | 39 | before running experiments as described below. 40 | 41 | 42 | 43 | ### Preparing datasets 44 | To run experiments for dataset used in the paper, please download from the following links and put them under `data/` (see below for instructions on organizing the datasets). 45 | 46 | - *CPF data* (`cora`, `citeseer`, `pubmed`, `a-computer`, and `a-photo`): Download the '.npz' files from [here](https://github.com/BUPT-GAMMA/CPF/tree/master/data/npz). Rename `amazon_electronics_computers.npz` and `amazon_electronics_photo.npz` to `a-computer.npz` and `a-photo.npz` respectively. 47 | 48 | - *OGB data* (`ogbn-arxiv` and `ogbn-products`): Datasets will be automatically downloaded when running the `load_data` function in `dataloader.py`. More details [here](https://ogb.stanford.edu/). 49 | 50 | - *BGNN data* (`house_class` and `vk_class`): Follow the instructions [here](https://github.com/dmlc/dgl/tree/473d5e0a4c4e4735f1c9dc9d783e0374328cca9a/examples/pytorch/bgnn) and download dataset pre-processed in DGL format from [here](https://www.dropbox.com/s/verx1evkykzli88/datasets.zip). 51 | 52 | - *NonHom data* (`penn94` and `pokec`): Follow the instructions [here](https://github.com/CUAI/Non-Homophily-Benchmarks) to download the `penn94` dataset and its splits. The `pokec` dataset will be automatically downloaded when running the `load_data` function in `dataloader.py`. 53 | 54 | - Your favourite datasets: download and add to the `load_data` function in `dataloader.py`. 55 | 56 | 57 | ### Usage 58 | 59 | To quickly train a teacher model you can run `train_teacher.py` by specifying the experiment setting, i.e. transductive (`tran`) or inductive (`ind`), teacher model, e.g. `GCN`, and dataset, e.g. `cora`, as per the example below. 60 | 61 | ``` 62 | python train_teacher.py --exp_setting tran --teacher GCN --dataset cora 63 | ``` 64 | 65 | To quickly train a student model with a pretrained teacher you can run `train_student.py` by specifying the experiment setting, teacher model, student model, and dataset like the example below. Make sure you train the teacher using the `train_teacher.py` first and have its result stored in the correct path specified by `--out_t_path`. 66 | 67 | ``` 68 | python train_student.py --exp_setting ind --teacher SAGE --student MLP --dataset citeseer --out_t_path outputs 69 | ``` 70 | 71 | For more examples, and to reproduce results in the paper, please refer to scripts in `experiments/` as below. 72 | 73 | ``` 74 | bash experiments/sage_cpf.sh 75 | ``` 76 | 77 | To extend GLNN to your own model, you may do one of the following. 78 | - Add your favourite model architectures to the `Model` class in `model.py`. Then follow the examples above. 79 | - Train teacher model and store its output (log-probabilities). Then train the student by `train_student.py` with the correct `--out_t_path`. 80 | 81 | 82 | ## Results 83 | 84 | GraphSAGE vs. MLP vs. GLNN under the production setting described in the paper (transductive and inductive combined). Delta_MLP (Delta_GNN) represents difference between the GLNN and the MLP (GNN). Results show classification accuracy (higher is better); Delta_GNN > 0 indicates GLNN outperforms GNN. We observe that GLNNs always improve from MLPs by large margins and achieve competitive results as GNN on 6/7 datasets. Please see Table 3 in the paper for more details. 85 | 86 | | Datasets | GNN(SAGE) | MLP | GLNN | Delta_MLP | Delta_GNN | 87 | |------------|----------------|--------------|----------------|-----------------|-------------------| 88 | | Cora | **79.29** | 58.98 | 78.28 | 19.30 (32.72\%) | -1.01 (-1.28\%) | 89 | | Citseer | 68.38 | 59.81 | **69.27** | 9.46 (15.82\%) | 0.89 (1.30\%) | 90 | | Pubmed | **74.88** | 66.80 | 74.71 | 7.91 (11.83\%) | -0.17 (-0.22\%) | 91 | | A-computer | 82.14 | 67.38 | **82.29** | 14.90 (22.12\%) | 0.15 (0.19\%) | 92 | | A-photo | 91.08 | 79.25 | **92.38** | 13.13 (16.57\%) | 1.30 (1.42\%) | 93 | | Arxiv | **70.73** | 55.30 | 65.09 | 9.79 (17.70\%) | -5.64 (-7.97\%) | 94 | | Products | **76.60** | 63.72 | 75.77 | 12.05 (18.91\%) | -0.83 (-1.09\%) | 95 | 96 | 121 | 122 | ## Citation 123 | 124 | If you find our work useful, please cite the following: 125 | 126 | ```BibTeX 127 | @inproceedings{zhang2021graphless, 128 | title={Graph-less Neural Networks: Teaching Old MLPs New Tricks via Distillation}, 129 | author={Shichang Zhang and Yozen Liu and Yizhou Sun and Neil Shah}, 130 | booktitle={International Conference on Learning Representations} 131 | year={2022}, 132 | url={https://arxiv.org/abs/2110.08727} 133 | } 134 | ``` 135 | 136 | ## Contact Us 137 | 138 | Please open an issue or contact `shichang@cs.ucla.edu` if you have any questions. 139 | 140 | 141 | -------------------------------------------------------------------------------- /data_preprocess.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from the CPF implementation 3 | https://github.com/BUPT-GAMMA/CPF/tree/389c01aaf238689ee7b1e5aba127842341e123b6/data 4 | """ 5 | 6 | import numpy as np 7 | import scipy.sparse as sp 8 | from collections import Counter 9 | from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer 10 | 11 | 12 | def is_binary_bag_of_words(features): 13 | features_coo = features.tocoo() 14 | return all( 15 | single_entry == 1.0 16 | for _, _, single_entry in zip( 17 | features_coo.row, features_coo.col, features_coo.data 18 | ) 19 | ) 20 | 21 | 22 | def to_binary_bag_of_words(features): 23 | """Converts TF/IDF features to binary bag-of-words features.""" 24 | features_copy = features.tocsr() 25 | features_copy.data[:] = 1.0 26 | return features_copy 27 | 28 | 29 | def normalize(mx): 30 | """Row-normalize sparse matrix""" 31 | rowsum = np.array(mx.sum(1)) 32 | r_inv = np.power(rowsum, -1).flatten() 33 | r_inv[np.isinf(r_inv)] = 0.0 34 | r_mat_inv = sp.diags(r_inv) 35 | mx = r_mat_inv.dot(mx) 36 | return mx 37 | 38 | 39 | def normalize_adj(adj): 40 | adj = normalize(adj + sp.eye(adj.shape[0])) 41 | return adj 42 | 43 | 44 | def eliminate_self_loops_adj(A): 45 | """Remove self-loops from the adjacency matrix.""" 46 | A = A.tolil() 47 | A.setdiag(0) 48 | A = A.tocsr() 49 | A.eliminate_zeros() 50 | return A 51 | 52 | 53 | def largest_connected_components(sparse_graph, n_components=1): 54 | """Select the largest connected components in the graph. 55 | 56 | Parameters 57 | ---------- 58 | sparse_graph : SparseGraph 59 | Input graph. 60 | n_components : int, default 1 61 | Number of largest connected components to keep. 62 | 63 | Returns 64 | ------- 65 | sparse_graph : SparseGraph 66 | Subgraph of the input graph where only the nodes in largest n_components are kept. 67 | 68 | """ 69 | _, component_indices = sp.csgraph.connected_components(sparse_graph.adj_matrix) 70 | component_sizes = np.bincount(component_indices) 71 | components_to_keep = np.argsort(component_sizes)[::-1][ 72 | :n_components 73 | ] # reverse order to sort descending 74 | nodes_to_keep = [ 75 | idx 76 | for (idx, component) in enumerate(component_indices) 77 | if component in components_to_keep 78 | ] 79 | return create_subgraph(sparse_graph, nodes_to_keep=nodes_to_keep) 80 | 81 | 82 | def create_subgraph( 83 | sparse_graph, _sentinel=None, nodes_to_remove=None, nodes_to_keep=None 84 | ): 85 | """Create a graph with the specified subset of nodes. 86 | 87 | Exactly one of (nodes_to_remove, nodes_to_keep) should be provided, while the other stays None. 88 | Note that to avoid confusion, it is required to pass node indices as named arguments to this function. 89 | 90 | Parameters 91 | ---------- 92 | sparse_graph : SparseGraph 93 | Input graph. 94 | _sentinel : None 95 | Internal, to prevent passing positional arguments. Do not use. 96 | nodes_to_remove : array-like of int 97 | Indices of nodes that have to removed. 98 | nodes_to_keep : array-like of int 99 | Indices of nodes that have to be kept. 100 | 101 | Returns 102 | ------- 103 | sparse_graph : SparseGraph 104 | Graph with specified nodes removed. 105 | 106 | """ 107 | # Check that arguments are passed correctly 108 | if _sentinel is not None: 109 | raise ValueError( 110 | "Only call `create_subgraph` with named arguments'," 111 | " (nodes_to_remove=...) or (nodes_to_keep=...)" 112 | ) 113 | if nodes_to_remove is None and nodes_to_keep is None: 114 | raise ValueError("Either nodes_to_remove or nodes_to_keep must be provided.") 115 | elif nodes_to_remove is not None and nodes_to_keep is not None: 116 | raise ValueError( 117 | "Only one of nodes_to_remove or nodes_to_keep must be provided." 118 | ) 119 | elif nodes_to_remove is not None: 120 | nodes_to_keep = [ 121 | i for i in range(sparse_graph.num_nodes()) if i not in nodes_to_remove 122 | ] 123 | elif nodes_to_keep is not None: 124 | nodes_to_keep = sorted(nodes_to_keep) 125 | else: 126 | raise RuntimeError("This should never happen.") 127 | 128 | sparse_graph.adj_matrix = sparse_graph.adj_matrix[nodes_to_keep][:, nodes_to_keep] 129 | if sparse_graph.attr_matrix is not None: 130 | sparse_graph.attr_matrix = sparse_graph.attr_matrix[nodes_to_keep] 131 | if sparse_graph.labels is not None: 132 | sparse_graph.labels = sparse_graph.labels[nodes_to_keep] 133 | if sparse_graph.node_names is not None: 134 | sparse_graph.node_names = sparse_graph.node_names[nodes_to_keep] 135 | return sparse_graph 136 | 137 | 138 | def binarize_labels(labels, sparse_output=False, return_classes=False): 139 | """Convert labels vector to a binary label matrix. 140 | 141 | In the default single-label case, labels look like 142 | labels = [y1, y2, y3, ...]. 143 | Also supports the multi-label format. 144 | In this case, labels should look something like 145 | labels = [[y11, y12], [y21, y22, y23], [y31], ...]. 146 | 147 | Parameters 148 | ---------- 149 | labels : array-like, shape [num_samples] 150 | Array of node labels in categorical single- or multi-label format. 151 | sparse_output : bool, default False 152 | Whether return the label_matrix in CSR format. 153 | return_classes : bool, default False 154 | Whether return the classes corresponding to the columns of the label matrix. 155 | 156 | Returns 157 | ------- 158 | label_matrix : np.ndarray or sp.csr_matrix, shape [num_samples, num_classes] 159 | Binary matrix of class labels. 160 | num_classes = number of unique values in "labels" array. 161 | label_matrix[i, k] = 1 <=> node i belongs to class k. 162 | classes : np.array, shape [num_classes], optional 163 | Classes that correspond to each column of the label_matrix. 164 | 165 | """ 166 | if hasattr(labels[0], "__iter__"): # labels[0] is iterable <=> multilabel format 167 | binarizer = MultiLabelBinarizer(sparse_output=sparse_output) 168 | else: 169 | binarizer = LabelBinarizer(sparse_output=sparse_output) 170 | label_matrix = binarizer.fit_transform(labels).astype(np.float32) 171 | return (label_matrix, binarizer.classes_) if return_classes else label_matrix 172 | 173 | 174 | def remove_underrepresented_classes( 175 | g, train_examples_per_class, val_examples_per_class 176 | ): 177 | """Remove nodes from graph that correspond to a class of which there are less than 178 | num_classes * train_examples_per_class + num_classes * val_examples_per_class nodes. 179 | 180 | Those classes would otherwise break the training procedure. 181 | """ 182 | min_examples_per_class = train_examples_per_class + val_examples_per_class 183 | examples_counter = Counter(g.labels) 184 | keep_classes = set( 185 | class_ 186 | for class_, count in examples_counter.items() 187 | if count > min_examples_per_class 188 | ) 189 | keep_indices = [i for i in range(len(g.labels)) if g.labels[i] in keep_classes] 190 | 191 | return create_subgraph(g, nodes_to_keep=keep_indices) 192 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataloader of CPF datasets are adapted from the CPF implementation 3 | https://github.com/BUPT-GAMMA/CPF/tree/389c01aaf238689ee7b1e5aba127842341e123b6/data 4 | 5 | Dataloader of NonHom datasets are adapted from the Non-homophily benchmarks 6 | https://github.com/CUAI/Non-Homophily-Benchmarks 7 | 8 | Dataloader of BGNN datasets are adapted from the BGNN implementation and dgl example of BGNN 9 | https://github.com/nd7141/bgnn 10 | https://github.com/dmlc/dgl/tree/473d5e0a4c4e4735f1c9dc9d783e0374328cca9a/examples/pytorch/bgnn 11 | """ 12 | 13 | import numpy as np 14 | import scipy.sparse as sp 15 | import torch 16 | import dgl 17 | import os 18 | import scipy 19 | import pandas as pd 20 | import json 21 | from dgl.data.utils import load_graphs 22 | from os import path 23 | from category_encoders import CatBoostEncoder 24 | from pathlib import Path 25 | from google_drive_downloader import GoogleDriveDownloader as gdd 26 | from sklearn.preprocessing import label_binarize 27 | from sklearn import preprocessing 28 | from data_preprocess import ( 29 | normalize_adj, 30 | eliminate_self_loops_adj, 31 | largest_connected_components, 32 | binarize_labels, 33 | ) 34 | from ogb.nodeproppred import DglNodePropPredDataset 35 | 36 | CPF_data = ["cora", "citeseer", "pubmed", "a-computer", "a-photo"] 37 | OGB_data = ["ogbn-arxiv", "ogbn-products"] 38 | NonHom_data = ["pokec", "penn94"] 39 | BGNN_data = ["house_class", "vk_class"] 40 | 41 | 42 | def load_data(dataset, dataset_path, **kwargs): 43 | if dataset in CPF_data: 44 | return load_cpf_data( 45 | dataset, 46 | dataset_path, 47 | kwargs["seed"], 48 | kwargs["labelrate_train"], 49 | kwargs["labelrate_val"], 50 | ) 51 | elif dataset in OGB_data: 52 | return load_ogb_data(dataset, dataset_path) 53 | elif dataset in NonHom_data: 54 | return load_nonhom_data(dataset, dataset_path, kwargs["split_idx"]) 55 | elif dataset in BGNN_data: 56 | return load_bgnn_data(dataset, dataset_path, kwargs["split_idx"]) 57 | else: 58 | raise ValueError(f"Unknown dataset: {dataset}") 59 | 60 | 61 | def load_ogb_data(dataset, dataset_path): 62 | data = DglNodePropPredDataset(dataset, dataset_path) 63 | splitted_idx = data.get_idx_split() 64 | idx_train, idx_val, idx_test = ( 65 | splitted_idx["train"], 66 | splitted_idx["valid"], 67 | splitted_idx["test"], 68 | ) 69 | 70 | g, labels = data[0] 71 | labels = labels.squeeze() 72 | 73 | # Turn the graph to undirected 74 | if dataset == "ogbn-arxiv": 75 | srcs, dsts = g.all_edges() 76 | g.add_edges(dsts, srcs) 77 | g = g.remove_self_loop().add_self_loop() 78 | 79 | return g, labels, idx_train, idx_val, idx_test 80 | 81 | 82 | def load_cpf_data(dataset, dataset_path, seed, labelrate_train, labelrate_val): 83 | data_path = Path.cwd().joinpath(dataset_path, f"{dataset}.npz") 84 | if os.path.isfile(data_path): 85 | data = load_npz_to_sparse_graph(data_path) 86 | else: 87 | raise ValueError(f"{data_path} doesn't exist.") 88 | 89 | # remove self loop and extract the largest CC 90 | data = data.standardize() 91 | adj, features, labels = data.unpack() 92 | 93 | labels = binarize_labels(labels) 94 | 95 | random_state = np.random.RandomState(seed) 96 | idx_train, idx_val, idx_test = get_train_val_test_split( 97 | random_state, labels, labelrate_train, labelrate_val 98 | ) 99 | 100 | features = torch.FloatTensor(np.array(features.todense())) 101 | labels = torch.LongTensor(labels.argmax(axis=1)) 102 | 103 | adj = normalize_adj(adj) 104 | adj_sp = adj.tocoo() 105 | g = dgl.graph((adj_sp.row, adj_sp.col)) 106 | g.ndata["feat"] = features 107 | 108 | idx_train = torch.LongTensor(idx_train) 109 | idx_val = torch.LongTensor(idx_val) 110 | idx_test = torch.LongTensor(idx_test) 111 | return g, labels, idx_train, idx_val, idx_test 112 | 113 | 114 | def load_nonhom_data(dataset, dataset_path, split_idx): 115 | data_path = Path.cwd().joinpath(dataset_path, f"{dataset}.mat") 116 | data_split_path = Path.cwd().joinpath( 117 | dataset_path, "splits", f"{dataset}-splits.npy" 118 | ) 119 | 120 | if dataset == "pokec": 121 | g, features, labels = load_pokec_mat(data_path) 122 | elif dataset == "penn94": 123 | g, features, labels = load_penn94_mat(data_path) 124 | else: 125 | raise ValueError("Invalid dataname") 126 | 127 | g = g.remove_self_loop().add_self_loop() 128 | g.ndata["feat"] = features 129 | labels = torch.LongTensor(labels) 130 | 131 | splitted_idx = load_fixed_splits(dataset, data_split_path, split_idx) 132 | idx_train, idx_val, idx_test = ( 133 | splitted_idx["train"], 134 | splitted_idx["valid"], 135 | splitted_idx["test"], 136 | ) 137 | return g, labels, idx_train, idx_val, idx_test 138 | 139 | 140 | def load_bgnn_data(dataset, dataset_path, split_idx): 141 | data_path = Path.cwd().joinpath(dataset_path, f"{dataset}") 142 | 143 | g, X, y, cat_features, masks = read_input(data_path) 144 | train_mask, val_mask, test_mask = ( 145 | masks[str(split_idx)]["train"], 146 | masks[str(split_idx)]["val"], 147 | masks[str(split_idx)]["test"], 148 | ) 149 | 150 | encoded_X = X.copy() 151 | if cat_features is not None and len(cat_features): 152 | encoded_X = encode_cat_features( 153 | encoded_X, y, cat_features, train_mask, val_mask, test_mask 154 | ) 155 | encoded_X = normalize_features(encoded_X, train_mask, val_mask, test_mask) 156 | encoded_X = replace_na(encoded_X, train_mask) 157 | features, labels = pandas_to_torch(encoded_X, y) 158 | 159 | g = g.remove_self_loop().add_self_loop() 160 | g.ndata["feat"] = features 161 | labels = labels.long() 162 | 163 | idx_train = torch.LongTensor(train_mask) 164 | idx_val = torch.LongTensor(val_mask) 165 | idx_test = torch.LongTensor(test_mask) 166 | return g, labels, idx_train, idx_val, idx_test 167 | 168 | 169 | def load_out_t(out_t_dir): 170 | return torch.from_numpy(np.load(out_t_dir.joinpath("out.npz"))["arr_0"]) 171 | 172 | 173 | """ For NonHom""" 174 | dataset_drive_url = {"pokec": "1dNs5E7BrWJbgcHeQ_zuy5Ozp2tRCWG0y"} 175 | splits_drive_url = {"pokec": "1ZhpAiyTNc0cE_hhgyiqxnkKREHK7MK-_"} 176 | 177 | 178 | def load_penn94_mat(data_path): 179 | mat = scipy.io.loadmat(data_path) 180 | A = mat["A"] 181 | metadata = mat["local_info"] 182 | 183 | edge_index = torch.tensor(A.nonzero(), dtype=torch.long) 184 | metadata = metadata.astype(np.int) 185 | 186 | # make features into one-hot encodings 187 | feature_vals = np.hstack((np.expand_dims(metadata[:, 0], 1), metadata[:, 2:])) 188 | features = np.empty((A.shape[0], 0)) 189 | for col in range(feature_vals.shape[1]): 190 | feat_col = feature_vals[:, col] 191 | feat_onehot = label_binarize(feat_col, classes=np.unique(feat_col)) 192 | features = np.hstack((features, feat_onehot)) 193 | 194 | g = dgl.graph((edge_index[0], edge_index[1])) 195 | g = dgl.to_bidirected(g) 196 | 197 | features = torch.tensor(features, dtype=torch.float) 198 | labels = torch.tensor(metadata[:, 1] - 1) # gender label, -1 means unlabeled 199 | return g, features, labels 200 | 201 | 202 | def load_pokec_mat(data_path): 203 | if not path.exists(data_path): 204 | gdd.download_file_from_google_drive( 205 | file_id=dataset_drive_url["pokec"], dest_path=data_path, showsize=True 206 | ) 207 | 208 | fulldata = scipy.io.loadmat(data_path) 209 | edge_index = torch.tensor(fulldata["edge_index"], dtype=torch.long) 210 | g = dgl.graph((edge_index[0], edge_index[1])) 211 | g = dgl.to_bidirected(g) 212 | 213 | features = torch.tensor(fulldata["node_feat"]).float() 214 | labels = fulldata["label"].flatten() 215 | return g, features, labels 216 | 217 | 218 | class NCDataset(object): 219 | def __init__(self, name, root): 220 | """ 221 | based off of ogb NodePropPredDataset 222 | https://github.com/snap-stanford/ogb/blob/master/ogb/nodeproppred/dataset.py 223 | Gives torch tensors instead of numpy arrays 224 | - name (str): name of the dataset 225 | - root (str): root directory to store the dataset folder 226 | - meta_dict: dictionary that stores all the meta-information about data. Default is None, 227 | but when something is passed, it uses its information. Useful for debugging for external contributers. 228 | 229 | Usage after construction: 230 | 231 | split_idx = dataset.get_idx_split() 232 | train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] 233 | graph, label = dataset[0] 234 | 235 | Where the graph is a dictionary of the following form: 236 | dataset.graph = {'edge_index': edge_index, 237 | 'edge_feat': None, 238 | 'node_feat': node_feat, 239 | 'num_nodes': num_nodes} 240 | For additional documentation, see OGB Library-Agnostic Loader https://ogb.stanford.edu/docs/nodeprop/ 241 | 242 | """ 243 | 244 | self.name = name # original name, e.g., ogbn-proteins 245 | self.graph = {} 246 | self.label = None 247 | 248 | def rand_train_test_idx(label, train_prop, valid_prop, ignore_negative): 249 | """ 250 | Randomly splits the dataset into train, validation, and test sets. 251 | """ 252 | if ignore_negative: 253 | non_negative_idx = np.where(label >= 0)[0] 254 | else: 255 | non_negative_idx = np.arange(len(label)) 256 | 257 | num_nodes = len(non_negative_idx) 258 | num_train = int(train_prop * num_nodes) 259 | num_valid = int(valid_prop * num_nodes) 260 | num_test = num_nodes - num_train - num_valid 261 | 262 | idx = np.random.permutation(non_negative_idx) 263 | 264 | train_idx = idx[:num_train] 265 | valid_idx = idx[num_train: num_train + num_valid] 266 | test_idx = idx[num_train + num_valid:] 267 | 268 | return train_idx, valid_idx, test_idx 269 | 270 | def get_idx_split(self, split_type="random", train_prop=0.5, valid_prop=0.25): 271 | """ 272 | train_prop: The proportion of dataset for train split. Between 0 and 1. 273 | valid_prop: The proportion of dataset for validation split. Between 0 and 1. 274 | """ 275 | 276 | if split_type == "random": 277 | ignore_negative = False if self.name == "ogbn-proteins" else True 278 | train_idx, valid_idx, test_idx = self.rand_train_test_idx( 279 | self.label, 280 | train_prop=train_prop, 281 | valid_prop=valid_prop, 282 | ignore_negative=ignore_negative, 283 | ) 284 | split_idx = {"train": train_idx, "valid": valid_idx, "test": test_idx} 285 | return split_idx 286 | 287 | def __getitem__(self, idx): 288 | assert idx == 0, "This dataset has only one graph" 289 | return self.graph, self.label 290 | 291 | def __len__(self): 292 | return 1 293 | 294 | def __repr__(self): 295 | return "{}({})".format(self.__class__.__name__, len(self)) 296 | 297 | 298 | def load_fixed_splits(dataset, data_split_path="", split_idx=0): 299 | if not os.path.exists(data_split_path): 300 | assert dataset in splits_drive_url.keys() 301 | gdd.download_file_from_google_drive( 302 | file_id=splits_drive_url[dataset], dest_path=data_split_path, showsize=True 303 | ) 304 | 305 | splits_lst = np.load(data_split_path, allow_pickle=True) 306 | splits = splits_lst[split_idx] 307 | 308 | for key in splits: 309 | if not torch.is_tensor(splits[key]): 310 | splits[key] = torch.as_tensor(splits[key]) 311 | 312 | return splits 313 | 314 | 315 | """For BGNN """ 316 | 317 | 318 | def pandas_to_torch(*args): 319 | return [torch.from_numpy(arg.to_numpy(copy=True)).float().squeeze() for arg in args] 320 | 321 | 322 | def read_input(input_folder): 323 | X = pd.read_csv(f"{input_folder}/X.csv") 324 | y = pd.read_csv(f"{input_folder}/y.csv") 325 | 326 | categorical_columns = [] 327 | if os.path.exists(f"{input_folder}/cat_features.txt"): 328 | with open(f"{input_folder}/cat_features.txt") as f: 329 | for line in f: 330 | if line.strip(): 331 | categorical_columns.append(line.strip()) 332 | 333 | cat_features = None 334 | if categorical_columns: 335 | columns = X.columns 336 | cat_features = np.where(columns.isin(categorical_columns))[0] 337 | 338 | for col in list(columns[cat_features]): 339 | X[col] = X[col].astype(str) 340 | 341 | gs, _ = load_graphs(f"{input_folder}/graph.dgl") 342 | graph = gs[0] 343 | 344 | with open(f"{input_folder}/masks.json") as f: 345 | masks = json.load(f) 346 | 347 | return graph, X, y, cat_features, masks 348 | 349 | 350 | def normalize_features(X, train_mask, val_mask, test_mask): 351 | min_max_scaler = preprocessing.MinMaxScaler() 352 | A = X.to_numpy(copy=True) 353 | A[train_mask] = min_max_scaler.fit_transform(A[train_mask]) 354 | A[val_mask + test_mask] = min_max_scaler.transform(A[val_mask + test_mask]) 355 | return pd.DataFrame(A, columns=X.columns).astype(float) 356 | 357 | 358 | def replace_na(X, train_mask): 359 | if X.isna().any().any(): 360 | return X.fillna(X.iloc[train_mask].min() - 1) 361 | return X 362 | 363 | 364 | def encode_cat_features(X, y, cat_features, train_mask, val_mask, test_mask): 365 | enc = CatBoostEncoder() 366 | A = X.to_numpy(copy=True) 367 | b = y.to_numpy(copy=True) 368 | A[np.ix_(train_mask, cat_features)] = enc.fit_transform( 369 | A[np.ix_(train_mask, cat_features)], b[train_mask] 370 | ) 371 | A[np.ix_(val_mask + test_mask, cat_features)] = enc.transform( 372 | A[np.ix_(val_mask + test_mask, cat_features)] 373 | ) 374 | A = A.astype(float) 375 | return pd.DataFrame(A, columns=X.columns) 376 | 377 | 378 | """ For CPF""" 379 | 380 | 381 | class SparseGraph: 382 | """Attributed labeled graph stored in sparse matrix form.""" 383 | 384 | def __init__( 385 | self, 386 | adj_matrix, 387 | attr_matrix=None, 388 | labels=None, 389 | node_names=None, 390 | attr_names=None, 391 | class_names=None, 392 | metadata=None, 393 | ): 394 | """Create an attributed graph. 395 | 396 | Parameters 397 | ---------- 398 | adj_matrix : sp.csr_matrix, shape [num_nodes, num_nodes] 399 | Adjacency matrix in CSR format. 400 | attr_matrix : sp.csr_matrix or np.ndarray, shape [num_nodes, num_attr], optional 401 | Attribute matrix in CSR or numpy format. 402 | labels : np.ndarray, shape [num_nodes], optional 403 | Array, where each entry represents respective node's label(s). 404 | node_names : np.ndarray, shape [num_nodes], optional 405 | Names of nodes (as strings). 406 | attr_names : np.ndarray, shape [num_attr] 407 | Names of the attributes (as strings). 408 | class_names : np.ndarray, shape [num_classes], optional 409 | Names of the class labels (as strings). 410 | metadata : object 411 | Additional metadata such as text. 412 | 413 | """ 414 | # Make sure that the dimensions of matrices / arrays all agree 415 | if sp.isspmatrix(adj_matrix): 416 | adj_matrix = adj_matrix.tocsr().astype(np.float32) 417 | else: 418 | raise ValueError( 419 | "Adjacency matrix must be in sparse format (got {0} instead)".format( 420 | type(adj_matrix) 421 | ) 422 | ) 423 | 424 | if adj_matrix.shape[0] != adj_matrix.shape[1]: 425 | raise ValueError("Dimensions of the adjacency matrix don't agree") 426 | 427 | if attr_matrix is not None: 428 | if sp.isspmatrix(attr_matrix): 429 | attr_matrix = attr_matrix.tocsr().astype(np.float32) 430 | elif isinstance(attr_matrix, np.ndarray): 431 | attr_matrix = attr_matrix.astype(np.float32) 432 | else: 433 | raise ValueError( 434 | "Attribute matrix must be a sp.spmatrix or a np.ndarray (got {0} instead)".format( 435 | type(attr_matrix) 436 | ) 437 | ) 438 | 439 | if attr_matrix.shape[0] != adj_matrix.shape[0]: 440 | raise ValueError( 441 | "Dimensions of the adjacency and attribute matrices don't agree" 442 | ) 443 | 444 | if labels is not None: 445 | if labels.shape[0] != adj_matrix.shape[0]: 446 | raise ValueError( 447 | "Dimensions of the adjacency matrix and the label vector don't agree" 448 | ) 449 | 450 | if node_names is not None: 451 | if len(node_names) != adj_matrix.shape[0]: 452 | raise ValueError( 453 | "Dimensions of the adjacency matrix and the node names don't agree" 454 | ) 455 | 456 | if attr_names is not None: 457 | if len(attr_names) != attr_matrix.shape[1]: 458 | raise ValueError( 459 | "Dimensions of the attribute matrix and the attribute names don't agree" 460 | ) 461 | 462 | self.adj_matrix = adj_matrix 463 | self.attr_matrix = attr_matrix 464 | self.labels = labels 465 | self.node_names = node_names 466 | self.attr_names = attr_names 467 | self.class_names = class_names 468 | self.metadata = metadata 469 | 470 | def num_nodes(self): 471 | """Get the number of nodes in the graph.""" 472 | return self.adj_matrix.shape[0] 473 | 474 | def num_edges(self): 475 | """Get the number of edges in the graph. 476 | 477 | For undirected graphs, (i, j) and (j, i) are counted as single edge. 478 | """ 479 | if self.is_directed(): 480 | return int(self.adj_matrix.nnz) 481 | else: 482 | return int(self.adj_matrix.nnz / 2) 483 | 484 | def get_neighbors(self, idx): 485 | """Get the indices of neighbors of a given node. 486 | 487 | Parameters 488 | ---------- 489 | idx : int 490 | Index of the node whose neighbors are of interest. 491 | 492 | """ 493 | return self.adj_matrix[idx].indices 494 | 495 | def is_directed(self): 496 | """Check if the graph is directed (adjacency matrix is not symmetric).""" 497 | return (self.adj_matrix != self.adj_matrix.T).sum() != 0 498 | 499 | def to_undirected(self): 500 | """Convert to an undirected graph (make adjacency matrix symmetric).""" 501 | if self.is_weighted(): 502 | raise ValueError("Convert to unweighted graph first.") 503 | else: 504 | self.adj_matrix = self.adj_matrix + self.adj_matrix.T 505 | self.adj_matrix[self.adj_matrix != 0] = 1 506 | return self 507 | 508 | def is_weighted(self): 509 | """Check if the graph is weighted (edge weights other than 1).""" 510 | return np.any(np.unique(self.adj_matrix[self.adj_matrix != 0].A1) != 1) 511 | 512 | def to_unweighted(self): 513 | """Convert to an unweighted graph (set all edge weights to 1).""" 514 | self.adj_matrix.data = np.ones_like(self.adj_matrix.data) 515 | return self 516 | 517 | # Quality of life (shortcuts) 518 | def standardize(self): 519 | """Select the LCC of the unweighted/undirected/no-self-loop graph. 520 | 521 | All changes are done inplace. 522 | 523 | """ 524 | G = self.to_unweighted().to_undirected() 525 | G.adj_matrix = eliminate_self_loops_adj(G.adj_matrix) 526 | G = largest_connected_components(G, 1) 527 | return G 528 | 529 | def unpack(self): 530 | """Return the (A, X, z) triplet.""" 531 | return self.adj_matrix, self.attr_matrix, self.labels 532 | 533 | 534 | def load_npz_to_sparse_graph(file_name): 535 | """Load a SparseGraph from a Numpy binary file. 536 | 537 | Parameters 538 | ---------- 539 | file_name : str 540 | Name of the file to load. 541 | 542 | Returns 543 | ------- 544 | sparse_graph : SparseGraph 545 | Graph in sparse matrix format. 546 | 547 | """ 548 | with np.load(file_name, allow_pickle=True) as loader: 549 | loader = dict(loader) 550 | adj_matrix = sp.csr_matrix( 551 | (loader["adj_data"], loader["adj_indices"], loader["adj_indptr"]), 552 | shape=loader["adj_shape"], 553 | ) 554 | 555 | if "attr_data" in loader: 556 | # Attributes are stored as a sparse CSR matrix 557 | attr_matrix = sp.csr_matrix( 558 | (loader["attr_data"], loader["attr_indices"], loader["attr_indptr"]), 559 | shape=loader["attr_shape"], 560 | ) 561 | elif "attr_matrix" in loader: 562 | # Attributes are stored as a (dense) np.ndarray 563 | attr_matrix = loader["attr_matrix"] 564 | else: 565 | attr_matrix = None 566 | 567 | if "labels_data" in loader: 568 | # Labels are stored as a CSR matrix 569 | labels = sp.csr_matrix( 570 | ( 571 | loader["labels_data"], 572 | loader["labels_indices"], 573 | loader["labels_indptr"], 574 | ), 575 | shape=loader["labels_shape"], 576 | ) 577 | elif "labels" in loader: 578 | # Labels are stored as a numpy array 579 | labels = loader["labels"] 580 | else: 581 | labels = None 582 | 583 | node_names = loader.get("node_names") 584 | attr_names = loader.get("attr_names") 585 | class_names = loader.get("class_names") 586 | metadata = loader.get("metadata") 587 | 588 | return SparseGraph( 589 | adj_matrix, attr_matrix, labels, node_names, attr_names, class_names, metadata 590 | ) 591 | 592 | 593 | def sample_per_class( 594 | random_state, labels, num_examples_per_class, forbidden_indices=None 595 | ): 596 | """ 597 | Used in get_train_val_test_split, when we try to get a fixed number of examples per class 598 | """ 599 | 600 | num_samples, num_classes = labels.shape 601 | sample_indices_per_class = {index: [] for index in range(num_classes)} 602 | 603 | # get indices sorted by class 604 | for class_index in range(num_classes): 605 | for sample_index in range(num_samples): 606 | if labels[sample_index, class_index] > 0.0: 607 | if forbidden_indices is None or sample_index not in forbidden_indices: 608 | sample_indices_per_class[class_index].append(sample_index) 609 | 610 | # get specified number of indices for each class 611 | return np.concatenate( 612 | [ 613 | random_state.choice( 614 | sample_indices_per_class[class_index], 615 | num_examples_per_class, 616 | replace=False, 617 | ) 618 | for class_index in range(len(sample_indices_per_class)) 619 | ] 620 | ) 621 | 622 | 623 | def get_train_val_test_split( 624 | random_state, 625 | labels, 626 | train_examples_per_class=None, 627 | val_examples_per_class=None, 628 | test_examples_per_class=None, 629 | train_size=None, 630 | val_size=None, 631 | test_size=None, 632 | ): 633 | 634 | num_samples, num_classes = labels.shape 635 | remaining_indices = list(range(num_samples)) 636 | if train_examples_per_class is not None: 637 | train_indices = sample_per_class(random_state, labels, train_examples_per_class) 638 | else: 639 | # select train examples with no respect to class distribution 640 | train_indices = random_state.choice( 641 | remaining_indices, train_size, replace=False 642 | ) 643 | 644 | if val_examples_per_class is not None: 645 | val_indices = sample_per_class( 646 | random_state, 647 | labels, 648 | val_examples_per_class, 649 | forbidden_indices=train_indices, 650 | ) 651 | else: 652 | remaining_indices = np.setdiff1d(remaining_indices, train_indices) 653 | val_indices = random_state.choice(remaining_indices, val_size, replace=False) 654 | 655 | forbidden_indices = np.concatenate((train_indices, val_indices)) 656 | if test_examples_per_class is not None: 657 | test_indices = sample_per_class( 658 | random_state, 659 | labels, 660 | test_examples_per_class, 661 | forbidden_indices=forbidden_indices, 662 | ) 663 | elif test_size is not None: 664 | remaining_indices = np.setdiff1d(remaining_indices, forbidden_indices) 665 | test_indices = random_state.choice(remaining_indices, test_size, replace=False) 666 | else: 667 | test_indices = np.setdiff1d(remaining_indices, forbidden_indices) 668 | 669 | # assert that there are no duplicates in sets 670 | assert len(set(train_indices)) == len(train_indices) 671 | assert len(set(val_indices)) == len(val_indices) 672 | assert len(set(test_indices)) == len(test_indices) 673 | # assert sets are mutually exclusive 674 | assert len(set(train_indices) - set(val_indices)) == len(set(train_indices)) 675 | assert len(set(train_indices) - set(test_indices)) == len(set(train_indices)) 676 | assert len(set(val_indices) - set(test_indices)) == len(set(val_indices)) 677 | if test_size is None and test_examples_per_class is None: 678 | # all indices must be part of the split 679 | assert ( 680 | len(np.concatenate((train_indices, val_indices, test_indices))) 681 | == num_samples 682 | ) 683 | 684 | if train_examples_per_class is not None: 685 | train_labels = labels[train_indices, :] 686 | train_sum = np.sum(train_labels, axis=0) 687 | # assert all classes have equal cardinality 688 | assert np.unique(train_sum).size == 1 689 | 690 | if val_examples_per_class is not None: 691 | val_labels = labels[val_indices, :] 692 | val_sum = np.sum(val_labels, axis=0) 693 | # assert all classes have equal cardinality 694 | assert np.unique(val_sum).size == 1 695 | 696 | if test_examples_per_class is not None: 697 | test_labels = labels[test_indices, :] 698 | test_sum = np.sum(test_labels, axis=0) 699 | # assert all classes have equal cardinality 700 | assert np.unique(test_sum).size == 1 701 | 702 | return train_indices, val_indices, test_indices 703 | -------------------------------------------------------------------------------- /experiments/ablation_feature_noise.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Train SAGE with 10 different node feature noise levels: 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1 4 | # on five datasets: "cora" "citeseer" "pubmed" "a-computer" "a-photo" 5 | # Then train corresponding GLNN for each level 6 | 7 | aggregated_result_file="ablation_feature_noise.txt" 8 | printf "Teacher\n" >> $aggregated_result_file 9 | 10 | for n in 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1 11 | do 12 | printf "%3s\n" $n >> $aggregated_result_file 13 | for ds in "cora" "citeseer" "pubmed" "a-computer" "a-photo" 14 | do 15 | printf "%10s\t" $ds >> $aggregated_result_file 16 | python train_teacher.py --exp_setting "ind" --teacher "SAGE" --dataset $ds --feature_noise $n \ 17 | --num_exp 5 --max_epoch 200 --patience 50 >> $aggregated_result_file 18 | done 19 | printf "\n" >> $aggregated_result_file 20 | done 21 | 22 | printf "Student\n" >> $aggregated_result_file 23 | 24 | for n in 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1 25 | do 26 | printf "%3s\n" $n >> $aggregated_result_file 27 | for ds in "cora" "citeseer" "pubmed" "a-computer" "a-photo" 28 | do 29 | printf "%10s\t" $ds >> $aggregated_result_file 30 | python train_student.py --exp_setting "ind" --teacher "SAGE" --dataset $ds --feature_noise $n \ 31 | --num_exp 5 --max_epoch 200 --patience 50 >> $aggregated_result_file 32 | done 33 | printf "\n" >> $aggregated_result_file 34 | done 35 | 36 | -------------------------------------------------------------------------------- /experiments/ablation_gnn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # Train five different teachers "GCN" "GAT" "SAGE" "MLP" "APPNP" 5 | # on five datasets: "cora" "citeseer" "pubmed" "a-computer" "a-photo" 6 | # Then train corresponding GLNN for each teacher 7 | 8 | aggregated_result_file="ablation_gnn.txt" 9 | printf "Teacher\n" >> $aggregated_result_file 10 | 11 | for e in "tran" "ind" 12 | do 13 | printf "%6s\n" $e >> $aggregated_result_file 14 | for t in "GCN" "GAT" "SAGE" "MLP" "APPNP" 15 | do 16 | printf "%6s\n" $t >> $aggregated_result_file 17 | for ds in "cora" "citeseer" "pubmed" "a-computer" "a-photo" 18 | do 19 | printf "%10s\t" $ds >> $aggregated_result_file 20 | python train_teacher.py --exp_setting $e --teacher $t --dataset $ds --num_exp 5 \ 21 | --max_epoch 200 --patience 50 >> $aggregated_result_file 22 | done 23 | printf "\n" >> $aggregated_result_file 24 | done 25 | printf "\n" >> $aggregated_result_file 26 | done 27 | 28 | printf "Student\n" >> $aggregated_result_file 29 | 30 | for e in "tran" "ind" 31 | do 32 | printf "%6s\n" $e >> $aggregated_result_file 33 | for t in "GCN" "GAT" "SAGE" "APPNP" 34 | do 35 | printf "%6s\n" $t >> $aggregated_result_file 36 | for ds in "cora" "citeseer" "pubmed" "a-computer" "a-photo" 37 | do 38 | printf "%10s\t" $ds >> $aggregated_result_file 39 | python train_student.py --exp_setting $e --teacher $t --dataset $ds --num_exp 5 \ 40 | --max_epoch 200 --patience 50 >> $aggregated_result_file 41 | done 42 | printf "\n" >> $aggregated_result_file 43 | done 44 | printf "\n" >> $aggregated_result_file 45 | done 46 | -------------------------------------------------------------------------------- /experiments/ablation_ind_split_rate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Train SAGE with 9 different inductive split rates: 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 4 | # on five datasets: "cora" "citeseer" "pubmed" "a-computer" "a-photo" 5 | # Then train corresponding GLNN for each split 6 | 7 | aggregated_result_file="ablation_ind_split_rate.txt" 8 | printf "Teacher\n" >> $aggregated_result_file 9 | 10 | for r in 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 11 | do 12 | printf "%3s\n" $r >> $aggregated_result_file 13 | for ds in "cora" "citeseer" "pubmed" "a-computer" "a-photo" 14 | do 15 | printf "%10s\t" $ds >> $aggregated_result_file 16 | python train_teacher.py --exp_setting "ind" --teacher "SAGE" --dataset $ds --split_rate $r \ 17 | --num_exp 5 --max_epoch 200 --patience 50 >> $aggregated_result_file 18 | done 19 | printf "\n" >> $aggregated_result_file 20 | done 21 | 22 | printf "Student\n" >> $aggregated_result_file 23 | 24 | for r in 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 25 | do 26 | printf "%3s\n" $r >> $aggregated_result_file 27 | for ds in "cora" "citeseer" "pubmed" "a-computer" "a-photo" 28 | do 29 | printf "%10s\t" $ds >> $aggregated_result_file 30 | python train_student.py --exp_setting "ind" --teacher "SAGE" --dataset $ds --split_rate $r \ 31 | --num_exp 5 --max_epoch 200 --patience 50 >> $aggregated_result_file 32 | done 33 | printf "\n" >> $aggregated_result_file 34 | done 35 | 36 | -------------------------------------------------------------------------------- /experiments/ga_glnn_arxiv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Train 1-hop GA-MLP and 1-hop GA-GLNN with SAGE teacher on "ogbn-arxiv" under the inductive setting 4 | 5 | python train_teacher.py --exp_setting "ind" --teacher "MLP3w4" --dataset "ogbn-arxiv" \ 6 | --num_exp 5 --max_epoch 200 --patience 50 \ 7 | --feature_aug_k 1 8 | 9 | python train_student.py --exp_setting "ind" --teacher "SAGE" --student "MLP3w4" --dataset "ogbn-arxiv" \ 10 | --num_exp 5 --max_epoch 200 --patience 50 \ 11 | --feature_aug_k 1 12 | 13 | -------------------------------------------------------------------------------- /experiments/glnn_arxiv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Train GLNN with SAGE teacher on "ogbn-arxiv" 4 | 5 | for e in "tran" "ind" 6 | do 7 | python train_student.py --exp_setting $e --teacher "SAGE" --student "MLP3w4" --dataset "ogbn-arxiv" \ 8 | --num_exp 10 --max_epoch 200 --patience 50 9 | done 10 | -------------------------------------------------------------------------------- /experiments/glnn_cpf.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Train GLNN with SAGE teacher on five datasets: "cora" "citeseer" "pubmed" "a-computer" "a-photo" 4 | 5 | aggregated_result_file="glnn_cpf.txt" 6 | for e in "tran" "ind" 7 | do 8 | printf "%6s\n" $e >> $aggregated_result_file 9 | for ds in "cora" "citeseer" "pubmed" "a-computer" "a-photo" 10 | do 11 | printf "%10s\t" $ds >> $aggregated_result_file 12 | python train_student.py --exp_setting $e --teacher "SAGE" --dataset $ds --num_exp 10 \ 13 | --max_epoch 200 --patience 50 \ 14 | --save_results >> $aggregated_result_file 15 | done 16 | printf "\n" >> $aggregated_result_file 17 | done 18 | -------------------------------------------------------------------------------- /experiments/glnn_products.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Train GLNN with SAGE teacher on "ogbn-products" 4 | 5 | for e in "tran" "ind" 6 | do 7 | python train_student.py --exp_setting $e --teacher "SAGE" --student "MLP3w8" --dataset "ogbn-products" \ 8 | --num_exp 10 --max_epoch 200 --patience 30 9 | done 10 | -------------------------------------------------------------------------------- /experiments/sage_arxiv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Train SAGE teacher on "ogbn-arxiv" 4 | 5 | for e in "tran" "ind" 6 | do 7 | python train_teacher.py --exp_setting $e --teacher "SAGE" --dataset "ogbn-arxiv" \ 8 | --num_exp 10 --max_epoch 200 --patience 50 \ 9 | --save_results 10 | done 11 | -------------------------------------------------------------------------------- /experiments/sage_cpf.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Train SAGE teacher on five datasets: "cora" "citeseer" "pubmed" "a-computer" "a-photo" 4 | 5 | aggregated_result_file="sage_cpf.txt" 6 | for e in "tran" "ind" 7 | do 8 | printf "%6s\n" $e >> $aggregated_result_file 9 | for ds in "cora" "citeseer" "pubmed" "a-computer" "a-photo" 10 | do 11 | printf "%10s\t" $ds >> $aggregated_result_file 12 | python train_teacher.py --exp_setting $e --teacher "SAGE" --dataset $ds \ 13 | --num_exp 10 --max_epoch 200 --patience 50 \ 14 | --save_results >> $aggregated_result_file 15 | done 16 | printf "\n" >> $aggregated_result_file 17 | done 18 | -------------------------------------------------------------------------------- /experiments/sage_products.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Train SAGE teacher on "ogbn-products" 4 | 5 | for e in "tran" "ind" 6 | do 7 | python train_teacher.py --exp_setting $e --teacher "SAGE" --dataset "ogbn-products" \ 8 | --num_exp 10 --max_epoch 40 --patience 10 \ 9 | --save_results 10 | done 11 | -------------------------------------------------------------------------------- /imgs/glnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snap-research/graphless-neural-networks/80d0f8ea66ad49b27336f9c02691ca16a6dc1dd2/imgs/glnn.png -------------------------------------------------------------------------------- /imgs/trade_off.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snap-research/graphless-neural-networks/80d0f8ea66ad49b27336f9c02691ca16a6dc1dd2/imgs/trade_off.png -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from dgl.nn import GraphConv, SAGEConv, APPNPConv, GATConv 5 | 6 | 7 | class MLP(nn.Module): 8 | def __init__( 9 | self, 10 | num_layers, 11 | input_dim, 12 | hidden_dim, 13 | output_dim, 14 | dropout_ratio, 15 | norm_type="none", 16 | ): 17 | super(MLP, self).__init__() 18 | self.num_layers = num_layers 19 | self.norm_type = norm_type 20 | self.dropout = nn.Dropout(dropout_ratio) 21 | self.layers = nn.ModuleList() 22 | self.norms = nn.ModuleList() 23 | 24 | if num_layers == 1: 25 | self.layers.append(nn.Linear(input_dim, output_dim)) 26 | else: 27 | self.layers.append(nn.Linear(input_dim, hidden_dim)) 28 | if self.norm_type == "batch": 29 | self.norms.append(nn.BatchNorm1d(hidden_dim)) 30 | elif self.norm_type == "layer": 31 | self.norms.append(nn.LayerNorm(hidden_dim)) 32 | 33 | for i in range(num_layers - 2): 34 | self.layers.append(nn.Linear(hidden_dim, hidden_dim)) 35 | if self.norm_type == "batch": 36 | self.norms.append(nn.BatchNorm1d(hidden_dim)) 37 | elif self.norm_type == "layer": 38 | self.norms.append(nn.LayerNorm(hidden_dim)) 39 | 40 | self.layers.append(nn.Linear(hidden_dim, output_dim)) 41 | 42 | def forward(self, feats): 43 | h = feats 44 | h_list = [] 45 | for l, layer in enumerate(self.layers): 46 | h = layer(h) 47 | if l != self.num_layers - 1: 48 | h_list.append(h) 49 | if self.norm_type != "none": 50 | h = self.norms[l](h) 51 | h = F.relu(h) 52 | h = self.dropout(h) 53 | return h_list, h 54 | 55 | 56 | """ 57 | Adapted from the SAGE implementation from the official DGL example 58 | https://github.com/dmlc/dgl/blob/master/examples/pytorch/ogb/ogbn-products/graphsage/main.py 59 | """ 60 | 61 | 62 | class SAGE(nn.Module): 63 | def __init__( 64 | self, 65 | num_layers, 66 | input_dim, 67 | hidden_dim, 68 | output_dim, 69 | dropout_ratio, 70 | activation, 71 | norm_type="none", 72 | ): 73 | super().__init__() 74 | self.num_layers = num_layers 75 | self.hidden_dim = hidden_dim 76 | self.output_dim = output_dim 77 | self.norm_type = norm_type 78 | self.activation = activation 79 | self.dropout = nn.Dropout(dropout_ratio) 80 | self.layers = nn.ModuleList() 81 | self.norms = nn.ModuleList() 82 | 83 | if num_layers == 1: 84 | self.layers.append(SAGEConv(input_dim, output_dim, "gcn")) 85 | else: 86 | self.layers.append(SAGEConv(input_dim, hidden_dim, "gcn")) 87 | if self.norm_type == "batch": 88 | self.norms.append(nn.BatchNorm1d(hidden_dim)) 89 | elif self.norm_type == "layer": 90 | self.norms.append(nn.LayerNorm(hidden_dim)) 91 | 92 | for i in range(num_layers - 2): 93 | self.layers.append(SAGEConv(hidden_dim, hidden_dim, "gcn")) 94 | if self.norm_type == "batch": 95 | self.norms.append(nn.BatchNorm1d(hidden_dim)) 96 | elif self.norm_type == "layer": 97 | self.norms.append(nn.LayerNorm(hidden_dim)) 98 | 99 | self.layers.append(SAGEConv(hidden_dim, output_dim, "gcn")) 100 | 101 | def forward(self, blocks, feats): 102 | h = feats 103 | h_list = [] 104 | for l, (layer, block) in enumerate(zip(self.layers, blocks)): 105 | # We need to first copy the representation of nodes on the RHS from the 106 | # appropriate nodes on the LHS. 107 | # Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst 108 | # would be (num_nodes_RHS, D) 109 | h_dst = h[: block.num_dst_nodes()] 110 | # Then we compute the updated representation on the RHS. 111 | # The shape of h now becomes (num_nodes_RHS, D) 112 | h = layer(block, (h, h_dst)) 113 | if l != self.num_layers - 1: 114 | h_list.append(h) 115 | if self.norm_type != "none": 116 | h = self.norms[l](h) 117 | h = self.activation(h) 118 | h = self.dropout(h) 119 | return h_list, h 120 | 121 | def inference(self, dataloader, feats): 122 | """ 123 | Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling). 124 | dataloader : The entire graph loaded in blocks with full neighbors for each node. 125 | feats : The input feats of entire node set. 126 | """ 127 | device = feats.device 128 | for l, layer in enumerate(self.layers): 129 | y = torch.zeros( 130 | feats.shape[0], 131 | self.hidden_dim if l != self.num_layers - 1 else self.output_dim, 132 | ).to(device) 133 | for input_nodes, output_nodes, blocks in dataloader: 134 | block = blocks[0].int().to(device) 135 | 136 | h = feats[input_nodes] 137 | h_dst = h[: block.num_dst_nodes()] 138 | h = layer(block, (h, h_dst)) 139 | if l != self.num_layers - 1: 140 | if self.norm_type != "none": 141 | h = self.norms[l](h) 142 | h = self.activation(h) 143 | h = self.dropout(h) 144 | 145 | y[output_nodes] = h 146 | 147 | feats = y 148 | return y 149 | 150 | 151 | class GCN(nn.Module): 152 | def __init__( 153 | self, 154 | num_layers, 155 | input_dim, 156 | hidden_dim, 157 | output_dim, 158 | dropout_ratio, 159 | activation, 160 | norm_type="none", 161 | ): 162 | super().__init__() 163 | self.num_layers = num_layers 164 | self.norm_type = norm_type 165 | self.dropout = nn.Dropout(dropout_ratio) 166 | self.layers = nn.ModuleList() 167 | self.norms = nn.ModuleList() 168 | 169 | if num_layers == 1: 170 | self.layers.append(GraphConv(input_dim, output_dim, activation=activation)) 171 | else: 172 | self.layers.append(GraphConv(input_dim, hidden_dim, activation=activation)) 173 | if self.norm_type == "batch": 174 | self.norms.append(nn.BatchNorm1d(hidden_dim)) 175 | elif self.norm_type == "layer": 176 | self.norms.append(nn.LayerNorm(hidden_dim)) 177 | 178 | for i in range(num_layers - 2): 179 | self.layers.append( 180 | GraphConv(hidden_dim, hidden_dim, activation=activation) 181 | ) 182 | if self.norm_type == "batch": 183 | self.norms.append(nn.BatchNorm1d(hidden_dim)) 184 | elif self.norm_type == "layer": 185 | self.norms.append(nn.LayerNorm(hidden_dim)) 186 | 187 | self.layers.append(GraphConv(hidden_dim, output_dim)) 188 | 189 | def forward(self, g, feats): 190 | h = feats 191 | h_list = [] 192 | for l, layer in enumerate(self.layers): 193 | h = layer(g, h) 194 | if l != self.num_layers - 1: 195 | h_list.append(h) 196 | if self.norm_type != "none": 197 | h = self.norms[l](h) 198 | h = self.dropout(h) 199 | return h_list, h 200 | 201 | 202 | class GAT(nn.Module): 203 | def __init__( 204 | self, 205 | num_layers, 206 | input_dim, 207 | hidden_dim, 208 | output_dim, 209 | dropout_ratio, 210 | activation, 211 | num_heads=8, 212 | attn_drop=0.3, 213 | negative_slope=0.2, 214 | residual=False, 215 | ): 216 | super(GAT, self).__init__() 217 | # For GAT, the number of layers is required to be > 1 218 | assert num_layers > 1 219 | 220 | hidden_dim //= num_heads 221 | self.num_layers = num_layers 222 | self.layers = nn.ModuleList() 223 | self.activation = activation 224 | 225 | heads = ([num_heads] * num_layers) + [1] 226 | # input (no residual) 227 | self.layers.append( 228 | GATConv( 229 | input_dim, 230 | hidden_dim, 231 | heads[0], 232 | dropout_ratio, 233 | attn_drop, 234 | negative_slope, 235 | False, 236 | self.activation, 237 | ) 238 | ) 239 | 240 | for l in range(1, num_layers - 1): 241 | # due to multi-head, the in_dim = hidden_dim * num_heads 242 | self.layers.append( 243 | GATConv( 244 | hidden_dim * heads[l - 1], 245 | hidden_dim, 246 | heads[l], 247 | dropout_ratio, 248 | attn_drop, 249 | negative_slope, 250 | residual, 251 | self.activation, 252 | ) 253 | ) 254 | 255 | self.layers.append( 256 | GATConv( 257 | hidden_dim * heads[-2], 258 | output_dim, 259 | heads[-1], 260 | dropout_ratio, 261 | attn_drop, 262 | negative_slope, 263 | residual, 264 | None, 265 | ) 266 | ) 267 | 268 | def forward(self, g, feats): 269 | h = feats 270 | h_list = [] 271 | for l, layer in enumerate(self.layers): 272 | # [num_head, node_num, nclass] -> [num_head, node_num*nclass] 273 | h = layer(g, h) 274 | if l != self.num_layers - 1: 275 | h = h.flatten(1) 276 | h_list.append(h) 277 | else: 278 | h = h.mean(1) 279 | return h_list, h 280 | 281 | 282 | class APPNP(nn.Module): 283 | def __init__( 284 | self, 285 | num_layers, 286 | input_dim, 287 | hidden_dim, 288 | output_dim, 289 | dropout_ratio, 290 | activation, 291 | norm_type="none", 292 | edge_drop=0.5, 293 | alpha=0.1, 294 | k=10, 295 | ): 296 | 297 | super(APPNP, self).__init__() 298 | self.num_layers = num_layers 299 | self.norm_type = norm_type 300 | self.activation = activation 301 | self.dropout = nn.Dropout(dropout_ratio) 302 | self.layers = nn.ModuleList() 303 | self.norms = nn.ModuleList() 304 | 305 | if num_layers == 1: 306 | self.layers.append(nn.Linear(input_dim, output_dim)) 307 | else: 308 | self.layers.append(nn.Linear(input_dim, hidden_dim)) 309 | if self.norm_type == "batch": 310 | self.norms.append(nn.BatchNorm1d(hidden_dim)) 311 | elif self.norm_type == "layer": 312 | self.norms.append(nn.LayerNorm(hidden_dim)) 313 | 314 | for i in range(num_layers - 2): 315 | self.layers.append(nn.Linear(hidden_dim, hidden_dim)) 316 | if self.norm_type == "batch": 317 | self.norms.append(nn.BatchNorm1d(hidden_dim)) 318 | elif self.norm_type == "layer": 319 | self.norms.append(nn.LayerNorm(hidden_dim)) 320 | 321 | self.layers.append(nn.Linear(hidden_dim, output_dim)) 322 | 323 | self.propagate = APPNPConv(k, alpha, edge_drop) 324 | self.reset_parameters() 325 | 326 | def reset_parameters(self): 327 | for layer in self.layers: 328 | layer.reset_parameters() 329 | 330 | def forward(self, g, feats): 331 | h = feats 332 | h_list = [] 333 | for l, layer in enumerate(self.layers): 334 | h = layer(h) 335 | 336 | if l != self.num_layers - 1: 337 | h_list.append(h) 338 | if self.norm_type != "none": 339 | h = self.norms[l](h) 340 | h = self.activation(h) 341 | h = self.dropout(h) 342 | 343 | h = self.propagate(g, h) 344 | return h_list, h 345 | 346 | 347 | class Model(nn.Module): 348 | """ 349 | Wrapper of different models 350 | """ 351 | 352 | def __init__(self, conf): 353 | super(Model, self).__init__() 354 | self.model_name = conf["model_name"] 355 | if "MLP" in conf["model_name"]: 356 | self.encoder = MLP( 357 | num_layers=conf["num_layers"], 358 | input_dim=conf["feat_dim"], 359 | hidden_dim=conf["hidden_dim"], 360 | output_dim=conf["label_dim"], 361 | dropout_ratio=conf["dropout_ratio"], 362 | norm_type=conf["norm_type"], 363 | ).to(conf["device"]) 364 | elif "SAGE" in conf["model_name"]: 365 | self.encoder = SAGE( 366 | num_layers=conf["num_layers"], 367 | input_dim=conf["feat_dim"], 368 | hidden_dim=conf["hidden_dim"], 369 | output_dim=conf["label_dim"], 370 | dropout_ratio=conf["dropout_ratio"], 371 | activation=F.relu, 372 | norm_type=conf["norm_type"], 373 | ).to(conf["device"]) 374 | elif "GCN" in conf["model_name"]: 375 | self.encoder = GCN( 376 | num_layers=conf["num_layers"], 377 | input_dim=conf["feat_dim"], 378 | hidden_dim=conf["hidden_dim"], 379 | output_dim=conf["label_dim"], 380 | dropout_ratio=conf["dropout_ratio"], 381 | activation=F.relu, 382 | norm_type=conf["norm_type"], 383 | ).to(conf["device"]) 384 | elif "GAT" in conf["model_name"]: 385 | self.encoder = GAT( 386 | num_layers=conf["num_layers"], 387 | input_dim=conf["feat_dim"], 388 | hidden_dim=conf["hidden_dim"], 389 | output_dim=conf["label_dim"], 390 | dropout_ratio=conf["dropout_ratio"], 391 | activation=F.relu, 392 | attn_drop=conf["attn_dropout_ratio"], 393 | ).to(conf["device"]) 394 | elif "APPNP" in conf["model_name"]: 395 | self.encoder = APPNP( 396 | num_layers=conf["num_layers"], 397 | input_dim=conf["feat_dim"], 398 | hidden_dim=conf["hidden_dim"], 399 | output_dim=conf["label_dim"], 400 | dropout_ratio=conf["dropout_ratio"], 401 | activation=F.relu, 402 | norm_type=conf["norm_type"], 403 | ).to(conf["device"]) 404 | 405 | def forward(self, data, feats): 406 | """ 407 | data: a graph `g` or a `dataloader` of blocks 408 | """ 409 | if "MLP" in self.model_name: 410 | return self.encoder(feats)[1] 411 | else: 412 | return self.encoder(data, feats)[1] 413 | 414 | def forward_fitnet(self, data, feats): 415 | """ 416 | Return a tuple (h_list, h) 417 | h_list: intermediate hidden representation 418 | h: final output 419 | """ 420 | if "MLP" in self.model_name: 421 | return self.encoder(feats) 422 | else: 423 | return self.encoder(data, feats) 424 | 425 | def inference(self, data, feats): 426 | if "SAGE" in self.model_name: 427 | return self.encoder.inference(data, feats) 428 | else: 429 | return self.forward(data, feats) 430 | -------------------------------------------------------------------------------- /prepare_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | conda create -y -n glnn python=3.6.9 4 | eval "$(conda shell.bash hook)" 5 | conda activate glnn 6 | 7 | pip install --no-cache-dir -r requirements.txt 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --find-links https://data.dgl.ai/wheels/repo.html 2 | 3 | ogb==1.3.3 4 | pillow==8.4.0 5 | scipy==1.4.1 6 | networkx==2.5.1 7 | numpy==1.18.1 8 | tabulate==0.8.7 9 | tqdm==4.62.3 10 | PyYAML==5.3.1 11 | scikit_learn==0.22.2 12 | googledrivedownloader==0.4 13 | category_encoders==2.3.0 14 | torch==1.7.0 15 | dgl==0.6.1 -------------------------------------------------------------------------------- /train.conf.yaml: -------------------------------------------------------------------------------- 1 | global: 2 | num_layers: 2 3 | hidden_dim: 128 4 | 5 | cora: 6 | SAGE: 7 | fan_out: 5,5 8 | learning_rate: 0.01 9 | dropout_ratio: 0 10 | weight_decay: 0.0005 11 | 12 | GCN: 13 | hidden_dim: 64 14 | dropout_ratio: 0.8 15 | weight_decay: 0.001 16 | 17 | MLP: 18 | learning_rate: 0.01 19 | weight_decay: 0.005 20 | dropout_ratio: 0.6 21 | 22 | GAT: 23 | dropout_ratio: 0.6 24 | weight_decay: 0.01 25 | num_heads: 8 26 | attn_dropout_ratio: 0.3 27 | 28 | APPNP: 29 | dropout_ratio: 0.5 30 | weight_decay: 0.01 31 | 32 | 33 | citeseer: 34 | SAGE: 35 | fan_out: 5,5 36 | learning_rate: 0.01 37 | dropout_ratio: 0 38 | weight_decay: 0.0005 39 | 40 | GCN: 41 | hidden_dim: 64 42 | dropout_ratio: 0.8 43 | weight_decay: 0.001 44 | 45 | MLP: 46 | learning_rate: 0.01 47 | weight_decay: 0.001 48 | dropout_ratio: 0.1 49 | 50 | GAT: 51 | dropout_ratio: 0.6 52 | weight_decay: 0.01 53 | num_heads: 8 54 | attn_dropout_ratio: 0.3 55 | 56 | APPNP: 57 | dropout_ratio: 0.5 58 | weight_decay: 0.01 59 | 60 | pubmed: 61 | SAGE: 62 | fan_out: 5,5 63 | learning_rate: 0.01 64 | dropout_ratio: 0 65 | weight_decay: 0.0005 66 | 67 | GCN: 68 | hidden_dim: 64 69 | dropout_ratio: 0.8 70 | weight_decay: 0.001 71 | 72 | MLP: 73 | learning_rate: 0.005 74 | weight_decay: 0 75 | dropout_ratio: 0.4 76 | 77 | GAT: 78 | dropout_ratio: 0.6 79 | weight_decay: 0.01 80 | num_heads: 8 81 | attn_dropout_ratio: 0.3 82 | 83 | APPNP: 84 | dropout_ratio: 0.5 85 | weight_decay: 0.01 86 | 87 | a-computer: 88 | SAGE: 89 | fan_out: 5,5 90 | learning_rate: 0.01 91 | dropout_ratio: 0 92 | weight_decay: 0.0005 93 | 94 | GCN: 95 | hidden_dim: 64 96 | dropout_ratio: 0.8 97 | weight_decay: 0.001 98 | 99 | MLP: 100 | learning_rate: 0.001 101 | weight_decay: 0.002 102 | dropout_ratio: 0.3 103 | 104 | GAT: 105 | dropout_ratio: 0.6 106 | weight_decay: 0.01 107 | num_heads: 8 108 | attn_dropout_ratio: 0.3 109 | 110 | APPNP: 111 | dropout_ratio: 0.5 112 | weight_decay: 0.01 113 | 114 | a-photo: 115 | SAGE: 116 | fan_out: 5,5 117 | learning_rate: 0.01 118 | dropout_ratio: 0 119 | weight_decay: 0.0005 120 | 121 | GCN: 122 | hidden_dim: 64 123 | dropout_ratio: 0.8 124 | weight_decay: 0.001 125 | 126 | MLP: 127 | learning_rate: 0.005 128 | weight_decay: 0.002 129 | dropout_ratio: 0.3 130 | 131 | GAT: 132 | dropout_ratio: 0.6 133 | weight_decay: 0.01 134 | num_heads: 8 135 | attn_dropout_ratio: 0.3 136 | 137 | APPNP: 138 | dropout_ratio: 0.5 139 | weight_decay: 0.01 140 | 141 | ogbn-arxiv: 142 | MLP: 143 | num_layers: 3 144 | hidden_dim: 256 145 | weight_decay: 0 146 | dropout_ratio: 0.2 147 | norm_type: batch 148 | 149 | MLP3w4: 150 | num_layers: 3 151 | hidden_dim: 1024 152 | weight_decay: 0 153 | dropout_ratio: 0.5 154 | norm_type: batch 155 | 156 | GA1MLP: 157 | num_layers: 3 158 | hidden_dim: 256 159 | weight_decay: 0 160 | dropout_ratio: 0.2 161 | norm_type: batch 162 | 163 | GA1MLP3w4: 164 | num_layers: 3 165 | hidden_dim: 1024 166 | weight_decay: 0 167 | dropout_ratio: 0.2 168 | norm_type: batch 169 | 170 | SAGE: 171 | num_layers: 3 172 | hidden_dim: 256 173 | dropout_ratio: 0.2 174 | learning_rate: 0.01 175 | weight_decay: 0 176 | norm_type: batch 177 | fan_out: 5,10,15 178 | 179 | ogbn-products: 180 | MLP: 181 | num_layers: 3 182 | hidden_dim: 256 183 | dropout_ratio: 0.5 184 | norm_type: batch 185 | batch_size: 4096 186 | 187 | MLP3w8: 188 | num_layers: 3 189 | hidden_dim: 2048 190 | dropout_ratio: 0.2 191 | learning_rate: 0.01 192 | weight_decay: 0 193 | norm_type: batch 194 | batch_size: 4096 195 | 196 | SAGE: 197 | num_layers: 3 198 | hidden_dim: 256 199 | dropout_ratio: 0.5 200 | learning_rate: 0.003 201 | weight_decay: 0 202 | norm_type: batch 203 | fan_out: 5,10,15 204 | batch_size: 4096 205 | 206 | pokec: 207 | GCN: 208 | num_layers: 2 209 | hidden_dim: 32 210 | dropout_ratio: 0.5 211 | learning_rate: 0.01 212 | weight_decay: 0 213 | norm_type: batch 214 | 215 | MLP: 216 | num_layers: 3 217 | hidden_dim: 256 218 | dropout_ratio: 0.2 219 | learning_rate: 0.001 220 | norm_type: none 221 | 222 | SAGE: 223 | num_layers: 2 224 | hidden_dim: 32 225 | dropout_ratio: 0.5 226 | learning_rate: 0.001 227 | weight_decay: 0 228 | norm_type: batch 229 | fan_out: 5,5 230 | 231 | penn94: 232 | GCN: 233 | num_layers: 2 234 | hidden_dim: 64 235 | dropout_ratio: 0.5 236 | learning_rate: 0.01 237 | weight_decay: 0.001 238 | norm_type: batch 239 | 240 | MLP: 241 | num_layers: 3 242 | hidden_dim: 256 243 | dropout_ratio: 0 244 | learning_rate: 0.001 245 | norm_type: none 246 | 247 | vk_class: 248 | MLP: 249 | num_layers: 3 250 | hidden_dim: 512 251 | learning_rate: 0.01 252 | dropout_ratio: 0 253 | weight_decay: 0 254 | norm_type: batch 255 | batch_size: 6754 256 | 257 | house_class: 258 | MLP: 259 | num_layers: 3 260 | hidden_dim: 512 261 | learning_rate: 0.01 262 | dropout_ratio: 0 263 | weight_decay: 0 264 | norm_type: layer 265 | batch_size: 2580 266 | -------------------------------------------------------------------------------- /train_and_eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | import torch 4 | import dgl 5 | from utils import set_seed 6 | 7 | """ 8 | 1. Train and eval 9 | """ 10 | 11 | 12 | def train(model, data, feats, labels, criterion, optimizer, idx_train, lamb=1): 13 | """ 14 | GNN full-batch training. Input the entire graph `g` as data. 15 | lamb: weight parameter lambda 16 | """ 17 | model.train() 18 | 19 | # Compute loss and prediction 20 | logits = model(data, feats) 21 | out = logits.log_softmax(dim=1) 22 | loss = criterion(out[idx_train], labels[idx_train]) 23 | loss_val = loss.item() 24 | 25 | loss *= lamb 26 | optimizer.zero_grad() 27 | loss.backward() 28 | optimizer.step() 29 | return loss_val 30 | 31 | 32 | def train_sage(model, dataloader, feats, labels, criterion, optimizer, lamb=1): 33 | """ 34 | Train for GraphSAGE. Process the graph in mini-batches using `dataloader` instead the entire graph `g`. 35 | lamb: weight parameter lambda 36 | """ 37 | device = feats.device 38 | model.train() 39 | total_loss = 0 40 | for step, (input_nodes, output_nodes, blocks) in enumerate(dataloader): 41 | blocks = [blk.int().to(device) for blk in blocks] 42 | batch_feats = feats[input_nodes] 43 | batch_labels = labels[output_nodes] 44 | 45 | # Compute loss and prediction 46 | logits = model(blocks, batch_feats) 47 | out = logits.log_softmax(dim=1) 48 | loss = criterion(out, batch_labels) 49 | total_loss += loss.item() 50 | 51 | loss *= lamb 52 | optimizer.zero_grad() 53 | loss.backward() 54 | optimizer.step() 55 | 56 | return total_loss / len(dataloader) 57 | 58 | 59 | def train_mini_batch(model, feats, labels, batch_size, criterion, optimizer, lamb=1): 60 | """ 61 | Train MLP for large datasets. Process the data in mini-batches. The graph is ignored, node features only. 62 | lamb: weight parameter lambda 63 | """ 64 | model.train() 65 | num_batches = max(1, feats.shape[0] // batch_size) 66 | idx_batch = torch.randperm(feats.shape[0])[: num_batches * batch_size] 67 | 68 | if num_batches == 1: 69 | idx_batch = idx_batch.view(1, -1) 70 | else: 71 | idx_batch = idx_batch.view(num_batches, batch_size) 72 | 73 | total_loss = 0 74 | for i in range(num_batches): 75 | # No graph needed for the forward function 76 | logits = model(None, feats[idx_batch[i]]) 77 | out = logits.log_softmax(dim=1) 78 | 79 | loss = criterion(out, labels[idx_batch[i]]) 80 | total_loss += loss.item() 81 | 82 | loss *= lamb 83 | optimizer.zero_grad() 84 | loss.backward() 85 | optimizer.step() 86 | return total_loss / num_batches 87 | 88 | 89 | def evaluate(model, data, feats, labels, criterion, evaluator, idx_eval=None): 90 | """ 91 | Returns: 92 | out: log probability of all input data 93 | loss & score (float): evaluated loss & score, if idx_eval is not None, only loss & score on those idx. 94 | """ 95 | model.eval() 96 | with torch.no_grad(): 97 | logits = model.inference(data, feats) 98 | out = logits.log_softmax(dim=1) 99 | if idx_eval is None: 100 | loss = criterion(out, labels) 101 | score = evaluator(out, labels) 102 | else: 103 | loss = criterion(out[idx_eval], labels[idx_eval]) 104 | score = evaluator(out[idx_eval], labels[idx_eval]) 105 | return out, loss.item(), score 106 | 107 | 108 | def evaluate_mini_batch( 109 | model, feats, labels, criterion, batch_size, evaluator, idx_eval=None 110 | ): 111 | """ 112 | Evaluate MLP for large datasets. Process the data in mini-batches. The graph is ignored, node features only. 113 | Return: 114 | out: log probability of all input data 115 | loss & score (float): evaluated loss & score, if idx_eval is not None, only loss & score on those idx. 116 | """ 117 | 118 | model.eval() 119 | with torch.no_grad(): 120 | num_batches = int(np.ceil(len(feats) / batch_size)) 121 | out_list = [] 122 | for i in range(num_batches): 123 | logits = model.inference(None, feats[batch_size * i : batch_size * (i + 1)]) 124 | out = logits.log_softmax(dim=1) 125 | out_list += [out.detach()] 126 | 127 | out_all = torch.cat(out_list) 128 | 129 | if idx_eval is None: 130 | loss = criterion(out_all, labels) 131 | score = evaluator(out_all, labels) 132 | else: 133 | loss = criterion(out_all[idx_eval], labels[idx_eval]) 134 | score = evaluator(out_all[idx_eval], labels[idx_eval]) 135 | 136 | return out_all, loss.item(), score 137 | 138 | 139 | """ 140 | 2. Run teacher 141 | """ 142 | 143 | 144 | def run_transductive( 145 | conf, 146 | model, 147 | g, 148 | feats, 149 | labels, 150 | indices, 151 | criterion, 152 | evaluator, 153 | optimizer, 154 | logger, 155 | loss_and_score, 156 | ): 157 | """ 158 | Train and eval under the transductive setting. 159 | The train/valid/test split is specified by `indices`. 160 | The input graph is assumed to be large. Thus, SAGE is used for GNNs, mini-batch is used for MLPs. 161 | 162 | loss_and_score: Stores losses and scores. 163 | """ 164 | set_seed(conf["seed"]) 165 | device = conf["device"] 166 | batch_size = conf["batch_size"] 167 | 168 | idx_train, idx_val, idx_test = indices 169 | 170 | feats = feats.to(device) 171 | labels = labels.to(device) 172 | 173 | if "SAGE" in model.model_name: 174 | # Create dataloader for SAGE 175 | 176 | # Create csr/coo/csc formats before launching sampling processes 177 | # This avoids creating certain formats in each data loader process, which saves momory and CPU. 178 | g.create_formats_() 179 | sampler = dgl.dataloading.MultiLayerNeighborSampler( 180 | [eval(fanout) for fanout in conf["fan_out"].split(",")] 181 | ) 182 | dataloader = dgl.dataloading.NodeDataLoader( 183 | g, 184 | idx_train, 185 | sampler, 186 | batch_size=batch_size, 187 | shuffle=True, 188 | drop_last=False, 189 | num_workers=conf["num_workers"], 190 | ) 191 | 192 | # SAGE inference is implemented as layer by layer, so the full-neighbor sampler only collects one-hop neighors 193 | sampler_eval = dgl.dataloading.MultiLayerFullNeighborSampler(1) 194 | dataloader_eval = dgl.dataloading.NodeDataLoader( 195 | g, 196 | torch.arange(g.num_nodes()), 197 | sampler_eval, 198 | batch_size=batch_size, 199 | shuffle=False, 200 | drop_last=False, 201 | num_workers=conf["num_workers"], 202 | ) 203 | 204 | data = dataloader 205 | data_eval = dataloader_eval 206 | elif "MLP" in model.model_name: 207 | feats_train, labels_train = feats[idx_train], labels[idx_train] 208 | feats_val, labels_val = feats[idx_val], labels[idx_val] 209 | feats_test, labels_test = feats[idx_test], labels[idx_test] 210 | else: 211 | g = g.to(device) 212 | data = g 213 | data_eval = g 214 | 215 | best_epoch, best_score_val, count = 0, 0, 0 216 | for epoch in range(1, conf["max_epoch"] + 1): 217 | if "SAGE" in model.model_name: 218 | loss = train_sage(model, data, feats, labels, criterion, optimizer) 219 | elif "MLP" in model.model_name: 220 | loss = train_mini_batch( 221 | model, feats_train, labels_train, batch_size, criterion, optimizer 222 | ) 223 | else: 224 | loss = train(model, data, feats, labels, criterion, optimizer, idx_train) 225 | 226 | if epoch % conf["eval_interval"] == 0: 227 | if "MLP" in model.model_name: 228 | _, loss_train, score_train = evaluate_mini_batch( 229 | model, feats_train, labels_train, criterion, batch_size, evaluator 230 | ) 231 | _, loss_val, score_val = evaluate_mini_batch( 232 | model, feats_val, labels_val, criterion, batch_size, evaluator 233 | ) 234 | _, loss_test, score_test = evaluate_mini_batch( 235 | model, feats_test, labels_test, criterion, batch_size, evaluator 236 | ) 237 | else: 238 | out, loss_train, score_train = evaluate( 239 | model, data_eval, feats, labels, criterion, evaluator, idx_train 240 | ) 241 | # Use criterion & evaluator instead of evaluate to avoid redundant forward pass 242 | loss_val = criterion(out[idx_val], labels[idx_val]).item() 243 | score_val = evaluator(out[idx_val], labels[idx_val]) 244 | loss_test = criterion(out[idx_test], labels[idx_test]).item() 245 | score_test = evaluator(out[idx_test], labels[idx_test]) 246 | 247 | logger.debug( 248 | f"Ep {epoch:3d} | loss: {loss:.4f} | s_train: {score_train:.4f} | s_val: {score_val:.4f} | s_test: {score_test:.4f}" 249 | ) 250 | loss_and_score += [ 251 | [ 252 | epoch, 253 | loss_train, 254 | loss_val, 255 | loss_test, 256 | score_train, 257 | score_val, 258 | score_test, 259 | ] 260 | ] 261 | 262 | if score_val >= best_score_val: 263 | best_epoch = epoch 264 | best_score_val = score_val 265 | state = copy.deepcopy(model.state_dict()) 266 | count = 0 267 | else: 268 | count += 1 269 | 270 | if count == conf["patience"] or epoch == conf["max_epoch"]: 271 | break 272 | 273 | model.load_state_dict(state) 274 | if "MLP" in model.model_name: 275 | out, _, score_val = evaluate_mini_batch( 276 | model, feats, labels, criterion, batch_size, evaluator, idx_val 277 | ) 278 | else: 279 | out, _, score_val = evaluate( 280 | model, data_eval, feats, labels, criterion, evaluator, idx_val 281 | ) 282 | 283 | score_test = evaluator(out[idx_test], labels[idx_test]) 284 | logger.info( 285 | f"Best valid model at epoch: {best_epoch: 3d}, score_val: {score_val :.4f}, score_test: {score_test :.4f}" 286 | ) 287 | return out, score_val, score_test 288 | 289 | 290 | def run_inductive( 291 | conf, 292 | model, 293 | g, 294 | feats, 295 | labels, 296 | indices, 297 | criterion, 298 | evaluator, 299 | optimizer, 300 | logger, 301 | loss_and_score, 302 | ): 303 | """ 304 | Train and eval under the inductive setting. 305 | The train/valid/test split is specified by `indices`. 306 | idx starting with `obs_idx_` contains the node idx in the observed graph `obs_g`. 307 | idx starting with `idx_` contains the node idx in the original graph `g`. 308 | The model is trained on the observed graph `obs_g`, and evaluated on both the observed test nodes (`obs_idx_test`) and inductive test nodes (`idx_test_ind`). 309 | The input graph is assumed to be large. Thus, SAGE is used for GNNs, mini-batch is used for MLPs. 310 | 311 | idx_obs: Idx of nodes in the original graph `g`, which form the observed graph 'obs_g'. 312 | loss_and_score: Stores losses and scores. 313 | """ 314 | 315 | set_seed(conf["seed"]) 316 | device = conf["device"] 317 | batch_size = conf["batch_size"] 318 | obs_idx_train, obs_idx_val, obs_idx_test, idx_obs, idx_test_ind = indices 319 | 320 | feats = feats.to(device) 321 | labels = labels.to(device) 322 | obs_feats = feats[idx_obs] 323 | obs_labels = labels[idx_obs] 324 | obs_g = g.subgraph(idx_obs) 325 | 326 | if "SAGE" in model.model_name: 327 | # Create dataloader for SAGE 328 | 329 | # Create csr/coo/csc formats before launching sampling processes 330 | # This avoids creating certain formats in each data loader process, which saves momory and CPU. 331 | obs_g.create_formats_() 332 | g.create_formats_() 333 | sampler = dgl.dataloading.MultiLayerNeighborSampler( 334 | [eval(fanout) for fanout in conf["fan_out"].split(",")] 335 | ) 336 | obs_dataloader = dgl.dataloading.NodeDataLoader( 337 | obs_g, 338 | obs_idx_train, 339 | sampler, 340 | batch_size=batch_size, 341 | shuffle=True, 342 | drop_last=False, 343 | num_workers=conf["num_workers"], 344 | ) 345 | 346 | sampler_eval = dgl.dataloading.MultiLayerFullNeighborSampler(1) 347 | obs_dataloader_eval = dgl.dataloading.NodeDataLoader( 348 | obs_g, 349 | torch.arange(obs_g.num_nodes()), 350 | sampler_eval, 351 | batch_size=batch_size, 352 | shuffle=False, 353 | drop_last=False, 354 | num_workers=conf["num_workers"], 355 | ) 356 | dataloader_eval = dgl.dataloading.NodeDataLoader( 357 | g, 358 | torch.arange(g.num_nodes()), 359 | sampler_eval, 360 | batch_size=batch_size, 361 | shuffle=False, 362 | drop_last=False, 363 | num_workers=conf["num_workers"], 364 | ) 365 | 366 | obs_data = obs_dataloader 367 | obs_data_eval = obs_dataloader_eval 368 | data_eval = dataloader_eval 369 | elif "MLP" in model.model_name: 370 | feats_train, labels_train = obs_feats[obs_idx_train], obs_labels[obs_idx_train] 371 | feats_val, labels_val = obs_feats[obs_idx_val], obs_labels[obs_idx_val] 372 | feats_test_tran, labels_test_tran = ( 373 | obs_feats[obs_idx_test], 374 | obs_labels[obs_idx_test], 375 | ) 376 | feats_test_ind, labels_test_ind = feats[idx_test_ind], labels[idx_test_ind] 377 | 378 | else: 379 | obs_g = obs_g.to(device) 380 | g = g.to(device) 381 | 382 | obs_data = obs_g 383 | obs_data_eval = obs_g 384 | data_eval = g 385 | 386 | best_epoch, best_score_val, count = 0, 0, 0 387 | for epoch in range(1, conf["max_epoch"] + 1): 388 | if "SAGE" in model.model_name: 389 | loss = train_sage( 390 | model, obs_data, obs_feats, obs_labels, criterion, optimizer 391 | ) 392 | elif "MLP" in model.model_name: 393 | loss = train_mini_batch( 394 | model, feats_train, labels_train, batch_size, criterion, optimizer 395 | ) 396 | else: 397 | loss = train( 398 | model, 399 | obs_data, 400 | obs_feats, 401 | obs_labels, 402 | criterion, 403 | optimizer, 404 | obs_idx_train, 405 | ) 406 | 407 | if epoch % conf["eval_interval"] == 0: 408 | if "MLP" in model.model_name: 409 | _, loss_train, score_train = evaluate_mini_batch( 410 | model, feats_train, labels_train, criterion, batch_size, evaluator 411 | ) 412 | _, loss_val, score_val = evaluate_mini_batch( 413 | model, feats_val, labels_val, criterion, batch_size, evaluator 414 | ) 415 | _, loss_test_tran, score_test_tran = evaluate_mini_batch( 416 | model, 417 | feats_test_tran, 418 | labels_test_tran, 419 | criterion, 420 | batch_size, 421 | evaluator, 422 | ) 423 | _, loss_test_ind, score_test_ind = evaluate_mini_batch( 424 | model, 425 | feats_test_ind, 426 | labels_test_ind, 427 | criterion, 428 | batch_size, 429 | evaluator, 430 | ) 431 | else: 432 | obs_out, loss_train, score_train = evaluate( 433 | model, 434 | obs_data_eval, 435 | obs_feats, 436 | obs_labels, 437 | criterion, 438 | evaluator, 439 | obs_idx_train, 440 | ) 441 | # Use criterion & evaluator instead of evaluate to avoid redundant forward pass 442 | loss_val = criterion( 443 | obs_out[obs_idx_val], obs_labels[obs_idx_val] 444 | ).item() 445 | score_val = evaluator(obs_out[obs_idx_val], obs_labels[obs_idx_val]) 446 | loss_test_tran = criterion( 447 | obs_out[obs_idx_test], obs_labels[obs_idx_test] 448 | ).item() 449 | score_test_tran = evaluator( 450 | obs_out[obs_idx_test], obs_labels[obs_idx_test] 451 | ) 452 | 453 | # Evaluate the inductive part with the full graph 454 | out, loss_test_ind, score_test_ind = evaluate( 455 | model, data_eval, feats, labels, criterion, evaluator, idx_test_ind 456 | ) 457 | logger.debug( 458 | f"Ep {epoch:3d} | loss: {loss:.4f} | s_train: {score_train:.4f} | s_val: {score_val:.4f} | s_tt: {score_test_tran:.4f} | s_ti: {score_test_ind:.4f}" 459 | ) 460 | loss_and_score += [ 461 | [ 462 | epoch, 463 | loss_train, 464 | loss_val, 465 | loss_test_tran, 466 | loss_test_ind, 467 | score_train, 468 | score_val, 469 | score_test_tran, 470 | score_test_ind, 471 | ] 472 | ] 473 | if score_val >= best_score_val: 474 | best_epoch = epoch 475 | best_score_val = score_val 476 | state = copy.deepcopy(model.state_dict()) 477 | count = 0 478 | else: 479 | count += 1 480 | 481 | if count == conf["patience"] or epoch == conf["max_epoch"]: 482 | break 483 | 484 | model.load_state_dict(state) 485 | if "MLP" in model.model_name: 486 | obs_out, _, score_val = evaluate_mini_batch( 487 | model, obs_feats, obs_labels, criterion, batch_size, evaluator, obs_idx_val 488 | ) 489 | out, _, score_test_ind = evaluate_mini_batch( 490 | model, feats, labels, criterion, batch_size, evaluator, idx_test_ind 491 | ) 492 | 493 | else: 494 | obs_out, _, score_val = evaluate( 495 | model, 496 | obs_data_eval, 497 | obs_feats, 498 | obs_labels, 499 | criterion, 500 | evaluator, 501 | obs_idx_val, 502 | ) 503 | out, _, score_test_ind = evaluate( 504 | model, data_eval, feats, labels, criterion, evaluator, idx_test_ind 505 | ) 506 | 507 | score_test_tran = evaluator(obs_out[obs_idx_test], obs_labels[obs_idx_test]) 508 | out[idx_obs] = obs_out 509 | logger.info( 510 | f"Best valid model at epoch: {best_epoch :3d}, score_val: {score_val :.4f}, score_test_tran: {score_test_tran :.4f}, score_test_ind: {score_test_ind :.4f}" 511 | ) 512 | return out, score_val, score_test_tran, score_test_ind 513 | 514 | 515 | """ 516 | 3. Distill 517 | """ 518 | 519 | 520 | def distill_run_transductive( 521 | conf, 522 | model, 523 | feats, 524 | labels, 525 | out_t_all, 526 | distill_indices, 527 | criterion_l, 528 | criterion_t, 529 | evaluator, 530 | optimizer, 531 | logger, 532 | loss_and_score, 533 | ): 534 | """ 535 | Distill training and eval under the transductive setting. 536 | The hard_label_train/soft_label_train/valid/test split is specified by `distill_indices`. 537 | The input graph is assumed to be large, and MLP is assumed to be the student model. Thus, node feature only and mini-batch is used. 538 | 539 | out_t: Soft labels produced by the teacher model. 540 | criterion_l & criterion_t: Loss used for hard labels (`labels`) and soft labels (`out_t`) respectively 541 | loss_and_score: Stores losses and scores. 542 | """ 543 | set_seed(conf["seed"]) 544 | device = conf["device"] 545 | batch_size = conf["batch_size"] 546 | lamb = conf["lamb"] 547 | idx_l, idx_t, idx_val, idx_test = distill_indices 548 | 549 | feats = feats.to(device) 550 | labels = labels.to(device) 551 | out_t_all = out_t_all.to(device) 552 | 553 | feats_l, labels_l = feats[idx_l], labels[idx_l] 554 | feats_t, out_t = feats[idx_t], out_t_all[idx_t] 555 | feats_val, labels_val = feats[idx_val], labels[idx_val] 556 | feats_test, labels_test = feats[idx_test], labels[idx_test] 557 | 558 | best_epoch, best_score_val, count = 0, 0, 0 559 | for epoch in range(1, conf["max_epoch"] + 1): 560 | loss_l = train_mini_batch( 561 | model, feats_l, labels_l, batch_size, criterion_l, optimizer, lamb 562 | ) 563 | loss_t = train_mini_batch( 564 | model, feats_t, out_t, batch_size, criterion_t, optimizer, 1 - lamb 565 | ) 566 | loss = loss_l + loss_t 567 | if epoch % conf["eval_interval"] == 0: 568 | _, loss_l, score_l = evaluate_mini_batch( 569 | model, feats_l, labels_l, criterion_l, batch_size, evaluator 570 | ) 571 | _, loss_val, score_val = evaluate_mini_batch( 572 | model, feats_val, labels_val, criterion_l, batch_size, evaluator 573 | ) 574 | _, loss_test, score_test = evaluate_mini_batch( 575 | model, feats_test, labels_test, criterion_l, batch_size, evaluator 576 | ) 577 | 578 | logger.debug( 579 | f"Ep {epoch:3d} | loss: {loss:.4f} | s_l: {score_l:.4f} | s_val: {score_val:.4f} | s_test: {score_test:.4f}" 580 | ) 581 | loss_and_score += [ 582 | [epoch, loss_l, loss_val, loss_test, score_l, score_val, score_test] 583 | ] 584 | 585 | if score_val >= best_score_val: 586 | best_epoch = epoch 587 | best_score_val = score_val 588 | state = copy.deepcopy(model.state_dict()) 589 | count = 0 590 | else: 591 | count += 1 592 | 593 | if count == conf["patience"] or epoch == conf["max_epoch"]: 594 | break 595 | 596 | model.load_state_dict(state) 597 | out, _, score_val = evaluate_mini_batch( 598 | model, feats, labels, criterion_l, batch_size, evaluator, idx_val 599 | ) 600 | # Use evaluator instead of evaluate to avoid redundant forward pass 601 | score_test = evaluator(out[idx_test], labels_test) 602 | 603 | logger.info( 604 | f"Best valid model at epoch: {best_epoch: 3d}, score_val: {score_val :.4f}, score_test: {score_test :.4f}" 605 | ) 606 | return out, score_val, score_test 607 | 608 | 609 | def distill_run_inductive( 610 | conf, 611 | model, 612 | feats, 613 | labels, 614 | out_t_all, 615 | distill_indices, 616 | criterion_l, 617 | criterion_t, 618 | evaluator, 619 | optimizer, 620 | logger, 621 | loss_and_score, 622 | ): 623 | """ 624 | Distill training and eval under the inductive setting. 625 | The hard_label_train/soft_label_train/valid/test split is specified by `distill_indices`. 626 | idx starting with `obs_idx_` contains the node idx in the observed graph `obs_g`. 627 | idx starting with `idx_` contains the node idx in the original graph `g`. 628 | The model is trained on the observed graph `obs_g`, and evaluated on both the observed test nodes (`obs_idx_test`) and inductive test nodes (`idx_test_ind`). 629 | The input graph is assumed to be large, and MLP is assumed to be the student model. Thus, node feature only and mini-batch is used. 630 | 631 | idx_obs: Idx of nodes in the original graph `g`, which form the observed graph 'obs_g'. 632 | out_t: Soft labels produced by the teacher model. 633 | criterion_l & criterion_t: Loss used for hard labels (`labels`) and soft labels (`out_t`) respectively. 634 | loss_and_score: Stores losses and scores. 635 | """ 636 | 637 | set_seed(conf["seed"]) 638 | device = conf["device"] 639 | batch_size = conf["batch_size"] 640 | lamb = conf["lamb"] 641 | ( 642 | obs_idx_l, 643 | obs_idx_t, 644 | obs_idx_val, 645 | obs_idx_test, 646 | idx_obs, 647 | idx_test_ind, 648 | ) = distill_indices 649 | 650 | feats = feats.to(device) 651 | labels = labels.to(device) 652 | out_t_all = out_t_all.to(device) 653 | obs_feats = feats[idx_obs] 654 | obs_labels = labels[idx_obs] 655 | obs_out_t = out_t_all[idx_obs] 656 | 657 | feats_l, labels_l = obs_feats[obs_idx_l], obs_labels[obs_idx_l] 658 | feats_t, out_t = obs_feats[obs_idx_t], obs_out_t[obs_idx_t] 659 | feats_val, labels_val = obs_feats[obs_idx_val], obs_labels[obs_idx_val] 660 | feats_test_tran, labels_test_tran = ( 661 | obs_feats[obs_idx_test], 662 | obs_labels[obs_idx_test], 663 | ) 664 | feats_test_ind, labels_test_ind = feats[idx_test_ind], labels[idx_test_ind] 665 | 666 | best_epoch, best_score_val, count = 0, 0, 0 667 | for epoch in range(1, conf["max_epoch"] + 1): 668 | loss_l = train_mini_batch( 669 | model, feats_l, labels_l, batch_size, criterion_l, optimizer, lamb 670 | ) 671 | loss_t = train_mini_batch( 672 | model, feats_t, out_t, batch_size, criterion_t, optimizer, 1 - lamb 673 | ) 674 | loss = loss_l + loss_t 675 | if epoch % conf["eval_interval"] == 0: 676 | _, loss_l, score_l = evaluate_mini_batch( 677 | model, feats_l, labels_l, criterion_l, batch_size, evaluator 678 | ) 679 | _, loss_val, score_val = evaluate_mini_batch( 680 | model, feats_val, labels_val, criterion_l, batch_size, evaluator 681 | ) 682 | _, loss_test_tran, score_test_tran = evaluate_mini_batch( 683 | model, 684 | feats_test_tran, 685 | labels_test_tran, 686 | criterion_l, 687 | batch_size, 688 | evaluator, 689 | ) 690 | _, loss_test_ind, score_test_ind = evaluate_mini_batch( 691 | model, 692 | feats_test_ind, 693 | labels_test_ind, 694 | criterion_l, 695 | batch_size, 696 | evaluator, 697 | ) 698 | 699 | logger.debug( 700 | f"Ep {epoch:3d} | l: {loss:.4f} | s_l: {score_l:.4f} | s_val: {score_val:.4f} | s_tt: {score_test_tran:.4f} | s_ti: {score_test_ind:.4f}" 701 | ) 702 | loss_and_score += [ 703 | [ 704 | epoch, 705 | loss_l, 706 | loss_val, 707 | loss_test_tran, 708 | loss_test_ind, 709 | score_l, 710 | score_val, 711 | score_test_tran, 712 | score_test_ind, 713 | ] 714 | ] 715 | 716 | if score_val >= best_score_val: 717 | best_epoch = epoch 718 | best_score_val = score_val 719 | state = copy.deepcopy(model.state_dict()) 720 | count = 0 721 | else: 722 | count += 1 723 | 724 | if count == conf["patience"] or epoch == conf["max_epoch"]: 725 | break 726 | 727 | model.load_state_dict(state) 728 | obs_out, _, score_val = evaluate_mini_batch( 729 | model, obs_feats, obs_labels, criterion_l, batch_size, evaluator, obs_idx_val 730 | ) 731 | out, _, score_test_ind = evaluate_mini_batch( 732 | model, feats, labels, criterion_l, batch_size, evaluator, idx_test_ind 733 | ) 734 | 735 | # Use evaluator instead of evaluate to avoid redundant forward pass 736 | score_test_tran = evaluator(obs_out[obs_idx_test], labels_test_tran) 737 | out[idx_obs] = obs_out 738 | 739 | logger.info( 740 | f"Best valid model at epoch: {best_epoch: 3d} score_val: {score_val :.4f}, score_test_tran: {score_test_tran :.4f}, score_test_ind: {score_test_ind :.4f}" 741 | ) 742 | return out, score_val, score_test_tran, score_test_ind 743 | -------------------------------------------------------------------------------- /train_student.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | import torch.optim as optim 5 | from pathlib import Path 6 | from models import Model 7 | from dataloader import load_data, load_out_t 8 | from utils import ( 9 | get_logger, 10 | get_evaluator, 11 | set_seed, 12 | get_training_config, 13 | check_writable, 14 | check_readable, 15 | compute_min_cut_loss, 16 | graph_split, 17 | feature_prop, 18 | ) 19 | from train_and_eval import distill_run_transductive, distill_run_inductive 20 | 21 | 22 | def get_args(): 23 | parser = argparse.ArgumentParser(description="PyTorch DGL implementation") 24 | parser.add_argument("--device", type=int, default=-1, help="CUDA device, -1 means CPU") 25 | parser.add_argument("--seed", type=int, default=0, help="Random seed") 26 | parser.add_argument( 27 | "--log_level", 28 | type=int, 29 | default=20, 30 | help="Logger levels for run {10: DEBUG, 20: INFO, 30: WARNING}", 31 | ) 32 | parser.add_argument( 33 | "--console_log", 34 | action="store_true", 35 | help="Set to True to display log info in console", 36 | ) 37 | parser.add_argument( 38 | "--output_path", type=str, default="outputs", help="Path to save outputs" 39 | ) 40 | parser.add_argument( 41 | "--num_exp", type=int, default=1, help="Repeat how many experiments" 42 | ) 43 | parser.add_argument( 44 | "--exp_setting", 45 | type=str, 46 | default="tran", 47 | help="Experiment setting, one of [tran, ind]", 48 | ) 49 | parser.add_argument( 50 | "--eval_interval", type=int, default=1, help="Evaluate once per how many epochs" 51 | ) 52 | parser.add_argument( 53 | "--save_results", 54 | action="store_true", 55 | help="Set to True to save the loss curves, trained model, and min-cut loss for the transductive setting", 56 | ) 57 | 58 | """Dataset""" 59 | parser.add_argument("--dataset", type=str, default="cora", help="Dataset") 60 | parser.add_argument("--data_path", type=str, default="./data", help="Path to data") 61 | parser.add_argument( 62 | "--labelrate_train", 63 | type=int, 64 | default=20, 65 | help="How many labeled data per class as train set", 66 | ) 67 | parser.add_argument( 68 | "--labelrate_val", 69 | type=int, 70 | default=30, 71 | help="How many labeled data per class in valid set", 72 | ) 73 | parser.add_argument( 74 | "--split_idx", 75 | type=int, 76 | default=0, 77 | help="For Non-Homo datasets only, one of [0,1,2,3,4]", 78 | ) 79 | 80 | """Model""" 81 | parser.add_argument( 82 | "--model_config_path", 83 | type=str, 84 | default="./train.conf.yaml", 85 | help="Path to model configeration", 86 | ) 87 | parser.add_argument("--teacher", type=str, default="SAGE", help="Teacher model") 88 | parser.add_argument("--student", type=str, default="MLP", help="Student model") 89 | parser.add_argument( 90 | "--num_layers", type=int, default=2, help="Student model number of layers" 91 | ) 92 | parser.add_argument( 93 | "--hidden_dim", 94 | type=int, 95 | default=64, 96 | help="Student model hidden layer dimensions", 97 | ) 98 | parser.add_argument("--dropout_ratio", type=float, default=0) 99 | parser.add_argument( 100 | "--norm_type", type=str, default="none", help="One of [none, batch, layer]" 101 | ) 102 | 103 | """SAGE Specific""" 104 | parser.add_argument("--batch_size", type=int, default=512) 105 | parser.add_argument( 106 | "--fan_out", 107 | type=str, 108 | default="5,5", 109 | help="Number of samples for each layer in SAGE. Length = num_layers", 110 | ) 111 | parser.add_argument( 112 | "--num_workers", type=int, default=0, help="Number of workers for sampler" 113 | ) 114 | 115 | """Optimization""" 116 | parser.add_argument("--learning_rate", type=float, default=0.01) 117 | parser.add_argument("--weight_decay", type=float, default=0.0005) 118 | parser.add_argument( 119 | "--max_epoch", type=int, default=500, help="Evaluate once per how many epochs" 120 | ) 121 | parser.add_argument( 122 | "--patience", 123 | type=int, 124 | default=50, 125 | help="Early stop is the score on validation set does not improve for how many epochs", 126 | ) 127 | 128 | """Ablation""" 129 | parser.add_argument( 130 | "--feature_noise", 131 | type=float, 132 | default=0, 133 | help="add white noise to features for analysis, value in [0, 1] for noise level", 134 | ) 135 | parser.add_argument( 136 | "--split_rate", 137 | type=float, 138 | default=0.2, 139 | help="Rate for graph split, see comment of graph_split for more details", 140 | ) 141 | parser.add_argument( 142 | "--compute_min_cut", 143 | action="store_true", 144 | help="Set to True to compute and store the min-cut loss", 145 | ) 146 | parser.add_argument( 147 | "--feature_aug_k", 148 | type=int, 149 | default=0, 150 | help="Augment node futures by aggregating feature_aug_k-hop neighbor features", 151 | ) 152 | 153 | """Distiall""" 154 | parser.add_argument( 155 | "--lamb", 156 | type=float, 157 | default=0, 158 | help="Parameter balances loss from hard labels and teacher outputs, take values in [0, 1]", 159 | ) 160 | parser.add_argument( 161 | "--out_t_path", type=str, default="outputs", help="Path to load teacher outputs" 162 | ) 163 | args = parser.parse_args() 164 | 165 | return args 166 | 167 | 168 | def run(args): 169 | """ 170 | Returns: 171 | score_lst: a list of evaluation results on test set. 172 | len(score_lst) = 1 for the transductive setting. 173 | len(score_lst) = 2 for the inductive/production setting. 174 | """ 175 | 176 | """ Set seed, device, and logger """ 177 | set_seed(args.seed) 178 | if torch.cuda.is_available() and args.device >= 0: 179 | device = torch.device("cuda:" + str(args.device)) 180 | else: 181 | device = "cpu" 182 | 183 | if args.feature_noise != 0 and args.seed == 0: 184 | args.output_path = Path.cwd().joinpath( 185 | args.output_path, "noisy_features", f"noise_{args.feature_noise}" 186 | ) 187 | # Teacher is assumed to be trained on the same noisy features as well. 188 | args.out_t_path = args.output_path 189 | 190 | if args.feature_aug_k > 0 and args.seed == 0: 191 | args.output_path = Path.cwd().joinpath( 192 | args.output_path, "aug_features", f"aug_hop_{args.feature_aug_k}" 193 | ) 194 | # NOTE: Teacher may or may not have augmented features, specify args.out_t_path explicitly. 195 | # args.out_t_path = 196 | args.student = f"GA{args.feature_aug_k}{args.student}" 197 | 198 | if args.exp_setting == "tran": 199 | output_dir = Path.cwd().joinpath( 200 | args.output_path, 201 | "transductive", 202 | args.dataset, 203 | f"{args.teacher}_{args.student}", 204 | f"seed_{args.seed}", 205 | ) 206 | out_t_dir = Path.cwd().joinpath( 207 | args.out_t_path, 208 | "transductive", 209 | args.dataset, 210 | args.teacher, 211 | f"seed_{args.seed}", 212 | ) 213 | elif args.exp_setting == "ind": 214 | output_dir = Path.cwd().joinpath( 215 | args.output_path, 216 | "inductive", 217 | f"split_rate_{args.split_rate}", 218 | args.dataset, 219 | f"{args.teacher}_{args.student}", 220 | f"seed_{args.seed}", 221 | ) 222 | out_t_dir = Path.cwd().joinpath( 223 | args.out_t_path, 224 | "inductive", 225 | f"split_rate_{args.split_rate}", 226 | args.dataset, 227 | args.teacher, 228 | f"seed_{args.seed}", 229 | ) 230 | else: 231 | raise ValueError(f"Unknown experiment setting! {args.exp_setting}") 232 | args.output_dir = output_dir 233 | 234 | check_writable(output_dir, overwrite=False) 235 | check_readable(out_t_dir) 236 | 237 | logger = get_logger(output_dir.joinpath("log"), args.console_log, args.log_level) 238 | logger.info(f"output_dir: {output_dir}") 239 | logger.info(f"out_t_dir: {out_t_dir}") 240 | 241 | """ Load data and model config""" 242 | g, labels, idx_train, idx_val, idx_test = load_data( 243 | args.dataset, 244 | args.data_path, 245 | split_idx=args.split_idx, 246 | seed=args.seed, 247 | labelrate_train=args.labelrate_train, 248 | labelrate_val=args.labelrate_val, 249 | ) 250 | 251 | logger.info(f"Total {g.number_of_nodes()} nodes.") 252 | logger.info(f"Total {g.number_of_edges()} edges.") 253 | 254 | feats = g.ndata["feat"] 255 | args.feat_dim = g.ndata["feat"].shape[1] 256 | args.label_dim = labels.int().max().item() + 1 257 | 258 | if 0 < args.feature_noise <= 1: 259 | feats = ( 260 | 1 - args.feature_noise 261 | ) * feats + args.feature_noise * torch.randn_like(feats) 262 | 263 | """ Model config """ 264 | conf = {} 265 | if args.model_config_path is not None: 266 | conf = get_training_config( 267 | args.model_config_path, args.student, args.dataset 268 | ) # Note: student config 269 | conf = dict(args.__dict__, **conf) 270 | conf["device"] = device 271 | logger.info(f"conf: {conf}") 272 | 273 | """ Model init """ 274 | model = Model(conf) 275 | optimizer = optim.Adam( 276 | model.parameters(), lr=conf["learning_rate"], weight_decay=conf["weight_decay"] 277 | ) 278 | criterion_l = torch.nn.NLLLoss() 279 | criterion_t = torch.nn.KLDivLoss(reduction="batchmean", log_target=True) 280 | evaluator = get_evaluator(conf["dataset"]) 281 | 282 | """Load teacher model output""" 283 | out_t = load_out_t(out_t_dir) 284 | logger.debug( 285 | f"teacher score on train data: {evaluator(out_t[idx_train], labels[idx_train])}" 286 | ) 287 | logger.debug( 288 | f"teacher score on val data: {evaluator(out_t[idx_val], labels[idx_val])}" 289 | ) 290 | logger.debug( 291 | f"teacher score on test data: {evaluator(out_t[idx_test], labels[idx_test])}" 292 | ) 293 | 294 | """Data split and run""" 295 | loss_and_score = [] 296 | if args.exp_setting == "tran": 297 | idx_l = idx_train 298 | idx_t = torch.cat([idx_train, idx_val, idx_test]) 299 | distill_indices = (idx_l, idx_t, idx_val, idx_test) 300 | 301 | # propagate node feature 302 | if args.feature_aug_k > 0: 303 | feats = feature_prop(feats, g, args.feature_aug_k) 304 | 305 | out, score_val, score_test = distill_run_transductive( 306 | conf, 307 | model, 308 | feats, 309 | labels, 310 | out_t, 311 | distill_indices, 312 | criterion_l, 313 | criterion_t, 314 | evaluator, 315 | optimizer, 316 | logger, 317 | loss_and_score, 318 | ) 319 | score_lst = [score_test] 320 | 321 | elif args.exp_setting == "ind": 322 | # Create inductive split 323 | obs_idx_train, obs_idx_val, obs_idx_test, idx_obs, idx_test_ind = graph_split( 324 | idx_train, idx_val, idx_test, args.split_rate, args.seed 325 | ) 326 | obs_idx_l = obs_idx_train 327 | obs_idx_t = torch.cat([obs_idx_train, obs_idx_val, obs_idx_test]) 328 | distill_indices = ( 329 | obs_idx_l, 330 | obs_idx_t, 331 | obs_idx_val, 332 | obs_idx_test, 333 | idx_obs, 334 | idx_test_ind, 335 | ) 336 | 337 | # propagate node feature. The propagation for the observed graph only happens within the subgraph obs_g 338 | if args.feature_aug_k > 0: 339 | obs_g = g.subgraph(idx_obs) 340 | obs_feats = feature_prop(feats[idx_obs], obs_g, args.feature_aug_k) 341 | feats = feature_prop(feats, g, args.feature_aug_k) 342 | feats[idx_obs] = obs_feats 343 | 344 | out, score_val, score_test_tran, score_test_ind = distill_run_inductive( 345 | conf, 346 | model, 347 | feats, 348 | labels, 349 | out_t, 350 | distill_indices, 351 | criterion_l, 352 | criterion_t, 353 | evaluator, 354 | optimizer, 355 | logger, 356 | loss_and_score, 357 | ) 358 | score_lst = [score_test_tran, score_test_ind] 359 | 360 | logger.info( 361 | f"num_layers: {conf['num_layers']}. hidden_dim: {conf['hidden_dim']}. dropout_ratio: {conf['dropout_ratio']}" 362 | ) 363 | logger.info(f"# params {sum(p.numel() for p in model.parameters())}") 364 | 365 | """ Saving student outputs """ 366 | out_np = out.detach().cpu().numpy() 367 | np.savez(output_dir.joinpath("out"), out_np) 368 | 369 | """ Saving loss curve and model """ 370 | if args.save_results: 371 | # Loss curves 372 | loss_and_score = np.array(loss_and_score) 373 | np.savez(output_dir.joinpath("loss_and_score"), loss_and_score) 374 | 375 | # Model 376 | torch.save(model.state_dict(), output_dir.joinpath("model.pth")) 377 | 378 | """ Saving min-cut loss""" 379 | if args.exp_setting == "tran" and args.compute_min_cut: 380 | min_cut = compute_min_cut_loss(g, out) 381 | with open(output_dir.parent.joinpath("min_cut_loss"), "a+") as f: 382 | f.write(f"{min_cut :.4f}\n") 383 | 384 | return score_lst 385 | 386 | 387 | def repeat_run(args): 388 | scores = [] 389 | for seed in range(args.num_exp): 390 | args.seed = seed 391 | scores.append(run(args)) 392 | scores_np = np.array(scores) 393 | return scores_np.mean(axis=0), scores_np.std(axis=0) 394 | 395 | 396 | def main(): 397 | args = get_args() 398 | if args.num_exp == 1: 399 | score = run(args) 400 | score_str = "".join([f"{s : .4f}\t" for s in score]) 401 | 402 | elif args.num_exp > 1: 403 | score_mean, score_std = repeat_run(args) 404 | score_str = "".join( 405 | [f"{s : .4f}\t" for s in score_mean] + [f"{s : .4f}\t" for s in score_std] 406 | ) 407 | 408 | with open(args.output_dir.parent.joinpath("exp_results"), "a+") as f: 409 | f.write(f"{score_str}\n") 410 | 411 | # for collecting aggregated results 412 | print(score_str) 413 | 414 | 415 | if __name__ == "__main__": 416 | main() 417 | -------------------------------------------------------------------------------- /train_teacher.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | import torch.optim as optim 5 | from pathlib import Path 6 | from models import Model 7 | from dataloader import load_data 8 | from utils import ( 9 | get_logger, 10 | get_evaluator, 11 | set_seed, 12 | get_training_config, 13 | check_writable, 14 | compute_min_cut_loss, 15 | graph_split, 16 | feature_prop, 17 | ) 18 | from train_and_eval import run_transductive, run_inductive 19 | 20 | 21 | def get_args(): 22 | parser = argparse.ArgumentParser(description="PyTorch DGL implementation") 23 | parser.add_argument("--device", type=int, default=-1, help="CUDA device, -1 means CPU") 24 | parser.add_argument("--seed", type=int, default=0, help="Random seed") 25 | parser.add_argument( 26 | "--log_level", 27 | type=int, 28 | default=20, 29 | help="Logger levels for run {10: DEBUG, 20: INFO, 30: WARNING}", 30 | ) 31 | parser.add_argument( 32 | "--console_log", 33 | action="store_true", 34 | help="Set to True to display log info in console", 35 | ) 36 | parser.add_argument( 37 | "--output_path", type=str, default="outputs", help="Path to save outputs" 38 | ) 39 | parser.add_argument( 40 | "--num_exp", type=int, default=1, help="Repeat how many experiments" 41 | ) 42 | parser.add_argument( 43 | "--exp_setting", 44 | type=str, 45 | default="tran", 46 | help="Experiment setting, one of [tran, ind]", 47 | ) 48 | parser.add_argument( 49 | "--eval_interval", type=int, default=1, help="Evaluate once per how many epochs" 50 | ) 51 | parser.add_argument( 52 | "--save_results", 53 | action="store_true", 54 | help="Set to True to save the loss curves, trained model, and min-cut loss for the transductive setting", 55 | ) 56 | 57 | """Dataset""" 58 | parser.add_argument("--dataset", type=str, default="cora", help="Dataset") 59 | parser.add_argument("--data_path", type=str, default="./data", help="Path to data") 60 | parser.add_argument( 61 | "--labelrate_train", 62 | type=int, 63 | default=20, 64 | help="How many labeled data per class as train set", 65 | ) 66 | parser.add_argument( 67 | "--labelrate_val", 68 | type=int, 69 | default=30, 70 | help="How many labeled data per class in valid set", 71 | ) 72 | parser.add_argument( 73 | "--split_idx", 74 | type=int, 75 | default=0, 76 | help="For Non-Homo datasets only, one of [0,1,2,3,4]", 77 | ) 78 | 79 | """Model""" 80 | parser.add_argument( 81 | "--model_config_path", 82 | type=str, 83 | default="./train.conf.yaml", 84 | help="Path to model configeration", 85 | ) 86 | parser.add_argument("--teacher", type=str, default="SAGE", help="Teacher model") 87 | parser.add_argument( 88 | "--num_layers", type=int, default=2, help="Model number of layers" 89 | ) 90 | parser.add_argument( 91 | "--hidden_dim", type=int, default=128, help="Model hidden layer dimensions" 92 | ) 93 | parser.add_argument("--dropout_ratio", type=float, default=0) 94 | parser.add_argument( 95 | "--norm_type", type=str, default="none", help="One of [none, batch, layer]" 96 | ) 97 | 98 | """SAGE Specific""" 99 | parser.add_argument("--batch_size", type=int, default=512) 100 | parser.add_argument( 101 | "--fan_out", 102 | type=str, 103 | default="5,5", 104 | help="Number of samples for each layer in SAGE. Length = num_layers", 105 | ) 106 | parser.add_argument( 107 | "--num_workers", type=int, default=0, help="Number of workers for sampler" 108 | ) 109 | 110 | """Optimization""" 111 | parser.add_argument("--learning_rate", type=float, default=0.01) 112 | parser.add_argument("--weight_decay", type=float, default=0.0005) 113 | parser.add_argument( 114 | "--max_epoch", type=int, default=500, help="Evaluate once per how many epochs" 115 | ) 116 | parser.add_argument( 117 | "--patience", 118 | type=int, 119 | default=50, 120 | help="Early stop is the score on validation set does not improve for how many epochs", 121 | ) 122 | 123 | """Ablation""" 124 | parser.add_argument( 125 | "--feature_noise", 126 | type=float, 127 | default=0, 128 | help="add white noise to features for analysis, value in [0, 1] for noise level", 129 | ) 130 | parser.add_argument( 131 | "--split_rate", 132 | type=float, 133 | default=0.2, 134 | help="Rate for graph split, see comment of graph_split for more details", 135 | ) 136 | parser.add_argument( 137 | "--compute_min_cut", 138 | action="store_true", 139 | help="Set to True to compute and store the min-cut loss", 140 | ) 141 | parser.add_argument( 142 | "--feature_aug_k", 143 | type=int, 144 | default=0, 145 | help="Augment node futures by aggregating feature_aug_k-hop neighbor features", 146 | ) 147 | 148 | args = parser.parse_args() 149 | return args 150 | 151 | 152 | def run(args): 153 | """ 154 | Returns: 155 | score_lst: a list of evaluation results on test set. 156 | len(score_lst) = 1 for the transductive setting. 157 | len(score_lst) = 2 for the inductive/production setting. 158 | """ 159 | 160 | """ Set seed, device, and logger """ 161 | set_seed(args.seed) 162 | if torch.cuda.is_available() and args.device >= 0: 163 | device = torch.device("cuda:" + str(args.device)) 164 | else: 165 | device = "cpu" 166 | 167 | if args.feature_noise != 0 and args.seed == 0: 168 | args.output_path = Path.cwd().joinpath( 169 | args.output_path, "noisy_features", f"noise_{args.feature_noise}" 170 | ) 171 | 172 | if args.feature_aug_k > 0 and args.seed == 0: 173 | args.output_path = Path.cwd().joinpath( 174 | args.output_path, "aug_features", f"aug_hop_{args.feature_aug_k}" 175 | ) 176 | args.teacher = f"GA{args.feature_aug_k}{args.teacher}" 177 | 178 | if args.exp_setting == "tran": 179 | output_dir = Path.cwd().joinpath( 180 | args.output_path, 181 | "transductive", 182 | args.dataset, 183 | args.teacher, 184 | f"seed_{args.seed}", 185 | ) 186 | elif args.exp_setting == "ind": 187 | output_dir = Path.cwd().joinpath( 188 | args.output_path, 189 | "inductive", 190 | f"split_rate_{args.split_rate}", 191 | args.dataset, 192 | args.teacher, 193 | f"seed_{args.seed}", 194 | ) 195 | else: 196 | raise ValueError(f"Unknown experiment setting! {args.exp_setting}") 197 | args.output_dir = output_dir 198 | 199 | check_writable(output_dir, overwrite=False) 200 | logger = get_logger(output_dir.joinpath("log"), args.console_log, args.log_level) 201 | logger.info(f"output_dir: {output_dir}") 202 | 203 | """ Load data """ 204 | g, labels, idx_train, idx_val, idx_test = load_data( 205 | args.dataset, 206 | args.data_path, 207 | split_idx=args.split_idx, 208 | seed=args.seed, 209 | labelrate_train=args.labelrate_train, 210 | labelrate_val=args.labelrate_val, 211 | ) 212 | logger.info(f"Total {g.number_of_nodes()} nodes.") 213 | logger.info(f"Total {g.number_of_edges()} edges.") 214 | 215 | feats = g.ndata["feat"] 216 | args.feat_dim = g.ndata["feat"].shape[1] 217 | args.label_dim = labels.int().max().item() + 1 218 | 219 | if 0 < args.feature_noise <= 1: 220 | feats = ( 221 | 1 - args.feature_noise 222 | ) * feats + args.feature_noise * torch.randn_like(feats) 223 | 224 | """ Model config """ 225 | conf = {} 226 | if args.model_config_path is not None: 227 | conf = get_training_config(args.model_config_path, args.teacher, args.dataset) 228 | conf = dict(args.__dict__, **conf) 229 | conf["device"] = device 230 | logger.info(f"conf: {conf}") 231 | 232 | """ Model init """ 233 | model = Model(conf) 234 | optimizer = optim.Adam( 235 | model.parameters(), lr=conf["learning_rate"], weight_decay=conf["weight_decay"] 236 | ) 237 | criterion = torch.nn.NLLLoss() 238 | evaluator = get_evaluator(conf["dataset"]) 239 | 240 | """ Data split and run """ 241 | loss_and_score = [] 242 | if args.exp_setting == "tran": 243 | indices = (idx_train, idx_val, idx_test) 244 | 245 | # propagate node feature 246 | if args.feature_aug_k > 0: 247 | feats = feature_prop(feats, g, args.feature_aug_k) 248 | 249 | out, score_val, score_test = run_transductive( 250 | conf, 251 | model, 252 | g, 253 | feats, 254 | labels, 255 | indices, 256 | criterion, 257 | evaluator, 258 | optimizer, 259 | logger, 260 | loss_and_score, 261 | ) 262 | score_lst = [score_test] 263 | 264 | elif args.exp_setting == "ind": 265 | indices = graph_split(idx_train, idx_val, idx_test, args.split_rate, args.seed) 266 | 267 | # propagate node feature. The propagation for the observed graph only happens within the subgraph obs_g 268 | if args.feature_aug_k > 0: 269 | idx_obs = indices[3] 270 | obs_g = g.subgraph(idx_obs) 271 | obs_feats = feature_prop(feats[idx_obs], obs_g, args.feature_aug_k) 272 | feats = feature_prop(feats, g, args.feature_aug_k) 273 | feats[idx_obs] = obs_feats 274 | 275 | out, score_val, score_test_tran, score_test_ind = run_inductive( 276 | conf, 277 | model, 278 | g, 279 | feats, 280 | labels, 281 | indices, 282 | criterion, 283 | evaluator, 284 | optimizer, 285 | logger, 286 | loss_and_score, 287 | ) 288 | score_lst = [score_test_tran, score_test_ind] 289 | 290 | logger.info( 291 | f"num_layers: {conf['num_layers']}. hidden_dim: {conf['hidden_dim']}. dropout_ratio: {conf['dropout_ratio']}" 292 | ) 293 | logger.info(f"# params {sum(p.numel() for p in model.parameters())}") 294 | 295 | """ Saving teacher outputs """ 296 | out_np = out.detach().cpu().numpy() 297 | np.savez(output_dir.joinpath("out"), out_np) 298 | 299 | """ Saving loss curve and model """ 300 | if args.save_results: 301 | # Loss curves 302 | loss_and_score = np.array(loss_and_score) 303 | np.savez(output_dir.joinpath("loss_and_score"), loss_and_score) 304 | 305 | # Model 306 | torch.save(model.state_dict(), output_dir.joinpath("model.pth")) 307 | 308 | """ Saving min-cut loss """ 309 | if args.exp_setting == "tran" and args.compute_min_cut: 310 | min_cut = compute_min_cut_loss(g, out) 311 | with open(output_dir.parent.joinpath("min_cut_loss"), "a+") as f: 312 | f.write(f"{min_cut :.4f}\n") 313 | 314 | return score_lst 315 | 316 | 317 | def repeat_run(args): 318 | scores = [] 319 | for seed in range(args.num_exp): 320 | args.seed = seed 321 | scores.append(run(args)) 322 | scores_np = np.array(scores) 323 | return scores_np.mean(axis=0), scores_np.std(axis=0) 324 | 325 | 326 | def main(): 327 | args = get_args() 328 | if args.num_exp == 1: 329 | score = run(args) 330 | score_str = "".join([f"{s : .4f}\t" for s in score]) 331 | 332 | elif args.num_exp > 1: 333 | score_mean, score_std = repeat_run(args) 334 | score_str = "".join( 335 | [f"{s : .4f}\t" for s in score_mean] + [f"{s : .4f}\t" for s in score_std] 336 | ) 337 | 338 | with open(args.output_dir.parent.joinpath("exp_results"), "a+") as f: 339 | f.write(f"{score_str}\n") 340 | 341 | # for collecting aggregated results 342 | print(score_str) 343 | 344 | 345 | if __name__ == "__main__": 346 | main() 347 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import logging 4 | import pytz 5 | import random 6 | import os 7 | import yaml 8 | import shutil 9 | from datetime import datetime 10 | from ogb.nodeproppred import Evaluator 11 | from dgl import function as fn 12 | 13 | CPF_data = ["cora", "citeseer", "pubmed", "a-computer", "a-photo"] 14 | OGB_data = ["ogbn-arxiv", "ogbn-products"] 15 | NonHom_data = ["pokec", "penn94"] 16 | BGNN_data = ["house_class", "vk_class"] 17 | 18 | 19 | def set_seed(seed): 20 | torch.manual_seed(seed) 21 | np.random.seed(seed) 22 | random.seed(seed) 23 | torch.backends.cudnn.benchmark = False 24 | torch.backends.cudnn.deterministic = True 25 | if torch.cuda.is_available(): 26 | torch.cuda.manual_seed_all(seed) 27 | 28 | 29 | def get_training_config(config_path, model_name, dataset): 30 | with open(config_path, "r") as conf: 31 | full_config = yaml.load(conf, Loader=yaml.FullLoader) 32 | dataset_specific_config = full_config["global"] 33 | model_specific_config = full_config[dataset][model_name] 34 | 35 | if model_specific_config is not None: 36 | specific_config = dict(dataset_specific_config, **model_specific_config) 37 | else: 38 | specific_config = dataset_specific_config 39 | 40 | specific_config["model_name"] = model_name 41 | return specific_config 42 | 43 | 44 | def check_writable(path, overwrite=True): 45 | if not os.path.exists(path): 46 | os.makedirs(path) 47 | elif overwrite: 48 | shutil.rmtree(path) 49 | os.makedirs(path) 50 | else: 51 | pass 52 | 53 | 54 | def check_readable(path): 55 | if not os.path.exists(path): 56 | raise ValueError(f"No such file or directory! {path}") 57 | 58 | 59 | def timetz(*args): 60 | tz = pytz.timezone("US/Pacific") 61 | return datetime.now(tz).timetuple() 62 | 63 | 64 | def get_logger(filename, console_log=False, log_level=logging.INFO): 65 | tz = pytz.timezone("US/Pacific") 66 | log_time = datetime.now(tz).strftime("%b%d_%H_%M_%S") 67 | logger = logging.getLogger(__name__) 68 | logger.propagate = False # avoid duplicate logging 69 | logger.setLevel(log_level) 70 | 71 | # Clean logger first to avoid duplicated handlers 72 | for hdlr in logger.handlers[:]: 73 | logger.removeHandler(hdlr) 74 | 75 | file_handler = logging.FileHandler(filename) 76 | formatter = logging.Formatter("%(asctime)s: %(message)s", datefmt="%b%d %H-%M-%S") 77 | formatter.converter = timetz 78 | file_handler.setFormatter(formatter) 79 | logger.addHandler(file_handler) 80 | 81 | if console_log: 82 | console_handler = logging.StreamHandler() 83 | console_handler.setFormatter(formatter) 84 | logger.addHandler(console_handler) 85 | return logger 86 | 87 | 88 | def idx_split(idx, ratio, seed=0): 89 | """ 90 | randomly split idx into two portions with ratio% elements and (1 - ratio)% elements 91 | """ 92 | set_seed(seed) 93 | n = len(idx) 94 | cut = int(n * ratio) 95 | idx_idx_shuffle = torch.randperm(n) 96 | 97 | idx1_idx, idx2_idx = idx_idx_shuffle[:cut], idx_idx_shuffle[cut:] 98 | idx1, idx2 = idx[idx1_idx], idx[idx2_idx] 99 | # assert((torch.cat([idx1, idx2]).sort()[0] == idx.sort()[0]).all()) 100 | return idx1, idx2 101 | 102 | 103 | def graph_split(idx_train, idx_val, idx_test, rate, seed): 104 | """ 105 | Args: 106 | The original setting was transductive. Full graph is observed, and idx_train takes up a small portion. 107 | Split the graph by further divide idx_test into [idx_test_tran, idx_test_ind]. 108 | rate = idx_test_ind : idx_test (how much test to hide for the inductive evaluation) 109 | 110 | Ex. Ogbn-products 111 | loaded : train : val : test = 8 : 2 : 90, rate = 0.2 112 | after split: train : val : test_tran : test_ind = 8 : 2 : 72 : 18 113 | 114 | Return: 115 | Indices start with 'obs_' correspond to the node indices within the observed subgraph, 116 | where as indices start directly with 'idx_' correspond to the node indices in the original graph 117 | """ 118 | idx_test_ind, idx_test_tran = idx_split(idx_test, rate, seed) 119 | 120 | idx_obs = torch.cat([idx_train, idx_val, idx_test_tran]) 121 | N1, N2 = idx_train.shape[0], idx_val.shape[0] 122 | obs_idx_all = torch.arange(idx_obs.shape[0]) 123 | obs_idx_train = obs_idx_all[:N1] 124 | obs_idx_val = obs_idx_all[N1 : N1 + N2] 125 | obs_idx_test = obs_idx_all[N1 + N2 :] 126 | 127 | return obs_idx_train, obs_idx_val, obs_idx_test, idx_obs, idx_test_ind 128 | 129 | 130 | def get_evaluator(dataset): 131 | if dataset in CPF_data + NonHom_data + BGNN_data: 132 | 133 | def evaluator(out, labels): 134 | pred = out.argmax(1) 135 | return pred.eq(labels).float().mean().item() 136 | 137 | elif dataset in OGB_data: 138 | ogb_evaluator = Evaluator(dataset) 139 | 140 | def evaluator(out, labels): 141 | pred = out.argmax(1, keepdim=True) 142 | input_dict = {"y_true": labels.unsqueeze(1), "y_pred": pred} 143 | return ogb_evaluator.eval(input_dict)["acc"] 144 | 145 | else: 146 | raise ValueError("Unknown dataset") 147 | 148 | return evaluator 149 | 150 | 151 | def get_evaluator(dataset): 152 | def evaluator(out, labels): 153 | pred = out.argmax(1) 154 | return pred.eq(labels).float().mean().item() 155 | 156 | return evaluator 157 | 158 | 159 | def compute_min_cut_loss(g, out): 160 | out = out.to("cpu") 161 | S = out.exp() 162 | A = g.adj().to_dense() 163 | D = g.in_degrees().float().diag() 164 | min_cut = ( 165 | torch.matmul(torch.matmul(S.transpose(1, 0), A), S).trace() 166 | / torch.matmul(torch.matmul(S.transpose(1, 0), D), S).trace() 167 | ) 168 | return min_cut.item() 169 | 170 | 171 | def feature_prop(feats, g, k): 172 | """ 173 | Augment node feature by propagating the node features within k-hop neighborhood. 174 | The propagation is done in the SGC fashion, i.e. hop by hop and symmetrically normalized by node degrees. 175 | """ 176 | assert feats.shape[0] == g.num_nodes() 177 | 178 | degs = g.in_degrees().float().clamp(min=1) 179 | norm = torch.pow(degs, -0.5).unsqueeze(1) 180 | 181 | # compute (D^-1/2 A D^-1/2)^k X 182 | for _ in range(k): 183 | feats = feats * norm 184 | g.ndata["h"] = feats 185 | g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h")) 186 | feats = g.ndata.pop("h") 187 | feats = feats * norm 188 | 189 | return feats 190 | --------------------------------------------------------------------------------