├── .gitignore ├── LICENSE ├── README.md ├── deepgcn_env_install.sh ├── eff_gcn_modules └── rev │ ├── __init__.py │ ├── gcn_revop.py │ ├── memgcn.py │ └── rev_layer.py ├── examples ├── modelnet_cls │ ├── README.md │ ├── __init__.py │ ├── architecture.py │ ├── config.py │ ├── data.py │ └── main.py ├── ogb │ ├── README.md │ ├── ogbg_mol │ │ ├── README.md │ │ ├── __init__.py │ │ ├── args.py │ │ ├── main.py │ │ ├── model.py │ │ └── test.py │ ├── ogbg_ppa │ │ ├── README.md │ │ ├── __init__.py │ │ ├── args.py │ │ ├── main.py │ │ ├── model.py │ │ └── test.py │ ├── ogbl_collab │ │ ├── README.md │ │ ├── __init__.py │ │ ├── args.py │ │ ├── main.py │ │ ├── model.py │ │ └── test.py │ ├── ogbn_arxiv │ │ ├── README.md │ │ ├── __init__.py │ │ ├── args.py │ │ ├── main.py │ │ ├── model.py │ │ └── test.py │ ├── ogbn_products │ │ ├── README.md │ │ ├── __init__.py │ │ ├── args.py │ │ ├── main.py │ │ ├── model.py │ │ └── test.py │ └── ogbn_proteins │ │ ├── README.md │ │ ├── __init__.py │ │ ├── args.py │ │ ├── dataset.py │ │ ├── main.py │ │ ├── model.py │ │ └── test.py ├── ogb_eff │ ├── ogbn_arxiv_dgl │ │ ├── README.md │ │ ├── __init__.py │ │ ├── loss.py │ │ ├── main.py │ │ └── model_rev.py │ └── ogbn_proteins │ │ ├── README.md │ │ ├── __init__.py │ │ ├── args.py │ │ ├── dataset.py │ │ ├── main.py │ │ ├── model_rev.py │ │ └── test.py ├── part_sem_seg │ ├── README.md │ ├── __init__.py │ ├── architecture.py │ ├── config.py │ ├── data.py │ ├── eval.py │ ├── main.py │ └── visualize.py ├── ppi │ ├── README.md │ ├── architecture.py │ ├── main.py │ └── opt.py ├── sem_seg_dense │ ├── README.md │ ├── __init__.py │ ├── architecture.py │ ├── config.py │ ├── test.py │ └── train.py └── sem_seg_sparse │ ├── README.md │ ├── __init__.py │ ├── architecture.py │ ├── config.py │ ├── script │ ├── test.sh │ └── train.sh │ ├── test.py │ └── train.py ├── gcn_lib ├── __init__.py ├── dense │ ├── __init__.py │ ├── torch_edge.py │ ├── torch_nn.py │ └── torch_vertex.py └── sparse │ ├── __init__.py │ ├── torch_edge.py │ ├── torch_message.py │ ├── torch_nn.py │ └── torch_vertex.py ├── misc ├── deeper_gcn_intro.png ├── deeper_power_mean.png ├── deeper_softmax.png ├── intro.png ├── modelnet_cls.png ├── part_sem_seg.png ├── pipeline.png ├── ppi.png └── sem_seg_s3dis.png └── utils ├── __init__.py ├── ckpt_util.py ├── data_util.py ├── logger.py ├── loss.py ├── metrics.py ├── optim.py ├── pc_viz.py ├── pyg_util.py └── tf_logger.py /.gitignore: -------------------------------------------------------------------------------- 1 | *ipynb 2 | *.pyc 3 | .git/ 4 | .idea/ 5 | *checkpoints* 6 | *logs* 7 | *ibex 8 | 9 | examples/sem_seg_dense/script/ 10 | examples/sem_seg_dense/checkpoints/ 11 | examples/sem_seg_dense/logs/ 12 | 13 | examples/sem_seg_sparse/checkpoints/ 14 | examples/sem_seg_sparse/logs/ 15 | examples/sem_seg_sparse/script/ 16 | 17 | examples/ppi/checkpoints/ 18 | examples/ppi/logs/ 19 | examples/ppi/script/ 20 | 21 | examples/part_sem_seg/logs/ 22 | examples/part_sem_seg/checkpoints/ 23 | examples/part_sem_seg/result/ 24 | examples/part_sem_seg/script/ 25 | 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 DeepGCNs.org 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepGCNs: Can GCNs Go as Deep as CNNs? 2 | In this work, we present new ways to successfully train very deep GCNs. We borrow concepts from CNNs, mainly residual/dense connections and dilated convolutions, and adapt them to GCN architectures. Through extensive experiments, we show the positive effect of these deep GCN frameworks. 3 | 4 | [[Project]](https://www.deepgcns.org/) [[Paper]](https://arxiv.org/abs/1904.03751) [[Slides]](https://docs.google.com/presentation/d/1L82wWymMnHyYJk3xUKvteEWD5fX0jVRbCbI65Cxxku0/edit?usp=sharing) [[Tensorflow Code]](https://github.com/lightaime/deep_gcns) [[Pytorch Code]](https://github.com/lightaime/deep_gcns_torch) 5 | 6 |

7 | 8 |

9 | 10 | ## Overview 11 | We do extensive experiments to show how different components (#Layers, #Filters, #Nearest Neighbors, Dilation, etc.) effect `DeepGCNs`. We also provide ablation studies on different type of Deep GCNs (MRGCN, EdgeConv, GraphSage and GIN). 12 | 13 |

14 | 15 |

16 | 17 | 18 | ## How to train, test and evaluate our models 19 | Please look the details in `Readme.md` of each task inside `examples` folder. 20 | All the information of code, data, and pretrained models can be found there. 21 | * DeepGCNs ([ICCV'2019](https://arxiv.org/abs/1904.03751), [TPAMI'2021](https://arxiv.org/abs/1910.06849)): [S3DIS](examples/sem_seg_dense), [PartNet](examples/part_sem_seg), [ModelNet40](examples/modelnet_cls), [PPI](/examples/ppi) 22 | * DeeperGCN ([Arxiv'2020](https://arxiv.org/abs/2006.07739)): [OGB](examples/ogb) 23 | * GNN'1000 ([ICML'2021](https://arxiv.org/abs/2106.07476)): [OGB](examples/ogb_eff) 24 | 25 | ## Recommended Requirements 26 | * [Python>=3.7](https://www.python.org/) 27 | * [Pytorch>=1.9.0](https://pytorch.org) 28 | * [pytorch_geometric>=1.6.0](https://pytorch-geometric.readthedocs.io/en/latest/) 29 | * [ogb>=1.3.1](https://github.com/snap-stanford/ogb) only used for experiments on OGB datasets 30 | * [dgl>=0.5.3](https://github.com/dmlc/dgl) only used for the experiment `examples/ogb_eff/ogbn_arxiv_dgl` 31 | 32 | Install enviroment by runing: 33 | ``` 34 | source deepgcn_env_install.sh 35 | ``` 36 | 37 | ## Code Architecture 38 | . 39 | ├── misc # Misc images 40 | ├── utils # Common useful modules 41 | ├── gcn_lib # gcn library 42 | │ ├── dense # gcn library for dense data (B x C x N x 1) 43 | │ └── sparse # gcn library for sparse data (N x C) 44 | ├── eff_gcn_modules # modules for mem efficient gnns 45 | ├── examples 46 | │ ├── modelnet_cls # code for point clouds classification on ModelNet40 47 | │ ├── sem_seg_dense # code for point clouds semantic segmentation on S3DIS (data type: dense) 48 | │ ├── sem_seg_sparse # code for point clouds semantic segmentation on S3DIS (data type: sparse) 49 | │ ├── part_sem_seg # code for part segmentation on PartNet 50 | │ ├── ppi # code for node classification on PPI dataset 51 | │ └── ogb # code for node/graph property prediction on OGB datasets 52 | │ └── ogb_eff # code for node/graph property prediction on OGB datasets with memory efficient GNNs 53 | └── ... 54 | 55 | ## Citation 56 | Please cite our paper if you find anything helpful, 57 | 58 | ``` 59 | @InProceedings{li2019deepgcns, 60 | title={DeepGCNs: Can GCNs Go as Deep as CNNs?}, 61 | author={Guohao Li and Matthias Müller and Ali Thabet and Bernard Ghanem}, 62 | booktitle={The IEEE International Conference on Computer Vision (ICCV)}, 63 | year={2019} 64 | } 65 | ``` 66 | 67 | ``` 68 | @article{li2021deepgcns_pami, 69 | title={Deepgcns: Making gcns go as deep as cnns}, 70 | author={Li, Guohao and M{\"u}ller, Matthias and Qian, Guocheng and Perez, Itzel Carolina Delgadillo and Abualshour, Abdulellah and Thabet, Ali Kassem and Ghanem, Bernard}, 71 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 72 | year={2021}, 73 | publisher={IEEE} 74 | } 75 | ``` 76 | 77 | ``` 78 | @misc{li2020deepergcn, 79 | title={DeeperGCN: All You Need to Train Deeper GCNs}, 80 | author={Guohao Li and Chenxin Xiong and Ali Thabet and Bernard Ghanem}, 81 | year={2020}, 82 | eprint={2006.07739}, 83 | archivePrefix={arXiv}, 84 | primaryClass={cs.LG} 85 | } 86 | ``` 87 | 88 | ``` 89 | @InProceedings{li2021gnn1000, 90 | title={Training Graph Neural Networks with 1000 layers}, 91 | author={Guohao Li and Matthias Müller and Bernard Ghanem and Vladlen Koltun}, 92 | booktitle={International Conference on Machine Learning (ICML)}, 93 | year={2021} 94 | } 95 | ``` 96 | 97 | ## License 98 | MIT License 99 | 100 | ## Contact 101 | For more information please contact [Guohao Li](https://ghli.org), [Matthias Muller](https://matthias.pw/), [Guocheng Qian](https://www.gcqian.com/). 102 | -------------------------------------------------------------------------------- /deepgcn_env_install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # make sure command is : source deepgcn_env_install.sh 3 | 4 | # install anaconda3. 5 | # cd ~/ 6 | # wget https://repo.anaconda.com/archive/Anaconda3-2019.07-Linux-x86_64.sh 7 | # bash Anaconda3-2019.07-Linux-x86_64.sh 8 | 9 | 10 | source ~/.bashrc 11 | export TORCH_CUDA_ARCH_LIST="7.0;7.5" # v100: 7.0; 2080ti: 7.5; titan xp: 6.1 12 | 13 | # make sure system cuda version is the same with pytorch cuda 14 | # follow the instruction of PyTorch Geometric: https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html 15 | export PATH=/usr/local/cuda-10.2/bin:$PATH 16 | export LD_LIBRARY_PATH=/usr/local/cuda-10.2/lib64:$LD_LIBRARY_PATH 17 | 18 | conda create -n deepgcn 19 | conda activate deepgcn 20 | # make sure pytorch version >=1.4.0 21 | conda install -y pytorch=1.9.0 torchvision cudatoolkit=10.2 python=3.7 -c pytorch 22 | pip install tensorboard 23 | 24 | # command to install pytorch geometric, please refer to the official website for latest installation. 25 | # https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html 26 | CUDA=cu102 27 | pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+${CUDA}.html 28 | pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+${CUDA}.html 29 | pip install torch-geometric 30 | 31 | pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.9.0+${CUDA}.html 32 | pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.9.0+${CUDA}.html 33 | 34 | pip install requests 35 | 36 | # install useful modules 37 | pip install tqdm 38 | 39 | # additional package required for ogb experiments 40 | pip install ogb 41 | ### check the version of ogb installed, if it is not the latest 42 | python -c "import ogb; print(ogb.__version__)" 43 | # please update the version by running 44 | pip install -U ogb 45 | 46 | # additional package required for dgl implementation 47 | pip install dgl-cu102 48 | -------------------------------------------------------------------------------- /eff_gcn_modules/rev/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightaime/deep_gcns_torch/4f6681eee2290e217bda941b5536452a7c09decb/eff_gcn_modules/rev/__init__.py -------------------------------------------------------------------------------- /eff_gcn_modules/rev/memgcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import copy 4 | try: 5 | from .gcn_revop import InvertibleModuleWrapper 6 | except: 7 | from gcn_revop import InvertibleModuleWrapper 8 | 9 | class GroupAdditiveCoupling(torch.nn.Module): 10 | def __init__(self, Fms, split_dim=-1, group=2): 11 | super(GroupAdditiveCoupling, self).__init__() 12 | 13 | self.Fms = Fms 14 | self.split_dim = split_dim 15 | self.group = group 16 | 17 | def forward(self, x, edge_index, *args): 18 | xs = torch.chunk(x, self.group, dim=self.split_dim) 19 | chunked_args = list(map(lambda arg: torch.chunk(arg, self.group, dim=self.split_dim), args)) 20 | args_chunks = list(zip(*chunked_args)) 21 | y_in = sum(xs[1:]) 22 | 23 | ys = [] 24 | for i in range(self.group): 25 | Fmd = self.Fms[i].forward(y_in, edge_index, *args_chunks[i]) 26 | y = xs[i] + Fmd 27 | y_in = y 28 | ys.append(y) 29 | 30 | out = torch.cat(ys, dim=self.split_dim) 31 | 32 | return out 33 | 34 | def inverse(self, y, edge_index, *args): 35 | ys = torch.chunk(y, self.group, dim=self.split_dim) 36 | chunked_args = list(map(lambda arg: torch.chunk(arg, self.group, dim=self.split_dim), args)) 37 | args_chunks = list(zip(*chunked_args)) 38 | 39 | xs = [] 40 | for i in range(self.group-1, -1, -1): 41 | if i != 0: 42 | y_in = ys[i-1] 43 | else: 44 | y_in = sum(xs) 45 | 46 | Fmd = self.Fms[i].forward(y_in, edge_index, *args_chunks[i]) 47 | x = ys[i] - Fmd 48 | xs.append(x) 49 | 50 | x = torch.cat(xs[::-1], dim=self.split_dim) 51 | 52 | return x 53 | -------------------------------------------------------------------------------- /eff_gcn_modules/rev/rev_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | try: 5 | from torch_geometric.nn import GCNConv, SAGEConv, GATConv 6 | from gcn_lib.sparse.torch_vertex import GENConv 7 | from gcn_lib.sparse.torch_nn import norm_layer 8 | except: 9 | print("An import exception occurred") 10 | 11 | 12 | class SharedDropout(nn.Module): 13 | def __init__(self): 14 | super(SharedDropout, self).__init__() 15 | self.mask = None 16 | 17 | def set_mask(self, mask): 18 | self.mask = mask 19 | 20 | def forward(self, x): 21 | if self.training: 22 | assert self.mask is not None 23 | out = x * self.mask 24 | return out 25 | else: 26 | return x 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | def __init__(self, norm, in_channels): 31 | super(BasicBlock, self).__init__() 32 | self.norm = norm_layer(norm, in_channels) 33 | self.dropout = SharedDropout() 34 | 35 | def forward(self, x, edge_index, dropout_mask=None, edge_emb=None): 36 | # dropout_mask = kwargs.get('dropout_mask', None) 37 | # edge_emb = kwargs.get('edge_emb', None) 38 | out = self.norm(x) 39 | out = F.relu(out) 40 | 41 | if isinstance(self.dropout, SharedDropout): 42 | if dropout_mask is not None: 43 | self.dropout.set_mask(dropout_mask) 44 | out = self.dropout(out) 45 | 46 | if edge_emb is not None: 47 | out = self.gcn(out, edge_index, edge_emb) 48 | else: 49 | out = self.gcn(out, edge_index) 50 | 51 | return out 52 | 53 | 54 | class GENBlock(BasicBlock): 55 | def __init__(self, in_channels, out_channels, 56 | aggr='max', 57 | t=1.0, learn_t=False, 58 | p=1.0, learn_p=False, 59 | y=0.0, learn_y=False, 60 | msg_norm=False, 61 | learn_msg_scale=False, 62 | encode_edge=False, 63 | edge_feat_dim=0, 64 | norm='layer', mlp_layers=1): 65 | super(GENBlock, self).__init__(norm, in_channels) 66 | 67 | self.gcn = GENConv(in_channels, out_channels, 68 | aggr=aggr, 69 | t=t, learn_t=learn_t, 70 | p=p, learn_p=learn_p, 71 | y=y, learn_y=learn_y, 72 | msg_norm=msg_norm, 73 | learn_msg_scale=learn_msg_scale, 74 | encode_edge=encode_edge, 75 | edge_feat_dim=edge_feat_dim, 76 | norm=norm, 77 | mlp_layers=mlp_layers) 78 | 79 | 80 | class GCNBlock(BasicBlock): 81 | def __init__(self, in_channels, out_channels, 82 | norm='layer'): 83 | super(GCNBlock, self).__init__(norm, in_channels) 84 | 85 | self.gcn = GCNConv(in_channels, out_channels) 86 | 87 | 88 | class SAGEBlock(BasicBlock): 89 | def __init__(self, in_channels, out_channels, 90 | norm='layer', 91 | dropout=0.0): 92 | super(SAGEBlock, self).__init__(norm, in_channels) 93 | 94 | self.gcn = SAGEConv(in_channels, out_channels) 95 | 96 | 97 | class GATBlock(torch.nn.Module): 98 | def __init__(self, in_channels, out_channels, 99 | heads=1, 100 | norm='layer', 101 | att_dropout=0.0, 102 | dropout=0.0): 103 | super(GATBlock, self).__init__(norm, in_channels) 104 | 105 | self.gcn = GATConv(in_channels, out_channels, 106 | heads=heads, 107 | concat=False, 108 | dropout=att_dropout, 109 | add_self_loops=False) 110 | -------------------------------------------------------------------------------- /examples/modelnet_cls/README.md: -------------------------------------------------------------------------------- 1 | ## [Point cloud classification on ModelNet40](https://arxiv.org/pdf/1910.06849.pdf) 2 | 3 |

4 | 5 |

6 | 7 | ### Train 8 | We train PlainGCN-28 and ResGCN-28 models on one Tesla V100. 9 | For DenseGCN, we use 4 Tesla V100s. 10 | 11 | For training ResGCN-28, run: 12 | ``` 13 | python main.py --phase train --n_blocks 28 --block res --data_dir /path/to/modelnet40 14 | ``` 15 | Just need to set `--data_dir` into your data folder, dataset will be downloaded automatically. 16 | 17 | ### Test 18 | Models can be tested on one 1080Ti. 19 | Our pretrained models are available [Google Drive](https://drive.google.com/drive/folders/1LUWH0V3ZoHNQBylj0u0_36Mx0-UrDh1v?usp=sharing). 20 | 21 | Use the parameter `--pretrained_model` to set a specific pretrained model to load. For example, 22 | 23 | ``` 24 | python main.py --phase test --n_blocks 28 --block res --pretrained_model /path/to/pretrained_model --data_dir /path/to/modelnet40 25 | ``` 26 | 27 | -------------------------------------------------------------------------------- /examples/modelnet_cls/__init__.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../')) 3 | 4 | -------------------------------------------------------------------------------- /examples/modelnet_cls/architecture.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import __init__ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.nn import Sequential as Seq 8 | from gcn_lib.dense import BasicConv, GraphConv2d, ResDynBlock2d, DenseDynBlock2d, DilatedKnnGraph, PlainDynBlock2d 9 | 10 | 11 | class DeepGCN(torch.nn.Module): 12 | def __init__(self, opt): 13 | super(DeepGCN, self).__init__() 14 | channels = opt.n_filters 15 | k = opt.k 16 | act = opt.act 17 | norm = opt.norm 18 | bias = opt.bias 19 | knn = 'matrix' # implement knn using matrix multiplication 20 | epsilon = opt.epsilon 21 | stochastic = opt.use_stochastic 22 | conv = opt.conv 23 | c_growth = channels 24 | emb_dims = opt.emb_dims 25 | self.n_blocks = opt.n_blocks 26 | 27 | self.knn = DilatedKnnGraph(k, 1, stochastic, epsilon) 28 | self.head = GraphConv2d(opt.in_channels, channels, conv, act, norm, bias=False) 29 | 30 | if opt.block.lower() == 'dense': 31 | self.backbone = Seq(*[DenseDynBlock2d(channels+c_growth*i, c_growth, k, 1+i, conv, act, 32 | norm, bias, stochastic, epsilon, knn) 33 | for i in range(self.n_blocks-1)]) 34 | fusion_dims = int( 35 | (channels + channels + c_growth * (self.n_blocks-1)) * self.n_blocks // 2) 36 | 37 | elif opt.block.lower() == 'res': 38 | if opt.use_dilation: 39 | self.backbone = Seq(*[ResDynBlock2d(channels, k, i + 1, conv, act, norm, 40 | bias, stochastic, epsilon, knn) 41 | for i in range(self.n_blocks - 1)]) 42 | else: 43 | self.backbone = Seq(*[ResDynBlock2d(channels, k, 1, conv, act, norm, 44 | bias, stochastic, epsilon, knn) 45 | for _ in range(self.n_blocks - 1)]) 46 | fusion_dims = int(channels + c_growth * (self.n_blocks - 1)) 47 | else: 48 | # Plain GCN. No dilation, no stochastic, no residual connections 49 | stochastic = False 50 | 51 | self.backbone = Seq(*[PlainDynBlock2d(channels, k, 1, conv, act, norm, 52 | bias, stochastic, epsilon, knn) 53 | for i in range(self.n_blocks - 1)]) 54 | 55 | fusion_dims = int(channels+c_growth*(self.n_blocks-1)) 56 | 57 | self.fusion_block = BasicConv([fusion_dims, emb_dims], 'leakyrelu', norm, bias=False) 58 | self.prediction = Seq(*[BasicConv([emb_dims * 2, 512], 'leakyrelu', norm, drop=opt.dropout), 59 | BasicConv([512, 256], 'leakyrelu', norm, drop=opt.dropout), 60 | BasicConv([256, opt.n_classes], None, None)]) 61 | self.model_init() 62 | 63 | def model_init(self): 64 | for m in self.modules(): 65 | if isinstance(m, torch.nn.Conv2d): 66 | torch.nn.init.kaiming_normal_(m.weight) 67 | m.weight.requires_grad = True 68 | if m.bias is not None: 69 | m.bias.data.zero_() 70 | m.bias.requires_grad = True 71 | 72 | def forward(self, inputs): 73 | feats = [self.head(inputs, self.knn(inputs[:, 0:3]))] 74 | for i in range(self.n_blocks-1): 75 | feats.append(self.backbone[i](feats[-1])) 76 | 77 | feats = torch.cat(feats, dim=1) 78 | fusion = self.fusion_block(feats) 79 | x1 = F.adaptive_max_pool2d(fusion, 1) 80 | x2 = F.adaptive_avg_pool2d(fusion, 1) 81 | return self.prediction(torch.cat((x1, x2), dim=1)).squeeze(-1).squeeze(-1) 82 | 83 | 84 | if __name__ == '__main__': 85 | import argparse 86 | parser = argparse.ArgumentParser(description='Point Cloud Segmentation') 87 | # ----------------- Model related 88 | parser.add_argument('--k', default=9, type=int, help='neighbor num (default:9)') 89 | parser.add_argument('--block', default='res', type=str, help='graph backbone block type {res, plain, dense}') 90 | parser.add_argument('--conv', default='edge', type=str, help='graph conv layer {edge, mr}') 91 | parser.add_argument('--act', default='relu', type=str, help='activation layer {relu, prelu, leakyrelu}') 92 | parser.add_argument('--norm', default='batch', type=str, 93 | help='batch or instance normalization {batch, instance}') 94 | parser.add_argument('--bias', default=True, type=bool, help='bias of conv layer True or False') 95 | parser.add_argument('--n_blocks', type=int, default=14, help='number of basic blocks in the backbone') 96 | parser.add_argument('--n_filters', default=64, type=int, help='number of channels of deep features') 97 | parser.add_argument('--in_channels', type=int, default=3, help='Dimension of input ') 98 | parser.add_argument('--n_classes', type=int, default=40, help='Dimension of out_channels ') 99 | parser.add_argument('--emb_dims', type=int, default=1024, metavar='N', help='Dimension of embeddings') 100 | parser.add_argument('--dropout', type=float, default=0.5, help='dropout rate') 101 | # dilated knn 102 | parser.add_argument('--use_dilation', default=True, type=bool, help='use dilated knn or not') 103 | parser.add_argument('--epsilon', default=0.2, type=float, help='stochastic epsilon for gcn') 104 | parser.add_argument('--use_stochastic', default=True, type=bool, help='stochastic for gcn, True or False') 105 | 106 | args = parser.parse_args() 107 | args.device = torch.device('cuda') 108 | 109 | feats = torch.rand((2, 3, 1024, 1), dtype=torch.float).to(args.device) 110 | num_neighbors = 20 111 | 112 | print('Input size {}'.format(feats.size())) 113 | net = DeepGCN(args).to(args.device) 114 | out = net(feats) 115 | print(net) 116 | print('Output size {}'.format(out.size())) 117 | -------------------------------------------------------------------------------- /examples/modelnet_cls/data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import glob 4 | import h5py 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | 8 | 9 | def download(data_dir): 10 | if not os.path.exists(data_dir): 11 | os.makedirs(data_dir) 12 | if not os.path.exists(os.path.join(data_dir, 'modelnet40_ply_hdf5_2048')): 13 | www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip' 14 | zipfile = os.path.basename(www) 15 | os.system('wget %s; unzip %s' % (www, zipfile)) 16 | os.system('mv %s %s' % (zipfile[:-4], data_dir)) 17 | os.system('rm %s' % (zipfile)) 18 | 19 | 20 | def load_data(data_dir, partition): 21 | download(data_dir) 22 | all_data = [] 23 | all_label = [] 24 | for h5_name in glob.glob(os.path.join(data_dir, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5'%partition)): 25 | with h5py.File(h5_name, 'r') as f: 26 | data = f['data'][:].astype('float32') 27 | label = f['label'][:].astype('int64') 28 | all_data.append(data) 29 | all_label.append(label) 30 | all_data = np.concatenate(all_data, axis=0) 31 | all_label = np.concatenate(all_label, axis=0) 32 | return all_data, all_label 33 | 34 | 35 | def translate_pointcloud(pointcloud): 36 | """ 37 | for scaling and shifting the point cloud 38 | :param pointcloud: 39 | :return: 40 | """ 41 | scale = np.random.uniform(low=2. / 3., high=3. / 2., size=[3]) 42 | shift = np.random.uniform(low=-0.2, high=0.2, size=[3]) 43 | translated_pointcloud = np.add(np.multiply(pointcloud, scale), shift).astype('float32') 44 | return translated_pointcloud 45 | 46 | 47 | class ModelNet40(Dataset): 48 | """ 49 | This is the data loader for ModelNet 40 50 | ModelNet40 contains 12,311 meshed CAD models from 40 categories. 51 | 52 | num_points: 1024 by default 53 | data_dir 54 | paritition: train or test 55 | """ 56 | def __init__(self, num_points=1024, data_dir="/data/deepgcn/modelnet40", partition='train'): 57 | self.data, self.label = load_data(data_dir, partition) 58 | self.num_points = num_points 59 | self.partition = partition 60 | 61 | def __getitem__(self, item): 62 | pointcloud = self.data[item][:self.num_points] 63 | label = self.label[item] 64 | if self.partition == 'train': 65 | pointcloud = translate_pointcloud(pointcloud) 66 | np.random.shuffle(pointcloud) 67 | return pointcloud, label 68 | 69 | def __len__(self): 70 | return self.data.shape[0] 71 | 72 | def num_classes(self): 73 | return np.max(self.label) + 1 74 | 75 | 76 | if __name__ == '__main__': 77 | train = ModelNet40(1024) 78 | test = ModelNet40(1024, 'test') 79 | for data, label in train: 80 | print(data.shape) 81 | print(label.shape) 82 | -------------------------------------------------------------------------------- /examples/ogb/README.md: -------------------------------------------------------------------------------- 1 | # [DeeperGCN: All You Need to Train Deeper GCNs](https://arxiv.org/abs/2006.07739) 2 | In this work, we propose a novel Generalized Aggregation Function suited for graph convolutions. We show how our function covers all commonly used aggregations. Our generalized aggregation function is fully differentiable and can also be learned in an end-to-end fashion. We also show how by modifying current GCN skip connections and introducing a novel message normalization layer, we can enhance the performance in several benchmarks. Through combining our generalized aggregations, modified skip connections, and message normalization, we achieve state-of-the-art (SOTA) performance on four [Open Graph Benchmark](https://ogb.stanford.edu/) (OGB) datasets. 3 | [[paper](https://arxiv.org/pdf/2006.07739.pdf)] 4 | 5 | ## Overview 6 | The definition of generalized message aggregation functions help us to find a family of differentiable permutation invariant aggregators. In order to cover the *Mean* and *Max* aggregations into the function space, we propose two variants of generalized mean-max aggregation functions, ***SoftMax_Aggβ(.)*** 7 | and ***PowerMean_Aggp(.)***. They can also be instantiated as a *Min* aggregator as β or p goes to −∞. 8 | 9 |

10 | 11 |

12 | 13 | ## DyResGEN 14 | 15 | Learning curves of 7-layer DyResGEN with ***SoftMax_Aggβ(.)*** and MsgNorm. 16 | 17 |

18 | 19 |

20 | 21 | Learning curves of 7-layer DyResGEN with ***PowerMean_Aggp(.)*** and MsgNorm. 22 | 23 |

24 | 25 |

26 | 27 | ## Results on OGB Datasets 28 | 29 | 30 | |Dataset | Test | 31 | |-------------|---------------| 32 | |[ogbn-products](ogbn-products)|0.8098 ± 0.0020| 33 | |[ogbn-proteins](ogbn_proteins)|0.8580 ± 0.0017| 34 | |[ogbn-arxiv](ogbn_arxiv) |0.7192 ± 0.0016| 35 | |[ogbg-molhiv](ogbg_mol) |0.7858 ± 0.0117| 36 | |[ogbg-molpcba](ogbg_mol) |0.2745 ± 0.0025| 37 | |[ogbg-ppa](ogbg_ppa ) |0.7712 ± 0.0071| 38 | 39 | ## Requirements 40 | 41 | - [PyTorch 1.5.0](https://pytorch.org/get-started/locally/) 42 | - [torch-geometric 1.6.0](https://pytorch-geometric.readthedocs.io/en/latest/index.html) 43 | - [ogb >= 1.1.1](https://ogb.stanford.edu/docs/home/) 44 | 45 | Install enviroment by runing: 46 | 47 | source deepgcn_env_install.sh 48 | 49 | Please cite our paper if you find anything helpful, 50 | 51 | ``` 52 | @misc{li2020deepergcn, 53 | title={DeeperGCN: All You Need to Train Deeper GCNs}, 54 | author={Guohao Li and Chenxin Xiong and Ali Thabet and Bernard Ghanem}, 55 | year={2020}, 56 | eprint={2006.07739}, 57 | archivePrefix={arXiv}, 58 | primaryClass={cs.LG} 59 | } 60 | ``` 61 | -------------------------------------------------------------------------------- /examples/ogb/ogbg_mol/README.md: -------------------------------------------------------------------------------- 1 | # ogbg_mol 2 | The code is shared by two molecular datasets: ogbg_molhiv and ogbg_molpcba. 3 | ## Default 4 | --use_gpu False 5 | --dataset ogbg-molhiv 6 | --batch_size 32 7 | --block res+ #options: [plain, res, res+] 8 | --conv gen 9 | --gcn_aggr max #options: [max, mean, add, softmax, softmax_sg, softmax_sum, power, power_sum] 10 | --num_layers 3 11 | --conv_encode_edge False 12 | --add_virtual_node False 13 | --mlp_layers 1 14 | --norm batch 15 | --hidden_channels 256 16 | --epochs 300 17 | --lr 0.01 18 | --dropout 0.5 19 | --graph_pooling mean #options: [mean, max, sum] 20 | ## ogbg_molhiv: DyResGEN 21 | ### Train 22 | python main.py --use_gpu --conv_encode_edge --num_layers 7 --dataset ogbg-molhiv --block res+ --gcn_aggr softmax --t 1.0 --learn_t --dropout 0.2 --lr 0.0001 23 | ### Test (use pre-trained model, [download](https://drive.google.com/file/d/1ja1xc2a4U4ps8AtZm5xo2CmffWA-C5Yl/view?usp=sharing) from Google Drive) 24 | python test.py --use_gpu --conv_encode_edge --num_layers 7 --dataset ogbg-molhiv --block res+ --gcn_aggr softmax --t 1.0 --learn_t 25 | 26 | ## ogbg_molpcba: ResGEN + virtual nodes 27 | ### Train 28 | python main.py --use_gpu --conv_encode_edge --add_virtual_node --mlp_layers 2 --num_layers 14 --dataset ogbg-molpcba --block res+ --gcn_aggr softmax_sg --t 0.1 29 | 30 | ### Test (use pre-trained model, [download](https://drive.google.com/file/d/1OYds41b7NNKGYBt52bro8lbxSCXALalx/view?usp=sharing) from Google Drive) 31 | 32 | python test.py --use_gpu --conv_encode_edge --add_virtual_node --mlp_layers 2 --num_layers 14 --dataset ogbg-molpcba --block res+ --gcn_aggr softmax_sg --t 0.1 --model_load_path ogbg_molpcba_pretrained_model.pth 33 | -------------------------------------------------------------------------------- /examples/ogb/ogbg_mol/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 4 | sys.path.append(ROOT_DIR) 5 | -------------------------------------------------------------------------------- /examples/ogb/ogbg_mol/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import uuid 3 | import logging 4 | import time 5 | import os 6 | import sys 7 | from utils.logger import create_exp_dir 8 | import glob 9 | 10 | 11 | class ArgsInit(object): 12 | def __init__(self): 13 | parser = argparse.ArgumentParser(description='DeeperGCN') 14 | # dataset 15 | parser.add_argument('--dataset', type=str, default="ogbg-molhiv", 16 | help='dataset name (default: ogbg-molhiv)') 17 | parser.add_argument('--num_workers', type=int, default=0, 18 | help='number of workers (default: 0)') 19 | parser.add_argument('--batch_size', type=int, default=32, 20 | help='input batch size for training (default: 32)') 21 | parser.add_argument('--feature', type=str, default='full', 22 | help='two options: full or simple') 23 | parser.add_argument('--add_virtual_node', action='store_true') 24 | # training & eval settings 25 | parser.add_argument('--use_gpu', action='store_true') 26 | parser.add_argument('--device', type=int, default=0, 27 | help='which gpu to use if any (default: 0)') 28 | parser.add_argument('--epochs', type=int, default=300, 29 | help='number of epochs to train (default: 300)') 30 | parser.add_argument('--lr', type=float, default=0.01, 31 | help='learning rate set for optimizer.') 32 | parser.add_argument('--dropout', type=float, default=0.5) 33 | parser.add_argument('--grad_clip', type=float, default=0., 34 | help='gradient clipping') 35 | 36 | # model 37 | parser.add_argument('--num_layers', type=int, default=3, 38 | help='the number of layers of the networks') 39 | parser.add_argument('--mlp_layers', type=int, default=1, 40 | help='the number of layers of mlp in conv') 41 | parser.add_argument('--hidden_channels', type=int, default=256, 42 | help='the dimension of embeddings of nodes and edges') 43 | parser.add_argument('--block', default='res+', type=str, 44 | help='graph backbone block type {res+, res, dense, plain}') 45 | parser.add_argument('--conv', type=str, default='gen', 46 | help='the type of GCNs') 47 | parser.add_argument('--gcn_aggr', type=str, default='max', 48 | help='the aggregator of GENConv [mean, max, add, softmax, softmax_sg, power]') 49 | parser.add_argument('--norm', type=str, default='batch', 50 | help='the type of normalization layer') 51 | parser.add_argument('--num_tasks', type=int, default=1, 52 | help='the number of prediction tasks') 53 | # learnable parameters 54 | parser.add_argument('--t', type=float, default=1.0, 55 | help='the temperature of SoftMax') 56 | parser.add_argument('--p', type=float, default=1.0, 57 | help='the power of PowerMean') 58 | parser.add_argument('--learn_t', action='store_true') 59 | parser.add_argument('--learn_p', action='store_true') 60 | parser.add_argument('--y', type=float, default=0.0, 61 | help='the power of softmax_sum and powermean_sum') 62 | parser.add_argument('--learn_y', action='store_true') 63 | 64 | # message norm 65 | parser.add_argument('--msg_norm', action='store_true') 66 | parser.add_argument('--learn_msg_scale', action='store_true') 67 | # encode edge in conv 68 | parser.add_argument('--conv_encode_edge', action='store_true') 69 | # graph pooling type 70 | parser.add_argument('--graph_pooling', type=str, default='mean', 71 | help='graph pooling method') 72 | # save model 73 | parser.add_argument('--model_save_path', type=str, default='model_ckpt', 74 | help='the directory used to save models') 75 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 76 | # load pre-trained model 77 | parser.add_argument('--model_load_path', type=str, default='ogbg_molhiv_pretrained_model.pth', 78 | help='the path of pre-trained model') 79 | 80 | self.args = parser.parse_args() 81 | 82 | def save_exp(self): 83 | self.args.save = '{}-B_{}-C_{}-L_{}-F_{}-DP_{}' \ 84 | '-GA_{}-T_{}-LT_{}-P_{}-LP_{}-Y_{}-LY_{}' \ 85 | '-MN_{}-LS_{}'.format(self.args.save, self.args.block, self.args.conv, 86 | self.args.num_layers, self.args.hidden_channels, 87 | self.args.dropout, self.args.gcn_aggr, 88 | self.args.t, self.args.learn_t, self.args.p, self.args.learn_p, 89 | self.args.y, self.args.learn_y, 90 | self.args.msg_norm, self.args.learn_msg_scale) 91 | 92 | self.args.save = 'log/{}-{}-{}'.format(self.args.save, time.strftime("%Y%m%d-%H%M%S"), str(uuid.uuid4())) 93 | self.args.model_save_path = os.path.join(self.args.save, self.args.model_save_path) 94 | create_exp_dir(self.args.save, scripts_to_save=glob.glob('*.py')) 95 | log_format = '%(asctime)s %(message)s' 96 | logging.basicConfig(stream=sys.stdout, 97 | level=logging.INFO, 98 | format=log_format, 99 | datefmt='%m/%d %I:%M:%S %p') 100 | fh = logging.FileHandler(os.path.join(self.args.save, 'log.txt')) 101 | fh.setFormatter(logging.Formatter(log_format)) 102 | logging.getLogger().addHandler(fh) 103 | 104 | return self.args 105 | -------------------------------------------------------------------------------- /examples/ogb/ogbg_mol/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import DataLoader 3 | import torch.optim as optim 4 | from model import DeeperGCN 5 | from tqdm import tqdm 6 | from args import ArgsInit 7 | from utils.ckpt_util import save_ckpt 8 | import logging 9 | import time 10 | import statistics 11 | from ogb.graphproppred import PygGraphPropPredDataset, Evaluator 12 | 13 | 14 | def train(model, device, loader, optimizer, task_type, grad_clip=0.): 15 | loss_list = [] 16 | model.train() 17 | 18 | for step, batch in enumerate(tqdm(loader, desc="Iteration")): 19 | batch = batch.to(device) 20 | 21 | if batch.x.shape[0] == 1 or batch.batch[-1] == 0: 22 | pass 23 | else: 24 | optimizer.zero_grad() 25 | pred = model(batch) 26 | is_labeled = batch.y == batch.y 27 | if "classification" in task_type: 28 | loss = cls_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled]) 29 | else: 30 | loss = reg_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled]) 31 | 32 | loss.backward() 33 | 34 | if grad_clip > 0: 35 | torch.nn.utils.clip_grad_value_( 36 | model.parameters(), 37 | grad_clip) 38 | 39 | optimizer.step() 40 | 41 | loss_list.append(loss.item()) 42 | return statistics.mean(loss_list) 43 | 44 | 45 | @torch.no_grad() 46 | def eval(model, device, loader, evaluator): 47 | model.eval() 48 | y_true = [] 49 | y_pred = [] 50 | 51 | for step, batch in enumerate(tqdm(loader, desc="Iteration")): 52 | batch = batch.to(device) 53 | 54 | if batch.x.shape[0] == 1: 55 | pass 56 | else: 57 | pred = model(batch) 58 | y_true.append(batch.y.view(pred.shape).detach().cpu()) 59 | y_pred.append(pred.detach().cpu()) 60 | 61 | y_true = torch.cat(y_true, dim=0).numpy() 62 | y_pred = torch.cat(y_pred, dim=0).numpy() 63 | 64 | input_dict = {"y_true": y_true, 65 | "y_pred": y_pred} 66 | 67 | return evaluator.eval(input_dict) 68 | 69 | 70 | def main(): 71 | 72 | args = ArgsInit().save_exp() 73 | 74 | if args.use_gpu: 75 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 76 | else: 77 | device = torch.device('cpu') 78 | 79 | sub_dir = 'BS_{}-NF_{}'.format(args.batch_size, 80 | args.feature) 81 | 82 | dataset = PygGraphPropPredDataset(name=args.dataset) 83 | args.num_tasks = dataset.num_tasks 84 | logging.info('%s' % args) 85 | 86 | if args.feature == 'full': 87 | pass 88 | elif args.feature == 'simple': 89 | print('using simple feature') 90 | # only retain the top two node/edge features 91 | dataset.data.x = dataset.data.x[:, :2] 92 | dataset.data.edge_attr = dataset.data.edge_attr[:, :2] 93 | 94 | evaluator = Evaluator(args.dataset) 95 | split_idx = dataset.get_idx_split() 96 | 97 | train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True, 98 | num_workers=args.num_workers) 99 | valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.batch_size, shuffle=False, 100 | num_workers=args.num_workers) 101 | test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False, 102 | num_workers=args.num_workers) 103 | 104 | model = DeeperGCN(args).to(device) 105 | 106 | logging.info(model) 107 | 108 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 109 | 110 | results = {'highest_valid': 0, 111 | 'final_train': 0, 112 | 'final_test': 0, 113 | 'highest_train': 0} 114 | 115 | start_time = time.time() 116 | 117 | for epoch in range(1, args.epochs + 1): 118 | logging.info("=====Epoch {}".format(epoch)) 119 | logging.info('Training...') 120 | 121 | epoch_loss = train(model, device, train_loader, optimizer, dataset.task_type, grad_clip=args.grad_clip) 122 | 123 | logging.info('Evaluating...') 124 | train_result = eval(model, device, train_loader, evaluator)[dataset.eval_metric] 125 | valid_result = eval(model, device, valid_loader, evaluator)[dataset.eval_metric] 126 | test_result = eval(model, device, test_loader, evaluator)[dataset.eval_metric] 127 | 128 | logging.info({'Train': train_result, 129 | 'Validation': valid_result, 130 | 'Test': test_result}) 131 | 132 | model.print_params(epoch=epoch) 133 | 134 | if train_result > results['highest_train']: 135 | 136 | results['highest_train'] = train_result 137 | 138 | if valid_result > results['highest_valid']: 139 | results['highest_valid'] = valid_result 140 | results['final_train'] = train_result 141 | results['final_test'] = test_result 142 | 143 | save_ckpt(model, optimizer, 144 | round(epoch_loss, 4), epoch, 145 | args.model_save_path, 146 | sub_dir, name_post='valid_best') 147 | 148 | logging.info("%s" % results) 149 | 150 | end_time = time.time() 151 | total_time = end_time - start_time 152 | logging.info('Total time: {}'.format(time.strftime('%H:%M:%S', time.gmtime(total_time)))) 153 | 154 | 155 | if __name__ == "__main__": 156 | cls_criterion = torch.nn.BCEWithLogitsLoss() 157 | reg_criterion = torch.nn.MSELoss() 158 | main() 159 | -------------------------------------------------------------------------------- /examples/ogb/ogbg_mol/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import DataLoader 3 | from model import DeeperGCN 4 | from tqdm import tqdm 5 | from args import ArgsInit 6 | from ogb.graphproppred import PygGraphPropPredDataset, Evaluator 7 | 8 | 9 | @torch.no_grad() 10 | def eval(model, device, loader, evaluator): 11 | model.eval() 12 | y_true = [] 13 | y_pred = [] 14 | 15 | for step, batch in enumerate(tqdm(loader, desc="Iteration")): 16 | batch = batch.to(device) 17 | 18 | if batch.x.shape[0] == 1: 19 | pass 20 | else: 21 | pred = model(batch) 22 | y_true.append(batch.y.view(pred.shape).detach().cpu()) 23 | y_pred.append(pred.detach().cpu()) 24 | 25 | y_true = torch.cat(y_true, dim=0).numpy() 26 | y_pred = torch.cat(y_pred, dim=0).numpy() 27 | 28 | input_dict = {"y_true": y_true, 29 | "y_pred": y_pred} 30 | 31 | return evaluator.eval(input_dict) 32 | 33 | 34 | def main(): 35 | 36 | args = ArgsInit().args 37 | 38 | if args.use_gpu: 39 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 40 | else: 41 | device = torch.device('cpu') 42 | 43 | dataset = PygGraphPropPredDataset(name=args.dataset) 44 | args.num_tasks = dataset.num_tasks 45 | print(args) 46 | 47 | if args.feature == 'full': 48 | pass 49 | elif args.feature == 'simple': 50 | print('using simple feature') 51 | # only retain the top two node/edge features 52 | dataset.data.x = dataset.data.x[:, :2] 53 | dataset.data.edge_attr = dataset.data.edge_attr[:, :2] 54 | 55 | 56 | split_idx = dataset.get_idx_split() 57 | 58 | evaluator = Evaluator(args.dataset) 59 | 60 | train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=False, 61 | num_workers=args.num_workers) 62 | valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.batch_size, shuffle=False, 63 | num_workers=args.num_workers) 64 | test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False, 65 | num_workers=args.num_workers) 66 | 67 | model = DeeperGCN(args) 68 | 69 | model.load_state_dict(torch.load(args.model_load_path)['model_state_dict']) 70 | model.to(device) 71 | 72 | train_result = eval(model, device, train_loader, evaluator)[dataset.eval_metric] 73 | valid_result = eval(model, device, valid_loader, evaluator)[dataset.eval_metric] 74 | test_result = eval(model, device, test_loader, evaluator)[dataset.eval_metric] 75 | 76 | print({'Train': train_result, 77 | 'Validation': valid_result, 78 | 'Test': test_result}) 79 | 80 | model.print_params(final=True) 81 | 82 | 83 | if __name__ == "__main__": 84 | main() 85 | -------------------------------------------------------------------------------- /examples/ogb/ogbg_ppa/README.md: -------------------------------------------------------------------------------- 1 | # ogbg-ppa 2 | We initialize the features of nodes of ogbg_ppa dataset through aggregating the features of their connected edges by a Sum (Add) aggregation, just like what we do for ogbn_proteins. 3 | 4 | ## Default 5 | --use_gpu False 6 | --batch_size 32 7 | --aggr add #options: [mean, max, add] 8 | --block res+ #options: [plain, res, res+] 9 | --conv gen 10 | --gcn_aggr max #options: [max, mean, add, softmax, softmax_sg, softmax_sum, power, power_sum] 11 | --num_layers 3 12 | --conv_encode_edge False 13 | --mlp_layers 2 14 | --norm layer 15 | --hidden_channels 128 16 | --epochs 200 17 | --lr 0.01 18 | --dropout 0.5 19 | --graph_pooling mean #options: [mean, max, add] 20 | ## ResGEN 21 | ### Train 22 | python main.py --use_gpu --conv_encode_edge --num_layers 28 --gcn_aggr softmax_sg --t 0.01 23 | 24 | 25 | ### Test (use pre-trained model, [download](https://drive.google.com/file/d/1vlmNPUgDes8QJ0SQoo-K5L_yFVeV1lkH/view?usp=sharing) from Google Drive) 26 | python test.py --use_gpu --conv_encode_edge --num_layers 28 --gcn_aggr softmax_sg --t 0.01 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/ogb/ogbg_ppa/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 4 | sys.path.append(ROOT_DIR) 5 | -------------------------------------------------------------------------------- /examples/ogb/ogbg_ppa/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import uuid 3 | import logging 4 | import time 5 | import os 6 | import sys 7 | from utils.logger import create_exp_dir 8 | import glob 9 | 10 | 11 | class ArgsInit(object): 12 | def __init__(self): 13 | parser = argparse.ArgumentParser(description='DeeperGCN') 14 | # dataset 15 | parser.add_argument('--dataset', type=str, default="ogbg-ppa", 16 | help='dataset name (default: ogbg-ppa)') 17 | parser.add_argument('--num_workers', type=int, default=0, 18 | help='number of workers (default: 0)') 19 | parser.add_argument('--batch_size', type=int, default=32, 20 | help='input batch size for training (default: 32)') 21 | # extract node features 22 | parser.add_argument('--aggr', type=str, default='add', 23 | help='the aggregation operator to obtain nodes\' initial features [mean, max, add]') 24 | parser.add_argument('--not_extract_node_feature', action='store_true') 25 | # training & eval settings 26 | parser.add_argument('--use_gpu', action='store_true') 27 | parser.add_argument('--device', type=int, default=0, 28 | help='which gpu to use if any (default: 0)') 29 | parser.add_argument('--epochs', type=int, default=200, 30 | help='number of epochs to train (default: 100)') 31 | parser.add_argument('--lr', type=float, default=0.01, 32 | help='learning rate set for optimizer.') 33 | parser.add_argument('--dropout', type=float, default=0.5) 34 | # model 35 | parser.add_argument('--num_layers', type=int, default=3, 36 | help='the number of layers of the networks') 37 | parser.add_argument('--mlp_layers', type=int, default=2, 38 | help='the number of layers of mlp in conv') 39 | parser.add_argument('--hidden_channels', type=int, default=128, 40 | help='the dimension of embeddings of nodes and edges') 41 | parser.add_argument('--block', default='res+', type=str, 42 | help='graph backbone block type {res+, res, dense, plain}') 43 | parser.add_argument('--conv', type=str, default='gen', 44 | help='the type of GCNs') 45 | parser.add_argument('--gcn_aggr', type=str, default='max', 46 | help='the aggregator of GENConv [mean, max, add, softmax, softmax_sg, power]') 47 | parser.add_argument('--norm', type=str, default='layer', 48 | help='the type of normalization layer') 49 | parser.add_argument('--num_tasks', type=int, default=1, 50 | help='the number of prediction tasks') 51 | # learnable parameters 52 | parser.add_argument('--t', type=float, default=1.0, 53 | help='the temperature of SoftMax') 54 | parser.add_argument('--p', type=float, default=1.0, 55 | help='the power of PowerMean') 56 | parser.add_argument('--learn_t', action='store_true') 57 | parser.add_argument('--learn_p', action='store_true') 58 | # message norm 59 | parser.add_argument('--msg_norm', action='store_true') 60 | parser.add_argument('--learn_msg_scale', action='store_true') 61 | # encode edge in conv 62 | parser.add_argument('--conv_encode_edge', action='store_true') 63 | # graph pooling type 64 | parser.add_argument('--graph_pooling', type=str, default='mean', 65 | help='graph pooling method') 66 | # save model 67 | parser.add_argument('--model_save_path', type=str, default='model_ckpt', 68 | help='the directory used to save models') 69 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 70 | # load pre-trained model 71 | parser.add_argument('--model_load_path', type=str, default='ogbg_ppa_pretrained_model.pth', 72 | help='the path of pre-trained model') 73 | # others, eval steps 74 | parser.add_argument('--eval_steps', type=int, default=5) 75 | parser.add_argument('--num_layers_threshold', type=int, default=14) 76 | 77 | self.args = parser.parse_args() 78 | 79 | def save_exp(self): 80 | self.args.save = '{}-B_{}-C_{}-L_{}-F_{}-DP_{}' \ 81 | '-A_{}-GA_{}-T_{}-LT_{}-P_{}-LP_{}' \ 82 | '-MN_{}-LS_{}'.format(self.args.save, self.args.block, self.args.conv, 83 | self.args.num_layers, self.args.hidden_channels, self.args.dropout, 84 | self.args.aggr, self.args.gcn_aggr, 85 | self.args.t, self.args.learn_t, self.args.p, self.args.learn_p, 86 | self.args.msg_norm, self.args.learn_msg_scale) 87 | 88 | self.args.save = 'log/{}-{}-{}'.format(self.args.save, time.strftime("%Y%m%d-%H%M%S"), str(uuid.uuid4())) 89 | self.args.model_save_path = os.path.join(self.args.save, self.args.model_save_path) 90 | create_exp_dir(self.args.save, scripts_to_save=glob.glob('*.py')) 91 | log_format = '%(asctime)s %(message)s' 92 | logging.basicConfig(stream=sys.stdout, 93 | level=logging.INFO, 94 | format=log_format, 95 | datefmt='%m/%d %I:%M:%S %p') 96 | fh = logging.FileHandler(os.path.join(self.args.save, 'log.txt')) 97 | fh.setFormatter(logging.Formatter(log_format)) 98 | logging.getLogger().addHandler(fh) 99 | 100 | return self.args 101 | -------------------------------------------------------------------------------- /examples/ogb/ogbg_ppa/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import DataLoader 3 | import torch.optim as optim 4 | from model import DeeperGCN 5 | from tqdm import tqdm 6 | from args import ArgsInit 7 | from utils.ckpt_util import save_ckpt 8 | from utils.data_util import add_zeros, extract_node_feature 9 | import logging 10 | from functools import partial 11 | import time 12 | import statistics 13 | from ogb.graphproppred import PygGraphPropPredDataset, Evaluator 14 | 15 | 16 | def train(model, device, loader, optimizer, criterion): 17 | loss_list = [] 18 | model.train() 19 | 20 | for step, batch in enumerate(tqdm(loader, desc="Iteration")): 21 | batch = batch.to(device) 22 | 23 | if batch.x.shape[0] == 1 or batch.batch[-1] == 0: 24 | pass 25 | else: 26 | pred = model(batch) 27 | optimizer.zero_grad() 28 | 29 | loss = criterion(pred.to(torch.float32), batch.y.view(-1, )) 30 | 31 | loss.backward() 32 | optimizer.step() 33 | loss_list.append(loss.item()) 34 | return statistics.mean(loss_list) 35 | 36 | 37 | @torch.no_grad() 38 | def eval(model, device, loader, evaluator): 39 | model.eval() 40 | y_true = [] 41 | y_pred = [] 42 | 43 | for step, batch in enumerate(tqdm(loader, desc="Iteration")): 44 | batch = batch.to(device) 45 | 46 | if batch.x.shape[0] == 1: 47 | pass 48 | else: 49 | pred = model(batch) 50 | y_true.append(batch.y.view(-1, 1).detach().cpu()) 51 | y_pred.append(torch.argmax(pred.detach(), dim=1).view(-1, 1).cpu()) 52 | 53 | y_true = torch.cat(y_true, dim=0).numpy() 54 | y_pred = torch.cat(y_pred, dim=0).numpy() 55 | 56 | input_dict = {"y_true": y_true, "y_pred": y_pred} 57 | 58 | return evaluator.eval(input_dict)['acc'] 59 | 60 | 61 | def main(): 62 | args = ArgsInit().save_exp() 63 | 64 | if args.use_gpu: 65 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 66 | else: 67 | device = torch.device('cpu') 68 | 69 | sub_dir = 'BS_{}'.format(args.batch_size) 70 | 71 | if args.not_extract_node_feature: 72 | dataset = PygGraphPropPredDataset(name=args.dataset, 73 | transform=add_zeros) 74 | else: 75 | extract_node_feature_func = partial(extract_node_feature, reduce=args.aggr) 76 | dataset = PygGraphPropPredDataset(name=args.dataset, 77 | transform=extract_node_feature_func) 78 | 79 | sub_dir = sub_dir + '-NF_{}'.format(args.aggr) 80 | 81 | args.num_tasks = dataset.num_classes 82 | evaluator = Evaluator(args.dataset) 83 | 84 | logging.info('%s' % args) 85 | 86 | split_idx = dataset.get_idx_split() 87 | 88 | train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True, 89 | num_workers=args.num_workers) 90 | valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.batch_size, shuffle=False, 91 | num_workers=args.num_workers) 92 | test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False, 93 | num_workers=args.num_workers) 94 | 95 | model = DeeperGCN(args).to(device) 96 | 97 | logging.info(model) 98 | 99 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 100 | criterion = torch.nn.CrossEntropyLoss() 101 | 102 | results = {'highest_valid': 0, 103 | 'final_train': 0, 104 | 'final_test': 0, 105 | 'highest_train': 0} 106 | 107 | start_time = time.time() 108 | 109 | evaluate = True 110 | 111 | for epoch in range(1, args.epochs + 1): 112 | logging.info("=====Epoch {}".format(epoch)) 113 | logging.info('Training...') 114 | 115 | epoch_loss = train(model, device, train_loader, optimizer, criterion) 116 | 117 | if args.num_layers > args.num_layers_threshold: 118 | if epoch % args.eval_steps != 0: 119 | evaluate = False 120 | else: 121 | evaluate = True 122 | 123 | model.print_params(epoch=epoch) 124 | 125 | if evaluate: 126 | 127 | logging.info('Evaluating...') 128 | 129 | train_accuracy = eval(model, device, train_loader, evaluator) 130 | valid_accuracy = eval(model, device, valid_loader, evaluator) 131 | test_accuracy = eval(model, device, test_loader, evaluator) 132 | 133 | logging.info({'Train': train_accuracy, 134 | 'Validation': valid_accuracy, 135 | 'Test': test_accuracy}) 136 | 137 | if train_accuracy > results['highest_train']: 138 | 139 | results['highest_train'] = train_accuracy 140 | 141 | if valid_accuracy > results['highest_valid']: 142 | results['highest_valid'] = valid_accuracy 143 | results['final_train'] = train_accuracy 144 | results['final_test'] = test_accuracy 145 | 146 | save_ckpt(model, optimizer, 147 | round(epoch_loss, 4), epoch, 148 | args.model_save_path, 149 | sub_dir, name_post='valid_best') 150 | 151 | logging.info("%s" % results) 152 | 153 | end_time = time.time() 154 | total_time = end_time - start_time 155 | logging.info('Total time: {}'.format(time.strftime('%H:%M:%S', time.gmtime(total_time)))) 156 | 157 | 158 | if __name__ == "__main__": 159 | main() 160 | -------------------------------------------------------------------------------- /examples/ogb/ogbg_ppa/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import DataLoader 3 | from model import DeeperGCN 4 | from tqdm import tqdm 5 | from args import ArgsInit 6 | from utils.data_util import add_zeros, extract_node_feature 7 | from ogb.graphproppred import PygGraphPropPredDataset, Evaluator 8 | from functools import partial 9 | 10 | 11 | @torch.no_grad() 12 | def eval(model, device, loader, evaluator): 13 | model.eval() 14 | y_true = [] 15 | y_pred = [] 16 | 17 | for step, batch in enumerate(tqdm(loader, desc="Iteration")): 18 | batch = batch.to(device) 19 | 20 | if batch.x.shape[0] == 1: 21 | pass 22 | else: 23 | pred = model(batch) 24 | y_true.append(batch.y.view(-1, 1).detach().cpu()) 25 | y_pred.append(torch.argmax(pred.detach(), dim=1).view(-1, 1).cpu()) 26 | 27 | y_true = torch.cat(y_true, dim=0).numpy() 28 | y_pred = torch.cat(y_pred, dim=0).numpy() 29 | 30 | input_dict = {"y_true": y_true, "y_pred": y_pred} 31 | 32 | return evaluator.eval(input_dict)['acc'] 33 | 34 | 35 | def main(): 36 | 37 | args = ArgsInit().args 38 | 39 | if args.use_gpu: 40 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 41 | else: 42 | device = torch.device('cpu') 43 | 44 | if args.not_extract_node_feature: 45 | dataset = PygGraphPropPredDataset(name=args.dataset, 46 | transform=add_zeros) 47 | else: 48 | extract_node_feature_func = partial(extract_node_feature, reduce=args.aggr) 49 | dataset = PygGraphPropPredDataset(name=args.dataset, 50 | transform=extract_node_feature_func) 51 | 52 | args.num_tasks = dataset.num_classes 53 | evaluator = Evaluator(args.dataset) 54 | 55 | split_idx = dataset.get_idx_split() 56 | 57 | train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=False, 58 | num_workers=args.num_workers) 59 | valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.batch_size, shuffle=False, 60 | num_workers=args.num_workers) 61 | test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False, 62 | num_workers=args.num_workers) 63 | 64 | print(args) 65 | 66 | model = DeeperGCN(args) 67 | model.load_state_dict(torch.load(args.model_load_path)['model_state_dict']) 68 | model.to(device) 69 | 70 | train_accuracy = eval(model, device, train_loader, evaluator) 71 | valid_accuracy = eval(model, device, valid_loader, evaluator) 72 | test_accuracy = eval(model, device, test_loader, evaluator) 73 | 74 | print({'Train': train_accuracy, 75 | 'Validation': valid_accuracy, 76 | 'Test': test_accuracy}) 77 | model.print_params(final=True) 78 | 79 | 80 | if __name__ == "__main__": 81 | main() 82 | -------------------------------------------------------------------------------- /examples/ogb/ogbl_collab/README.md: -------------------------------------------------------------------------------- 1 | # ogbn-arxiv 2 | ## Default 3 | --use_gpu False 4 | --block res+ #options: [plain, res, res+] 5 | --conv gen 6 | --gcn_aggr max #options: [max, mean, add, softmax, softmax_sg, softmax_sum, power, power_sum] 7 | --num_layers 3 #the number of layers of DeeperGCN model 8 | --lp_num_layers 3 #the number of layers of the link predictor model 9 | --mlp_layers 1 10 | --norm batch 11 | --lp_norm #the type of normalization layer for link predictor 12 | --hidden_channels 128 13 | --epochs 400 14 | --lr 0.001 15 | --dropout 0.0 16 | ## DyResGEN 17 | ### Train 18 | SoftMax aggregator with learnable t (initialized as 1.0) 19 | 20 | python main.py --use_gpu --num_layers 7 --block res+ --gcn_aggr softmax --learn_t --t 1.0 21 | 22 | ### Test (use pre-trained model, [DyResGEN](https://drive.google.com/file/d/1aPzYzXiKBN7vnSVHFfO010zJwTYWazgM/view?usp=sharing) and [Link Predictor](https://drive.google.com/file/d/1Y-UZjIxXA6swX8qGLs041Cg_qSh7SFgx/view?usp=sharing) from Google Drive) 23 | python test.py --use_gpu --num_layers 7 --block res+ --gcn_aggr softmax --learn_t --t 1.0 24 | -------------------------------------------------------------------------------- /examples/ogb/ogbl_collab/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 4 | sys.path.append(ROOT_DIR) 5 | -------------------------------------------------------------------------------- /examples/ogb/ogbl_collab/args.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | import argparse 3 | import uuid 4 | import logging 5 | import time 6 | import os 7 | import sys 8 | from utils.logger import create_exp_dir 9 | import glob 10 | 11 | 12 | class ArgsInit(object): 13 | def __init__(self): 14 | parser = argparse.ArgumentParser(description='DeeperGCN') 15 | # dataset 16 | parser.add_argument('--dataset', type=str, default='ogbl-collab', 17 | help='dataset name (default: ogbl-collab)') 18 | parser.add_argument('--self_loop', action='store_true') 19 | # training & eval settings 20 | parser.add_argument('--use_gpu', action='store_true') 21 | parser.add_argument('--device', type=int, default=0, 22 | help='which gpu to use if any (default: 0)') 23 | parser.add_argument('--epochs', type=int, default=400, 24 | help='number of epochs to train (default: 400)') 25 | parser.add_argument('--lr', type=float, default=0.001, 26 | help='learning rate set for optimizer.') 27 | parser.add_argument('--dropout', type=float, default=0.0) 28 | parser.add_argument('--batch_size', type=int, default=64 * 1024, 29 | help='the number of edges per batch') 30 | # model 31 | parser.add_argument('--num_layers', type=int, default=3, 32 | help='the number of layers of the networks') 33 | parser.add_argument('--lp_num_layers', type=int, default=3, 34 | help='the number of layers of the link predictor model') 35 | parser.add_argument('--mlp_layers', type=int, default=1, 36 | help='the number of layers of mlp in conv') 37 | parser.add_argument('--in_channels', type=int, default=128, 38 | help='the dimension of initial embeddings of nodes') 39 | parser.add_argument('--hidden_channels', type=int, default=128, 40 | help='the dimension of embeddings of nodes') 41 | parser.add_argument('--block', default='res+', type=str, 42 | help='graph backbone block type {res+, res, dense, plain}') 43 | parser.add_argument('--conv', type=str, default='gen', 44 | help='the type of GCNs') 45 | parser.add_argument('--gcn_aggr', type=str, default='max', 46 | help='the aggregator of GENConv [mean, max, add, softmax, softmax_sg, power]') 47 | parser.add_argument('--norm', type=str, default='batch', 48 | help='the type of normalization layer') 49 | parser.add_argument('--lp_norm', type=str, default='none', 50 | help='the type of normalization layer for link predictor') 51 | parser.add_argument('--num_tasks', type=int, default=1, 52 | help='the number of prediction tasks') 53 | # learnable parameters 54 | parser.add_argument('--t', type=float, default=1.0, 55 | help='the temperature of SoftMax') 56 | parser.add_argument('--p', type=float, default=1.0, 57 | help='the power of PowerMean') 58 | parser.add_argument('--learn_t', action='store_true') 59 | parser.add_argument('--learn_p', action='store_true') 60 | parser.add_argument('--y', type=float, default=0.0, 61 | help='the power of softmax_sum and powermean_sum') 62 | parser.add_argument('--learn_y', action='store_true') 63 | 64 | # message norm 65 | parser.add_argument('--msg_norm', action='store_true') 66 | parser.add_argument('--scale_msg', action='store_true') 67 | parser.add_argument('--learn_msg_scale', action='store_true') 68 | 69 | # save model 70 | parser.add_argument('--model_save_path', type=str, default='model_ckpt', 71 | help='the directory used to save models') 72 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 73 | parser.add_argument('--model_load_path', type=str, default='ogbl_collab_pretrained_model_Hits@50.pth', 74 | help='the path of pre-trained deeperGCN model') 75 | parser.add_argument('--predictor_load_path', type=str, default='ogbl_collab_pretrained_link_predictor_Hits@50.pth', 76 | help='the path of pre-trained predictor model') 77 | parser.add_argument('--use_tensor_board', action='store_true') 78 | self.args = parser.parse_args() 79 | 80 | def save_exp(self): 81 | self.args.save = 'LR_{}-{}-B_{}-C_{}-L_{}-LPL_{}-F_{}' \ 82 | '-NORM_{}-DP_{}-GA_{}-T_{}-LT_{}-' \ 83 | 'P_{}-LP_{}-MN_{}-LS_{}'.format(self.args.lr, 84 | self.args.save, self.args.block, self.args.conv, 85 | self.args.num_layers, self.args.lp_num_layers, 86 | self.args.hidden_channels, self.args.norm, 87 | self.args.dropout, self.args.gcn_aggr, 88 | self.args.t, self.args.learn_t, self.args.p, self.args.learn_p, 89 | self.args.msg_norm, self.args.learn_msg_scale) 90 | 91 | self.args.save = 'log/{}-{}-{}'.format(self.args.save, time.strftime("%Y%m%d-%H%M%S"), str(uuid.uuid4())) 92 | self.args.model_save_path = os.path.join(self.args.save, self.args.model_save_path) 93 | create_exp_dir(self.args.save, scripts_to_save=glob.glob('*.py')) 94 | log_format = '%(asctime)s %(message)s' 95 | logging.basicConfig(stream=sys.stdout, 96 | level=logging.INFO, 97 | format=log_format, 98 | datefmt='%m/%d %I:%M:%S %p') 99 | fh = logging.FileHandler(os.path.join(self.args.save, 'log.txt')) 100 | fh.setFormatter(logging.Formatter(log_format)) 101 | logging.getLogger().addHandler(fh) 102 | 103 | return self.args 104 | -------------------------------------------------------------------------------- /examples/ogb/ogbl_collab/test.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | from ogb.nodeproppred import Evaluator 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from args import ArgsInit 6 | from ogb.linkproppred import PygLinkPropPredDataset, Evaluator 7 | from model import DeeperGCN, LinkPredictor 8 | 9 | 10 | @torch.no_grad() 11 | def test(model, predictor, x, edge_index, split_edge, evaluator, batch_size): 12 | model.eval() 13 | predictor.eval() 14 | 15 | h = model(x, edge_index) 16 | 17 | pos_train_edge = split_edge['train']['edge'].to(h.device) 18 | pos_valid_edge = split_edge['valid']['edge'].to(h.device) 19 | neg_valid_edge = split_edge['valid']['edge_neg'].to(h.device) 20 | pos_test_edge = split_edge['test']['edge'].to(h.device) 21 | neg_test_edge = split_edge['test']['edge_neg'].to(h.device) 22 | 23 | pos_train_preds = [] 24 | for perm in DataLoader(range(pos_train_edge.size(0)), batch_size): 25 | edge = pos_train_edge[perm].t() 26 | pos_train_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()] 27 | pos_train_pred = torch.cat(pos_train_preds, dim=0) 28 | 29 | pos_valid_preds = [] 30 | for perm in DataLoader(range(pos_valid_edge.size(0)), batch_size): 31 | edge = pos_valid_edge[perm].t() 32 | pos_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()] 33 | pos_valid_pred = torch.cat(pos_valid_preds, dim=0) 34 | 35 | neg_valid_preds = [] 36 | for perm in DataLoader(range(neg_valid_edge.size(0)), batch_size): 37 | edge = neg_valid_edge[perm].t() 38 | neg_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()] 39 | neg_valid_pred = torch.cat(neg_valid_preds, dim=0) 40 | 41 | pos_test_preds = [] 42 | for perm in DataLoader(range(pos_test_edge.size(0)), batch_size): 43 | edge = pos_test_edge[perm].t() 44 | pos_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()] 45 | pos_test_pred = torch.cat(pos_test_preds, dim=0) 46 | 47 | neg_test_preds = [] 48 | for perm in DataLoader(range(neg_test_edge.size(0)), batch_size): 49 | edge = neg_test_edge[perm].t() 50 | neg_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()] 51 | neg_test_pred = torch.cat(neg_test_preds, dim=0) 52 | 53 | results = {} 54 | for K in [10, 50, 100]: 55 | evaluator.K = K 56 | train_hits = evaluator.eval({ 57 | 'y_pred_pos': pos_train_pred, 58 | 'y_pred_neg': neg_valid_pred, 59 | })[f'hits@{K}'] 60 | valid_hits = evaluator.eval({ 61 | 'y_pred_pos': pos_valid_pred, 62 | 'y_pred_neg': neg_valid_pred, 63 | })[f'hits@{K}'] 64 | test_hits = evaluator.eval({ 65 | 'y_pred_pos': pos_test_pred, 66 | 'y_pred_neg': neg_test_pred, 67 | })[f'hits@{K}'] 68 | 69 | results[f'Hits@{K}'] = (train_hits, valid_hits, test_hits) 70 | 71 | return results 72 | 73 | 74 | def main(): 75 | 76 | args = ArgsInit().args 77 | 78 | if args.use_gpu: 79 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 80 | else: 81 | device = torch.device('cpu') 82 | 83 | dataset = PygLinkPropPredDataset(name=args.dataset) 84 | data = dataset[0] 85 | # Data(edge_index=[2, 2358104], edge_weight=[2358104, 1], edge_year=[2358104, 1], x=[235868, 128]) 86 | split_edge = dataset.get_edge_split() 87 | evaluator = Evaluator(args.dataset) 88 | 89 | x = data.x.to(device) 90 | 91 | edge_index = data.edge_index.to(device) 92 | 93 | args.in_channels = data.x.size(-1) 94 | args.num_tasks = 1 95 | 96 | print(args) 97 | 98 | model = DeeperGCN(args).to(device) 99 | predictor = LinkPredictor(args).to(device) 100 | 101 | model.load_state_dict(torch.load(args.model_load_path)['model_state_dict']) 102 | model.to(device) 103 | 104 | predictor.load_state_dict(torch.load(args.predictor_load_path)['model_state_dict']) 105 | predictor.to(device) 106 | 107 | hits = ['Hits@10', 'Hits@50', 'Hits@100'] 108 | 109 | result = test(model, predictor, x, edge_index, split_edge, evaluator, args.batch_size) 110 | 111 | for k in hits: 112 | train_result, valid_result, test_result = result[k] 113 | print('{}--Train: {}, Validation: {}, Test: {}'.format(k, 114 | train_result, 115 | valid_result, 116 | test_result)) 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /examples/ogb/ogbn_arxiv/README.md: -------------------------------------------------------------------------------- 1 | # ogbn-arxiv 2 | ## Default 3 | --use_gpu False 4 | --self_loop False 5 | --block res+ #options: [plain, res, res+] 6 | --conv gen 7 | --gcn_aggr max #options: [max, mean, add, softmax, softmax_sg, softmax_sum, power, power_sum] 8 | --num_layers 3 9 | --mlp_layers 1 10 | --norm batch 11 | --hidden_channels 128 12 | --epochs 500 13 | --lr 0.001 14 | --dropout 0.5 15 | ## ResGEN 16 | ### Train 17 | python main.py --use_gpu --self_loop --num_layers 28 --block res+ --gcn_aggr softmax_sg --t 0.1 18 | 19 | ### Test (use pre-trained model, [download](https://drive.google.com/file/d/19DA0SzfInkb3Q2cdeazejJ_mYMAvRZyb/view?usp=sharing) from Google Drive) 20 | python test.py --use_gpu --self_loop --num_layers 28 --block res+ --gcn_aggr softmax_sg --t 0.1 21 | -------------------------------------------------------------------------------- /examples/ogb/ogbn_arxiv/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 4 | sys.path.append(ROOT_DIR) 5 | -------------------------------------------------------------------------------- /examples/ogb/ogbn_arxiv/args.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | import argparse 3 | import uuid 4 | import logging 5 | import time 6 | import os 7 | import sys 8 | from utils.logger import create_exp_dir 9 | import glob 10 | 11 | 12 | class ArgsInit(object): 13 | def __init__(self): 14 | parser = argparse.ArgumentParser(description='DeeperGCN') 15 | # dataset 16 | parser.add_argument('--dataset', type=str, default='ogbn-arxiv', 17 | help='dataset name (default: ogbn-arxiv)') 18 | parser.add_argument('--self_loop', action='store_true') 19 | # training & eval settings 20 | parser.add_argument('--use_gpu', action='store_true') 21 | parser.add_argument('--device', type=int, default=0, 22 | help='which gpu to use if any (default: 0)') 23 | parser.add_argument('--epochs', type=int, default=500, 24 | help='number of epochs to train (default: 500)') 25 | parser.add_argument('--lr', type=float, default=0.001, 26 | help='learning rate set for optimizer.') 27 | parser.add_argument('--dropout', type=float, default=0.5) 28 | # model 29 | parser.add_argument('--num_layers', type=int, default=3, 30 | help='the number of layers of the networks') 31 | parser.add_argument('--mlp_layers', type=int, default=1, 32 | help='the number of layers of mlp in conv') 33 | parser.add_argument('--in_channels', type=int, default=128, 34 | help='the dimension of initial embeddings of nodes') 35 | parser.add_argument('--hidden_channels', type=int, default=128, 36 | help='the dimension of embeddings of nodes') 37 | parser.add_argument('--block', default='res+', type=str, 38 | help='graph backbone block type {res+, res, dense, plain}') 39 | parser.add_argument('--conv', type=str, default='gen', 40 | help='the type of GCNs') 41 | parser.add_argument('--gcn_aggr', type=str, default='max', 42 | help='the aggregator of GENConv [mean, max, add, softmax, softmax_sg, softmax_sum, power, power_sum]') 43 | parser.add_argument('--norm', type=str, default='batch', 44 | help='the type of normalization layer') 45 | parser.add_argument('--num_tasks', type=int, default=1, 46 | help='the number of prediction tasks') 47 | # learnable parameters 48 | parser.add_argument('--t', type=float, default=1.0, 49 | help='the temperature of SoftMax') 50 | parser.add_argument('--p', type=float, default=1.0, 51 | help='the power of PowerMean') 52 | parser.add_argument('--y', type=float, default=0.0, 53 | help='the power of degrees') 54 | parser.add_argument('--learn_t', action='store_true') 55 | parser.add_argument('--learn_p', action='store_true') 56 | parser.add_argument('--learn_y', action='store_true') 57 | # message norm 58 | parser.add_argument('--msg_norm', action='store_true') 59 | parser.add_argument('--learn_msg_scale', action='store_true') 60 | # save model 61 | parser.add_argument('--model_save_path', type=str, default='model_ckpt', 62 | help='the directory used to save models') 63 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 64 | # load pre-trained model 65 | parser.add_argument('--model_load_path', type=str, default='ogbn_arxiv_pretrained_model.pth', 66 | help='the path of pre-trained model') 67 | 68 | self.args = parser.parse_args() 69 | 70 | def save_exp(self): 71 | self.args.save = '{}-B_{}-C_{}-L_{}-F_{}-DP_{}' \ 72 | '-GA_{}-T_{}-LT_{}-P_{}-LP_{}-Y_{}-LY_{}' \ 73 | '-MN_{}-LS_{}'.format(self.args.save, self.args.block, self.args.conv, 74 | self.args.num_layers, self.args.hidden_channels, 75 | self.args.dropout, self.args.gcn_aggr, 76 | self.args.t, self.args.learn_t, 77 | self.args.p, self.args.learn_p, 78 | self.args.y, self.args.learn_y, 79 | self.args.msg_norm, self.args.learn_msg_scale) 80 | 81 | self.args.save = 'log/{}-{}-{}'.format(self.args.save, time.strftime("%Y%m%d-%H%M%S"), str(uuid.uuid4())) 82 | self.args.model_save_path = os.path.join(self.args.save, self.args.model_save_path) 83 | create_exp_dir(self.args.save, scripts_to_save=glob.glob('*.py')) 84 | log_format = '%(asctime)s %(message)s' 85 | logging.basicConfig(stream=sys.stdout, 86 | level=logging.INFO, 87 | format=log_format, 88 | datefmt='%m/%d %I:%M:%S %p') 89 | fh = logging.FileHandler(os.path.join(self.args.save, 'log.txt')) 90 | fh.setFormatter(logging.Formatter(log_format)) 91 | logging.getLogger().addHandler(fh) 92 | 93 | return self.args 94 | -------------------------------------------------------------------------------- /examples/ogb/ogbn_arxiv/main.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | from ogb.nodeproppred import Evaluator 3 | import torch 4 | import torch.nn.functional as F 5 | from torch_geometric.utils import to_undirected, add_self_loops 6 | from args import ArgsInit 7 | from ogb.nodeproppred import PygNodePropPredDataset 8 | from model import DeeperGCN 9 | from utils.ckpt_util import save_ckpt 10 | import logging 11 | import time 12 | 13 | 14 | @torch.no_grad() 15 | def test(model, x, edge_index, y_true, split_idx, evaluator): 16 | model.eval() 17 | out = model(x, edge_index) 18 | 19 | y_pred = out.argmax(dim=-1, keepdim=True) 20 | 21 | train_acc = evaluator.eval({ 22 | 'y_true': y_true[split_idx['train']], 23 | 'y_pred': y_pred[split_idx['train']], 24 | })['acc'] 25 | valid_acc = evaluator.eval({ 26 | 'y_true': y_true[split_idx['valid']], 27 | 'y_pred': y_pred[split_idx['valid']], 28 | })['acc'] 29 | test_acc = evaluator.eval({ 30 | 'y_true': y_true[split_idx['test']], 31 | 'y_pred': y_pred[split_idx['test']], 32 | })['acc'] 33 | 34 | return train_acc, valid_acc, test_acc 35 | 36 | 37 | def train(model, x, edge_index, y_true, train_idx, optimizer): 38 | model.train() 39 | 40 | optimizer.zero_grad() 41 | 42 | pred = model(x, edge_index)[train_idx] 43 | 44 | loss = F.nll_loss(pred, y_true.squeeze(1)[train_idx]) 45 | loss.backward() 46 | optimizer.step() 47 | 48 | return loss.item() 49 | 50 | 51 | def main(): 52 | 53 | args = ArgsInit().save_exp() 54 | logging.getLogger().setLevel(logging.INFO) 55 | 56 | if args.use_gpu: 57 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 58 | else: 59 | device = torch.device('cpu') 60 | 61 | dataset = PygNodePropPredDataset(name=args.dataset) 62 | data = dataset[0] 63 | split_idx = dataset.get_idx_split() 64 | 65 | evaluator = Evaluator(args.dataset) 66 | 67 | x = data.x.to(device) 68 | y_true = data.y.to(device) 69 | train_idx = split_idx['train'].to(device) 70 | 71 | edge_index = data.edge_index.to(device) 72 | edge_index = to_undirected(edge_index, data.num_nodes) 73 | 74 | if args.self_loop: 75 | edge_index = add_self_loops(edge_index, num_nodes=data.num_nodes)[0] 76 | 77 | sub_dir = 'SL_{}'.format(args.self_loop) 78 | 79 | args.in_channels = data.x.size(-1) 80 | args.num_tasks = dataset.num_classes 81 | 82 | logging.info('%s' % args) 83 | 84 | model = DeeperGCN(args).to(device) 85 | 86 | logging.info(model) 87 | 88 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 89 | 90 | results = {'highest_valid': 0, 91 | 'final_train': 0, 92 | 'final_test': 0, 93 | 'highest_train': 0} 94 | 95 | start_time = time.time() 96 | 97 | for epoch in range(1, args.epochs + 1): 98 | 99 | epoch_loss = train(model, x, edge_index, y_true, train_idx, optimizer) 100 | logging.info('Epoch {}, training loss {:.4f}'.format(epoch, epoch_loss)) 101 | model.print_params(epoch=epoch) 102 | 103 | result = test(model, x, edge_index, y_true, split_idx, evaluator) 104 | logging.info("%s" % results) 105 | train_accuracy, valid_accuracy, test_accuracy = result 106 | 107 | if train_accuracy > results['highest_train']: 108 | results['highest_train'] = train_accuracy 109 | 110 | if valid_accuracy > results['highest_valid']: 111 | results['highest_valid'] = valid_accuracy 112 | results['final_train'] = train_accuracy 113 | results['final_test'] = test_accuracy 114 | 115 | save_ckpt(model, optimizer, 116 | round(epoch_loss, 4), epoch, 117 | args.model_save_path, 118 | sub_dir, name_post='valid_best') 119 | 120 | logging.info("%s" % results) 121 | 122 | end_time = time.time() 123 | total_time = end_time - start_time 124 | logging.info('Total time: {}'.format(time.strftime('%H:%M:%S', time.gmtime(total_time)))) 125 | 126 | 127 | if __name__ == "__main__": 128 | main() 129 | -------------------------------------------------------------------------------- /examples/ogb/ogbn_arxiv/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import to_undirected, add_self_loops 3 | from ogb.nodeproppred import PygNodePropPredDataset 4 | from ogb.nodeproppred import Evaluator 5 | from args import ArgsInit 6 | from model import DeeperGCN 7 | 8 | 9 | @torch.no_grad() 10 | def test(model, x, edge_index, y_true, split_idx, evaluator): 11 | model.eval() 12 | out = model(x, edge_index) 13 | 14 | y_pred = out.argmax(dim=-1, keepdim=True) 15 | 16 | train_acc = evaluator.eval({ 17 | 'y_true': y_true[split_idx['train']], 18 | 'y_pred': y_pred[split_idx['train']], 19 | })['acc'] 20 | valid_acc = evaluator.eval({ 21 | 'y_true': y_true[split_idx['valid']], 22 | 'y_pred': y_pred[split_idx['valid']], 23 | })['acc'] 24 | test_acc = evaluator.eval({ 25 | 'y_true': y_true[split_idx['test']], 26 | 'y_pred': y_pred[split_idx['test']], 27 | })['acc'] 28 | 29 | return train_acc, valid_acc, test_acc 30 | 31 | 32 | def main(): 33 | 34 | args = ArgsInit().args 35 | 36 | if args.use_gpu: 37 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 38 | else: 39 | device = torch.device('cpu') 40 | 41 | dataset = PygNodePropPredDataset(name=args.dataset) 42 | data = dataset[0] 43 | split_idx = dataset.get_idx_split() 44 | 45 | evaluator = Evaluator(args.dataset) 46 | 47 | x = data.x.to(device) 48 | y_true = data.y.to(device) 49 | 50 | edge_index = data.edge_index.to(device) 51 | edge_index = to_undirected(edge_index, data.num_nodes) 52 | 53 | if args.self_loop: 54 | edge_index = add_self_loops(edge_index, num_nodes=data.num_nodes)[0] 55 | 56 | args.in_channels = data.x.size(-1) 57 | args.num_tasks = dataset.num_classes 58 | 59 | print(args) 60 | 61 | model = DeeperGCN(args) 62 | 63 | model.load_state_dict(torch.load(args.model_load_path)['model_state_dict']) 64 | model.to(device) 65 | 66 | result = test(model, x, edge_index, y_true, split_idx, evaluator) 67 | train_accuracy, valid_accuracy, test_accuracy = result 68 | 69 | print({'Train': train_accuracy, 70 | 'Validation': valid_accuracy, 71 | 'Test': test_accuracy}) 72 | 73 | model.print_params(final=True) 74 | 75 | 76 | if __name__ == "__main__": 77 | main() 78 | -------------------------------------------------------------------------------- /examples/ogb/ogbn_products/README.md: -------------------------------------------------------------------------------- 1 | # ogbn-products 2 | We simply apply a random partition to generate batches for mini-batch training on GPU and full-batch test on CPU. We set the number of partitions to be 10 for training and the batch size is 1 subgraph. 3 | ## Default 4 | --use_gpu False 5 | --self_loop False 6 | --cluster_number 10 7 | --block res+ #options: [plain, res, res+] 8 | --conv gen 9 | --gcn_aggr max #options: [max, mean, add, softmax, softmax_sg, softmax_sum, power, power_sum] 10 | --num_layers 3 11 | --mlp_layers 1 12 | --norm batch 13 | --hidden_channels 128 14 | --epochs 500 15 | --lr 0.001 16 | --dropout 0.5 17 | ## ResGEN 18 | ### Train 19 | python main.py --use_gpu --self_loop --num_layers 14 --gcn_aggr softmax_sg --t 0.1 20 | 21 | ### Test (use pre-trained model, [download](https://drive.google.com/file/d/1OxyA2IZN-4BCfkWzUG8QBS-khxhHHnZB/view?usp=sharing) from Google Drive) 22 | python test.py --self_loop --num_layers 14 --gcn_aggr softmax_sg --t 0.1 23 | -------------------------------------------------------------------------------- /examples/ogb/ogbn_products/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 4 | sys.path.append(ROOT_DIR) 5 | -------------------------------------------------------------------------------- /examples/ogb/ogbn_products/args.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | import argparse 3 | import uuid 4 | import logging 5 | import time 6 | import os 7 | import sys 8 | from utils.logger import create_exp_dir 9 | import glob 10 | 11 | 12 | class ArgsInit(object): 13 | def __init__(self): 14 | parser = argparse.ArgumentParser(description='DeeperGCN') 15 | # dataset 16 | parser.add_argument('--dataset', type=str, default='ogbn-products', 17 | help='dataset name (default: ogbn-products)') 18 | parser.add_argument('--cluster_number', type=int, default=10, 19 | help='the number of sub-graphs for training') 20 | parser.add_argument('--self_loop', action='store_true') 21 | # training & eval settings 22 | parser.add_argument('--use_gpu', action='store_true') 23 | parser.add_argument('--device', type=int, default=0, 24 | help='which gpu to use if any (default: 0)') 25 | parser.add_argument('--epochs', type=int, default=500, 26 | help='number of epochs to train (default: 500)') 27 | parser.add_argument('--lr', type=float, default=0.001, 28 | help='learning rate set for optimizer.') 29 | parser.add_argument('--dropout', type=float, default=0.5) 30 | # model 31 | parser.add_argument('--num_layers', type=int, default=3, 32 | help='the number of layers of the networks') 33 | parser.add_argument('--mlp_layers', type=int, default=1, 34 | help='the number of layers of mlp in conv') 35 | parser.add_argument('--in_channels', type=int, default=128, 36 | help='the dimension of initial embeddings of nodes') 37 | parser.add_argument('--hidden_channels', type=int, default=128, 38 | help='the dimension of embeddings of nodes') 39 | parser.add_argument('--block', default='res+', type=str, 40 | help='graph backbone block type {res+, res, dense, plain}') 41 | parser.add_argument('--conv', type=str, default='gen', 42 | help='the type of GCNs') 43 | parser.add_argument('--gcn_aggr', type=str, default='max', 44 | help='the aggregator of GENConv [mean, max, add, softmax, softmax_sg, power]') 45 | parser.add_argument('--norm', type=str, default='batch', 46 | help='the type of normalization layer') 47 | parser.add_argument('--num_tasks', type=int, default=1, 48 | help='the number of prediction tasks') 49 | # learnable parameters 50 | parser.add_argument('--t', type=float, default=1.0, 51 | help='the temperature of SoftMax') 52 | parser.add_argument('--p', type=float, default=1.0, 53 | help='the power of PowerMean') 54 | parser.add_argument('--learn_t', action='store_true') 55 | parser.add_argument('--learn_p', action='store_true') 56 | # message norm 57 | parser.add_argument('--msg_norm', action='store_true') 58 | parser.add_argument('--learn_msg_scale', action='store_true') 59 | # save model 60 | parser.add_argument('--model_save_path', type=str, default='model_ckpt', 61 | help='the directory used to save models') 62 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 63 | # load pre-trained model 64 | parser.add_argument('--model_load_path', type=str, default='ogbn_products_pretrained_model.pth', 65 | help='the path of pre-trained model') 66 | 67 | self.args = parser.parse_args() 68 | 69 | def save_exp(self): 70 | self.args.save = '{}-B_{}-C_{}-L_{}-F_{}-DP_{}' \ 71 | '-GA_{}-T_{}-LT_{}-P_{}-LP_{}' \ 72 | '-MN_{}-LS_{}'.format(self.args.save, self.args.block, self.args.conv, 73 | self.args.num_layers, self.args.hidden_channels, 74 | self.args.dropout, self.args.gcn_aggr, 75 | self.args.t, self.args.learn_t, self.args.p, self.args.learn_p, 76 | self.args.msg_norm, self.args.learn_msg_scale) 77 | 78 | self.args.save = 'log/{}-{}-{}'.format(self.args.save, time.strftime("%Y%m%d-%H%M%S"), str(uuid.uuid4())) 79 | self.args.model_save_path = os.path.join(self.args.save, self.args.model_save_path) 80 | create_exp_dir(self.args.save, scripts_to_save=glob.glob('*.py')) 81 | log_format = '%(asctime)s %(message)s' 82 | logging.basicConfig(stream=sys.stdout, 83 | level=logging.INFO, 84 | format=log_format, 85 | datefmt='%m/%d %I:%M:%S %p') 86 | fh = logging.FileHandler(os.path.join(self.args.save, 'log.txt')) 87 | fh.setFormatter(logging.Formatter(log_format)) 88 | logging.getLogger().addHandler(fh) 89 | 90 | return self.args 91 | -------------------------------------------------------------------------------- /examples/ogb/ogbn_products/main.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | from ogb.nodeproppred import Evaluator 3 | import torch 4 | from torch_sparse import SparseTensor 5 | import torch.nn.functional as F 6 | from torch_geometric.utils import add_self_loops 7 | from utils.data_util import intersection, random_partition_graph, generate_sub_graphs 8 | from args import ArgsInit 9 | from ogb.nodeproppred import PygNodePropPredDataset 10 | from model import DeeperGCN 11 | import numpy as np 12 | from utils.ckpt_util import save_ckpt 13 | import logging 14 | import statistics 15 | import time 16 | 17 | 18 | @torch.no_grad() 19 | def test(model, x, edge_index, y_true, split_idx, evaluator): 20 | # test on CPU 21 | model.eval() 22 | model.to('cpu') 23 | out = model(x, edge_index) 24 | 25 | y_pred = out.argmax(dim=-1, keepdim=True) 26 | 27 | train_acc = evaluator.eval({ 28 | 'y_true': y_true[split_idx['train']], 29 | 'y_pred': y_pred[split_idx['train']], 30 | })['acc'] 31 | valid_acc = evaluator.eval({ 32 | 'y_true': y_true[split_idx['valid']], 33 | 'y_pred': y_pred[split_idx['valid']], 34 | })['acc'] 35 | test_acc = evaluator.eval({ 36 | 'y_true': y_true[split_idx['test']], 37 | 'y_pred': y_pred[split_idx['test']], 38 | })['acc'] 39 | 40 | return train_acc, valid_acc, test_acc 41 | 42 | 43 | def train(data, model, x, y_true, train_idx, optimizer, device): 44 | loss_list = [] 45 | model.train() 46 | 47 | sg_nodes, sg_edges = data 48 | train_y = y_true[train_idx].squeeze(1) 49 | 50 | idx_clusters = np.arange(len(sg_nodes)) 51 | np.random.shuffle(idx_clusters) 52 | 53 | for idx in idx_clusters: 54 | 55 | x_ = x[sg_nodes[idx]].to(device) 56 | sg_edges_ = sg_edges[idx].to(device) 57 | mapper = {node: idx for idx, node in enumerate(sg_nodes[idx])} 58 | 59 | inter_idx = intersection(sg_nodes[idx], train_idx) 60 | training_idx = [mapper[t_idx] for t_idx in inter_idx] 61 | 62 | optimizer.zero_grad() 63 | 64 | pred = model(x_, sg_edges_) 65 | target = train_y[inter_idx].to(device) 66 | 67 | loss = F.nll_loss(pred[training_idx], target) 68 | loss.backward() 69 | optimizer.step() 70 | loss_list.append(loss.item()) 71 | 72 | return statistics.mean(loss_list) 73 | 74 | 75 | def main(): 76 | 77 | args = ArgsInit().save_exp() 78 | 79 | if args.use_gpu: 80 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 81 | else: 82 | device = torch.device('cpu') 83 | 84 | dataset = PygNodePropPredDataset(name=args.dataset) 85 | graph = dataset[0] 86 | 87 | adj = SparseTensor(row=graph.edge_index[0], 88 | col=graph.edge_index[1]) 89 | 90 | if args.self_loop: 91 | adj = adj.set_diag() 92 | graph.edge_index = add_self_loops(edge_index=graph.edge_index, 93 | num_nodes=graph.num_nodes)[0] 94 | split_idx = dataset.get_idx_split() 95 | train_idx = split_idx["train"].tolist() 96 | 97 | evaluator = Evaluator(args.dataset) 98 | 99 | sub_dir = 'random-train_{}-full_batch_test'.format(args.cluster_number) 100 | logging.info(sub_dir) 101 | 102 | args.in_channels = graph.x.size(-1) 103 | args.num_tasks = dataset.num_classes 104 | 105 | logging.info('%s' % args) 106 | 107 | model = DeeperGCN(args).to(device) 108 | 109 | logging.info(model) 110 | 111 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 112 | 113 | results = {'highest_valid': 0, 114 | 'final_train': 0, 115 | 'final_test': 0, 116 | 'highest_train': 0} 117 | 118 | start_time = time.time() 119 | 120 | for epoch in range(1, args.epochs + 1): 121 | # generate batches 122 | parts = random_partition_graph(graph.num_nodes, 123 | cluster_number=args.cluster_number) 124 | data = generate_sub_graphs(adj, parts, cluster_number=args.cluster_number) 125 | 126 | epoch_loss = train(data, model, graph.x, graph.y, train_idx, optimizer, device) 127 | logging.info('Epoch {}, training loss {:.4f}'.format(epoch, epoch_loss)) 128 | model.print_params(epoch=epoch) 129 | 130 | if epoch == args.epochs: 131 | 132 | result = test(model, graph.x, graph.edge_index, graph.y, split_idx, evaluator) 133 | logging.info(result) 134 | 135 | train_accuracy, valid_accuracy, test_accuracy = result 136 | 137 | if train_accuracy > results['highest_train']: 138 | results['highest_train'] = train_accuracy 139 | 140 | if valid_accuracy > results['highest_valid']: 141 | results['highest_valid'] = valid_accuracy 142 | results['final_train'] = train_accuracy 143 | results['final_test'] = test_accuracy 144 | 145 | save_ckpt(model, optimizer, 146 | round(epoch_loss, 4), epoch, 147 | args.model_save_path, 148 | sub_dir, name_post='valid_best') 149 | 150 | logging.info("%s" % results) 151 | 152 | end_time = time.time() 153 | total_time = end_time - start_time 154 | logging.info('Total time: {}'.format(time.strftime('%H:%M:%S', time.gmtime(total_time)))) 155 | 156 | 157 | if __name__ == "__main__": 158 | main() 159 | -------------------------------------------------------------------------------- /examples/ogb/ogbn_products/test.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | from ogb.nodeproppred import Evaluator 3 | import torch 4 | from torch_geometric.utils import add_self_loops 5 | from args import ArgsInit 6 | from ogb.nodeproppred import PygNodePropPredDataset 7 | from model import DeeperGCN 8 | 9 | 10 | @torch.no_grad() 11 | def test(model, x, edge_index, y_true, split_idx, evaluator): 12 | # test on CPU 13 | model.eval() 14 | out = model(x, edge_index) 15 | 16 | y_pred = out.argmax(dim=-1, keepdim=True) 17 | 18 | train_acc = evaluator.eval({ 19 | 'y_true': y_true[split_idx['train']], 20 | 'y_pred': y_pred[split_idx['train']], 21 | })['acc'] 22 | valid_acc = evaluator.eval({ 23 | 'y_true': y_true[split_idx['valid']], 24 | 'y_pred': y_pred[split_idx['valid']], 25 | })['acc'] 26 | test_acc = evaluator.eval({ 27 | 'y_true': y_true[split_idx['test']], 28 | 'y_pred': y_pred[split_idx['test']], 29 | })['acc'] 30 | 31 | return train_acc, valid_acc, test_acc 32 | 33 | 34 | def main(): 35 | 36 | args = ArgsInit().args 37 | 38 | dataset = PygNodePropPredDataset(name=args.dataset) 39 | graph = dataset[0] 40 | 41 | if args.self_loop: 42 | graph.edge_index = add_self_loops(edge_index=graph.edge_index, 43 | num_nodes=graph.num_nodes)[0] 44 | split_idx = dataset.get_idx_split() 45 | 46 | evaluator = Evaluator(args.dataset) 47 | 48 | args.in_channels = graph.x.size(-1) 49 | args.num_tasks = dataset.num_classes 50 | 51 | print(args) 52 | 53 | model = DeeperGCN(args) 54 | 55 | print(model) 56 | 57 | model.load_state_dict(torch.load(args.model_load_path)['model_state_dict']) 58 | result = test(model, graph.x, graph.edge_index, graph.y, split_idx, evaluator) 59 | print(result) 60 | model.print_params(final=True) 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | -------------------------------------------------------------------------------- /examples/ogb/ogbn_proteins/README.md: -------------------------------------------------------------------------------- 1 | # ogbn-proteins 2 | 3 | We simply apply a random partition to generate batches for both mini-batch training and test. We set the number of partitions to be 10 for training and 5 for test, and we set the batch size to 1 subgraph. We initialize the features of nodes through aggregating the features of their connected edges by a Sum (Add) aggregation. 4 | ## Default 5 | --use_gpu False 6 | --cluster_number 10 7 | --valid_cluster_number 5 8 | --aggr add #options: [mean, max, add] 9 | --block plain #options: [plain, res, res+] 10 | --conv gen 11 | --gcn_aggr max #options: [max, mean, add, softmax, softmax_sg, softmax_sum, power, power_sum] 12 | --num_layers 3 13 | --conv_encode_edge False 14 | --use_one_hot_encoding False 15 | --mlp_layers 2 16 | --norm layer 17 | --hidden_channels 64 18 | --epochs 1000 19 | --lr 0.001 20 | --dropout 0.0 21 | --num_evals 1 22 | 23 | ## DyResGEN-112 24 | 25 | ### Train the model that performs best 26 | python main.py --use_gpu --conv_encode_edge --use_one_hot_encoding --num_layers 112 --block res+ --gcn_aggr softmax --t 1.0 --learn_t --dropout 0.1 27 | ### Test (use pre-trained model, [download](https://drive.google.com/file/d/1LjsgXZo02WgzpIJe-SQHrbrwEuQl8VQk/view?usp=sharing) from Google Drive) 28 | python test.py --use_gpu --conv_encode_edge --use_one_hot_encoding --num_layers 112 --block res+ --gcn_aggr softmax --t 1.0 --learn_t --dropout 0.1 29 | ### Test by multiple evaluations (e.g. 5 times) 30 | 31 | python test.py --use_gpu --num_evals 5 --conv_encode_edge --use_one_hot_encoding --num_layers 112 --block res+ --gcn_aggr softmax --t 1.0 --learn_t --dropout 0.1 32 | 33 | ## Train ResGCN-112 34 | python main.py --use_gpu --conv_encode_edge --use_one_hot_encoding --num_layers 112 --block res --gcn_aggr max 35 | 36 | #### Train with different GCN models with 28 layers on GPU 37 | 38 | SoftMax aggregator with learnable t (initialized as 1.0) 39 | 40 | python main.py --use_gpu --conv_encode_edge --use_one_hot_encoding --num_layers 28 --block res+ --gcn_aggr softmax --t 1.0 --learn_t 41 | 42 | PowerMean aggregator with learnable p (initialized as 1.0) 43 | 44 | python main.py --use_gpu --conv_encode_edge --use_one_hot_encoding --num_layers 28 --block res+ --gcn_aggr power --p 1.0 --learn_p 45 | 46 | Apply MsgNorm (message normalization) layer (e.g. SoftMax aggregator with fixed t (e.g. 0.1)) 47 | 48 | **Not learn parameter s (message scale)** 49 | 50 | python main.py --use_gpu --conv_encode_edge --use_one_hot_encoding --num_layers 28 --block res+ --gcn_aggr softmax_sg --t 0.1 --msg_norm 51 | **Learn parameter s (message scale)** 52 | 53 | python main.py --use_gpu --conv_encode_edge --use_one_hot_encoding --num_layers 28 --block res+ --gcn_aggr softmax_sg --t 0.1 --msg_norm --learn_msg_scale 54 | 55 | ## ResGEN 56 | SoftMax aggregator with fixed t (e.g. 0.001) 57 | 58 | python main.py --use_gpu --conv_encode_edge --use_one_hot_encoding --num_layers 28 --block res+ --gcn_aggr softmax_sg --t 0.001 59 | 60 | PowerMean aggregator with fixed p (e.g. 5.0) 61 | 62 | python main.py --use_gpu --conv_encode_edge --use_one_hot_encoding --num_layers 28 --block res+ --gcn_aggr power --p 5.0 63 | ## ResGCN+ 64 | python main.py --use_gpu --conv_encode_edge --use_one_hot_encoding --num_layers 28 --block res+ --gcn_aggr mean 65 | ## ResGCN 66 | python main.py --use_gpu --conv_encode_edge --use_one_hot_encoding --num_layers 28 --block res --gcn_aggr mean 67 | ## PlainGCN 68 | python main.py --use_gpu --conv_encode_edge --use_one_hot_encoding --num_layers 28 --gcn_aggr mean 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /examples/ogb/ogbn_proteins/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 4 | sys.path.append(ROOT_DIR) 5 | -------------------------------------------------------------------------------- /examples/ogb/ogbn_proteins/args.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | import argparse 3 | import uuid 4 | import logging 5 | import time 6 | import os 7 | import sys 8 | from utils.logger import create_exp_dir 9 | import glob 10 | 11 | 12 | class ArgsInit(object): 13 | def __init__(self): 14 | parser = argparse.ArgumentParser(description='DeeperGCN') 15 | # dataset 16 | parser.add_argument('--dataset', type=str, default='ogbn-proteins', 17 | help='dataset name (default: ogbn-proteins)') 18 | parser.add_argument('--cluster_number', type=int, default=10, 19 | help='the number of sub-graphs for training') 20 | parser.add_argument('--valid_cluster_number', type=int, default=5, 21 | help='the number of sub-graphs for evaluation') 22 | parser.add_argument('--aggr', type=str, default='add', 23 | help='the aggregation operator to obtain nodes\' initial features [mean, max, add]') 24 | parser.add_argument('--nf_path', type=str, default='init_node_features_add.pt', 25 | help='the file path of extracted node features saved.') 26 | # training & eval settings 27 | parser.add_argument('--use_gpu', action='store_true') 28 | parser.add_argument('--device', type=int, default=0, 29 | help='which gpu to use if any (default: 0)') 30 | parser.add_argument('--epochs', type=int, default=1000, 31 | help='number of epochs to train (default: 100)') 32 | parser.add_argument('--num_evals', type=int, default=1, 33 | help='The number of evaluation times') 34 | parser.add_argument('--lr', type=float, default=0.001, 35 | help='learning rate set for optimizer.') 36 | parser.add_argument('--dropout', type=float, default=0.0) 37 | # model 38 | parser.add_argument('--num_layers', type=int, default=3, 39 | help='the number of layers of the networks') 40 | parser.add_argument('--mlp_layers', type=int, default=2, 41 | help='the number of layers of mlp in conv') 42 | parser.add_argument('--hidden_channels', type=int, default=64, 43 | help='the dimension of embeddings of nodes and edges') 44 | parser.add_argument('--block', default='plain', type=str, 45 | help='graph backbone block type {res+, res, dense, plain}') 46 | parser.add_argument('--conv', type=str, default='gen', 47 | help='the type of GCNs') 48 | parser.add_argument('--gcn_aggr', type=str, default='max', 49 | help='the aggregator of GENConv [mean, max, add, softmax, softmax_sg, softmax_sum, power, power_sum]') 50 | parser.add_argument('--norm', type=str, default='layer', 51 | help='the type of normalization layer') 52 | parser.add_argument('--num_tasks', type=int, default=1, 53 | help='the number of prediction tasks') 54 | # learnable parameters 55 | parser.add_argument('--t', type=float, default=1.0, 56 | help='the temperature of SoftMax') 57 | parser.add_argument('--p', type=float, default=1.0, 58 | help='the power of PowerMean') 59 | parser.add_argument('--y', type=float, default=0.0, 60 | help='the power of degrees') 61 | parser.add_argument('--learn_t', action='store_true') 62 | parser.add_argument('--learn_p', action='store_true') 63 | parser.add_argument('--learn_y', action='store_true') 64 | # message norm 65 | parser.add_argument('--msg_norm', action='store_true') 66 | parser.add_argument('--learn_msg_scale', action='store_true') 67 | # encode edge in conv 68 | parser.add_argument('--conv_encode_edge', action='store_true') 69 | # if use one-hot-encoding node feature 70 | parser.add_argument('--use_one_hot_encoding', action='store_true') 71 | # save model 72 | parser.add_argument('--model_save_path', type=str, default='model_ckpt', 73 | help='the directory used to save models') 74 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 75 | # load pre-trained model 76 | parser.add_argument('--model_load_path', type=str, default='ogbn_proteins_pretrained_model.pth', 77 | help='the path of pre-trained model') 78 | 79 | self.args = parser.parse_args() 80 | 81 | def save_exp(self): 82 | self.args.save = '{}-B_{}-C_{}-L_{}-F_{}-DP_{}' \ 83 | '-A_{}-GA_{}-T_{}-LT_{}-P_{}-LP_{}-Y_{}-LY_{}' \ 84 | '-MN_{}-LS_{}'.format(self.args.save, self.args.block, self.args.conv, 85 | self.args.num_layers, self.args.hidden_channels, self.args.dropout, 86 | self.args.aggr, self.args.gcn_aggr, 87 | self.args.t, self.args.learn_t, 88 | self.args.p, self.args.learn_p, 89 | self.args.y, self.args.learn_y, 90 | self.args.msg_norm, self.args.learn_msg_scale) 91 | 92 | self.args.save = 'log/{}-{}-{}'.format(self.args.save, time.strftime("%Y%m%d-%H%M%S"), str(uuid.uuid4())) 93 | self.args.model_save_path = os.path.join(self.args.save, self.args.model_save_path) 94 | create_exp_dir(self.args.save, scripts_to_save=glob.glob('*.py')) 95 | log_format = '%(asctime)s %(message)s' 96 | logging.basicConfig(stream=sys.stdout, 97 | level=logging.INFO, 98 | format=log_format, 99 | datefmt='%m/%d %I:%M:%S %p') 100 | fh = logging.FileHandler(os.path.join(self.args.save, 'log.txt')) 101 | fh.setFormatter(logging.Formatter(log_format)) 102 | logging.getLogger().addHandler(fh) 103 | 104 | return self.args 105 | -------------------------------------------------------------------------------- /examples/ogb/ogbn_proteins/test.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | import torch 3 | from dataset import OGBNDataset 4 | from utils.data_util import intersection, process_indexes 5 | import numpy as np 6 | from ogb.nodeproppred import Evaluator 7 | from model import DeeperGCN 8 | from args import ArgsInit 9 | 10 | 11 | @torch.no_grad() 12 | def multi_evaluate(valid_data_list, dataset, model, evaluator, device): 13 | model.eval() 14 | target = dataset.y.detach().numpy() 15 | 16 | train_pre_ordered_list = [] 17 | valid_pre_ordered_list = [] 18 | test_pre_ordered_list = [] 19 | 20 | test_idx = dataset.test_idx.tolist() 21 | train_idx = dataset.train_idx.tolist() 22 | valid_idx = dataset.valid_idx.tolist() 23 | 24 | for valid_data_item in valid_data_list: 25 | sg_nodes, sg_edges, sg_edges_index, _ = valid_data_item 26 | idx_clusters = np.arange(len(sg_nodes)) 27 | 28 | test_predict = [] 29 | test_target_idx = [] 30 | 31 | train_predict = [] 32 | valid_predict = [] 33 | 34 | train_target_idx = [] 35 | valid_target_idx = [] 36 | 37 | for idx in idx_clusters: 38 | x = dataset.x[sg_nodes[idx]].float().to(device) 39 | sg_nodes_idx = torch.LongTensor(sg_nodes[idx]).to(device) 40 | 41 | mapper = {node: idx for idx, node in enumerate(sg_nodes[idx])} 42 | sg_edges_attr = dataset.edge_attr[sg_edges_index[idx]].to(device) 43 | 44 | inter_tr_idx = intersection(sg_nodes[idx], train_idx) 45 | inter_v_idx = intersection(sg_nodes[idx], valid_idx) 46 | 47 | train_target_idx += inter_tr_idx 48 | valid_target_idx += inter_v_idx 49 | 50 | tr_idx = [mapper[tr_idx] for tr_idx in inter_tr_idx] 51 | v_idx = [mapper[v_idx] for v_idx in inter_v_idx] 52 | 53 | pred = model(x, sg_nodes_idx, sg_edges[idx].to(device), sg_edges_attr).cpu().detach() 54 | 55 | train_predict.append(pred[tr_idx]) 56 | valid_predict.append(pred[v_idx]) 57 | 58 | inter_te_idx = intersection(sg_nodes[idx], test_idx) 59 | test_target_idx += inter_te_idx 60 | 61 | te_idx = [mapper[te_idx] for te_idx in inter_te_idx] 62 | test_predict.append(pred[te_idx]) 63 | 64 | train_pre = torch.cat(train_predict, 0).numpy() 65 | valid_pre = torch.cat(valid_predict, 0).numpy() 66 | test_pre = torch.cat(test_predict, 0).numpy() 67 | 68 | train_pre_ordered = train_pre[process_indexes(train_target_idx)] 69 | valid_pre_ordered = valid_pre[process_indexes(valid_target_idx)] 70 | test_pre_ordered = test_pre[process_indexes(test_target_idx)] 71 | 72 | train_pre_ordered_list.append(train_pre_ordered) 73 | valid_pre_ordered_list.append(valid_pre_ordered) 74 | test_pre_ordered_list.append(test_pre_ordered) 75 | 76 | train_pre_final = torch.mean(torch.Tensor(train_pre_ordered_list), dim=0) 77 | valid_pre_final = torch.mean(torch.Tensor(valid_pre_ordered_list), dim=0) 78 | test_pre_final = torch.mean(torch.Tensor(test_pre_ordered_list), dim=0) 79 | 80 | eval_result = {} 81 | 82 | input_dict = {"y_true": target[train_idx], "y_pred": train_pre_final} 83 | eval_result["train"] = evaluator.eval(input_dict) 84 | 85 | input_dict = {"y_true": target[valid_idx], "y_pred": valid_pre_final} 86 | eval_result["valid"] = evaluator.eval(input_dict) 87 | 88 | input_dict = {"y_true": target[test_idx], "y_pred": test_pre_final} 89 | eval_result["test"] = evaluator.eval(input_dict) 90 | 91 | return eval_result 92 | 93 | 94 | def main(): 95 | args = ArgsInit().args 96 | 97 | if args.use_gpu: 98 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 99 | else: 100 | device = torch.device("cpu") 101 | 102 | dataset = OGBNDataset(dataset_name=args.dataset) 103 | args.num_tasks = dataset.num_tasks 104 | args.nf_path = dataset.extract_node_features(args.aggr) 105 | 106 | evaluator = Evaluator(args.dataset) 107 | 108 | valid_data_list = [] 109 | 110 | for i in range(args.num_evals): 111 | parts = dataset.random_partition_graph(dataset.total_no_of_nodes, 112 | cluster_number=args.valid_cluster_number) 113 | valid_data = dataset.generate_sub_graphs(parts, 114 | cluster_number=args.valid_cluster_number) 115 | valid_data_list.append(valid_data) 116 | 117 | model = DeeperGCN(args) 118 | 119 | model.load_state_dict(torch.load(args.model_load_path)['model_state_dict']) 120 | model.to(device) 121 | result = multi_evaluate(valid_data_list, dataset, model, evaluator, device) 122 | print(result) 123 | model.print_params(final=True) 124 | 125 | 126 | if __name__ == "__main__": 127 | main() 128 | -------------------------------------------------------------------------------- /examples/ogb_eff/ogbn_arxiv_dgl/README.md: -------------------------------------------------------------------------------- 1 | # [Training Graph Neural Networks with 1000 Layers (ICML'2021)](https://arxiv.org/abs/2106.07476) 2 | 3 | # ogbn-arxiv dgl implementation 4 | 5 | ## All models are trained with one NVIDIA Tesla V100 (32GB GPU) 6 | 7 | ### Train the RevGAT teacher models (RevGAT+NormAdj+LabelReuse) 8 | Expected results: Average test accuracy: 74.02 ± 0.18 9 | ``` 10 | python3 main.py --use-norm --use-labels --n-label-iters=1 --no-attn-dst --edge-drop=0.3 --input-drop=0.25 --n-layers 5 --dropout 0.75 --n-hidden 256 --save kd --backbone rev --group 2 --mode teacher 11 | ``` 12 | ### Train the RevGAT student models after training the teacher models (RevGAT+N.Adj+LabelReuse+SelfKD) 13 | Expected results: Average test accuracy: 74.26 ± 0.17 14 | ``` 15 | python3 main.py --use-norm --use-labels --n-label-iters=1 --no-attn-dst --edge-drop=0.3 --input-drop=0.25 --n-layers 5 --dropout 0.75 --n-hidden 256 --save kd --backbone rev --group 2 --alpha 0.95 --temp 0.7 --mode student 16 | ``` 17 | 18 | ### Acknowledgements 19 | 20 | Our implementation is based on two previous submissions on OGB: [GAT+norm. adj.+label reuse](https://github.com/Espylapiza/dgl/tree/master/examples/pytorch/ogb/ogbn-arxiv) 21 | and [GAT+label reuse+self KD](https://github.com/ShunliRen/dgl/tree/master/examples/pytorch/ogb/ogbn-arxiv) 22 | -------------------------------------------------------------------------------- /examples/ogb_eff/ogbn_arxiv_dgl/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 4 | sys.path.append(ROOT_DIR) 5 | -------------------------------------------------------------------------------- /examples/ogb_eff/ogbn_arxiv_dgl/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def loss_kd(all_out, teacher_all_out, outputs, labels, teacher_outputs, 6 | alpha, temperature): 7 | """ 8 | loss function for Knowledge Distillation (KD) 9 | """ 10 | 11 | T = temperature 12 | 13 | loss_CE = F.cross_entropy(outputs, labels) 14 | D_KL = nn.KLDivLoss()(F.log_softmax(all_out / T, dim=1), 15 | F.softmax(teacher_all_out / T, dim=1)) * (T * T) 16 | KD_loss = (1. - alpha) * loss_CE + alpha * D_KL 17 | 18 | return KD_loss 19 | 20 | def loss_kd_only(all_out, teacher_all_out, temperature): 21 | T = temperature 22 | 23 | D_KL = nn.KLDivLoss()(F.log_softmax(all_out / T, dim=1), 24 | F.softmax(teacher_all_out / T, dim=1)) * (T * T) 25 | 26 | return D_KL 27 | -------------------------------------------------------------------------------- /examples/ogb_eff/ogbn_proteins/README.md: -------------------------------------------------------------------------------- 1 | # [Training Graph Neural Networks with 1000 Layers (ICML'2021)](https://arxiv.org/abs/2106.07476) 2 | 3 | # ogbn-proteins 4 | 5 | Our models RevGNN-Deep (1001 layers with 80 channels each) and RevGNN-Wide (448 layers with 224 channels each) were both trained on a single commodity GPU and achieve an ROC-AUC of 87.74 ± 0.13 and 88.24 ± 0.15 on the ogbn-proteins dataset. To the best of our knowledge, RevGNN-Deep is the deepest GNN in the literature by one order of magnitude. 6 | 7 | ## Default 8 | ``` 9 | --use_gpu False 10 | --cluster_number 10 11 | --valid_cluster_number 5 12 | --aggr add #options: [mean, max, add] 13 | --block plain #options: [plain, res, res+] 14 | --conv gen 15 | --gcn_aggr max #options: [max, mean, add, softmax_sg, softmax, power] 16 | --num_layers 3 17 | --conv_encode_edge False 18 | --mlp_layers 2 19 | --norm layer 20 | --hidden_channels 64 21 | --epochs 1000 22 | --lr 0.01 23 | --dropout 0.0 24 | --num_evals 1 25 | --backbone rev 26 | --group 2 27 | ``` 28 | 29 | ## All models are trained with one NVIDIA Tesla V100 (32GB GPU) 30 | 31 | ## RevGNN-Wide (448 layers, 224 channels) 32 | 33 | ### Train the RevGNN-Wide (448 layers, 224 channels) model on one GPU 34 | ``` 35 | python main.py --use_gpu --conv_encode_edge --use_one_hot_encoding --block res+ --gcn_aggr max --num_layers 448 --hidden_channels 224 --lr 0.001 --backbone rev --dropout 0.2 --group 2 36 | ``` 37 | 38 | ### Test the RevGNN-Wide model by multiple view inference (e.g. 10 times with 3 parts) 39 | Pre-trained model: [download](https://drive.google.com/drive/folders/1Bw6S0OUy8qDIZIfwQOD5I5VBjPdmN9yB?usp=sharing) from Google Drive. 40 | 41 | Expected test ROC-AUC: 88.24 ± 0.15. Need 48G GPU memory. NVIDIA RTX 6000 (48G) is recommented. 42 | ``` 43 | python test.py --use_gpu --conv_encode_edge --use_one_hot_encoding --block res+ --gcn_aggr max --num_layers 448 --hidden_channels 224 --lr 0.001 --backbone rev --dropout 0.2 --group 2 --model_load_path revgnn_wide.pth --valid_cluster_number 3 --num_evals 10 44 | ``` 45 | ### Test the RevGNN-Wide model by single inference (e.g. 1 time with 5 parts) 46 | Pre-trained model, [download](https://drive.google.com/drive/folders/1Bw6S0OUy8qDIZIfwQOD5I5VBjPdmN9yB?usp=sharing) from Google Drive. 47 | 48 | Expected test ROC-AUC: 87.62 ± 0.18. 32G GPU is enough. NVIDIA Tesla V100 (32GB GPU) is recommented. 49 | ``` 50 | python test.py --use_gpu --conv_encode_edge --use_one_hot_encoding --block res+ --gcn_aggr max --num_layers 448 --hidden_channels 224 --lr 0.001 --backbone rev --dropout 0.2 --group 2 --model_load_path revgnn_wide.pth --valid_cluster_number 5 --num_evals 1 51 | ``` 52 | 53 | ## RevGNN-Deep (1001 layers, 80 channels) 54 | 55 | ### Train the RevGNN-Deep (1001 layers, 80 channels) model on one GPU 56 | ``` 57 | python main.py --use_gpu --conv_encode_edge --use_one_hot_encoding --block res+ --gcn_aggr max --num_layers 1001 --hidden_channels 80 --lr 0.001 --backbone rev --dropout 0.1 --group 2 58 | ``` 59 | 60 | ### Test the RevGNN-Deep model by multiple view inference (e.g. 10 times with 3 parts) 61 | Pre-trained model, [download](https://drive.google.com/drive/folders/1Bw6S0OUy8qDIZIfwQOD5I5VBjPdmN9yB?usp=sharing) from Google Drive. 62 | 63 | Expected test ROC-AUC 87.74 ± 0.13. 32G GPU is enough. NVIDIA Tesla V100 (32GB GPU) is recommented. 64 | ``` 65 | python test.py --use_gpu --conv_encode_edge --use_one_hot_encoding --block res+ --gcn_aggr max --num_layers 1001 --hidden_channels 80 --lr 0.001 --backbone rev --dropout 0.1 --group 2 --model_load_path revgnn_deep.pth --valid_cluster_number 3 --num_evals 10 66 | ``` 67 | 68 | ### Test the RevGNN-Deep model by single inference (e.g. 1 time with 5 parts) 69 | Pre-trained model, [download](https://drive.google.com/drive/folders/1Bw6S0OUy8qDIZIfwQOD5I5VBjPdmN9yB?usp=sharing) from Google Drive. 70 | 71 | Expected test ROC-AUC 87.06 ± 0.20. 32G GPU is enough. NVIDIA Tesla V100 (32GB GPU) is recommented. 72 | ``` 73 | python test.py --use_gpu --conv_encode_edge --use_one_hot_encoding --block res+ --gcn_aggr max --num_layers 1001 --hidden_channels 80 --lr 0.001 --backbone rev --dropout 0.1 --group 2 --model_load_path revgnn_deep.pth --valid_cluster_number 5 --num_evals 1 74 | ``` 75 | 76 | ### Acknowledgements 77 | The [reversible module](../../../eff_gcn_modules/rev/gcn_revop.py) is implemented based on [MemCNN](https://github.com/silvandeleemput/memcnn/blob/master/LICENSE.txt) under MIT license. 78 | 79 | -------------------------------------------------------------------------------- /examples/ogb_eff/ogbn_proteins/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 4 | sys.path.append(ROOT_DIR) 5 | -------------------------------------------------------------------------------- /examples/ogb_eff/ogbn_proteins/model_rev.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | import torch 3 | import torch.nn as nn 4 | from gcn_lib.sparse.torch_nn import norm_layer 5 | import torch.nn.functional as F 6 | import logging 7 | import eff_gcn_modules.rev.memgcn as memgcn 8 | from eff_gcn_modules.rev.rev_layer import GENBlock 9 | import copy 10 | 11 | 12 | class RevGCN(torch.nn.Module): 13 | def __init__(self, args): 14 | super(RevGCN, self).__init__() 15 | 16 | self.num_layers = args.num_layers 17 | self.dropout = args.dropout 18 | self.group = args.group 19 | 20 | hidden_channels = args.hidden_channels 21 | num_tasks = args.num_tasks 22 | aggr = args.gcn_aggr 23 | 24 | t = args.t 25 | self.learn_t = args.learn_t 26 | p = args.p 27 | self.learn_p = args.learn_p 28 | y = args.y 29 | self.learn_y = args.learn_y 30 | 31 | self.msg_norm = args.msg_norm 32 | learn_msg_scale = args.learn_msg_scale 33 | 34 | conv_encode_edge = args.conv_encode_edge 35 | norm = args.norm 36 | mlp_layers = args.mlp_layers 37 | node_features_file_path = args.nf_path 38 | 39 | self.use_one_hot_encoding = args.use_one_hot_encoding 40 | 41 | self.gcns = torch.nn.ModuleList() 42 | self.last_norm = norm_layer(norm, hidden_channels) 43 | 44 | for layer in range(self.num_layers): 45 | Fms = nn.ModuleList() 46 | fm = GENBlock(hidden_channels//self.group, hidden_channels//self.group, 47 | aggr=aggr, 48 | t=t, learn_t=self.learn_t, 49 | p=p, learn_p=self.learn_p, 50 | y=y, learn_y=self.learn_y, 51 | msg_norm=self.msg_norm, 52 | learn_msg_scale=learn_msg_scale, 53 | encode_edge=conv_encode_edge, 54 | edge_feat_dim=hidden_channels, 55 | norm=norm, mlp_layers=mlp_layers) 56 | 57 | for i in range(self.group): 58 | if i == 0: 59 | Fms.append(fm) 60 | else: 61 | Fms.append(copy.deepcopy(fm)) 62 | 63 | 64 | invertible_module = memgcn.GroupAdditiveCoupling(Fms, 65 | group=self.group) 66 | 67 | 68 | gcn = memgcn.InvertibleModuleWrapper(fn=invertible_module, 69 | keep_input=False) 70 | 71 | self.gcns.append(gcn) 72 | 73 | self.node_features = torch.load(node_features_file_path).to(args.device) 74 | 75 | if self.use_one_hot_encoding: 76 | self.node_one_hot_encoder = torch.nn.Linear(8, 8) 77 | self.node_features_encoder = torch.nn.Linear(8 * 2, hidden_channels) 78 | else: 79 | self.node_features_encoder = torch.nn.Linear(8, hidden_channels) 80 | 81 | self.edge_encoder = torch.nn.Linear(8, hidden_channels) 82 | 83 | self.node_pred_linear = torch.nn.Linear(hidden_channels, num_tasks) 84 | 85 | def forward(self, x, node_index, edge_index, edge_attr, epoch=-1): 86 | 87 | node_features_1st = self.node_features[node_index] 88 | 89 | if self.use_one_hot_encoding: 90 | node_features_2nd = self.node_one_hot_encoder(x) 91 | # concatenate 92 | node_features = torch.cat((node_features_1st, node_features_2nd), dim=1) 93 | else: 94 | node_features = node_features_1st 95 | 96 | h = self.node_features_encoder(node_features) 97 | 98 | edge_emb = self.edge_encoder(edge_attr) 99 | edge_emb = torch.cat([edge_emb]*self.group, dim=-1) 100 | 101 | m = torch.zeros_like(h).bernoulli_(1 - self.dropout) 102 | mask = m.requires_grad_(False) / (1 - self.dropout) 103 | 104 | h = self.gcns[0](h, edge_index, mask, edge_emb) 105 | 106 | for layer in range(1, self.num_layers): 107 | h = self.gcns[layer](h, edge_index, mask, edge_emb) 108 | 109 | h = F.relu(self.last_norm(h)) 110 | h = F.dropout(h, p=self.dropout, training=self.training) 111 | 112 | return self.node_pred_linear(h) 113 | 114 | 115 | def print_params(self, epoch=None, final=False): 116 | 117 | if self.learn_t: 118 | ts = [] 119 | for gcn in self.gcns: 120 | ts.append(gcn.t.item()) 121 | if final: 122 | print('Final t {}'.format(ts)) 123 | else: 124 | logging.info('Epoch {}, t {}'.format(epoch, ts)) 125 | 126 | if self.learn_p: 127 | ps = [] 128 | for gcn in self.gcns: 129 | ps.append(gcn.p.item()) 130 | if final: 131 | print('Final p {}'.format(ps)) 132 | else: 133 | logging.info('Epoch {}, p {}'.format(epoch, ps)) 134 | 135 | if self.learn_y: 136 | ys = [] 137 | for gcn in self.gcns: 138 | ys.append(gcn.sigmoid_y.item()) 139 | if final: 140 | print('Final sigmoid(y) {}'.format(ys)) 141 | else: 142 | logging.info('Epoch {}, sigmoid(y) {}'.format(epoch, ys)) 143 | 144 | if self.msg_norm: 145 | ss = [] 146 | for gcn in self.gcns: 147 | ss.append(gcn.msg_norm.msg_scale.item()) 148 | if final: 149 | print('Final s {}'.format(ss)) 150 | else: 151 | logging.info('Epoch {}, s {}'.format(epoch, ss)) 152 | -------------------------------------------------------------------------------- /examples/ogb_eff/ogbn_proteins/test.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | import torch 3 | from dataset import OGBNDataset 4 | from utils.data_util import intersection, process_indexes 5 | import numpy as np 6 | from ogb.nodeproppred import Evaluator 7 | # from model import DeeperGCN 8 | # from model_geq import DEQGCN 9 | from model_rev import RevGCN 10 | # from model_revwt import WTRevGCN 11 | from args import ArgsInit 12 | 13 | 14 | @torch.no_grad() 15 | def multi_evaluate(valid_data_list, dataset, model, evaluator, device): 16 | model.eval() 17 | target = dataset.y.detach().numpy() 18 | 19 | train_pre_ordered_list = [] 20 | valid_pre_ordered_list = [] 21 | test_pre_ordered_list = [] 22 | 23 | test_idx = dataset.test_idx.tolist() 24 | train_idx = dataset.train_idx.tolist() 25 | valid_idx = dataset.valid_idx.tolist() 26 | 27 | for valid_data_item in valid_data_list: 28 | sg_nodes, sg_edges, sg_edges_index, _ = valid_data_item 29 | idx_clusters = np.arange(len(sg_nodes)) 30 | 31 | test_predict = [] 32 | test_target_idx = [] 33 | 34 | train_predict = [] 35 | valid_predict = [] 36 | 37 | train_target_idx = [] 38 | valid_target_idx = [] 39 | 40 | for idx in idx_clusters: 41 | x = dataset.x[sg_nodes[idx]].float().to(device) 42 | sg_nodes_idx = torch.LongTensor(sg_nodes[idx]).to(device) 43 | 44 | mapper = {node: idx for idx, node in enumerate(sg_nodes[idx])} 45 | sg_edges_attr = dataset.edge_attr[sg_edges_index[idx]].to(device) 46 | 47 | inter_tr_idx = intersection(sg_nodes[idx], train_idx) 48 | inter_v_idx = intersection(sg_nodes[idx], valid_idx) 49 | 50 | train_target_idx += inter_tr_idx 51 | valid_target_idx += inter_v_idx 52 | 53 | tr_idx = [mapper[tr_idx] for tr_idx in inter_tr_idx] 54 | v_idx = [mapper[v_idx] for v_idx in inter_v_idx] 55 | 56 | pred = model(x, sg_nodes_idx, sg_edges[idx].to(device), sg_edges_attr).cpu().detach() 57 | 58 | train_predict.append(pred[tr_idx]) 59 | valid_predict.append(pred[v_idx]) 60 | 61 | inter_te_idx = intersection(sg_nodes[idx], test_idx) 62 | test_target_idx += inter_te_idx 63 | 64 | te_idx = [mapper[te_idx] for te_idx in inter_te_idx] 65 | test_predict.append(pred[te_idx]) 66 | 67 | train_pre = torch.cat(train_predict, 0).numpy() 68 | valid_pre = torch.cat(valid_predict, 0).numpy() 69 | test_pre = torch.cat(test_predict, 0).numpy() 70 | 71 | train_pre_ordered = train_pre[process_indexes(train_target_idx)] 72 | valid_pre_ordered = valid_pre[process_indexes(valid_target_idx)] 73 | test_pre_ordered = test_pre[process_indexes(test_target_idx)] 74 | 75 | train_pre_ordered_list.append(train_pre_ordered) 76 | valid_pre_ordered_list.append(valid_pre_ordered) 77 | test_pre_ordered_list.append(test_pre_ordered) 78 | 79 | torch.cuda.empty_cache() 80 | 81 | train_pre_final = torch.mean(torch.Tensor(train_pre_ordered_list), dim=0) 82 | valid_pre_final = torch.mean(torch.Tensor(valid_pre_ordered_list), dim=0) 83 | test_pre_final = torch.mean(torch.Tensor(test_pre_ordered_list), dim=0) 84 | 85 | eval_result = {} 86 | 87 | input_dict = {"y_true": target[train_idx], "y_pred": train_pre_final} 88 | eval_result["train"] = evaluator.eval(input_dict) 89 | 90 | input_dict = {"y_true": target[valid_idx], "y_pred": valid_pre_final} 91 | eval_result["valid"] = evaluator.eval(input_dict) 92 | 93 | input_dict = {"y_true": target[test_idx], "y_pred": test_pre_final} 94 | eval_result["test"] = evaluator.eval(input_dict) 95 | 96 | return eval_result 97 | 98 | 99 | def main(): 100 | args = ArgsInit().args 101 | 102 | if args.use_gpu: 103 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 104 | else: 105 | device = torch.device("cpu") 106 | 107 | args.device = device 108 | 109 | dataset = OGBNDataset(dataset_name=args.dataset) 110 | args.num_tasks = dataset.num_tasks 111 | args.nf_path = dataset.extract_node_features(args.aggr) 112 | 113 | evaluator = Evaluator(args.dataset) 114 | 115 | valid_data_list = [] 116 | 117 | for i in range(args.num_evals): 118 | parts = dataset.random_partition_graph(dataset.total_no_of_nodes, 119 | cluster_number=args.valid_cluster_number) 120 | valid_data = dataset.generate_sub_graphs(parts, 121 | cluster_number=args.valid_cluster_number) 122 | valid_data_list.append(valid_data) 123 | 124 | if args.backbone == 'deepergcn': 125 | # model = DeeperGCN(args).to(device) 126 | pass 127 | # elif args.backbone == 'deq': 128 | # model = DEQGCN(args).to(device) 129 | # elif args.backbone == 'revwt': 130 | # model = WTRevGCN(args).to(device) 131 | elif args.backbone == 'rev': 132 | model = RevGCN(args).to(device) 133 | pass 134 | else: 135 | raise Exception("unkown backbone") 136 | 137 | if torch.cuda.is_available(): 138 | model.load_state_dict(torch.load(args.model_load_path)['model_state_dict']) 139 | else: 140 | model.load_state_dict(torch.load(args.model_load_path, 141 | map_location=torch.device('cpu'))['model_state_dict']) 142 | 143 | model.to(device) 144 | with torch.cuda.amp.autocast(): 145 | result = multi_evaluate(valid_data_list, dataset, model, evaluator, device) 146 | print(result) 147 | model.print_params(final=True) 148 | peak_memuse = torch.cuda.max_memory_allocated(device) / float(1024 ** 3) 149 | print('Peak memuse {:.2f} G'.format(peak_memuse)) 150 | 151 | 152 | if __name__ == "__main__": 153 | main() 154 | -------------------------------------------------------------------------------- /examples/part_sem_seg/README.md: -------------------------------------------------------------------------------- 1 | ## [Part Segmentation on PartNet](https://arxiv.org/pdf/1910.06849.pdf) 2 | 3 |

4 | 5 |

6 | 7 | ### Preparing the Dataset 8 | Make sure you request access to download the PartNet v0 dataset [here](https://cs.stanford.edu/~kaichun/partnet/). It's an official website of Partnet. 9 | Once the data is downloaded, extract the `sem_seg_h5` data and put them inside a new folder called 'raw'. 10 | For example, our data folder structure is like this: `/data/deepgcn/partnet/raw/sem_seg_h5/category-level`. `category` is the name of a category, eg. Bed. `level` is 1, 2, or 3. When we train and test, we set `--data_dir /data/deepgcn/partnet`. 11 | 12 | ### Train 13 | We train each model on one tesla V100. 14 | 15 | For training the default ResEdgeConv-28 with 64 filters on the Bed category, run: 16 | ``` 17 | python main.py --phase train --category 1 --data_dir /data/deepgcn/partnet 18 | ``` 19 | Note that, We only focus on fine-grained level of part segmentation in the experiment. 20 | For all the categories, we use the same training parameters as default (see `config.py` for details). 21 | 22 | If you want to train a model with other gcn layers (for example mrgcn), run 23 | ``` 24 | python main.py --phase train --category 1 --conv mr --data_dir /data/deepgcn/partnet 25 | ``` 26 | Other important parameters are: 27 | ``` 28 | --block graph backbone block type {res, plain, dense} 29 | --conv graph conv layer {edge, mr, sage, gin, gcn, gat} 30 | --n_filters number of channels of deep features, default is 64 31 | --n_blocks number of basic blocks, default is 28 32 | --category NO. of category. default is 1 (Bed) 33 | ``` 34 | The category list is: 35 | ``` 36 | category_names = ['Bag', 'Bed', 'Bottle', 'Bowl', 'Chair', 'Clock', 'Dishwasher', 'Display', 'Door', 'Earphone', # 0-9 37 | 'Faucet', 'Hat', 'Keyboard', 'Knife', 'Lamp', 'Laptop', 'Microwave', 'Mug', 'Refrigerator', 'Scissors', # 10-19 38 | 'StorageFurniture', 'Table', 'TrashCan', 'Vase'] 39 | ``` 40 | ### Test 41 | We test and report results on the testing dataset using the checkpoints which perform the best in the validation dataset. 42 | Our pretrained models can be found from [Google Cloud](https://drive.google.com/drive/folders/15Y7Ao4VBysHBHxyQwYvb2SU1iFi9ZZRK?usp=sharing). 43 | 44 | The Naming format of our pretrained model is: `task-category-segmentationLevel-conv-n_blocks-n_filters-otherParameters-val_best_model_best.pth`, eg. `PartnetSemanticSeg-Bed-L3-res-edge-n28-C64-k9-drop0.5-lr0.005_B6-val_best_model.pth`. 45 | val_best means the checkpoint is the best one on the validation dataset. 46 | 47 | Use the parameter `--pretrained_model` to set a specific pretrained model to load. For example, 48 | ``` 49 | python -u main.py --phase test --category 1 --pretrained_model checkpoints/PartnetSemanticSeg-Bed-L3-res-edge-n28-C64-k9-drop0.5-lr0.005_B6-val_best_model.pth --data_dir /data/deepgcn/partnet --test_batch_size 8 50 | ``` 51 | Please also specify the number of blocks and filters. 52 | Note: 53 | - the path of `--pretrained_model` is a relative path to `main.py`, so don't add `examples/part_sem_seg` in `--pretrained_model`. Or you can feed an absolute path of `--pretrained_model`. 54 | - if you do not have V100, you can set the `test_batch_size` to 1. It does not influence the test accuracy. 55 | 56 | #### Visualization 57 | 1. step1 58 | Use the script `eval.py` to generate `.obj` files to be visualized: 59 | ``` 60 | python -u eval.py --phase test --category 1 --pretrained_model checkpoints/PartnetSemanticSeg-Bed-L3-res-edge-n28-C64-k9-drop0.5-lr0.005_B6-val_best_model.pth --data_dir /data/deepgcn/partnet 61 | ``` 62 | 2. step2 63 | To visualize the output of a trained model please use `visualize.py`. 64 | Define the path to the result folder (`--dir_path`), category's number (`--category`), the No. of model instance (`--obj_no`), the folders to visualize (`--folders`) and run below: 65 | ``` 66 | python -u visualize.py --dir_path /change/the/path/to/your/result/ --category 1 --obj_no 0 --folders res 67 | ``` 68 | `dir_path` is the path to the folder of your result, the structure is the following: 69 | 70 | dir_path 71 | ├── res # result folder for ResGCN 72 | ├── Bed # result of Bed class 73 | ├── Bottle # result of Bottle class 74 | ... # result of other classes 75 | 76 | ├── plain # result folder for PlainGCN 77 | ├── Bed # result of Bed class 78 | ├── Bottle # result of Bottle class 79 | ... # result of other classes 80 | -------------------------------------------------------------------------------- /examples/part_sem_seg/__init__.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../')) 3 | 4 | -------------------------------------------------------------------------------- /examples/part_sem_seg/architecture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from gcn_lib.dense import BasicConv, GraphConv2d, ResDynBlock2d, DenseDynBlock2d, DenseDilatedKnnGraph, PlainDynBlock2d 3 | from torch.nn import Sequential as Seq 4 | import torch.nn.functional as F 5 | 6 | 7 | class DeepGCN(torch.nn.Module): 8 | def __init__(self, opt): 9 | super(DeepGCN, self).__init__() 10 | channels = opt.n_filters 11 | k = opt.k 12 | act = opt.act 13 | norm = opt.norm 14 | bias = opt.bias 15 | knn = 'matrix' # implement knn using matrix multiplication 16 | epsilon = opt.epsilon 17 | stochastic = opt.stochastic 18 | conv = opt.conv 19 | c_growth = channels 20 | emb_dims = 1024 21 | 22 | self.n_blocks = opt.n_blocks 23 | 24 | self.knn = DenseDilatedKnnGraph(k, 1, stochastic, epsilon) 25 | self.head = GraphConv2d(opt.in_channels, channels, conv, act, norm, bias=False) 26 | 27 | if opt.block.lower() == 'res': 28 | if opt.use_dilation: 29 | self.backbone = Seq(*[ResDynBlock2d(channels, k, i + 1, conv, act, norm, 30 | bias, stochastic, epsilon, knn) 31 | for i in range(self.n_blocks - 1)]) 32 | else: 33 | self.backbone = Seq(*[ResDynBlock2d(channels, k, 1, conv, act, norm, 34 | bias, stochastic, epsilon, knn) 35 | for _ in range(self.n_blocks - 1)]) 36 | fusion_dims = int(channels + c_growth * (self.n_blocks - 1)) 37 | elif opt.block.lower() == 'plain': 38 | # Plain GCN. No dilation, no stochastic 39 | stochastic = False 40 | self.backbone = Seq(*[PlainDynBlock2d(channels, k, 1, conv, act, norm, 41 | bias, stochastic, epsilon, knn) 42 | for i in range(self.n_blocks - 1)]) 43 | 44 | fusion_dims = int(channels+c_growth*(self.n_blocks-1)) 45 | else: 46 | raise NotImplementedError('{} is not supported in this experiment'.format(opt.block)) 47 | 48 | self.fusion_block = BasicConv([fusion_dims, emb_dims], 'leakyrelu', norm, bias=False) 49 | self.prediction = Seq(*[BasicConv([emb_dims * 3, 512], 'leakyrelu', norm, drop=opt.dropout), 50 | BasicConv([512, 256], 'leakyrelu', norm, drop=opt.dropout), 51 | BasicConv([256, opt.n_classes], None, None)]) 52 | 53 | self.model_init() 54 | 55 | def model_init(self): 56 | for m in self.modules(): 57 | if isinstance(m, torch.nn.Conv2d): 58 | torch.nn.init.kaiming_normal_(m.weight) 59 | m.weight.requires_grad = True 60 | if m.bias is not None: 61 | m.bias.data.zero_() 62 | m.bias.requires_grad = True 63 | 64 | def forward(self, inputs): 65 | feats = [self.head(inputs, self.knn(inputs[:, 0:3]))] 66 | for i in range(self.n_blocks-1): 67 | feats.append(self.backbone[i](feats[-1])) 68 | feats = torch.cat(feats, 1) 69 | fusion = self.fusion_block(feats) 70 | 71 | x1 = F.adaptive_max_pool2d(fusion, 1) 72 | x2 = F.adaptive_avg_pool2d(fusion, 1) 73 | feat_global_pool = torch.cat((x1, x2), dim=1) 74 | feat_global_pool = torch.repeat_interleave(feat_global_pool, repeats=fusion.shape[2], dim=2) 75 | cat_pooled = torch.cat((feat_global_pool, fusion), dim=1) 76 | out = self.prediction(cat_pooled).squeeze(-1) 77 | return F.log_softmax(out, dim=1) 78 | 79 | -------------------------------------------------------------------------------- /examples/part_sem_seg/visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | sys.path.append(ROOT_DIR) 6 | import os.path as osp 7 | from utils.pc_viz import visualize_part_seg 8 | import argparse 9 | 10 | 11 | category_names = ['Bag', 'Bed', 'Bottle', 'Bowl', 'Chair', 'Clock', 'Dishwasher', 'Display', 'Door', 'Earphone', # 0-9 12 | 'Faucet', 'Hat', 'Keyboard', 'Knife', 'Lamp', 'Laptop', 'Microwave', 'Mug', 'Refrigerator', 'Scissors', # 10-19 13 | 'StorageFurniture', 'Table', 'TrashCan', 'Vase'] # 20-23 14 | 15 | parser = argparse.ArgumentParser(description='Qualitative comparision of ResGCN ' 16 | 'against PlainGCN on PartNet segmentation') 17 | 18 | # dir_path set to the location of the result folder. 19 | # result folder should have such structure: 20 | # result 21 | # ├── plain # result folder for PlainGCN 22 | # ├── Bed # the obj director of category Bed 23 | # ├── Bed_0_pred.obj 24 | # ├── res # result folder for ResGCN 25 | # ├── Bed # the obj director of category Bed 26 | # ├── Bed_0_pred.obj 27 | 28 | parser.add_argument('--category', type=int, default=4) 29 | parser.add_argument('--obj_no', default=0, type=int, help='NO. of which obj in a given category to visualize') 30 | parser.add_argument('--dir_path', default='../result', type=str, help='path to the result') 31 | parser.add_argument('--folders', default='plain,res', type=str, 32 | help='use "," to separate different folders, eg. "res,plain"') 33 | args = parser.parse_args() 34 | 35 | category = category_names[args.category] 36 | obj_no = args.obj_no 37 | folders = list(map(lambda x: x.strip(), args.folders.split(','))) 38 | 39 | folder_paths = list(map(lambda x: osp.join(args.dir_path, x, category), folders)) 40 | 41 | file_name_pred = '_'.join([category, str(obj_no), 'pred.obj']) 42 | file_name_gt = '_'.join([category, str(obj_no), 'gt.obj']) 43 | 44 | texts = folders.copy() 45 | texts.insert(0, 'Ground Truth') 46 | # show Ground Truth, PlainGCN, ResGCN 47 | visualize_part_seg(file_name_pred, 48 | file_name_gt, 49 | folder_paths, 50 | limit=-1, 51 | text=texts, 52 | interactive=True, 53 | orientation='horizontal') 54 | -------------------------------------------------------------------------------- /examples/ppi/README.md: -------------------------------------------------------------------------------- 1 | ## [Graph Learning on Biological Networks](https://arxiv.org/pdf/1910.06849.pdf) 2 | 3 |

4 | 5 |

6 | 7 | ### Train 8 | We train each model on one tesla V100. 9 | 10 | For training the default ResMRConv-14 with 64 filters, run 11 | ``` 12 | python -u examples/ppi/main.py --phase train --data_dir /data/deepgcn/ppi 13 | ``` 14 | If you want to train model with other gcn layers (for example EdgeConv, 28 layers, 256 channels in the first layer, with dense connection), run 15 | ``` 16 | python -u examples/ppi/main.py --phase train --conv edge --data_dir /data/deepgcn/ppi --block dense --n_filters 256 --n_blocks 28 17 | ``` 18 | 19 | Just need to set `--data_dir` into your data folder, dataset will be downloaded automatically. 20 | Other parameters for changing the architecture are: 21 | ``` 22 | --block graph backbone block type {res, plain, dense} 23 | --conv graph conv layer {edge, mr, sage, gin, gcn, gat} 24 | --n_filters number of channels of deep features, default is 64 25 | --n_blocks number of basic blocks, default is 28 26 | ``` 27 | ### Test 28 | #### Pretrained Models 29 | Our pretrained models can be found from [Goolge Cloud](https://drive.google.com/drive/folders/1LoT1B9FDgylUffHY8K43FFfred-luZaz?usp=sharing). 30 | 31 | The Naming format of our pretrained model: `task-connection-conv_type-n_blocks-n_filters_phase_best.pth`, eg. `ppi-res-mr-28-256_val_best.pth`, which means PPI node classification task, with residual connection, convolution is MRGCN, 28 layers, 256 channels, the best pretrained model found in validation dataset. 32 | 33 | Use parameter `--pretrained_model` to set the specific pretrained model you want. 34 | ``` 35 | python -u examples/ppi/main.py --phase test --pretrained_model checkpoints/ppi-res-mr-28-256_val_best.pth --data_dir /data/deepgcn/ppi --n_filters 256 --n_blocks 28 --conv mr --block res 36 | ``` 37 | 38 | ``` 39 | python -u examples/ppi/main.py --phase test --pretrained_model checkpoints/ppi-dense-mr-14-256_val_best.pth --data_dir /data/deepgcn/ppi --n_filters 256 --n_blocks 14 --conv mr --block dense 40 | ``` 41 | Please also specify the number of blocks and filters according to the name of pretrained models. 42 | -------------------------------------------------------------------------------- /examples/ppi/architecture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Linear as Lin, Sequential as Seq 3 | from gcn_lib.sparse import MultiSeq, MLP, GraphConv, ResGraphBlock, DenseGraphBlock 4 | 5 | 6 | class DeepGCN(torch.nn.Module): 7 | """ 8 | static graph 9 | 10 | """ 11 | def __init__(self, opt): 12 | super(DeepGCN, self).__init__() 13 | channels = opt.n_filters 14 | act = opt.act 15 | norm = opt.norm 16 | bias = opt.bias 17 | conv = opt.conv 18 | heads = opt.n_heads 19 | c_growth = 0 20 | self.n_blocks = opt.n_blocks 21 | self.head = GraphConv(opt.in_channels, channels, conv, act, norm, bias, heads) 22 | 23 | res_scale = 1 if opt.block.lower() == 'res' else 0 24 | if opt.block.lower() == 'dense': 25 | c_growth = channels 26 | self.backbone = MultiSeq(*[DenseGraphBlock(channels+i*c_growth, c_growth, conv, act, norm, bias, heads) 27 | for i in range(self.n_blocks-1)]) 28 | else: 29 | self.backbone = MultiSeq(*[ResGraphBlock(channels, conv, act, norm, bias, heads, res_scale) 30 | for _ in range(self.n_blocks-1)]) 31 | fusion_dims = int(channels * self.n_blocks + c_growth * ((1 + self.n_blocks - 1) * (self.n_blocks - 1) / 2)) 32 | self.fusion_block = MLP([fusion_dims, 1024], act, None, bias) 33 | self.prediction = Seq(*[MLP([1+fusion_dims, 512], act, norm, bias), torch.nn.Dropout(p=opt.dropout), 34 | MLP([512, 256], act, norm, bias), torch.nn.Dropout(p=opt.dropout), 35 | MLP([256, opt.n_classes], None, None, bias)]) 36 | self.model_init() 37 | 38 | def model_init(self): 39 | for m in self.modules(): 40 | if isinstance(m, Lin): 41 | torch.nn.init.kaiming_normal_(m.weight) 42 | m.weight.requires_grad = True 43 | if m.bias is not None: 44 | m.bias.data.zero_() 45 | m.bias.requires_grad = True 46 | 47 | def forward(self, data): 48 | x, edge_index, batch = data.x, data.edge_index, data.batch 49 | feats = [self.head(x, edge_index)] 50 | for i in range(self.n_blocks-1): 51 | feats.append(self.backbone[i](feats[-1], edge_index)[0]) 52 | feats = torch.cat(feats, 1) 53 | fusion, _ = torch.max(self.fusion_block(feats), 1, keepdim=True) 54 | out = self.prediction(torch.cat((feats, fusion), 1)) 55 | return out 56 | 57 | 58 | -------------------------------------------------------------------------------- /examples/sem_seg_dense/README.md: -------------------------------------------------------------------------------- 1 | ## [Semantic segmentation of indoor scenes](https://arxiv.org/pdf/1904.03751.pdf) 2 | 3 |

4 | 5 |

6 | 7 | 8 | Sem_seg_dense and sem_seg_sparse are both for the semantic segmentation task. 9 | The difference between them is that the data shape for the graph is different. 10 | As for sem_seg_sparse, data shape is N x C (N is number of nodes, C is feature_size) and there is a `batch` variable indicating the batch of each node. 11 | But for sem_seg_dense, data shape is B x C x N x 1 (B is batch size, N here means the number of nodes per graph). 12 | In gcn_lib, there are two folders: dense and sparse. They are used for different data shapes above. 13 | 14 | 15 | ### Train 16 | We keep using 2 Tesla V100 GPUs for distributed training. 17 | ``` 18 | cd examples/sem_seg/dense 19 | ``` 20 | 21 | Train ResGCN-28 (DeepGCN with 28 residually connected EdgeConv layers, dilated graph convolutions and batch normalization): 22 | ``` 23 | CUDA_VISIBLE_DEVICES=0,1 python train.py --multi_gpus --phase train --data_dir /data/deepgcn/S3DIS 24 | ``` 25 | Just need to set `--data_dir $yourpath/to/data`, dataset will be downloaded automatically. 26 | 27 | If you want to train model with other gcn layers (for example, mrgcn), run 28 | ``` 29 | python train.py --conv mr --multi_gpus --phase train --data_dir /data/deepgcn/S3DIS 30 | ``` 31 | Other parameters for changing the architecture are: 32 | ``` 33 | --block graph backbone block type {res, plain, dense} 34 | --conv graph conv layer {edge, mr} 35 | --n_filters number of channels of deep features, default is 64 36 | --n_blocks number of basic blocks, default is 28 37 | ``` 38 | 39 | A shallow version of DeepGCN (ResGCN-7) could be trained by the command below: 40 | ``` 41 | CUDA_VISIBLE_DEVICES=0 python train.py --multi_gpus --phase train --data_dir /data/deepgcn/S3DIS --n_blocks 7 42 | ``` 43 | 44 | 45 | ### Evaluation 46 | Quick test on area 5, run: 47 | 48 | ``` 49 | python test.py --pretrained_model checkpoints/sem_seg_dense-res-edge-28-64-ckpt_best_model.pth --batch_size 32 --data_dir /data/deepgcn/S3DIS 50 | ``` 51 | 52 | #### Pretrained Models 53 | Our pretrained model is available here [google driver](https://drive.google.com/open?id=1iAJbHqiNwc4nJlP67sp1xLkl5EtC4PU_). 54 | 55 | Note: Please use our Tensorflow code if you want to reproduce the same result in the paper. 56 | The performance of pytorch code is slightly worse than tensorflow. mIOU is 52.11% on Area 5 compared to 52.49% in the tensorflow version. 57 | ``` 58 | python test.py --pretrained_model checkpoints/sem_seg_dense-res-edge-28-64-ckpt_best_model.pth --batch_size 32 --data_dir /data/deepgcn/S3DIS 59 | ``` 60 | Lower the batch size if running out of memory. The batch size will not influence the test results. 61 | 62 | -------------------------------------------------------------------------------- /examples/sem_seg_dense/__init__.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../')) 3 | 4 | -------------------------------------------------------------------------------- /examples/sem_seg_dense/architecture.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | import torch 3 | from gcn_lib.dense import BasicConv, GraphConv2d, PlainDynBlock2d, ResDynBlock2d, DenseDynBlock2d, DenseDilatedKnnGraph 4 | from torch.nn import Sequential as Seq 5 | 6 | 7 | class DenseDeepGCN(torch.nn.Module): 8 | def __init__(self, opt): 9 | super(DenseDeepGCN, self).__init__() 10 | channels = opt.n_filters 11 | k = opt.k 12 | act = opt.act 13 | norm = opt.norm 14 | bias = opt.bias 15 | epsilon = opt.epsilon 16 | stochastic = opt.stochastic 17 | conv = opt.conv 18 | c_growth = channels 19 | self.n_blocks = opt.n_blocks 20 | 21 | self.knn = DenseDilatedKnnGraph(k, 1, stochastic, epsilon) 22 | self.head = GraphConv2d(opt.in_channels, channels, conv, act, norm, bias) 23 | 24 | if opt.block.lower() == 'res': 25 | self.backbone = Seq(*[ResDynBlock2d(channels, k, 1+i, conv, act, norm, bias, stochastic, epsilon) 26 | for i in range(self.n_blocks-1)]) 27 | fusion_dims = int(channels + c_growth * (self.n_blocks - 1)) 28 | elif opt.block.lower() == 'dense': 29 | self.backbone = Seq(*[DenseDynBlock2d(channels+c_growth*i, c_growth, k, 1+i, conv, act, 30 | norm, bias, stochastic, epsilon) 31 | for i in range(self.n_blocks-1)]) 32 | fusion_dims = int( 33 | (channels + channels + c_growth * (self.n_blocks - 1)) * self.n_blocks // 2) 34 | else: 35 | stochastic = False 36 | 37 | self.backbone = Seq(*[PlainDynBlock2d(channels, k, 1, conv, act, norm, 38 | bias, stochastic, epsilon) 39 | for i in range(self.n_blocks - 1)]) 40 | fusion_dims = int(channels + c_growth * (self.n_blocks - 1)) 41 | 42 | self.fusion_block = BasicConv([fusion_dims, 1024], act, norm, bias) 43 | self.prediction = Seq(*[BasicConv([fusion_dims+1024, 512], act, norm, bias), 44 | BasicConv([512, 256], act, norm, bias), 45 | torch.nn.Dropout(p=opt.dropout), 46 | BasicConv([256, opt.n_classes], None, None, bias)]) 47 | 48 | def forward(self, inputs): 49 | feats = [self.head(inputs, self.knn(inputs[:, 0:3]))] 50 | for i in range(self.n_blocks-1): 51 | feats.append(self.backbone[i](feats[-1])) 52 | feats = torch.cat(feats, dim=1) 53 | 54 | fusion = torch.max_pool2d(self.fusion_block(feats), kernel_size=[feats.shape[2], feats.shape[3]]) 55 | fusion = torch.repeat_interleave(fusion, repeats=feats.shape[2], dim=2) 56 | return self.prediction(torch.cat((fusion, feats), dim=1)).squeeze(-1) 57 | 58 | 59 | if __name__ == "__main__": 60 | import random, numpy as np, argparse 61 | seed = 0 62 | torch.manual_seed(seed) 63 | torch.cuda.manual_seed(seed) 64 | torch.cuda.manual_seed_all(seed) 65 | torch.backends.cudnn.deterministic = True 66 | torch.backends.cudnn.benchmark = False 67 | 68 | batch_size = 2 69 | N = 1024 70 | device = 'cuda' 71 | 72 | parser = argparse.ArgumentParser(description='PyTorch implementation of Deep GCN For semantic segmentation') 73 | parser.add_argument('--in_channels', default=9, type=int, help='input channels (default:9)') 74 | parser.add_argument('--n_classes', default=13, type=int, help='num of segmentation classes (default:13)') 75 | parser.add_argument('--k', default=20, type=int, help='neighbor num (default:16)') 76 | parser.add_argument('--block', default='res', type=str, help='graph backbone block type {plain, res, dense}') 77 | parser.add_argument('--conv', default='edge', type=str, help='graph conv layer {edge, mr}') 78 | parser.add_argument('--act', default='relu', type=str, help='activation layer {relu, prelu, leakyrelu}') 79 | parser.add_argument('--norm', default='batch', type=str, help='{batch, instance} normalization') 80 | parser.add_argument('--bias', default=True, type=bool, help='bias of conv layer True or False') 81 | parser.add_argument('--n_filters', default=64, type=int, help='number of channels of deep features') 82 | parser.add_argument('--n_blocks', default=7, type=int, help='number of basic blocks') 83 | parser.add_argument('--dropout', default=0.5, type=float, help='ratio of dropout') 84 | parser.add_argument('--epsilon', default=0.2, type=float, help='stochastic epsilon for gcn') 85 | parser.add_argument('--stochastic', default=False, type=bool, help='stochastic for gcn, True or False') 86 | args = parser.parse_args() 87 | 88 | pos = torch.rand((batch_size, N, 3), dtype=torch.float).to(device) 89 | x = torch.rand((batch_size, N, 6), dtype=torch.float).to(device) 90 | 91 | inputs = torch.cat((pos, x), 2).transpose(1, 2).unsqueeze(-1) 92 | 93 | net = DenseDeepGCN(args).to(device) 94 | print(net) 95 | 96 | out = net(inputs) 97 | 98 | print(inputs.shape, out.shape) 99 | import time 100 | st = time.time() 101 | runs = 1000 102 | 103 | with torch.no_grad(): 104 | for i in range(runs): 105 | 106 | out = net(inputs) 107 | torch.cuda.synchronize() 108 | 109 | print(time.time() - st) 110 | 111 | -------------------------------------------------------------------------------- /examples/sem_seg_dense/test.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | from tqdm import tqdm 3 | import numpy as np 4 | import torch 5 | import torch_geometric.datasets as GeoData 6 | from torch_geometric.data import DenseDataLoader 7 | import torch_geometric.transforms as T 8 | from config import OptInit 9 | from architecture import DenseDeepGCN 10 | from utils.ckpt_util import load_pretrained_models 11 | import logging 12 | 13 | 14 | def main(): 15 | opt = OptInit().get_args() 16 | 17 | logging.info('===> Creating dataloader...') 18 | test_dataset = GeoData.S3DIS(opt.data_dir, opt.area, train=False, pre_transform=T.NormalizeScale()) 19 | test_loader = DenseDataLoader(test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=0) 20 | opt.n_classes = test_loader.dataset.num_classes 21 | if opt.no_clutter: 22 | opt.n_classes -= 1 23 | 24 | logging.info('===> Loading the network ...') 25 | model = DenseDeepGCN(opt).to(opt.device) 26 | model, opt.best_value, opt.epoch = load_pretrained_models(model, opt.pretrained_model, opt.phase) 27 | 28 | logging.info('===> Start Evaluation ...') 29 | test(model, test_loader, opt) 30 | 31 | 32 | def test(model, loader, opt): 33 | Is = np.empty((len(loader), opt.n_classes)) 34 | Us = np.empty((len(loader), opt.n_classes)) 35 | 36 | model.eval() 37 | with torch.no_grad(): 38 | for i, data in enumerate(tqdm(loader)): 39 | data = data.to(opt.device) 40 | inputs = torch.cat((data.pos.transpose(2, 1).unsqueeze(3), data.x.transpose(2, 1).unsqueeze(3)), 1) 41 | gt = data.y 42 | 43 | out = model(inputs) 44 | pred = out.max(dim=1)[1] 45 | 46 | pred_np = pred.cpu().numpy() 47 | target_np = gt.cpu().numpy() 48 | 49 | for cl in range(opt.n_classes): 50 | cur_gt_mask = (target_np == cl) 51 | cur_pred_mask = (pred_np == cl) 52 | I = np.sum(np.logical_and(cur_pred_mask, cur_gt_mask), dtype=np.float32) 53 | U = np.sum(np.logical_or(cur_pred_mask, cur_gt_mask), dtype=np.float32) 54 | Is[i, cl] = I 55 | Us[i, cl] = U 56 | 57 | ious = np.divide(np.sum(Is, 0), np.sum(Us, 0)) 58 | ious[np.isnan(ious)] = 1 59 | for cl in range(opt.n_classes): 60 | logging.info("===> mIOU for class {}: {}".format(cl, ious[cl])) 61 | logging.info("===> mIOU is {}".format(np.mean(ious))) 62 | 63 | 64 | if __name__ == '__main__': 65 | main() 66 | 67 | 68 | -------------------------------------------------------------------------------- /examples/sem_seg_dense/train.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | import numpy as np 3 | import torch 4 | import torch_geometric.datasets as GeoData 5 | from torch_geometric.data import DenseDataLoader 6 | import torch_geometric.transforms as T 7 | from torch.nn import DataParallel 8 | from config import OptInit 9 | from architecture import DenseDeepGCN 10 | from utils.ckpt_util import load_pretrained_models, load_pretrained_optimizer, save_checkpoint 11 | from utils.metrics import AverageMeter 12 | import logging 13 | from tqdm import tqdm 14 | 15 | 16 | def main(): 17 | opt = OptInit().get_args() 18 | logging.info('===> Creating dataloader ...') 19 | train_dataset = GeoData.S3DIS(opt.data_dir, opt.area, True, pre_transform=T.NormalizeScale()) 20 | train_loader = DenseDataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=4) 21 | test_dataset = GeoData.S3DIS(opt.data_dir, opt.area, train=False, pre_transform=T.NormalizeScale()) 22 | test_loader = DenseDataLoader(test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=0) 23 | opt.n_classes = train_loader.dataset.num_classes 24 | 25 | logging.info('===> Loading the network ...') 26 | model = DenseDeepGCN(opt).to(opt.device) 27 | if opt.multi_gpus: 28 | model = DataParallel(DenseDeepGCN(opt)).to(opt.device) 29 | 30 | logging.info('===> loading pre-trained ...') 31 | model, opt.best_value, opt.epoch = load_pretrained_models(model, opt.pretrained_model, opt.phase) 32 | logging.info(model) 33 | 34 | logging.info('===> Init the optimizer ...') 35 | criterion = torch.nn.CrossEntropyLoss().to(opt.device) 36 | optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr) 37 | 38 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, opt.lr_adjust_freq, opt.lr_decay_rate) 39 | optimizer, scheduler, opt.lr = load_pretrained_optimizer(opt.pretrained_model, optimizer, scheduler, opt.lr) 40 | 41 | logging.info('===> Init Metric ...') 42 | opt.losses = AverageMeter() 43 | opt.test_value = 0. 44 | 45 | logging.info('===> start training ...') 46 | for _ in range(opt.epoch, opt.total_epochs): 47 | opt.epoch += 1 48 | logging.info('Epoch:{}'.format(opt.epoch)) 49 | train(model, train_loader, optimizer, criterion, opt) 50 | if opt.epoch % opt.eval_freq == 0 and opt.eval_freq != -1: 51 | test(model, test_loader, opt) 52 | scheduler.step() 53 | 54 | # ------------------ save checkpoints 55 | # min or max. based on the metrics 56 | is_best = (opt.test_value < opt.best_value) 57 | opt.best_value = max(opt.test_value, opt.best_value) 58 | model_cpu = {k: v.cpu() for k, v in model.state_dict().items()} 59 | save_checkpoint({ 60 | 'epoch': opt.epoch, 61 | 'state_dict': model_cpu, 62 | 'optimizer_state_dict': optimizer.state_dict(), 63 | 'scheduler_state_dict': scheduler.state_dict(), 64 | 'best_value': opt.best_value, 65 | }, is_best, opt.ckpt_dir, opt.exp_name) 66 | 67 | # ------------------ tensorboard log 68 | info = { 69 | 'loss': opt.losses.avg, 70 | 'test_value': opt.test_value, 71 | 'lr': scheduler.get_lr()[0] 72 | } 73 | opt.writer.add_scalars('epoch', info, opt.iter) 74 | 75 | logging.info('Saving the final model.Finish!') 76 | 77 | 78 | def train(model, train_loader, optimizer, criterion, opt): 79 | opt.losses.reset() 80 | model.train() 81 | with tqdm(train_loader) as tqdm_loader: 82 | for i, data in enumerate(tqdm_loader): 83 | opt.iter += 1 84 | desc = 'Epoch:{} Iter:{} [{}/{}] Loss:{Losses.avg: .4f}'\ 85 | .format(opt.epoch, opt.iter, i + 1, len(train_loader), Losses=opt.losses) 86 | tqdm_loader.set_description(desc) 87 | 88 | if not opt.multi_gpus: 89 | data = data.to(opt.device) 90 | inputs = torch.cat((data.pos.transpose(2, 1).unsqueeze(3), data.x.transpose(2, 1).unsqueeze(3)), 1) 91 | gt = data.y.to(opt.device) 92 | # ------------------ zero, output, loss 93 | optimizer.zero_grad() 94 | out = model(inputs) 95 | loss = criterion(out, gt) 96 | 97 | # ------------------ optimization 98 | loss.backward() 99 | optimizer.step() 100 | 101 | opt.losses.update(loss.item()) 102 | 103 | 104 | def test(model, loader, opt): 105 | Is = np.empty((len(loader), opt.n_classes)) 106 | Us = np.empty((len(loader), opt.n_classes)) 107 | 108 | model.eval() 109 | with torch.no_grad(): 110 | for i, data in enumerate(tqdm(loader)): 111 | if not opt.multi_gpus: 112 | data = data.to(opt.device) 113 | inputs = torch.cat((data.pos.transpose(2, 1).unsqueeze(3), data.x.transpose(2, 1).unsqueeze(3)), 1) 114 | gt = data.y 115 | 116 | out = model(inputs) 117 | pred = out.max(dim=1)[1] 118 | 119 | pred_np = pred.cpu().numpy() 120 | target_np = gt.cpu().numpy() 121 | 122 | for cl in range(opt.n_classes): 123 | cur_gt_mask = (target_np == cl) 124 | cur_pred_mask = (pred_np == cl) 125 | I = np.sum(np.logical_and(cur_pred_mask, cur_gt_mask), dtype=np.float32) 126 | U = np.sum(np.logical_or(cur_pred_mask, cur_gt_mask), dtype=np.float32) 127 | Is[i, cl] = I 128 | Us[i, cl] = U 129 | 130 | ious = np.divide(np.sum(Is, 0), np.sum(Us, 0)) 131 | ious[np.isnan(ious)] = 1 132 | iou = np.mean(ious) 133 | if opt.phase == 'test': 134 | for cl in range(opt.n_classes): 135 | logging.info("===> mIOU for class {}: {}".format(cl, ious[cl])) 136 | 137 | opt.test_value = iou 138 | logging.info('TEST Epoch: [{}]\t mIoU: {:.4f}\t'.format(opt.epoch, opt.test_value)) 139 | 140 | 141 | if __name__ == '__main__': 142 | main() 143 | -------------------------------------------------------------------------------- /examples/sem_seg_sparse/README.md: -------------------------------------------------------------------------------- 1 | ## [Semantic segmentation of indoor scenes](https://arxiv.org/pdf/1904.03751.pdf) 2 | 3 |

4 | 5 |

6 | 7 | Sem_seg_dense and sem_seg_sparse are both for the semantic segmentation task. The difference between them is that the data shape is different. 8 | As for sem_seg_sparse, data shape is N x feature_size and there is a batch variable indicating the batch of each node. 9 | But for sem_seg_dense, data shape is Batch_size x feature_size x nodes_num x 1. 10 | 11 | In gcn_lib, there are two folders: dense and sparse. They are used for different data shapes above. 12 | 13 | 14 | ### Note 15 | We suggest to use sem_seg_dense. It is more efficient. 16 | 17 | ### Train 18 | We keep using 2 Tesla V100 GPUs for distributed training. Run: 19 | ``` 20 | CUDA_VISIBLE_DEVICES=0,1 python examples/sem_seg_sparse/train.py --multi_gpus --phase train --train_path /data/deepgcn/S3DIS 21 | ``` 22 | Note on `--train_path`: Make sure you have the folder. Just need to set `--train_path path/to/data`, dataset will be downloaded automatically. 23 | 24 | If you want to train model with other gcn layers (for example mrgcn), run 25 | ``` 26 | python train.py --conv mr --multi_gpus --phase train --train_path /data/deepgcn/S3DIS 27 | ``` 28 | Other parameters for changing the architecture are: 29 | ``` 30 | --block graph backbone block type {res, dense} 31 | --conv graph conv layer {edge, mr, sage, gin, gcn, gat} 32 | --n_filters number of channels of deep features, default is 64 33 | --n_blocks number of basic blocks, default is 28 34 | ``` 35 | 36 | ### Evaluation 37 | Qucik test on area 5, run: 38 | 39 | ``` 40 | python test.py --pretrained_model checkpoints/densedeepgcn-res-edge-ckpt_50.pth --batch_size 1 --test_path /data/deepgcn/S3DIS 41 | ``` 42 | 43 | #### Visualization 44 | Coming soon!! 45 | 46 | -------------------------------------------------------------------------------- /examples/sem_seg_sparse/__init__.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../')) 3 | 4 | -------------------------------------------------------------------------------- /examples/sem_seg_sparse/architecture.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | import torch 3 | from torch.nn import Linear as Lin 4 | from gcn_lib.sparse import MultiSeq, MLP, GraphConv, PlainDynBlock, ResDynBlock, DenseDynBlock, DilatedKnnGraph 5 | from utils.pyg_util import scatter_ 6 | from torch_geometric.data import Data 7 | 8 | 9 | class SparseDeepGCN(torch.nn.Module): 10 | def __init__(self, opt): 11 | super(SparseDeepGCN, self).__init__() 12 | channels = opt.n_filters 13 | k = opt.k 14 | act = opt.act 15 | norm = opt.norm 16 | bias = opt.bias 17 | epsilon = opt.epsilon 18 | stochastic = opt.stochastic 19 | conv = opt.conv 20 | c_growth = channels 21 | 22 | self.n_blocks = opt.n_blocks 23 | 24 | self.knn = DilatedKnnGraph(k, 1, stochastic, epsilon) 25 | self.head = GraphConv(opt.in_channels, channels, conv, act, norm, bias) 26 | 27 | if opt.block.lower() == 'res': 28 | self.backbone = MultiSeq(*[ResDynBlock(channels, k, 1+i, conv, act, norm, bias, stochastic=stochastic, epsilon=epsilon) 29 | for i in range(self.n_blocks-1)]) 30 | fusion_dims = int(channels + c_growth * (self.n_blocks - 1)) 31 | elif opt.block.lower() == 'dense': 32 | self.backbone = MultiSeq(*[DenseDynBlock(channels+c_growth*i, c_growth, k, 1+i, 33 | conv, act, norm, bias, stochastic=stochastic, epsilon=epsilon) 34 | for i in range(self.n_blocks-1)]) 35 | fusion_dims = int( 36 | (channels + channels + c_growth * (self.n_blocks - 1)) * self.n_blocks // 2) 37 | else: 38 | # Use PlainGCN without skip connection and dilated convolution. 39 | stochastic = False 40 | self.backbone = MultiSeq( 41 | *[PlainDynBlock(channels, k, 1, conv, act, norm, bias, stochastic=stochastic, epsilon=epsilon) 42 | for i in range(self.n_blocks - 1)]) 43 | fusion_dims = int(channels + c_growth * (self.n_blocks - 1)) 44 | 45 | self.fusion_block = MLP([fusion_dims, 1024], act, norm, bias) 46 | self.prediction = MultiSeq(*[MLP([fusion_dims+1024, 512], act, norm, bias), 47 | MLP([512, 256], act, norm, bias, drop=opt.dropout), 48 | MLP([256, opt.n_classes], None, None, bias)]) 49 | self.model_init() 50 | 51 | def model_init(self): 52 | for m in self.modules(): 53 | if isinstance(m, Lin): 54 | torch.nn.init.kaiming_normal_(m.weight) 55 | m.weight.requires_grad = True 56 | if m.bias is not None: 57 | m.bias.data.zero_() 58 | m.bias.requires_grad = True 59 | 60 | def forward(self, data): 61 | corr, color, batch = data.pos, data.x, data.batch 62 | x = torch.cat((corr, color), dim=1) 63 | feats = [self.head(x, self.knn(x[:, 0:3], batch))] 64 | for i in range(self.n_blocks-1): 65 | feats.append(self.backbone[i](feats[-1], batch)[0]) 66 | feats = torch.cat(feats, dim=1) 67 | 68 | fusion = scatter_('max', self.fusion_block(feats), batch) 69 | fusion = torch.repeat_interleave(fusion, repeats=feats.shape[0]//fusion.shape[0], dim=0) 70 | return self.prediction(torch.cat((fusion, feats), dim=1)) 71 | 72 | 73 | if __name__ == "__main__": 74 | import random, numpy as np, argparse 75 | seed = 0 76 | torch.manual_seed(seed) 77 | torch.cuda.manual_seed(seed) 78 | torch.cuda.manual_seed_all(seed) 79 | torch.backends.cudnn.deterministic = True 80 | torch.backends.cudnn.benchmark = False 81 | 82 | batch_size = 2 83 | N = 1024 84 | device = 'cuda' 85 | 86 | parser = argparse.ArgumentParser(description='PyTorch implementation of Deep GCN For semantic segmentation') 87 | parser.add_argument('--in_channels', default=9, type=int, help='input channels (default:9)') 88 | parser.add_argument('--n_classes', default=13, type=int, help='num of segmentation classes (default:13)') 89 | parser.add_argument('--k', default=20, type=int, help='neighbor num (default:16)') 90 | parser.add_argument('--block', default='res', type=str, help='graph backbone block type {plain, res, dense}') 91 | parser.add_argument('--conv', default='edge', type=str, help='graph conv layer {edge, mr}') 92 | parser.add_argument('--act', default='relu', type=str, help='activation layer {relu, prelu, leakyrelu}') 93 | parser.add_argument('--norm', default='batch', type=str, help='{batch, instance} normalization') 94 | parser.add_argument('--bias', default=True, type=bool, help='bias of conv layer True or False') 95 | parser.add_argument('--n_filters', default=64, type=int, help='number of channels of deep features') 96 | parser.add_argument('--n_blocks', default=7, type=int, help='number of basic blocks') 97 | parser.add_argument('--dropout', default=0.5, type=float, help='ratio of dropout') 98 | parser.add_argument('--epsilon', default=0.2, type=float, help='stochastic epsilon for gcn') 99 | parser.add_argument('--stochastic', default=False, type=bool, help='stochastic for gcn, True or False') 100 | args = parser.parse_args() 101 | 102 | pos = torch.rand((batch_size*N, 3), dtype=torch.float).to(device) 103 | x = torch.rand((batch_size*N, 6), dtype=torch.float).to(device) 104 | 105 | data = Data() 106 | data.pos = pos 107 | data.x = x 108 | data.batch = torch.arange(batch_size).unsqueeze(-1).expand(-1, N).contiguous().view(-1).contiguous() 109 | data = data.to(device) 110 | 111 | net = SparseDeepGCN(args).to(device) 112 | print(net) 113 | 114 | out = net(data) 115 | 116 | print('out logits shape', out.shape) 117 | import time 118 | st = time.time() 119 | runs = 1000 120 | 121 | with torch.no_grad(): 122 | for i in range(runs): 123 | # print(i) 124 | out = net(data) 125 | torch.cuda.synchronize() 126 | print(time.time() - st) 127 | 128 | -------------------------------------------------------------------------------- /examples/sem_seg_sparse/script/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda activate deepgcn 3 | python -u test.py --pretrained_model sem_seg/checkpoints/deepgcn-res-edge-190822_ckpt_50.pth --batch_size 1 --test_path /data/deepgcn/S3DIS 4 | 5 | -------------------------------------------------------------------------------- /examples/sem_seg_sparse/script/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | conda activate deepgcn 4 | python -u train.py --multi_gpus --phase train --train_path /data/deepgcn/S3DIS --batch_size 16 5 | echo "===> training loaded Done..." 6 | 7 | -------------------------------------------------------------------------------- /examples/sem_seg_sparse/test.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | from tqdm import tqdm 3 | import numpy as np 4 | import torch 5 | import torch_geometric.datasets as GeoData 6 | from torch_geometric.data import DataLoader 7 | import torch_geometric.transforms as T 8 | from config import OptInit 9 | from architecture import SparseDeepGCN 10 | from utils.ckpt_util import load_pretrained_models 11 | import logging 12 | 13 | 14 | def main(): 15 | opt = OptInit().get_args() 16 | 17 | logging.info('===> Creating dataloader...') 18 | test_dataset = GeoData.S3DIS(opt.data_dir, 5, train=False, pre_transform=T.NormalizeScale()) 19 | test_loader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=0) 20 | opt.n_classes = test_loader.dataset.num_classes 21 | if opt.no_clutter: 22 | opt.n_classes -= 1 23 | 24 | logging.info('===> Loading the network ...') 25 | model = SparseDeepGCN(opt).to(opt.device) 26 | model, opt.best_value, opt.epoch = load_pretrained_models(model, opt.pretrained_model, opt.phase) 27 | 28 | logging.info('===> Start Evaluation ...') 29 | test(model, test_loader, opt) 30 | 31 | 32 | def test(model, loader, opt): 33 | Is = np.empty((len(loader), opt.n_classes)) 34 | Us = np.empty((len(loader), opt.n_classes)) 35 | 36 | model.eval() 37 | with torch.no_grad(): 38 | for i, data in enumerate(tqdm(loader)): 39 | data = data.to(opt.device) 40 | out = model(data) 41 | pred = out.max(dim=1)[1] 42 | 43 | pred_np = pred.cpu().numpy() 44 | target_np = data.y.cpu().numpy() 45 | 46 | for cl in range(opt.n_classes): 47 | I = np.sum(np.logical_and(pred_np == cl, target_np == cl)) 48 | U = np.sum(np.logical_or(pred_np == cl, target_np == cl)) 49 | Is[i, cl] = I 50 | Us[i, cl] = U 51 | 52 | ious = np.divide(np.sum(Is, 0), np.sum(Us, 0)) 53 | ious[np.isnan(ious)] = 1 54 | for cl in range(opt.n_classes): 55 | logging.info("===> mIOU for class {}: {}".format(cl, ious[cl])) 56 | logging.info("===> mIOU is {}".format(np.mean(ious))) 57 | 58 | 59 | if __name__ == '__main__': 60 | main() 61 | 62 | 63 | -------------------------------------------------------------------------------- /examples/sem_seg_sparse/train.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | import torch 3 | import torch_geometric.datasets as GeoData 4 | from torch_geometric.data import DataLoader, DataListLoader 5 | import torch_geometric.transforms as T 6 | from torch_geometric.nn.data_parallel import DataParallel 7 | from config import OptInit 8 | from architecture import SparseDeepGCN 9 | from utils.ckpt_util import load_pretrained_models, load_pretrained_optimizer, save_checkpoint 10 | from utils.metrics import AverageMeter 11 | from utils import optim 12 | import logging 13 | 14 | 15 | def main(): 16 | opt = OptInit().get_args() 17 | logging.info('===> Creating dataloader ...') 18 | train_dataset = GeoData.S3DIS(opt.data_dir, test_area=5, train=True, pre_transform=T.NormalizeScale()) 19 | if opt.multi_gpus: 20 | train_loader = DataListLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=4) 21 | else: 22 | train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=4) 23 | opt.n_classes = train_loader.dataset.num_classes 24 | 25 | logging.info('===> Loading the network ...') 26 | model = SparseDeepGCN(opt).to(opt.device) 27 | if opt.multi_gpus: 28 | model = DataParallel(SparseDeepGCN(opt)).to(opt.device) 29 | logging.info('===> loading pre-trained ...') 30 | model, opt.best_value, opt.epoch = load_pretrained_models(model, opt.pretrained_model, opt.phase) 31 | logging.info(model) 32 | 33 | logging.info('===> Init the optimizer ...') 34 | criterion = torch.nn.CrossEntropyLoss().to(opt.device) 35 | optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr) 36 | 37 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, opt.lr_adjust_freq, opt.lr_decay_rate) 38 | optimizer, scheduler, opt.lr = load_pretrained_optimizer(opt.pretrained_model, optimizer, scheduler, opt.lr) 39 | 40 | logging.info('===> Init Metric ...') 41 | opt.losses = AverageMeter() 42 | # opt.test_metric = miou 43 | # opt.test_values = AverageMeter() 44 | opt.test_value = 0. 45 | 46 | logging.info('===> start training ...') 47 | for _ in range(opt.total_epochs): 48 | opt.epoch += 1 49 | train(model, train_loader, optimizer, scheduler, criterion, opt) 50 | # test_value = test(model, test_loader, test_metric, opt) 51 | scheduler.step() 52 | logging.info('Saving the final model.Finish!') 53 | 54 | 55 | def train(model, train_loader, optimizer, scheduler, criterion, opt): 56 | model.train() 57 | for i, data in enumerate(train_loader): 58 | opt.iter += 1 59 | if not opt.multi_gpus: 60 | data = data.to(opt.device) 61 | gt = data.y 62 | else: 63 | gt = torch.cat([data_batch.y for data_batch in data], 0).to(opt.device) 64 | 65 | # ------------------ zero, output, loss 66 | optimizer.zero_grad() 67 | out = model(data) 68 | loss = criterion(out, gt) 69 | 70 | # ------------------ optimization 71 | loss.backward() 72 | optimizer.step() 73 | 74 | opt.losses.update(loss.item()) 75 | # ------------------ show information 76 | if opt.iter % opt.print_freq == 0: 77 | logging.info('Epoch:{}\t Iter: {}\t [{}/{}]\t Loss: {Losses.avg: .4f}'.format( 78 | opt.epoch, opt.iter, i + 1, len(train_loader), Losses=opt.losses)) 79 | opt.losses.reset() 80 | 81 | # ------------------ tensor board log 82 | info = { 83 | 'loss': loss, 84 | 'test_value': opt.test_value, 85 | 'lr': scheduler.get_lr()[0] 86 | } 87 | for tag, value in info.items(): 88 | opt.writer.scalar_summary(tag, value, opt.iter) 89 | 90 | # ------------------ save checkpoints 91 | # min or max. based on the metrics 92 | is_best = (opt.test_value < opt.best_value) 93 | opt.best_value = min(opt.test_value, opt.best_value) 94 | 95 | model_cpu = {k: v.cpu() for k, v in model.state_dict().items()} 96 | # optim_cpu = {k: v.cpu() for k, v in optimizer.state_dict().items()} 97 | save_checkpoint({ 98 | 'epoch': opt.epoch, 99 | 'state_dict': model_cpu, 100 | 'optimizer_state_dict': optimizer.state_dict(), 101 | 'scheduler_state_dict': scheduler.state_dict(), 102 | 'best_value': opt.best_value, 103 | }, is_best, opt.save_path, opt.post) 104 | 105 | 106 | def test(model, test_loader, test_metric, opt): 107 | opt.test_values.reset() 108 | model.eval() 109 | with torch.no_grad(): 110 | for i, data in enumerate(test_loader): 111 | if not opt.multi_gpus: 112 | data = data.to(opt.device) 113 | gt = data.y 114 | else: 115 | gt = torch.cat([data_batch.y for data_batch in data], 0).to(opt.device) 116 | 117 | out = opt.model(data) 118 | test_value = test_metric(out.max(dim=1)[1], gt, opt.n_classes) 119 | opt.test_values.update(test_value, opt.batch_size) 120 | logging.info('Epoch: [{0}]\t Iter: [{1}]\t''TEST loss: {test_values.avg: .4f})\t'.format( 121 | opt.epoch, opt.iter, test_values=opt.test_values)) 122 | 123 | opt.test_value = opt.test_values.avg 124 | 125 | 126 | if __name__ == '__main__': 127 | main() 128 | 129 | 130 | -------------------------------------------------------------------------------- /gcn_lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightaime/deep_gcns_torch/4f6681eee2290e217bda941b5536452a7c09decb/gcn_lib/__init__.py -------------------------------------------------------------------------------- /gcn_lib/dense/__init__.py: -------------------------------------------------------------------------------- 1 | from .torch_nn import * 2 | from .torch_edge import * 3 | from .torch_vertex import * 4 | 5 | -------------------------------------------------------------------------------- /gcn_lib/dense/torch_edge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_cluster import knn_graph 4 | 5 | 6 | class DenseDilated(nn.Module): 7 | """ 8 | Find dilated neighbor from neighbor list 9 | 10 | edge_index: (2, batch_size, num_points, k) 11 | """ 12 | def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): 13 | super(DenseDilated, self).__init__() 14 | self.dilation = dilation 15 | self.stochastic = stochastic 16 | self.epsilon = epsilon 17 | self.k = k 18 | 19 | def forward(self, edge_index): 20 | if self.stochastic: 21 | if torch.rand(1) < self.epsilon and self.training: 22 | num = self.k * self.dilation 23 | randnum = torch.randperm(num)[:self.k] 24 | edge_index = edge_index[:, :, :, randnum] 25 | else: 26 | edge_index = edge_index[:, :, :, ::self.dilation] 27 | else: 28 | edge_index = edge_index[:, :, :, ::self.dilation] 29 | return edge_index 30 | 31 | 32 | def pairwise_distance(x): 33 | """ 34 | Compute pairwise distance of a point cloud. 35 | Args: 36 | x: tensor (batch_size, num_points, num_dims) 37 | Returns: 38 | pairwise distance: (batch_size, num_points, num_points) 39 | """ 40 | x_inner = -2*torch.matmul(x, x.transpose(2, 1)) 41 | x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True) 42 | return x_square + x_inner + x_square.transpose(2, 1) 43 | 44 | 45 | def dense_knn_matrix(x, k=16): 46 | """Get KNN based on the pairwise distance. 47 | Args: 48 | x: (batch_size, num_dims, num_points, 1) 49 | k: int 50 | Returns: 51 | nearest neighbors: (batch_size, num_points ,k) (batch_size, num_points, k) 52 | """ 53 | with torch.no_grad(): 54 | x = x.transpose(2, 1).squeeze(-1) 55 | batch_size, n_points, n_dims = x.shape 56 | _, nn_idx = torch.topk(-pairwise_distance(x.detach()), k=k) 57 | center_idx = torch.arange(0, n_points, device=x.device).expand(batch_size, k, -1).transpose(2, 1) 58 | return torch.stack((nn_idx, center_idx), dim=0) 59 | 60 | 61 | class DenseDilatedKnnGraph(nn.Module): 62 | """ 63 | Find the neighbors' indices based on dilated knn 64 | """ 65 | def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): 66 | super(DenseDilatedKnnGraph, self).__init__() 67 | self.dilation = dilation 68 | self.stochastic = stochastic 69 | self.epsilon = epsilon 70 | self.k = k 71 | self._dilated = DenseDilated(k, dilation, stochastic, epsilon) 72 | self.knn = dense_knn_matrix 73 | 74 | def forward(self, x): 75 | edge_index = self.knn(x, self.k * self.dilation) 76 | return self._dilated(edge_index) 77 | 78 | 79 | class DilatedKnnGraph(nn.Module): 80 | """ 81 | Find the neighbors' indices based on dilated knn 82 | """ 83 | def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): 84 | super(DilatedKnnGraph, self).__init__() 85 | self.dilation = dilation 86 | self.stochastic = stochastic 87 | self.epsilon = epsilon 88 | self.k = k 89 | self._dilated = DenseDilated(k, dilation, stochastic, epsilon) 90 | self.knn = knn_graph 91 | 92 | def forward(self, x): 93 | x = x.squeeze(-1) 94 | B, C, N = x.shape 95 | edge_index = [] 96 | for i in range(B): 97 | edgeindex = self.knn(x[i].contiguous().transpose(1, 0).contiguous(), self.k * self.dilation) 98 | edgeindex = edgeindex.view(2, N, self.k * self.dilation) 99 | edge_index.append(edgeindex) 100 | edge_index = torch.stack(edge_index, dim=1) 101 | return self._dilated(edge_index) 102 | -------------------------------------------------------------------------------- /gcn_lib/dense/torch_nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Sequential as Seq, Linear as Lin, Conv2d 4 | 5 | 6 | ############################## 7 | # Basic layers 8 | ############################## 9 | def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1): 10 | # activation layer 11 | 12 | act = act.lower() 13 | if act == 'relu': 14 | layer = nn.ReLU(inplace) 15 | elif act == 'leakyrelu': 16 | layer = nn.LeakyReLU(neg_slope, inplace) 17 | elif act == 'prelu': 18 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) 19 | else: 20 | raise NotImplementedError('activation layer [%s] is not found' % act) 21 | return layer 22 | 23 | 24 | def norm_layer(norm, nc): 25 | # normalization layer 2d 26 | norm = norm.lower() 27 | if norm == 'batch': 28 | layer = nn.BatchNorm2d(nc, affine=True) 29 | elif norm == 'instance': 30 | layer = nn.InstanceNorm2d(nc, affine=False) 31 | else: 32 | raise NotImplementedError('normalization layer [%s] is not found' % norm) 33 | return layer 34 | 35 | 36 | class MLP(Seq): 37 | def __init__(self, channels, act='relu', norm=None, bias=True): 38 | m = [] 39 | for i in range(1, len(channels)): 40 | m.append(Lin(channels[i - 1], channels[i], bias)) 41 | if act is not None and act.lower() != 'none': 42 | m.append(act_layer(act)) 43 | if norm is not None and norm.lower() != 'none': 44 | m.append(norm_layer(norm, channels[-1])) 45 | super(MLP, self).__init__(*m) 46 | 47 | 48 | class BasicConv(Seq): 49 | def __init__(self, channels, act='relu', norm=None, bias=True, drop=0.): 50 | m = [] 51 | for i in range(1, len(channels)): 52 | m.append(Conv2d(channels[i - 1], channels[i], 1, bias=bias)) 53 | if act is not None and act.lower() != 'none': 54 | m.append(act_layer(act)) 55 | if norm is not None and norm.lower() != 'none': 56 | m.append(norm_layer(norm, channels[-1])) 57 | if drop > 0: 58 | m.append(nn.Dropout2d(drop)) 59 | 60 | super(BasicConv, self).__init__(*m) 61 | 62 | self.reset_parameters() 63 | 64 | def reset_parameters(self): 65 | for m in self.modules(): 66 | if isinstance(m, nn.Conv2d): 67 | nn.init.kaiming_normal_(m.weight) 68 | if m.bias is not None: 69 | nn.init.zeros_(m.bias) 70 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): 71 | m.weight.data.fill_(1) 72 | m.bias.data.zero_() 73 | 74 | 75 | def batched_index_select(x, idx): 76 | r"""fetches neighbors features from a given neighbor idx 77 | 78 | Args: 79 | x (Tensor): input feature Tensor 80 | :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times 1}`. 81 | idx (Tensor): edge_idx 82 | :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times l}`. 83 | Returns: 84 | Tensor: output neighbors features 85 | :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times k}`. 86 | """ 87 | batch_size, num_dims, num_vertices = x.shape[:3] 88 | k = idx.shape[-1] 89 | idx_base = torch.arange(0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices 90 | idx = idx + idx_base 91 | idx = idx.contiguous().view(-1) 92 | 93 | x = x.transpose(2, 1) 94 | feature = x.contiguous().view(batch_size * num_vertices, -1)[idx, :] 95 | feature = feature.view(batch_size, num_vertices, k, num_dims).permute(0, 3, 1, 2).contiguous() 96 | return feature 97 | -------------------------------------------------------------------------------- /gcn_lib/dense/torch_vertex.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .torch_nn import BasicConv, batched_index_select 4 | from .torch_edge import DenseDilatedKnnGraph, DilatedKnnGraph 5 | import torch.nn.functional as F 6 | 7 | 8 | class MRConv2d(nn.Module): 9 | """ 10 | Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type 11 | """ 12 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True): 13 | super(MRConv2d, self).__init__() 14 | self.nn = BasicConv([in_channels*2, out_channels], act, norm, bias) 15 | 16 | def forward(self, x, edge_index): 17 | x_i = batched_index_select(x, edge_index[1]) 18 | x_j = batched_index_select(x, edge_index[0]) 19 | x_j, _ = torch.max(x_j - x_i, -1, keepdim=True) 20 | return self.nn(torch.cat([x, x_j], dim=1)) 21 | 22 | 23 | class EdgeConv2d(nn.Module): 24 | """ 25 | Edge convolution layer (with activation, batch normalization) for dense data type 26 | """ 27 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True): 28 | super(EdgeConv2d, self).__init__() 29 | self.nn = BasicConv([in_channels * 2, out_channels], act, norm, bias) 30 | 31 | def forward(self, x, edge_index): 32 | x_i = batched_index_select(x, edge_index[1]) 33 | x_j = batched_index_select(x, edge_index[0]) 34 | max_value, _ = torch.max(self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True) 35 | return max_value 36 | 37 | 38 | class GraphConv2d(nn.Module): 39 | """ 40 | Static graph convolution layer 41 | """ 42 | def __init__(self, in_channels, out_channels, conv='edge', act='relu', norm=None, bias=True): 43 | super(GraphConv2d, self).__init__() 44 | if conv == 'edge': 45 | self.gconv = EdgeConv2d(in_channels, out_channels, act, norm, bias) 46 | elif conv == 'mr': 47 | self.gconv = MRConv2d(in_channels, out_channels, act, norm, bias) 48 | else: 49 | raise NotImplementedError('conv:{} is not supported'.format(conv)) 50 | 51 | def forward(self, x, edge_index): 52 | return self.gconv(x, edge_index) 53 | 54 | 55 | class DynConv2d(GraphConv2d): 56 | """ 57 | Dynamic graph convolution layer 58 | """ 59 | def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='edge', act='relu', 60 | norm=None, bias=True, stochastic=False, epsilon=0.0, knn='matrix'): 61 | super(DynConv2d, self).__init__(in_channels, out_channels, conv, act, norm, bias) 62 | self.k = kernel_size 63 | self.d = dilation 64 | if knn == 'matrix': 65 | self.dilated_knn_graph = DenseDilatedKnnGraph(kernel_size, dilation, stochastic, epsilon) 66 | else: 67 | self.dilated_knn_graph = DilatedKnnGraph(kernel_size, dilation, stochastic, epsilon) 68 | 69 | def forward(self, x, edge_index=None): 70 | if edge_index is None: 71 | edge_index = self.dilated_knn_graph(x) 72 | return super(DynConv2d, self).forward(x, edge_index) 73 | 74 | 75 | class PlainDynBlock2d(nn.Module): 76 | """ 77 | Plain Dynamic graph convolution block 78 | """ 79 | def __init__(self, in_channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, 80 | bias=True, stochastic=False, epsilon=0.0, knn='matrix'): 81 | super(PlainDynBlock2d, self).__init__() 82 | self.body = DynConv2d(in_channels, in_channels, kernel_size, dilation, conv, 83 | act, norm, bias, stochastic, epsilon, knn) 84 | 85 | def forward(self, x, edge_index=None): 86 | return self.body(x, edge_index) 87 | 88 | 89 | class ResDynBlock2d(nn.Module): 90 | """ 91 | Residual Dynamic graph convolution block 92 | """ 93 | def __init__(self, in_channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, 94 | bias=True, stochastic=False, epsilon=0.0, knn='matrix', res_scale=1): 95 | super(ResDynBlock2d, self).__init__() 96 | self.body = DynConv2d(in_channels, in_channels, kernel_size, dilation, conv, 97 | act, norm, bias, stochastic, epsilon, knn) 98 | self.res_scale = res_scale 99 | 100 | def forward(self, x, edge_index=None): 101 | return self.body(x, edge_index) + x*self.res_scale 102 | 103 | 104 | class DenseDynBlock2d(nn.Module): 105 | """ 106 | Dense Dynamic graph convolution block 107 | """ 108 | def __init__(self, in_channels, out_channels=64, kernel_size=9, dilation=1, conv='edge', 109 | act='relu', norm=None,bias=True, stochastic=False, epsilon=0.0, knn='matrix'): 110 | super(DenseDynBlock2d, self).__init__() 111 | self.body = DynConv2d(in_channels, out_channels, kernel_size, dilation, conv, 112 | act, norm, bias, stochastic, epsilon, knn) 113 | 114 | def forward(self, x, edge_index=None): 115 | dense = self.body(x, edge_index) 116 | return torch.cat((x, dense), 1) 117 | -------------------------------------------------------------------------------- /gcn_lib/sparse/__init__.py: -------------------------------------------------------------------------------- 1 | from .torch_nn import * 2 | from .torch_edge import * 3 | from .torch_vertex import * 4 | 5 | -------------------------------------------------------------------------------- /gcn_lib/sparse/torch_edge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_cluster import knn_graph 4 | 5 | 6 | class Dilated(nn.Module): 7 | """ 8 | Find dilated neighbor from neighbor list 9 | """ 10 | def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): 11 | super(Dilated, self).__init__() 12 | self.dilation = dilation 13 | self.stochastic = stochastic 14 | self.epsilon = epsilon 15 | self.k = k 16 | 17 | def forward(self, edge_index, batch=None): 18 | if self.stochastic: 19 | if torch.rand(1) < self.epsilon and self.training: 20 | num = self.k * self.dilation 21 | randnum = torch.randperm(num)[:self.k] 22 | edge_index = edge_index.view(2, -1, num) 23 | edge_index = edge_index[:, :, randnum] 24 | return edge_index.view(2, -1) 25 | else: 26 | edge_index = edge_index[:, ::self.dilation] 27 | else: 28 | edge_index = edge_index[:, ::self.dilation] 29 | return edge_index 30 | 31 | 32 | class DilatedKnnGraph(nn.Module): 33 | """ 34 | Find the neighbors' indices based on dilated knn 35 | """ 36 | def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0, knn='matrix'): 37 | super(DilatedKnnGraph, self).__init__() 38 | self.dilation = dilation 39 | self.stochastic = stochastic 40 | self.epsilon = epsilon 41 | self.k = k 42 | self._dilated = Dilated(k, dilation, stochastic, epsilon) 43 | if knn == 'matrix': 44 | self.knn = knn_graph_matrix 45 | else: 46 | self.knn = knn_graph 47 | 48 | def forward(self, x, batch): 49 | edge_index = self.knn(x, self.k * self.dilation, batch) 50 | return self._dilated(edge_index, batch) 51 | 52 | 53 | def pairwise_distance(x): 54 | """ 55 | Compute pairwise distance of a point cloud. 56 | Args: 57 | x: tensor (batch_size, num_points, num_dims) 58 | Returns: 59 | pairwise distance: (batch_size, num_points, num_points) 60 | """ 61 | x_inner = -2*torch.matmul(x, x.transpose(2, 1)) 62 | x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True) 63 | return x_square + x_inner + x_square.transpose(2, 1) 64 | 65 | 66 | def knn_matrix(x, k=16, batch=None): 67 | """Get KNN based on the pairwise distance. 68 | Args: 69 | pairwise distance: (num_points, num_points) 70 | k: int 71 | Returns: 72 | nearest neighbors: (num_points*k ,1) (num_points, k) 73 | """ 74 | with torch.no_grad(): 75 | if batch is None: 76 | batch_size = 1 77 | else: 78 | batch_size = batch[-1] + 1 79 | x = x.view(batch_size, -1, x.shape[-1]) 80 | 81 | neg_adj = -pairwise_distance(x.detach()) 82 | 83 | _, nn_idx = torch.topk(neg_adj, k=k) 84 | 85 | n_points = x.shape[1] 86 | start_idx = torch.arange(0, n_points * batch_size, n_points, device=x.device).view(batch_size, 1, 1) 87 | nn_idx += start_idx 88 | 89 | nn_idx = nn_idx.view(1, -1) 90 | center_idx = torch.arange(0, n_points*batch_size, device=x.device).expand(k, -1).transpose(1, 0).contiguous().view(1, -1) 91 | return nn_idx, center_idx 92 | 93 | 94 | def knn_graph_matrix(x, k=16, batch=None): 95 | """Construct edge feature for each point 96 | Args: 97 | x: (num_points, num_dims) 98 | batch: (num_points, ) 99 | k: int 100 | Returns: 101 | edge_index: (2, num_points*k) 102 | """ 103 | nn_idx, center_idx = knn_matrix(x, k, batch) 104 | return torch.cat((nn_idx, center_idx), dim=0) 105 | 106 | -------------------------------------------------------------------------------- /gcn_lib/sparse/torch_message.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch_geometric.nn import MessagePassing 4 | from torch_scatter import scatter, scatter_softmax 5 | from torch_geometric.utils import degree 6 | 7 | 8 | class GenMessagePassing(MessagePassing): 9 | def __init__(self, aggr='softmax', 10 | t=1.0, learn_t=False, 11 | p=1.0, learn_p=False, 12 | y=0.0, learn_y=False): 13 | 14 | if aggr in ['softmax_sg', 'softmax', 'softmax_sum']: 15 | 16 | super(GenMessagePassing, self).__init__(aggr=None) 17 | self.aggr = aggr 18 | 19 | if learn_t and (aggr == 'softmax' or aggr == 'softmax_sum'): 20 | self.learn_t = True 21 | self.t = torch.nn.Parameter(torch.Tensor([t]), requires_grad=True) 22 | else: 23 | self.learn_t = False 24 | self.t = t 25 | 26 | if aggr == 'softmax_sum': 27 | self.y = torch.nn.Parameter(torch.Tensor([y]), requires_grad=learn_y) 28 | 29 | elif aggr in ['power', 'power_sum']: 30 | 31 | super(GenMessagePassing, self).__init__(aggr=None) 32 | self.aggr = aggr 33 | 34 | if learn_p: 35 | self.p = torch.nn.Parameter(torch.Tensor([p]), requires_grad=True) 36 | else: 37 | self.p = p 38 | 39 | if aggr == 'power_sum': 40 | self.y = torch.nn.Parameter(torch.Tensor([y]), requires_grad=learn_y) 41 | else: 42 | super(GenMessagePassing, self).__init__(aggr=aggr) 43 | 44 | def aggregate(self, inputs, index, ptr=None, dim_size=None): 45 | 46 | if self.aggr in ['add', 'mean', 'max', None]: 47 | return super(GenMessagePassing, self).aggregate(inputs, index, ptr, dim_size) 48 | 49 | elif self.aggr in ['softmax_sg', 'softmax', 'softmax_sum']: 50 | 51 | if self.learn_t: 52 | out = scatter_softmax(inputs*self.t, index, dim=self.node_dim) 53 | else: 54 | with torch.no_grad(): 55 | out = scatter_softmax(inputs*self.t, index, dim=self.node_dim) 56 | 57 | out = scatter(inputs*out, index, dim=self.node_dim, 58 | dim_size=dim_size, reduce='sum') 59 | 60 | if self.aggr == 'softmax_sum': 61 | self.sigmoid_y = torch.sigmoid(self.y) 62 | degrees = degree(index, num_nodes=dim_size).unsqueeze(1) 63 | out = torch.pow(degrees, self.sigmoid_y) * out 64 | 65 | return out 66 | 67 | 68 | elif self.aggr in ['power', 'power_sum']: 69 | min_value, max_value = 1e-7, 1e1 70 | torch.clamp_(inputs, min_value, max_value) 71 | out = scatter(torch.pow(inputs, self.p), index, dim=self.node_dim, 72 | dim_size=dim_size, reduce='mean') 73 | torch.clamp_(out, min_value, max_value) 74 | out = torch.pow(out, 1/self.p) 75 | # torch.clamp(out, min_value, max_value) 76 | 77 | if self.aggr == 'power_sum': 78 | self.sigmoid_y = torch.sigmoid(self.y) 79 | degrees = degree(index, num_nodes=dim_size).unsqueeze(1) 80 | out = torch.pow(degrees, self.sigmoid_y) * out 81 | 82 | return out 83 | 84 | else: 85 | raise NotImplementedError('To be implemented') 86 | 87 | 88 | class MsgNorm(torch.nn.Module): 89 | def __init__(self, learn_msg_scale=False): 90 | super(MsgNorm, self).__init__() 91 | 92 | self.msg_scale = torch.nn.Parameter(torch.Tensor([1.0]), 93 | requires_grad=learn_msg_scale) 94 | 95 | def forward(self, x, msg, p=2): 96 | msg = F.normalize(msg, p=p, dim=1) 97 | x_norm = x.norm(p=p, dim=1, keepdim=True) 98 | msg = msg * x_norm * self.msg_scale 99 | return msg 100 | -------------------------------------------------------------------------------- /gcn_lib/sparse/torch_nn.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import Sequential as Seq, Linear as Lin 3 | from utils.data_util import get_atom_feature_dims, get_bond_feature_dims 4 | 5 | 6 | ############################## 7 | # Basic layers 8 | ############################## 9 | def act_layer(act_type, inplace=False, neg_slope=0.2, n_prelu=1): 10 | # activation layer 11 | act = act_type.lower() 12 | if act == 'relu': 13 | layer = nn.ReLU(inplace) 14 | elif act == 'leakyrelu': 15 | layer = nn.LeakyReLU(neg_slope, inplace) 16 | elif act == 'prelu': 17 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) 18 | else: 19 | raise NotImplementedError('activation layer [%s] is not found' % act) 20 | return layer 21 | 22 | 23 | def norm_layer(norm_type, nc): 24 | # normalization layer 1d 25 | norm = norm_type.lower() 26 | if norm == 'batch': 27 | layer = nn.BatchNorm1d(nc, affine=True) 28 | elif norm == 'layer': 29 | layer = nn.LayerNorm(nc, elementwise_affine=True) 30 | elif norm == 'instance': 31 | layer = nn.InstanceNorm1d(nc, affine=False) 32 | else: 33 | raise NotImplementedError('normalization layer [%s] is not found' % norm) 34 | return layer 35 | 36 | 37 | class MultiSeq(Seq): 38 | def __init__(self, *args): 39 | super(MultiSeq, self).__init__(*args) 40 | 41 | def forward(self, *inputs): 42 | for module in self._modules.values(): 43 | if type(inputs) == tuple: 44 | inputs = module(*inputs) 45 | else: 46 | inputs = module(inputs) 47 | return inputs 48 | 49 | 50 | class MLP(Seq): 51 | def __init__(self, channels, act='relu', 52 | norm=None, bias=True, 53 | drop=0., last_lin=False): 54 | m = [] 55 | 56 | for i in range(1, len(channels)): 57 | 58 | m.append(Lin(channels[i - 1], channels[i], bias)) 59 | 60 | if (i == len(channels) - 1) and last_lin: 61 | pass 62 | else: 63 | if norm is not None and norm.lower() != 'none': 64 | m.append(norm_layer(norm, channels[i])) 65 | if act is not None and act.lower() != 'none': 66 | m.append(act_layer(act)) 67 | if drop > 0: 68 | m.append(nn.Dropout2d(drop)) 69 | 70 | self.m = m 71 | super(MLP, self).__init__(*self.m) 72 | 73 | 74 | class AtomEncoder(nn.Module): 75 | 76 | def __init__(self, emb_dim): 77 | super(AtomEncoder, self).__init__() 78 | 79 | self.atom_embedding_list = nn.ModuleList() 80 | full_atom_feature_dims = get_atom_feature_dims() 81 | 82 | for i, dim in enumerate(full_atom_feature_dims): 83 | emb = nn.Embedding(dim, emb_dim) 84 | nn.init.xavier_uniform_(emb.weight.data) 85 | self.atom_embedding_list.append(emb) 86 | 87 | def forward(self, x): 88 | x_embedding = 0 89 | for i in range(x.shape[1]): 90 | x_embedding += self.atom_embedding_list[i](x[:, i]) 91 | 92 | return x_embedding 93 | 94 | 95 | class BondEncoder(nn.Module): 96 | 97 | def __init__(self, emb_dim): 98 | super(BondEncoder, self).__init__() 99 | 100 | self.bond_embedding_list = nn.ModuleList() 101 | full_bond_feature_dims = get_bond_feature_dims() 102 | 103 | for i, dim in enumerate(full_bond_feature_dims): 104 | emb = nn.Embedding(dim, emb_dim) 105 | nn.init.xavier_uniform_(emb.weight.data) 106 | self.bond_embedding_list.append(emb) 107 | 108 | def forward(self, edge_attr): 109 | bond_embedding = 0 110 | for i in range(edge_attr.shape[1]): 111 | bond_embedding += self.bond_embedding_list[i](edge_attr[:, i]) 112 | 113 | return bond_embedding 114 | 115 | 116 | -------------------------------------------------------------------------------- /misc/deeper_gcn_intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightaime/deep_gcns_torch/4f6681eee2290e217bda941b5536452a7c09decb/misc/deeper_gcn_intro.png -------------------------------------------------------------------------------- /misc/deeper_power_mean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightaime/deep_gcns_torch/4f6681eee2290e217bda941b5536452a7c09decb/misc/deeper_power_mean.png -------------------------------------------------------------------------------- /misc/deeper_softmax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightaime/deep_gcns_torch/4f6681eee2290e217bda941b5536452a7c09decb/misc/deeper_softmax.png -------------------------------------------------------------------------------- /misc/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightaime/deep_gcns_torch/4f6681eee2290e217bda941b5536452a7c09decb/misc/intro.png -------------------------------------------------------------------------------- /misc/modelnet_cls.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightaime/deep_gcns_torch/4f6681eee2290e217bda941b5536452a7c09decb/misc/modelnet_cls.png -------------------------------------------------------------------------------- /misc/part_sem_seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightaime/deep_gcns_torch/4f6681eee2290e217bda941b5536452a7c09decb/misc/part_sem_seg.png -------------------------------------------------------------------------------- /misc/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightaime/deep_gcns_torch/4f6681eee2290e217bda941b5536452a7c09decb/misc/pipeline.png -------------------------------------------------------------------------------- /misc/ppi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightaime/deep_gcns_torch/4f6681eee2290e217bda941b5536452a7c09decb/misc/ppi.png -------------------------------------------------------------------------------- /misc/sem_seg_s3dis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightaime/deep_gcns_torch/4f6681eee2290e217bda941b5536452a7c09decb/misc/sem_seg_s3dis.png -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .ckpt_util import * 2 | # from .data_util import * 3 | from .loss import * 4 | from .metrics import * 5 | from .optim import * 6 | # from .tf_logger import * 7 | 8 | -------------------------------------------------------------------------------- /utils/ckpt_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import shutil 4 | from collections import OrderedDict 5 | import logging 6 | import numpy as np 7 | 8 | 9 | def save_ckpt(model, optimizer, loss, epoch, save_path, name_pre, name_post='best'): 10 | model_cpu = {k: v.cpu() for k, v in model.state_dict().items()} 11 | state = { 12 | 'epoch': epoch, 13 | 'model_state_dict': model_cpu, 14 | 'optimizer_state_dict': optimizer.state_dict(), 15 | 'loss': loss 16 | } 17 | 18 | if not os.path.exists(save_path): 19 | os.mkdir(save_path) 20 | print("Directory ", save_path, " is created.") 21 | 22 | filename = '{}/{}_{}.pth'.format(save_path, name_pre, name_post) 23 | torch.save(state, filename) 24 | print('model has been saved as {}'.format(filename)) 25 | 26 | 27 | def load_pretrained_models(model, pretrained_model, phase, ismax=True): # ismax means max best 28 | if ismax: 29 | best_value = -np.inf 30 | else: 31 | best_value = np.inf 32 | epoch = -1 33 | 34 | if pretrained_model: 35 | if os.path.isfile(pretrained_model): 36 | logging.info("===> Loading checkpoint '{}'".format(pretrained_model)) 37 | checkpoint = torch.load(pretrained_model) 38 | try: 39 | best_value = checkpoint['best_value'] 40 | if best_value == -np.inf or best_value == np.inf: 41 | show_best_value = False 42 | else: 43 | show_best_value = True 44 | except: 45 | best_value = best_value 46 | show_best_value = False 47 | 48 | model_dict = model.state_dict() 49 | ckpt_model_state_dict = checkpoint['state_dict'] 50 | 51 | # rename ckpt (avoid name is not same because of multi-gpus) 52 | is_model_multi_gpus = True if list(model_dict)[0][0][0] == 'm' else False 53 | is_ckpt_multi_gpus = True if list(ckpt_model_state_dict)[0][0] == 'm' else False 54 | 55 | if not (is_model_multi_gpus == is_ckpt_multi_gpus): 56 | temp_dict = OrderedDict() 57 | for k, v in ckpt_model_state_dict.items(): 58 | if is_ckpt_multi_gpus: 59 | name = k[7:] # remove 'module.' 60 | else: 61 | name = 'module.'+k # add 'module' 62 | temp_dict[name] = v 63 | # load params 64 | ckpt_model_state_dict = temp_dict 65 | 66 | model_dict.update(ckpt_model_state_dict) 67 | model.load_state_dict(ckpt_model_state_dict) 68 | 69 | if show_best_value: 70 | logging.info("The pretrained_model is at checkpoint {}. \t " 71 | "Best value: {}".format(checkpoint['epoch'], best_value)) 72 | else: 73 | logging.info("The pretrained_model is at checkpoint {}.".format(checkpoint['epoch'])) 74 | 75 | if phase == 'train': 76 | epoch = checkpoint['epoch'] 77 | else: 78 | epoch = -1 79 | else: 80 | raise ImportError("===> No checkpoint found at '{}'".format(pretrained_model)) 81 | else: 82 | logging.info('===> No pre-trained model') 83 | return model, best_value, epoch 84 | 85 | 86 | def load_pretrained_optimizer(pretrained_model, optimizer, scheduler, lr, use_ckpt_lr=True): 87 | if pretrained_model: 88 | if os.path.isfile(pretrained_model): 89 | checkpoint = torch.load(pretrained_model) 90 | if 'optimizer_state_dict' in checkpoint.keys(): 91 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 92 | for state in optimizer.state.values(): 93 | for k, v in state.items(): 94 | if torch.is_tensor(v): 95 | state[k] = v.cuda() 96 | if 'scheduler_state_dict' in checkpoint.keys(): 97 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 98 | if use_ckpt_lr: 99 | try: 100 | lr = scheduler.get_lr()[0] 101 | except: 102 | lr = lr 103 | 104 | return optimizer, scheduler, lr 105 | 106 | 107 | def save_checkpoint(state, is_best, save_path, postname): 108 | filename = '{}/{}_{}.pth'.format(save_path, postname, int(state['epoch'])) 109 | torch.save(state, filename) 110 | if is_best: 111 | shutil.copyfile(filename, '{}/{}_best.pth'.format(save_path, postname)) 112 | 113 | 114 | def change_ckpt_dict(model, optimizer, scheduler, opt): 115 | 116 | for _ in range(opt.epoch): 117 | scheduler.step() 118 | is_best = (opt.test_value < opt.best_value) 119 | opt.best_value = min(opt.test_value, opt.best_value) 120 | 121 | model_cpu = {k: v.cpu() for k, v in model.state_dict().items()} 122 | # optim_cpu = {k: v.cpu() for k, v in optimizer.state_dict().items()} 123 | save_checkpoint({ 124 | 'epoch': opt.epoch, 125 | 'state_dict': model_cpu, 126 | 'optimizer_state_dict': optimizer.state_dict(), 127 | 'scheduler_state_dict': scheduler.state_dict(), 128 | 'best_value': opt.best_value, 129 | }, is_best, opt.save_path, opt.post) 130 | 131 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import csv 4 | 5 | 6 | def save_best_result(list_of_dict, file_name, dir_path='best_result'): 7 | if not os.path.exists(dir_path): 8 | os.mkdir(dir_path) 9 | print("Directory ", dir_path, " is created.") 10 | csv_file_name = '{}/{}.csv'.format(dir_path, file_name) 11 | with open(csv_file_name, 'a+') as csv_file: 12 | csv_writer = csv.writer(csv_file) 13 | for _ in range(len(list_of_dict)): 14 | csv_writer.writerow(list_of_dict[_].values()) 15 | 16 | 17 | def create_exp_dir(path, scripts_to_save=None): 18 | if not os.path.exists(path): 19 | os.makedirs(path) 20 | print('Experiment dir : {}'.format(path)) 21 | 22 | if scripts_to_save is not None: 23 | os.mkdir(os.path.join(path, 'scripts')) 24 | for script in scripts_to_save: 25 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 26 | shutil.copyfile(script, dst_file) 27 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class SmoothCrossEntropy(torch.nn.Module): 6 | def __init__(self, smoothing=True, eps=0.2): 7 | super(SmoothCrossEntropy, self).__init__() 8 | self.smoothing = smoothing 9 | self.eps = eps 10 | 11 | def forward(self, pred, gt): 12 | gt = gt.contiguous().view(-1) 13 | 14 | if self.smoothing: 15 | n_class = pred.size(1) 16 | one_hot = torch.zeros_like(pred).scatter(1, gt.view(-1, 1), 1) 17 | one_hot = one_hot * (1 - self.eps) + (1 - one_hot) * self.eps / (n_class - 1) 18 | log_prb = F.log_softmax(pred, dim=1) 19 | 20 | loss = -(one_hot * log_prb).sum(dim=1).mean() 21 | else: 22 | loss = F.cross_entropy(pred, gt, reduction='mean') 23 | 24 | return loss 25 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | from math import log10 2 | 3 | 4 | def PSNR(mse, peak=1.): 5 | return 10 * log10((peak ** 2) / mse) 6 | 7 | 8 | class AverageMeter(object): 9 | """Computes and stores the average and current value""" 10 | 11 | def __init__(self): 12 | self.reset() 13 | 14 | def reset(self): 15 | self.val = 0 16 | self.avg = 0 17 | self.sum = 0 18 | self.count = 0 19 | 20 | def update(self, val, n=1): 21 | self.val = val 22 | self.sum += val * n 23 | self.count += n 24 | self.avg = self.sum / self.count 25 | 26 | -------------------------------------------------------------------------------- /utils/pyg_util.py: -------------------------------------------------------------------------------- 1 | import torch_scatter 2 | 3 | 4 | def scatter_(name, src, index, dim=0, dim_size=None): 5 | r"""Aggregates all values from the :attr:`src` tensor at the indices 6 | specified in the :attr:`index` tensor along the first dimension. 7 | If multiple indices reference the same location, their contributions 8 | are aggregated according to :attr:`name` (either :obj:`"add"`, 9 | :obj:`"mean"` or :obj:`"max"`). 10 | 11 | Args: 12 | name (string): The aggregation to use (:obj:`"add"`, :obj:`"mean"`, 13 | :obj:`"min"`, :obj:`"max"`). 14 | src (Tensor): The source tensor. 15 | index (LongTensor): The indices of elements to scatter. 16 | dim (int, optional): The axis along which to index. (default: :obj:`0`) 17 | dim_size (int, optional): Automatically create output tensor with size 18 | :attr:`dim_size` in the first dimension. If set to :attr:`None`, a 19 | minimal sized output tensor is returned. (default: :obj:`None`) 20 | 21 | :rtype: :class:`Tensor` 22 | """ 23 | 24 | assert name in ['add', 'mean', 'min', 'max'] 25 | 26 | op = getattr(torch_scatter, 'scatter_{}'.format(name)) 27 | out = op(src, index, dim, None, dim_size) 28 | out = out[0] if isinstance(out, tuple) else out 29 | 30 | if name == 'max': 31 | out[out < -10000] = 0 32 | elif name == 'min': 33 | out[out > 10000] = 0 34 | 35 | return out 36 | -------------------------------------------------------------------------------- /utils/tf_logger.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | try: 3 | import tensorflow as tf 4 | import tensorboard.plugins.mesh.summary as meshsummary 5 | except ImportError: 6 | print('tensorflow is not installed.') 7 | import numpy as np 8 | import scipy.misc 9 | 10 | 11 | try: 12 | from StringIO import StringIO # Python 2.7 13 | except ImportError: 14 | from io import BytesIO # Python 3.x 15 | 16 | 17 | class TfLogger(object): 18 | 19 | def __init__(self, log_dir): 20 | """Create a summary writer logging to log_dir.""" 21 | self.writer = tf.compat.v1.summary.FileWriter(log_dir) 22 | 23 | # Camera and scene configuration. 24 | self.config_dict = { 25 | 'camera': {'cls': 'PerspectiveCamera', 'fov': 75}, 26 | 'lights': [ 27 | { 28 | 'cls': 'AmbientLight', 29 | 'color': '#ffffff', 30 | 'intensity': 0.75, 31 | }, { 32 | 'cls': 'DirectionalLight', 33 | 'color': '#ffffff', 34 | 'intensity': 0.75, 35 | 'position': [0, -1, 2], 36 | }], 37 | 'material': { 38 | 'cls': 'MeshStandardMaterial', 39 | 'metalness': 0 40 | } 41 | } 42 | 43 | def scalar_summary(self, tag, value, step): 44 | """Log a scalar variable.""" 45 | summary = tf.compat.v1.Summary(value=[tf.compat.v1.Summary.Value(tag=tag, simple_value=value)]) 46 | self.writer.add_summary(summary, step) 47 | 48 | def image_summary(self, tag, images, step): 49 | """Log a list of images.""" 50 | img_summaries = [] 51 | for i, img in enumerate(images): 52 | # Write the image to a string 53 | s = BytesIO() 54 | scipy.misc.toimage(img).save(s, format="png") 55 | 56 | # Create an Image object 57 | img_sum = tf.compat.v1.Summary.Image(encoded_image_string=s.getvalue(), 58 | height=img.shape[0], width=img.shape[1]) 59 | # Create a Summary value 60 | img_summaries.append(tf.compat.v1.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 61 | 62 | # Create and write Summary 63 | summary = tf.Summary(value=img_summaries) 64 | self.writer.add_summary(summary, step) 65 | 66 | def mesh_summary(self, tag, vertices, faces=None, colors=None, step=0): 67 | 68 | """Log a list of mesh images.""" 69 | if colors is None: 70 | colors = tf.constant(np.zeros_like(vertices)) 71 | vertices = tf.constant(vertices) 72 | if faces is not None: 73 | faces = tf.constant(faces) 74 | meshes_summares=[] 75 | for i in range(vertices.shape[0]): 76 | meshes_summares.append(meshsummary.op( 77 | tag, vertices=vertices, faces=faces, colors=colors, config_dict=self.config_dict)) 78 | 79 | sess = tf.Session() 80 | summaries = sess.run(meshes_summares) 81 | for summary in summaries: 82 | self.writer.add_summary(summary, step) 83 | 84 | def histo_summary(self, tag, values, step, bins=1000): 85 | """Log a histogram of the tensor of values.""" 86 | 87 | # Create a histogram using numpy 88 | counts, bin_edges = np.histogram(values, bins=bins) 89 | 90 | # Fill the fields of the histogram proto 91 | hist = tf.HistogramProto() 92 | hist.min = float(np.min(values)) 93 | hist.max = float(np.max(values)) 94 | hist.num = int(np.prod(values.shape)) 95 | hist.sum = float(np.sum(values)) 96 | hist.sum_squares = float(np.sum(values**2)) 97 | 98 | # Drop the start of the first bin 99 | bin_edges = bin_edges[1:] 100 | 101 | # Add bin edges and counts 102 | for edge in bin_edges: 103 | hist.bucket_limit.append(edge) 104 | for c in counts: 105 | hist.bucket.append(c) 106 | 107 | # Create and write Summary 108 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 109 | self.writer.add_summary(summary, step) 110 | self.writer.flush() 111 | 112 | --------------------------------------------------------------------------------