├── .gitignore ├── README.md ├── arxiv_dgl ├── README.md ├── criterion.py ├── gat.py ├── load_checkpoint.py ├── models.py ├── scripts │ ├── gat-teachers.sh │ ├── run_all.sh │ └── run_all_kd_and_aux.sh ├── sign.py ├── submit.py ├── submit_sign.py ├── test_timing_gat.py └── test_timing_sign.py ├── arxiv_pyg ├── README.md ├── correlation.py ├── criterion.py ├── gnn.py ├── gnn_kd_and_aux.py ├── logger.py ├── scripts │ ├── run_gcn.sh │ ├── run_kd_and_aux.sh │ └── run_sage.sh ├── submit.py └── test.py ├── awesome-efficient-gnns.md ├── img ├── architectures.png ├── arxiv-mag.png ├── computing-gnns.PNG ├── degree-quant.PNG ├── distillation.PNG ├── dual-nas.PNG ├── graph-saint.PNG ├── molhiv.png ├── ogb-lsc.PNG ├── pinsage.PNG ├── pipeline.png ├── ppi.png ├── sgc.PNG └── techniques.png ├── mag_pyg ├── README.md ├── criterion.py ├── gnn.py ├── gnn_kd_and_aux.py ├── logger.py ├── scripts │ ├── run.sh │ ├── run_kd_and_aux.sh │ └── teacher.sh ├── submit.py └── test.py ├── mol_pyg └── README.md └── ppi_pyg ├── README.md ├── criterion.py ├── gnn.py ├── logger.py ├── scripts ├── baselines.sh ├── run.sh └── train_teacher.sh ├── submit.py └── train_teacher.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | *.ipynb 3 | dataset/ 4 | logs/ 5 | .vscode/ 6 | .DS_Store 7 | 8 | arxiv_dgl/dataset/ 9 | arxiv_dgl/output/ 10 | arxiv_dgl/logits/ 11 | arxiv_dgl/features/ 12 | arxiv_dgl/checkpoints/ 13 | 14 | arxiv_pyg/dataset 15 | arxiv_pyg/logs 16 | 17 | ppi_pyg/data 18 | ppi_pyg/logs 19 | ppi_pyg/checkpoints 20 | 21 | mag_pyg/data 22 | mag_pyg/logs 23 | 24 | # Byte-compiled / optimized / DLL files 25 | __pycache__/ 26 | *.py[cod] 27 | *$py.class 28 | 29 | # C extensions 30 | *.so 31 | 32 | # Distribution / packaging 33 | .Python 34 | build/ 35 | develop-eggs/ 36 | dist/ 37 | downloads/ 38 | eggs/ 39 | .eggs/ 40 | lib/ 41 | lib64/ 42 | parts/ 43 | sdist/ 44 | var/ 45 | wheels/ 46 | pip-wheel-metadata/ 47 | share/python-wheels/ 48 | *.egg-info/ 49 | .installed.cfg 50 | *.egg 51 | MANIFEST 52 | 53 | # PyInstaller 54 | # Usually these files are written by a python script from a template 55 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 56 | *.manifest 57 | *.spec 58 | 59 | # Installer logs 60 | pip-log.txt 61 | pip-delete-this-directory.txt 62 | 63 | # Unit test / coverage reports 64 | htmlcov/ 65 | .tox/ 66 | .nox/ 67 | .coverage 68 | .coverage.* 69 | .cache 70 | nosetests.xml 71 | coverage.xml 72 | *.cover 73 | *.py,cover 74 | .hypothesis/ 75 | .pytest_cache/ 76 | 77 | # Translations 78 | *.mo 79 | *.pot 80 | 81 | # Django stuff: 82 | *.log 83 | local_settings.py 84 | db.sqlite3 85 | db.sqlite3-journal 86 | 87 | # Flask stuff: 88 | instance/ 89 | .webassets-cache 90 | 91 | # Scrapy stuff: 92 | .scrapy 93 | 94 | # Sphinx documentation 95 | docs/_build/ 96 | 97 | # PyBuilder 98 | target/ 99 | 100 | # Jupyter Notebook 101 | .ipynb_checkpoints 102 | 103 | # IPython 104 | profile_default/ 105 | ipython_config.py 106 | 107 | # pyenv 108 | .python-version 109 | 110 | # pipenv 111 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 112 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 113 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 114 | # install all needed dependencies. 115 | #Pipfile.lock 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Knowledge Distillation for Resource-efficient Graph Neural Networks 2 | 3 | This repository provides resources on Graph Neural Network efficiency and scalability, as well as implementations of knowledge distillation techniques for developing resource-efficient GNNs. 4 | 5 | ![Knowledge distillation pipeline for GNNs](img/pipeline.png) 6 | 7 | Check out the accompanying paper ['On Representation Knowledge Distillation for Graph Neural Networks'](https://arxiv.org/abs/2111.04964), which introduces new GNN distillation techniques using contrastive learning to preserve the global topology of teacher and student embeddings. 8 | 9 | > Chaitanya K. Joshi, Fayao Liu, Xu Xun, Jie Lin, and Chuan Sheng Foo. On Representation Knowledge Distillation for Graph Neural Networks. IEEE Transactions on Neural Networks and Learning Systems (TNNLS), *Special Issue on Deep Neural Networks for Graphs: Theory, Models, Algorithms and Applications*. 10 | > 11 | > [PDF](https://arxiv.org/pdf/2111.04964.pdf) | [Blog](https://www.chaitjo.com/post/efficient-gnns/) 12 | 13 | ❓New to GNN scalability: See [`awesome-efficient-gnns.md`](awesome-efficient-gnns.md) and the [accompanying blogpost](https://www.chaitjo.com/post/efficient-gnns/) for a currated overview of papers on efficient and scalable Graph Representation Learning for real-world applications. 14 | 15 | ## Distillation Techniques 16 | 17 | ![Representation distillation techniques for GNNs](img/techniques.png) 18 | 19 | We benchmark the following knowledge distillation techniques for GNNs: 20 | - **Local Structure Preserving loss**, [Yang et al., CVPR 2020](https://arxiv.org/abs/2003.10477): preserve pairwise relationships over graph edges, but may not preserve global topology due to latent interactions. 21 | - **Global Structure Preserving loss**, [Joshi et al., TNNLS 2022](https://arxiv.org/abs/2111.04964): preserve all pairwise global relationships, but computationally more cumbersome. 22 | - 🌟 **Graph Contrastive Representation Distillation**, [Joshi et al., TNNLS 2022](https://arxiv.org/abs/2111.04964): contrastive learning among positive/negative pairwise relations across the teacher and student embedding spaces. 23 | - 🔥 Your new GNN distillation technique? 24 | 25 | We also include baselines: **Logit-based KD**, [Hinton et al., 2015](https://arxiv.org/abs/1503.02531); and feature mimicking baselines for computer vision: **FitNet**, [Romero et al., 2014](https://arxiv.org/abs/1412.6550), **Attention Transfer**, [Zagoruyko and Komodakis, 2016](https://arxiv.org/abs/1612.03928). 26 | 27 | ## Datasets and Architectures 28 | 29 | We conduct benchmarks on large-scale and real-world graph datasets, where the performance gap between expressive+cumbersome teacher and resource-efficient student GNNs is non-negligible: 30 | - **Graph classification** on `MOLHIV` from Open Graph Benchmark/MoleculeNet -- GIN-E/PNA teachers, GCN/GIN students. 31 | - **Node classification** on `ARXIV` and `MAG` from Open Graph Benchmark and Microsoft Academic Graph -- GAT/R-GCN teachers, GCN/GraphSage/SIGN students. 32 | - **3D point cloud segmentation** on `S3DIS` -- not released publicly yet. 33 | - **Node classification** on `PPI` -- provided to reproduce results from [Yang et al.](https://arxiv.org/abs/2003.10477) 34 | 35 | ## Installation and Usage 36 | 37 | Our results are reported with Python 3.7, PyTorch, 1.7.1, and CUDA 11.0. 38 | We used the following GPUs: RTX3090 for ARXIV/MAG, V100 for MOLHIV/S3DIS. 39 | 40 | Usage instructions for each dataset are provided within the corresponding directory. 41 | 42 | ```sh 43 | # Create new conda environment 44 | conda create -n ogb python=3.7 45 | conda activate ogb 46 | 47 | # Install PyTorch (Check CUDA version!) 48 | conda install pytorch=1.7.1 cudatoolkit=11.0 -c pytorch 49 | 50 | # Install DGL 51 | conda install -c dglteam dgl-cuda11.0 52 | 53 | # Install PyG 54 | CUDA=cu110 55 | pip3 install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.1+${CUDA}.html 56 | pip3 install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.1+${CUDA}.html 57 | pip3 install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.7.1+${CUDA}.html 58 | pip3 install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.7.1+${CUDA}.html 59 | pip3 install torch-geometric 60 | 61 | # Install other dependencies 62 | conda install tqdm scikit-learn pandas urllib3 tensorboard 63 | pip3 install ipdb, nvidia-ml-py3 64 | 65 | # Install OGB 66 | pip3 install -U ogb 67 | ``` 68 | 69 | ## Citation 70 | 71 | ``` 72 | @article{joshi2022representation, 73 | title={On Representation Knowledge Distillation for Graph Neural Networks}, 74 | author={Chaitanya K. Joshi and Fayao Liu and Xu Xun and Jie Lin and Chuan-Sheng Foo}, 75 | journal={IEEE Transactions on Neural Networks and Learning Systems}, 76 | year={2022} 77 | } 78 | 79 | @article{joshi2022efficientgnns, 80 | author = {Joshi, Chaitanya K.}, 81 | title = {Recent Advances in Efficient and Scalable Graph Neural Networks}, 82 | year = {2022}, 83 | howpublished = {\url{https://www.chaitjo.com/post/efficient-gnns/}}, 84 | } 85 | ``` 86 | -------------------------------------------------------------------------------- /arxiv_dgl/README.md: -------------------------------------------------------------------------------- 1 | # Knowledge Distillation for GNNs (ARXIV with DGL) 2 | 3 | **Dataset**: ARXIV 4 | 5 | **Library**: DGL 6 | 7 | This repository contains code to benchmark knowledge distillation for GNNs on the ARXIV dataset, developed in the DGL framework. 8 | The main purpose of the DGL codebase is to: 9 | - Train teacher models (which are 3 layer GATs) which were giving OOM errors when trained via PyG. Code adapted from: https://github.com/dmlc/dgl/tree/master/examples/pytorch/ogb/ogbn-arxiv 10 | - Train graph-agnostic SIGN models, for which the best open-source implementation was in DGL: https://github.com/dmlc/dgl/tree/master/examples/pytorch/ogb/sign 11 | 12 | Note that other student GNNs like GCN/GraphSage are trained on ARXIV via the PyG implementation in `arxiv_pyg`. 13 | 14 | ## Directory Structure 15 | 16 | ``` 17 | . 18 | ├── dataset # automatically created by OGB data downloaders 19 | ├── checkpoints # saves GAT teacher checkpoints 20 | ├── features # saves GAT teacher final hidden features 21 | ├── logits # saves GAT teacher predicted logits 22 | ├── output # saves GAT teacher predicted outputs 23 | ├── logs # logging directory for SIGN models 24 | | 25 | ├── scripts # scripts to conduct full experiments and reproduce results 26 | │ ├── gat-teachers.sh # trains GAT teachers for 10 seeds 27 | │ ├── run_all_kd_and_aux.sh # script to benchmark KD techniques 28 | │ └── run_all.sh # script to benchmark KD techniques 29 | | 30 | ├── README.md 31 | | 32 | ├── criterion.py # KD loss functions 33 | ├── gat.py # trains GAT teacher and dumps checkpoints, features, logits, outputs, etc. 34 | ├── load_checkpoint.py # loads and evaluates a GAT checkpoint on the test set 35 | ├── models.py # model definitions for GAT and GCN 36 | ├── sign.py # trains SIGN models with/without KD 37 | ├── submit.py # load and evaluate GAT model predictions on the test set 38 | ├── submit_sign.py # load and evaluate SIGN models on the test set 39 | ├── test_timing_gat.py # this script was used to measure inference time 40 | └── test_timing_sign.py # this script was used to measure inference time 41 | ``` 42 | 43 | ## Example Usage 44 | 45 | For full usage, each file has accompanying flags and documentation. 46 | Also see the `scripts` folder for reproducing results. 47 | -------------------------------------------------------------------------------- /arxiv_dgl/criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | from torch_geometric.utils import softmax 6 | 7 | 8 | def kd_criterion(logits, labels, teacher_logits, alpha=0.9, T=4): 9 | """Logit-based KD, [Hinton et al., 2015](https://arxiv.org/abs/1503.02531) 10 | """ 11 | loss_cls = F.cross_entropy(logits, labels) 12 | 13 | loss_kd = F.kl_div( 14 | F.log_softmax(logits/ T, dim=1), 15 | F.softmax(teacher_logits/ T, dim=1), 16 | log_target=False 17 | ) 18 | 19 | loss = loss_kd* (alpha* T* T) + loss_cls* (1-alpha) 20 | 21 | return loss, loss_cls, loss_kd 22 | 23 | 24 | def fitnet_criterion(logits, labels, feat, teacher_feat, beta=1000): 25 | """FitNet, [Romero et al., 2014](https://arxiv.org/abs/1412.6550) 26 | """ 27 | loss_cls = F.cross_entropy(logits, labels) 28 | 29 | feat = F.normalize(feat, p=2, dim=-1) 30 | teacher_feat = F.normalize(teacher_feat, p=2, dim=-1) 31 | 32 | loss_fitnet = F.mse_loss(feat, teacher_feat) 33 | 34 | loss = loss_cls + beta* loss_fitnet 35 | 36 | return loss, loss_cls, loss_fitnet 37 | 38 | 39 | def at_criterion(logits, labels, feat, teacher_feat, beta=1000): 40 | """Attention Transfer, [Zagoruyko and Komodakis, 2016](https://arxiv.org/abs/1612.03928) 41 | """ 42 | loss_cls = F.cross_entropy(logits, labels) 43 | 44 | feat = feat.pow(2).sum(-1) 45 | teacher_feat = teacher_feat.pow(2).sum(-1) 46 | 47 | loss_at = F.mse_loss( 48 | F.normalize(feat, p=2, dim=-1), 49 | F.normalize(teacher_feat, p=2, dim=-1) 50 | ) 51 | 52 | loss = loss_cls + beta* loss_at 53 | 54 | return loss, loss_cls, loss_at 55 | 56 | 57 | def gpw_criterion(logits, labels, feat, teacher_feat, kernel='cosine', beta=1, max_samples=8192): 58 | """Global Structure Preserving loss, [Joshi et al., TNNLS 2022](https://arxiv.org/abs/2111.04964) 59 | """ 60 | loss_cls = F.cross_entropy(logits, labels) 61 | 62 | if max_samples < feat.shape[0]: 63 | sampled_inds = np.random.choice(feat.shape[0], max_samples, replace=False) 64 | feat = feat[sampled_inds] 65 | teacher_feat = teacher_feat[sampled_inds] 66 | 67 | pw_sim = None 68 | teacher_pw_sim = None 69 | if kernel == 'cosine': 70 | feat = F.normalize(feat, p=2, dim=-1) 71 | teacher_feat = F.normalize(teacher_feat, p=2, dim=-1) 72 | pw_sim = torch.mm(feat, feat.transpose(0, 1)).flatten() 73 | teacher_pw_sim = torch.mm(teacher_feat, teacher_feat.transpose(0, 1)).flatten() 74 | elif kernel == 'poly': 75 | feat = F.normalize(feat, p=2, dim=-1) 76 | teacher_feat = F.normalize(teacher_feat, p=2, dim=-1) 77 | pw_sim = torch.mm(feat, feat.transpose(0, 1)).flatten()**2 78 | teacher_pw_sim = torch.mm(teacher_feat, teacher_feat.transpose(0, 1)).flatten()**2 79 | elif kernel == 'l2': 80 | pw_sim = (feat.unsqueeze(0) - feat.unsqueeze(1)).norm(p=2, dim=-1).flatten() 81 | teacher_pw_sim = (teacher_feat.unsqueeze(0) - teacher_feat.unsqueeze(1)).norm(p=2, dim=-1).flatten() 82 | elif kernel == 'rbf': 83 | pw_sim = torch.exp(-0.5* ((feat.unsqueeze(0) - feat.unsqueeze(1))**2).sum(dim=-1).flatten()) 84 | teacher_pw_sim = torch.exp(-0.5* ((teacher_feat.unsqueeze(0) - teacher_feat.unsqueeze(1))**2).sum(dim=-1).flatten()) 85 | else: 86 | raise NotImplementedError 87 | 88 | loss_gpw = F.mse_loss(pw_sim, teacher_pw_sim) 89 | 90 | loss = loss_cls + beta* loss_gpw 91 | 92 | return loss, loss_cls, loss_gpw 93 | 94 | 95 | def nce_criterion(logits, labels, feat, teacher_feat, beta=0.5, nce_T=0.075, max_samples=8192): 96 | """Graph Contrastive Representation Distillation, [Joshi et al., TNNLS 2022](https://arxiv.org/abs/2111.04964) 97 | """ 98 | loss_cls = F.cross_entropy(logits, labels) 99 | 100 | if max_samples < feat.shape[0]: 101 | sampled_inds = np.random.choice(feat.shape[0], max_samples, replace=False) 102 | feat = feat[sampled_inds] 103 | teacher_feat = teacher_feat[sampled_inds] 104 | 105 | feat = F.normalize(feat, p=2, dim=-1) 106 | teacher_feat = F.normalize(teacher_feat, p=2, dim=-1) 107 | 108 | nce_logits = torch.mm(feat, teacher_feat.transpose(0, 1)) 109 | nce_labels = torch.arange(feat.shape[0]).to(feat.device) 110 | 111 | loss_nce = F.cross_entropy(nce_logits/ nce_T, nce_labels) 112 | 113 | loss = loss_cls + beta* loss_nce 114 | 115 | return loss, loss_cls, loss_nce 116 | -------------------------------------------------------------------------------- /arxiv_dgl/load_checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import math 6 | import os 7 | import random 8 | import time 9 | import glob 10 | 11 | import dgl 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | from matplotlib import pyplot as plt 17 | from matplotlib.ticker import AutoMinorLocator, MultipleLocator 18 | from ogb.nodeproppred import DglNodePropPredDataset, Evaluator 19 | 20 | from models import GAT 21 | 22 | epsilon = 1 - math.log(2) 23 | 24 | device = None 25 | 26 | dataset = "ogbn-arxiv" 27 | n_node_feats, n_classes = 0, 0 28 | 29 | 30 | def seed(seed=0): 31 | random.seed(seed) 32 | np.random.seed(seed) 33 | torch.manual_seed(seed) 34 | torch.cuda.manual_seed(seed) 35 | torch.cuda.manual_seed_all(seed) 36 | torch.backends.cudnn.deterministic = True 37 | torch.backends.cudnn.benchmark = False 38 | dgl.random.seed(seed) 39 | 40 | 41 | def load_data(dataset): 42 | global n_node_feats, n_classes 43 | 44 | data = DglNodePropPredDataset(name=dataset) 45 | evaluator = Evaluator(name=dataset) 46 | 47 | splitted_idx = data.get_idx_split() 48 | train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"] 49 | graph, labels = data[0] 50 | 51 | n_node_feats = graph.ndata["feat"].shape[1] 52 | n_classes = (labels.max() + 1).item() 53 | 54 | return graph, labels, train_idx, val_idx, test_idx, evaluator 55 | 56 | 57 | def preprocess(graph): 58 | global n_node_feats 59 | 60 | # make bidirected 61 | feat = graph.ndata["feat"] 62 | graph = dgl.to_bidirected(graph) 63 | graph.ndata["feat"] = feat 64 | 65 | # add self-loop 66 | print(f"Total edges before adding self-loop {graph.number_of_edges()}") 67 | graph = graph.remove_self_loop().add_self_loop() 68 | print(f"Total edges after adding self-loop {graph.number_of_edges()}") 69 | 70 | graph.create_formats_() 71 | 72 | return graph 73 | 74 | 75 | def gen_model(args): 76 | if args.use_labels: 77 | n_node_feats_ = n_node_feats + n_classes 78 | else: 79 | n_node_feats_ = n_node_feats 80 | 81 | model = GAT( 82 | n_node_feats_, 83 | n_classes, 84 | n_hidden=args.n_hidden, 85 | n_layers=args.n_layers, 86 | n_heads=args.n_heads, 87 | activation=F.relu, 88 | dropout=args.dropout, 89 | input_drop=args.input_drop, 90 | attn_drop=args.attn_drop, 91 | edge_drop=args.edge_drop, 92 | use_attn_dst=not args.no_attn_dst, 93 | use_symmetric_norm=args.use_norm, 94 | ) 95 | 96 | return model 97 | 98 | 99 | def custom_loss_function(x, labels): 100 | y = F.cross_entropy(x, labels[:, 0], reduction="none") 101 | y = torch.log(epsilon + y) - math.log(epsilon) 102 | return torch.mean(y) 103 | 104 | 105 | def add_labels(feat, labels, idx): 106 | onehot = torch.zeros([feat.shape[0], n_classes], device=device) 107 | onehot[idx, labels[idx, 0]] = 1 108 | return torch.cat([feat, onehot], dim=-1) 109 | 110 | 111 | @torch.no_grad() 112 | def evaluate(args, model, graph, labels, train_idx, val_idx, test_idx, evaluator): 113 | model.eval() 114 | 115 | feat = graph.ndata["feat"] 116 | 117 | if args.use_labels: 118 | feat = add_labels(feat, labels, train_idx) 119 | 120 | pred = model(graph, feat) 121 | 122 | if args.n_label_iters > 0: 123 | unlabel_idx = torch.cat([val_idx, test_idx]) 124 | for _ in range(args.n_label_iters): 125 | feat[unlabel_idx, -n_classes:] = F.softmax(pred[unlabel_idx], dim=-1) 126 | pred = model(graph, feat) 127 | 128 | train_loss = custom_loss_function(pred[train_idx], labels[train_idx]) 129 | val_loss = custom_loss_function(pred[val_idx], labels[val_idx]) 130 | test_loss = custom_loss_function(pred[test_idx], labels[test_idx]) 131 | 132 | model_feat = model.feat 133 | 134 | return ( 135 | evaluator(pred[train_idx], labels[train_idx]), 136 | evaluator(pred[val_idx], labels[val_idx]), 137 | evaluator(pred[test_idx], labels[test_idx]), 138 | train_loss, 139 | val_loss, 140 | test_loss, 141 | pred, 142 | model_feat 143 | ) 144 | 145 | 146 | def count_parameters(args): 147 | model = gen_model(args) 148 | return sum([p.numel() for p in model.parameters() if p.requires_grad]) 149 | 150 | 151 | def main(): 152 | global device, n_node_feats, n_classes, epsilon 153 | 154 | argparser = argparse.ArgumentParser( 155 | "GAT checkpoint on ogbn-arxiv", formatter_class=argparse.ArgumentDefaultsHelpFormatter 156 | ) 157 | argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides --gpu.") 158 | argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.") 159 | argparser.add_argument("--seed", type=int, default=0, help="seed") 160 | argparser.add_argument( 161 | "--use-labels", action="store_true", help="Use labels in the training set as input features." 162 | ) 163 | argparser.add_argument("--n-label-iters", type=int, default=1, help="number of label iterations") 164 | argparser.add_argument("--mask-rate", type=float, default=0.5, help="mask rate") 165 | argparser.add_argument("--no-attn-dst", action="store_true", help="Don't use attn_dst.") 166 | argparser.add_argument("--use-norm", action="store_true", help="Use symmetrically normalized adjacency matrix.") 167 | argparser.add_argument("--lr", type=float, default=0.002, help="learning rate") 168 | argparser.add_argument("--n-layers", type=int, default=3, help="number of layers") 169 | argparser.add_argument("--n-heads", type=int, default=3, help="number of heads") 170 | argparser.add_argument("--n-hidden", type=int, default=250, help="number of hidden units") 171 | argparser.add_argument("--dropout", type=float, default=0.75, help="dropout rate") 172 | argparser.add_argument("--input-drop", type=float, default=0.25, help="input drop rate") 173 | argparser.add_argument("--attn-drop", type=float, default=0.0, help="attention drop rate") 174 | argparser.add_argument("--edge-drop", type=float, default=0.3, help="edge drop rate") 175 | argparser.add_argument("--wd", type=float, default=0, help="weight decay") 176 | argparser.add_argument("--checkpoint-files", type=str, default="./checkpoints/*.pt", help="address of checkpoint files") 177 | args = argparser.parse_args() 178 | 179 | if not args.use_labels and args.n_label_iters > 0: 180 | raise ValueError("'--use-labels' must be enabled when n_label_iters > 0") 181 | 182 | if args.cpu: 183 | device = torch.device("cpu") 184 | else: 185 | device = torch.device(f"cuda:{args.gpu}") 186 | 187 | # load data & preprocess 188 | graph, labels, train_idx, val_idx, test_idx, evaluator = load_data(dataset) 189 | graph = preprocess(graph) 190 | 191 | graph, labels, train_idx, val_idx, test_idx = map( 192 | lambda x: x.to(device), (graph, labels, train_idx, val_idx, test_idx) 193 | ) 194 | 195 | evaluator_wrapper = lambda pred, labels: evaluator.eval( 196 | {"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels} 197 | )["acc"] 198 | 199 | # run 200 | val_accs, test_accs = [], [] 201 | 202 | for file in glob.iglob(args.checkpoint_files): 203 | print("load:", file) 204 | checkpoint = torch.load(file) 205 | 206 | # define model and optimizer 207 | model = gen_model(args).to(device) 208 | model.load_state_dict(checkpoint['model_state_dict']) 209 | # optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.wd) 210 | # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 211 | 212 | train_acc, val_acc, test_acc, train_loss, val_loss, test_loss, pred, model_feat = evaluate( 213 | args, model, graph, labels, train_idx, val_idx, test_idx, evaluator_wrapper 214 | ) 215 | val_accs.append(val_acc) 216 | test_accs.append(test_acc) 217 | 218 | n_running = file.split('/')[-1][:-3] 219 | os.makedirs(f"./logits", exist_ok=True) 220 | torch.save(pred, f"./logits/{n_running}.pt") 221 | os.makedirs(f"./features", exist_ok=True) 222 | torch.save(model_feat, f"./features/{n_running}.pt") 223 | 224 | print(args) 225 | print() 226 | print(f"Run {len(val_accs)} times") 227 | print(f"Average val accuracy: {np.mean(val_accs)*100:.2f} ± {np.std(val_accs)*100:.2f}") 228 | print(f"Average test accuracy: {np.mean(test_accs)*100:.2f} ± {np.std(test_accs)*100:.2f}") 229 | print(f"Number of params: {count_parameters(args)}") 230 | 231 | 232 | if __name__ == "__main__": 233 | main() 234 | -------------------------------------------------------------------------------- /arxiv_dgl/models.py: -------------------------------------------------------------------------------- 1 | import dgl.nn.pytorch as dglnn 2 | import torch 3 | import torch.nn as nn 4 | from dgl import function as fn 5 | from dgl._ffi.base import DGLError 6 | from dgl.nn.pytorch.utils import Identity 7 | from dgl.ops import edge_softmax 8 | from dgl.utils import expand_as_pair 9 | 10 | 11 | class ElementWiseLinear(nn.Module): 12 | def __init__(self, size, weight=True, bias=True, inplace=False): 13 | super().__init__() 14 | if weight: 15 | self.weight = nn.Parameter(torch.Tensor(size)) 16 | else: 17 | self.weight = None 18 | if bias: 19 | self.bias = nn.Parameter(torch.Tensor(size)) 20 | else: 21 | self.bias = None 22 | self.inplace = inplace 23 | 24 | self.reset_parameters() 25 | 26 | def reset_parameters(self): 27 | if self.weight is not None: 28 | nn.init.ones_(self.weight) 29 | if self.bias is not None: 30 | nn.init.zeros_(self.bias) 31 | 32 | def forward(self, x): 33 | if self.inplace: 34 | if self.weight is not None: 35 | x.mul_(self.weight) 36 | if self.bias is not None: 37 | x.add_(self.bias) 38 | else: 39 | if self.weight is not None: 40 | x = x * self.weight 41 | if self.bias is not None: 42 | x = x + self.bias 43 | return x 44 | 45 | 46 | class GCN(nn.Module): 47 | def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout, use_linear): 48 | super().__init__() 49 | self.n_layers = n_layers 50 | self.n_hidden = n_hidden 51 | self.n_classes = n_classes 52 | self.use_linear = use_linear 53 | 54 | self.convs = nn.ModuleList() 55 | if use_linear: 56 | self.linear = nn.ModuleList() 57 | self.norms = nn.ModuleList() 58 | 59 | for i in range(n_layers): 60 | in_hidden = n_hidden if i > 0 else in_feats 61 | out_hidden = n_hidden if i < n_layers - 1 else n_classes 62 | bias = i == n_layers - 1 63 | 64 | self.convs.append(dglnn.GraphConv(in_hidden, out_hidden, "both", bias=bias)) 65 | if use_linear: 66 | self.linear.append(nn.Linear(in_hidden, out_hidden, bias=False)) 67 | if i < n_layers - 1: 68 | self.norms.append(nn.BatchNorm1d(out_hidden)) 69 | 70 | self.input_drop = nn.Dropout(min(0.1, dropout)) 71 | self.dropout = nn.Dropout(dropout) 72 | self.activation = activation 73 | 74 | def forward(self, graph, feat): 75 | h = feat 76 | h = self.input_drop(h) 77 | 78 | for i in range(self.n_layers): 79 | conv = self.convs[i](graph, h) 80 | 81 | if self.use_linear: 82 | linear = self.linear[i](h) 83 | h = conv + linear 84 | else: 85 | h = conv 86 | 87 | if i < self.n_layers - 1: 88 | h = self.norms[i](h) 89 | h = self.activation(h) 90 | h = self.dropout(h) 91 | 92 | return h 93 | 94 | 95 | class GATConv(nn.Module): 96 | def __init__( 97 | self, 98 | in_feats, 99 | out_feats, 100 | num_heads=1, 101 | feat_drop=0.0, 102 | attn_drop=0.0, 103 | edge_drop=0.0, 104 | negative_slope=0.2, 105 | use_attn_dst=True, 106 | residual=False, 107 | activation=None, 108 | allow_zero_in_degree=False, 109 | use_symmetric_norm=False, 110 | ): 111 | super(GATConv, self).__init__() 112 | self._num_heads = num_heads 113 | self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) 114 | self._out_feats = out_feats 115 | self._allow_zero_in_degree = allow_zero_in_degree 116 | self._use_symmetric_norm = use_symmetric_norm 117 | if isinstance(in_feats, tuple): 118 | self.fc_src = nn.Linear(self._in_src_feats, out_feats * num_heads, bias=False) 119 | self.fc_dst = nn.Linear(self._in_dst_feats, out_feats * num_heads, bias=False) 120 | else: 121 | self.fc = nn.Linear(self._in_src_feats, out_feats * num_heads, bias=False) 122 | self.attn_l = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats))) 123 | if use_attn_dst: 124 | self.attn_r = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats))) 125 | else: 126 | self.register_buffer("attn_r", None) 127 | self.feat_drop = nn.Dropout(feat_drop) 128 | self.attn_drop = nn.Dropout(attn_drop) 129 | self.edge_drop = edge_drop 130 | self.leaky_relu = nn.LeakyReLU(negative_slope) 131 | if residual: 132 | self.res_fc = nn.Linear(self._in_dst_feats, num_heads * out_feats, bias=False) 133 | else: 134 | self.register_buffer("res_fc", None) 135 | self.reset_parameters() 136 | self._activation = activation 137 | 138 | def reset_parameters(self): 139 | gain = nn.init.calculate_gain("relu") 140 | if hasattr(self, "fc"): 141 | nn.init.xavier_normal_(self.fc.weight, gain=gain) 142 | else: 143 | nn.init.xavier_normal_(self.fc_src.weight, gain=gain) 144 | nn.init.xavier_normal_(self.fc_dst.weight, gain=gain) 145 | nn.init.xavier_normal_(self.attn_l, gain=gain) 146 | if isinstance(self.attn_r, nn.Parameter): 147 | nn.init.xavier_normal_(self.attn_r, gain=gain) 148 | if isinstance(self.res_fc, nn.Linear): 149 | nn.init.xavier_normal_(self.res_fc.weight, gain=gain) 150 | 151 | def set_allow_zero_in_degree(self, set_value): 152 | self._allow_zero_in_degree = set_value 153 | 154 | def forward(self, graph, feat): 155 | with graph.local_scope(): 156 | if not self._allow_zero_in_degree: 157 | if (graph.in_degrees() == 0).any(): 158 | assert False 159 | 160 | if isinstance(feat, tuple): 161 | h_src = self.feat_drop(feat[0]) 162 | h_dst = self.feat_drop(feat[1]) 163 | if not hasattr(self, "fc_src"): 164 | self.fc_src, self.fc_dst = self.fc, self.fc 165 | feat_src, feat_dst = h_src, h_dst 166 | feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats) 167 | feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats) 168 | else: 169 | h_src = self.feat_drop(feat) 170 | feat_src = h_src 171 | feat_src = self.fc(h_src).view(-1, self._num_heads, self._out_feats) 172 | if graph.is_block: 173 | h_dst = h_src[: graph.number_of_dst_nodes()] 174 | feat_dst = feat_src[: graph.number_of_dst_nodes()] 175 | else: 176 | h_dst = h_src 177 | feat_dst = feat_src 178 | 179 | if self._use_symmetric_norm: 180 | degs = graph.out_degrees().float().clamp(min=1) 181 | norm = torch.pow(degs, -0.5) 182 | shp = norm.shape + (1,) * (feat_src.dim() - 1) 183 | norm = torch.reshape(norm, shp) 184 | feat_src = feat_src * norm 185 | 186 | # NOTE: GAT paper uses "first concatenation then linear projection" 187 | # to compute attention scores, while ours is "first projection then 188 | # addition", the two approaches are mathematically equivalent: 189 | # We decompose the weight vector a mentioned in the paper into 190 | # [a_l || a_r], then 191 | # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j 192 | # Our implementation is much efficient because we do not need to 193 | # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus, 194 | # addition could be optimized with DGL's built-in function u_add_v, 195 | # which further speeds up computation and saves memory footprint. 196 | el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1) 197 | graph.srcdata.update({"ft": feat_src, "el": el}) 198 | # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively. 199 | if self.attn_r is not None: 200 | er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1) 201 | graph.dstdata.update({"er": er}) 202 | graph.apply_edges(fn.u_add_v("el", "er", "e")) 203 | else: 204 | graph.apply_edges(fn.copy_u("el", "e")) 205 | e = self.leaky_relu(graph.edata.pop("e")) 206 | 207 | if self.training and self.edge_drop > 0: 208 | perm = torch.randperm(graph.number_of_edges(), device=e.device) 209 | bound = int(graph.number_of_edges() * self.edge_drop) 210 | eids = perm[bound:] 211 | graph.edata["a"] = torch.zeros_like(e) 212 | graph.edata["a"][eids] = self.attn_drop(edge_softmax(graph, e[eids], eids=eids)) 213 | else: 214 | graph.edata["a"] = self.attn_drop(edge_softmax(graph, e)) 215 | 216 | # message passing 217 | graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft")) 218 | rst = graph.dstdata["ft"] 219 | 220 | if self._use_symmetric_norm: 221 | degs = graph.in_degrees().float().clamp(min=1) 222 | norm = torch.pow(degs, 0.5) 223 | shp = norm.shape + (1,) * (feat_dst.dim() - 1) 224 | norm = torch.reshape(norm, shp) 225 | rst = rst * norm 226 | 227 | # residual 228 | if self.res_fc is not None: 229 | resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats) 230 | rst = rst + resval 231 | 232 | # activation 233 | if self._activation is not None: 234 | rst = self._activation(rst) 235 | 236 | return rst 237 | 238 | 239 | class GAT(nn.Module): 240 | def __init__( 241 | self, 242 | in_feats, 243 | n_classes, 244 | n_hidden, 245 | n_layers, 246 | n_heads, 247 | activation, 248 | dropout=0.0, 249 | input_drop=0.0, 250 | attn_drop=0.0, 251 | edge_drop=0.0, 252 | use_attn_dst=True, 253 | use_symmetric_norm=False, 254 | ): 255 | super().__init__() 256 | self.in_feats = in_feats 257 | self.n_hidden = n_hidden 258 | self.n_classes = n_classes 259 | self.n_layers = n_layers 260 | self.num_heads = n_heads 261 | 262 | self.convs = nn.ModuleList() 263 | self.norms = nn.ModuleList() 264 | 265 | for i in range(n_layers): 266 | in_hidden = n_heads * n_hidden if i > 0 else in_feats 267 | out_hidden = n_hidden if i < n_layers - 1 else n_classes 268 | num_heads = n_heads if i < n_layers - 1 else 1 269 | out_channels = n_heads 270 | 271 | self.convs.append( 272 | GATConv( 273 | in_hidden, 274 | out_hidden, 275 | num_heads=num_heads, 276 | attn_drop=attn_drop, 277 | edge_drop=edge_drop, 278 | use_attn_dst=use_attn_dst, 279 | use_symmetric_norm=use_symmetric_norm, 280 | residual=True, 281 | ) 282 | ) 283 | 284 | if i < n_layers - 1: 285 | self.norms.append(nn.BatchNorm1d(out_channels * out_hidden)) 286 | 287 | self.bias_last = ElementWiseLinear(n_classes, weight=False, bias=True, inplace=True) 288 | 289 | self.input_drop = nn.Dropout(input_drop) 290 | self.dropout = nn.Dropout(dropout) 291 | self.activation = activation 292 | 293 | def forward(self, graph, feat): 294 | h = feat 295 | h = self.input_drop(h) 296 | 297 | for i in range(self.n_layers): 298 | conv = self.convs[i](graph, h) 299 | 300 | h = conv 301 | 302 | if i < self.n_layers - 1: 303 | h = h.flatten(1) 304 | h = self.norms[i](h) 305 | h = self.activation(h, inplace=True) 306 | h = self.dropout(h) 307 | 308 | self.feat = h # tracks final features before prediction 309 | 310 | h = h.mean(1) 311 | h = self.bias_last(h) 312 | 313 | return h -------------------------------------------------------------------------------- /arxiv_dgl/scripts/gat-teachers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | tmux new -s expt -d 4 | tmux send-keys "conda activate ogb" C-m 5 | 6 | tmux send-keys " 7 | python3 gat.py --use-norm --use-labels --n-label-iters=1 --no-attn-dst --edge-drop=0.3 --input-drop=0.25 --save-pred --gpu 0 --seed 0 --n-runs 5 --expt-name gat-3L250x3h & 8 | python3 gat.py --use-norm --use-labels --n-label-iters=1 --no-attn-dst --edge-drop=0.3 --input-drop=0.25 --save-pred --gpu 1 --seed 5 --n-runs 5 --expt-name gat-3L250x3h & 9 | wait" C-m 10 | 11 | tmux send-keys " 12 | python3 gat.py --use-norm --use-labels --n-label-iters=1 --no-attn-dst --edge-drop=0.3 --input-drop=0.25 --save-pred --gpu 0 --seed 0 --n-runs 5 --n-heads 4 --expt-name gat-3L250x4h & 13 | python3 gat.py --use-norm --use-labels --n-label-iters=1 --no-attn-dst --edge-drop=0.3 --input-drop=0.25 --save-pred --gpu 1 --seed 5 --n-runs 5 --n-heads 4 --expt-name gat-3L250x4h & 14 | wait" C-m 15 | 16 | tmux send-keys "tmux kill-session -t expt" C-m 17 | -------------------------------------------------------------------------------- /arxiv_dgl/scripts/run_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Experimental setup: 4 | # - We tune the loss balancing weights α, β in for all techniques on the validation set when using both logit-based and representation distillation together. 5 | # - For KD, we tune α ∈ {0.8, 0.9}, τ1 ∈ {4, 5}. 6 | # - For FitNet, AT, LSP, and GSP, we tune β ∈ {100, 1000, 10000} and the kernel for LSP, GSP. 7 | # - For G-CRD, we tune β ∈ {0.01, 0.05}, τ2 ∈ {0.05, 0.075, 0.1} and the projection head. 8 | # When comparing representation distillation methods, we set α to 0 in order to ablate performance and reduce β by one order of magnitude. 9 | 10 | tmux new -s expt -d 11 | tmux send-keys "conda activate ogb" C-m 12 | 13 | 14 | 15 | training=supervised 16 | 17 | tmux send-keys " 18 | python3 sign.py --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 5 --seed 0 --gpu 0 --training $training --expt_name $training & 19 | python3 sign.py --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 5 --seed 5 --gpu 1 --training $training --expt_name $training & 20 | wait" C-m 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | training=kd 29 | 30 | tmux send-keys " 31 | python3 sign.py --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 5 --seed 0 --gpu 0 --training $training --expt_name $training & 32 | python3 sign.py --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 5 --seed 5 --gpu 1 --training $training --expt_name $training & 33 | wait" C-m 34 | 35 | 36 | 37 | 38 | 39 | 40 | training=fitnet 41 | beta=10000 42 | 43 | tmux send-keys " 44 | python3 sign.py --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 5 --seed 0 --gpu 0 --training $training --expt_name $training --beta $beta & 45 | python3 sign.py --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 5 --seed 5 --gpu 1 --training $training --expt_name $training --beta $beta & 46 | wait" C-m 47 | 48 | 49 | 50 | training=at 51 | beta=1000 52 | 53 | tmux send-keys " 54 | python3 sign.py --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 5 --seed 0 --gpu 0 --training $training --expt_name $training --beta $beta & 55 | python3 sign.py --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 5 --seed 5 --gpu 1 --training $training --expt_name $training --beta $beta & 56 | wait" C-m 57 | 58 | 59 | 60 | 61 | training=gpw 62 | beta=10000000 63 | max_samples=2048 64 | proj_dim=128 65 | 66 | tmux send-keys " 67 | python3 sign.py --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 5 --seed 0 --gpu 0 --training $training --expt_name $training --beta $beta --max_samples $max_samples --proj_dim $proj_dim & 68 | python3 sign.py --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 5 --seed 5 --gpu 1 --training $training --expt_name $training --beta $beta --max_samples $max_samples --proj_dim $proj_dim & 69 | wait" C-m 70 | 71 | 72 | 73 | 74 | training=nce 75 | beta=0.5 76 | nce_T=0.075 77 | max_samples=8196 78 | proj_dim=256 79 | 80 | tmux send-keys " 81 | python3 sign.py --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 5 --seed 0 --gpu 0 --training $training --expt_name $training --beta $beta --max_samples $max_samples --proj_dim $proj_dim --nce_T $nce_T & 82 | python3 sign.py --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 5 --seed 5 --gpu 1 --training $training --expt_name $training --beta $beta --max_samples $max_samples --proj_dim $proj_dim --nce_T $nce_T & 83 | wait" C-m 84 | 85 | 86 | 87 | 88 | 89 | tmux send-keys "tmux kill-session -t expt" C-m 90 | -------------------------------------------------------------------------------- /arxiv_dgl/scripts/run_all_kd_and_aux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Experimental setup: 4 | # - We tune the loss balancing weights α, β in for all techniques on the validation set when using both logit-based and representation distillation together. 5 | # - For KD, we tune α ∈ {0.8, 0.9}, τ1 ∈ {4, 5}. 6 | # - For FitNet, AT, LSP, and GSP, we tune β ∈ {100, 1000, 10000} and the kernel for LSP, GSP. 7 | # - For G-CRD, we tune β ∈ {0.01, 0.05}, τ2 ∈ {0.05, 0.075, 0.1} and the projection head. 8 | # When comparing representation distillation methods, we set α to 0 in order to ablate performance and reduce β by one order of magnitude. 9 | 10 | tmux new -s expt -d 11 | tmux send-keys "conda activate ogb" C-m 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | training=kd 20 | 21 | tmux send-keys " 22 | python3 sign.py --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 5 --seed 0 --gpu 0 --training $training --expt_name $training & 23 | python3 sign.py --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 5 --seed 5 --gpu 1 --training $training --expt_name $training & 24 | wait" C-m 25 | 26 | 27 | 28 | 29 | 30 | 31 | training=fitnet 32 | beta=1000 33 | 34 | tmux send-keys " 35 | python3 sign.py --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 5 --seed 0 --gpu 0 --training $training --expt_name $training --beta $beta & 36 | python3 sign.py --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 5 --seed 5 --gpu 1 --training $training --expt_name $training --beta $beta & 37 | wait" C-m 38 | 39 | 40 | 41 | training=at 42 | beta=100 43 | 44 | tmux send-keys " 45 | python3 sign.py --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 5 --seed 0 --gpu 0 --training $training --expt_name $training --beta $beta & 46 | python3 sign.py --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 5 --seed 5 --gpu 1 --training $training --expt_name $training --beta $beta & 47 | wait" C-m 48 | 49 | 50 | 51 | 52 | training=gpw 53 | beta=1000000 54 | max_samples=2048 55 | proj_dim=128 56 | 57 | tmux send-keys " 58 | python3 sign.py --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 5 --seed 0 --gpu 0 --training $training --expt_name $training --beta $beta --max_samples $max_samples --proj_dim $proj_dim & 59 | python3 sign.py --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 5 --seed 5 --gpu 1 --training $training --expt_name $training --beta $beta --max_samples $max_samples --proj_dim $proj_dim & 60 | wait" C-m 61 | 62 | 63 | 64 | training=nce 65 | beta=0.1 66 | nce_T=0.075 67 | max_samples=16384 68 | proj_dim=256 69 | expt_name=nce 70 | 71 | tmux send-keys " 72 | python3 sign.py --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 5 --seed 0 --gpu 0 --training $training --expt_name $expt_name --beta $beta --max_samples $max_samples --proj_dim $proj_dim --nce_T $nce_T & 73 | python3 sign.py --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 5 --seed 5 --gpu 1 --training $training --expt_name $expt_name --beta $beta --max_samples $max_samples --proj_dim $proj_dim --nce_T $nce_T & 74 | wait" C-m 75 | 76 | 77 | 78 | tmux send-keys "tmux kill-session -t expt" C-m 79 | -------------------------------------------------------------------------------- /arxiv_dgl/submit.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | 4 | import numpy as np 5 | import torch 6 | from ogb.nodeproppred import DglNodePropPredDataset, Evaluator 7 | 8 | device = None 9 | 10 | dataset = "ogbn-arxiv" 11 | n_node_feats, n_classes = 0, 0 12 | 13 | 14 | def load_data(dataset): 15 | global n_node_feats, n_classes 16 | 17 | data = DglNodePropPredDataset(name=dataset) 18 | evaluator = Evaluator(name=dataset) 19 | 20 | splitted_idx = data.get_idx_split() 21 | train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"] 22 | graph, labels = data[0] 23 | 24 | n_node_feats = graph.ndata["feat"].shape[1] 25 | n_classes = (labels.max() + 1).item() 26 | 27 | return graph, labels, train_idx, val_idx, test_idx, evaluator 28 | 29 | 30 | def evaluate(labels, pred, train_idx, val_idx, test_idx, evaluator): 31 | return ( 32 | evaluator(pred[train_idx], labels[train_idx]), 33 | evaluator(pred[val_idx], labels[val_idx]), 34 | evaluator(pred[test_idx], labels[test_idx]), 35 | ) 36 | 37 | 38 | if __name__ == "__main__": 39 | argparser = argparse.ArgumentParser() 40 | argparser.add_argument("--pred-files", type=str, default="./output/*.pt", help="address of prediction files") 41 | args = argparser.parse_args() 42 | 43 | # load data 44 | graph, labels, train_idx, val_idx, test_idx, evaluator = load_data(dataset) 45 | evaluator_wrapper = lambda pred, labels: evaluator.eval( 46 | {"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels} 47 | )["acc"] 48 | 49 | val_accs, test_accs = [], [] 50 | 51 | for pred_file in glob.iglob(args.pred_files): 52 | print("load:", pred_file) 53 | pred = torch.load(pred_file) 54 | train_acc, val_acc, test_acc = evaluate(labels, pred, train_idx, val_idx, test_idx, evaluator_wrapper) 55 | val_accs.append(val_acc) 56 | test_accs.append(test_acc) 57 | 58 | print(args) 59 | print() 60 | print(f"Run {len(val_accs)} times") 61 | print(f"Average val accuracy: {np.mean(val_accs)*100:.2f} ± {np.std(val_accs)*100:.2f}") 62 | print(f"Average test accuracy: {np.mean(test_accs)*100:.2f} ± {np.std(test_accs)*100:.2f}") 63 | -------------------------------------------------------------------------------- /arxiv_dgl/submit_sign.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import argparse 5 | 6 | 7 | if __name__ == "__main__": 8 | # Experiment settings 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--expt_name', type=str, default="debug", 11 | help='Name of experiment logs/results folder') 12 | args = parser.parse_args() 13 | 14 | expt_args = [] 15 | model = [] 16 | total_param = [] 17 | BestEpoch = [] 18 | Validation = [] 19 | Test = [] 20 | Train = [] 21 | BestTrain = [] 22 | 23 | for root, dirs, files in os.walk(f'logs/{args.expt_name}'): 24 | if 'results.pt' in files: 25 | results = torch.load(os.path.join(root, 'results.pt'), map_location=torch.device('cpu')) 26 | expt_args.append(results['args']) 27 | total_param.append(results['total_param']) 28 | BestEpoch.append(results['BestEpoch']) 29 | Train.append(results['Train']) 30 | Validation.append(results['Validation']) 31 | Test.append(results['Test']) 32 | 33 | print(results['args'].seed, results['Test'], results['Validation']) 34 | 35 | print(expt_args[0]) 36 | print() 37 | print(f'Test performance: {np.mean(Test)*100:.2f} +- {np.std(Test)*100:.2f}') 38 | print(f'Validation performance: {np.mean(Validation)*100:.2f} +- {np.std(Validation)*100:.2f}') 39 | print(f'Train performance: {np.mean(Train)*100:.2f} +- {np.std(Train)*100:.2f}') 40 | print(f'Total parameters: {int(np.mean(total_param))}') -------------------------------------------------------------------------------- /arxiv_dgl/test_timing_gat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import math 6 | import os 7 | import random 8 | import time 9 | import glob 10 | 11 | import dgl 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | from matplotlib import pyplot as plt 17 | from matplotlib.ticker import AutoMinorLocator, MultipleLocator 18 | from ogb.nodeproppred import DglNodePropPredDataset, Evaluator 19 | 20 | from models import GAT 21 | 22 | import nvidia_smi 23 | nvidia_smi.nvmlInit() 24 | 25 | 26 | epsilon = 1 - math.log(2) 27 | 28 | device = None 29 | 30 | dataset = "ogbn-arxiv" 31 | n_node_feats, n_classes = 0, 0 32 | 33 | 34 | def seed(seed=0): 35 | random.seed(seed) 36 | np.random.seed(seed) 37 | torch.manual_seed(seed) 38 | torch.cuda.manual_seed(seed) 39 | torch.cuda.manual_seed_all(seed) 40 | torch.backends.cudnn.deterministic = True 41 | torch.backends.cudnn.benchmark = False 42 | dgl.random.seed(seed) 43 | 44 | 45 | def load_data(dataset): 46 | global n_node_feats, n_classes 47 | 48 | data = DglNodePropPredDataset(name=dataset) 49 | evaluator = Evaluator(name=dataset) 50 | 51 | splitted_idx = data.get_idx_split() 52 | train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"] 53 | graph, labels = data[0] 54 | 55 | n_node_feats = graph.ndata["feat"].shape[1] 56 | n_classes = (labels.max() + 1).item() 57 | 58 | return graph, labels, train_idx, val_idx, test_idx, evaluator 59 | 60 | 61 | def preprocess(graph): 62 | global n_node_feats 63 | 64 | # make bidirected 65 | feat = graph.ndata["feat"] 66 | graph = dgl.to_bidirected(graph) 67 | graph.ndata["feat"] = feat 68 | 69 | # add self-loop 70 | print(f"Total edges before adding self-loop {graph.number_of_edges()}") 71 | graph = graph.remove_self_loop().add_self_loop() 72 | print(f"Total edges after adding self-loop {graph.number_of_edges()}") 73 | 74 | graph.create_formats_() 75 | 76 | return graph 77 | 78 | 79 | def gen_model(args): 80 | if args.use_labels: 81 | n_node_feats_ = n_node_feats + n_classes 82 | else: 83 | n_node_feats_ = n_node_feats 84 | 85 | model = GAT( 86 | n_node_feats_, 87 | n_classes, 88 | n_hidden=args.n_hidden, 89 | n_layers=args.n_layers, 90 | n_heads=args.n_heads, 91 | activation=F.relu, 92 | dropout=args.dropout, 93 | input_drop=args.input_drop, 94 | attn_drop=args.attn_drop, 95 | edge_drop=args.edge_drop, 96 | use_attn_dst=not args.no_attn_dst, 97 | use_symmetric_norm=args.use_norm, 98 | ) 99 | 100 | return model 101 | 102 | 103 | def custom_loss_function(x, labels): 104 | y = F.cross_entropy(x, labels[:, 0], reduction="none") 105 | y = torch.log(epsilon + y) - math.log(epsilon) 106 | return torch.mean(y) 107 | 108 | 109 | def add_labels(feat, labels, idx): 110 | onehot = torch.zeros([feat.shape[0], n_classes], device=device) 111 | onehot[idx, labels[idx, 0]] = 1 112 | return torch.cat([feat, onehot], dim=-1) 113 | 114 | 115 | @torch.no_grad() 116 | def evaluate(args, model, graph, labels, train_idx, val_idx, test_idx, evaluator, device): 117 | model.eval() 118 | 119 | feat = graph.ndata["feat"] 120 | 121 | if args.use_labels: 122 | feat = add_labels(feat, labels, train_idx) 123 | 124 | handle = nvidia_smi.nvmlDeviceGetHandleByIndex(device) 125 | 126 | iter_start_time = time.time() 127 | pred = model(graph, feat) 128 | iter_time = time.time() - iter_start_time 129 | gpu_used = nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used 130 | 131 | if args.n_label_iters > 0: 132 | unlabel_idx = torch.cat([val_idx, test_idx]) 133 | for _ in range(args.n_label_iters): 134 | feat[unlabel_idx, -n_classes:] = F.softmax(pred[unlabel_idx], dim=-1) 135 | pred = model(graph, feat) 136 | 137 | train_loss = custom_loss_function(pred[train_idx], labels[train_idx]) 138 | val_loss = custom_loss_function(pred[val_idx], labels[val_idx]) 139 | test_loss = custom_loss_function(pred[test_idx], labels[test_idx]) 140 | 141 | model_feat = model.feat 142 | 143 | return ( 144 | evaluator(pred[train_idx], labels[train_idx]), 145 | evaluator(pred[val_idx], labels[val_idx]), 146 | evaluator(pred[test_idx], labels[test_idx]), 147 | train_loss, 148 | val_loss, 149 | test_loss, 150 | pred, 151 | model_feat, 152 | iter_time, 153 | gpu_used 154 | ) 155 | 156 | 157 | def count_parameters(args): 158 | model = gen_model(args) 159 | return sum([p.numel() for p in model.parameters() if p.requires_grad]) 160 | 161 | 162 | def main(): 163 | seed(42) 164 | 165 | global device, n_node_feats, n_classes, epsilon 166 | 167 | argparser = argparse.ArgumentParser( 168 | "GAT checkpoint on ogbn-arxiv", formatter_class=argparse.ArgumentDefaultsHelpFormatter 169 | ) 170 | argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides --gpu.") 171 | argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.") 172 | argparser.add_argument("--seed", type=int, default=0, help="seed") 173 | argparser.add_argument( 174 | "--use-labels", action="store_true", help="Use labels in the training set as input features." 175 | ) 176 | argparser.add_argument("--n-label-iters", type=int, default=1, help="number of label iterations") 177 | argparser.add_argument("--mask-rate", type=float, default=0.5, help="mask rate") 178 | argparser.add_argument("--no-attn-dst", action="store_true", help="Don't use attn_dst.") 179 | argparser.add_argument("--use-norm", action="store_true", help="Use symmetrically normalized adjacency matrix.") 180 | argparser.add_argument("--lr", type=float, default=0.002, help="learning rate") 181 | argparser.add_argument("--n-layers", type=int, default=3, help="number of layers") 182 | argparser.add_argument("--n-heads", type=int, default=3, help="number of heads") 183 | argparser.add_argument("--n-hidden", type=int, default=250, help="number of hidden units") 184 | argparser.add_argument("--dropout", type=float, default=0.75, help="dropout rate") 185 | argparser.add_argument("--input-drop", type=float, default=0.25, help="input drop rate") 186 | argparser.add_argument("--attn-drop", type=float, default=0.0, help="attention drop rate") 187 | argparser.add_argument("--edge-drop", type=float, default=0.3, help="edge drop rate") 188 | argparser.add_argument("--wd", type=float, default=0, help="weight decay") 189 | argparser.add_argument("--checkpoint-files", type=str, default="./checkpoints/gat-3L250x3h/*.pt", help="address of checkpoint files") 190 | args = argparser.parse_args() 191 | 192 | if not args.use_labels and args.n_label_iters > 0: 193 | raise ValueError("'--use-labels' must be enabled when n_label_iters > 0") 194 | 195 | if args.cpu: 196 | device = torch.device("cpu") 197 | else: 198 | device = torch.device(f"cuda:{args.gpu}") 199 | 200 | # load data & preprocess 201 | graph, labels, train_idx, val_idx, test_idx, evaluator = load_data(dataset) 202 | graph = preprocess(graph) 203 | 204 | graph, labels, train_idx, val_idx, test_idx = map( 205 | lambda x: x.to(device), (graph, labels, train_idx, val_idx, test_idx) 206 | ) 207 | 208 | evaluator_wrapper = lambda pred, labels: evaluator.eval( 209 | {"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels} 210 | )["acc"] 211 | 212 | # run 213 | val_accs, test_accs = [], [] 214 | t_iter_time = [] 215 | t_gpu_used = [] 216 | for file in glob.iglob(args.checkpoint_files): 217 | torch.cuda.empty_cache() 218 | 219 | print("load:", file) 220 | checkpoint = torch.load(file) 221 | 222 | # define model and optimizer 223 | model = gen_model(args).to(device) 224 | model.load_state_dict(checkpoint['model_state_dict']) 225 | # optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.wd) 226 | # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 227 | 228 | train_acc, val_acc, test_acc, train_loss, val_loss, test_loss, pred, model_feat, iter_time, gpu_used = evaluate( 229 | args, model, graph, labels, train_idx, val_idx, test_idx, evaluator_wrapper, args.gpu 230 | ) 231 | val_accs.append(val_acc) 232 | test_accs.append(test_acc) 233 | t_iter_time.append(iter_time) 234 | t_gpu_used.append(gpu_used) 235 | 236 | n_running = file.split('/')[-1][:-3] 237 | # os.makedirs(f"./logits", exist_ok=True) 238 | # torch.save(pred, f"./logits/{n_running}.pt") 239 | # os.makedirs(f"./features", exist_ok=True) 240 | # torch.save(model_feat, f"./features/{n_running}.pt") 241 | 242 | print(args) 243 | print() 244 | print(f"Run {len(val_accs)} times") 245 | print(f"Average val accuracy: {np.mean(val_accs)*100:.2f} ± {np.std(val_accs)*100:.2f}") 246 | print(f"Average test accuracy: {np.mean(test_accs)*100:.2f} ± {np.std(test_accs)*100:.2f}") 247 | print(f"Avg. Iteration Time: {np.mean(t_iter_time[1:]):.3f}s ± {np.std(t_iter_time[1:]):.3f}") 248 | print(f"Avg. GPU Used: {np.mean(t_gpu_used) * 1e-9:.1f}GB ± {np.std(t_gpu_used) * 1e-9:.1f}") 249 | print(f"Number of params: {count_parameters(args)}") 250 | 251 | 252 | if __name__ == "__main__": 253 | main() 254 | 255 | # python test_gat.py --use-norm --use-labels --n-label-iters=1 --no-attn-dst --edge-drop=0.3 --input-drop=0.25 --gpu 0 -------------------------------------------------------------------------------- /arxiv_dgl/test_timing_sign.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import time 4 | import os 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | import dgl 13 | import dgl.function as fn 14 | 15 | from ogb.nodeproppred import DglNodePropPredDataset, Evaluator 16 | 17 | import nvidia_smi 18 | nvidia_smi.nvmlInit() 19 | 20 | 21 | 22 | def get_ogb_evaluator(dataset): 23 | """ 24 | Get evaluator from Open Graph Benchmark based on dataset 25 | """ 26 | evaluator = Evaluator(name=dataset) 27 | return lambda preds, labels: evaluator.eval({ 28 | "y_true": labels.view(-1, 1), 29 | "y_pred": preds.view(-1, 1), 30 | })["acc"] 31 | 32 | 33 | def convert_mag_to_homograph(g, device): 34 | """ 35 | Featurize node types that don't have input features (i.e. author, 36 | institution, field_of_study) by averaging their neighbor features. 37 | Then convert the graph to a undirected homogeneous graph. 38 | """ 39 | src_writes, dst_writes = g.all_edges(etype="writes") 40 | src_topic, dst_topic = g.all_edges(etype="has_topic") 41 | src_aff, dst_aff = g.all_edges(etype="affiliated_with") 42 | new_g = dgl.heterograph({ 43 | ("paper", "written", "author"): (dst_writes, src_writes), 44 | ("paper", "has_topic", "field"): (src_topic, dst_topic), 45 | ("author", "aff", "inst"): (src_aff, dst_aff) 46 | }) 47 | new_g = new_g.to(device) 48 | new_g.nodes["paper"].data["feat"] = g.nodes["paper"].data["feat"] 49 | new_g["written"].update_all(fn.copy_u("feat", "m"), fn.mean("m", "feat")) 50 | new_g["has_topic"].update_all(fn.copy_u("feat", "m"), fn.mean("m", "feat")) 51 | new_g["aff"].update_all(fn.copy_u("feat", "m"), fn.mean("m", "feat")) 52 | g.nodes["author"].data["feat"] = new_g.nodes["author"].data["feat"] 53 | g.nodes["institution"].data["feat"] = new_g.nodes["inst"].data["feat"] 54 | g.nodes["field_of_study"].data["feat"] = new_g.nodes["field"].data["feat"] 55 | 56 | # Convert to homogeneous graph 57 | # Get DGL type id for paper type 58 | target_type_id = g.get_ntype_id("paper") 59 | g = dgl.to_homogeneous(g, ndata=["feat"]) 60 | g = dgl.add_reverse_edges(g, copy_ndata=True) 61 | # Mask for paper nodes 62 | g.ndata["target_mask"] = g.ndata[dgl.NTYPE] == target_type_id 63 | return g 64 | 65 | 66 | def load_dataset(name, device): 67 | """ 68 | Load dataset and move graph and features to device 69 | """ 70 | if name not in ["ogbn-products", "ogbn-arxiv", "ogbn-mag"]: 71 | raise RuntimeError("Dataset {} is not supported".format(name)) 72 | dataset = DglNodePropPredDataset(name=name) 73 | splitted_idx = dataset.get_idx_split() 74 | train_nid = splitted_idx["train"] 75 | val_nid = splitted_idx["valid"] 76 | test_nid = splitted_idx["test"] 77 | g, labels = dataset[0] 78 | g = g.to(device) 79 | if name == "ogbn-arxiv": 80 | g = dgl.add_reverse_edges(g, copy_ndata=True) 81 | g = dgl.add_self_loop(g) 82 | g.ndata['feat'] = g.ndata['feat'].float() 83 | elif name == "ogbn-mag": 84 | # MAG is a heterogeneous graph. The task is to make prediction for 85 | # paper nodes 86 | labels = labels["paper"] 87 | train_nid = train_nid["paper"] 88 | val_nid = val_nid["paper"] 89 | test_nid = test_nid["paper"] 90 | g = convert_mag_to_homograph(g, device) 91 | else: 92 | g.ndata['feat'] = g.ndata['feat'].float() 93 | n_classes = dataset.num_classes 94 | labels = labels.squeeze() 95 | evaluator = get_ogb_evaluator(name) 96 | 97 | print(f"# Nodes: {g.number_of_nodes()}\n" 98 | f"# Edges: {g.number_of_edges()}\n" 99 | f"# Train: {len(train_nid)}\n" 100 | f"# Val: {len(val_nid)}\n" 101 | f"# Test: {len(test_nid)}\n" 102 | f"# Classes: {n_classes}") 103 | 104 | return g, labels, n_classes, train_nid, val_nid, test_nid, evaluator 105 | 106 | 107 | class FeedForwardNet(nn.Module): 108 | def __init__(self, in_feats, hidden, out_feats, n_layers, dropout): 109 | super(FeedForwardNet, self).__init__() 110 | self.layers = nn.ModuleList() 111 | self.n_layers = n_layers 112 | if n_layers == 1: 113 | self.layers.append(nn.Linear(in_feats, out_feats)) 114 | else: 115 | self.layers.append(nn.Linear(in_feats, hidden)) 116 | for i in range(n_layers - 2): 117 | self.layers.append(nn.Linear(hidden, hidden)) 118 | self.layers.append(nn.Linear(hidden, out_feats)) 119 | if self.n_layers > 1: 120 | self.prelu = nn.PReLU() 121 | self.dropout = nn.Dropout(dropout) 122 | self.reset_parameters() 123 | 124 | def reset_parameters(self): 125 | gain = nn.init.calculate_gain("relu") 126 | for layer in self.layers: 127 | nn.init.xavier_uniform_(layer.weight, gain=gain) 128 | nn.init.zeros_(layer.bias) 129 | 130 | def forward(self, x): 131 | for layer_id, layer in enumerate(self.layers): 132 | x = layer(x) 133 | if layer_id < self.n_layers - 1: 134 | x = self.dropout(self.prelu(x)) 135 | return x 136 | 137 | 138 | class SIGN(nn.Module): 139 | def __init__(self, in_feats, hidden, out_feats, num_hops, n_layers, 140 | dropout, input_drop): 141 | super(SIGN, self).__init__() 142 | self.dropout = nn.Dropout(dropout) 143 | self.prelu = nn.PReLU() 144 | self.inception_ffs = nn.ModuleList() 145 | self.input_drop = nn.Dropout(input_drop) 146 | for hop in range(num_hops): 147 | self.inception_ffs.append( 148 | FeedForwardNet(in_feats, hidden, hidden, n_layers, dropout)) 149 | self.project = FeedForwardNet(num_hops * hidden, hidden, out_feats, 150 | n_layers, dropout) 151 | 152 | def forward(self, feats): 153 | feats = [self.input_drop(feat) for feat in feats] 154 | hidden = [] 155 | for feat, ff in zip(feats, self.inception_ffs): 156 | hidden.append(ff(feat)) 157 | self.out_feat = self.dropout(self.prelu(torch.cat(hidden, dim=-1))) 158 | out = self.project(self.out_feat) 159 | return out 160 | 161 | def reset_parameters(self): 162 | for ff in self.inception_ffs: 163 | ff.reset_parameters() 164 | self.project.reset_parameters() 165 | 166 | 167 | def get_n_params(model): 168 | pp = 0 169 | for p in list(model.parameters()): 170 | nn = 1 171 | for s in list(p.size()): 172 | nn = nn*s 173 | pp += nn 174 | return pp 175 | 176 | 177 | def neighbor_average_features(g, args): 178 | """ 179 | Compute multi-hop neighbor-averaged node features 180 | """ 181 | print("Compute neighbor-averaged feats") 182 | g.ndata["feat_0"] = g.ndata["feat"] 183 | for hop in range(1, args.R + 1): 184 | g.update_all(fn.copy_u(f"feat_{hop-1}", "msg"), 185 | fn.mean("msg", f"feat_{hop}")) 186 | res = [] 187 | for hop in range(args.R + 1): 188 | res.append(g.ndata.pop(f"feat_{hop}")) 189 | 190 | if args.dataset == "ogbn-mag": 191 | # For MAG dataset, only return features for target node types (i.e. 192 | # paper nodes) 193 | target_mask = g.ndata["target_mask"] 194 | target_ids = g.ndata[dgl.NID][target_mask] 195 | num_target = target_mask.sum().item() 196 | new_res = [] 197 | for x in res: 198 | feat = torch.zeros((num_target,) + x.shape[1:], 199 | dtype=x.dtype, device=x.device) 200 | feat[target_ids] = x[target_mask] 201 | new_res.append(feat) 202 | res = new_res 203 | return res 204 | 205 | 206 | def prepare_data(device, args): 207 | """ 208 | Load dataset and compute neighbor-averaged node features used by SIGN model 209 | """ 210 | data = load_dataset(args.dataset, device) 211 | g, labels, n_classes, train_nid, val_nid, test_nid, evaluator = data 212 | in_feats = g.ndata['feat'].shape[1] 213 | feats = neighbor_average_features(g, args) 214 | labels = labels.to(device) 215 | # move to device 216 | train_nid = train_nid.to(device) 217 | val_nid = val_nid.to(device) 218 | test_nid = test_nid.to(device) 219 | return feats, labels, in_feats, n_classes, \ 220 | train_nid, val_nid, test_nid, evaluator 221 | 222 | 223 | def test(model, feats, labels, test_loader, evaluator, train_nid, val_nid, test_nid, device_id): 224 | handle = nvidia_smi.nvmlDeviceGetHandleByIndex(device_id) 225 | model.eval() 226 | device = labels.device 227 | preds = [] 228 | logits = [] 229 | iter_time = 0 230 | max_gpu_used = nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used 231 | for batch in test_loader: 232 | iter_start_time = time.time() 233 | batch_feats = [feat[batch].to(device) for feat in feats] 234 | batch_logits = model(batch_feats) 235 | iter_time += (time.time() - iter_start_time) 236 | gpu_used = nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used 237 | if gpu_used >= max_gpu_used: 238 | max_gpu_used = gpu_used 239 | 240 | preds.append(torch.argmax(batch_logits, dim=-1)) 241 | logits.append(batch_logits) 242 | 243 | # Concat mini-batch prediction results along node dimension 244 | preds = torch.cat(preds, dim=0) 245 | logits = torch.cat(logits, dim=0) 246 | train_res = evaluator(preds[train_nid], labels[train_nid]) 247 | val_res = evaluator(preds[val_nid], labels[val_nid]) 248 | test_res = evaluator(preds[test_nid], labels[test_nid]) 249 | 250 | return train_res, val_res, test_res, iter_time, max_gpu_used 251 | 252 | 253 | def run(run_idx, args, data, device): 254 | feats, labels, in_size, num_classes, \ 255 | train_nid, val_nid, test_nid, evaluator = data 256 | test_loader = torch.utils.data.DataLoader( 257 | torch.arange(labels.shape[0]), batch_size=args.eval_batch_size, 258 | shuffle=False, drop_last=False) 259 | 260 | # Initialize model and optimizer for each run 261 | num_hops = args.R + 1 262 | model = SIGN(in_size, args.num_hidden, num_classes, num_hops, 263 | args.ff_layer, args.dropout, args.input_dropout) 264 | model = model.to(device) 265 | total_param = get_n_params(model) 266 | print("# Params:", total_param) 267 | 268 | with torch.no_grad(): 269 | train_res, val_res, test_res, iter_time, max_gpu_used = test( 270 | model, feats, labels, test_loader, evaluator, train_nid, val_nid, test_nid, args.gpu) 271 | 272 | return iter_time, max_gpu_used 273 | 274 | 275 | def main(args): 276 | seed(42) 277 | 278 | if args.gpu < 0: 279 | device = "cpu" 280 | else: 281 | device = "cuda:{}".format(args.gpu) 282 | 283 | with torch.no_grad(): 284 | data = prepare_data(device, args) 285 | 286 | t_iter_time = [] 287 | t_gpu_used = [] 288 | for i in range(args.num_runs): 289 | iter_time, gpu_used = run(i, args, data, device) 290 | t_iter_time.append(iter_time) 291 | t_gpu_used.append(gpu_used) 292 | 293 | print(f"Avg. Iteration Time: {np.mean(t_iter_time[1:]):.3f}s ± {np.std(t_iter_time[1:]):.3f}") 294 | print(f"Avg. GPU Used: {np.mean(t_gpu_used) * 1e-9:.1f}GB ± {np.std(t_gpu_used) * 1e-9:.1f}") 295 | 296 | 297 | def seed(seed=0): 298 | random.seed(seed) 299 | np.random.seed(seed) 300 | torch.manual_seed(seed) 301 | torch.cuda.manual_seed(seed) 302 | torch.cuda.manual_seed_all(seed) 303 | torch.backends.cudnn.deterministic = True 304 | torch.backends.cudnn.benchmark = False 305 | dgl.random.seed(seed) 306 | 307 | 308 | if __name__ == "__main__": 309 | parser = argparse.ArgumentParser(description="SIGN") 310 | 311 | # Experimental settings 312 | parser.add_argument("--dataset", type=str, default="ogbn-arxiv") 313 | parser.add_argument("--seed", type=int, default=0, help="seed") 314 | parser.add_argument("--num-epochs", type=int, default=1000) 315 | parser.add_argument("--gpu", type=int, default=0) 316 | parser.add_argument("--num-runs", type=int, default=10, 317 | help="number of times to repeat the experiment") 318 | 319 | # SIGN settings 320 | parser.add_argument("--num-hidden", type=int, default=512) 321 | parser.add_argument("--ff-layer", type=int, default=2, 322 | help="number of feed-forward layers") 323 | parser.add_argument("--R", type=int, default=5, 324 | help="number of hops") 325 | parser.add_argument("--dropout", type=float, default=0.5, 326 | help="dropout on activation") 327 | parser.add_argument("--input-dropout", type=float, default=0.1, 328 | help="dropout on input features") 329 | parser.add_argument("--weight-decay", type=float, default=0) 330 | parser.add_argument("--lr", type=float, default=0.001) 331 | parser.add_argument("--eval-every", type=int, default=10) 332 | parser.add_argument("--batch-size", type=int, default=50000) 333 | parser.add_argument("--eval-batch-size", type=int, default=100000, 334 | help="evaluation batch size") 335 | 336 | args = parser.parse_args() 337 | 338 | print(args) 339 | main(args) -------------------------------------------------------------------------------- /arxiv_pyg/README.md: -------------------------------------------------------------------------------- 1 | # Knowledge Distillation for GNNs (ARXIV with PyG) 2 | 3 | **Dataset**: ARXIV 4 | 5 | **Library**: PyG 6 | 7 | This repository contains code to benchmark knowledge distillation for GNNs on the ARXIV dataset, developed in the PyG framework. 8 | The main purpose of the PyG codebase is to: 9 | - Train student models GCN and GraphSage on the ARXIV dataset with/without knowledge distillation. 10 | 11 | Note that the teacher GNNs are trained on ARXIV via the DGL implementation in `arxiv_dgl`. 12 | To train students via PyG, you first need to train and dump teacher models via DGL (the PyG implementation for teachers gave OOM errors). 13 | 14 | ![ARXIV Results](../img/arxiv-mag.png) 15 | 16 | ## Directory Structure 17 | 18 | ``` 19 | . 20 | ├── dataset # automatically created by OGB data downloaders 21 | ├── logs # logging directory for student models 22 | | 23 | ├── scripts # scripts to conduct full experiments and reproduce results 24 | │ ├── run_gcn.sh # script to benchmark all KD losses for GCN 25 | │ ├── run_kd_and_aux.sh # script to benchmark all KD+Auxiliary losses 26 | │ └── run_sage.sh # script to benchmark all KD losses for GraphSage 27 | | 28 | ├── README.md 29 | | 30 | ├── correlation.py # script used to compute structural correlation among teacher-student embedding spaces (CKA and Mantel Tests) 31 | ├── criterion.py # KD loss functions 32 | ├── gnn_kd_and_aux.py # train student GNNs via KD+Auxiliary loss training 33 | ├── gnn.py # train student GNNs via Auxiliary representation distillation loss 34 | ├── logger.py # logging utilities 35 | ├── submit.py # read log directory to aggregate results 36 | └── test.py # test model checkpoint and timing 37 | ``` 38 | 39 | ## Example Usage 40 | 41 | For full usage, each file has accompanying flags and documentation. 42 | Also see the `scripts` folder for reproducing results. 43 | -------------------------------------------------------------------------------- /arxiv_pyg/correlation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import time 4 | import random 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | import torch_geometric.transforms as T 10 | from torch_geometric.nn import GCNConv, SAGEConv 11 | from torch_geometric.utils import subgraph 12 | 13 | from ogb.nodeproppred import PygNodePropPredDataset, Evaluator 14 | 15 | from scipy.stats import pearsonr, spearmanr 16 | from scipy.spatial.distance import pdist, squareform 17 | from sklearn.metrics import pairwise_distances 18 | 19 | from gnn import * 20 | 21 | 22 | def pairwise_euclidean_dist(A, eps=1e-10): 23 | # Fast pairwise euclidean distance on GPU 24 | # TODO why is this not fast? 25 | sqrA = torch.sum(torch.pow(A, 2), 1, keepdim=True).expand(A.shape[0], A.shape[0]) 26 | return torch.sqrt( 27 | sqrA - 2*torch.mm(A, A.t()) + sqrA.t() + eps 28 | ) 29 | 30 | 31 | def centering(K): 32 | n = K.shape[0] 33 | unit = np.ones([n, n]) 34 | I = np.eye(n) 35 | H = I - unit / n 36 | 37 | return np.dot(np.dot(H, K), H) # HKH are the same with KH, KH is the first centering, H(KH) do the second time, results are the sme with one time centering 38 | # return np.dot(H, K) # KH 39 | 40 | 41 | def rbf(X, sigma=None): 42 | GX = np.dot(X, X.T) 43 | KX = np.diag(GX) - GX + (np.diag(GX) - GX).T 44 | if sigma is None: 45 | mdist = np.median(KX[KX != 0]) 46 | sigma = math.sqrt(mdist) 47 | KX *= - 0.5 / (sigma * sigma) 48 | KX = np.exp(KX) 49 | return KX 50 | 51 | 52 | def kernel_HSIC(X, Y, sigma): 53 | return np.sum(centering(rbf(X, sigma)) * centering(rbf(Y, sigma))) 54 | 55 | 56 | def linear_HSIC(X, Y): 57 | L_X = np.dot(X, X.T) 58 | L_Y = np.dot(Y, Y.T) 59 | return np.sum(centering(L_X) * centering(L_Y)) 60 | 61 | 62 | def linear_CKA(X, Y): 63 | hsic = linear_HSIC(X, Y) 64 | var1 = np.sqrt(linear_HSIC(X, X)) 65 | var2 = np.sqrt(linear_HSIC(Y, Y)) 66 | 67 | return hsic / (var1 * var2) 68 | 69 | 70 | def kernel_CKA(X, Y, sigma=None): 71 | hsic = kernel_HSIC(X, Y, sigma) 72 | var1 = np.sqrt(kernel_HSIC(X, X, sigma)) 73 | var2 = np.sqrt(kernel_HSIC(Y, Y, sigma)) 74 | 75 | return hsic / (var1 * var2) 76 | 77 | 78 | def fast_linear_CKA(X, Y): 79 | L_X = centering(np.dot(X, X.T)) 80 | L_Y = centering(np.dot(Y, Y.T)) 81 | hsic = np.sum(L_X * L_Y) 82 | var1 = np.sqrt(np.sum(L_X * L_X)) 83 | var2 = np.sqrt(np.sum(L_Y * L_Y)) 84 | 85 | return hsic / (var1 * var2) 86 | 87 | 88 | # Load ARXIV dataset 89 | 90 | device = 'cuda:0' # 'cpu' # f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' 91 | device = torch.device(device) 92 | 93 | dataset = PygNodePropPredDataset(name='ogbn-arxiv', transform=T.ToSparseTensor()) 94 | 95 | data = dataset[0] 96 | data.adj_t = data.adj_t.to_symmetric() 97 | data = data.to(device) 98 | 99 | split_idx = dataset.get_idx_split() 100 | train_idx = split_idx['train'].to(device) 101 | valid_idx = split_idx['valid'].to(device) 102 | test_idx = split_idx['test'].to(device) 103 | 104 | edge_index = torch.stack(data.adj_t.coo()[:2]) 105 | edge_index = subgraph(valid_idx, edge_index, relabel_nodes=True)[0] 106 | src, dst = edge_index 107 | 108 | evaluator = Evaluator(name='ogbn-arxiv') 109 | 110 | ### 111 | 112 | class args: 113 | # Experiment settings 114 | device = -1 115 | seed = 0 116 | 117 | gnn = "sage" #"sage" 118 | 119 | # metric = "cosine-local" 120 | 121 | num_layers = 2 122 | hidden_channels = 256 123 | dropout = 0.5 124 | 125 | checkpoint_dirs = { 126 | # Example 127 | "GCD": "/home/user/molecules/arxiv_pyg/logs_gcd/kd_and_aux/gcd/sage-L2h256-Student-gcd", 128 | "NCE": "/home/user/molecules/arxiv_pyg/logs_kd_and_aux/sage-rerun/nce/sage-L2h256-Student-nce", 129 | "LPW": "/home/user/molecules/arxiv_pyg/logs_kd_and_aux/sage/lpw-cosine/sage-L2h256-Student-lpw", 130 | "GPW": "/home/user/molecules/arxiv_pyg/logs_kd_and_aux/sage-rerun/gpw-cosine/sage-L2h256-Student-gpw", 131 | "AT": "/home/user/molecules/arxiv_pyg/logs_kd_and_aux/sage-rerun/at/sage-L2h256-Student-at", 132 | "FitNet": "/home/user/molecules/arxiv_pyg/logs_kd_and_aux/sage/fitnet/sage-L2h256-Student-fitnet", 133 | "KD": "/home/user/molecules/arxiv_pyg/logs/sage/kd/sage-L2h256-Student-kd", 134 | "Sup.": "/home/user/molecules/arxiv_pyg/logs/sage/supervised/sage-L2h256-Student-supervised" 135 | } 136 | 137 | ### 138 | 139 | with torch.no_grad(): 140 | 141 | for method, checkpoint_dir in zip(checkpoint_dirs.keys(), checkpoint_dirs.values()): 142 | print(checkpoint_dir) 143 | 144 | student_checkpoints = {} 145 | for root, dirs, files in os.walk(checkpoint_dir): 146 | if 'results.pt' in files: 147 | seed = int(root[-1]) 148 | student_checkpoints[seed] = os.path.join(root, 'results.pt') 149 | 150 | corr_list = [] 151 | corr_list_local = [] 152 | cka_list = [] 153 | for run in range(10): 154 | torch.cuda.empty_cache() 155 | 156 | # Teacher features and pairwise distances 157 | f_t = torch.load(f"../arxiv_dgl/features/gat-3L250x3h/{run}.pt", map_location=torch.device('cpu')).to(device) 158 | f_t = f_t[valid_idx] 159 | 160 | # if args.metric == "cosine": 161 | # f_t = F.normalize(f_t, p=2, dim=-1) 162 | # pw_t = 1 - torch.mm(f_t, f_t.transpose(0, 1)).cpu().numpy() 163 | # np.fill_diagonal(pw_t, 0) 164 | # pw_t = squareform(pw_t) 165 | 166 | # elif args.metric == "cosine-local": 167 | # f_t = F.normalize(f_t, p=2, dim=-1) 168 | # pw_t = 1 - F.cosine_similarity(f_t[src], f_t[dst]).cpu().numpy() 169 | 170 | # elif args.metric == 'cka': 171 | # f_t = F.normalize(f_t, p=2, dim=-1).cpu().numpy() 172 | 173 | # else: 174 | # # pw_t = pairwise_euclidean_dist(f_t) 175 | # # pw_t = pairwise_distances(f_t.cpu().numpy(), metric='euclidean', n_jobs=32) 176 | # pw_t = pdist(f_t.cpu().numpy(), 'euclidean') 177 | 178 | f_t = F.normalize(f_t, p=2, dim=-1) 179 | pw_t = 1 - torch.mm(f_t, f_t.transpose(0, 1)).cpu().numpy() 180 | np.fill_diagonal(pw_t, 0) 181 | pw_t = squareform(pw_t) 182 | pw_t_local = 1 - F.cosine_similarity(f_t[src], f_t[dst]).cpu().numpy() 183 | 184 | ### 185 | 186 | # Student features and pairwise distances 187 | if args.gnn == 'sage': 188 | model = SAGE(data.num_features, args.hidden_channels, 189 | dataset.num_classes, args.num_layers, 190 | args.dropout).to(device) 191 | elif args.gnn == 'gcn': 192 | model = GCN(data.num_features, args.hidden_channels, 193 | dataset.num_classes, args.num_layers, 194 | args.dropout).to(device) 195 | else: 196 | raise ValueError('Invalid GNN type') 197 | 198 | checkpoint = torch.load(student_checkpoints[run], map_location=torch.device('cpu')) 199 | model.load_state_dict(checkpoint['model_state_dict']) 200 | model.eval() 201 | 202 | model(data.x, data.adj_t) 203 | f_s = model.out_feat[valid_idx] 204 | 205 | f_s = F.normalize(f_s, p=2, dim=-1) 206 | pw_s = 1 - torch.mm(f_s, f_s.transpose(0, 1)).cpu().numpy() 207 | np.fill_diagonal(pw_s, 0) 208 | pw_s = squareform(pw_s) 209 | pw_s_local = 1 - F.cosine_similarity(f_s[src], f_s[dst]).cpu().numpy() 210 | 211 | # Compute correlation 212 | corr_list.append(pearsonr(pw_t, pw_s)[0]) 213 | corr_list_local.append(pearsonr(pw_t_local, pw_s_local)[0]) 214 | cka_list.append(fast_linear_CKA(f_s.cpu().numpy(), f_t.cpu().numpy())) 215 | 216 | # print(corr_list) 217 | print(f"{method}: {np.mean(corr_list):.4f} +- {np.std(corr_list):.4f}, {np.mean(corr_list_local):.4f} +- {np.std(corr_list_local):.4f}, {np.mean(cka_list):.4f} +- {np.std(cka_list):.4f}") 218 | -------------------------------------------------------------------------------- /arxiv_pyg/criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | from torch_geometric.utils import softmax, to_dense_adj, subgraph, negative_sampling, add_self_loops 6 | 7 | 8 | def kd_criterion(logits, labels, teacher_logits, alpha=0.9, T=4): 9 | """Logit-based KD, [Hinton et al., 2015](https://arxiv.org/abs/1503.02531) 10 | """ 11 | loss_cls = F.cross_entropy(logits, labels) 12 | 13 | loss_kd = F.kl_div( 14 | F.log_softmax(logits/ T, dim=1), 15 | F.softmax(teacher_logits/ T, dim=1), 16 | log_target=False 17 | ) 18 | 19 | loss = loss_kd* (alpha* T* T) + loss_cls* (1-alpha) 20 | 21 | return loss, loss_cls, loss_kd 22 | 23 | 24 | def fitnet_criterion(logits, labels, feat, teacher_feat, beta=1000): 25 | """FitNet, [Romero et al., 2014](https://arxiv.org/abs/1412.6550) 26 | """ 27 | loss_cls = F.cross_entropy(logits, labels) 28 | 29 | feat = F.normalize(feat, p=2, dim=-1) 30 | teacher_feat = F.normalize(teacher_feat, p=2, dim=-1) 31 | 32 | loss_fitnet = F.mse_loss(feat, teacher_feat) 33 | 34 | loss = loss_cls + beta* loss_fitnet 35 | 36 | return loss, loss_cls, loss_fitnet 37 | 38 | 39 | def at_criterion(logits, labels, feat, teacher_feat, beta=1000): 40 | """Attention Transfer, [Zagoruyko and Komodakis, 2016](https://arxiv.org/abs/1612.03928) 41 | """ 42 | loss_cls = F.cross_entropy(logits, labels) 43 | 44 | feat = feat.pow(2).sum(-1) 45 | teacher_feat = teacher_feat.pow(2).sum(-1) 46 | 47 | loss_at = F.mse_loss( 48 | F.normalize(feat, p=2, dim=-1), 49 | F.normalize(teacher_feat, p=2, dim=-1) 50 | ) 51 | 52 | loss = loss_cls + beta* loss_at 53 | 54 | return loss, loss_cls, loss_at 55 | 56 | 57 | def gpw_criterion(logits, labels, feat, teacher_feat, kernel='cosine', beta=1, max_samples=8192): 58 | """Global Structure Preserving loss, [Joshi et al., TNNLS 2022](https://arxiv.org/abs/2111.04964) 59 | """ 60 | loss_cls = F.cross_entropy(logits, labels) 61 | 62 | if max_samples < feat.shape[0]: 63 | sampled_inds = np.random.choice(feat.shape[0], max_samples, replace=False) 64 | feat = feat[sampled_inds] 65 | teacher_feat = teacher_feat[sampled_inds] 66 | 67 | pw_sim = None 68 | teacher_pw_sim = None 69 | if kernel == 'cosine': 70 | feat = F.normalize(feat, p=2, dim=-1) 71 | teacher_feat = F.normalize(teacher_feat, p=2, dim=-1) 72 | pw_sim = torch.mm(feat, feat.transpose(0, 1)).flatten() 73 | teacher_pw_sim = torch.mm(teacher_feat, teacher_feat.transpose(0, 1)).flatten() 74 | elif kernel == 'poly': 75 | feat = F.normalize(feat, p=2, dim=-1) 76 | teacher_feat = F.normalize(teacher_feat, p=2, dim=-1) 77 | pw_sim = torch.mm(feat, feat.transpose(0, 1)).flatten()**2 78 | teacher_pw_sim = torch.mm(teacher_feat, teacher_feat.transpose(0, 1)).flatten()**2 79 | elif kernel == 'l2': 80 | pw_sim = (feat.unsqueeze(0) - feat.unsqueeze(1)).norm(p=2, dim=-1).flatten() 81 | teacher_pw_sim = (teacher_feat.unsqueeze(0) - teacher_feat.unsqueeze(1)).norm(p=2, dim=-1).flatten() 82 | elif kernel == 'rbf': 83 | pw_sim = torch.exp(-0.5* ((feat.unsqueeze(0) - feat.unsqueeze(1))**2).sum(dim=-1).flatten()) 84 | teacher_pw_sim = torch.exp(-0.5* ((teacher_feat.unsqueeze(0) - teacher_feat.unsqueeze(1))**2).sum(dim=-1).flatten()) 85 | else: 86 | raise NotImplementedError 87 | 88 | loss_gpw = F.mse_loss(pw_sim, teacher_pw_sim) 89 | 90 | loss = loss_cls + beta* loss_gpw 91 | 92 | return loss, loss_cls, loss_gpw 93 | 94 | 95 | def lpw_criterion(logits, labels, feat, teacher_feat, edge_index, kernel='cosine', beta=100, criterion='kld'): 96 | """Local Structure Preserving loss, [Yang et al., CVPR 2020](https://arxiv.org/abs/2003.10477) 97 | """ 98 | loss_cls = F.cross_entropy(logits, labels) 99 | 100 | src, dst = edge_index 101 | 102 | if kernel == 'cosine': 103 | pw_sim = softmax(F.cosine_similarity(feat[src], feat[dst]), dst) 104 | teacher_pw_sim = softmax(F.cosine_similarity(teacher_feat[src], teacher_feat[dst]), dst) 105 | elif kernel == 'poly': 106 | pw_sim = softmax(F.cosine_similarity(feat[src], feat[dst])**2, dst) 107 | teacher_pw_sim = softmax(F.cosine_similarity(teacher_feat[src], teacher_feat[dst])**2, dst) 108 | elif kernel == 'l2': 109 | pw_sim = softmax((feat[src] - feat[dst]).norm(p=2, dim=-1), dst) 110 | teacher_pw_sim = softmax((teacher_feat[src] - teacher_feat[dst]).norm(p=2, dim=-1), dst) 111 | elif kernel == 'rbf': 112 | pw_sim = softmax(torch.exp( -0.5* ((feat[src] - feat[dst])**2).sum(dim=-1) ), dst) 113 | teacher_pw_sim = softmax(torch.exp( -0.5* ((teacher_feat[src] - teacher_feat[dst])**2).sum(dim=-1) ), dst) 114 | else: 115 | raise NotImplementedError 116 | 117 | if criterion == 'mse': 118 | loss_lpw = F.mse_loss(pw_sim, teacher_pw_sim) 119 | elif criterion == 'kld': 120 | loss_lpw = F.kl_div(torch.log(pw_sim), teacher_pw_sim, log_target=False) 121 | else: 122 | raise NotImplementedError 123 | 124 | loss = loss_cls + beta* loss_lpw 125 | 126 | return loss, loss_cls, loss_lpw 127 | 128 | 129 | def nce_criterion(logits, labels, feat, teacher_feat, beta=0.5, nce_T=0.075, max_samples=8192): 130 | """Graph Contrastive Representation Distillation, [Joshi et al., TNNLS 2022](https://arxiv.org/abs/2111.04964) 131 | """ 132 | loss_cls = F.cross_entropy(logits, labels) 133 | 134 | if max_samples < feat.shape[0]: 135 | sampled_inds = np.random.choice(feat.shape[0], max_samples, replace=False) 136 | feat = feat[sampled_inds] 137 | teacher_feat = teacher_feat[sampled_inds] 138 | 139 | feat = F.normalize(feat, p=2, dim=-1) 140 | teacher_feat = F.normalize(teacher_feat, p=2, dim=-1) 141 | 142 | nce_logits = torch.mm(feat, teacher_feat.transpose(0, 1)) 143 | nce_labels = torch.arange(feat.shape[0]).to(feat.device) 144 | 145 | loss_nce = F.cross_entropy(nce_logits/ nce_T, nce_labels) 146 | 147 | loss = loss_cls + beta* loss_nce 148 | 149 | return loss, loss_cls, loss_nce 150 | -------------------------------------------------------------------------------- /arxiv_pyg/gnn_kd_and_aux.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import time 4 | import os 5 | import random 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | import torch_geometric.transforms as T 12 | from torch_geometric.nn import GCNConv, SAGEConv 13 | from torch_geometric.utils import subgraph 14 | 15 | from ogb.nodeproppred import PygNodePropPredDataset, Evaluator 16 | 17 | from logger import Logger 18 | 19 | from criterion import * 20 | 21 | 22 | class GCN(torch.nn.Module): 23 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout): 24 | super(GCN, self).__init__() 25 | 26 | self.convs = torch.nn.ModuleList() 27 | self.convs.append(GCNConv(in_channels, hidden_channels, cached=True)) 28 | self.bns = torch.nn.ModuleList() 29 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) 30 | for _ in range(num_layers - 2): 31 | self.convs.append( 32 | GCNConv(hidden_channels, hidden_channels, cached=True)) 33 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) 34 | self.convs.append(GCNConv(hidden_channels, out_channels, cached=True)) 35 | 36 | self.dropout = dropout 37 | 38 | def reset_parameters(self): 39 | for conv in self.convs: 40 | conv.reset_parameters() 41 | for bn in self.bns: 42 | bn.reset_parameters() 43 | 44 | def forward(self, x, adj_t): 45 | for i, conv in enumerate(self.convs[:-1]): 46 | x = conv(x, adj_t) 47 | x = self.bns[i](x) 48 | x = F.relu(x) 49 | x = F.dropout(x, p=self.dropout, training=self.training) 50 | self.out_feat = x 51 | x = self.convs[-1](x, adj_t) 52 | return x 53 | 54 | 55 | class SAGE(torch.nn.Module): 56 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers,dropout): 57 | super(SAGE, self).__init__() 58 | 59 | self.convs = torch.nn.ModuleList() 60 | self.convs.append(SAGEConv(in_channels, hidden_channels)) 61 | self.bns = torch.nn.ModuleList() 62 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) 63 | for _ in range(num_layers - 2): 64 | self.convs.append(SAGEConv(hidden_channels, hidden_channels)) 65 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) 66 | self.convs.append(SAGEConv(hidden_channels, out_channels)) 67 | 68 | self.dropout = dropout 69 | 70 | def reset_parameters(self): 71 | for conv in self.convs: 72 | conv.reset_parameters() 73 | for bn in self.bns: 74 | bn.reset_parameters() 75 | 76 | def forward(self, x, adj_t): 77 | for i, conv in enumerate(self.convs[:-1]): 78 | x = conv(x, adj_t) 79 | x = self.bns[i](x) 80 | x = F.relu(x) 81 | x = F.dropout(x, p=self.dropout, training=self.training) 82 | self.out_feat = x 83 | x = self.convs[-1](x, adj_t) 84 | return x 85 | 86 | 87 | class ProjectionGCD(torch.nn.Module): 88 | def __init__(self, hidden_channels, proj_dim): 89 | super(ProjectionGCD, self).__init__() 90 | self.conv = GCNConv(hidden_channels, proj_dim) 91 | self.bn = torch.nn.BatchNorm1d(proj_dim) 92 | 93 | def forward(self, x, adj_t): 94 | x = self.conv(x, adj_t) 95 | x = self.bn(x) 96 | x = F.relu(x) 97 | return x 98 | 99 | 100 | def train(model, data, train_idx, optimizer, args, teacher_out_feat, teacher_logits, student_proj=None, teacher_proj=None, edge_index=None): 101 | model.train() 102 | if student_proj: 103 | student_proj.train() 104 | if teacher_proj: 105 | teacher_proj.train() 106 | 107 | out = model(data.x, data.adj_t)[train_idx] 108 | labels = data.y.squeeze(1)[train_idx] 109 | 110 | if args.training == 'supervised': 111 | loss = F.cross_entropy(out, labels) 112 | loss_cls = loss 113 | loss_aux = loss*0 114 | elif args.training == 'kd': 115 | loss, loss_cls, loss_aux = kd_criterion( 116 | out, labels, teacher_logits[train_idx], args.alpha, args.kd_T 117 | ) 118 | elif args.training == 'fitnet': 119 | out_feat = student_proj(model.out_feat[train_idx]) 120 | teacher_out_feat = teacher_proj(teacher_out_feat[train_idx]) 121 | _, _, loss_aux = fitnet_criterion( 122 | out, labels, out_feat, teacher_out_feat, args.beta 123 | ) 124 | loss, loss_cls, _ = kd_criterion( 125 | out, labels, teacher_logits[train_idx], args.alpha, args.kd_T 126 | ) 127 | loss = loss + args.beta* loss_aux 128 | elif args.training == 'at': 129 | out_feat = model.out_feat[train_idx] 130 | teacher_out_feat = teacher_out_feat[train_idx] 131 | _, _, loss_aux = at_criterion( 132 | out, labels, out_feat, teacher_out_feat, args.beta 133 | ) 134 | loss, loss_cls, _ = kd_criterion( 135 | out, labels, teacher_logits[train_idx], args.alpha, args.kd_T 136 | ) 137 | loss = loss + args.beta* loss_aux 138 | elif args.training == 'gpw': 139 | out_feat = student_proj(model.out_feat[train_idx]) 140 | teacher_out_feat = teacher_proj(teacher_out_feat[train_idx]) 141 | _, _, loss_aux = gpw_criterion( 142 | out, labels, out_feat, teacher_out_feat, 143 | args.kernel, args.beta, args.max_samples 144 | ) 145 | loss, loss_cls, _ = kd_criterion( 146 | out, labels, teacher_logits[train_idx], args.alpha, args.kd_T 147 | ) 148 | loss = loss + args.beta* loss_aux 149 | elif args.training == 'lpw': 150 | out_feat = model.out_feat[train_idx] 151 | teacher_out_feat = teacher_out_feat[train_idx] 152 | _, _, loss_aux = lpw_criterion( 153 | out, labels, out_feat, teacher_out_feat, 154 | edge_index, args.kernel, args.beta 155 | ) 156 | loss, loss_cls, _ = kd_criterion( 157 | out, labels, teacher_logits[train_idx], args.alpha, args.kd_T 158 | ) 159 | loss = loss + args.beta* loss_aux 160 | elif args.training == 'nce': 161 | out_feat = student_proj(model.out_feat[train_idx]) 162 | teacher_out_feat = teacher_proj(teacher_out_feat[train_idx]) 163 | _, _, loss_aux = nce_criterion( 164 | out, labels, out_feat, teacher_out_feat, 165 | args.beta, args.nce_T, args.max_samples 166 | ) 167 | loss, loss_cls, _ = kd_criterion( 168 | out, labels, teacher_logits[train_idx], args.alpha, args.kd_T 169 | ) 170 | loss = loss + args.beta* loss_aux 171 | elif args.training == 'gcd': 172 | out_feat = student_proj(model.out_feat, data.adj_t)[train_idx] 173 | teacher_out_feat = teacher_proj(teacher_out_feat, data.adj_t)[train_idx] 174 | _, _, loss_aux = nce_criterion( 175 | out, labels, out_feat, teacher_out_feat, 176 | args.beta, args.nce_T, args.max_samples 177 | ) 178 | loss, loss_cls, _ = kd_criterion( 179 | out, labels, teacher_logits[train_idx], args.alpha, args.kd_T 180 | ) 181 | loss = loss + args.beta* loss_aux 182 | else: 183 | raise NotImplementedError 184 | 185 | optimizer.zero_grad() 186 | loss.backward() 187 | optimizer.step() 188 | 189 | return loss.item(), loss_cls.item(), loss_aux.item() 190 | 191 | 192 | @torch.no_grad() 193 | def test(model, data, split_idx, evaluator): 194 | model.eval() 195 | 196 | out = model(data.x, data.adj_t) 197 | y_pred = out.argmax(dim=-1, keepdim=True) 198 | 199 | train_acc = evaluator.eval({ 200 | 'y_true': data.y[split_idx['train']], 201 | 'y_pred': y_pred[split_idx['train']], 202 | })['acc'] 203 | valid_acc = evaluator.eval({ 204 | 'y_true': data.y[split_idx['valid']], 205 | 'y_pred': y_pred[split_idx['valid']], 206 | })['acc'] 207 | test_acc = evaluator.eval({ 208 | 'y_true': data.y[split_idx['test']], 209 | 'y_pred': y_pred[split_idx['test']], 210 | })['acc'] 211 | 212 | return out, (train_acc, valid_acc, test_acc) 213 | 214 | 215 | def seed(seed=0): 216 | random.seed(seed) 217 | np.random.seed(seed) 218 | torch.manual_seed(seed) 219 | torch.cuda.manual_seed(seed) 220 | torch.cuda.manual_seed_all(seed) 221 | torch.backends.cudnn.deterministic = True 222 | torch.backends.cudnn.benchmark = False 223 | # dgl.random.seed(seed) 224 | 225 | 226 | def main(): 227 | device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' 228 | device = torch.device(device) 229 | 230 | dataset = PygNodePropPredDataset(name='ogbn-arxiv', 231 | transform=T.ToSparseTensor()) 232 | 233 | data = dataset[0] 234 | data.adj_t = data.adj_t.to_symmetric() 235 | data = data.to(device) 236 | 237 | split_idx = dataset.get_idx_split() 238 | train_idx = split_idx['train'].to(device) 239 | 240 | edge_index = None 241 | if args.training in ['lpw', 'nce-edges', 'nce-labels-edges']: 242 | edge_index = torch.stack(data.adj_t.coo()[:2]) 243 | edge_index = subgraph(train_idx, edge_index, relabel_nodes=True)[0] 244 | 245 | if args.gnn == 'sage': 246 | model = SAGE(data.num_features, args.hidden_channels, 247 | dataset.num_classes, args.num_layers, 248 | args.dropout).to(device) 249 | elif args.gnn == 'gcn': 250 | model = GCN(data.num_features, args.hidden_channels, 251 | dataset.num_classes, args.num_layers, 252 | args.dropout).to(device) 253 | else: 254 | raise ValueError('Invalid GNN type') 255 | 256 | print(model) 257 | total_param = 0 258 | for param in model.parameters(): 259 | total_param += np.prod(list(param.data.size())) 260 | print(f'Total parameters: {total_param}') 261 | 262 | evaluator = Evaluator(name='ogbn-arxiv') 263 | logger = Logger(args.runs, args) 264 | 265 | for run in range(args.runs): 266 | seed(args.seed + run) 267 | 268 | model.reset_parameters() 269 | 270 | teacher_out_feat, teacher_logits = None, None 271 | if args.training != 'supervised': 272 | teacher_out_feat = torch.load(f"../arxiv_dgl/features/gat-3L250x3h/{args.seed + run}.pt").to(device) 273 | teacher_logits = torch.load(f"../arxiv_dgl/logits/gat-3L250x3h/{args.seed + run}.pt").to(device) 274 | 275 | if args.training in ["nce", "fitnet", "gpw", "gcd"]: 276 | if args.training == 'gcd': 277 | student_proj = ProjectionGCD(args.hidden_channels, args.proj_dim).to(device) 278 | teacher_proj = ProjectionGCD(750, args.proj_dim).to(device) 279 | 280 | else: 281 | student_proj = torch.nn.Sequential( 282 | torch.nn.Linear(args.hidden_channels, args.proj_dim), 283 | torch.nn.BatchNorm1d(args.proj_dim), 284 | torch.nn.ReLU() 285 | ).to(device) 286 | 287 | teacher_proj = torch.nn.Sequential( 288 | torch.nn.Linear(750, args.proj_dim), 289 | torch.nn.BatchNorm1d(args.proj_dim), 290 | torch.nn.ReLU() 291 | ).to(device) 292 | 293 | optimizer = torch.optim.Adam([ 294 | {'params': model.parameters(), 'lr': args.lr}, 295 | {'params': student_proj.parameters(), 'lr': args.lr}, 296 | {'params': teacher_proj.parameters(), 'lr': args.lr} 297 | ]) 298 | else: 299 | student_proj, teacher_proj = None, None 300 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 301 | 302 | # Create Tensorboard logger 303 | start_time_str = time.strftime("%Y%m%dT%H%M%S") 304 | log_dir = os.path.join( 305 | "logs/kd_and_aux", 306 | args.expt_name, 307 | f"{args.gnn}-L{args.num_layers}h{args.hidden_channels}-Student-{args.training}", 308 | f"{start_time_str}-GPU{args.device}-seed{args.seed+run}" 309 | ) 310 | tb_logger = SummaryWriter(log_dir) 311 | 312 | # Start training 313 | best_epoch = 0 314 | best_train = 0 315 | best_val = 0 316 | best_test = 0 317 | best_logits = None 318 | for epoch in range(1, 1 + args.epochs): 319 | train_loss, train_loss_cls, train_loss_aux = train( 320 | model, data, train_idx, optimizer, 321 | args, teacher_out_feat, teacher_logits, 322 | student_proj, teacher_proj, edge_index 323 | ) 324 | 325 | logits, result = test(model, data, split_idx, evaluator) 326 | 327 | logger.add_result(run, result) 328 | 329 | if epoch % args.log_steps == 0: 330 | train_acc, valid_acc, test_acc = result 331 | print(f'Run: {run + 1:02d}, ' 332 | f'Epoch: {epoch:02d}, ' 333 | f'Loss_total: {train_loss:.4f}, ' 334 | f'Loss_cls: {train_loss_cls:.4f}, ' 335 | f'Loss_aux: {train_loss_aux:.8f}, ' 336 | f'Train: {100 * train_acc:.2f}%, ' 337 | f'Valid: {100 * valid_acc:.2f}% ' 338 | f'Test: {100 * test_acc:.2f}%') 339 | 340 | # Log statistics to Tensorboard, etc. 341 | tb_logger.add_scalar('loss/train', train_loss, epoch) 342 | tb_logger.add_scalar('loss/cls', train_loss_cls, epoch) 343 | tb_logger.add_scalar('loss/aux', train_loss_aux, epoch) 344 | tb_logger.add_scalar('acc/train', train_acc, epoch) 345 | tb_logger.add_scalar('acc/valid', valid_acc, epoch) 346 | tb_logger.add_scalar('acc/test', test_acc, epoch) 347 | 348 | if valid_acc > best_val: 349 | best_epoch = epoch 350 | best_train = train_acc 351 | best_val = valid_acc 352 | best_test = test_acc 353 | best_logits = logits 354 | 355 | logger.print_statistics(run) 356 | torch.save({ 357 | 'args': args, 358 | 'total_param': total_param, 359 | 'BestEpoch': best_epoch, 360 | 'Train': best_train, 361 | 'Validation': best_val, 362 | 'Test': best_test, 363 | 'logits': best_logits, 364 | 'model_state_dict': model.state_dict(), 365 | 'optimizer_state_dict': optimizer.state_dict(), 366 | }, os.path.join(log_dir, "results.pt")) 367 | 368 | logger.print_statistics() 369 | 370 | 371 | if __name__ == "__main__": 372 | parser = argparse.ArgumentParser(description='OGBN-Arxiv (GNN)') 373 | 374 | # Experiment settings 375 | parser.add_argument('--device', type=int, default=0) 376 | parser.add_argument("--seed", type=int, default=0, help="seed") 377 | parser.add_argument('--log_steps', type=int, default=1) 378 | parser.add_argument('--training', type=str, default="supervised") 379 | parser.add_argument('--expt_name', type=str, default="debug") 380 | 381 | # GNN settings 382 | parser.add_argument('--gnn', type=str, default='gcn') 383 | parser.add_argument('--num_layers', type=int, default=3) 384 | parser.add_argument('--hidden_channels', type=int, default=256) 385 | parser.add_argument('--dropout', type=float, default=0.5) 386 | parser.add_argument('--lr', type=float, default=0.01) 387 | parser.add_argument('--epochs', type=int, default=500) 388 | parser.add_argument('--runs', type=int, default=10) 389 | 390 | # KD settings 391 | parser.add_argument('--alpha', type=float, default=0.9, 392 | help='alpha parameter for KD (default: 0.5)') 393 | parser.add_argument('--kd_T', type=float, default=4.0, 394 | help='temperature parameter for KD (default: 1.0)') 395 | parser.add_argument('--beta', type=float, default=0.5, 396 | help='beta parameter for auxiliary distillation losses (default: 0.5)') 397 | 398 | # NCE/auxiliary distillation settings 399 | parser.add_argument('--nce_T', type=float, default=0.075, 400 | help='temperature parameter for NCE (default: 0.075)') 401 | parser.add_argument('--max_samples', type=int, default=8192, 402 | help='maximum samples for NCE/GPW (default: 8192)') 403 | parser.add_argument('--proj_dim', type=int, default=256, 404 | help='common projection dimensionality for NCE/FitNet (default: 150)') 405 | parser.add_argument('--kernel', type=str, default='rbf', 406 | help='kernel for LPW: cosine, polynomial, l2, rbf (default: rbf)') 407 | 408 | args = parser.parse_args() 409 | 410 | print(args) 411 | main() -------------------------------------------------------------------------------- /arxiv_pyg/logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Logger(object): 5 | def __init__(self, runs, info=None): 6 | self.info = info 7 | self.results = [[] for _ in range(runs)] 8 | 9 | def add_result(self, run, result): 10 | assert len(result) == 3 11 | assert run >= 0 and run < len(self.results) 12 | self.results[run].append(result) 13 | 14 | def print_statistics(self, run=None): 15 | if run is not None: 16 | result = 100 * torch.tensor(self.results[run]) 17 | argmax = result[:, 1].argmax().item() 18 | print(f'Run {run + 1:02d}:') 19 | print(f'Highest Train: {result[:, 0].max():.2f}') 20 | print(f'Highest Valid: {result[:, 1].max():.2f}') 21 | print(f' Final Train: {result[argmax, 0]:.2f}') 22 | print(f' Final Test: {result[argmax, 2]:.2f}') 23 | else: 24 | result = 100 * torch.tensor(self.results) 25 | 26 | best_results = [] 27 | for r in result: 28 | train1 = r[:, 0].max().item() 29 | valid = r[:, 1].max().item() 30 | train2 = r[r[:, 1].argmax(), 0].item() 31 | test = r[r[:, 1].argmax(), 2].item() 32 | best_results.append((train1, valid, train2, test)) 33 | 34 | best_result = torch.tensor(best_results) 35 | 36 | print(f'All runs:') 37 | r = best_result[:, 0] 38 | print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}') 39 | r = best_result[:, 1] 40 | print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}') 41 | r = best_result[:, 2] 42 | print(f' Final Train: {r.mean():.2f} ± {r.std():.2f}') 43 | r = best_result[:, 3] 44 | print(f' Final Test: {r.mean():.2f} ± {r.std():.2f}') -------------------------------------------------------------------------------- /arxiv_pyg/scripts/run_gcn.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | 4 | # Experimental setup: 5 | # - We tune the loss balancing weights α, β in for all techniques on the validation set when using both logit-based and representation distillation together. 6 | # - For KD, we tune α ∈ {0.8, 0.9}, τ1 ∈ {4, 5}. 7 | # - For FitNet, AT, LSP, and GSP, we tune β ∈ {100, 1000, 10000} and the kernel for LSP, GSP. 8 | # - For G-CRD, we tune β ∈ {0.01, 0.05}, τ2 ∈ {0.05, 0.075, 0.1} and the projection head. 9 | # When comparing representation distillation methods, we set α to 0 in order to ablate performance and reduce β by one order of magnitude. 10 | 11 | tmux new -s expt -d 12 | tmux send-keys "conda activate ogb" C-m 13 | 14 | gnn=gcn 15 | 16 | runs=5 17 | 18 | num_layers=2 19 | 20 | ############################ 21 | 22 | training=supervised 23 | expt_name=supervised 24 | tmux send-keys " 25 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers & 26 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers & 27 | wait" C-m 28 | 29 | training=kd 30 | expt_name=kd 31 | tmux send-keys " 32 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers & 33 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers & 34 | wait" C-m 35 | 36 | training=fitnet 37 | expt_name=fitnet 38 | beta=1000 39 | tmux send-keys " 40 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --beta $beta & 41 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --beta $beta & 42 | wait" C-m 43 | 44 | training=at 45 | expt_name=at 46 | beta=100000 47 | tmux send-keys " 48 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --beta $beta & 49 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --beta $beta & 50 | wait" C-m 51 | 52 | training=gpw 53 | expt_name=gpw-rbf 54 | beta=100000 55 | max_samples=2048 56 | proj_dim=128 57 | kernel=rbf 58 | tmux send-keys " 59 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 60 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 61 | wait" C-m 62 | 63 | training=gpw 64 | expt_name=gpw-l2 65 | beta=1 66 | max_samples=2048 67 | proj_dim=128 68 | kernel=l2 69 | tmux send-keys " 70 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 71 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 72 | wait" C-m 73 | 74 | training=gpw 75 | expt_name=gpw-poly 76 | beta=100 77 | max_samples=4096 78 | proj_dim=128 79 | kernel=poly 80 | tmux send-keys " 81 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 82 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 83 | wait" C-m 84 | 85 | training=gpw 86 | expt_name=gpw-cosine 87 | beta=100 88 | max_samples=4096 89 | proj_dim=128 90 | kernel=cosine 91 | tmux send-keys " 92 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 93 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 94 | wait" C-m 95 | 96 | training=lpw 97 | expt_name=lpw-rbf 98 | beta=100 99 | max_samples=2048 100 | proj_dim=128 101 | kernel=rbf 102 | tmux send-keys " 103 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 104 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 105 | wait" C-m 106 | 107 | training=lpw 108 | expt_name=lpw-l2 109 | beta=100 110 | max_samples=2048 111 | proj_dim=128 112 | kernel=l2 113 | tmux send-keys " 114 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 115 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 116 | wait" C-m 117 | 118 | training=lpw 119 | expt_name=lpw-poly 120 | beta=100 121 | max_samples=4096 122 | proj_dim=128 123 | kernel=poly 124 | tmux send-keys " 125 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 126 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 127 | wait" C-m 128 | 129 | training=lpw 130 | expt_name=lpw-cosine 131 | beta=100 132 | max_samples=4096 133 | proj_dim=128 134 | kernel=cosine 135 | tmux send-keys " 136 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 137 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 138 | wait" C-m 139 | 140 | training=nce 141 | expt_name=nce 142 | beta=0.1 143 | nce_T=0.075 144 | max_samples=16384 145 | proj_dim=256 146 | tmux send-keys " 147 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --beta $beta --max_samples $max_samples --proj_dim $proj_dim --nce_T $nce_T & 148 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --beta $beta --max_samples $max_samples --proj_dim $proj_dim --nce_T $nce_T & 149 | wait" C-m 150 | 151 | ############################ 152 | 153 | 154 | 155 | tmux send-keys "tmux kill-session -t expt" C-m 156 | -------------------------------------------------------------------------------- /arxiv_pyg/scripts/run_kd_and_aux.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | 4 | tmux new -s expt -d 5 | tmux send-keys "conda activate ogb" C-m 6 | 7 | runs=5 8 | 9 | num_layers=2 10 | 11 | 12 | 13 | gnn=sage 14 | 15 | 16 | ############################ 17 | 18 | training=fitnet 19 | expt_name=sage-rerun/fitnet 20 | beta=100 21 | tmux send-keys " 22 | python gnn_kd_and_aux.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta & 23 | python gnn_kd_and_aux.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta & 24 | wait" C-m 25 | 26 | training=at 27 | expt_name=sage-rerun/at 28 | beta=10000 29 | tmux send-keys " 30 | python gnn_kd_and_aux.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta & 31 | python gnn_kd_and_aux.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta & 32 | wait" C-m 33 | 34 | training=gpw 35 | expt_name=sage-rerun/gpw-cosine 36 | beta=10 37 | max_samples=4096 38 | proj_dim=128 39 | kernel=cosine 40 | tmux send-keys " 41 | python gnn_kd_and_aux.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 42 | python gnn_kd_and_aux.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 43 | wait" C-m 44 | 45 | training=lpw 46 | expt_name=sage-rerun/lpw-cosine 47 | beta=100 48 | max_samples=4096 49 | proj_dim=128 50 | kernel=cosine 51 | tmux send-keys " 52 | python gnn_kd_and_aux.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 53 | python gnn_kd_and_aux.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 54 | wait" C-m 55 | 56 | training=nce 57 | expt_name=sage-rerun/nce 58 | beta=0.01 59 | nce_T=0.05 60 | max_samples=16384 61 | proj_dim=256 62 | tmux send-keys " 63 | python gnn_kd_and_aux.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --nce_T $nce_T & 64 | python gnn_kd_and_aux.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --nce_T $nce_T & 65 | wait" C-m 66 | 67 | ############################ 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | gnn=gcn 80 | 81 | ############################ 82 | 83 | training=fitnet 84 | expt_name=gcn-rerun/fitnet 85 | beta=100 86 | tmux send-keys " 87 | python gnn_kd_and_aux.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta & 88 | python gnn_kd_and_aux.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta & 89 | wait" C-m 90 | 91 | training=at 92 | expt_name=gcn-rerun/at 93 | beta=10000 94 | tmux send-keys " 95 | python gnn_kd_and_aux.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta & 96 | python gnn_kd_and_aux.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta & 97 | wait" C-m 98 | 99 | training=gpw 100 | expt_name=gcn-rerun/gpw-cosine 101 | beta=10 102 | max_samples=4096 103 | proj_dim=128 104 | kernel=cosine 105 | tmux send-keys " 106 | python gnn_kd_and_aux.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 107 | python gnn_kd_and_aux.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 108 | wait" C-m 109 | 110 | training=lpw 111 | expt_name=gcn-rerun/lpw-cosine 112 | beta=100 113 | max_samples=4096 114 | proj_dim=128 115 | kernel=cosine 116 | tmux send-keys " 117 | python gnn_kd_and_aux.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 118 | python gnn_kd_and_aux.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 119 | wait" C-m 120 | 121 | training=nce 122 | expt_name=gcn-rerun/nce 123 | beta=0.01 124 | nce_T=0.05 125 | max_samples=16384 126 | proj_dim=256 127 | tmux send-keys " 128 | python gnn_kd_and_aux.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --nce_T $nce_T & 129 | python gnn_kd_and_aux.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --nce_T $nce_T & 130 | wait" C-m 131 | 132 | ############################ 133 | 134 | 135 | 136 | 137 | 138 | tmux send-keys "tmux kill-session -t expt" C-m 139 | -------------------------------------------------------------------------------- /arxiv_pyg/scripts/run_sage.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | 4 | # Experimental setup: 5 | # - We tune the loss balancing weights α, β in for all techniques on the validation set when using both logit-based and representation distillation together. 6 | # - For KD, we tune α ∈ {0.8, 0.9}, τ1 ∈ {4, 5}. 7 | # - For FitNet, AT, LSP, and GSP, we tune β ∈ {100, 1000, 10000} and the kernel for LSP, GSP. 8 | # - For G-CRD, we tune β ∈ {0.01, 0.05}, τ2 ∈ {0.05, 0.075, 0.1} and the projection head. 9 | # When comparing representation distillation methods, we set α to 0 in order to ablate performance and reduce β by one order of magnitude. 10 | 11 | tmux new -s expt -d 12 | tmux send-keys "conda activate ogb" C-m 13 | 14 | gnn=sage 15 | 16 | runs=5 17 | 18 | num_layers=2 19 | 20 | ############################ 21 | 22 | training=supervised 23 | expt_name=sage/supervised 24 | tmux send-keys " 25 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn & 26 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn & 27 | wait" C-m 28 | 29 | training=kd 30 | expt_name=sage/kd 31 | tmux send-keys " 32 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn & 33 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn & 34 | wait" C-m 35 | 36 | training=fitnet 37 | expt_name=sage/fitnet 38 | beta=1000 39 | tmux send-keys " 40 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta & 41 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta & 42 | wait" C-m 43 | 44 | training=at 45 | expt_name=sage/at 46 | beta=100000 47 | tmux send-keys " 48 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta & 49 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta & 50 | wait" C-m 51 | 52 | training=gpw 53 | expt_name=sage/gpw-rbf 54 | beta=100000 55 | max_samples=2048 56 | proj_dim=128 57 | kernel=rbf 58 | tmux send-keys " 59 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 60 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 61 | wait" C-m 62 | 63 | training=gpw 64 | expt_name=sage/gpw-l2 65 | beta=1 66 | max_samples=2048 67 | proj_dim=128 68 | kernel=l2 69 | tmux send-keys " 70 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 71 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 72 | wait" C-m 73 | 74 | training=gpw 75 | expt_name=sage/gpw-poly 76 | beta=100 77 | max_samples=4096 78 | proj_dim=128 79 | kernel=poly 80 | tmux send-keys " 81 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 82 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 83 | wait" C-m 84 | 85 | training=gpw 86 | expt_name=sage/gpw-cosine 87 | beta=100 88 | max_samples=4096 89 | proj_dim=128 90 | kernel=cosine 91 | tmux send-keys " 92 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 93 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 94 | wait" C-m 95 | 96 | training=lpw 97 | expt_name=sage/lpw-rbf 98 | beta=100 99 | max_samples=2048 100 | proj_dim=128 101 | kernel=rbf 102 | tmux send-keys " 103 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 104 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 105 | wait" C-m 106 | 107 | training=lpw 108 | expt_name=sage/lpw-l2 109 | beta=100 110 | max_samples=2048 111 | proj_dim=128 112 | kernel=l2 113 | tmux send-keys " 114 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 115 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 116 | wait" C-m 117 | 118 | training=lpw 119 | expt_name=sage/lpw-poly 120 | beta=100 121 | max_samples=4096 122 | proj_dim=128 123 | kernel=poly 124 | tmux send-keys " 125 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 126 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 127 | wait" C-m 128 | 129 | training=lpw 130 | expt_name=sage/lpw-cosine 131 | beta=100 132 | max_samples=4096 133 | proj_dim=128 134 | kernel=cosine 135 | tmux send-keys " 136 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 137 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --kernel $kernel & 138 | wait" C-m 139 | 140 | training=nce 141 | expt_name=sage/nce 142 | beta=0.1 143 | nce_T=0.075 144 | max_samples=16384 145 | proj_dim=256 146 | tmux send-keys " 147 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --nce_T $nce_T & 148 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --num_layers $num_layers --gnn $gnn --beta $beta --max_samples $max_samples --proj_dim $proj_dim --nce_T $nce_T & 149 | wait" C-m 150 | 151 | ############################ 152 | 153 | 154 | 155 | tmux send-keys "tmux kill-session -t expt" C-m 156 | -------------------------------------------------------------------------------- /arxiv_pyg/submit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import argparse 5 | 6 | 7 | if __name__ == "__main__": 8 | # Experiment settings 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--expt_name', type=str, default="debug", 11 | help='Name of experiment logs/results folder') 12 | parser.add_argument('--run', type=str, default="debug", 13 | help='Name of specific run') 14 | args = parser.parse_args() 15 | 16 | expt_args = [] 17 | model = [] 18 | total_param = [] 19 | BestEpoch = [] 20 | Validation = [] 21 | Test = [] 22 | Train = [] 23 | BestTrain = [] 24 | 25 | # for root, dirs, files in os.walk(f'logs_kd_and_aux/{args.expt_name}/{args.run}'): 26 | for root, dirs, files in os.walk(f'logs/{args.expt_name}/{args.run}'): 27 | if 'results.pt' in files: 28 | results = torch.load(os.path.join(root, 'results.pt'), map_location=torch.device('cpu')) 29 | expt_args.append(results['args']) 30 | total_param.append(results['total_param']) 31 | BestEpoch.append(results['BestEpoch']) 32 | Validation.append(results['Validation']) 33 | Test.append(results['Test']) 34 | Train.append(results['Train']) 35 | 36 | print(expt_args[0]) 37 | print() 38 | print(f'Test performance: {np.mean(Test)*100:.2f} +- {np.std(Test)*100:.2f}') 39 | print(f'Validation performance: {np.mean(Validation)*100:.2f} +- {np.std(Validation)*100:.2f}') 40 | print(f'Train performance: {np.mean(Train)*100:.2f} +- {np.std(Train)*100:.2f}') 41 | print(f'Total parameters: {int(np.mean(total_param))}') -------------------------------------------------------------------------------- /arxiv_pyg/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import time 4 | import os 5 | import random 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | import torch_geometric.transforms as T 12 | from torch_geometric.nn import GCNConv, SAGEConv 13 | from torch_geometric.utils import subgraph 14 | 15 | from ogb.nodeproppred import PygNodePropPredDataset, Evaluator 16 | 17 | from logger import Logger 18 | 19 | import nvidia_smi 20 | nvidia_smi.nvmlInit() 21 | 22 | 23 | class GCN(torch.nn.Module): 24 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout): 25 | super(GCN, self).__init__() 26 | 27 | self.convs = torch.nn.ModuleList() 28 | self.convs.append(GCNConv(in_channels, hidden_channels, cached=True)) 29 | self.bns = torch.nn.ModuleList() 30 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) 31 | for _ in range(num_layers - 2): 32 | self.convs.append( 33 | GCNConv(hidden_channels, hidden_channels, cached=True)) 34 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) 35 | self.convs.append(GCNConv(hidden_channels, out_channels, cached=True)) 36 | 37 | self.dropout = dropout 38 | 39 | def reset_parameters(self): 40 | for conv in self.convs: 41 | conv.reset_parameters() 42 | for bn in self.bns: 43 | bn.reset_parameters() 44 | 45 | def forward(self, x, adj_t): 46 | for i, conv in enumerate(self.convs[:-1]): 47 | x = conv(x, adj_t) 48 | x = self.bns[i](x) 49 | x = F.relu(x) 50 | x = F.dropout(x, p=self.dropout, training=self.training) 51 | self.out_feat = x 52 | x = self.convs[-1](x, adj_t) 53 | return x 54 | 55 | 56 | class SAGE(torch.nn.Module): 57 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers,dropout): 58 | super(SAGE, self).__init__() 59 | 60 | self.convs = torch.nn.ModuleList() 61 | self.convs.append(SAGEConv(in_channels, hidden_channels)) 62 | self.bns = torch.nn.ModuleList() 63 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) 64 | for _ in range(num_layers - 2): 65 | self.convs.append(SAGEConv(hidden_channels, hidden_channels)) 66 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) 67 | self.convs.append(SAGEConv(hidden_channels, out_channels)) 68 | 69 | self.dropout = dropout 70 | 71 | def reset_parameters(self): 72 | for conv in self.convs: 73 | conv.reset_parameters() 74 | for bn in self.bns: 75 | bn.reset_parameters() 76 | 77 | def forward(self, x, adj_t): 78 | for i, conv in enumerate(self.convs[:-1]): 79 | x = conv(x, adj_t) 80 | x = self.bns[i](x) 81 | x = F.relu(x) 82 | x = F.dropout(x, p=self.dropout, training=self.training) 83 | self.out_feat = x 84 | x = self.convs[-1](x, adj_t) 85 | return x 86 | 87 | 88 | @torch.no_grad() 89 | def test(model, data, split_idx, evaluator, device): 90 | model.eval() 91 | 92 | handle = nvidia_smi.nvmlDeviceGetHandleByIndex(device) 93 | 94 | iter_start_time = time.time() 95 | out = model(data.x, data.adj_t) 96 | iter_time = time.time() - iter_start_time 97 | gpu_used = nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used 98 | 99 | y_pred = out.argmax(dim=-1, keepdim=True) 100 | 101 | train_acc = evaluator.eval({ 102 | 'y_true': data.y[split_idx['train']], 103 | 'y_pred': y_pred[split_idx['train']], 104 | })['acc'] 105 | valid_acc = evaluator.eval({ 106 | 'y_true': data.y[split_idx['valid']], 107 | 'y_pred': y_pred[split_idx['valid']], 108 | })['acc'] 109 | test_acc = evaluator.eval({ 110 | 'y_true': data.y[split_idx['test']], 111 | 'y_pred': y_pred[split_idx['test']], 112 | })['acc'] 113 | 114 | return train_acc, valid_acc, test_acc, iter_time, gpu_used 115 | 116 | 117 | def seed(seed=0): 118 | random.seed(seed) 119 | np.random.seed(seed) 120 | torch.manual_seed(seed) 121 | torch.cuda.manual_seed(seed) 122 | torch.cuda.manual_seed_all(seed) 123 | torch.backends.cudnn.deterministic = True 124 | torch.backends.cudnn.benchmark = False 125 | # dgl.random.seed(seed) 126 | 127 | 128 | def main(): 129 | seed(42) 130 | 131 | device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' 132 | device = torch.device(device) 133 | 134 | dataset = PygNodePropPredDataset(name='ogbn-arxiv', 135 | transform=T.ToSparseTensor()) 136 | 137 | data = dataset[0] 138 | data.adj_t = data.adj_t.to_symmetric() 139 | data = data.to(device) 140 | 141 | split_idx = dataset.get_idx_split() 142 | 143 | if args.gnn == 'sage': 144 | model = SAGE(data.num_features, args.hidden_channels, 145 | dataset.num_classes, args.num_layers, 146 | args.dropout).to(device) 147 | elif args.gnn == 'gcn': 148 | model = GCN(data.num_features, args.hidden_channels, 149 | dataset.num_classes, args.num_layers, 150 | args.dropout).to(device) 151 | else: 152 | raise ValueError('Invalid GNN type') 153 | 154 | print(model) 155 | total_param = 0 156 | for param in model.parameters(): 157 | total_param += np.prod(list(param.data.size())) 158 | print(f'Total parameters: {total_param}') 159 | 160 | if args.checkpoint != "": 161 | checkpoint = torch.load(f"{args.checkpoint}/results.pt", map_location=torch.device('cpu')) 162 | model.load_state_dict(checkpoint['model_state_dict']) 163 | model.eval() 164 | 165 | evaluator = Evaluator(name='ogbn-arxiv') 166 | logger = Logger(args.runs, args) 167 | 168 | t_iter_time = [] 169 | t_gpu_used = [] 170 | for run in range(args.runs): 171 | torch.cuda.empty_cache() 172 | model.reset_parameters() 173 | 174 | result = test(model, data, split_idx, evaluator, args.device) 175 | 176 | logger.add_result(run, result[:3]) 177 | 178 | train_acc, valid_acc, test_acc, iter_time, gpu_used = result 179 | t_iter_time.append(iter_time) 180 | t_gpu_used.append(gpu_used) 181 | print(f'Run: {run + 1:02d}, ' 182 | f'Train: {100 * train_acc:.2f}%, ' 183 | f'Valid: {100 * valid_acc:.2f}% ' 184 | f'Test: {100 * test_acc:.2f}% ' 185 | f'Time: {iter_time:.3f}s, ' 186 | f'GPU: {gpu_used * 1e-9:.1f}GB' ) 187 | 188 | logger.print_statistics() 189 | print(f"Avg. Iteration Time: {np.mean(t_iter_time[1:]):.3f}s ± {np.std(t_iter_time[1:]):.3f}") 190 | print(f"Avg. GPU Used: {np.mean(t_gpu_used) * 1e-9:.1f}GB ± {np.std(t_gpu_used) * 1e-9:.1f}") 191 | 192 | 193 | if __name__ == "__main__": 194 | parser = argparse.ArgumentParser(description='OGBN-Arxiv (GNN)') 195 | 196 | # Experiment settings 197 | parser.add_argument('--device', type=int, default=0) 198 | parser.add_argument('--checkpoint', type=str, default="") 199 | 200 | # GNN settings 201 | parser.add_argument('--gnn', type=str, default='gcn') 202 | parser.add_argument('--num_layers', type=int, default=2) 203 | parser.add_argument('--hidden_channels', type=int, default=256) 204 | parser.add_argument('--dropout', type=float, default=0.5) 205 | parser.add_argument('--runs', type=int, default=10) 206 | 207 | args = parser.parse_args() 208 | 209 | print(args) 210 | main() -------------------------------------------------------------------------------- /awesome-efficient-gnns.md: -------------------------------------------------------------------------------- 1 | # 🚀 Awesome Efficient Graph Neural Networks 2 | 3 | This is a curated list of must-read papers on efficient **Graph Neural Networks** and scalable **Graph Representation Learning** for real-world applications. 4 | Contributions for new papers and topics are welcome! 5 | 6 | **Accompanying Blogpost**: [chaitjo.com/post/efficient-gnns](https://www.chaitjo.com/post/efficient-gnns/) 7 | 8 | ## Efficient and Scalable GNN Architectures 9 | - [ICML 2019] [**Simplifying Graph Convolutional Networks**](https://arxiv.org/abs/1902.07153). Felix Wu, Tianyi Zhang, Amauri Holanda de Souza Jr., Christopher Fifty, Tao Yu, Kilian Q. Weinberger. 10 | - [ICML 2020 Workshop] [**SIGN: Scalable Inception Graph Neural Networks**](https://arxiv.org/abs/2004.11198). Fabrizio Frasca, Emanuele Rossi, Davide Eynard, Ben Chamberlain, Michael Bronstein, Federico Monti. 11 | - [ICLR 2021 Workshop] [**Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions**](https://arxiv.org/abs/2104.01481). Shyam A. Tailor, Felix L. Opolka, Pietro Liò, Nicholas D. Lane. 12 | - [ICLR 2021] [**On Graph Neural Networks versus Graph-Augmented MLPs**](https://arxiv.org/pdf/2010.15116.pdf). Lei Chen, Zhengdao Chen, Joan Bruna. 13 | - [ICML 2021] [**Training Graph Neural Networks with 1000 Layers**](https://arxiv.org/abs/2106.07476). Guohao Li, Matthias Müller, Bernard Ghanem, Vladlen Koltun. 14 | 15 |

16 | 17 |
18 | Source: Simplifying Graph Convolutional Networks 19 |

20 | 21 | ## Neural Architecture Search for GNNs 22 | - [IJCAI 2020] [**GraphNAS: Graph Neural Architecture Search with Reinforcement Learning**](https://arxiv.org/abs/1904.09981). Yang Gao, Hong Yang, Peng Zhang, Chuan Zhou, Yue Hu. 23 | - [AAAI 2021 Workshop] [**Probabilistic Dual Network Architecture Search on Graphs**](https://arxiv.org/abs/2003.09676). Yiren Zhao, Duo Wang, Xitong Gao, Robert Mullins, Pietro Lio, Mateja Jamnik. 24 | - [IJCAI 2021] [**Automated Machine Learning on Graphs: A Survey**](https://arxiv.org/abs/2103.00742). 25 | Ziwei Zhang, Xin Wang, Wenwu Zhu. 26 | 27 |

28 | 29 |
30 | Source: Probabilistic Dual Network Architecture Search on Graphs 31 |

32 | 33 | ## Large-scale Graphs and Sampling Techniques 34 | - [KDD 2019] [**Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks**](https://arxiv.org/abs/1905.07953). Wei-Lin Chiang, Xuanqing Liu, Si Si, Yang Li, Samy Bengio, Cho-Jui Hsieh. 35 | - [ICLR 2020] [**GraphSAINT: Graph Sampling Based Inductive Learning Method**](https://arxiv.org/abs/1907.04931). Hanqing Zeng, Hongkuan Zhou, Ajitesh Srivastava, Rajgopal Kannan, Viktor Prasanna. 36 | - [CVPR 2020] [**L2-GCN: Layer-Wise and Learned Efficient Training of Graph Convolutional Networks**](https://openaccess.thecvf.com/content_CVPR_2020/html/You_L2-GCN_Layer-Wise_and_Learned_Efficient_Training_of_Graph_Convolutional_Networks_CVPR_2020_paper.html). Yuning You, Tianlong Chen, Zhangyang Wang, Yang Shen. 37 | - [KDD 2020] [**Scaling Graph Neural Networks with Approximate PageRank**](https://arxiv.org/abs/2007.01570). Aleksandar Bojchevski, Johannes Klicpera, Bryan Perozzi, Amol Kapoor, Martin Blais, Benedek Rózemberczki, Michal Lukasik, Stephan Günnemann. 38 | - [ICML 2021] [**GNNAutoScale: Scalable and Expressive Graph Neural Networks via Historical Embeddings**](https://arxiv.org/abs/2106.05609). Matthias Fey, Jan E. Lenssen, Frank Weichert, Jure Leskovec. 39 | - [ICLR 2021] [**Graph Traversal with Tensor Functionals: A Meta-Algorithm for Scalable Learning**](https://arxiv.org/abs/2102.04350). Elan Markowitz, Keshav Balasubramanian, Mehrnoosh Mirtaheri, Sami Abu-El-Haija, Bryan Perozzi, Greg Ver Steeg, Aram Galstyan. 40 | 41 |

42 | 43 |
44 | Source: GraphSAINT: Graph Sampling Based Inductive Learning Method 45 |

46 | 47 | ## Low Precision and Quantized GNNs 48 | - [EuroMLSys 2021] [**Learned Low Precision Graph Neural Networks**](https://arxiv.org/abs/2009.09232). Yiren Zhao, Duo Wang, Daniel Bates, Robert Mullins, Mateja Jamnik, Pietro Lio. 49 | - [ICLR 2021] [**Degree-Quant: Quantization-Aware Training for Graph Neural Networks**](https://arxiv.org/abs/2008.05000). Shyam A. Tailor, Javier Fernandez-Marques, Nicholas D. Lane. 50 | - [CVPR 2021] [**Binary Graph Neural Networks**](https://arxiv.org/abs/2012.15823). Mehdi Bahri, Gaétan Bahl, Stefanos Zafeiriou. 51 | 52 |

53 | 54 |
55 | Source: Degree-Quant: Quantization-Aware Training for Graph Neural Networks 56 |

57 | 58 | ## Knowledge Distillation for GNNs 59 | - [CVPR 2020] [**Distilling Knowledge from Graph Convolutional Networks**](https://arxiv.org/abs/2003.10477). Yiding Yang, Jiayan Qiu, Mingli Song, Dacheng Tao, Xinchao Wang. 60 | - [WWW 2021] [**Extract the Knowledge of Graph Neural Networks and Go Beyond it: An Effective Knowledge Distillation Framework**](https://arxiv.org/abs/2103.02885). Cheng Yang, Jiawei Liu, Chuan Shi. 61 | - [IJCAI 2021] [**On Self-Distilling Graph Neural Network**](https://www.ijcai.org/proceedings/2021/314). Yuzhao Chen, Yatao Bian, Xi Xiao, Yu Rong, Tingyang Xu, Junzhou Huang. 62 | - [IJCAI 2021] [**Graph-Free Knowledge Distillation for Graph Neural Networks**](https://arxiv.org/abs/2105.07519). Xiang Deng, Zhongfei Zhang. 63 | - [ArXiv 2021] [**On Representation Knowledge Distillation for Graph Neural Networks**](https://arxiv.org/abs/2111.04964). Chaitanya K. Joshi, Fayao Liu, Xu Xun, Jie Lin, Chuan-Sheng Foo. 64 | 65 |

66 | 67 |
68 | Source: On Representation Knowledge Distillation for Graph Neural Networks 69 |

70 | 71 | ## Hardware Acceleration of GNNs 72 | - [IPDPS 2019] [**Accurate, Efficient and Scalable Graph Embedding**](https://arxiv.org/abs/1810.11899). Hanqing Zeng, Hongkuan Zhou, Ajitesh Srivastava, Rajgopal Kannan, Viktor Prasanna. 73 | - [IEEE TC 2020] [**EnGN: A High-Throughput and Energy-Efficient Accelerator for Large Graph Neural Networks**](https://arxiv.org/abs/1909.00155). Shengwen Liang, Ying Wang, Cheng Liu, Lei He, Huawei Li, Xiaowei Li. 74 | - [FPGA 2020] [**GraphACT: Accelerating GCN Training on CPU-FPGA Heterogeneous Platforms**](https://arxiv.org/abs/2001.02498). Hanqing Zeng, Viktor Prasanna. 75 | - [IEEE CAD 2021] [**Rubik: A Hierarchical Architecture for Efficient Graph Learning**](https://arxiv.org/abs/2009.12495). Xiaobing Chen, Yuke Wang, Xinfeng Xie, Xing Hu, Abanti Basak, Ling Liang, Mingyu Yan, Lei Deng, Yufei Ding, Zidong Du, Yunji Chen, Yuan Xie. 76 | - [ACM Computing 2021] [**Computing Graph Neural Networks: A Survey from Algorithms to Accelerators**](https://arxiv.org/abs/2010.00130). Sergi Abadal, Akshay Jain, Robert Guirado, Jorge López-Alonso, Eduard Alarcón. 77 | 78 |

79 | 80 |
81 | Source: Computing Graph Neural Networks: A Survey from Algorithms to Accelerators 82 |

83 | 84 | ## Code Frameworks, Libraries, and Datasets 85 | 86 | - [PyG] [**PyTorch Geometric**](https://www.pyg.org/). 87 | - [DGL] [**Deep Graph Library**](https://www.dgl.ai/). 88 | - [NeurIPS 2020] [**Open Graph Benchmark: Datasets for Machine Learning on Graphs**](https://arxiv.org/abs/2005.00687). Weihua Hu, Matthias Fey, Marinka Zitnik, Yuxiao Dong, Hongyu Ren, Bowen Liu, Michele Catasta, Jure Leskovec. 89 | - [KDD Cup 2021] [**OGB-LSC: A Large-Scale Challenge for Machine Learning on Graphs**](https://ogb.stanford.edu/kddcup2021/) Weihua Hu, Matthias Fey, Hongyu Ren, Maho Nakata, Yuxiao Dong, Jure Leskovec. 90 | - [CIKM 2021] [**PyTorch Geometric Temporal: Spatiotemporal Signal Processing with Neural Machine Learning Models**](https://arxiv.org/abs/2104.07788). Benedek Rozemberczki, Paul Scherer, Yixuan He, George Panagopoulos, Alexander Riedel, Maria Astefanoaei, Oliver Kiss, Ferenc Beres, Guzmán López, Nicolas Collignon, Rik Sarkar. 91 | 92 |

93 | 94 |
95 | Source: OGB-LSC: A Large-Scale Challenge for Machine Learning on Graphs 96 |

97 | 98 | ## Industrial Applications and Systems 99 | - [KDD 2018] [**Graph Convolutional Neural Networks for Web-Scale Recommender Systems**](https://arxiv.org/abs/1806.01973). Rex Ying, Ruining He, Kaifeng Chen, Pong Eksombatchai, William L. Hamilton, Jure Leskovec. 100 | - [VLDB 2019] [**AliGraph: A Comprehensive Graph Neural Network Platform**](https://arxiv.org/abs/1902.08730). Rong Zhu, Kun Zhao, Hongxia Yang, Wei Lin, Chang Zhou, Baole Ai, Yong Li, Jingren Zhou. 101 | - [KDD 2020] [**PinnerSage: Multi-Modal User Embedding Framework for Recommendations at Pinterest**](https://arxiv.org/abs/2007.03634) Aditya Pal, Chantat Eksombatchai, Yitong Zhou, Bo Zhao, Charles Rosenberg, Jure Leskovec. 102 | - [CIKM 2020] [**P-Companion: A Principled Framework for Diversified Complementary Product Recommendation**](https://dl.acm.org/doi/10.1145/3340531.3412732) Junheng Hao, Tong Zhao, Jin Li, Xin Luna Dong, Christos Faloutsos, Yizhou Sun, and Wei Wang. 103 | - [CIKM 2021] [**ETA Prediction with Graph Neural Networks in Google Maps**](https://arxiv.org/abs/2108.11482). Austin Derrow-Pinion, Jennifer She, David Wong, Oliver Lange, Todd Hester, Luis Perez, Marc Nunkesser, Seongjae Lee, Xueying Guo, Brett Wiltshire, Peter W. Battaglia, Vishal Gupta, Ang Li, Zhongwen Xu, Alvaro Sanchez-Gonzalez, Yujia Li, Petar Veličković. 104 | 105 |

106 | 107 |
108 | Source: Graph Convolutional Neural Networks for Web-Scale Recommender Systems 109 |

110 | -------------------------------------------------------------------------------- /img/architectures.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/efficient-gnns/f0b079d19e1a2587c085cba2bfca59d3d2a08298/img/architectures.png -------------------------------------------------------------------------------- /img/arxiv-mag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/efficient-gnns/f0b079d19e1a2587c085cba2bfca59d3d2a08298/img/arxiv-mag.png -------------------------------------------------------------------------------- /img/computing-gnns.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/efficient-gnns/f0b079d19e1a2587c085cba2bfca59d3d2a08298/img/computing-gnns.PNG -------------------------------------------------------------------------------- /img/degree-quant.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/efficient-gnns/f0b079d19e1a2587c085cba2bfca59d3d2a08298/img/degree-quant.PNG -------------------------------------------------------------------------------- /img/distillation.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/efficient-gnns/f0b079d19e1a2587c085cba2bfca59d3d2a08298/img/distillation.PNG -------------------------------------------------------------------------------- /img/dual-nas.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/efficient-gnns/f0b079d19e1a2587c085cba2bfca59d3d2a08298/img/dual-nas.PNG -------------------------------------------------------------------------------- /img/graph-saint.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/efficient-gnns/f0b079d19e1a2587c085cba2bfca59d3d2a08298/img/graph-saint.PNG -------------------------------------------------------------------------------- /img/molhiv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/efficient-gnns/f0b079d19e1a2587c085cba2bfca59d3d2a08298/img/molhiv.png -------------------------------------------------------------------------------- /img/ogb-lsc.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/efficient-gnns/f0b079d19e1a2587c085cba2bfca59d3d2a08298/img/ogb-lsc.PNG -------------------------------------------------------------------------------- /img/pinsage.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/efficient-gnns/f0b079d19e1a2587c085cba2bfca59d3d2a08298/img/pinsage.PNG -------------------------------------------------------------------------------- /img/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/efficient-gnns/f0b079d19e1a2587c085cba2bfca59d3d2a08298/img/pipeline.png -------------------------------------------------------------------------------- /img/ppi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/efficient-gnns/f0b079d19e1a2587c085cba2bfca59d3d2a08298/img/ppi.png -------------------------------------------------------------------------------- /img/sgc.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/efficient-gnns/f0b079d19e1a2587c085cba2bfca59d3d2a08298/img/sgc.PNG -------------------------------------------------------------------------------- /img/techniques.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/efficient-gnns/f0b079d19e1a2587c085cba2bfca59d3d2a08298/img/techniques.png -------------------------------------------------------------------------------- /mag_pyg/README.md: -------------------------------------------------------------------------------- 1 | # Knowledge Distillation for GNNs (MAG with PyG) 2 | 3 | **Dataset**: MAG 4 | 5 | **Library**: PyG 6 | 7 | This repository contains code to benchmark knowledge distillation for GNNs on the MAG dataset, developed in the PyG framework. 8 | The main purpose of the codebase is to: 9 | - Train teacher R-GCN models on MAG dataset via supervised learning and export the checkpoints 10 | - Train student R-GCN models with/without knowledge distillation. 11 | 12 | ![MAG Results](../img/arxiv-mag.png) 13 | 14 | ## Directory Structure 15 | 16 | ``` 17 | . 18 | ├── dataset # automatically created by OGB data downloaders 19 | | 20 | ├── scripts # scripts to conduct full experiments and reproduce results 21 | │ ├── run_kd_and_aux.sh # script to benchmark all KD+Auxiliary losses 22 | │ ├── run.sh # script to benchmark all KD losses 23 | │ └── teacher.sh # script to train and save teacher checkpoints 24 | | 25 | ├── README.md 26 | | 27 | ├── criterion.py # KD loss functions 28 | ├── gnn_kd_and_aux.py # train student GNNs via KD+Auxiliary loss training 29 | ├── gnn.py # train student GNNs via Auxiliary representation distillation loss 30 | ├── logger.py # logging utilities 31 | ├── submit.py # read log directory to aggregate results 32 | └── test.py # test model checkpoint and timing 33 | ``` 34 | 35 | ## Example Usage 36 | 37 | For full usage, each file has accompanying flags and documentation. 38 | Also see the `scripts` folder for reproducing results. 39 | -------------------------------------------------------------------------------- /mag_pyg/criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | from torch_geometric.utils import softmax, to_dense_adj, subgraph, negative_sampling, add_self_loops 6 | 7 | 8 | def kd_criterion(logits, labels, teacher_logits, alpha=0.9, T=4): 9 | """Logit-based KD, [Hinton et al., 2015](https://arxiv.org/abs/1503.02531) 10 | """ 11 | loss_cls = F.cross_entropy(logits, labels) 12 | 13 | loss_kd = F.kl_div( 14 | F.log_softmax(logits/ T, dim=1), 15 | F.softmax(teacher_logits/ T, dim=1), 16 | log_target=False 17 | ) 18 | 19 | loss = loss_kd* (alpha* T* T) + loss_cls* (1-alpha) 20 | 21 | return loss, loss_cls, loss_kd 22 | 23 | 24 | def fitnet_criterion(logits, labels, feat, teacher_feat, beta=1000): 25 | """FitNet, [Romero et al., 2014](https://arxiv.org/abs/1412.6550) 26 | """ 27 | loss_cls = F.cross_entropy(logits, labels) 28 | 29 | feat = F.normalize(feat, p=2, dim=-1) 30 | teacher_feat = F.normalize(teacher_feat, p=2, dim=-1) 31 | 32 | loss_fitnet = F.mse_loss(feat, teacher_feat) 33 | 34 | loss = loss_cls + beta* loss_fitnet 35 | 36 | return loss, loss_cls, loss_fitnet 37 | 38 | 39 | def at_criterion(logits, labels, feat, teacher_feat, beta=1000): 40 | """Attention Transfer, [Zagoruyko and Komodakis, 2016](https://arxiv.org/abs/1612.03928) 41 | """ 42 | loss_cls = F.cross_entropy(logits, labels) 43 | 44 | feat = feat.pow(2).sum(-1) 45 | teacher_feat = teacher_feat.pow(2).sum(-1) 46 | 47 | loss_at = F.mse_loss( 48 | F.normalize(feat, p=2, dim=-1), 49 | F.normalize(teacher_feat, p=2, dim=-1) 50 | ) 51 | 52 | loss = loss_cls + beta* loss_at 53 | 54 | return loss, loss_cls, loss_at 55 | 56 | 57 | def gpw_criterion(logits, labels, feat, teacher_feat, kernel='cosine', beta=1, max_samples=8192): 58 | """Global Structure Preserving loss, [Joshi et al., TNNLS 2022](https://arxiv.org/abs/2111.04964) 59 | """ 60 | loss_cls = F.cross_entropy(logits, labels) 61 | 62 | if max_samples < feat.shape[0]: 63 | sampled_inds = np.random.choice(feat.shape[0], max_samples, replace=False) 64 | feat = feat[sampled_inds] 65 | teacher_feat = teacher_feat[sampled_inds] 66 | 67 | pw_sim = None 68 | teacher_pw_sim = None 69 | if kernel == 'cosine': 70 | feat = F.normalize(feat, p=2, dim=-1) 71 | teacher_feat = F.normalize(teacher_feat, p=2, dim=-1) 72 | pw_sim = torch.mm(feat, feat.transpose(0, 1)).flatten() 73 | teacher_pw_sim = torch.mm(teacher_feat, teacher_feat.transpose(0, 1)).flatten() 74 | elif kernel == 'poly': 75 | feat = F.normalize(feat, p=2, dim=-1) 76 | teacher_feat = F.normalize(teacher_feat, p=2, dim=-1) 77 | pw_sim = torch.mm(feat, feat.transpose(0, 1)).flatten()**2 78 | teacher_pw_sim = torch.mm(teacher_feat, teacher_feat.transpose(0, 1)).flatten()**2 79 | elif kernel == 'l2': 80 | pw_sim = (feat.unsqueeze(0) - feat.unsqueeze(1)).norm(p=2, dim=-1).flatten() 81 | teacher_pw_sim = (teacher_feat.unsqueeze(0) - teacher_feat.unsqueeze(1)).norm(p=2, dim=-1).flatten() 82 | elif kernel == 'rbf': 83 | pw_sim = torch.exp(-0.5* ((feat.unsqueeze(0) - feat.unsqueeze(1))**2).sum(dim=-1).flatten()) 84 | teacher_pw_sim = torch.exp(-0.5* ((teacher_feat.unsqueeze(0) - teacher_feat.unsqueeze(1))**2).sum(dim=-1).flatten()) 85 | else: 86 | raise NotImplementedError 87 | 88 | loss_gpw = F.mse_loss(pw_sim, teacher_pw_sim) 89 | 90 | loss = loss_cls + beta* loss_gpw 91 | 92 | return loss, loss_cls, loss_gpw 93 | 94 | 95 | def lpw_criterion(logits, labels, feat, teacher_feat, edge_index, kernel='cosine', beta=100, criterion='kld'): 96 | """Local Structure Preserving loss, [Yang et al., CVPR 2020](https://arxiv.org/abs/2003.10477) 97 | """ 98 | loss_cls = F.cross_entropy(logits, labels) 99 | 100 | src, dst = edge_index 101 | 102 | if kernel == 'cosine': 103 | pw_sim = softmax(F.cosine_similarity(feat[src], feat[dst]), dst) 104 | teacher_pw_sim = softmax(F.cosine_similarity(teacher_feat[src], teacher_feat[dst]), dst) 105 | elif kernel == 'poly': 106 | pw_sim = softmax(F.cosine_similarity(feat[src], feat[dst])**2, dst) 107 | teacher_pw_sim = softmax(F.cosine_similarity(teacher_feat[src], teacher_feat[dst])**2, dst) 108 | elif kernel == 'l2': 109 | pw_sim = softmax((feat[src] - feat[dst]).norm(p=2, dim=-1), dst) 110 | teacher_pw_sim = softmax((teacher_feat[src] - teacher_feat[dst]).norm(p=2, dim=-1), dst) 111 | elif kernel == 'rbf': 112 | pw_sim = softmax(torch.exp( -0.5* ((feat[src] - feat[dst])**2).sum(dim=-1) ), dst) 113 | teacher_pw_sim = softmax(torch.exp( -0.5* ((teacher_feat[src] - teacher_feat[dst])**2).sum(dim=-1) ), dst) 114 | else: 115 | raise NotImplementedError 116 | 117 | if criterion == 'mse': 118 | loss_lpw = F.mse_loss(pw_sim, teacher_pw_sim) 119 | elif criterion == 'kld': 120 | loss_lpw = F.kl_div(torch.log(pw_sim), teacher_pw_sim, log_target=False) 121 | else: 122 | raise NotImplementedError 123 | 124 | loss = loss_cls + beta* loss_lpw 125 | 126 | return loss, loss_cls, loss_lpw 127 | 128 | 129 | def nce_criterion(logits, labels, feat, teacher_feat, beta=0.5, nce_T=0.075, max_samples=8192): 130 | """Graph Contrastive Representation Distillation, [Joshi et al., TNNLS 2022](https://arxiv.org/abs/2111.04964) 131 | """ 132 | loss_cls = F.cross_entropy(logits, labels) 133 | 134 | if max_samples < feat.shape[0]: 135 | sampled_inds = np.random.choice(feat.shape[0], max_samples, replace=False) 136 | feat = feat[sampled_inds] 137 | teacher_feat = teacher_feat[sampled_inds] 138 | 139 | feat = F.normalize(feat, p=2, dim=-1) 140 | teacher_feat = F.normalize(teacher_feat, p=2, dim=-1) 141 | 142 | nce_logits = torch.mm(feat, teacher_feat.transpose(0, 1)) 143 | nce_labels = torch.arange(feat.shape[0]).to(feat.device) 144 | 145 | loss_nce = F.cross_entropy(nce_logits/ nce_T, nce_labels) 146 | 147 | loss = loss_cls + beta* loss_nce 148 | 149 | return loss, loss_cls, loss_nce 150 | -------------------------------------------------------------------------------- /mag_pyg/logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Logger(object): 5 | def __init__(self, runs, info=None): 6 | self.info = info 7 | self.results = [[] for _ in range(runs)] 8 | 9 | def add_result(self, run, result): 10 | assert len(result) == 3 11 | assert run >= 0 and run < len(self.results) 12 | self.results[run].append(result) 13 | 14 | def print_statistics(self, run=None): 15 | if run is not None: 16 | result = 100 * torch.tensor(self.results[run]) 17 | argmax = result[:, 1].argmax().item() 18 | print(f'Run {run + 1:02d}:') 19 | print(f'Highest Train: {result[:, 0].max():.2f}') 20 | print(f'Highest Valid: {result[:, 1].max():.2f}') 21 | print(f' Final Train: {result[argmax, 0]:.2f}') 22 | print(f' Final Test: {result[argmax, 2]:.2f}') 23 | else: 24 | result = 100 * torch.tensor(self.results) 25 | 26 | best_results = [] 27 | for r in result: 28 | train1 = r[:, 0].max().item() 29 | valid = r[:, 1].max().item() 30 | train2 = r[r[:, 1].argmax(), 0].item() 31 | test = r[r[:, 1].argmax(), 2].item() 32 | best_results.append((train1, valid, train2, test)) 33 | 34 | best_result = torch.tensor(best_results) 35 | 36 | print(f'All runs:') 37 | r = best_result[:, 0] 38 | print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}') 39 | r = best_result[:, 1] 40 | print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}') 41 | r = best_result[:, 2] 42 | print(f' Final Train: {r.mean():.2f} ± {r.std():.2f}') 43 | r = best_result[:, 3] 44 | print(f' Final Test: {r.mean():.2f} ± {r.std():.2f}') -------------------------------------------------------------------------------- /mag_pyg/scripts/run.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | 4 | # Experimental setup: 5 | # - We tune the loss balancing weights α, β in for all techniques on the validation set when using both logit-based and representation distillation together. 6 | # - For KD, we tune α ∈ {0.8, 0.9}, τ1 ∈ {4, 5}. 7 | # - For FitNet, AT, LSP, and GSP, we tune β ∈ {100, 1000, 10000} and the kernel for LSP, GSP. 8 | # - For G-CRD, we tune β ∈ {0.01, 0.05}, τ2 ∈ {0.05, 0.075, 0.1} and the projection head. 9 | # When comparing representation distillation methods, we set α to 0 in order to ablate performance and reduce β by one order of magnitude. 10 | 11 | tmux new -s mag -d 12 | tmux send-keys "conda activate ogb" C-m 13 | 14 | runs=5 15 | 16 | num_layers=2 17 | hidden_channels=32 18 | 19 | ############################ 20 | 21 | training=supervised 22 | 23 | tmux send-keys " 24 | python gnn.py --device 0 --seed 0 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training & 25 | python gnn.py --device 1 --seed 5 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training & 26 | wait" C-m 27 | 28 | ############################ 29 | 30 | training=kd 31 | 32 | tmux send-keys " 33 | python gnn.py --device 0 --seed 0 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training & 34 | python gnn.py --device 1 --seed 5 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training & 35 | wait" C-m 36 | 37 | ############################ 38 | 39 | training=fitnet 40 | beta=100 41 | 42 | tmux send-keys " 43 | python gnn.py --device 0 --seed 0 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta & 44 | python gnn.py --device 1 --seed 5 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta & 45 | wait" C-m 46 | 47 | ############################ 48 | 49 | training=at 50 | beta=10000 51 | 52 | tmux send-keys " 53 | python gnn.py --device 0 --seed 0 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta & 54 | python gnn.py --device 1 --seed 5 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta & 55 | wait" C-m 56 | 57 | # ############################ 58 | 59 | expt_name=graph_saint/lpw-rbf 60 | training=lpw 61 | kernel=rbf 62 | beta=100 63 | 64 | tmux send-keys " 65 | python gnn.py --device 0 --seed 0 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --kernel $kernel & 66 | python gnn.py --device 1 --seed 5 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --kernel $kernel & 67 | wait" C-m 68 | 69 | ############################ 70 | 71 | expt_name=graph_saint/lpw-cosine 72 | training=lpw 73 | kernel=cosine 74 | beta=100 75 | 76 | tmux send-keys " 77 | python gnn.py --device 0 --seed 0 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --kernel $kernel & 78 | python gnn.py --device 1 --seed 5 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --kernel $kernel & 79 | wait" C-m 80 | 81 | ############################ 82 | 83 | expt_name=graph_saint/lpw-poly 84 | training=lpw 85 | kernel=poly 86 | beta=100 87 | 88 | tmux send-keys " 89 | python gnn.py --device 0 --seed 0 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --kernel $kernel & 90 | python gnn.py --device 1 --seed 5 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --kernel $kernel & 91 | wait" C-m 92 | 93 | ############################ 94 | 95 | expt_name=graph_saint/gpw-rbf 96 | training=gpw 97 | kernel=rbf 98 | beta=100 99 | max_samples=8192 100 | 101 | tmux send-keys " 102 | python gnn.py --device 0 --seed 0 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --kernel $kernel --max_samples $max_samples & 103 | python gnn.py --device 1 --seed 5 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --kernel $kernel --max_samples $max_samples & 104 | wait" C-m 105 | 106 | ############################ 107 | 108 | expt_name=graph_saint/gpw-cosine 109 | training=gpw 110 | kernel=cosine 111 | beta=100 112 | max_samples=24576 113 | 114 | tmux send-keys " 115 | python gnn.py --device 0 --seed 0 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --kernel $kernel --max_samples $max_samples & 116 | python gnn.py --device 1 --seed 5 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --kernel $kernel --max_samples $max_samples & 117 | wait" C-m 118 | 119 | ############################ 120 | 121 | expt_name=graph_saint/gpw-poly 122 | training=gpw 123 | kernel=poly 124 | beta=100 125 | max_samples=24576 126 | 127 | tmux send-keys " 128 | python gnn.py --device 0 --seed 0 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --kernel $kernel --max_samples $max_samples & 129 | python gnn.py --device 1 --seed 5 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --kernel $kernel --max_samples $max_samples & 130 | wait" C-m 131 | 132 | ############################ 133 | 134 | expt_name=graph_saint/nce 135 | training=nce 136 | beta=0.1 137 | nce_T=0.075 138 | max_samples=24576 139 | 140 | tmux send-keys " 141 | python gnn.py --device 0 --seed 0 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --nce_T $nce_T --max_samples $max_samples & 142 | python gnn.py --device 1 --seed 5 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --nce_T $nce_T --max_samples $max_samples & 143 | wait" C-m 144 | 145 | ########################### 146 | 147 | tmux send-keys "tmux kill-session -t mag" C-m 148 | -------------------------------------------------------------------------------- /mag_pyg/scripts/run_kd_and_aux.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | 4 | # Experimental setup: 5 | # - We tune the loss balancing weights α, β in for all techniques on the validation set when using both logit-based and representation distillation together. 6 | # - For KD, we tune α ∈ {0.8, 0.9}, τ1 ∈ {4, 5}. 7 | # - For FitNet, AT, LSP, and GSP, we tune β ∈ {100, 1000, 10000} and the kernel for LSP, GSP. 8 | # - For G-CRD, we tune β ∈ {0.01, 0.05}, τ2 ∈ {0.05, 0.075, 0.1} and the projection head. 9 | # When comparing representation distillation methods, we set α to 0 in order to ablate performance and reduce β by one order of magnitude. 10 | 11 | tmux new -s mag -d 12 | tmux send-keys "conda activate ogb" C-m 13 | 14 | runs=5 15 | 16 | num_layers=2 17 | hidden_channels=32 18 | 19 | ############################ 20 | 21 | training=fitnet 22 | beta=1 23 | 24 | tmux send-keys " 25 | python gnn_kd_and_aux.py --device 0 --seed 0 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta & 26 | python gnn_kd_and_aux.py --device 1 --seed 5 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta & 27 | wait" C-m 28 | 29 | ############################ 30 | 31 | training=at 32 | beta=100 33 | 34 | tmux send-keys " 35 | python gnn_kd_and_aux.py --device 0 --seed 0 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta & 36 | python gnn_kd_and_aux.py --device 1 --seed 5 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta & 37 | wait" C-m 38 | 39 | # ############################ 40 | 41 | expt_name=graph_saint/lpw-rbf 42 | training=lpw 43 | kernel=rbf 44 | beta=1 45 | 46 | tmux send-keys " 47 | python gnn_kd_and_aux.py --device 0 --seed 0 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --kernel $kernel --expt_name $expt_name & 48 | python gnn_kd_and_aux.py --device 1 --seed 5 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --kernel $kernel --expt_name $expt_name & 49 | wait" C-m 50 | 51 | ############################ 52 | 53 | expt_name=graph_saint/lpw-cosine 54 | training=lpw 55 | kernel=cosine 56 | beta=1 57 | 58 | tmux send-keys " 59 | python gnn_kd_and_aux.py --device 0 --seed 0 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --kernel $kernel --expt_name $expt_name & 60 | python gnn_kd_and_aux.py --device 1 --seed 5 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --kernel $kernel --expt_name $expt_name & 61 | wait" C-m 62 | 63 | ############################ 64 | 65 | expt_name=graph_saint/lpw-poly 66 | training=lpw 67 | kernel=poly 68 | beta=1 69 | 70 | tmux send-keys " 71 | python gnn_kd_and_aux.py --device 0 --seed 0 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --kernel $kernel --expt_name $expt_name & 72 | python gnn_kd_and_aux.py --device 1 --seed 5 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --kernel $kernel --expt_name $expt_name & 73 | wait" C-m 74 | 75 | ############################ 76 | 77 | expt_name=graph_saint/gpw-cosine 78 | training=gpw 79 | kernel=cosine 80 | beta=1 81 | max_samples=24576 82 | 83 | tmux send-keys " 84 | python gnn_kd_and_aux.py --device 0 --seed 0 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --kernel $kernel --expt_name $expt_name --max_samples $max_samples --expt_name $expt_name & 85 | python gnn_kd_and_aux.py --device 1 --seed 5 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --kernel $kernel --expt_name $expt_name --max_samples $max_samples --expt_name $expt_name & 86 | wait" C-m 87 | 88 | ############################ 89 | 90 | expt_name=graph_saint/gpw-poly 91 | training=gpw 92 | kernel=poly 93 | beta=1 94 | max_samples=24576 95 | 96 | tmux send-keys " 97 | python gnn_kd_and_aux.py --device 0 --seed 0 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --kernel $kernel --expt_name $expt_name --max_samples $max_samples --expt_name $expt_name & 98 | python gnn_kd_and_aux.py --device 1 --seed 5 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --kernel $kernel --expt_name $expt_name --max_samples $max_samples --expt_name $expt_name & 99 | wait" C-m 100 | 101 | ############################ 102 | 103 | expt_name=graph_saint/nce 104 | training=nce 105 | beta=0.1 106 | nce_T=0.075 107 | max_samples=24576 108 | 109 | tmux send-keys " 110 | python gnn_kd_and_aux.py --device 0 --seed 0 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --nce_T $nce_T --max_samples $max_samples --expt_name $expt_name & 111 | python gnn_kd_and_aux.py --device 1 --seed 5 --runs $runs --num_layers $num_layers --hidden_channels $hidden_channels --training $training --beta $beta --nce_T $nce_T --max_samples $max_samples --expt_name $expt_name & 112 | wait" C-m 113 | 114 | ########################### 115 | 116 | tmux send-keys "tmux kill-session -t mag" C-m 117 | -------------------------------------------------------------------------------- /mag_pyg/scripts/teacher.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | 4 | tmux new -s expt -d 5 | tmux send-keys "conda activate ogb" C-m 6 | 7 | runs=5 8 | 9 | 10 | ############################ 11 | 12 | tmux send-keys " 13 | python gnn.py --device 0 --seed 0 --runs $runs --num_layers 3 --hidden_channels 512 & 14 | python gnn.py --device 1 --seed 5 --runs $runs --num_layers 3 --hidden_channels 512 & 15 | wait" C-m 16 | 17 | tmux send-keys "tmux kill-session -t expt" C-m 18 | -------------------------------------------------------------------------------- /mag_pyg/submit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import argparse 5 | 6 | 7 | if __name__ == "__main__": 8 | # Experiment settings 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--expt_name', type=str, default="debug", 11 | help='Name of experiment logs/results folder') 12 | parser.add_argument('--run', type=str, default="debug", 13 | help='Name of specific run') 14 | args = parser.parse_args() 15 | 16 | expt_args = [] 17 | model = [] 18 | total_param = [] 19 | BestEpoch = [] 20 | Validation = [] 21 | Test = [] 22 | Train = [] 23 | BestTrain = [] 24 | 25 | for root, dirs, files in os.walk(f'/data/user/mag_pyg_logs/logs_kd_and_aux/{args.expt_name}/{args.run}'): 26 | if 'results.pt' in files: 27 | results = torch.load(os.path.join(root, 'results.pt'), map_location=torch.device('cpu')) 28 | expt_args.append(results['args']) 29 | total_param.append(results['total_param']) 30 | BestEpoch.append(results['BestEpoch']) 31 | Validation.append(results['Validation']) 32 | Test.append(results['Test']) 33 | Train.append(results['Train']) 34 | 35 | print(results['args'].seed, results['Test'], results['Validation'], results['Train']) 36 | 37 | print(expt_args[0]) 38 | print() 39 | print(f'Test performance: {np.mean(Test)*100:.2f} +- {np.std(Test)*100:.2f}') 40 | print(f'Validation performance: {np.mean(Validation)*100:.2f} +- {np.std(Validation)*100:.2f}') 41 | print(f'Train performance: {np.mean(Train)*100:.2f} +- {np.std(Train)*100:.2f}') 42 | print(f'Total parameters: {int(np.mean(total_param))}') -------------------------------------------------------------------------------- /mag_pyg/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import time 4 | import os 5 | import random 6 | from copy import copy 7 | from tqdm import tqdm 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.nn import ModuleList, Linear, ParameterDict, Parameter 12 | 13 | from torch_sparse import SparseTensor 14 | from torch_geometric.utils import to_undirected, subgraph 15 | from torch_geometric.data import Data, GraphSAINTRandomWalkSampler 16 | from torch_geometric.utils.hetero import group_hetero_graph 17 | from torch_geometric.nn import MessagePassing 18 | 19 | from ogb.nodeproppred import PygNodePropPredDataset, Evaluator 20 | 21 | from logger import Logger 22 | 23 | import nvidia_smi 24 | nvidia_smi.nvmlInit() 25 | 26 | 27 | class RGCNConv(MessagePassing): 28 | def __init__(self, in_channels, out_channels, num_node_types, 29 | num_edge_types): 30 | super(RGCNConv, self).__init__(aggr='mean') 31 | 32 | self.in_channels = in_channels 33 | self.out_channels = out_channels 34 | self.num_node_types = num_node_types 35 | self.num_edge_types = num_edge_types 36 | 37 | self.rel_lins = ModuleList([ 38 | Linear(in_channels, out_channels, bias=False) 39 | for _ in range(num_edge_types) 40 | ]) 41 | 42 | self.root_lins = ModuleList([ 43 | Linear(in_channels, out_channels, bias=True) 44 | for _ in range(num_node_types) 45 | ]) 46 | 47 | self.reset_parameters() 48 | 49 | def reset_parameters(self): 50 | for lin in self.rel_lins: 51 | lin.reset_parameters() 52 | for lin in self.root_lins: 53 | lin.reset_parameters() 54 | 55 | def forward(self, x, edge_index, edge_type, node_type): 56 | out = x.new_zeros(x.size(0), self.out_channels) 57 | 58 | for i in range(self.num_edge_types): 59 | mask = edge_type == i 60 | out.add_(self.propagate(edge_index[:, mask], x=x, edge_type=i)) 61 | 62 | for i in range(self.num_node_types): 63 | mask = node_type == i 64 | out[mask] += self.root_lins[i](x[mask]) 65 | 66 | return out 67 | 68 | def message(self, x_j, edge_type: int): 69 | return self.rel_lins[edge_type](x_j) 70 | 71 | 72 | class RGCN(torch.nn.Module): 73 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, 74 | dropout, num_nodes_dict, x_types, num_edge_types): 75 | super(RGCN, self).__init__() 76 | 77 | self.in_channels = in_channels 78 | self.hidden_channels = hidden_channels 79 | self.out_channels = out_channels 80 | self.num_layers = num_layers 81 | self.dropout = dropout 82 | 83 | node_types = list(num_nodes_dict.keys()) 84 | num_node_types = len(node_types) 85 | 86 | self.num_node_types = num_node_types 87 | self.num_edge_types = num_edge_types 88 | 89 | # Create embeddings for all node types that do not come with features. 90 | self.emb_dict = ParameterDict({ 91 | f'{key}': Parameter(torch.Tensor(num_nodes_dict[key], in_channels)) 92 | for key in set(node_types).difference(set(x_types)) 93 | }) 94 | 95 | I, H, O = in_channels, hidden_channels, out_channels # noqa 96 | 97 | # Create `num_layers` many message passing layers. 98 | self.convs = ModuleList() 99 | self.convs.append(RGCNConv(I, H, num_node_types, num_edge_types)) 100 | for _ in range(num_layers - 2): 101 | self.convs.append(RGCNConv(H, H, num_node_types, num_edge_types)) 102 | self.convs.append(RGCNConv(H, O, self.num_node_types, num_edge_types)) 103 | 104 | self.reset_parameters() 105 | 106 | def reset_parameters(self): 107 | for emb in self.emb_dict.values(): 108 | torch.nn.init.xavier_uniform_(emb) 109 | for conv in self.convs: 110 | conv.reset_parameters() 111 | 112 | def group_input(self, x_dict, node_type, local_node_idx): 113 | # Create global node feature matrix. 114 | h = torch.zeros((node_type.size(0), self.in_channels), 115 | device=node_type.device) 116 | 117 | for key, x in x_dict.items(): 118 | mask = node_type == key 119 | h[mask] = x[local_node_idx[mask]] 120 | 121 | for key, emb in self.emb_dict.items(): 122 | mask = node_type == int(key) 123 | h[mask] = emb[local_node_idx[mask]] 124 | 125 | return h 126 | 127 | def forward(self, x_dict, edge_index, edge_type, node_type, 128 | local_node_idx): 129 | 130 | x = self.group_input(x_dict, node_type, local_node_idx) 131 | 132 | for i, conv in enumerate(self.convs): 133 | x = conv(x, edge_index, edge_type, node_type) 134 | if i != self.num_layers - 1: 135 | x = F.relu(x) 136 | x = F.dropout(x, p=0.5, training=self.training) 137 | self.out_feat = x 138 | 139 | return x 140 | 141 | def inference(self, x_dict, edge_index_dict, key2int): 142 | # We can perform full-batch inference on GPU. 143 | 144 | device = list(x_dict.values())[0].device 145 | 146 | x_dict = copy(x_dict) 147 | for key, emb in self.emb_dict.items(): 148 | x_dict[int(key)] = emb 149 | 150 | adj_t_dict = {} 151 | for key, (row, col) in edge_index_dict.items(): 152 | adj_t_dict[key] = SparseTensor(row=col, col=row).to(device) 153 | 154 | iter_start_time = time.time() 155 | for i, conv in enumerate(self.convs): 156 | out_dict = {} 157 | 158 | for j, x in x_dict.items(): 159 | out_dict[j] = conv.root_lins[j](x) 160 | 161 | for keys, adj_t in adj_t_dict.items(): 162 | src_key, target_key = keys[0], keys[-1] 163 | out = out_dict[key2int[target_key]] 164 | tmp = adj_t.matmul(x_dict[key2int[src_key]], reduce='mean') 165 | out.add_(conv.rel_lins[key2int[keys]](tmp)) 166 | 167 | if i != self.num_layers - 1: 168 | for j in range(self.num_node_types): 169 | F.relu_(out_dict[j]) 170 | 171 | x_dict = out_dict 172 | 173 | iter_time = time.time() - iter_start_time 174 | 175 | return x_dict, iter_time 176 | 177 | 178 | @torch.no_grad() 179 | def test(model, data, x_dict, edge_index_dict, key2int, split_idx, evaluator, device): 180 | model.eval() 181 | 182 | handle = nvidia_smi.nvmlDeviceGetHandleByIndex(device) 183 | 184 | # iter_start_time = time.time() 185 | out, iter_time = model.inference(x_dict, edge_index_dict, key2int) 186 | # iter_time = time.time() - iter_start_time 187 | gpu_used = nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used 188 | 189 | out = out[key2int['paper']] 190 | 191 | y_pred = out.argmax(dim=-1, keepdim=True).cpu() 192 | y_true = data.y_dict['paper'] 193 | 194 | train_acc = evaluator.eval({ 195 | 'y_true': y_true[split_idx['train']['paper']], 196 | 'y_pred': y_pred[split_idx['train']['paper']], 197 | })['acc'] 198 | valid_acc = evaluator.eval({ 199 | 'y_true': y_true[split_idx['valid']['paper']], 200 | 'y_pred': y_pred[split_idx['valid']['paper']], 201 | })['acc'] 202 | test_acc = evaluator.eval({ 203 | 'y_true': y_true[split_idx['test']['paper']], 204 | 'y_pred': y_pred[split_idx['test']['paper']], 205 | })['acc'] 206 | 207 | return train_acc, valid_acc, test_acc, iter_time, gpu_used 208 | 209 | 210 | def seed(seed=0): 211 | random.seed(seed) 212 | np.random.seed(seed) 213 | torch.manual_seed(seed) 214 | torch.cuda.manual_seed(seed) 215 | torch.cuda.manual_seed_all(seed) 216 | torch.backends.cudnn.deterministic = True 217 | torch.backends.cudnn.benchmark = False 218 | 219 | 220 | def main(): 221 | seed(42) 222 | 223 | dataset = PygNodePropPredDataset(name='ogbn-mag') 224 | data = dataset[0] 225 | split_idx = dataset.get_idx_split() 226 | evaluator = Evaluator(name='ogbn-mag') 227 | logger = Logger(args.runs, args) 228 | 229 | # We do not consider those attributes for now. 230 | data.node_year_dict = None 231 | data.edge_reltype_dict = None 232 | 233 | print(data) 234 | 235 | edge_index_dict = data.edge_index_dict 236 | 237 | # We need to add reverse edges to the heterogeneous graph. 238 | r, c = edge_index_dict[('author', 'affiliated_with', 'institution')] 239 | edge_index_dict[('institution', 'to', 'author')] = torch.stack([c, r]) 240 | 241 | r, c = edge_index_dict[('author', 'writes', 'paper')] 242 | edge_index_dict[('paper', 'to', 'author')] = torch.stack([c, r]) 243 | 244 | r, c = edge_index_dict[('paper', 'has_topic', 'field_of_study')] 245 | edge_index_dict[('field_of_study', 'to', 'paper')] = torch.stack([c, r]) 246 | 247 | # Convert to undirected paper <-> paper relation. 248 | edge_index = to_undirected(edge_index_dict[('paper', 'cites', 'paper')]) 249 | edge_index_dict[('paper', 'cites', 'paper')] = edge_index 250 | 251 | # We convert the individual graphs into a single big one, so that sampling 252 | # neighbors does not need to care about different edge types. 253 | # This will return the following: 254 | # * `edge_index`: The new global edge connectivity. 255 | # * `edge_type`: The edge type for each edge. 256 | # * `node_type`: The node type for each node. 257 | # * `local_node_idx`: The original index for each node. 258 | # * `local2global`: A dictionary mapping original (local) node indices of 259 | # type `key` to global ones. 260 | # `key2int`: A dictionary that maps original keys to their new canonical type. 261 | out = group_hetero_graph(data.edge_index_dict, data.num_nodes_dict) 262 | edge_index, edge_type, node_type, local_node_idx, local2global, key2int = out 263 | 264 | homo_data = Data(edge_index=edge_index, edge_attr=edge_type, 265 | node_type=node_type, local_node_idx=local_node_idx, 266 | num_nodes=node_type.size(0)) 267 | 268 | homo_data.y = node_type.new_full((node_type.size(0), 1), -1) 269 | homo_data.y[local2global['paper']] = data.y_dict['paper'] 270 | 271 | homo_data.train_mask = torch.zeros((node_type.size(0)), dtype=torch.bool) 272 | homo_data.train_mask[local2global['paper'][split_idx['train']['paper']]] = True 273 | 274 | print(homo_data) 275 | 276 | # Map informations to their canonical type. 277 | x_dict = {} 278 | for key, x in data.x_dict.items(): 279 | x_dict[key2int[key]] = x 280 | 281 | num_nodes_dict = {} 282 | for key, N in data.num_nodes_dict.items(): 283 | num_nodes_dict[key2int[key]] = N 284 | 285 | device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' 286 | 287 | model = RGCN(128, args.hidden_channels, dataset.num_classes, args.num_layers, 288 | args.dropout, num_nodes_dict, list(x_dict.keys()), 289 | len(edge_index_dict.keys())).to(device) 290 | print(model) 291 | total_param = 0 292 | for param in model.parameters(): 293 | total_param += np.prod(list(param.data.size())) 294 | print(f'Total parameters: {total_param}') 295 | total_param = total_param - 1134649*128 - 59965*128 - 8740*128 296 | print(f'GNN parameters: {total_param}') 297 | print(f'Embedding parameters: {1134649*128 - 59965*128 - 8740*128}') 298 | 299 | if args.checkpoint != "": 300 | checkpoint = torch.load(f"{args.checkpoint}/results.pt", map_location=torch.device('cpu')) 301 | model.load_state_dict(checkpoint['model_state_dict']) 302 | model.eval() 303 | 304 | x_dict = {k: v.to(device) for k, v in x_dict.items()} 305 | 306 | t_iter_time = [] 307 | t_gpu_used = [] 308 | for run in range(args.runs): 309 | torch.cuda.empty_cache() 310 | model.reset_parameters() 311 | result = test(model, data, x_dict, edge_index_dict, key2int, split_idx, evaluator, args.device) 312 | 313 | logger.add_result(run, result[:3]) 314 | train_acc, valid_acc, test_acc, iter_time, gpu_used = result 315 | t_iter_time.append(iter_time) 316 | t_gpu_used.append(gpu_used) 317 | print(f'Run: {run + 1:02d}, ' 318 | f'Train: {100 * train_acc:.2f}%, ' 319 | f'Valid: {100 * valid_acc:.2f}%, ' 320 | f'Test: {100 * test_acc:.2f}%, ' 321 | f'Time: {iter_time:.3f}s, ' 322 | f'GPU: {gpu_used * 1e-9:.1f}GB' ) 323 | 324 | logger.print_statistics() 325 | print(f"Avg. Iteration Time: {np.mean(t_iter_time[1:]):.3f}s ± {np.std(t_iter_time[1:]):.3f}") 326 | print(f"Avg. GPU Used: {np.mean(t_gpu_used) * 1e-9:.1f}GB ± {np.std(t_gpu_used)}") 327 | 328 | if __name__ == "__main__": 329 | parser = argparse.ArgumentParser(description='OGBN-MAG (GraphSAINT)') 330 | 331 | # Experiment settings 332 | parser.add_argument('--device', type=int, default=0) 333 | parser.add_argument('--checkpoint', type=str, default="") 334 | 335 | # GNN settings 336 | parser.add_argument('--gnn', type=str, default='rgcn') 337 | parser.add_argument('--num_layers', type=int, default=2) 338 | parser.add_argument('--hidden_channels', type=int, default=32) 339 | parser.add_argument('--dropout', type=float, default=0.5) 340 | parser.add_argument('--runs', type=int, default=10) 341 | parser.add_argument('--batch_size', type=int, default=20000) 342 | parser.add_argument('--walk_length', type=int, default=2) 343 | parser.add_argument('--num_steps', type=int, default=30) 344 | 345 | args = parser.parse_args() 346 | 347 | print(args) 348 | main() 349 | -------------------------------------------------------------------------------- /mol_pyg/README.md: -------------------------------------------------------------------------------- 1 | # Knowledge Distillation for GNNs (Graph classification on OGBG-MOL* datasets) 2 | 3 | **Dataset**: MOLHIV (and other OGBG-MOL* datasets) 4 | 5 | **Library**: PyG 6 | 7 | This folder contains code to benchmark knowledge distillation for GNNs on various molecular graph classification datasets from OGBG-MOL*, experiments were predominantly done on MOLHIV. 8 | 9 | Under preparation. 10 | 11 | ![MOLHIV Results](../img/molhiv.png) 12 | -------------------------------------------------------------------------------- /ppi_pyg/README.md: -------------------------------------------------------------------------------- 1 | # Knowledge Distillation for GNNs (PPI with PyG) 2 | 3 | **Dataset**: PPI 4 | 5 | **Library**: PyG 6 | 7 | This repository contains code to benchmark knowledge distillation for GNNs on the PPI dataset, developed in the PyG framework. 8 | The main purpose of the codebase is to: 9 | - Train teacher GAT models on PPI dataset via supervised learning and export the checkpoints 10 | - Train student GAT models with/without knowledge distillation. 11 | 12 | ![PPI Results](../img/ppi.png) 13 | 14 | ## Directory Structure 15 | 16 | ``` 17 | . 18 | ├── checkpoints 19 | ├── logs 20 | ├── data # automatically created by OGB data downloaders 21 | | 22 | ├── scripts # scripts to conduct full experiments and reproduce results 23 | │ ├── baselines.sh # script to train student models without KD 24 | │ ├── run.sh # script to benchmark all KD losses 25 | │ └── train_teacher.sh # script to train and save teacher checkpoints 26 | | 27 | ├── README.md 28 | | 29 | ├── criterion.py # KD loss functions 30 | ├── gnn.py # train student GNNs via auxiliary representation distillation loss 31 | ├── train_teacher.py # train teacher GNNs and export checkpoints 32 | ├── logger.py # logging utilities 33 | └── submit.py # read log directory to aggregate results 34 | ``` 35 | 36 | ## Example Usage 37 | 38 | For full usage, each file has accompanying flags and documentation. 39 | Also see the `scripts` folder for reproducing results. 40 | -------------------------------------------------------------------------------- /ppi_pyg/criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | from torch_geometric.utils import softmax, to_dense_adj, subgraph, negative_sampling, add_self_loops 6 | 7 | 8 | def kd_criterion(logits, labels, teacher_logits, alpha=0.5, T=1): 9 | """Logit-based KD, [Hinton et al., 2015](https://arxiv.org/abs/1503.02531) 10 | """ 11 | loss_cls = F.binary_cross_entropy_with_logits(logits, labels) 12 | 13 | loss_kd = F.binary_cross_entropy_with_logits(logits, torch.sigmoid(teacher_logits)) 14 | 15 | loss = loss_kd* (alpha* T* T) + loss_cls* (1-alpha) 16 | # TODO use temperature? 17 | 18 | return loss, loss_cls, loss_kd 19 | 20 | 21 | def fitnet_criterion(logits, labels, feat, teacher_feat, beta=1000): 22 | """FitNet, [Romero et al., 2014](https://arxiv.org/abs/1412.6550) 23 | """ 24 | loss_cls = F.binary_cross_entropy_with_logits(logits, labels) 25 | 26 | feat = F.normalize(feat, p=2, dim=-1) 27 | teacher_feat = F.normalize(teacher_feat, p=2, dim=-1) 28 | 29 | loss_fitnet = F.mse_loss(feat, teacher_feat) 30 | 31 | loss = loss_cls + beta* loss_fitnet 32 | 33 | return loss, loss_cls, loss_fitnet 34 | 35 | 36 | def at_criterion(logits, labels, feat, teacher_feat, beta=1000): 37 | """Attention Transfer, [Zagoruyko and Komodakis, 2016](https://arxiv.org/abs/1612.03928) 38 | """ 39 | loss_cls = F.binary_cross_entropy_with_logits(logits, labels) 40 | 41 | feat = feat.pow(2).sum(-1) 42 | teacher_feat = teacher_feat.pow(2).sum(-1) 43 | 44 | loss_at = F.mse_loss( 45 | F.normalize(feat, p=2, dim=-1), 46 | F.normalize(teacher_feat, p=2, dim=-1) 47 | ) 48 | 49 | loss = loss_cls + beta* loss_at 50 | 51 | return loss, loss_cls, loss_at 52 | 53 | 54 | def gpw_criterion(logits, labels, feat, teacher_feat, kernel='cosine', beta=1, max_samples=8192): 55 | """Global Structure Preserving loss, [Joshi et al., TNNLS 2022](https://arxiv.org/abs/2111.04964) 56 | """ 57 | loss_cls = F.binary_cross_entropy_with_logits(logits, labels) 58 | 59 | if max_samples < feat.shape[0]: 60 | sampled_inds = np.random.choice(feat.shape[0], max_samples, replace=False) 61 | feat = feat[sampled_inds] 62 | teacher_feat = teacher_feat[sampled_inds] 63 | 64 | pw_sim = None 65 | teacher_pw_sim = None 66 | if kernel == 'cosine': 67 | feat = F.normalize(feat, p=2, dim=-1) 68 | teacher_feat = F.normalize(teacher_feat, p=2, dim=-1) 69 | pw_sim = torch.mm(feat, feat.transpose(0, 1)).flatten() 70 | teacher_pw_sim = torch.mm(teacher_feat, teacher_feat.transpose(0, 1)).flatten() 71 | elif kernel == 'poly': 72 | feat = F.normalize(feat, p=2, dim=-1) 73 | teacher_feat = F.normalize(teacher_feat, p=2, dim=-1) 74 | pw_sim = torch.mm(feat, feat.transpose(0, 1)).flatten()**2 75 | teacher_pw_sim = torch.mm(teacher_feat, teacher_feat.transpose(0, 1)).flatten()**2 76 | elif kernel == 'l2': 77 | pw_sim = (feat.unsqueeze(0) - feat.unsqueeze(1)).norm(p=2, dim=-1).flatten() 78 | teacher_pw_sim = (teacher_feat.unsqueeze(0) - teacher_feat.unsqueeze(1)).norm(p=2, dim=-1).flatten() 79 | elif kernel == 'rbf': 80 | pw_sim = torch.exp(-0.5* ((feat.unsqueeze(0) - feat.unsqueeze(1))**2).sum(dim=-1).flatten()) 81 | teacher_pw_sim = torch.exp(-0.5* ((teacher_feat.unsqueeze(0) - teacher_feat.unsqueeze(1))**2).sum(dim=-1).flatten()) 82 | else: 83 | raise NotImplementedError 84 | 85 | loss_gpw = F.mse_loss(pw_sim, teacher_pw_sim) 86 | 87 | loss = loss_cls + beta* loss_gpw 88 | 89 | return loss, loss_cls, loss_gpw 90 | 91 | 92 | def lpw_criterion(logits, labels, feat, teacher_feat, edge_index, kernel='cosine', beta=100, criterion='kld'): 93 | """Local Structure Preserving loss, [Yang et al., CVPR 2020](https://arxiv.org/abs/2003.10477) 94 | """ 95 | loss_cls = F.binary_cross_entropy_with_logits(logits, labels) 96 | 97 | src, dst = edge_index 98 | 99 | if kernel == 'cosine': 100 | pw_sim = softmax(F.cosine_similarity(feat[src], feat[dst]), dst) 101 | teacher_pw_sim = softmax(F.cosine_similarity(teacher_feat[src], teacher_feat[dst]), dst) 102 | elif kernel == 'poly': 103 | pw_sim = softmax(F.cosine_similarity(feat[src], feat[dst])**2, dst) 104 | teacher_pw_sim = softmax(F.cosine_similarity(teacher_feat[src], teacher_feat[dst])**2, dst) 105 | elif kernel == 'l2': 106 | pw_sim = softmax((feat[src] - feat[dst]).norm(p=2, dim=-1), dst) 107 | teacher_pw_sim = softmax((teacher_feat[src] - teacher_feat[dst]).norm(p=2, dim=-1), dst) 108 | elif kernel == 'rbf': 109 | pw_sim = softmax(torch.exp( -0.5* ((feat[src] - feat[dst])**2).sum(dim=-1) ), dst) 110 | teacher_pw_sim = softmax(torch.exp( -0.5* ((teacher_feat[src] - teacher_feat[dst])**2).sum(dim=-1) ), dst) 111 | else: 112 | raise NotImplementedError 113 | 114 | if criterion == 'mse': 115 | loss_lpw = F.mse_loss(pw_sim, teacher_pw_sim) 116 | elif criterion == 'kld': 117 | loss_lpw = F.kl_div(torch.log(pw_sim), teacher_pw_sim, log_target=False) 118 | else: 119 | raise NotImplementedError 120 | 121 | loss = loss_cls + beta* loss_lpw 122 | 123 | return loss, loss_cls, loss_lpw 124 | 125 | 126 | def nce_criterion(logits, labels, feat, teacher_feat, beta=0.5, nce_T=0.075, max_samples=8192): 127 | """Graph Contrastive Representation Distillation, [Joshi et al., TNNLS 2022](https://arxiv.org/abs/2111.04964) 128 | """ 129 | loss_cls = F.binary_cross_entropy_with_logits(logits, labels) 130 | 131 | if max_samples < feat.shape[0]: 132 | sampled_inds = np.random.choice(feat.shape[0], max_samples, replace=False) 133 | feat = feat[sampled_inds] 134 | teacher_feat = teacher_feat[sampled_inds] 135 | 136 | feat = F.normalize(feat, p=2, dim=-1) 137 | teacher_feat = F.normalize(teacher_feat, p=2, dim=-1) 138 | 139 | nce_logits = torch.mm(feat, teacher_feat.transpose(0, 1)) 140 | nce_labels = torch.arange(feat.shape[0]).to(feat.device) 141 | 142 | loss_nce = F.cross_entropy(nce_logits/ nce_T, nce_labels) 143 | 144 | loss = loss_cls + beta* loss_nce 145 | 146 | return loss, loss_cls, loss_nce 147 | -------------------------------------------------------------------------------- /ppi_pyg/logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Logger(object): 5 | def __init__(self, runs, info=None): 6 | self.info = info 7 | self.results = [[] for _ in range(runs)] 8 | 9 | def add_result(self, run, result): 10 | assert len(result) == 3 11 | assert run >= 0 and run < len(self.results) 12 | self.results[run].append(result) 13 | 14 | def print_statistics(self, run=None): 15 | if run is not None: 16 | result = 100 * torch.tensor(self.results[run]) 17 | argmax = result[:, 1].argmax().item() 18 | print(f'Run {run + 1:02d}:') 19 | print(f'Highest Train: {result[:, 0].max():.2f}') 20 | print(f'Highest Valid: {result[:, 1].max():.2f}') 21 | print(f' Final Train: {result[argmax, 0]:.2f}') 22 | print(f' Final Test: {result[argmax, 2]:.2f}') 23 | else: 24 | result = 100 * torch.tensor(self.results) 25 | 26 | best_results = [] 27 | for r in result: 28 | train1 = r[:, 0].max().item() 29 | valid = r[:, 1].max().item() 30 | train2 = r[r[:, 1].argmax(), 0].item() 31 | test = r[r[:, 1].argmax(), 2].item() 32 | best_results.append((train1, valid, train2, test)) 33 | 34 | best_result = torch.tensor(best_results) 35 | 36 | print(f'All runs:') 37 | r = best_result[:, 0] 38 | print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}') 39 | r = best_result[:, 1] 40 | print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}') 41 | r = best_result[:, 2] 42 | print(f' Final Train: {r.mean():.2f} ± {r.std():.2f}') 43 | r = best_result[:, 3] 44 | print(f' Final Test: {r.mean():.2f} ± {r.std():.2f}') -------------------------------------------------------------------------------- /ppi_pyg/scripts/baselines.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | 4 | tmux new -s baselines -d 5 | tmux send-keys "conda activate ogb" C-m 6 | 7 | runs=5 8 | 9 | ############################ 10 | 11 | gnn=student 12 | num_layers=5 13 | hidden_channels=68 14 | 15 | training=supervised 16 | expt_name=supervised 17 | tmux send-keys " 18 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels & 19 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels & 20 | wait" C-m 21 | 22 | training=kd 23 | expt_name=kd 24 | tmux send-keys " 25 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels & 26 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels & 27 | wait" C-m 28 | 29 | ############################ 30 | 31 | gnn=gcn 32 | num_layers=2 33 | hidden_channels=256 34 | 35 | training=supervised 36 | expt_name=supervised 37 | tmux send-keys " 38 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels & 39 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels & 40 | wait" C-m 41 | 42 | training=kd 43 | expt_name=kd 44 | tmux send-keys " 45 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels & 46 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels & 47 | wait" C-m 48 | 49 | ############################ 50 | 51 | gnn=sage 52 | num_layers=2 53 | hidden_channels=256 54 | 55 | training=supervised 56 | expt_name=supervised 57 | tmux send-keys " 58 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels & 59 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels & 60 | wait" C-m 61 | 62 | training=kd 63 | expt_name=kd 64 | tmux send-keys " 65 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels & 66 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels & 67 | wait" C-m 68 | 69 | ############################ 70 | 71 | gnn=gat 72 | num_layers=2 73 | hidden_channels=64 # 4 heads 74 | 75 | training=supervised 76 | expt_name=supervised 77 | tmux send-keys " 78 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels & 79 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels & 80 | wait" C-m 81 | 82 | training=kd 83 | expt_name=kd 84 | tmux send-keys " 85 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels & 86 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels & 87 | wait" C-m 88 | 89 | 90 | tmux send-keys "tmux kill-session -t baselines" C-m 91 | -------------------------------------------------------------------------------- /ppi_pyg/scripts/run.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | 4 | # Experimental setup: 5 | # - We tune the loss balancing weights α, β in for all techniques on the validation set when using both logit-based and representation distillation together. 6 | # - For KD, we tune α ∈ {0.8, 0.9}, τ1 ∈ {4, 5}. 7 | # - For FitNet, AT, LSP, and GSP, we tune β ∈ {100, 1000, 10000} and the kernel for LSP, GSP. 8 | # - For G-CRD, we tune β ∈ {0.01, 0.05}, τ2 ∈ {0.05, 0.075, 0.1} and the projection head. 9 | # When comparing representation distillation methods, we set α to 0 in order to ablate performance and reduce β by one order of magnitude. 10 | 11 | tmux new -s all -d 12 | tmux send-keys "conda activate ogb" C-m 13 | 14 | runs=5 15 | 16 | gnn=student 17 | num_layers=5 18 | hidden_channels=68 19 | 20 | ############################ 21 | 22 | training=fitnet 23 | expt_name=fitnet 24 | beta=1000 25 | 26 | tmux send-keys " 27 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta & 28 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta & 29 | wait" C-m 30 | 31 | ############################ 32 | 33 | training=at 34 | expt_name=at 35 | beta=1000 36 | 37 | tmux send-keys " 38 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta & 39 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta & 40 | wait" C-m 41 | 42 | ############################ 43 | 44 | training=lpw 45 | expt_name=lpw-rbf 46 | beta=100 47 | kernel=rbf 48 | 49 | tmux send-keys " 50 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 51 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 52 | wait" C-m 53 | 54 | ############################ 55 | 56 | training=lpw 57 | expt_name=lpw-poly 58 | beta=100 59 | kernel=poly 60 | 61 | tmux send-keys " 62 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 63 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 64 | wait" C-m 65 | 66 | ############################ 67 | 68 | training=lpw 69 | expt_name=lpw-cosine 70 | beta=100 71 | kernel=cosine 72 | 73 | tmux send-keys " 74 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 75 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 76 | wait" C-m 77 | 78 | ############################ 79 | 80 | training=lpw 81 | expt_name=lpw-l2 82 | beta=100 83 | kernel=l2 84 | 85 | tmux send-keys " 86 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 87 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 88 | wait" C-m 89 | 90 | ############################ 91 | 92 | training=nce 93 | expt_name=nce 94 | beta=0.1 95 | nce_T=0.075 96 | max_samples=16384 97 | proj_dim=256 98 | 99 | tmux send-keys " 100 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --nce_T $nce_T --max_samples $max_samples --proj_dim $proj_dim & 101 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --nce_T $nce_T --max_samples $max_samples --proj_dim $proj_dim & 102 | wait" C-m 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | gnn=gcn 118 | num_layers=2 119 | hidden_channels=256 120 | 121 | ############################ 122 | 123 | training=fitnet 124 | expt_name=fitnet 125 | beta=1000 126 | 127 | tmux send-keys " 128 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta & 129 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta & 130 | wait" C-m 131 | 132 | ############################ 133 | 134 | training=at 135 | expt_name=at 136 | beta=1000 137 | 138 | tmux send-keys " 139 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta & 140 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta & 141 | wait" C-m 142 | 143 | ############################ 144 | 145 | training=lpw 146 | expt_name=lpw-rbf 147 | beta=100 148 | kernel=rbf 149 | 150 | tmux send-keys " 151 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 152 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 153 | wait" C-m 154 | 155 | ############################ 156 | 157 | training=lpw 158 | expt_name=lpw-poly 159 | beta=100 160 | kernel=poly 161 | 162 | tmux send-keys " 163 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 164 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 165 | wait" C-m 166 | 167 | ############################ 168 | 169 | training=lpw 170 | expt_name=lpw-cosine 171 | beta=100 172 | kernel=cosine 173 | 174 | tmux send-keys " 175 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 176 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 177 | wait" C-m 178 | 179 | ############################ 180 | 181 | training=lpw 182 | expt_name=lpw-l2 183 | beta=100 184 | kernel=l2 185 | 186 | tmux send-keys " 187 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 188 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 189 | wait" C-m 190 | 191 | ############################ 192 | 193 | training=nce 194 | expt_name=nce 195 | beta=0.1 196 | nce_T=0.075 197 | max_samples=16384 198 | proj_dim=256 199 | 200 | tmux send-keys " 201 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --nce_T $nce_T --max_samples $max_samples --proj_dim $proj_dim & 202 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --nce_T $nce_T --max_samples $max_samples --proj_dim $proj_dim & 203 | wait" C-m 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | gnn=sage 213 | num_layers=2 214 | hidden_channels=256 215 | 216 | ############################ 217 | 218 | training=fitnet 219 | expt_name=fitnet 220 | beta=1000 221 | 222 | tmux send-keys " 223 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta & 224 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta & 225 | wait" C-m 226 | 227 | ############################ 228 | 229 | training=at 230 | expt_name=at 231 | beta=1000 232 | 233 | tmux send-keys " 234 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta & 235 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta & 236 | wait" C-m 237 | 238 | ############################ 239 | 240 | training=lpw 241 | expt_name=lpw-rbf 242 | beta=100 243 | kernel=rbf 244 | 245 | tmux send-keys " 246 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 247 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 248 | wait" C-m 249 | 250 | ############################ 251 | 252 | training=lpw 253 | expt_name=lpw-poly 254 | beta=100 255 | kernel=poly 256 | 257 | tmux send-keys " 258 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 259 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 260 | wait" C-m 261 | 262 | ############################ 263 | 264 | training=lpw 265 | expt_name=lpw-cosine 266 | beta=100 267 | kernel=cosine 268 | 269 | tmux send-keys " 270 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 271 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 272 | wait" C-m 273 | 274 | ############################ 275 | 276 | training=lpw 277 | expt_name=lpw-l2 278 | beta=100 279 | kernel=l2 280 | 281 | tmux send-keys " 282 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 283 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 284 | wait" C-m 285 | 286 | ############################ 287 | 288 | training=nce 289 | expt_name=nce 290 | beta=0.1 291 | nce_T=0.075 292 | max_samples=16384 293 | proj_dim=256 294 | 295 | tmux send-keys " 296 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --nce_T $nce_T --max_samples $max_samples --proj_dim $proj_dim & 297 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --nce_T $nce_T --max_samples $max_samples --proj_dim $proj_dim & 298 | wait" C-m 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | gnn=gat 308 | num_layers=2 309 | hidden_channels=64 310 | 311 | ############################ 312 | 313 | training=fitnet 314 | expt_name=fitnet 315 | beta=1000 316 | 317 | tmux send-keys " 318 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta & 319 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta & 320 | wait" C-m 321 | 322 | ############################ 323 | 324 | training=at 325 | expt_name=at 326 | beta=1000 327 | 328 | tmux send-keys " 329 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta & 330 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta & 331 | wait" C-m 332 | 333 | ############################ 334 | 335 | training=lpw 336 | expt_name=lpw-rbf 337 | beta=100 338 | kernel=rbf 339 | 340 | tmux send-keys " 341 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 342 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 343 | wait" C-m 344 | 345 | ############################ 346 | 347 | training=lpw 348 | expt_name=lpw-poly 349 | beta=100 350 | kernel=poly 351 | 352 | tmux send-keys " 353 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 354 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 355 | wait" C-m 356 | 357 | ############################ 358 | 359 | training=lpw 360 | expt_name=lpw-cosine 361 | beta=100 362 | kernel=cosine 363 | 364 | tmux send-keys " 365 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 366 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 367 | wait" C-m 368 | 369 | ############################ 370 | 371 | training=lpw 372 | expt_name=lpw-l2 373 | beta=100 374 | kernel=l2 375 | 376 | tmux send-keys " 377 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 378 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --kernel $kernel & 379 | wait" C-m 380 | 381 | ############################ 382 | 383 | training=nce 384 | expt_name=nce 385 | beta=0.1 386 | nce_T=0.075 387 | max_samples=16384 388 | proj_dim=256 389 | 390 | tmux send-keys " 391 | python gnn.py --device 0 --seed 0 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --nce_T $nce_T --max_samples $max_samples --proj_dim $proj_dim & 392 | python gnn.py --device 1 --seed 5 --runs $runs --training $training --expt_name $expt_name --gnn $gnn --num_layers $num_layers --hidden_channels $hidden_channels --beta $beta --nce_T $nce_T --max_samples $max_samples --proj_dim $proj_dim & 393 | wait" C-m 394 | 395 | 396 | 397 | 398 | 399 | tmux send-keys "tmux kill-session -t all" C-m 400 | -------------------------------------------------------------------------------- /ppi_pyg/scripts/train_teacher.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | 4 | tmux new -s expt -d 5 | tmux send-keys "conda activate ogb" C-m 6 | 7 | runs=5 8 | 9 | 10 | ############################ 11 | 12 | tmux send-keys " 13 | python train_teacher.py --device 0 --seed 0 --runs $runs & 14 | python train_teacher.py --device 1 --seed 5 --runs $runs & 15 | wait" C-m 16 | 17 | tmux send-keys "tmux kill-session -t expt" C-m 18 | -------------------------------------------------------------------------------- /ppi_pyg/submit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import argparse 5 | 6 | 7 | if __name__ == "__main__": 8 | # Experiment settings 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--expt_name', type=str, default="debug", 11 | help='Name of experiment logs/results folder') 12 | parser.add_argument('--run', type=str, default="debug", 13 | help='Name of specific run') 14 | args = parser.parse_args() 15 | 16 | expt_args = [] 17 | model = [] 18 | total_param = [] 19 | BestEpoch = [] 20 | Validation = [] 21 | Test = [] 22 | Train = [] 23 | BestTrain = [] 24 | 25 | for root, dirs, files in os.walk(f'logs/{args.expt_name}/{args.run}'): 26 | if 'results.pt' in files: 27 | results = torch.load(os.path.join(root, 'results.pt'), map_location=torch.device('cpu')) 28 | expt_args.append(results['args']) 29 | total_param.append(results['total_param']) 30 | BestEpoch.append(results['BestEpoch']) 31 | Validation.append(results['Validation']) 32 | Test.append(results['Test']) 33 | Train.append(results['Train']) 34 | 35 | print(results['args'].seed, results['Test'], results['Validation'], results['Train']) 36 | 37 | print(expt_args[0]) 38 | print() 39 | print(f'Test performance: {np.mean(Test)*100:.2f} +- {np.std(Test)*100:.2f}') 40 | print(f'Validation performance: {np.mean(Validation)*100:.2f} +- {np.std(Validation)*100:.2f}') 41 | print(f'Train performance: {np.mean(Train)*100:.2f} +- {np.std(Train)*100:.2f}') 42 | print(f'Total parameters: {int(np.mean(total_param))}') -------------------------------------------------------------------------------- /ppi_pyg/train_teacher.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import time 4 | import os 5 | import random 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | import torch_geometric.transforms as T 12 | from torch_geometric.nn import GCNConv, SAGEConv, GATConv 13 | from torch_geometric.utils import subgraph 14 | from torch_geometric.datasets import PPI 15 | from torch_geometric.data import DataLoader 16 | 17 | from sklearn.metrics import f1_score 18 | 19 | from logger import Logger 20 | 21 | from criterion import * 22 | 23 | 24 | class TeacherNet(torch.nn.Module): 25 | def __init__(self, in_channels, out_channels): 26 | super(TeacherNet, self).__init__() 27 | self.conv1 = GATConv(in_channels, 256, heads=4) 28 | self.lin1 = torch.nn.Linear(in_channels, 4 * 256) 29 | self.conv2 = GATConv(4 * 256, 256, heads=4) 30 | self.lin2 = torch.nn.Linear(4 * 256, 4 * 256) 31 | self.conv3 = GATConv(4 * 256, out_channels, heads=6, concat=False) 32 | self.lin3 = torch.nn.Linear(4 * 256, out_channels) 33 | 34 | def reset_parameters(self): 35 | self.conv1.reset_parameters() 36 | self.conv2.reset_parameters() 37 | self.conv3.reset_parameters() 38 | self.lin1.reset_parameters() 39 | self.lin2.reset_parameters() 40 | self.lin3.reset_parameters() 41 | 42 | def forward(self, x, edge_index): 43 | x = F.elu(self.conv1(x, edge_index) + self.lin1(x)) 44 | x = F.elu(self.conv2(x, edge_index) + self.lin2(x)) 45 | self.out_feat = x 46 | x = self.conv3(x, edge_index) + self.lin3(x) 47 | return x 48 | 49 | 50 | def train(model, train_dataset, train_loader, optimizer, args, device): 51 | model.train() 52 | 53 | avg_loss = 0 54 | 55 | for step, batch in enumerate(train_loader): 56 | batch = batch.to(device) 57 | labels = batch.y 58 | out = model(batch.x, batch.edge_index) 59 | 60 | loss = F.binary_cross_entropy_with_logits(out, labels) 61 | loss_cls = loss 62 | loss_aux = loss*0 63 | 64 | optimizer.zero_grad() 65 | loss.backward() 66 | optimizer.step() 67 | 68 | avg_loss += loss.detach().item() 69 | 70 | avg_loss /= (step + 1) 71 | return avg_loss 72 | 73 | 74 | @torch.no_grad() 75 | def test(model, dataset, loader, device): 76 | model.eval() 77 | 78 | ys, preds = [], [] 79 | for data in loader: 80 | ys.append(data.y) 81 | out = model(data.x.to(device), data.edge_index.to(device)) 82 | preds.append((out > 0).float().cpu()) 83 | 84 | y, pred = torch.cat(ys, dim=0).numpy(), torch.cat(preds, dim=0).numpy() 85 | return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0 86 | 87 | 88 | def seed(seed=0): 89 | random.seed(seed) 90 | np.random.seed(seed) 91 | torch.manual_seed(seed) 92 | torch.cuda.manual_seed(seed) 93 | torch.cuda.manual_seed_all(seed) 94 | torch.backends.cudnn.deterministic = True 95 | torch.backends.cudnn.benchmark = False 96 | 97 | 98 | def main(): 99 | device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' 100 | device = torch.device(device) 101 | 102 | train_dataset = PPI('data/PPI/', split='train') 103 | val_dataset = PPI('data/PPI/', split='val') 104 | test_dataset = PPI('data/PPI/', split='test') 105 | train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) 106 | val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False) 107 | test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False) 108 | 109 | model = TeacherNet(train_dataset.num_features, train_dataset.num_classes).to(device) 110 | print(model) 111 | total_param = 0 112 | for param in model.parameters(): 113 | total_param += np.prod(list(param.data.size())) 114 | print(f'Total parameters: {total_param}') 115 | 116 | logger = Logger(args.runs, args) 117 | 118 | for run in range(args.runs): 119 | seed(args.seed + run) 120 | 121 | model.reset_parameters() 122 | 123 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 124 | 125 | # Create Tensorboard logger 126 | log_dir = os.path.join( 127 | "checkpoints", 128 | f"seed{args.seed+run}" 129 | ) 130 | tb_logger = SummaryWriter(log_dir) 131 | 132 | # Start training 133 | best_epoch = 0 134 | best_train = 0 135 | best_val = 0 136 | best_test = 0 137 | for epoch in range(1, 1 + args.epochs): 138 | train_loss = train( 139 | model, train_dataset, train_loader, 140 | optimizer, args, device 141 | ) 142 | 143 | train_acc = test(model, train_dataset, train_loader, device) 144 | valid_acc = test(model, val_dataset, val_loader, device) 145 | test_acc = test(model, test_dataset, test_loader, device) 146 | logger.add_result(run, (train_acc, valid_acc, test_acc)) 147 | 148 | if epoch % args.log_steps == 0: 149 | print(f'Run: {run + 1:02d}, ' 150 | f'Epoch: {epoch:02d}, ' 151 | f'Loss: {train_loss:.4f}, ' 152 | f'Train: {100 * train_acc:.2f}%, ' 153 | f'Valid: {100 * valid_acc:.2f}% ' 154 | f'Test: {100 * test_acc:.2f}%') 155 | 156 | # Log statistics to Tensorboard, etc. 157 | tb_logger.add_scalar('loss/train', train_loss, epoch) 158 | tb_logger.add_scalar('acc/train', train_acc, epoch) 159 | tb_logger.add_scalar('acc/valid', valid_acc, epoch) 160 | tb_logger.add_scalar('acc/test', test_acc, epoch) 161 | 162 | if valid_acc > best_val: 163 | best_epoch = epoch 164 | best_train = train_acc 165 | best_val = valid_acc 166 | best_test = test_acc 167 | 168 | torch.save({ 169 | 'args': args, 170 | 'total_param': total_param, 171 | 'BestEpoch': best_epoch, 172 | 'Train': best_train, 173 | 'Validation': best_val, 174 | 'Test': best_test, 175 | 'model_state_dict': model.state_dict(), 176 | 'optimizer_state_dict': optimizer.state_dict(), 177 | }, os.path.join(log_dir, "checkpoint.pt")) 178 | 179 | logger.print_statistics(run) 180 | 181 | logger.print_statistics() 182 | 183 | 184 | if __name__ == "__main__": 185 | parser = argparse.ArgumentParser(description='PPI') 186 | 187 | # Experiment settings 188 | parser.add_argument('--device', type=int, default=0) 189 | parser.add_argument("--seed", type=int, default=0, help="seed") 190 | parser.add_argument('--log_steps', type=int, default=1) 191 | 192 | # GNN settings 193 | parser.add_argument('--lr', type=float, default=0.005) 194 | parser.add_argument('--epochs', type=int, default=500) 195 | parser.add_argument('--runs', type=int, default=10) 196 | 197 | args = parser.parse_args() 198 | 199 | print(args) 200 | main() --------------------------------------------------------------------------------