├── .gitignore
├── LICENSE
├── README.md
├── data_preprocess.py
├── dataloader.py
├── experiments
├── ablation_feature_noise.sh
├── ablation_gnn.sh
├── ablation_ind_split_rate.sh
├── ga_glnn_arxiv.sh
├── glnn_arxiv.sh
├── glnn_cpf.sh
├── glnn_products.sh
├── sage_arxiv.sh
├── sage_cpf.sh
└── sage_products.sh
├── imgs
├── glnn.png
└── trade_off.png
├── models.py
├── prepare_env.sh
├── requirements.txt
├── train.conf.yaml
├── train_and_eval.py
├── train_student.py
├── train_teacher.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | **/data
2 | **/outputs
3 | **/*.npz
4 | **/*.ipynb
5 | **/__pycache__
6 | **/*.ipynb_checkpoints
7 | **/log
8 | **/plots
9 | **/temp
10 | ./.idea
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright Snap Inc. 2021.
2 |
3 | All rights reserved.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Graph-less Neural Networks (GLNN)
3 |
4 | Code for [Graph-less Neural Networks: Teaching Old MLPs New Tricks via Distillation](https://arxiv.org/pdf/2110.08727.pdf) by [Shichang Zhang](https://shichangzh.github.io/), [Yozen Liu](https://research.snap.com/team/yozen-liu/), [Yizhou Sun](http://web.cs.ucla.edu/~yzsun/), and [Neil Shah](http://nshah.net/).
5 |
6 |
7 | ## Overview
8 | ### Distillation framework
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | ### Accuracy vs. inference time on the `ogbn-products` dataset
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 | ## Getting Started
26 |
27 | ### Setup Environment
28 |
29 | We use conda for environment setup. You can use
30 |
31 | `bash ./prepare_env.sh`
32 |
33 | which will create a conda environment named `glnn` and install relevant requirements (from `requirements.txt`). For simplicity, we use CPU-based `torch` and `dgl` versions in this guide, as specified in requirements. To run experiments with CUDA, please install `torch` and `dgl` with proper CUDA support, remove them from `requirements.txt`, and properly set the `--device` argument in the scripts. See https://pytorch.org/ and https://www.dgl.ai/pages/start.html for more installation details.
34 |
35 | Be sure to activate the environment with
36 |
37 | `conda activate glnn`
38 |
39 | before running experiments as described below.
40 |
41 |
42 |
43 | ### Preparing datasets
44 | To run experiments for dataset used in the paper, please download from the following links and put them under `data/` (see below for instructions on organizing the datasets).
45 |
46 | - *CPF data* (`cora`, `citeseer`, `pubmed`, `a-computer`, and `a-photo`): Download the '.npz' files from [here](https://github.com/BUPT-GAMMA/CPF/tree/master/data/npz). Rename `amazon_electronics_computers.npz` and `amazon_electronics_photo.npz` to `a-computer.npz` and `a-photo.npz` respectively.
47 |
48 | - *OGB data* (`ogbn-arxiv` and `ogbn-products`): Datasets will be automatically downloaded when running the `load_data` function in `dataloader.py`. More details [here](https://ogb.stanford.edu/).
49 |
50 | - *BGNN data* (`house_class` and `vk_class`): Follow the instructions [here](https://github.com/dmlc/dgl/tree/473d5e0a4c4e4735f1c9dc9d783e0374328cca9a/examples/pytorch/bgnn) and download dataset pre-processed in DGL format from [here](https://www.dropbox.com/s/verx1evkykzli88/datasets.zip).
51 |
52 | - *NonHom data* (`penn94` and `pokec`): Follow the instructions [here](https://github.com/CUAI/Non-Homophily-Benchmarks) to download the `penn94` dataset and its splits. The `pokec` dataset will be automatically downloaded when running the `load_data` function in `dataloader.py`.
53 |
54 | - Your favourite datasets: download and add to the `load_data` function in `dataloader.py`.
55 |
56 |
57 | ### Usage
58 |
59 | To quickly train a teacher model you can run `train_teacher.py` by specifying the experiment setting, i.e. transductive (`tran`) or inductive (`ind`), teacher model, e.g. `GCN`, and dataset, e.g. `cora`, as per the example below.
60 |
61 | ```
62 | python train_teacher.py --exp_setting tran --teacher GCN --dataset cora
63 | ```
64 |
65 | To quickly train a student model with a pretrained teacher you can run `train_student.py` by specifying the experiment setting, teacher model, student model, and dataset like the example below. Make sure you train the teacher using the `train_teacher.py` first and have its result stored in the correct path specified by `--out_t_path`.
66 |
67 | ```
68 | python train_student.py --exp_setting ind --teacher SAGE --student MLP --dataset citeseer --out_t_path outputs
69 | ```
70 |
71 | For more examples, and to reproduce results in the paper, please refer to scripts in `experiments/` as below.
72 |
73 | ```
74 | bash experiments/sage_cpf.sh
75 | ```
76 |
77 | To extend GLNN to your own model, you may do one of the following.
78 | - Add your favourite model architectures to the `Model` class in `model.py`. Then follow the examples above.
79 | - Train teacher model and store its output (log-probabilities). Then train the student by `train_student.py` with the correct `--out_t_path`.
80 |
81 |
82 | ## Results
83 |
84 | GraphSAGE vs. MLP vs. GLNN under the production setting described in the paper (transductive and inductive combined). Delta_MLP (Delta_GNN) represents difference between the GLNN and the MLP (GNN). Results show classification accuracy (higher is better); Delta_GNN > 0 indicates GLNN outperforms GNN. We observe that GLNNs always improve from MLPs by large margins and achieve competitive results as GNN on 6/7 datasets. Please see Table 3 in the paper for more details.
85 |
86 | | Datasets | GNN(SAGE) | MLP | GLNN | Delta_MLP | Delta_GNN |
87 | |------------|----------------|--------------|----------------|-----------------|-------------------|
88 | | Cora | **79.29** | 58.98 | 78.28 | 19.30 (32.72\%) | -1.01 (-1.28\%) |
89 | | Citseer | 68.38 | 59.81 | **69.27** | 9.46 (15.82\%) | 0.89 (1.30\%) |
90 | | Pubmed | **74.88** | 66.80 | 74.71 | 7.91 (11.83\%) | -0.17 (-0.22\%) |
91 | | A-computer | 82.14 | 67.38 | **82.29** | 14.90 (22.12\%) | 0.15 (0.19\%) |
92 | | A-photo | 91.08 | 79.25 | **92.38** | 13.13 (16.57\%) | 1.30 (1.42\%) |
93 | | Arxiv | **70.73** | 55.30 | 65.09 | 9.79 (17.70\%) | -5.64 (-7.97\%) |
94 | | Products | **76.60** | 63.72 | 75.77 | 12.05 (18.91\%) | -0.83 (-1.09\%) |
95 |
96 |
121 |
122 | ## Citation
123 |
124 | If you find our work useful, please cite the following:
125 |
126 | ```BibTeX
127 | @inproceedings{zhang2021graphless,
128 | title={Graph-less Neural Networks: Teaching Old MLPs New Tricks via Distillation},
129 | author={Shichang Zhang and Yozen Liu and Yizhou Sun and Neil Shah},
130 | booktitle={International Conference on Learning Representations}
131 | year={2022},
132 | url={https://arxiv.org/abs/2110.08727}
133 | }
134 | ```
135 |
136 | ## Contact Us
137 |
138 | Please open an issue or contact `shichang@cs.ucla.edu` if you have any questions.
139 |
140 |
141 |
--------------------------------------------------------------------------------
/data_preprocess.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from the CPF implementation
3 | https://github.com/BUPT-GAMMA/CPF/tree/389c01aaf238689ee7b1e5aba127842341e123b6/data
4 | """
5 |
6 | import numpy as np
7 | import scipy.sparse as sp
8 | from collections import Counter
9 | from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
10 |
11 |
12 | def is_binary_bag_of_words(features):
13 | features_coo = features.tocoo()
14 | return all(
15 | single_entry == 1.0
16 | for _, _, single_entry in zip(
17 | features_coo.row, features_coo.col, features_coo.data
18 | )
19 | )
20 |
21 |
22 | def to_binary_bag_of_words(features):
23 | """Converts TF/IDF features to binary bag-of-words features."""
24 | features_copy = features.tocsr()
25 | features_copy.data[:] = 1.0
26 | return features_copy
27 |
28 |
29 | def normalize(mx):
30 | """Row-normalize sparse matrix"""
31 | rowsum = np.array(mx.sum(1))
32 | r_inv = np.power(rowsum, -1).flatten()
33 | r_inv[np.isinf(r_inv)] = 0.0
34 | r_mat_inv = sp.diags(r_inv)
35 | mx = r_mat_inv.dot(mx)
36 | return mx
37 |
38 |
39 | def normalize_adj(adj):
40 | adj = normalize(adj + sp.eye(adj.shape[0]))
41 | return adj
42 |
43 |
44 | def eliminate_self_loops_adj(A):
45 | """Remove self-loops from the adjacency matrix."""
46 | A = A.tolil()
47 | A.setdiag(0)
48 | A = A.tocsr()
49 | A.eliminate_zeros()
50 | return A
51 |
52 |
53 | def largest_connected_components(sparse_graph, n_components=1):
54 | """Select the largest connected components in the graph.
55 |
56 | Parameters
57 | ----------
58 | sparse_graph : SparseGraph
59 | Input graph.
60 | n_components : int, default 1
61 | Number of largest connected components to keep.
62 |
63 | Returns
64 | -------
65 | sparse_graph : SparseGraph
66 | Subgraph of the input graph where only the nodes in largest n_components are kept.
67 |
68 | """
69 | _, component_indices = sp.csgraph.connected_components(sparse_graph.adj_matrix)
70 | component_sizes = np.bincount(component_indices)
71 | components_to_keep = np.argsort(component_sizes)[::-1][
72 | :n_components
73 | ] # reverse order to sort descending
74 | nodes_to_keep = [
75 | idx
76 | for (idx, component) in enumerate(component_indices)
77 | if component in components_to_keep
78 | ]
79 | return create_subgraph(sparse_graph, nodes_to_keep=nodes_to_keep)
80 |
81 |
82 | def create_subgraph(
83 | sparse_graph, _sentinel=None, nodes_to_remove=None, nodes_to_keep=None
84 | ):
85 | """Create a graph with the specified subset of nodes.
86 |
87 | Exactly one of (nodes_to_remove, nodes_to_keep) should be provided, while the other stays None.
88 | Note that to avoid confusion, it is required to pass node indices as named arguments to this function.
89 |
90 | Parameters
91 | ----------
92 | sparse_graph : SparseGraph
93 | Input graph.
94 | _sentinel : None
95 | Internal, to prevent passing positional arguments. Do not use.
96 | nodes_to_remove : array-like of int
97 | Indices of nodes that have to removed.
98 | nodes_to_keep : array-like of int
99 | Indices of nodes that have to be kept.
100 |
101 | Returns
102 | -------
103 | sparse_graph : SparseGraph
104 | Graph with specified nodes removed.
105 |
106 | """
107 | # Check that arguments are passed correctly
108 | if _sentinel is not None:
109 | raise ValueError(
110 | "Only call `create_subgraph` with named arguments',"
111 | " (nodes_to_remove=...) or (nodes_to_keep=...)"
112 | )
113 | if nodes_to_remove is None and nodes_to_keep is None:
114 | raise ValueError("Either nodes_to_remove or nodes_to_keep must be provided.")
115 | elif nodes_to_remove is not None and nodes_to_keep is not None:
116 | raise ValueError(
117 | "Only one of nodes_to_remove or nodes_to_keep must be provided."
118 | )
119 | elif nodes_to_remove is not None:
120 | nodes_to_keep = [
121 | i for i in range(sparse_graph.num_nodes()) if i not in nodes_to_remove
122 | ]
123 | elif nodes_to_keep is not None:
124 | nodes_to_keep = sorted(nodes_to_keep)
125 | else:
126 | raise RuntimeError("This should never happen.")
127 |
128 | sparse_graph.adj_matrix = sparse_graph.adj_matrix[nodes_to_keep][:, nodes_to_keep]
129 | if sparse_graph.attr_matrix is not None:
130 | sparse_graph.attr_matrix = sparse_graph.attr_matrix[nodes_to_keep]
131 | if sparse_graph.labels is not None:
132 | sparse_graph.labels = sparse_graph.labels[nodes_to_keep]
133 | if sparse_graph.node_names is not None:
134 | sparse_graph.node_names = sparse_graph.node_names[nodes_to_keep]
135 | return sparse_graph
136 |
137 |
138 | def binarize_labels(labels, sparse_output=False, return_classes=False):
139 | """Convert labels vector to a binary label matrix.
140 |
141 | In the default single-label case, labels look like
142 | labels = [y1, y2, y3, ...].
143 | Also supports the multi-label format.
144 | In this case, labels should look something like
145 | labels = [[y11, y12], [y21, y22, y23], [y31], ...].
146 |
147 | Parameters
148 | ----------
149 | labels : array-like, shape [num_samples]
150 | Array of node labels in categorical single- or multi-label format.
151 | sparse_output : bool, default False
152 | Whether return the label_matrix in CSR format.
153 | return_classes : bool, default False
154 | Whether return the classes corresponding to the columns of the label matrix.
155 |
156 | Returns
157 | -------
158 | label_matrix : np.ndarray or sp.csr_matrix, shape [num_samples, num_classes]
159 | Binary matrix of class labels.
160 | num_classes = number of unique values in "labels" array.
161 | label_matrix[i, k] = 1 <=> node i belongs to class k.
162 | classes : np.array, shape [num_classes], optional
163 | Classes that correspond to each column of the label_matrix.
164 |
165 | """
166 | if hasattr(labels[0], "__iter__"): # labels[0] is iterable <=> multilabel format
167 | binarizer = MultiLabelBinarizer(sparse_output=sparse_output)
168 | else:
169 | binarizer = LabelBinarizer(sparse_output=sparse_output)
170 | label_matrix = binarizer.fit_transform(labels).astype(np.float32)
171 | return (label_matrix, binarizer.classes_) if return_classes else label_matrix
172 |
173 |
174 | def remove_underrepresented_classes(
175 | g, train_examples_per_class, val_examples_per_class
176 | ):
177 | """Remove nodes from graph that correspond to a class of which there are less than
178 | num_classes * train_examples_per_class + num_classes * val_examples_per_class nodes.
179 |
180 | Those classes would otherwise break the training procedure.
181 | """
182 | min_examples_per_class = train_examples_per_class + val_examples_per_class
183 | examples_counter = Counter(g.labels)
184 | keep_classes = set(
185 | class_
186 | for class_, count in examples_counter.items()
187 | if count > min_examples_per_class
188 | )
189 | keep_indices = [i for i in range(len(g.labels)) if g.labels[i] in keep_classes]
190 |
191 | return create_subgraph(g, nodes_to_keep=keep_indices)
192 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | """
2 | Dataloader of CPF datasets are adapted from the CPF implementation
3 | https://github.com/BUPT-GAMMA/CPF/tree/389c01aaf238689ee7b1e5aba127842341e123b6/data
4 |
5 | Dataloader of NonHom datasets are adapted from the Non-homophily benchmarks
6 | https://github.com/CUAI/Non-Homophily-Benchmarks
7 |
8 | Dataloader of BGNN datasets are adapted from the BGNN implementation and dgl example of BGNN
9 | https://github.com/nd7141/bgnn
10 | https://github.com/dmlc/dgl/tree/473d5e0a4c4e4735f1c9dc9d783e0374328cca9a/examples/pytorch/bgnn
11 | """
12 |
13 | import numpy as np
14 | import scipy.sparse as sp
15 | import torch
16 | import dgl
17 | import os
18 | import scipy
19 | import pandas as pd
20 | import json
21 | from dgl.data.utils import load_graphs
22 | from os import path
23 | from category_encoders import CatBoostEncoder
24 | from pathlib import Path
25 | from google_drive_downloader import GoogleDriveDownloader as gdd
26 | from sklearn.preprocessing import label_binarize
27 | from sklearn import preprocessing
28 | from data_preprocess import (
29 | normalize_adj,
30 | eliminate_self_loops_adj,
31 | largest_connected_components,
32 | binarize_labels,
33 | )
34 | from ogb.nodeproppred import DglNodePropPredDataset
35 |
36 | CPF_data = ["cora", "citeseer", "pubmed", "a-computer", "a-photo"]
37 | OGB_data = ["ogbn-arxiv", "ogbn-products"]
38 | NonHom_data = ["pokec", "penn94"]
39 | BGNN_data = ["house_class", "vk_class"]
40 |
41 |
42 | def load_data(dataset, dataset_path, **kwargs):
43 | if dataset in CPF_data:
44 | return load_cpf_data(
45 | dataset,
46 | dataset_path,
47 | kwargs["seed"],
48 | kwargs["labelrate_train"],
49 | kwargs["labelrate_val"],
50 | )
51 | elif dataset in OGB_data:
52 | return load_ogb_data(dataset, dataset_path)
53 | elif dataset in NonHom_data:
54 | return load_nonhom_data(dataset, dataset_path, kwargs["split_idx"])
55 | elif dataset in BGNN_data:
56 | return load_bgnn_data(dataset, dataset_path, kwargs["split_idx"])
57 | else:
58 | raise ValueError(f"Unknown dataset: {dataset}")
59 |
60 |
61 | def load_ogb_data(dataset, dataset_path):
62 | data = DglNodePropPredDataset(dataset, dataset_path)
63 | splitted_idx = data.get_idx_split()
64 | idx_train, idx_val, idx_test = (
65 | splitted_idx["train"],
66 | splitted_idx["valid"],
67 | splitted_idx["test"],
68 | )
69 |
70 | g, labels = data[0]
71 | labels = labels.squeeze()
72 |
73 | # Turn the graph to undirected
74 | if dataset == "ogbn-arxiv":
75 | srcs, dsts = g.all_edges()
76 | g.add_edges(dsts, srcs)
77 | g = g.remove_self_loop().add_self_loop()
78 |
79 | return g, labels, idx_train, idx_val, idx_test
80 |
81 |
82 | def load_cpf_data(dataset, dataset_path, seed, labelrate_train, labelrate_val):
83 | data_path = Path.cwd().joinpath(dataset_path, f"{dataset}.npz")
84 | if os.path.isfile(data_path):
85 | data = load_npz_to_sparse_graph(data_path)
86 | else:
87 | raise ValueError(f"{data_path} doesn't exist.")
88 |
89 | # remove self loop and extract the largest CC
90 | data = data.standardize()
91 | adj, features, labels = data.unpack()
92 |
93 | labels = binarize_labels(labels)
94 |
95 | random_state = np.random.RandomState(seed)
96 | idx_train, idx_val, idx_test = get_train_val_test_split(
97 | random_state, labels, labelrate_train, labelrate_val
98 | )
99 |
100 | features = torch.FloatTensor(np.array(features.todense()))
101 | labels = torch.LongTensor(labels.argmax(axis=1))
102 |
103 | adj = normalize_adj(adj)
104 | adj_sp = adj.tocoo()
105 | g = dgl.graph((adj_sp.row, adj_sp.col))
106 | g.ndata["feat"] = features
107 |
108 | idx_train = torch.LongTensor(idx_train)
109 | idx_val = torch.LongTensor(idx_val)
110 | idx_test = torch.LongTensor(idx_test)
111 | return g, labels, idx_train, idx_val, idx_test
112 |
113 |
114 | def load_nonhom_data(dataset, dataset_path, split_idx):
115 | data_path = Path.cwd().joinpath(dataset_path, f"{dataset}.mat")
116 | data_split_path = Path.cwd().joinpath(
117 | dataset_path, "splits", f"{dataset}-splits.npy"
118 | )
119 |
120 | if dataset == "pokec":
121 | g, features, labels = load_pokec_mat(data_path)
122 | elif dataset == "penn94":
123 | g, features, labels = load_penn94_mat(data_path)
124 | else:
125 | raise ValueError("Invalid dataname")
126 |
127 | g = g.remove_self_loop().add_self_loop()
128 | g.ndata["feat"] = features
129 | labels = torch.LongTensor(labels)
130 |
131 | splitted_idx = load_fixed_splits(dataset, data_split_path, split_idx)
132 | idx_train, idx_val, idx_test = (
133 | splitted_idx["train"],
134 | splitted_idx["valid"],
135 | splitted_idx["test"],
136 | )
137 | return g, labels, idx_train, idx_val, idx_test
138 |
139 |
140 | def load_bgnn_data(dataset, dataset_path, split_idx):
141 | data_path = Path.cwd().joinpath(dataset_path, f"{dataset}")
142 |
143 | g, X, y, cat_features, masks = read_input(data_path)
144 | train_mask, val_mask, test_mask = (
145 | masks[str(split_idx)]["train"],
146 | masks[str(split_idx)]["val"],
147 | masks[str(split_idx)]["test"],
148 | )
149 |
150 | encoded_X = X.copy()
151 | if cat_features is not None and len(cat_features):
152 | encoded_X = encode_cat_features(
153 | encoded_X, y, cat_features, train_mask, val_mask, test_mask
154 | )
155 | encoded_X = normalize_features(encoded_X, train_mask, val_mask, test_mask)
156 | encoded_X = replace_na(encoded_X, train_mask)
157 | features, labels = pandas_to_torch(encoded_X, y)
158 |
159 | g = g.remove_self_loop().add_self_loop()
160 | g.ndata["feat"] = features
161 | labels = labels.long()
162 |
163 | idx_train = torch.LongTensor(train_mask)
164 | idx_val = torch.LongTensor(val_mask)
165 | idx_test = torch.LongTensor(test_mask)
166 | return g, labels, idx_train, idx_val, idx_test
167 |
168 |
169 | def load_out_t(out_t_dir):
170 | return torch.from_numpy(np.load(out_t_dir.joinpath("out.npz"))["arr_0"])
171 |
172 |
173 | """ For NonHom"""
174 | dataset_drive_url = {"pokec": "1dNs5E7BrWJbgcHeQ_zuy5Ozp2tRCWG0y"}
175 | splits_drive_url = {"pokec": "1ZhpAiyTNc0cE_hhgyiqxnkKREHK7MK-_"}
176 |
177 |
178 | def load_penn94_mat(data_path):
179 | mat = scipy.io.loadmat(data_path)
180 | A = mat["A"]
181 | metadata = mat["local_info"]
182 |
183 | edge_index = torch.tensor(A.nonzero(), dtype=torch.long)
184 | metadata = metadata.astype(np.int)
185 |
186 | # make features into one-hot encodings
187 | feature_vals = np.hstack((np.expand_dims(metadata[:, 0], 1), metadata[:, 2:]))
188 | features = np.empty((A.shape[0], 0))
189 | for col in range(feature_vals.shape[1]):
190 | feat_col = feature_vals[:, col]
191 | feat_onehot = label_binarize(feat_col, classes=np.unique(feat_col))
192 | features = np.hstack((features, feat_onehot))
193 |
194 | g = dgl.graph((edge_index[0], edge_index[1]))
195 | g = dgl.to_bidirected(g)
196 |
197 | features = torch.tensor(features, dtype=torch.float)
198 | labels = torch.tensor(metadata[:, 1] - 1) # gender label, -1 means unlabeled
199 | return g, features, labels
200 |
201 |
202 | def load_pokec_mat(data_path):
203 | if not path.exists(data_path):
204 | gdd.download_file_from_google_drive(
205 | file_id=dataset_drive_url["pokec"], dest_path=data_path, showsize=True
206 | )
207 |
208 | fulldata = scipy.io.loadmat(data_path)
209 | edge_index = torch.tensor(fulldata["edge_index"], dtype=torch.long)
210 | g = dgl.graph((edge_index[0], edge_index[1]))
211 | g = dgl.to_bidirected(g)
212 |
213 | features = torch.tensor(fulldata["node_feat"]).float()
214 | labels = fulldata["label"].flatten()
215 | return g, features, labels
216 |
217 |
218 | class NCDataset(object):
219 | def __init__(self, name, root):
220 | """
221 | based off of ogb NodePropPredDataset
222 | https://github.com/snap-stanford/ogb/blob/master/ogb/nodeproppred/dataset.py
223 | Gives torch tensors instead of numpy arrays
224 | - name (str): name of the dataset
225 | - root (str): root directory to store the dataset folder
226 | - meta_dict: dictionary that stores all the meta-information about data. Default is None,
227 | but when something is passed, it uses its information. Useful for debugging for external contributers.
228 |
229 | Usage after construction:
230 |
231 | split_idx = dataset.get_idx_split()
232 | train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
233 | graph, label = dataset[0]
234 |
235 | Where the graph is a dictionary of the following form:
236 | dataset.graph = {'edge_index': edge_index,
237 | 'edge_feat': None,
238 | 'node_feat': node_feat,
239 | 'num_nodes': num_nodes}
240 | For additional documentation, see OGB Library-Agnostic Loader https://ogb.stanford.edu/docs/nodeprop/
241 |
242 | """
243 |
244 | self.name = name # original name, e.g., ogbn-proteins
245 | self.graph = {}
246 | self.label = None
247 |
248 | def rand_train_test_idx(label, train_prop, valid_prop, ignore_negative):
249 | """
250 | Randomly splits the dataset into train, validation, and test sets.
251 | """
252 | if ignore_negative:
253 | non_negative_idx = np.where(label >= 0)[0]
254 | else:
255 | non_negative_idx = np.arange(len(label))
256 |
257 | num_nodes = len(non_negative_idx)
258 | num_train = int(train_prop * num_nodes)
259 | num_valid = int(valid_prop * num_nodes)
260 | num_test = num_nodes - num_train - num_valid
261 |
262 | idx = np.random.permutation(non_negative_idx)
263 |
264 | train_idx = idx[:num_train]
265 | valid_idx = idx[num_train: num_train + num_valid]
266 | test_idx = idx[num_train + num_valid:]
267 |
268 | return train_idx, valid_idx, test_idx
269 |
270 | def get_idx_split(self, split_type="random", train_prop=0.5, valid_prop=0.25):
271 | """
272 | train_prop: The proportion of dataset for train split. Between 0 and 1.
273 | valid_prop: The proportion of dataset for validation split. Between 0 and 1.
274 | """
275 |
276 | if split_type == "random":
277 | ignore_negative = False if self.name == "ogbn-proteins" else True
278 | train_idx, valid_idx, test_idx = self.rand_train_test_idx(
279 | self.label,
280 | train_prop=train_prop,
281 | valid_prop=valid_prop,
282 | ignore_negative=ignore_negative,
283 | )
284 | split_idx = {"train": train_idx, "valid": valid_idx, "test": test_idx}
285 | return split_idx
286 |
287 | def __getitem__(self, idx):
288 | assert idx == 0, "This dataset has only one graph"
289 | return self.graph, self.label
290 |
291 | def __len__(self):
292 | return 1
293 |
294 | def __repr__(self):
295 | return "{}({})".format(self.__class__.__name__, len(self))
296 |
297 |
298 | def load_fixed_splits(dataset, data_split_path="", split_idx=0):
299 | if not os.path.exists(data_split_path):
300 | assert dataset in splits_drive_url.keys()
301 | gdd.download_file_from_google_drive(
302 | file_id=splits_drive_url[dataset], dest_path=data_split_path, showsize=True
303 | )
304 |
305 | splits_lst = np.load(data_split_path, allow_pickle=True)
306 | splits = splits_lst[split_idx]
307 |
308 | for key in splits:
309 | if not torch.is_tensor(splits[key]):
310 | splits[key] = torch.as_tensor(splits[key])
311 |
312 | return splits
313 |
314 |
315 | """For BGNN """
316 |
317 |
318 | def pandas_to_torch(*args):
319 | return [torch.from_numpy(arg.to_numpy(copy=True)).float().squeeze() for arg in args]
320 |
321 |
322 | def read_input(input_folder):
323 | X = pd.read_csv(f"{input_folder}/X.csv")
324 | y = pd.read_csv(f"{input_folder}/y.csv")
325 |
326 | categorical_columns = []
327 | if os.path.exists(f"{input_folder}/cat_features.txt"):
328 | with open(f"{input_folder}/cat_features.txt") as f:
329 | for line in f:
330 | if line.strip():
331 | categorical_columns.append(line.strip())
332 |
333 | cat_features = None
334 | if categorical_columns:
335 | columns = X.columns
336 | cat_features = np.where(columns.isin(categorical_columns))[0]
337 |
338 | for col in list(columns[cat_features]):
339 | X[col] = X[col].astype(str)
340 |
341 | gs, _ = load_graphs(f"{input_folder}/graph.dgl")
342 | graph = gs[0]
343 |
344 | with open(f"{input_folder}/masks.json") as f:
345 | masks = json.load(f)
346 |
347 | return graph, X, y, cat_features, masks
348 |
349 |
350 | def normalize_features(X, train_mask, val_mask, test_mask):
351 | min_max_scaler = preprocessing.MinMaxScaler()
352 | A = X.to_numpy(copy=True)
353 | A[train_mask] = min_max_scaler.fit_transform(A[train_mask])
354 | A[val_mask + test_mask] = min_max_scaler.transform(A[val_mask + test_mask])
355 | return pd.DataFrame(A, columns=X.columns).astype(float)
356 |
357 |
358 | def replace_na(X, train_mask):
359 | if X.isna().any().any():
360 | return X.fillna(X.iloc[train_mask].min() - 1)
361 | return X
362 |
363 |
364 | def encode_cat_features(X, y, cat_features, train_mask, val_mask, test_mask):
365 | enc = CatBoostEncoder()
366 | A = X.to_numpy(copy=True)
367 | b = y.to_numpy(copy=True)
368 | A[np.ix_(train_mask, cat_features)] = enc.fit_transform(
369 | A[np.ix_(train_mask, cat_features)], b[train_mask]
370 | )
371 | A[np.ix_(val_mask + test_mask, cat_features)] = enc.transform(
372 | A[np.ix_(val_mask + test_mask, cat_features)]
373 | )
374 | A = A.astype(float)
375 | return pd.DataFrame(A, columns=X.columns)
376 |
377 |
378 | """ For CPF"""
379 |
380 |
381 | class SparseGraph:
382 | """Attributed labeled graph stored in sparse matrix form."""
383 |
384 | def __init__(
385 | self,
386 | adj_matrix,
387 | attr_matrix=None,
388 | labels=None,
389 | node_names=None,
390 | attr_names=None,
391 | class_names=None,
392 | metadata=None,
393 | ):
394 | """Create an attributed graph.
395 |
396 | Parameters
397 | ----------
398 | adj_matrix : sp.csr_matrix, shape [num_nodes, num_nodes]
399 | Adjacency matrix in CSR format.
400 | attr_matrix : sp.csr_matrix or np.ndarray, shape [num_nodes, num_attr], optional
401 | Attribute matrix in CSR or numpy format.
402 | labels : np.ndarray, shape [num_nodes], optional
403 | Array, where each entry represents respective node's label(s).
404 | node_names : np.ndarray, shape [num_nodes], optional
405 | Names of nodes (as strings).
406 | attr_names : np.ndarray, shape [num_attr]
407 | Names of the attributes (as strings).
408 | class_names : np.ndarray, shape [num_classes], optional
409 | Names of the class labels (as strings).
410 | metadata : object
411 | Additional metadata such as text.
412 |
413 | """
414 | # Make sure that the dimensions of matrices / arrays all agree
415 | if sp.isspmatrix(adj_matrix):
416 | adj_matrix = adj_matrix.tocsr().astype(np.float32)
417 | else:
418 | raise ValueError(
419 | "Adjacency matrix must be in sparse format (got {0} instead)".format(
420 | type(adj_matrix)
421 | )
422 | )
423 |
424 | if adj_matrix.shape[0] != adj_matrix.shape[1]:
425 | raise ValueError("Dimensions of the adjacency matrix don't agree")
426 |
427 | if attr_matrix is not None:
428 | if sp.isspmatrix(attr_matrix):
429 | attr_matrix = attr_matrix.tocsr().astype(np.float32)
430 | elif isinstance(attr_matrix, np.ndarray):
431 | attr_matrix = attr_matrix.astype(np.float32)
432 | else:
433 | raise ValueError(
434 | "Attribute matrix must be a sp.spmatrix or a np.ndarray (got {0} instead)".format(
435 | type(attr_matrix)
436 | )
437 | )
438 |
439 | if attr_matrix.shape[0] != adj_matrix.shape[0]:
440 | raise ValueError(
441 | "Dimensions of the adjacency and attribute matrices don't agree"
442 | )
443 |
444 | if labels is not None:
445 | if labels.shape[0] != adj_matrix.shape[0]:
446 | raise ValueError(
447 | "Dimensions of the adjacency matrix and the label vector don't agree"
448 | )
449 |
450 | if node_names is not None:
451 | if len(node_names) != adj_matrix.shape[0]:
452 | raise ValueError(
453 | "Dimensions of the adjacency matrix and the node names don't agree"
454 | )
455 |
456 | if attr_names is not None:
457 | if len(attr_names) != attr_matrix.shape[1]:
458 | raise ValueError(
459 | "Dimensions of the attribute matrix and the attribute names don't agree"
460 | )
461 |
462 | self.adj_matrix = adj_matrix
463 | self.attr_matrix = attr_matrix
464 | self.labels = labels
465 | self.node_names = node_names
466 | self.attr_names = attr_names
467 | self.class_names = class_names
468 | self.metadata = metadata
469 |
470 | def num_nodes(self):
471 | """Get the number of nodes in the graph."""
472 | return self.adj_matrix.shape[0]
473 |
474 | def num_edges(self):
475 | """Get the number of edges in the graph.
476 |
477 | For undirected graphs, (i, j) and (j, i) are counted as single edge.
478 | """
479 | if self.is_directed():
480 | return int(self.adj_matrix.nnz)
481 | else:
482 | return int(self.adj_matrix.nnz / 2)
483 |
484 | def get_neighbors(self, idx):
485 | """Get the indices of neighbors of a given node.
486 |
487 | Parameters
488 | ----------
489 | idx : int
490 | Index of the node whose neighbors are of interest.
491 |
492 | """
493 | return self.adj_matrix[idx].indices
494 |
495 | def is_directed(self):
496 | """Check if the graph is directed (adjacency matrix is not symmetric)."""
497 | return (self.adj_matrix != self.adj_matrix.T).sum() != 0
498 |
499 | def to_undirected(self):
500 | """Convert to an undirected graph (make adjacency matrix symmetric)."""
501 | if self.is_weighted():
502 | raise ValueError("Convert to unweighted graph first.")
503 | else:
504 | self.adj_matrix = self.adj_matrix + self.adj_matrix.T
505 | self.adj_matrix[self.adj_matrix != 0] = 1
506 | return self
507 |
508 | def is_weighted(self):
509 | """Check if the graph is weighted (edge weights other than 1)."""
510 | return np.any(np.unique(self.adj_matrix[self.adj_matrix != 0].A1) != 1)
511 |
512 | def to_unweighted(self):
513 | """Convert to an unweighted graph (set all edge weights to 1)."""
514 | self.adj_matrix.data = np.ones_like(self.adj_matrix.data)
515 | return self
516 |
517 | # Quality of life (shortcuts)
518 | def standardize(self):
519 | """Select the LCC of the unweighted/undirected/no-self-loop graph.
520 |
521 | All changes are done inplace.
522 |
523 | """
524 | G = self.to_unweighted().to_undirected()
525 | G.adj_matrix = eliminate_self_loops_adj(G.adj_matrix)
526 | G = largest_connected_components(G, 1)
527 | return G
528 |
529 | def unpack(self):
530 | """Return the (A, X, z) triplet."""
531 | return self.adj_matrix, self.attr_matrix, self.labels
532 |
533 |
534 | def load_npz_to_sparse_graph(file_name):
535 | """Load a SparseGraph from a Numpy binary file.
536 |
537 | Parameters
538 | ----------
539 | file_name : str
540 | Name of the file to load.
541 |
542 | Returns
543 | -------
544 | sparse_graph : SparseGraph
545 | Graph in sparse matrix format.
546 |
547 | """
548 | with np.load(file_name, allow_pickle=True) as loader:
549 | loader = dict(loader)
550 | adj_matrix = sp.csr_matrix(
551 | (loader["adj_data"], loader["adj_indices"], loader["adj_indptr"]),
552 | shape=loader["adj_shape"],
553 | )
554 |
555 | if "attr_data" in loader:
556 | # Attributes are stored as a sparse CSR matrix
557 | attr_matrix = sp.csr_matrix(
558 | (loader["attr_data"], loader["attr_indices"], loader["attr_indptr"]),
559 | shape=loader["attr_shape"],
560 | )
561 | elif "attr_matrix" in loader:
562 | # Attributes are stored as a (dense) np.ndarray
563 | attr_matrix = loader["attr_matrix"]
564 | else:
565 | attr_matrix = None
566 |
567 | if "labels_data" in loader:
568 | # Labels are stored as a CSR matrix
569 | labels = sp.csr_matrix(
570 | (
571 | loader["labels_data"],
572 | loader["labels_indices"],
573 | loader["labels_indptr"],
574 | ),
575 | shape=loader["labels_shape"],
576 | )
577 | elif "labels" in loader:
578 | # Labels are stored as a numpy array
579 | labels = loader["labels"]
580 | else:
581 | labels = None
582 |
583 | node_names = loader.get("node_names")
584 | attr_names = loader.get("attr_names")
585 | class_names = loader.get("class_names")
586 | metadata = loader.get("metadata")
587 |
588 | return SparseGraph(
589 | adj_matrix, attr_matrix, labels, node_names, attr_names, class_names, metadata
590 | )
591 |
592 |
593 | def sample_per_class(
594 | random_state, labels, num_examples_per_class, forbidden_indices=None
595 | ):
596 | """
597 | Used in get_train_val_test_split, when we try to get a fixed number of examples per class
598 | """
599 |
600 | num_samples, num_classes = labels.shape
601 | sample_indices_per_class = {index: [] for index in range(num_classes)}
602 |
603 | # get indices sorted by class
604 | for class_index in range(num_classes):
605 | for sample_index in range(num_samples):
606 | if labels[sample_index, class_index] > 0.0:
607 | if forbidden_indices is None or sample_index not in forbidden_indices:
608 | sample_indices_per_class[class_index].append(sample_index)
609 |
610 | # get specified number of indices for each class
611 | return np.concatenate(
612 | [
613 | random_state.choice(
614 | sample_indices_per_class[class_index],
615 | num_examples_per_class,
616 | replace=False,
617 | )
618 | for class_index in range(len(sample_indices_per_class))
619 | ]
620 | )
621 |
622 |
623 | def get_train_val_test_split(
624 | random_state,
625 | labels,
626 | train_examples_per_class=None,
627 | val_examples_per_class=None,
628 | test_examples_per_class=None,
629 | train_size=None,
630 | val_size=None,
631 | test_size=None,
632 | ):
633 |
634 | num_samples, num_classes = labels.shape
635 | remaining_indices = list(range(num_samples))
636 | if train_examples_per_class is not None:
637 | train_indices = sample_per_class(random_state, labels, train_examples_per_class)
638 | else:
639 | # select train examples with no respect to class distribution
640 | train_indices = random_state.choice(
641 | remaining_indices, train_size, replace=False
642 | )
643 |
644 | if val_examples_per_class is not None:
645 | val_indices = sample_per_class(
646 | random_state,
647 | labels,
648 | val_examples_per_class,
649 | forbidden_indices=train_indices,
650 | )
651 | else:
652 | remaining_indices = np.setdiff1d(remaining_indices, train_indices)
653 | val_indices = random_state.choice(remaining_indices, val_size, replace=False)
654 |
655 | forbidden_indices = np.concatenate((train_indices, val_indices))
656 | if test_examples_per_class is not None:
657 | test_indices = sample_per_class(
658 | random_state,
659 | labels,
660 | test_examples_per_class,
661 | forbidden_indices=forbidden_indices,
662 | )
663 | elif test_size is not None:
664 | remaining_indices = np.setdiff1d(remaining_indices, forbidden_indices)
665 | test_indices = random_state.choice(remaining_indices, test_size, replace=False)
666 | else:
667 | test_indices = np.setdiff1d(remaining_indices, forbidden_indices)
668 |
669 | # assert that there are no duplicates in sets
670 | assert len(set(train_indices)) == len(train_indices)
671 | assert len(set(val_indices)) == len(val_indices)
672 | assert len(set(test_indices)) == len(test_indices)
673 | # assert sets are mutually exclusive
674 | assert len(set(train_indices) - set(val_indices)) == len(set(train_indices))
675 | assert len(set(train_indices) - set(test_indices)) == len(set(train_indices))
676 | assert len(set(val_indices) - set(test_indices)) == len(set(val_indices))
677 | if test_size is None and test_examples_per_class is None:
678 | # all indices must be part of the split
679 | assert (
680 | len(np.concatenate((train_indices, val_indices, test_indices)))
681 | == num_samples
682 | )
683 |
684 | if train_examples_per_class is not None:
685 | train_labels = labels[train_indices, :]
686 | train_sum = np.sum(train_labels, axis=0)
687 | # assert all classes have equal cardinality
688 | assert np.unique(train_sum).size == 1
689 |
690 | if val_examples_per_class is not None:
691 | val_labels = labels[val_indices, :]
692 | val_sum = np.sum(val_labels, axis=0)
693 | # assert all classes have equal cardinality
694 | assert np.unique(val_sum).size == 1
695 |
696 | if test_examples_per_class is not None:
697 | test_labels = labels[test_indices, :]
698 | test_sum = np.sum(test_labels, axis=0)
699 | # assert all classes have equal cardinality
700 | assert np.unique(test_sum).size == 1
701 |
702 | return train_indices, val_indices, test_indices
703 |
--------------------------------------------------------------------------------
/experiments/ablation_feature_noise.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Train SAGE with 10 different node feature noise levels: 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1
4 | # on five datasets: "cora" "citeseer" "pubmed" "a-computer" "a-photo"
5 | # Then train corresponding GLNN for each level
6 |
7 | aggregated_result_file="ablation_feature_noise.txt"
8 | printf "Teacher\n" >> $aggregated_result_file
9 |
10 | for n in 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1
11 | do
12 | printf "%3s\n" $n >> $aggregated_result_file
13 | for ds in "cora" "citeseer" "pubmed" "a-computer" "a-photo"
14 | do
15 | printf "%10s\t" $ds >> $aggregated_result_file
16 | python train_teacher.py --exp_setting "ind" --teacher "SAGE" --dataset $ds --feature_noise $n \
17 | --num_exp 5 --max_epoch 200 --patience 50 >> $aggregated_result_file
18 | done
19 | printf "\n" >> $aggregated_result_file
20 | done
21 |
22 | printf "Student\n" >> $aggregated_result_file
23 |
24 | for n in 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1
25 | do
26 | printf "%3s\n" $n >> $aggregated_result_file
27 | for ds in "cora" "citeseer" "pubmed" "a-computer" "a-photo"
28 | do
29 | printf "%10s\t" $ds >> $aggregated_result_file
30 | python train_student.py --exp_setting "ind" --teacher "SAGE" --dataset $ds --feature_noise $n \
31 | --num_exp 5 --max_epoch 200 --patience 50 >> $aggregated_result_file
32 | done
33 | printf "\n" >> $aggregated_result_file
34 | done
35 |
36 |
--------------------------------------------------------------------------------
/experiments/ablation_gnn.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | # Train five different teachers "GCN" "GAT" "SAGE" "MLP" "APPNP"
5 | # on five datasets: "cora" "citeseer" "pubmed" "a-computer" "a-photo"
6 | # Then train corresponding GLNN for each teacher
7 |
8 | aggregated_result_file="ablation_gnn.txt"
9 | printf "Teacher\n" >> $aggregated_result_file
10 |
11 | for e in "tran" "ind"
12 | do
13 | printf "%6s\n" $e >> $aggregated_result_file
14 | for t in "GCN" "GAT" "SAGE" "MLP" "APPNP"
15 | do
16 | printf "%6s\n" $t >> $aggregated_result_file
17 | for ds in "cora" "citeseer" "pubmed" "a-computer" "a-photo"
18 | do
19 | printf "%10s\t" $ds >> $aggregated_result_file
20 | python train_teacher.py --exp_setting $e --teacher $t --dataset $ds --num_exp 5 \
21 | --max_epoch 200 --patience 50 >> $aggregated_result_file
22 | done
23 | printf "\n" >> $aggregated_result_file
24 | done
25 | printf "\n" >> $aggregated_result_file
26 | done
27 |
28 | printf "Student\n" >> $aggregated_result_file
29 |
30 | for e in "tran" "ind"
31 | do
32 | printf "%6s\n" $e >> $aggregated_result_file
33 | for t in "GCN" "GAT" "SAGE" "APPNP"
34 | do
35 | printf "%6s\n" $t >> $aggregated_result_file
36 | for ds in "cora" "citeseer" "pubmed" "a-computer" "a-photo"
37 | do
38 | printf "%10s\t" $ds >> $aggregated_result_file
39 | python train_student.py --exp_setting $e --teacher $t --dataset $ds --num_exp 5 \
40 | --max_epoch 200 --patience 50 >> $aggregated_result_file
41 | done
42 | printf "\n" >> $aggregated_result_file
43 | done
44 | printf "\n" >> $aggregated_result_file
45 | done
46 |
--------------------------------------------------------------------------------
/experiments/ablation_ind_split_rate.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Train SAGE with 9 different inductive split rates: 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9
4 | # on five datasets: "cora" "citeseer" "pubmed" "a-computer" "a-photo"
5 | # Then train corresponding GLNN for each split
6 |
7 | aggregated_result_file="ablation_ind_split_rate.txt"
8 | printf "Teacher\n" >> $aggregated_result_file
9 |
10 | for r in 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9
11 | do
12 | printf "%3s\n" $r >> $aggregated_result_file
13 | for ds in "cora" "citeseer" "pubmed" "a-computer" "a-photo"
14 | do
15 | printf "%10s\t" $ds >> $aggregated_result_file
16 | python train_teacher.py --exp_setting "ind" --teacher "SAGE" --dataset $ds --split_rate $r \
17 | --num_exp 5 --max_epoch 200 --patience 50 >> $aggregated_result_file
18 | done
19 | printf "\n" >> $aggregated_result_file
20 | done
21 |
22 | printf "Student\n" >> $aggregated_result_file
23 |
24 | for r in 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9
25 | do
26 | printf "%3s\n" $r >> $aggregated_result_file
27 | for ds in "cora" "citeseer" "pubmed" "a-computer" "a-photo"
28 | do
29 | printf "%10s\t" $ds >> $aggregated_result_file
30 | python train_student.py --exp_setting "ind" --teacher "SAGE" --dataset $ds --split_rate $r \
31 | --num_exp 5 --max_epoch 200 --patience 50 >> $aggregated_result_file
32 | done
33 | printf "\n" >> $aggregated_result_file
34 | done
35 |
36 |
--------------------------------------------------------------------------------
/experiments/ga_glnn_arxiv.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Train 1-hop GA-MLP and 1-hop GA-GLNN with SAGE teacher on "ogbn-arxiv" under the inductive setting
4 |
5 | python train_teacher.py --exp_setting "ind" --teacher "MLP3w4" --dataset "ogbn-arxiv" \
6 | --num_exp 5 --max_epoch 200 --patience 50 \
7 | --feature_aug_k 1
8 |
9 | python train_student.py --exp_setting "ind" --teacher "SAGE" --student "MLP3w4" --dataset "ogbn-arxiv" \
10 | --num_exp 5 --max_epoch 200 --patience 50 \
11 | --feature_aug_k 1
12 |
13 |
--------------------------------------------------------------------------------
/experiments/glnn_arxiv.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Train GLNN with SAGE teacher on "ogbn-arxiv"
4 |
5 | for e in "tran" "ind"
6 | do
7 | python train_student.py --exp_setting $e --teacher "SAGE" --student "MLP3w4" --dataset "ogbn-arxiv" \
8 | --num_exp 10 --max_epoch 200 --patience 50
9 | done
10 |
--------------------------------------------------------------------------------
/experiments/glnn_cpf.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Train GLNN with SAGE teacher on five datasets: "cora" "citeseer" "pubmed" "a-computer" "a-photo"
4 |
5 | aggregated_result_file="glnn_cpf.txt"
6 | for e in "tran" "ind"
7 | do
8 | printf "%6s\n" $e >> $aggregated_result_file
9 | for ds in "cora" "citeseer" "pubmed" "a-computer" "a-photo"
10 | do
11 | printf "%10s\t" $ds >> $aggregated_result_file
12 | python train_student.py --exp_setting $e --teacher "SAGE" --dataset $ds --num_exp 10 \
13 | --max_epoch 200 --patience 50 \
14 | --save_results >> $aggregated_result_file
15 | done
16 | printf "\n" >> $aggregated_result_file
17 | done
18 |
--------------------------------------------------------------------------------
/experiments/glnn_products.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Train GLNN with SAGE teacher on "ogbn-products"
4 |
5 | for e in "tran" "ind"
6 | do
7 | python train_student.py --exp_setting $e --teacher "SAGE" --student "MLP3w8" --dataset "ogbn-products" \
8 | --num_exp 10 --max_epoch 200 --patience 30
9 | done
10 |
--------------------------------------------------------------------------------
/experiments/sage_arxiv.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Train SAGE teacher on "ogbn-arxiv"
4 |
5 | for e in "tran" "ind"
6 | do
7 | python train_teacher.py --exp_setting $e --teacher "SAGE" --dataset "ogbn-arxiv" \
8 | --num_exp 10 --max_epoch 200 --patience 50 \
9 | --save_results
10 | done
11 |
--------------------------------------------------------------------------------
/experiments/sage_cpf.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Train SAGE teacher on five datasets: "cora" "citeseer" "pubmed" "a-computer" "a-photo"
4 |
5 | aggregated_result_file="sage_cpf.txt"
6 | for e in "tran" "ind"
7 | do
8 | printf "%6s\n" $e >> $aggregated_result_file
9 | for ds in "cora" "citeseer" "pubmed" "a-computer" "a-photo"
10 | do
11 | printf "%10s\t" $ds >> $aggregated_result_file
12 | python train_teacher.py --exp_setting $e --teacher "SAGE" --dataset $ds \
13 | --num_exp 10 --max_epoch 200 --patience 50 \
14 | --save_results >> $aggregated_result_file
15 | done
16 | printf "\n" >> $aggregated_result_file
17 | done
18 |
--------------------------------------------------------------------------------
/experiments/sage_products.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Train SAGE teacher on "ogbn-products"
4 |
5 | for e in "tran" "ind"
6 | do
7 | python train_teacher.py --exp_setting $e --teacher "SAGE" --dataset "ogbn-products" \
8 | --num_exp 10 --max_epoch 40 --patience 10 \
9 | --save_results
10 | done
11 |
--------------------------------------------------------------------------------
/imgs/glnn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snap-research/graphless-neural-networks/80d0f8ea66ad49b27336f9c02691ca16a6dc1dd2/imgs/glnn.png
--------------------------------------------------------------------------------
/imgs/trade_off.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snap-research/graphless-neural-networks/80d0f8ea66ad49b27336f9c02691ca16a6dc1dd2/imgs/trade_off.png
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from dgl.nn import GraphConv, SAGEConv, APPNPConv, GATConv
5 |
6 |
7 | class MLP(nn.Module):
8 | def __init__(
9 | self,
10 | num_layers,
11 | input_dim,
12 | hidden_dim,
13 | output_dim,
14 | dropout_ratio,
15 | norm_type="none",
16 | ):
17 | super(MLP, self).__init__()
18 | self.num_layers = num_layers
19 | self.norm_type = norm_type
20 | self.dropout = nn.Dropout(dropout_ratio)
21 | self.layers = nn.ModuleList()
22 | self.norms = nn.ModuleList()
23 |
24 | if num_layers == 1:
25 | self.layers.append(nn.Linear(input_dim, output_dim))
26 | else:
27 | self.layers.append(nn.Linear(input_dim, hidden_dim))
28 | if self.norm_type == "batch":
29 | self.norms.append(nn.BatchNorm1d(hidden_dim))
30 | elif self.norm_type == "layer":
31 | self.norms.append(nn.LayerNorm(hidden_dim))
32 |
33 | for i in range(num_layers - 2):
34 | self.layers.append(nn.Linear(hidden_dim, hidden_dim))
35 | if self.norm_type == "batch":
36 | self.norms.append(nn.BatchNorm1d(hidden_dim))
37 | elif self.norm_type == "layer":
38 | self.norms.append(nn.LayerNorm(hidden_dim))
39 |
40 | self.layers.append(nn.Linear(hidden_dim, output_dim))
41 |
42 | def forward(self, feats):
43 | h = feats
44 | h_list = []
45 | for l, layer in enumerate(self.layers):
46 | h = layer(h)
47 | if l != self.num_layers - 1:
48 | h_list.append(h)
49 | if self.norm_type != "none":
50 | h = self.norms[l](h)
51 | h = F.relu(h)
52 | h = self.dropout(h)
53 | return h_list, h
54 |
55 |
56 | """
57 | Adapted from the SAGE implementation from the official DGL example
58 | https://github.com/dmlc/dgl/blob/master/examples/pytorch/ogb/ogbn-products/graphsage/main.py
59 | """
60 |
61 |
62 | class SAGE(nn.Module):
63 | def __init__(
64 | self,
65 | num_layers,
66 | input_dim,
67 | hidden_dim,
68 | output_dim,
69 | dropout_ratio,
70 | activation,
71 | norm_type="none",
72 | ):
73 | super().__init__()
74 | self.num_layers = num_layers
75 | self.hidden_dim = hidden_dim
76 | self.output_dim = output_dim
77 | self.norm_type = norm_type
78 | self.activation = activation
79 | self.dropout = nn.Dropout(dropout_ratio)
80 | self.layers = nn.ModuleList()
81 | self.norms = nn.ModuleList()
82 |
83 | if num_layers == 1:
84 | self.layers.append(SAGEConv(input_dim, output_dim, "gcn"))
85 | else:
86 | self.layers.append(SAGEConv(input_dim, hidden_dim, "gcn"))
87 | if self.norm_type == "batch":
88 | self.norms.append(nn.BatchNorm1d(hidden_dim))
89 | elif self.norm_type == "layer":
90 | self.norms.append(nn.LayerNorm(hidden_dim))
91 |
92 | for i in range(num_layers - 2):
93 | self.layers.append(SAGEConv(hidden_dim, hidden_dim, "gcn"))
94 | if self.norm_type == "batch":
95 | self.norms.append(nn.BatchNorm1d(hidden_dim))
96 | elif self.norm_type == "layer":
97 | self.norms.append(nn.LayerNorm(hidden_dim))
98 |
99 | self.layers.append(SAGEConv(hidden_dim, output_dim, "gcn"))
100 |
101 | def forward(self, blocks, feats):
102 | h = feats
103 | h_list = []
104 | for l, (layer, block) in enumerate(zip(self.layers, blocks)):
105 | # We need to first copy the representation of nodes on the RHS from the
106 | # appropriate nodes on the LHS.
107 | # Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
108 | # would be (num_nodes_RHS, D)
109 | h_dst = h[: block.num_dst_nodes()]
110 | # Then we compute the updated representation on the RHS.
111 | # The shape of h now becomes (num_nodes_RHS, D)
112 | h = layer(block, (h, h_dst))
113 | if l != self.num_layers - 1:
114 | h_list.append(h)
115 | if self.norm_type != "none":
116 | h = self.norms[l](h)
117 | h = self.activation(h)
118 | h = self.dropout(h)
119 | return h_list, h
120 |
121 | def inference(self, dataloader, feats):
122 | """
123 | Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
124 | dataloader : The entire graph loaded in blocks with full neighbors for each node.
125 | feats : The input feats of entire node set.
126 | """
127 | device = feats.device
128 | for l, layer in enumerate(self.layers):
129 | y = torch.zeros(
130 | feats.shape[0],
131 | self.hidden_dim if l != self.num_layers - 1 else self.output_dim,
132 | ).to(device)
133 | for input_nodes, output_nodes, blocks in dataloader:
134 | block = blocks[0].int().to(device)
135 |
136 | h = feats[input_nodes]
137 | h_dst = h[: block.num_dst_nodes()]
138 | h = layer(block, (h, h_dst))
139 | if l != self.num_layers - 1:
140 | if self.norm_type != "none":
141 | h = self.norms[l](h)
142 | h = self.activation(h)
143 | h = self.dropout(h)
144 |
145 | y[output_nodes] = h
146 |
147 | feats = y
148 | return y
149 |
150 |
151 | class GCN(nn.Module):
152 | def __init__(
153 | self,
154 | num_layers,
155 | input_dim,
156 | hidden_dim,
157 | output_dim,
158 | dropout_ratio,
159 | activation,
160 | norm_type="none",
161 | ):
162 | super().__init__()
163 | self.num_layers = num_layers
164 | self.norm_type = norm_type
165 | self.dropout = nn.Dropout(dropout_ratio)
166 | self.layers = nn.ModuleList()
167 | self.norms = nn.ModuleList()
168 |
169 | if num_layers == 1:
170 | self.layers.append(GraphConv(input_dim, output_dim, activation=activation))
171 | else:
172 | self.layers.append(GraphConv(input_dim, hidden_dim, activation=activation))
173 | if self.norm_type == "batch":
174 | self.norms.append(nn.BatchNorm1d(hidden_dim))
175 | elif self.norm_type == "layer":
176 | self.norms.append(nn.LayerNorm(hidden_dim))
177 |
178 | for i in range(num_layers - 2):
179 | self.layers.append(
180 | GraphConv(hidden_dim, hidden_dim, activation=activation)
181 | )
182 | if self.norm_type == "batch":
183 | self.norms.append(nn.BatchNorm1d(hidden_dim))
184 | elif self.norm_type == "layer":
185 | self.norms.append(nn.LayerNorm(hidden_dim))
186 |
187 | self.layers.append(GraphConv(hidden_dim, output_dim))
188 |
189 | def forward(self, g, feats):
190 | h = feats
191 | h_list = []
192 | for l, layer in enumerate(self.layers):
193 | h = layer(g, h)
194 | if l != self.num_layers - 1:
195 | h_list.append(h)
196 | if self.norm_type != "none":
197 | h = self.norms[l](h)
198 | h = self.dropout(h)
199 | return h_list, h
200 |
201 |
202 | class GAT(nn.Module):
203 | def __init__(
204 | self,
205 | num_layers,
206 | input_dim,
207 | hidden_dim,
208 | output_dim,
209 | dropout_ratio,
210 | activation,
211 | num_heads=8,
212 | attn_drop=0.3,
213 | negative_slope=0.2,
214 | residual=False,
215 | ):
216 | super(GAT, self).__init__()
217 | # For GAT, the number of layers is required to be > 1
218 | assert num_layers > 1
219 |
220 | hidden_dim //= num_heads
221 | self.num_layers = num_layers
222 | self.layers = nn.ModuleList()
223 | self.activation = activation
224 |
225 | heads = ([num_heads] * num_layers) + [1]
226 | # input (no residual)
227 | self.layers.append(
228 | GATConv(
229 | input_dim,
230 | hidden_dim,
231 | heads[0],
232 | dropout_ratio,
233 | attn_drop,
234 | negative_slope,
235 | False,
236 | self.activation,
237 | )
238 | )
239 |
240 | for l in range(1, num_layers - 1):
241 | # due to multi-head, the in_dim = hidden_dim * num_heads
242 | self.layers.append(
243 | GATConv(
244 | hidden_dim * heads[l - 1],
245 | hidden_dim,
246 | heads[l],
247 | dropout_ratio,
248 | attn_drop,
249 | negative_slope,
250 | residual,
251 | self.activation,
252 | )
253 | )
254 |
255 | self.layers.append(
256 | GATConv(
257 | hidden_dim * heads[-2],
258 | output_dim,
259 | heads[-1],
260 | dropout_ratio,
261 | attn_drop,
262 | negative_slope,
263 | residual,
264 | None,
265 | )
266 | )
267 |
268 | def forward(self, g, feats):
269 | h = feats
270 | h_list = []
271 | for l, layer in enumerate(self.layers):
272 | # [num_head, node_num, nclass] -> [num_head, node_num*nclass]
273 | h = layer(g, h)
274 | if l != self.num_layers - 1:
275 | h = h.flatten(1)
276 | h_list.append(h)
277 | else:
278 | h = h.mean(1)
279 | return h_list, h
280 |
281 |
282 | class APPNP(nn.Module):
283 | def __init__(
284 | self,
285 | num_layers,
286 | input_dim,
287 | hidden_dim,
288 | output_dim,
289 | dropout_ratio,
290 | activation,
291 | norm_type="none",
292 | edge_drop=0.5,
293 | alpha=0.1,
294 | k=10,
295 | ):
296 |
297 | super(APPNP, self).__init__()
298 | self.num_layers = num_layers
299 | self.norm_type = norm_type
300 | self.activation = activation
301 | self.dropout = nn.Dropout(dropout_ratio)
302 | self.layers = nn.ModuleList()
303 | self.norms = nn.ModuleList()
304 |
305 | if num_layers == 1:
306 | self.layers.append(nn.Linear(input_dim, output_dim))
307 | else:
308 | self.layers.append(nn.Linear(input_dim, hidden_dim))
309 | if self.norm_type == "batch":
310 | self.norms.append(nn.BatchNorm1d(hidden_dim))
311 | elif self.norm_type == "layer":
312 | self.norms.append(nn.LayerNorm(hidden_dim))
313 |
314 | for i in range(num_layers - 2):
315 | self.layers.append(nn.Linear(hidden_dim, hidden_dim))
316 | if self.norm_type == "batch":
317 | self.norms.append(nn.BatchNorm1d(hidden_dim))
318 | elif self.norm_type == "layer":
319 | self.norms.append(nn.LayerNorm(hidden_dim))
320 |
321 | self.layers.append(nn.Linear(hidden_dim, output_dim))
322 |
323 | self.propagate = APPNPConv(k, alpha, edge_drop)
324 | self.reset_parameters()
325 |
326 | def reset_parameters(self):
327 | for layer in self.layers:
328 | layer.reset_parameters()
329 |
330 | def forward(self, g, feats):
331 | h = feats
332 | h_list = []
333 | for l, layer in enumerate(self.layers):
334 | h = layer(h)
335 |
336 | if l != self.num_layers - 1:
337 | h_list.append(h)
338 | if self.norm_type != "none":
339 | h = self.norms[l](h)
340 | h = self.activation(h)
341 | h = self.dropout(h)
342 |
343 | h = self.propagate(g, h)
344 | return h_list, h
345 |
346 |
347 | class Model(nn.Module):
348 | """
349 | Wrapper of different models
350 | """
351 |
352 | def __init__(self, conf):
353 | super(Model, self).__init__()
354 | self.model_name = conf["model_name"]
355 | if "MLP" in conf["model_name"]:
356 | self.encoder = MLP(
357 | num_layers=conf["num_layers"],
358 | input_dim=conf["feat_dim"],
359 | hidden_dim=conf["hidden_dim"],
360 | output_dim=conf["label_dim"],
361 | dropout_ratio=conf["dropout_ratio"],
362 | norm_type=conf["norm_type"],
363 | ).to(conf["device"])
364 | elif "SAGE" in conf["model_name"]:
365 | self.encoder = SAGE(
366 | num_layers=conf["num_layers"],
367 | input_dim=conf["feat_dim"],
368 | hidden_dim=conf["hidden_dim"],
369 | output_dim=conf["label_dim"],
370 | dropout_ratio=conf["dropout_ratio"],
371 | activation=F.relu,
372 | norm_type=conf["norm_type"],
373 | ).to(conf["device"])
374 | elif "GCN" in conf["model_name"]:
375 | self.encoder = GCN(
376 | num_layers=conf["num_layers"],
377 | input_dim=conf["feat_dim"],
378 | hidden_dim=conf["hidden_dim"],
379 | output_dim=conf["label_dim"],
380 | dropout_ratio=conf["dropout_ratio"],
381 | activation=F.relu,
382 | norm_type=conf["norm_type"],
383 | ).to(conf["device"])
384 | elif "GAT" in conf["model_name"]:
385 | self.encoder = GAT(
386 | num_layers=conf["num_layers"],
387 | input_dim=conf["feat_dim"],
388 | hidden_dim=conf["hidden_dim"],
389 | output_dim=conf["label_dim"],
390 | dropout_ratio=conf["dropout_ratio"],
391 | activation=F.relu,
392 | attn_drop=conf["attn_dropout_ratio"],
393 | ).to(conf["device"])
394 | elif "APPNP" in conf["model_name"]:
395 | self.encoder = APPNP(
396 | num_layers=conf["num_layers"],
397 | input_dim=conf["feat_dim"],
398 | hidden_dim=conf["hidden_dim"],
399 | output_dim=conf["label_dim"],
400 | dropout_ratio=conf["dropout_ratio"],
401 | activation=F.relu,
402 | norm_type=conf["norm_type"],
403 | ).to(conf["device"])
404 |
405 | def forward(self, data, feats):
406 | """
407 | data: a graph `g` or a `dataloader` of blocks
408 | """
409 | if "MLP" in self.model_name:
410 | return self.encoder(feats)[1]
411 | else:
412 | return self.encoder(data, feats)[1]
413 |
414 | def forward_fitnet(self, data, feats):
415 | """
416 | Return a tuple (h_list, h)
417 | h_list: intermediate hidden representation
418 | h: final output
419 | """
420 | if "MLP" in self.model_name:
421 | return self.encoder(feats)
422 | else:
423 | return self.encoder(data, feats)
424 |
425 | def inference(self, data, feats):
426 | if "SAGE" in self.model_name:
427 | return self.encoder.inference(data, feats)
428 | else:
429 | return self.forward(data, feats)
430 |
--------------------------------------------------------------------------------
/prepare_env.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | conda create -y -n glnn python=3.6.9
4 | eval "$(conda shell.bash hook)"
5 | conda activate glnn
6 |
7 | pip install --no-cache-dir -r requirements.txt
8 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --find-links https://data.dgl.ai/wheels/repo.html
2 |
3 | ogb==1.3.3
4 | pillow==8.4.0
5 | scipy==1.4.1
6 | networkx==2.5.1
7 | numpy==1.18.1
8 | tabulate==0.8.7
9 | tqdm==4.62.3
10 | PyYAML==5.3.1
11 | scikit_learn==0.22.2
12 | googledrivedownloader==0.4
13 | category_encoders==2.3.0
14 | torch==1.7.0
15 | dgl==0.6.1
--------------------------------------------------------------------------------
/train.conf.yaml:
--------------------------------------------------------------------------------
1 | global:
2 | num_layers: 2
3 | hidden_dim: 128
4 |
5 | cora:
6 | SAGE:
7 | fan_out: 5,5
8 | learning_rate: 0.01
9 | dropout_ratio: 0
10 | weight_decay: 0.0005
11 |
12 | GCN:
13 | hidden_dim: 64
14 | dropout_ratio: 0.8
15 | weight_decay: 0.001
16 |
17 | MLP:
18 | learning_rate: 0.01
19 | weight_decay: 0.005
20 | dropout_ratio: 0.6
21 |
22 | GAT:
23 | dropout_ratio: 0.6
24 | weight_decay: 0.01
25 | num_heads: 8
26 | attn_dropout_ratio: 0.3
27 |
28 | APPNP:
29 | dropout_ratio: 0.5
30 | weight_decay: 0.01
31 |
32 |
33 | citeseer:
34 | SAGE:
35 | fan_out: 5,5
36 | learning_rate: 0.01
37 | dropout_ratio: 0
38 | weight_decay: 0.0005
39 |
40 | GCN:
41 | hidden_dim: 64
42 | dropout_ratio: 0.8
43 | weight_decay: 0.001
44 |
45 | MLP:
46 | learning_rate: 0.01
47 | weight_decay: 0.001
48 | dropout_ratio: 0.1
49 |
50 | GAT:
51 | dropout_ratio: 0.6
52 | weight_decay: 0.01
53 | num_heads: 8
54 | attn_dropout_ratio: 0.3
55 |
56 | APPNP:
57 | dropout_ratio: 0.5
58 | weight_decay: 0.01
59 |
60 | pubmed:
61 | SAGE:
62 | fan_out: 5,5
63 | learning_rate: 0.01
64 | dropout_ratio: 0
65 | weight_decay: 0.0005
66 |
67 | GCN:
68 | hidden_dim: 64
69 | dropout_ratio: 0.8
70 | weight_decay: 0.001
71 |
72 | MLP:
73 | learning_rate: 0.005
74 | weight_decay: 0
75 | dropout_ratio: 0.4
76 |
77 | GAT:
78 | dropout_ratio: 0.6
79 | weight_decay: 0.01
80 | num_heads: 8
81 | attn_dropout_ratio: 0.3
82 |
83 | APPNP:
84 | dropout_ratio: 0.5
85 | weight_decay: 0.01
86 |
87 | a-computer:
88 | SAGE:
89 | fan_out: 5,5
90 | learning_rate: 0.01
91 | dropout_ratio: 0
92 | weight_decay: 0.0005
93 |
94 | GCN:
95 | hidden_dim: 64
96 | dropout_ratio: 0.8
97 | weight_decay: 0.001
98 |
99 | MLP:
100 | learning_rate: 0.001
101 | weight_decay: 0.002
102 | dropout_ratio: 0.3
103 |
104 | GAT:
105 | dropout_ratio: 0.6
106 | weight_decay: 0.01
107 | num_heads: 8
108 | attn_dropout_ratio: 0.3
109 |
110 | APPNP:
111 | dropout_ratio: 0.5
112 | weight_decay: 0.01
113 |
114 | a-photo:
115 | SAGE:
116 | fan_out: 5,5
117 | learning_rate: 0.01
118 | dropout_ratio: 0
119 | weight_decay: 0.0005
120 |
121 | GCN:
122 | hidden_dim: 64
123 | dropout_ratio: 0.8
124 | weight_decay: 0.001
125 |
126 | MLP:
127 | learning_rate: 0.005
128 | weight_decay: 0.002
129 | dropout_ratio: 0.3
130 |
131 | GAT:
132 | dropout_ratio: 0.6
133 | weight_decay: 0.01
134 | num_heads: 8
135 | attn_dropout_ratio: 0.3
136 |
137 | APPNP:
138 | dropout_ratio: 0.5
139 | weight_decay: 0.01
140 |
141 | ogbn-arxiv:
142 | MLP:
143 | num_layers: 3
144 | hidden_dim: 256
145 | weight_decay: 0
146 | dropout_ratio: 0.2
147 | norm_type: batch
148 |
149 | MLP3w4:
150 | num_layers: 3
151 | hidden_dim: 1024
152 | weight_decay: 0
153 | dropout_ratio: 0.5
154 | norm_type: batch
155 |
156 | GA1MLP:
157 | num_layers: 3
158 | hidden_dim: 256
159 | weight_decay: 0
160 | dropout_ratio: 0.2
161 | norm_type: batch
162 |
163 | GA1MLP3w4:
164 | num_layers: 3
165 | hidden_dim: 1024
166 | weight_decay: 0
167 | dropout_ratio: 0.2
168 | norm_type: batch
169 |
170 | SAGE:
171 | num_layers: 3
172 | hidden_dim: 256
173 | dropout_ratio: 0.2
174 | learning_rate: 0.01
175 | weight_decay: 0
176 | norm_type: batch
177 | fan_out: 5,10,15
178 |
179 | ogbn-products:
180 | MLP:
181 | num_layers: 3
182 | hidden_dim: 256
183 | dropout_ratio: 0.5
184 | norm_type: batch
185 | batch_size: 4096
186 |
187 | MLP3w8:
188 | num_layers: 3
189 | hidden_dim: 2048
190 | dropout_ratio: 0.2
191 | learning_rate: 0.01
192 | weight_decay: 0
193 | norm_type: batch
194 | batch_size: 4096
195 |
196 | SAGE:
197 | num_layers: 3
198 | hidden_dim: 256
199 | dropout_ratio: 0.5
200 | learning_rate: 0.003
201 | weight_decay: 0
202 | norm_type: batch
203 | fan_out: 5,10,15
204 | batch_size: 4096
205 |
206 | pokec:
207 | GCN:
208 | num_layers: 2
209 | hidden_dim: 32
210 | dropout_ratio: 0.5
211 | learning_rate: 0.01
212 | weight_decay: 0
213 | norm_type: batch
214 |
215 | MLP:
216 | num_layers: 3
217 | hidden_dim: 256
218 | dropout_ratio: 0.2
219 | learning_rate: 0.001
220 | norm_type: none
221 |
222 | SAGE:
223 | num_layers: 2
224 | hidden_dim: 32
225 | dropout_ratio: 0.5
226 | learning_rate: 0.001
227 | weight_decay: 0
228 | norm_type: batch
229 | fan_out: 5,5
230 |
231 | penn94:
232 | GCN:
233 | num_layers: 2
234 | hidden_dim: 64
235 | dropout_ratio: 0.5
236 | learning_rate: 0.01
237 | weight_decay: 0.001
238 | norm_type: batch
239 |
240 | MLP:
241 | num_layers: 3
242 | hidden_dim: 256
243 | dropout_ratio: 0
244 | learning_rate: 0.001
245 | norm_type: none
246 |
247 | vk_class:
248 | MLP:
249 | num_layers: 3
250 | hidden_dim: 512
251 | learning_rate: 0.01
252 | dropout_ratio: 0
253 | weight_decay: 0
254 | norm_type: batch
255 | batch_size: 6754
256 |
257 | house_class:
258 | MLP:
259 | num_layers: 3
260 | hidden_dim: 512
261 | learning_rate: 0.01
262 | dropout_ratio: 0
263 | weight_decay: 0
264 | norm_type: layer
265 | batch_size: 2580
266 |
--------------------------------------------------------------------------------
/train_and_eval.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import copy
3 | import torch
4 | import dgl
5 | from utils import set_seed
6 |
7 | """
8 | 1. Train and eval
9 | """
10 |
11 |
12 | def train(model, data, feats, labels, criterion, optimizer, idx_train, lamb=1):
13 | """
14 | GNN full-batch training. Input the entire graph `g` as data.
15 | lamb: weight parameter lambda
16 | """
17 | model.train()
18 |
19 | # Compute loss and prediction
20 | logits = model(data, feats)
21 | out = logits.log_softmax(dim=1)
22 | loss = criterion(out[idx_train], labels[idx_train])
23 | loss_val = loss.item()
24 |
25 | loss *= lamb
26 | optimizer.zero_grad()
27 | loss.backward()
28 | optimizer.step()
29 | return loss_val
30 |
31 |
32 | def train_sage(model, dataloader, feats, labels, criterion, optimizer, lamb=1):
33 | """
34 | Train for GraphSAGE. Process the graph in mini-batches using `dataloader` instead the entire graph `g`.
35 | lamb: weight parameter lambda
36 | """
37 | device = feats.device
38 | model.train()
39 | total_loss = 0
40 | for step, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
41 | blocks = [blk.int().to(device) for blk in blocks]
42 | batch_feats = feats[input_nodes]
43 | batch_labels = labels[output_nodes]
44 |
45 | # Compute loss and prediction
46 | logits = model(blocks, batch_feats)
47 | out = logits.log_softmax(dim=1)
48 | loss = criterion(out, batch_labels)
49 | total_loss += loss.item()
50 |
51 | loss *= lamb
52 | optimizer.zero_grad()
53 | loss.backward()
54 | optimizer.step()
55 |
56 | return total_loss / len(dataloader)
57 |
58 |
59 | def train_mini_batch(model, feats, labels, batch_size, criterion, optimizer, lamb=1):
60 | """
61 | Train MLP for large datasets. Process the data in mini-batches. The graph is ignored, node features only.
62 | lamb: weight parameter lambda
63 | """
64 | model.train()
65 | num_batches = max(1, feats.shape[0] // batch_size)
66 | idx_batch = torch.randperm(feats.shape[0])[: num_batches * batch_size]
67 |
68 | if num_batches == 1:
69 | idx_batch = idx_batch.view(1, -1)
70 | else:
71 | idx_batch = idx_batch.view(num_batches, batch_size)
72 |
73 | total_loss = 0
74 | for i in range(num_batches):
75 | # No graph needed for the forward function
76 | logits = model(None, feats[idx_batch[i]])
77 | out = logits.log_softmax(dim=1)
78 |
79 | loss = criterion(out, labels[idx_batch[i]])
80 | total_loss += loss.item()
81 |
82 | loss *= lamb
83 | optimizer.zero_grad()
84 | loss.backward()
85 | optimizer.step()
86 | return total_loss / num_batches
87 |
88 |
89 | def evaluate(model, data, feats, labels, criterion, evaluator, idx_eval=None):
90 | """
91 | Returns:
92 | out: log probability of all input data
93 | loss & score (float): evaluated loss & score, if idx_eval is not None, only loss & score on those idx.
94 | """
95 | model.eval()
96 | with torch.no_grad():
97 | logits = model.inference(data, feats)
98 | out = logits.log_softmax(dim=1)
99 | if idx_eval is None:
100 | loss = criterion(out, labels)
101 | score = evaluator(out, labels)
102 | else:
103 | loss = criterion(out[idx_eval], labels[idx_eval])
104 | score = evaluator(out[idx_eval], labels[idx_eval])
105 | return out, loss.item(), score
106 |
107 |
108 | def evaluate_mini_batch(
109 | model, feats, labels, criterion, batch_size, evaluator, idx_eval=None
110 | ):
111 | """
112 | Evaluate MLP for large datasets. Process the data in mini-batches. The graph is ignored, node features only.
113 | Return:
114 | out: log probability of all input data
115 | loss & score (float): evaluated loss & score, if idx_eval is not None, only loss & score on those idx.
116 | """
117 |
118 | model.eval()
119 | with torch.no_grad():
120 | num_batches = int(np.ceil(len(feats) / batch_size))
121 | out_list = []
122 | for i in range(num_batches):
123 | logits = model.inference(None, feats[batch_size * i : batch_size * (i + 1)])
124 | out = logits.log_softmax(dim=1)
125 | out_list += [out.detach()]
126 |
127 | out_all = torch.cat(out_list)
128 |
129 | if idx_eval is None:
130 | loss = criterion(out_all, labels)
131 | score = evaluator(out_all, labels)
132 | else:
133 | loss = criterion(out_all[idx_eval], labels[idx_eval])
134 | score = evaluator(out_all[idx_eval], labels[idx_eval])
135 |
136 | return out_all, loss.item(), score
137 |
138 |
139 | """
140 | 2. Run teacher
141 | """
142 |
143 |
144 | def run_transductive(
145 | conf,
146 | model,
147 | g,
148 | feats,
149 | labels,
150 | indices,
151 | criterion,
152 | evaluator,
153 | optimizer,
154 | logger,
155 | loss_and_score,
156 | ):
157 | """
158 | Train and eval under the transductive setting.
159 | The train/valid/test split is specified by `indices`.
160 | The input graph is assumed to be large. Thus, SAGE is used for GNNs, mini-batch is used for MLPs.
161 |
162 | loss_and_score: Stores losses and scores.
163 | """
164 | set_seed(conf["seed"])
165 | device = conf["device"]
166 | batch_size = conf["batch_size"]
167 |
168 | idx_train, idx_val, idx_test = indices
169 |
170 | feats = feats.to(device)
171 | labels = labels.to(device)
172 |
173 | if "SAGE" in model.model_name:
174 | # Create dataloader for SAGE
175 |
176 | # Create csr/coo/csc formats before launching sampling processes
177 | # This avoids creating certain formats in each data loader process, which saves momory and CPU.
178 | g.create_formats_()
179 | sampler = dgl.dataloading.MultiLayerNeighborSampler(
180 | [eval(fanout) for fanout in conf["fan_out"].split(",")]
181 | )
182 | dataloader = dgl.dataloading.NodeDataLoader(
183 | g,
184 | idx_train,
185 | sampler,
186 | batch_size=batch_size,
187 | shuffle=True,
188 | drop_last=False,
189 | num_workers=conf["num_workers"],
190 | )
191 |
192 | # SAGE inference is implemented as layer by layer, so the full-neighbor sampler only collects one-hop neighors
193 | sampler_eval = dgl.dataloading.MultiLayerFullNeighborSampler(1)
194 | dataloader_eval = dgl.dataloading.NodeDataLoader(
195 | g,
196 | torch.arange(g.num_nodes()),
197 | sampler_eval,
198 | batch_size=batch_size,
199 | shuffle=False,
200 | drop_last=False,
201 | num_workers=conf["num_workers"],
202 | )
203 |
204 | data = dataloader
205 | data_eval = dataloader_eval
206 | elif "MLP" in model.model_name:
207 | feats_train, labels_train = feats[idx_train], labels[idx_train]
208 | feats_val, labels_val = feats[idx_val], labels[idx_val]
209 | feats_test, labels_test = feats[idx_test], labels[idx_test]
210 | else:
211 | g = g.to(device)
212 | data = g
213 | data_eval = g
214 |
215 | best_epoch, best_score_val, count = 0, 0, 0
216 | for epoch in range(1, conf["max_epoch"] + 1):
217 | if "SAGE" in model.model_name:
218 | loss = train_sage(model, data, feats, labels, criterion, optimizer)
219 | elif "MLP" in model.model_name:
220 | loss = train_mini_batch(
221 | model, feats_train, labels_train, batch_size, criterion, optimizer
222 | )
223 | else:
224 | loss = train(model, data, feats, labels, criterion, optimizer, idx_train)
225 |
226 | if epoch % conf["eval_interval"] == 0:
227 | if "MLP" in model.model_name:
228 | _, loss_train, score_train = evaluate_mini_batch(
229 | model, feats_train, labels_train, criterion, batch_size, evaluator
230 | )
231 | _, loss_val, score_val = evaluate_mini_batch(
232 | model, feats_val, labels_val, criterion, batch_size, evaluator
233 | )
234 | _, loss_test, score_test = evaluate_mini_batch(
235 | model, feats_test, labels_test, criterion, batch_size, evaluator
236 | )
237 | else:
238 | out, loss_train, score_train = evaluate(
239 | model, data_eval, feats, labels, criterion, evaluator, idx_train
240 | )
241 | # Use criterion & evaluator instead of evaluate to avoid redundant forward pass
242 | loss_val = criterion(out[idx_val], labels[idx_val]).item()
243 | score_val = evaluator(out[idx_val], labels[idx_val])
244 | loss_test = criterion(out[idx_test], labels[idx_test]).item()
245 | score_test = evaluator(out[idx_test], labels[idx_test])
246 |
247 | logger.debug(
248 | f"Ep {epoch:3d} | loss: {loss:.4f} | s_train: {score_train:.4f} | s_val: {score_val:.4f} | s_test: {score_test:.4f}"
249 | )
250 | loss_and_score += [
251 | [
252 | epoch,
253 | loss_train,
254 | loss_val,
255 | loss_test,
256 | score_train,
257 | score_val,
258 | score_test,
259 | ]
260 | ]
261 |
262 | if score_val >= best_score_val:
263 | best_epoch = epoch
264 | best_score_val = score_val
265 | state = copy.deepcopy(model.state_dict())
266 | count = 0
267 | else:
268 | count += 1
269 |
270 | if count == conf["patience"] or epoch == conf["max_epoch"]:
271 | break
272 |
273 | model.load_state_dict(state)
274 | if "MLP" in model.model_name:
275 | out, _, score_val = evaluate_mini_batch(
276 | model, feats, labels, criterion, batch_size, evaluator, idx_val
277 | )
278 | else:
279 | out, _, score_val = evaluate(
280 | model, data_eval, feats, labels, criterion, evaluator, idx_val
281 | )
282 |
283 | score_test = evaluator(out[idx_test], labels[idx_test])
284 | logger.info(
285 | f"Best valid model at epoch: {best_epoch: 3d}, score_val: {score_val :.4f}, score_test: {score_test :.4f}"
286 | )
287 | return out, score_val, score_test
288 |
289 |
290 | def run_inductive(
291 | conf,
292 | model,
293 | g,
294 | feats,
295 | labels,
296 | indices,
297 | criterion,
298 | evaluator,
299 | optimizer,
300 | logger,
301 | loss_and_score,
302 | ):
303 | """
304 | Train and eval under the inductive setting.
305 | The train/valid/test split is specified by `indices`.
306 | idx starting with `obs_idx_` contains the node idx in the observed graph `obs_g`.
307 | idx starting with `idx_` contains the node idx in the original graph `g`.
308 | The model is trained on the observed graph `obs_g`, and evaluated on both the observed test nodes (`obs_idx_test`) and inductive test nodes (`idx_test_ind`).
309 | The input graph is assumed to be large. Thus, SAGE is used for GNNs, mini-batch is used for MLPs.
310 |
311 | idx_obs: Idx of nodes in the original graph `g`, which form the observed graph 'obs_g'.
312 | loss_and_score: Stores losses and scores.
313 | """
314 |
315 | set_seed(conf["seed"])
316 | device = conf["device"]
317 | batch_size = conf["batch_size"]
318 | obs_idx_train, obs_idx_val, obs_idx_test, idx_obs, idx_test_ind = indices
319 |
320 | feats = feats.to(device)
321 | labels = labels.to(device)
322 | obs_feats = feats[idx_obs]
323 | obs_labels = labels[idx_obs]
324 | obs_g = g.subgraph(idx_obs)
325 |
326 | if "SAGE" in model.model_name:
327 | # Create dataloader for SAGE
328 |
329 | # Create csr/coo/csc formats before launching sampling processes
330 | # This avoids creating certain formats in each data loader process, which saves momory and CPU.
331 | obs_g.create_formats_()
332 | g.create_formats_()
333 | sampler = dgl.dataloading.MultiLayerNeighborSampler(
334 | [eval(fanout) for fanout in conf["fan_out"].split(",")]
335 | )
336 | obs_dataloader = dgl.dataloading.NodeDataLoader(
337 | obs_g,
338 | obs_idx_train,
339 | sampler,
340 | batch_size=batch_size,
341 | shuffle=True,
342 | drop_last=False,
343 | num_workers=conf["num_workers"],
344 | )
345 |
346 | sampler_eval = dgl.dataloading.MultiLayerFullNeighborSampler(1)
347 | obs_dataloader_eval = dgl.dataloading.NodeDataLoader(
348 | obs_g,
349 | torch.arange(obs_g.num_nodes()),
350 | sampler_eval,
351 | batch_size=batch_size,
352 | shuffle=False,
353 | drop_last=False,
354 | num_workers=conf["num_workers"],
355 | )
356 | dataloader_eval = dgl.dataloading.NodeDataLoader(
357 | g,
358 | torch.arange(g.num_nodes()),
359 | sampler_eval,
360 | batch_size=batch_size,
361 | shuffle=False,
362 | drop_last=False,
363 | num_workers=conf["num_workers"],
364 | )
365 |
366 | obs_data = obs_dataloader
367 | obs_data_eval = obs_dataloader_eval
368 | data_eval = dataloader_eval
369 | elif "MLP" in model.model_name:
370 | feats_train, labels_train = obs_feats[obs_idx_train], obs_labels[obs_idx_train]
371 | feats_val, labels_val = obs_feats[obs_idx_val], obs_labels[obs_idx_val]
372 | feats_test_tran, labels_test_tran = (
373 | obs_feats[obs_idx_test],
374 | obs_labels[obs_idx_test],
375 | )
376 | feats_test_ind, labels_test_ind = feats[idx_test_ind], labels[idx_test_ind]
377 |
378 | else:
379 | obs_g = obs_g.to(device)
380 | g = g.to(device)
381 |
382 | obs_data = obs_g
383 | obs_data_eval = obs_g
384 | data_eval = g
385 |
386 | best_epoch, best_score_val, count = 0, 0, 0
387 | for epoch in range(1, conf["max_epoch"] + 1):
388 | if "SAGE" in model.model_name:
389 | loss = train_sage(
390 | model, obs_data, obs_feats, obs_labels, criterion, optimizer
391 | )
392 | elif "MLP" in model.model_name:
393 | loss = train_mini_batch(
394 | model, feats_train, labels_train, batch_size, criterion, optimizer
395 | )
396 | else:
397 | loss = train(
398 | model,
399 | obs_data,
400 | obs_feats,
401 | obs_labels,
402 | criterion,
403 | optimizer,
404 | obs_idx_train,
405 | )
406 |
407 | if epoch % conf["eval_interval"] == 0:
408 | if "MLP" in model.model_name:
409 | _, loss_train, score_train = evaluate_mini_batch(
410 | model, feats_train, labels_train, criterion, batch_size, evaluator
411 | )
412 | _, loss_val, score_val = evaluate_mini_batch(
413 | model, feats_val, labels_val, criterion, batch_size, evaluator
414 | )
415 | _, loss_test_tran, score_test_tran = evaluate_mini_batch(
416 | model,
417 | feats_test_tran,
418 | labels_test_tran,
419 | criterion,
420 | batch_size,
421 | evaluator,
422 | )
423 | _, loss_test_ind, score_test_ind = evaluate_mini_batch(
424 | model,
425 | feats_test_ind,
426 | labels_test_ind,
427 | criterion,
428 | batch_size,
429 | evaluator,
430 | )
431 | else:
432 | obs_out, loss_train, score_train = evaluate(
433 | model,
434 | obs_data_eval,
435 | obs_feats,
436 | obs_labels,
437 | criterion,
438 | evaluator,
439 | obs_idx_train,
440 | )
441 | # Use criterion & evaluator instead of evaluate to avoid redundant forward pass
442 | loss_val = criterion(
443 | obs_out[obs_idx_val], obs_labels[obs_idx_val]
444 | ).item()
445 | score_val = evaluator(obs_out[obs_idx_val], obs_labels[obs_idx_val])
446 | loss_test_tran = criterion(
447 | obs_out[obs_idx_test], obs_labels[obs_idx_test]
448 | ).item()
449 | score_test_tran = evaluator(
450 | obs_out[obs_idx_test], obs_labels[obs_idx_test]
451 | )
452 |
453 | # Evaluate the inductive part with the full graph
454 | out, loss_test_ind, score_test_ind = evaluate(
455 | model, data_eval, feats, labels, criterion, evaluator, idx_test_ind
456 | )
457 | logger.debug(
458 | f"Ep {epoch:3d} | loss: {loss:.4f} | s_train: {score_train:.4f} | s_val: {score_val:.4f} | s_tt: {score_test_tran:.4f} | s_ti: {score_test_ind:.4f}"
459 | )
460 | loss_and_score += [
461 | [
462 | epoch,
463 | loss_train,
464 | loss_val,
465 | loss_test_tran,
466 | loss_test_ind,
467 | score_train,
468 | score_val,
469 | score_test_tran,
470 | score_test_ind,
471 | ]
472 | ]
473 | if score_val >= best_score_val:
474 | best_epoch = epoch
475 | best_score_val = score_val
476 | state = copy.deepcopy(model.state_dict())
477 | count = 0
478 | else:
479 | count += 1
480 |
481 | if count == conf["patience"] or epoch == conf["max_epoch"]:
482 | break
483 |
484 | model.load_state_dict(state)
485 | if "MLP" in model.model_name:
486 | obs_out, _, score_val = evaluate_mini_batch(
487 | model, obs_feats, obs_labels, criterion, batch_size, evaluator, obs_idx_val
488 | )
489 | out, _, score_test_ind = evaluate_mini_batch(
490 | model, feats, labels, criterion, batch_size, evaluator, idx_test_ind
491 | )
492 |
493 | else:
494 | obs_out, _, score_val = evaluate(
495 | model,
496 | obs_data_eval,
497 | obs_feats,
498 | obs_labels,
499 | criterion,
500 | evaluator,
501 | obs_idx_val,
502 | )
503 | out, _, score_test_ind = evaluate(
504 | model, data_eval, feats, labels, criterion, evaluator, idx_test_ind
505 | )
506 |
507 | score_test_tran = evaluator(obs_out[obs_idx_test], obs_labels[obs_idx_test])
508 | out[idx_obs] = obs_out
509 | logger.info(
510 | f"Best valid model at epoch: {best_epoch :3d}, score_val: {score_val :.4f}, score_test_tran: {score_test_tran :.4f}, score_test_ind: {score_test_ind :.4f}"
511 | )
512 | return out, score_val, score_test_tran, score_test_ind
513 |
514 |
515 | """
516 | 3. Distill
517 | """
518 |
519 |
520 | def distill_run_transductive(
521 | conf,
522 | model,
523 | feats,
524 | labels,
525 | out_t_all,
526 | distill_indices,
527 | criterion_l,
528 | criterion_t,
529 | evaluator,
530 | optimizer,
531 | logger,
532 | loss_and_score,
533 | ):
534 | """
535 | Distill training and eval under the transductive setting.
536 | The hard_label_train/soft_label_train/valid/test split is specified by `distill_indices`.
537 | The input graph is assumed to be large, and MLP is assumed to be the student model. Thus, node feature only and mini-batch is used.
538 |
539 | out_t: Soft labels produced by the teacher model.
540 | criterion_l & criterion_t: Loss used for hard labels (`labels`) and soft labels (`out_t`) respectively
541 | loss_and_score: Stores losses and scores.
542 | """
543 | set_seed(conf["seed"])
544 | device = conf["device"]
545 | batch_size = conf["batch_size"]
546 | lamb = conf["lamb"]
547 | idx_l, idx_t, idx_val, idx_test = distill_indices
548 |
549 | feats = feats.to(device)
550 | labels = labels.to(device)
551 | out_t_all = out_t_all.to(device)
552 |
553 | feats_l, labels_l = feats[idx_l], labels[idx_l]
554 | feats_t, out_t = feats[idx_t], out_t_all[idx_t]
555 | feats_val, labels_val = feats[idx_val], labels[idx_val]
556 | feats_test, labels_test = feats[idx_test], labels[idx_test]
557 |
558 | best_epoch, best_score_val, count = 0, 0, 0
559 | for epoch in range(1, conf["max_epoch"] + 1):
560 | loss_l = train_mini_batch(
561 | model, feats_l, labels_l, batch_size, criterion_l, optimizer, lamb
562 | )
563 | loss_t = train_mini_batch(
564 | model, feats_t, out_t, batch_size, criterion_t, optimizer, 1 - lamb
565 | )
566 | loss = loss_l + loss_t
567 | if epoch % conf["eval_interval"] == 0:
568 | _, loss_l, score_l = evaluate_mini_batch(
569 | model, feats_l, labels_l, criterion_l, batch_size, evaluator
570 | )
571 | _, loss_val, score_val = evaluate_mini_batch(
572 | model, feats_val, labels_val, criterion_l, batch_size, evaluator
573 | )
574 | _, loss_test, score_test = evaluate_mini_batch(
575 | model, feats_test, labels_test, criterion_l, batch_size, evaluator
576 | )
577 |
578 | logger.debug(
579 | f"Ep {epoch:3d} | loss: {loss:.4f} | s_l: {score_l:.4f} | s_val: {score_val:.4f} | s_test: {score_test:.4f}"
580 | )
581 | loss_and_score += [
582 | [epoch, loss_l, loss_val, loss_test, score_l, score_val, score_test]
583 | ]
584 |
585 | if score_val >= best_score_val:
586 | best_epoch = epoch
587 | best_score_val = score_val
588 | state = copy.deepcopy(model.state_dict())
589 | count = 0
590 | else:
591 | count += 1
592 |
593 | if count == conf["patience"] or epoch == conf["max_epoch"]:
594 | break
595 |
596 | model.load_state_dict(state)
597 | out, _, score_val = evaluate_mini_batch(
598 | model, feats, labels, criterion_l, batch_size, evaluator, idx_val
599 | )
600 | # Use evaluator instead of evaluate to avoid redundant forward pass
601 | score_test = evaluator(out[idx_test], labels_test)
602 |
603 | logger.info(
604 | f"Best valid model at epoch: {best_epoch: 3d}, score_val: {score_val :.4f}, score_test: {score_test :.4f}"
605 | )
606 | return out, score_val, score_test
607 |
608 |
609 | def distill_run_inductive(
610 | conf,
611 | model,
612 | feats,
613 | labels,
614 | out_t_all,
615 | distill_indices,
616 | criterion_l,
617 | criterion_t,
618 | evaluator,
619 | optimizer,
620 | logger,
621 | loss_and_score,
622 | ):
623 | """
624 | Distill training and eval under the inductive setting.
625 | The hard_label_train/soft_label_train/valid/test split is specified by `distill_indices`.
626 | idx starting with `obs_idx_` contains the node idx in the observed graph `obs_g`.
627 | idx starting with `idx_` contains the node idx in the original graph `g`.
628 | The model is trained on the observed graph `obs_g`, and evaluated on both the observed test nodes (`obs_idx_test`) and inductive test nodes (`idx_test_ind`).
629 | The input graph is assumed to be large, and MLP is assumed to be the student model. Thus, node feature only and mini-batch is used.
630 |
631 | idx_obs: Idx of nodes in the original graph `g`, which form the observed graph 'obs_g'.
632 | out_t: Soft labels produced by the teacher model.
633 | criterion_l & criterion_t: Loss used for hard labels (`labels`) and soft labels (`out_t`) respectively.
634 | loss_and_score: Stores losses and scores.
635 | """
636 |
637 | set_seed(conf["seed"])
638 | device = conf["device"]
639 | batch_size = conf["batch_size"]
640 | lamb = conf["lamb"]
641 | (
642 | obs_idx_l,
643 | obs_idx_t,
644 | obs_idx_val,
645 | obs_idx_test,
646 | idx_obs,
647 | idx_test_ind,
648 | ) = distill_indices
649 |
650 | feats = feats.to(device)
651 | labels = labels.to(device)
652 | out_t_all = out_t_all.to(device)
653 | obs_feats = feats[idx_obs]
654 | obs_labels = labels[idx_obs]
655 | obs_out_t = out_t_all[idx_obs]
656 |
657 | feats_l, labels_l = obs_feats[obs_idx_l], obs_labels[obs_idx_l]
658 | feats_t, out_t = obs_feats[obs_idx_t], obs_out_t[obs_idx_t]
659 | feats_val, labels_val = obs_feats[obs_idx_val], obs_labels[obs_idx_val]
660 | feats_test_tran, labels_test_tran = (
661 | obs_feats[obs_idx_test],
662 | obs_labels[obs_idx_test],
663 | )
664 | feats_test_ind, labels_test_ind = feats[idx_test_ind], labels[idx_test_ind]
665 |
666 | best_epoch, best_score_val, count = 0, 0, 0
667 | for epoch in range(1, conf["max_epoch"] + 1):
668 | loss_l = train_mini_batch(
669 | model, feats_l, labels_l, batch_size, criterion_l, optimizer, lamb
670 | )
671 | loss_t = train_mini_batch(
672 | model, feats_t, out_t, batch_size, criterion_t, optimizer, 1 - lamb
673 | )
674 | loss = loss_l + loss_t
675 | if epoch % conf["eval_interval"] == 0:
676 | _, loss_l, score_l = evaluate_mini_batch(
677 | model, feats_l, labels_l, criterion_l, batch_size, evaluator
678 | )
679 | _, loss_val, score_val = evaluate_mini_batch(
680 | model, feats_val, labels_val, criterion_l, batch_size, evaluator
681 | )
682 | _, loss_test_tran, score_test_tran = evaluate_mini_batch(
683 | model,
684 | feats_test_tran,
685 | labels_test_tran,
686 | criterion_l,
687 | batch_size,
688 | evaluator,
689 | )
690 | _, loss_test_ind, score_test_ind = evaluate_mini_batch(
691 | model,
692 | feats_test_ind,
693 | labels_test_ind,
694 | criterion_l,
695 | batch_size,
696 | evaluator,
697 | )
698 |
699 | logger.debug(
700 | f"Ep {epoch:3d} | l: {loss:.4f} | s_l: {score_l:.4f} | s_val: {score_val:.4f} | s_tt: {score_test_tran:.4f} | s_ti: {score_test_ind:.4f}"
701 | )
702 | loss_and_score += [
703 | [
704 | epoch,
705 | loss_l,
706 | loss_val,
707 | loss_test_tran,
708 | loss_test_ind,
709 | score_l,
710 | score_val,
711 | score_test_tran,
712 | score_test_ind,
713 | ]
714 | ]
715 |
716 | if score_val >= best_score_val:
717 | best_epoch = epoch
718 | best_score_val = score_val
719 | state = copy.deepcopy(model.state_dict())
720 | count = 0
721 | else:
722 | count += 1
723 |
724 | if count == conf["patience"] or epoch == conf["max_epoch"]:
725 | break
726 |
727 | model.load_state_dict(state)
728 | obs_out, _, score_val = evaluate_mini_batch(
729 | model, obs_feats, obs_labels, criterion_l, batch_size, evaluator, obs_idx_val
730 | )
731 | out, _, score_test_ind = evaluate_mini_batch(
732 | model, feats, labels, criterion_l, batch_size, evaluator, idx_test_ind
733 | )
734 |
735 | # Use evaluator instead of evaluate to avoid redundant forward pass
736 | score_test_tran = evaluator(obs_out[obs_idx_test], labels_test_tran)
737 | out[idx_obs] = obs_out
738 |
739 | logger.info(
740 | f"Best valid model at epoch: {best_epoch: 3d} score_val: {score_val :.4f}, score_test_tran: {score_test_tran :.4f}, score_test_ind: {score_test_ind :.4f}"
741 | )
742 | return out, score_val, score_test_tran, score_test_ind
743 |
--------------------------------------------------------------------------------
/train_student.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import torch
4 | import torch.optim as optim
5 | from pathlib import Path
6 | from models import Model
7 | from dataloader import load_data, load_out_t
8 | from utils import (
9 | get_logger,
10 | get_evaluator,
11 | set_seed,
12 | get_training_config,
13 | check_writable,
14 | check_readable,
15 | compute_min_cut_loss,
16 | graph_split,
17 | feature_prop,
18 | )
19 | from train_and_eval import distill_run_transductive, distill_run_inductive
20 |
21 |
22 | def get_args():
23 | parser = argparse.ArgumentParser(description="PyTorch DGL implementation")
24 | parser.add_argument("--device", type=int, default=-1, help="CUDA device, -1 means CPU")
25 | parser.add_argument("--seed", type=int, default=0, help="Random seed")
26 | parser.add_argument(
27 | "--log_level",
28 | type=int,
29 | default=20,
30 | help="Logger levels for run {10: DEBUG, 20: INFO, 30: WARNING}",
31 | )
32 | parser.add_argument(
33 | "--console_log",
34 | action="store_true",
35 | help="Set to True to display log info in console",
36 | )
37 | parser.add_argument(
38 | "--output_path", type=str, default="outputs", help="Path to save outputs"
39 | )
40 | parser.add_argument(
41 | "--num_exp", type=int, default=1, help="Repeat how many experiments"
42 | )
43 | parser.add_argument(
44 | "--exp_setting",
45 | type=str,
46 | default="tran",
47 | help="Experiment setting, one of [tran, ind]",
48 | )
49 | parser.add_argument(
50 | "--eval_interval", type=int, default=1, help="Evaluate once per how many epochs"
51 | )
52 | parser.add_argument(
53 | "--save_results",
54 | action="store_true",
55 | help="Set to True to save the loss curves, trained model, and min-cut loss for the transductive setting",
56 | )
57 |
58 | """Dataset"""
59 | parser.add_argument("--dataset", type=str, default="cora", help="Dataset")
60 | parser.add_argument("--data_path", type=str, default="./data", help="Path to data")
61 | parser.add_argument(
62 | "--labelrate_train",
63 | type=int,
64 | default=20,
65 | help="How many labeled data per class as train set",
66 | )
67 | parser.add_argument(
68 | "--labelrate_val",
69 | type=int,
70 | default=30,
71 | help="How many labeled data per class in valid set",
72 | )
73 | parser.add_argument(
74 | "--split_idx",
75 | type=int,
76 | default=0,
77 | help="For Non-Homo datasets only, one of [0,1,2,3,4]",
78 | )
79 |
80 | """Model"""
81 | parser.add_argument(
82 | "--model_config_path",
83 | type=str,
84 | default="./train.conf.yaml",
85 | help="Path to model configeration",
86 | )
87 | parser.add_argument("--teacher", type=str, default="SAGE", help="Teacher model")
88 | parser.add_argument("--student", type=str, default="MLP", help="Student model")
89 | parser.add_argument(
90 | "--num_layers", type=int, default=2, help="Student model number of layers"
91 | )
92 | parser.add_argument(
93 | "--hidden_dim",
94 | type=int,
95 | default=64,
96 | help="Student model hidden layer dimensions",
97 | )
98 | parser.add_argument("--dropout_ratio", type=float, default=0)
99 | parser.add_argument(
100 | "--norm_type", type=str, default="none", help="One of [none, batch, layer]"
101 | )
102 |
103 | """SAGE Specific"""
104 | parser.add_argument("--batch_size", type=int, default=512)
105 | parser.add_argument(
106 | "--fan_out",
107 | type=str,
108 | default="5,5",
109 | help="Number of samples for each layer in SAGE. Length = num_layers",
110 | )
111 | parser.add_argument(
112 | "--num_workers", type=int, default=0, help="Number of workers for sampler"
113 | )
114 |
115 | """Optimization"""
116 | parser.add_argument("--learning_rate", type=float, default=0.01)
117 | parser.add_argument("--weight_decay", type=float, default=0.0005)
118 | parser.add_argument(
119 | "--max_epoch", type=int, default=500, help="Evaluate once per how many epochs"
120 | )
121 | parser.add_argument(
122 | "--patience",
123 | type=int,
124 | default=50,
125 | help="Early stop is the score on validation set does not improve for how many epochs",
126 | )
127 |
128 | """Ablation"""
129 | parser.add_argument(
130 | "--feature_noise",
131 | type=float,
132 | default=0,
133 | help="add white noise to features for analysis, value in [0, 1] for noise level",
134 | )
135 | parser.add_argument(
136 | "--split_rate",
137 | type=float,
138 | default=0.2,
139 | help="Rate for graph split, see comment of graph_split for more details",
140 | )
141 | parser.add_argument(
142 | "--compute_min_cut",
143 | action="store_true",
144 | help="Set to True to compute and store the min-cut loss",
145 | )
146 | parser.add_argument(
147 | "--feature_aug_k",
148 | type=int,
149 | default=0,
150 | help="Augment node futures by aggregating feature_aug_k-hop neighbor features",
151 | )
152 |
153 | """Distiall"""
154 | parser.add_argument(
155 | "--lamb",
156 | type=float,
157 | default=0,
158 | help="Parameter balances loss from hard labels and teacher outputs, take values in [0, 1]",
159 | )
160 | parser.add_argument(
161 | "--out_t_path", type=str, default="outputs", help="Path to load teacher outputs"
162 | )
163 | args = parser.parse_args()
164 |
165 | return args
166 |
167 |
168 | def run(args):
169 | """
170 | Returns:
171 | score_lst: a list of evaluation results on test set.
172 | len(score_lst) = 1 for the transductive setting.
173 | len(score_lst) = 2 for the inductive/production setting.
174 | """
175 |
176 | """ Set seed, device, and logger """
177 | set_seed(args.seed)
178 | if torch.cuda.is_available() and args.device >= 0:
179 | device = torch.device("cuda:" + str(args.device))
180 | else:
181 | device = "cpu"
182 |
183 | if args.feature_noise != 0 and args.seed == 0:
184 | args.output_path = Path.cwd().joinpath(
185 | args.output_path, "noisy_features", f"noise_{args.feature_noise}"
186 | )
187 | # Teacher is assumed to be trained on the same noisy features as well.
188 | args.out_t_path = args.output_path
189 |
190 | if args.feature_aug_k > 0 and args.seed == 0:
191 | args.output_path = Path.cwd().joinpath(
192 | args.output_path, "aug_features", f"aug_hop_{args.feature_aug_k}"
193 | )
194 | # NOTE: Teacher may or may not have augmented features, specify args.out_t_path explicitly.
195 | # args.out_t_path =
196 | args.student = f"GA{args.feature_aug_k}{args.student}"
197 |
198 | if args.exp_setting == "tran":
199 | output_dir = Path.cwd().joinpath(
200 | args.output_path,
201 | "transductive",
202 | args.dataset,
203 | f"{args.teacher}_{args.student}",
204 | f"seed_{args.seed}",
205 | )
206 | out_t_dir = Path.cwd().joinpath(
207 | args.out_t_path,
208 | "transductive",
209 | args.dataset,
210 | args.teacher,
211 | f"seed_{args.seed}",
212 | )
213 | elif args.exp_setting == "ind":
214 | output_dir = Path.cwd().joinpath(
215 | args.output_path,
216 | "inductive",
217 | f"split_rate_{args.split_rate}",
218 | args.dataset,
219 | f"{args.teacher}_{args.student}",
220 | f"seed_{args.seed}",
221 | )
222 | out_t_dir = Path.cwd().joinpath(
223 | args.out_t_path,
224 | "inductive",
225 | f"split_rate_{args.split_rate}",
226 | args.dataset,
227 | args.teacher,
228 | f"seed_{args.seed}",
229 | )
230 | else:
231 | raise ValueError(f"Unknown experiment setting! {args.exp_setting}")
232 | args.output_dir = output_dir
233 |
234 | check_writable(output_dir, overwrite=False)
235 | check_readable(out_t_dir)
236 |
237 | logger = get_logger(output_dir.joinpath("log"), args.console_log, args.log_level)
238 | logger.info(f"output_dir: {output_dir}")
239 | logger.info(f"out_t_dir: {out_t_dir}")
240 |
241 | """ Load data and model config"""
242 | g, labels, idx_train, idx_val, idx_test = load_data(
243 | args.dataset,
244 | args.data_path,
245 | split_idx=args.split_idx,
246 | seed=args.seed,
247 | labelrate_train=args.labelrate_train,
248 | labelrate_val=args.labelrate_val,
249 | )
250 |
251 | logger.info(f"Total {g.number_of_nodes()} nodes.")
252 | logger.info(f"Total {g.number_of_edges()} edges.")
253 |
254 | feats = g.ndata["feat"]
255 | args.feat_dim = g.ndata["feat"].shape[1]
256 | args.label_dim = labels.int().max().item() + 1
257 |
258 | if 0 < args.feature_noise <= 1:
259 | feats = (
260 | 1 - args.feature_noise
261 | ) * feats + args.feature_noise * torch.randn_like(feats)
262 |
263 | """ Model config """
264 | conf = {}
265 | if args.model_config_path is not None:
266 | conf = get_training_config(
267 | args.model_config_path, args.student, args.dataset
268 | ) # Note: student config
269 | conf = dict(args.__dict__, **conf)
270 | conf["device"] = device
271 | logger.info(f"conf: {conf}")
272 |
273 | """ Model init """
274 | model = Model(conf)
275 | optimizer = optim.Adam(
276 | model.parameters(), lr=conf["learning_rate"], weight_decay=conf["weight_decay"]
277 | )
278 | criterion_l = torch.nn.NLLLoss()
279 | criterion_t = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
280 | evaluator = get_evaluator(conf["dataset"])
281 |
282 | """Load teacher model output"""
283 | out_t = load_out_t(out_t_dir)
284 | logger.debug(
285 | f"teacher score on train data: {evaluator(out_t[idx_train], labels[idx_train])}"
286 | )
287 | logger.debug(
288 | f"teacher score on val data: {evaluator(out_t[idx_val], labels[idx_val])}"
289 | )
290 | logger.debug(
291 | f"teacher score on test data: {evaluator(out_t[idx_test], labels[idx_test])}"
292 | )
293 |
294 | """Data split and run"""
295 | loss_and_score = []
296 | if args.exp_setting == "tran":
297 | idx_l = idx_train
298 | idx_t = torch.cat([idx_train, idx_val, idx_test])
299 | distill_indices = (idx_l, idx_t, idx_val, idx_test)
300 |
301 | # propagate node feature
302 | if args.feature_aug_k > 0:
303 | feats = feature_prop(feats, g, args.feature_aug_k)
304 |
305 | out, score_val, score_test = distill_run_transductive(
306 | conf,
307 | model,
308 | feats,
309 | labels,
310 | out_t,
311 | distill_indices,
312 | criterion_l,
313 | criterion_t,
314 | evaluator,
315 | optimizer,
316 | logger,
317 | loss_and_score,
318 | )
319 | score_lst = [score_test]
320 |
321 | elif args.exp_setting == "ind":
322 | # Create inductive split
323 | obs_idx_train, obs_idx_val, obs_idx_test, idx_obs, idx_test_ind = graph_split(
324 | idx_train, idx_val, idx_test, args.split_rate, args.seed
325 | )
326 | obs_idx_l = obs_idx_train
327 | obs_idx_t = torch.cat([obs_idx_train, obs_idx_val, obs_idx_test])
328 | distill_indices = (
329 | obs_idx_l,
330 | obs_idx_t,
331 | obs_idx_val,
332 | obs_idx_test,
333 | idx_obs,
334 | idx_test_ind,
335 | )
336 |
337 | # propagate node feature. The propagation for the observed graph only happens within the subgraph obs_g
338 | if args.feature_aug_k > 0:
339 | obs_g = g.subgraph(idx_obs)
340 | obs_feats = feature_prop(feats[idx_obs], obs_g, args.feature_aug_k)
341 | feats = feature_prop(feats, g, args.feature_aug_k)
342 | feats[idx_obs] = obs_feats
343 |
344 | out, score_val, score_test_tran, score_test_ind = distill_run_inductive(
345 | conf,
346 | model,
347 | feats,
348 | labels,
349 | out_t,
350 | distill_indices,
351 | criterion_l,
352 | criterion_t,
353 | evaluator,
354 | optimizer,
355 | logger,
356 | loss_and_score,
357 | )
358 | score_lst = [score_test_tran, score_test_ind]
359 |
360 | logger.info(
361 | f"num_layers: {conf['num_layers']}. hidden_dim: {conf['hidden_dim']}. dropout_ratio: {conf['dropout_ratio']}"
362 | )
363 | logger.info(f"# params {sum(p.numel() for p in model.parameters())}")
364 |
365 | """ Saving student outputs """
366 | out_np = out.detach().cpu().numpy()
367 | np.savez(output_dir.joinpath("out"), out_np)
368 |
369 | """ Saving loss curve and model """
370 | if args.save_results:
371 | # Loss curves
372 | loss_and_score = np.array(loss_and_score)
373 | np.savez(output_dir.joinpath("loss_and_score"), loss_and_score)
374 |
375 | # Model
376 | torch.save(model.state_dict(), output_dir.joinpath("model.pth"))
377 |
378 | """ Saving min-cut loss"""
379 | if args.exp_setting == "tran" and args.compute_min_cut:
380 | min_cut = compute_min_cut_loss(g, out)
381 | with open(output_dir.parent.joinpath("min_cut_loss"), "a+") as f:
382 | f.write(f"{min_cut :.4f}\n")
383 |
384 | return score_lst
385 |
386 |
387 | def repeat_run(args):
388 | scores = []
389 | for seed in range(args.num_exp):
390 | args.seed = seed
391 | scores.append(run(args))
392 | scores_np = np.array(scores)
393 | return scores_np.mean(axis=0), scores_np.std(axis=0)
394 |
395 |
396 | def main():
397 | args = get_args()
398 | if args.num_exp == 1:
399 | score = run(args)
400 | score_str = "".join([f"{s : .4f}\t" for s in score])
401 |
402 | elif args.num_exp > 1:
403 | score_mean, score_std = repeat_run(args)
404 | score_str = "".join(
405 | [f"{s : .4f}\t" for s in score_mean] + [f"{s : .4f}\t" for s in score_std]
406 | )
407 |
408 | with open(args.output_dir.parent.joinpath("exp_results"), "a+") as f:
409 | f.write(f"{score_str}\n")
410 |
411 | # for collecting aggregated results
412 | print(score_str)
413 |
414 |
415 | if __name__ == "__main__":
416 | main()
417 |
--------------------------------------------------------------------------------
/train_teacher.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import torch
4 | import torch.optim as optim
5 | from pathlib import Path
6 | from models import Model
7 | from dataloader import load_data
8 | from utils import (
9 | get_logger,
10 | get_evaluator,
11 | set_seed,
12 | get_training_config,
13 | check_writable,
14 | compute_min_cut_loss,
15 | graph_split,
16 | feature_prop,
17 | )
18 | from train_and_eval import run_transductive, run_inductive
19 |
20 |
21 | def get_args():
22 | parser = argparse.ArgumentParser(description="PyTorch DGL implementation")
23 | parser.add_argument("--device", type=int, default=-1, help="CUDA device, -1 means CPU")
24 | parser.add_argument("--seed", type=int, default=0, help="Random seed")
25 | parser.add_argument(
26 | "--log_level",
27 | type=int,
28 | default=20,
29 | help="Logger levels for run {10: DEBUG, 20: INFO, 30: WARNING}",
30 | )
31 | parser.add_argument(
32 | "--console_log",
33 | action="store_true",
34 | help="Set to True to display log info in console",
35 | )
36 | parser.add_argument(
37 | "--output_path", type=str, default="outputs", help="Path to save outputs"
38 | )
39 | parser.add_argument(
40 | "--num_exp", type=int, default=1, help="Repeat how many experiments"
41 | )
42 | parser.add_argument(
43 | "--exp_setting",
44 | type=str,
45 | default="tran",
46 | help="Experiment setting, one of [tran, ind]",
47 | )
48 | parser.add_argument(
49 | "--eval_interval", type=int, default=1, help="Evaluate once per how many epochs"
50 | )
51 | parser.add_argument(
52 | "--save_results",
53 | action="store_true",
54 | help="Set to True to save the loss curves, trained model, and min-cut loss for the transductive setting",
55 | )
56 |
57 | """Dataset"""
58 | parser.add_argument("--dataset", type=str, default="cora", help="Dataset")
59 | parser.add_argument("--data_path", type=str, default="./data", help="Path to data")
60 | parser.add_argument(
61 | "--labelrate_train",
62 | type=int,
63 | default=20,
64 | help="How many labeled data per class as train set",
65 | )
66 | parser.add_argument(
67 | "--labelrate_val",
68 | type=int,
69 | default=30,
70 | help="How many labeled data per class in valid set",
71 | )
72 | parser.add_argument(
73 | "--split_idx",
74 | type=int,
75 | default=0,
76 | help="For Non-Homo datasets only, one of [0,1,2,3,4]",
77 | )
78 |
79 | """Model"""
80 | parser.add_argument(
81 | "--model_config_path",
82 | type=str,
83 | default="./train.conf.yaml",
84 | help="Path to model configeration",
85 | )
86 | parser.add_argument("--teacher", type=str, default="SAGE", help="Teacher model")
87 | parser.add_argument(
88 | "--num_layers", type=int, default=2, help="Model number of layers"
89 | )
90 | parser.add_argument(
91 | "--hidden_dim", type=int, default=128, help="Model hidden layer dimensions"
92 | )
93 | parser.add_argument("--dropout_ratio", type=float, default=0)
94 | parser.add_argument(
95 | "--norm_type", type=str, default="none", help="One of [none, batch, layer]"
96 | )
97 |
98 | """SAGE Specific"""
99 | parser.add_argument("--batch_size", type=int, default=512)
100 | parser.add_argument(
101 | "--fan_out",
102 | type=str,
103 | default="5,5",
104 | help="Number of samples for each layer in SAGE. Length = num_layers",
105 | )
106 | parser.add_argument(
107 | "--num_workers", type=int, default=0, help="Number of workers for sampler"
108 | )
109 |
110 | """Optimization"""
111 | parser.add_argument("--learning_rate", type=float, default=0.01)
112 | parser.add_argument("--weight_decay", type=float, default=0.0005)
113 | parser.add_argument(
114 | "--max_epoch", type=int, default=500, help="Evaluate once per how many epochs"
115 | )
116 | parser.add_argument(
117 | "--patience",
118 | type=int,
119 | default=50,
120 | help="Early stop is the score on validation set does not improve for how many epochs",
121 | )
122 |
123 | """Ablation"""
124 | parser.add_argument(
125 | "--feature_noise",
126 | type=float,
127 | default=0,
128 | help="add white noise to features for analysis, value in [0, 1] for noise level",
129 | )
130 | parser.add_argument(
131 | "--split_rate",
132 | type=float,
133 | default=0.2,
134 | help="Rate for graph split, see comment of graph_split for more details",
135 | )
136 | parser.add_argument(
137 | "--compute_min_cut",
138 | action="store_true",
139 | help="Set to True to compute and store the min-cut loss",
140 | )
141 | parser.add_argument(
142 | "--feature_aug_k",
143 | type=int,
144 | default=0,
145 | help="Augment node futures by aggregating feature_aug_k-hop neighbor features",
146 | )
147 |
148 | args = parser.parse_args()
149 | return args
150 |
151 |
152 | def run(args):
153 | """
154 | Returns:
155 | score_lst: a list of evaluation results on test set.
156 | len(score_lst) = 1 for the transductive setting.
157 | len(score_lst) = 2 for the inductive/production setting.
158 | """
159 |
160 | """ Set seed, device, and logger """
161 | set_seed(args.seed)
162 | if torch.cuda.is_available() and args.device >= 0:
163 | device = torch.device("cuda:" + str(args.device))
164 | else:
165 | device = "cpu"
166 |
167 | if args.feature_noise != 0 and args.seed == 0:
168 | args.output_path = Path.cwd().joinpath(
169 | args.output_path, "noisy_features", f"noise_{args.feature_noise}"
170 | )
171 |
172 | if args.feature_aug_k > 0 and args.seed == 0:
173 | args.output_path = Path.cwd().joinpath(
174 | args.output_path, "aug_features", f"aug_hop_{args.feature_aug_k}"
175 | )
176 | args.teacher = f"GA{args.feature_aug_k}{args.teacher}"
177 |
178 | if args.exp_setting == "tran":
179 | output_dir = Path.cwd().joinpath(
180 | args.output_path,
181 | "transductive",
182 | args.dataset,
183 | args.teacher,
184 | f"seed_{args.seed}",
185 | )
186 | elif args.exp_setting == "ind":
187 | output_dir = Path.cwd().joinpath(
188 | args.output_path,
189 | "inductive",
190 | f"split_rate_{args.split_rate}",
191 | args.dataset,
192 | args.teacher,
193 | f"seed_{args.seed}",
194 | )
195 | else:
196 | raise ValueError(f"Unknown experiment setting! {args.exp_setting}")
197 | args.output_dir = output_dir
198 |
199 | check_writable(output_dir, overwrite=False)
200 | logger = get_logger(output_dir.joinpath("log"), args.console_log, args.log_level)
201 | logger.info(f"output_dir: {output_dir}")
202 |
203 | """ Load data """
204 | g, labels, idx_train, idx_val, idx_test = load_data(
205 | args.dataset,
206 | args.data_path,
207 | split_idx=args.split_idx,
208 | seed=args.seed,
209 | labelrate_train=args.labelrate_train,
210 | labelrate_val=args.labelrate_val,
211 | )
212 | logger.info(f"Total {g.number_of_nodes()} nodes.")
213 | logger.info(f"Total {g.number_of_edges()} edges.")
214 |
215 | feats = g.ndata["feat"]
216 | args.feat_dim = g.ndata["feat"].shape[1]
217 | args.label_dim = labels.int().max().item() + 1
218 |
219 | if 0 < args.feature_noise <= 1:
220 | feats = (
221 | 1 - args.feature_noise
222 | ) * feats + args.feature_noise * torch.randn_like(feats)
223 |
224 | """ Model config """
225 | conf = {}
226 | if args.model_config_path is not None:
227 | conf = get_training_config(args.model_config_path, args.teacher, args.dataset)
228 | conf = dict(args.__dict__, **conf)
229 | conf["device"] = device
230 | logger.info(f"conf: {conf}")
231 |
232 | """ Model init """
233 | model = Model(conf)
234 | optimizer = optim.Adam(
235 | model.parameters(), lr=conf["learning_rate"], weight_decay=conf["weight_decay"]
236 | )
237 | criterion = torch.nn.NLLLoss()
238 | evaluator = get_evaluator(conf["dataset"])
239 |
240 | """ Data split and run """
241 | loss_and_score = []
242 | if args.exp_setting == "tran":
243 | indices = (idx_train, idx_val, idx_test)
244 |
245 | # propagate node feature
246 | if args.feature_aug_k > 0:
247 | feats = feature_prop(feats, g, args.feature_aug_k)
248 |
249 | out, score_val, score_test = run_transductive(
250 | conf,
251 | model,
252 | g,
253 | feats,
254 | labels,
255 | indices,
256 | criterion,
257 | evaluator,
258 | optimizer,
259 | logger,
260 | loss_and_score,
261 | )
262 | score_lst = [score_test]
263 |
264 | elif args.exp_setting == "ind":
265 | indices = graph_split(idx_train, idx_val, idx_test, args.split_rate, args.seed)
266 |
267 | # propagate node feature. The propagation for the observed graph only happens within the subgraph obs_g
268 | if args.feature_aug_k > 0:
269 | idx_obs = indices[3]
270 | obs_g = g.subgraph(idx_obs)
271 | obs_feats = feature_prop(feats[idx_obs], obs_g, args.feature_aug_k)
272 | feats = feature_prop(feats, g, args.feature_aug_k)
273 | feats[idx_obs] = obs_feats
274 |
275 | out, score_val, score_test_tran, score_test_ind = run_inductive(
276 | conf,
277 | model,
278 | g,
279 | feats,
280 | labels,
281 | indices,
282 | criterion,
283 | evaluator,
284 | optimizer,
285 | logger,
286 | loss_and_score,
287 | )
288 | score_lst = [score_test_tran, score_test_ind]
289 |
290 | logger.info(
291 | f"num_layers: {conf['num_layers']}. hidden_dim: {conf['hidden_dim']}. dropout_ratio: {conf['dropout_ratio']}"
292 | )
293 | logger.info(f"# params {sum(p.numel() for p in model.parameters())}")
294 |
295 | """ Saving teacher outputs """
296 | out_np = out.detach().cpu().numpy()
297 | np.savez(output_dir.joinpath("out"), out_np)
298 |
299 | """ Saving loss curve and model """
300 | if args.save_results:
301 | # Loss curves
302 | loss_and_score = np.array(loss_and_score)
303 | np.savez(output_dir.joinpath("loss_and_score"), loss_and_score)
304 |
305 | # Model
306 | torch.save(model.state_dict(), output_dir.joinpath("model.pth"))
307 |
308 | """ Saving min-cut loss """
309 | if args.exp_setting == "tran" and args.compute_min_cut:
310 | min_cut = compute_min_cut_loss(g, out)
311 | with open(output_dir.parent.joinpath("min_cut_loss"), "a+") as f:
312 | f.write(f"{min_cut :.4f}\n")
313 |
314 | return score_lst
315 |
316 |
317 | def repeat_run(args):
318 | scores = []
319 | for seed in range(args.num_exp):
320 | args.seed = seed
321 | scores.append(run(args))
322 | scores_np = np.array(scores)
323 | return scores_np.mean(axis=0), scores_np.std(axis=0)
324 |
325 |
326 | def main():
327 | args = get_args()
328 | if args.num_exp == 1:
329 | score = run(args)
330 | score_str = "".join([f"{s : .4f}\t" for s in score])
331 |
332 | elif args.num_exp > 1:
333 | score_mean, score_std = repeat_run(args)
334 | score_str = "".join(
335 | [f"{s : .4f}\t" for s in score_mean] + [f"{s : .4f}\t" for s in score_std]
336 | )
337 |
338 | with open(args.output_dir.parent.joinpath("exp_results"), "a+") as f:
339 | f.write(f"{score_str}\n")
340 |
341 | # for collecting aggregated results
342 | print(score_str)
343 |
344 |
345 | if __name__ == "__main__":
346 | main()
347 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import logging
4 | import pytz
5 | import random
6 | import os
7 | import yaml
8 | import shutil
9 | from datetime import datetime
10 | from ogb.nodeproppred import Evaluator
11 | from dgl import function as fn
12 |
13 | CPF_data = ["cora", "citeseer", "pubmed", "a-computer", "a-photo"]
14 | OGB_data = ["ogbn-arxiv", "ogbn-products"]
15 | NonHom_data = ["pokec", "penn94"]
16 | BGNN_data = ["house_class", "vk_class"]
17 |
18 |
19 | def set_seed(seed):
20 | torch.manual_seed(seed)
21 | np.random.seed(seed)
22 | random.seed(seed)
23 | torch.backends.cudnn.benchmark = False
24 | torch.backends.cudnn.deterministic = True
25 | if torch.cuda.is_available():
26 | torch.cuda.manual_seed_all(seed)
27 |
28 |
29 | def get_training_config(config_path, model_name, dataset):
30 | with open(config_path, "r") as conf:
31 | full_config = yaml.load(conf, Loader=yaml.FullLoader)
32 | dataset_specific_config = full_config["global"]
33 | model_specific_config = full_config[dataset][model_name]
34 |
35 | if model_specific_config is not None:
36 | specific_config = dict(dataset_specific_config, **model_specific_config)
37 | else:
38 | specific_config = dataset_specific_config
39 |
40 | specific_config["model_name"] = model_name
41 | return specific_config
42 |
43 |
44 | def check_writable(path, overwrite=True):
45 | if not os.path.exists(path):
46 | os.makedirs(path)
47 | elif overwrite:
48 | shutil.rmtree(path)
49 | os.makedirs(path)
50 | else:
51 | pass
52 |
53 |
54 | def check_readable(path):
55 | if not os.path.exists(path):
56 | raise ValueError(f"No such file or directory! {path}")
57 |
58 |
59 | def timetz(*args):
60 | tz = pytz.timezone("US/Pacific")
61 | return datetime.now(tz).timetuple()
62 |
63 |
64 | def get_logger(filename, console_log=False, log_level=logging.INFO):
65 | tz = pytz.timezone("US/Pacific")
66 | log_time = datetime.now(tz).strftime("%b%d_%H_%M_%S")
67 | logger = logging.getLogger(__name__)
68 | logger.propagate = False # avoid duplicate logging
69 | logger.setLevel(log_level)
70 |
71 | # Clean logger first to avoid duplicated handlers
72 | for hdlr in logger.handlers[:]:
73 | logger.removeHandler(hdlr)
74 |
75 | file_handler = logging.FileHandler(filename)
76 | formatter = logging.Formatter("%(asctime)s: %(message)s", datefmt="%b%d %H-%M-%S")
77 | formatter.converter = timetz
78 | file_handler.setFormatter(formatter)
79 | logger.addHandler(file_handler)
80 |
81 | if console_log:
82 | console_handler = logging.StreamHandler()
83 | console_handler.setFormatter(formatter)
84 | logger.addHandler(console_handler)
85 | return logger
86 |
87 |
88 | def idx_split(idx, ratio, seed=0):
89 | """
90 | randomly split idx into two portions with ratio% elements and (1 - ratio)% elements
91 | """
92 | set_seed(seed)
93 | n = len(idx)
94 | cut = int(n * ratio)
95 | idx_idx_shuffle = torch.randperm(n)
96 |
97 | idx1_idx, idx2_idx = idx_idx_shuffle[:cut], idx_idx_shuffle[cut:]
98 | idx1, idx2 = idx[idx1_idx], idx[idx2_idx]
99 | # assert((torch.cat([idx1, idx2]).sort()[0] == idx.sort()[0]).all())
100 | return idx1, idx2
101 |
102 |
103 | def graph_split(idx_train, idx_val, idx_test, rate, seed):
104 | """
105 | Args:
106 | The original setting was transductive. Full graph is observed, and idx_train takes up a small portion.
107 | Split the graph by further divide idx_test into [idx_test_tran, idx_test_ind].
108 | rate = idx_test_ind : idx_test (how much test to hide for the inductive evaluation)
109 |
110 | Ex. Ogbn-products
111 | loaded : train : val : test = 8 : 2 : 90, rate = 0.2
112 | after split: train : val : test_tran : test_ind = 8 : 2 : 72 : 18
113 |
114 | Return:
115 | Indices start with 'obs_' correspond to the node indices within the observed subgraph,
116 | where as indices start directly with 'idx_' correspond to the node indices in the original graph
117 | """
118 | idx_test_ind, idx_test_tran = idx_split(idx_test, rate, seed)
119 |
120 | idx_obs = torch.cat([idx_train, idx_val, idx_test_tran])
121 | N1, N2 = idx_train.shape[0], idx_val.shape[0]
122 | obs_idx_all = torch.arange(idx_obs.shape[0])
123 | obs_idx_train = obs_idx_all[:N1]
124 | obs_idx_val = obs_idx_all[N1 : N1 + N2]
125 | obs_idx_test = obs_idx_all[N1 + N2 :]
126 |
127 | return obs_idx_train, obs_idx_val, obs_idx_test, idx_obs, idx_test_ind
128 |
129 |
130 | def get_evaluator(dataset):
131 | if dataset in CPF_data + NonHom_data + BGNN_data:
132 |
133 | def evaluator(out, labels):
134 | pred = out.argmax(1)
135 | return pred.eq(labels).float().mean().item()
136 |
137 | elif dataset in OGB_data:
138 | ogb_evaluator = Evaluator(dataset)
139 |
140 | def evaluator(out, labels):
141 | pred = out.argmax(1, keepdim=True)
142 | input_dict = {"y_true": labels.unsqueeze(1), "y_pred": pred}
143 | return ogb_evaluator.eval(input_dict)["acc"]
144 |
145 | else:
146 | raise ValueError("Unknown dataset")
147 |
148 | return evaluator
149 |
150 |
151 | def get_evaluator(dataset):
152 | def evaluator(out, labels):
153 | pred = out.argmax(1)
154 | return pred.eq(labels).float().mean().item()
155 |
156 | return evaluator
157 |
158 |
159 | def compute_min_cut_loss(g, out):
160 | out = out.to("cpu")
161 | S = out.exp()
162 | A = g.adj().to_dense()
163 | D = g.in_degrees().float().diag()
164 | min_cut = (
165 | torch.matmul(torch.matmul(S.transpose(1, 0), A), S).trace()
166 | / torch.matmul(torch.matmul(S.transpose(1, 0), D), S).trace()
167 | )
168 | return min_cut.item()
169 |
170 |
171 | def feature_prop(feats, g, k):
172 | """
173 | Augment node feature by propagating the node features within k-hop neighborhood.
174 | The propagation is done in the SGC fashion, i.e. hop by hop and symmetrically normalized by node degrees.
175 | """
176 | assert feats.shape[0] == g.num_nodes()
177 |
178 | degs = g.in_degrees().float().clamp(min=1)
179 | norm = torch.pow(degs, -0.5).unsqueeze(1)
180 |
181 | # compute (D^-1/2 A D^-1/2)^k X
182 | for _ in range(k):
183 | feats = feats * norm
184 | g.ndata["h"] = feats
185 | g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
186 | feats = g.ndata.pop("h")
187 | feats = feats * norm
188 |
189 | return feats
190 |
--------------------------------------------------------------------------------