├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── analysis ├── LICENSE ├── design_space.ipynb ├── example.ipynb ├── figs │ └── __init__.py └── idgnn.csv ├── docs ├── IDGNN.png ├── design_space.png ├── evaluation.png ├── overview.png └── rank.png ├── graphgym ├── __init__.py ├── checkpoint.py ├── cmd_args.py ├── config.py ├── contrib │ ├── __init__.py │ ├── act │ │ ├── __init__.py │ │ └── example.py │ ├── config │ │ ├── __init__.py │ │ └── example.py │ ├── feature_augment │ │ ├── __init__.py │ │ └── example.py │ ├── feature_encoder │ │ ├── __init__.py │ │ └── example.py │ ├── head │ │ ├── __init__.py │ │ └── example.py │ ├── layer │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── attconv.py │ │ ├── example.py │ │ ├── generalconv.py │ │ ├── generalconv_ogb.py │ │ ├── generalconv_v2.py │ │ ├── idconv.py │ │ └── sageinitconv.py │ ├── loader │ │ ├── __init__.py │ │ └── example.py │ ├── loss │ │ ├── __init__.py │ │ └── example.py │ ├── network │ │ ├── __init__.py │ │ └── example.py │ ├── optimizer │ │ ├── __init__.py │ │ └── example.py │ ├── pooling │ │ ├── __init__.py │ │ └── example.py │ ├── stage │ │ ├── __init__.py │ │ └── example.py │ ├── train │ │ ├── __init__.py │ │ └── example.py │ └── transform │ │ ├── __init__.py │ │ └── identity.py ├── init.py ├── loader.py ├── loader_pyg.py ├── logger.py ├── loss.py ├── model_builder.py ├── model_builder_pyg.py ├── models │ ├── LICENSE │ ├── __init__.py │ ├── act.py │ ├── feature_augment.py │ ├── feature_encoder.py │ ├── feature_encoder_pyg.py │ ├── gnn.py │ ├── gnn_pyg.py │ ├── head.py │ ├── head_pyg.py │ ├── layer.py │ ├── layer_pyg.py │ ├── pooling.py │ └── transform.py ├── optimizer.py ├── register.py ├── train.py ├── train_pyg.py └── utils │ ├── LICENSE │ ├── __init__.py │ ├── agg_runs.py │ ├── comp_budget.py │ ├── device.py │ ├── epoch.py │ ├── io.py │ ├── plot.py │ └── tools.py ├── install.sh ├── requirements.txt ├── run ├── LICENSE ├── __init__.py ├── agg_batch.py ├── configs │ ├── IDGNN │ │ ├── edge.yaml │ │ ├── graph.yaml │ │ ├── graph_enzyme.yaml │ │ ├── graph_ogb.yaml │ │ ├── node.yaml │ │ └── node_clustering.yaml │ ├── design │ │ ├── design_v1.yaml │ │ ├── design_v1att.yaml │ │ ├── design_v2.yaml │ │ ├── design_v2link.yaml │ │ └── design_v2ogb.yaml │ ├── example.yaml │ ├── example_cpu.yaml │ └── pyg │ │ ├── example_graph.yaml │ │ ├── example_link.yaml │ │ └── example_node.yaml ├── configs_gen.py ├── datasets │ ├── __init__.py │ ├── ba.pkl │ ├── ba500.pkl │ ├── scalefree.pkl │ ├── smallworld.pkl │ ├── syn_graph.py │ ├── ws.pkl │ └── ws500.pkl ├── grids │ ├── IDGNN │ │ ├── graph.txt │ │ ├── graph_enzyme.txt │ │ ├── graph_ogb.txt │ │ ├── link.txt │ │ ├── node.txt │ │ ├── node_clustering.txt │ │ └── path.txt │ ├── design │ │ ├── round1.txt │ │ ├── round1att.txt │ │ ├── round2.txt │ │ ├── round2link.txt │ │ └── round2ogb.txt │ ├── example.txt │ └── pyg │ │ └── example.txt ├── main.py ├── main_pyg.py ├── parallel.sh ├── results │ ├── design_v1_grid_round1 │ │ └── agg │ │ │ ├── train.csv │ │ │ └── val.csv │ ├── design_v1_grid_round1att │ │ └── agg │ │ │ ├── train.csv │ │ │ ├── train_best.csv │ │ │ ├── val.csv │ │ │ └── val_best.csv │ ├── design_v2_grid_round2 │ │ └── agg │ │ │ ├── train.csv │ │ │ └── val.csv │ ├── design_v2link_grid_round2link │ │ └── agg │ │ │ ├── train.csv │ │ │ ├── train_best.csv │ │ │ ├── train_bestepoch.csv │ │ │ ├── val.csv │ │ │ ├── val_best.csv │ │ │ └── val_bestepoch.csv │ └── design_v2ogb_grid_round2ogb │ │ └── agg │ │ ├── test.csv │ │ ├── test_best.csv │ │ ├── train.csv │ │ ├── train_best.csv │ │ ├── val.csv │ │ └── val_best.csv ├── run_batch.sh ├── run_batch_pyg.sh ├── run_single.sh ├── run_single_cpu.sh ├── run_single_pyg.sh ├── sample │ ├── dimensions.txt │ └── dimensionsatt.txt └── scripts │ ├── IDGNN │ ├── run_idgnn_edge.sh │ ├── run_idgnn_graph.sh │ └── run_idgnn_node.sh │ └── design │ ├── run_design_round1.sh │ ├── run_design_round2.sh │ └── run_design_round2ogb.sh └── setup.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | *.ipynb linguist-vendored -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/data_dir/ 2 | run/datasets/data/ 3 | **/__pycache__/ 4 | **/.ipynb_checkpoints 5 | .idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 Jiaxuan You 2 | Copyright (c) 2021 Jiaxuan You, Matthias Fey 3 | Copyright (c) 2020 Jiaxuan You, Rex Ying, Jonathan Gomes Selman 4 | Copyright (c) Facebook, Inc. and its affiliates. 5 | Additional copyrights are specified in relevant subdirectories. 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. -------------------------------------------------------------------------------- /analysis/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Jiaxuan You 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. -------------------------------------------------------------------------------- /analysis/figs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snap-stanford/GraphGym/daded21169ec92fde8b1252b439a8fac35b07d79/analysis/figs/__init__.py -------------------------------------------------------------------------------- /docs/IDGNN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snap-stanford/GraphGym/daded21169ec92fde8b1252b439a8fac35b07d79/docs/IDGNN.png -------------------------------------------------------------------------------- /docs/design_space.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snap-stanford/GraphGym/daded21169ec92fde8b1252b439a8fac35b07d79/docs/design_space.png -------------------------------------------------------------------------------- /docs/evaluation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snap-stanford/GraphGym/daded21169ec92fde8b1252b439a8fac35b07d79/docs/evaluation.png -------------------------------------------------------------------------------- /docs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snap-stanford/GraphGym/daded21169ec92fde8b1252b439a8fac35b07d79/docs/overview.png -------------------------------------------------------------------------------- /docs/rank.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snap-stanford/GraphGym/daded21169ec92fde8b1252b439a8fac35b07d79/docs/rank.png -------------------------------------------------------------------------------- /graphgym/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snap-stanford/GraphGym/daded21169ec92fde8b1252b439a8fac35b07d79/graphgym/__init__.py -------------------------------------------------------------------------------- /graphgym/checkpoint.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import os.path as osp 4 | from typing import Any, Dict, List, Optional, Union 5 | 6 | import torch 7 | 8 | from graphgym.config import cfg 9 | 10 | MODEL_STATE = 'model_state' 11 | OPTIMIZER_STATE = 'optimizer_state' 12 | SCHEDULER_STATE = 'scheduler_state' 13 | 14 | 15 | def load_ckpt( 16 | model: torch.nn.Module, 17 | optimizer: Optional[torch.optim.Optimizer] = None, 18 | scheduler: Optional[Any] = None, 19 | epoch: int = -1, 20 | ) -> int: 21 | r"""Loads the model checkpoint at a given epoch.""" 22 | epoch = get_ckpt_epoch(epoch) 23 | path = get_ckpt_path(epoch) 24 | 25 | if not osp.exists(path): 26 | return 0 27 | 28 | ckpt = torch.load(path) 29 | model.load_state_dict(ckpt[MODEL_STATE]) 30 | if optimizer is not None and OPTIMIZER_STATE in ckpt: 31 | optimizer.load_state_dict(ckpt[OPTIMIZER_STATE]) 32 | if scheduler is not None and SCHEDULER_STATE in ckpt: 33 | scheduler.load_state_dict(ckpt[SCHEDULER_STATE]) 34 | 35 | return epoch + 1 36 | 37 | 38 | def save_ckpt( 39 | model: torch.nn.Module, 40 | optimizer: Optional[torch.optim.Optimizer] = None, 41 | scheduler: Optional[Any] = None, 42 | epoch: int = 0, 43 | ): 44 | r"""Saves the model checkpoint at a given epoch.""" 45 | ckpt: Dict[str, Any] = {} 46 | ckpt[MODEL_STATE] = model.state_dict() 47 | if optimizer is not None: 48 | ckpt[OPTIMIZER_STATE] = optimizer.state_dict() 49 | if scheduler is not None: 50 | ckpt[SCHEDULER_STATE] = scheduler.state_dict() 51 | 52 | os.makedirs(get_ckpt_dir(), exist_ok=True) 53 | torch.save(ckpt, get_ckpt_path(get_ckpt_epoch(epoch))) 54 | 55 | 56 | def remove_ckpt(epoch: int = -1): 57 | r"""Removes the model checkpoint at a given epoch.""" 58 | os.remove(get_ckpt_path(get_ckpt_epoch(epoch))) 59 | 60 | 61 | def clean_ckpt(): 62 | r"""Removes all but the last model checkpoint.""" 63 | for epoch in get_ckpt_epochs()[:-1]: 64 | os.remove(get_ckpt_path(epoch)) 65 | 66 | 67 | ############################################################################### 68 | 69 | 70 | def get_ckpt_dir() -> str: 71 | return osp.join(cfg.run_dir, 'ckpt') 72 | 73 | 74 | def get_ckpt_path(epoch: Union[int, str]) -> str: 75 | return osp.join(get_ckpt_dir(), f'{epoch}.ckpt') 76 | 77 | 78 | def get_ckpt_epochs() -> List[int]: 79 | paths = glob.glob(get_ckpt_path('*')) 80 | return sorted([int(osp.basename(path).split('.')[0]) for path in paths]) 81 | 82 | 83 | def get_ckpt_epoch(epoch: int) -> int: 84 | if epoch < 0: 85 | epochs = get_ckpt_epochs() 86 | epoch = epochs[epoch] if len(epochs) > 0 else 0 87 | return epoch 88 | -------------------------------------------------------------------------------- /graphgym/cmd_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parse_args() -> argparse.Namespace: 5 | r"""Parses the command line arguments.""" 6 | parser = argparse.ArgumentParser(description='GraphGym') 7 | 8 | parser.add_argument('--cfg', 9 | dest='cfg_file', 10 | type=str, 11 | required=True, 12 | help='The configuration file path.') 13 | parser.add_argument('--repeat', 14 | type=int, 15 | default=1, 16 | help='The number of repeated jobs.') 17 | parser.add_argument('--mark_done', 18 | action='store_true', 19 | help='Mark yaml as done after a job has finished.') 20 | parser.add_argument('opts', 21 | default=None, 22 | nargs=argparse.REMAINDER, 23 | help='See graphgym/config.py for remaining options.') 24 | 25 | return parser.parse_args() 26 | -------------------------------------------------------------------------------- /graphgym/contrib/__init__.py: -------------------------------------------------------------------------------- 1 | from .act import * # noqa 2 | from .config import * # noqa 3 | from .feature_augment import * # noqa 4 | from .feature_encoder import * # noqa 5 | from .head import * # noqa 6 | from .layer import * # noqa 7 | from .loader import * # noqa 8 | from .loss import * # noqa 9 | from .network import * # noqa 10 | from .optimizer import * # noqa 11 | from .pooling import * # noqa 12 | from .stage import * # noqa 13 | from .train import * # noqa 14 | from .transform import * # noqa 15 | -------------------------------------------------------------------------------- /graphgym/contrib/act/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import basename, dirname, isfile, join 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/contrib/act/example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from graphgym.config import cfg 5 | from graphgym.register import register_act 6 | 7 | 8 | class SWISH(nn.Module): 9 | def __init__(self, inplace=False): 10 | super().__init__() 11 | self.inplace = inplace 12 | 13 | def forward(self, x): 14 | if self.inplace: 15 | x.mul_(torch.sigmoid(x)) 16 | return x 17 | else: 18 | return x * torch.sigmoid(x) 19 | 20 | 21 | register_act('swish', SWISH(inplace=cfg.mem.inplace)) 22 | 23 | register_act('lrelu_03', 24 | nn.LeakyReLU(negative_slope=0.3, inplace=cfg.mem.inplace)) 25 | -------------------------------------------------------------------------------- /graphgym/contrib/config/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import basename, dirname, isfile, join 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/contrib/config/example.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | from graphgym.register import register_config 4 | 5 | 6 | def set_cfg_example(cfg): 7 | r''' 8 | This function sets the default config value for customized options 9 | :return: customized configuration use by the experiment. 10 | ''' 11 | 12 | # ----------------------------------------------------------------------- # 13 | # Customized options 14 | # ----------------------------------------------------------------------- # 15 | 16 | # example argument 17 | cfg.example_arg = 'example' 18 | 19 | # example argument group 20 | cfg.example_group = CN() 21 | 22 | # then argument can be specified within the group 23 | cfg.example_group.example_arg = 'example' 24 | 25 | 26 | register_config('example', set_cfg_example) 27 | -------------------------------------------------------------------------------- /graphgym/contrib/feature_augment/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import basename, dirname, isfile, join 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/contrib/feature_augment/example.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | from graphgym.register import register_feature_augment 4 | 5 | 6 | def example_node_augmentation_func(graph, **kwargs): 7 | ''' 8 | compute node clustering coefficient as feature augmentation 9 | :param graph: deepsnap graph. graph.G is networkx 10 | :param kwargs: required, in case additional kwargs are provided 11 | :return: List of node feature values, length equals number of nodes 12 | Note: these returned values are later processed and treated as node 13 | features as specified in "cfg.dataset.augment_feature_repr" 14 | ''' 15 | return list(nx.clustering(graph.G).values()) 16 | 17 | 18 | register_feature_augment('example', example_node_augmentation_func) 19 | -------------------------------------------------------------------------------- /graphgym/contrib/feature_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import basename, dirname, isfile, join 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/contrib/feature_encoder/example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ogb.utils.features import get_bond_feature_dims 3 | 4 | from graphgym.register import register_edge_encoder, register_node_encoder 5 | 6 | 7 | class ExampleNodeEncoder(torch.nn.Module): 8 | """ 9 | Provides an encoder for integer node features 10 | 11 | Parameters: 12 | num_classes - the number of classes for the embedding mapping to learn 13 | """ 14 | 15 | def __init__(self, emb_dim, num_classes=None): 16 | super(ExampleNodeEncoder, self).__init__() 17 | 18 | self.encoder = torch.nn.Embedding(num_classes, emb_dim) 19 | torch.nn.init.xavier_uniform_(self.encoder.weight.data) 20 | 21 | def forward(self, batch): 22 | # Encode just the first dimension if more exist 23 | batch.node_feature = self.encoder(batch.node_feature[:, 0]) 24 | 25 | return batch 26 | 27 | 28 | register_node_encoder('example', ExampleNodeEncoder) 29 | 30 | 31 | class ExampleEdgeEncoder(torch.nn.Module): 32 | 33 | def __init__(self, emb_dim): 34 | super(ExampleEdgeEncoder, self).__init__() 35 | 36 | self.bond_embedding_list = torch.nn.ModuleList() 37 | full_bond_feature_dims = get_bond_feature_dims() 38 | 39 | for i, dim in enumerate(full_bond_feature_dims): 40 | emb = torch.nn.Embedding(dim, emb_dim) 41 | torch.nn.init.xavier_uniform_(emb.weight.data) 42 | self.bond_embedding_list.append(emb) 43 | 44 | def forward(self, batch): 45 | bond_embedding = 0 46 | for i in range(batch.edge_feature.shape[1]): 47 | bond_embedding += \ 48 | self.bond_embedding_list[i](batch.edge_feature[:, i]) 49 | 50 | batch.edge_feature = bond_embedding 51 | return batch 52 | 53 | 54 | register_edge_encoder('example', ExampleEdgeEncoder) 55 | -------------------------------------------------------------------------------- /graphgym/contrib/head/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import basename, dirname, isfile, join 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/contrib/head/example.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from graphgym.register import register_head 4 | 5 | 6 | class ExampleNodeHead(nn.Module): 7 | '''Head of GNN, node prediction''' 8 | 9 | def __init__(self, dim_in, dim_out): 10 | super(ExampleNodeHead, self).__init__() 11 | self.layer_post_mp = nn.Linear(dim_in, dim_out, bias=True) 12 | 13 | def _apply_index(self, batch): 14 | if batch.node_label_index.shape[0] == batch.node_label.shape[0]: 15 | return batch.node_feature[batch.node_label_index], batch.node_label 16 | else: 17 | return batch.node_feature[batch.node_label_index], \ 18 | batch.node_label[batch.node_label_index] 19 | 20 | def forward(self, batch): 21 | batch = self.layer_post_mp(batch) 22 | pred, label = self._apply_index(batch) 23 | return pred, label 24 | 25 | 26 | register_head('example', ExampleNodeHead) 27 | -------------------------------------------------------------------------------- /graphgym/contrib/layer/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 Jiaxuan You 2 | Copyright (c) 2020 Jiaxuan You, Jonathan Gomes Selman 3 | Copyright (c) 2020 Matthias Fey 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. -------------------------------------------------------------------------------- /graphgym/contrib/layer/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import basename, dirname, isfile, join 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/contrib/layer/example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Parameter 4 | from torch_geometric.nn.conv import MessagePassing 5 | from torch_geometric.nn.inits import glorot, zeros 6 | 7 | from graphgym.config import cfg 8 | from graphgym.register import register_layer 9 | 10 | # Note: A registered GNN layer should take 'batch' as input 11 | # and 'batch' as output 12 | 13 | 14 | # Example 1: Directly define a GraphGym format Conv 15 | # take 'batch' as input and 'batch' as output 16 | class ExampleConv1(MessagePassing): 17 | r"""Example GNN layer 18 | 19 | """ 20 | def __init__(self, in_channels, out_channels, bias=True, **kwargs): 21 | super(ExampleConv1, self).__init__(aggr=cfg.gnn.agg, **kwargs) 22 | 23 | self.in_channels = in_channels 24 | self.out_channels = out_channels 25 | 26 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 27 | 28 | if bias: 29 | self.bias = Parameter(torch.Tensor(out_channels)) 30 | else: 31 | self.register_parameter('bias', None) 32 | 33 | self.reset_parameters() 34 | 35 | def reset_parameters(self): 36 | glorot(self.weight) 37 | zeros(self.bias) 38 | 39 | def forward(self, batch): 40 | """""" 41 | x, edge_index = batch.node_feature, batch.edge_index 42 | x = torch.matmul(x, self.weight) 43 | 44 | batch.node_feature = self.propagate(edge_index, x=x) 45 | 46 | return batch 47 | 48 | def message(self, x_j): 49 | return x_j 50 | 51 | def update(self, aggr_out): 52 | if self.bias is not None: 53 | aggr_out = aggr_out + self.bias 54 | return aggr_out 55 | 56 | def __repr__(self): 57 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 58 | self.out_channels) 59 | 60 | 61 | # Remember to register your layer! 62 | register_layer('exampleconv1', ExampleConv1) 63 | 64 | 65 | # Example 2: First define a PyG format Conv layer 66 | # Then wrap it to become GraphGym format 67 | class ExampleConv2Layer(MessagePassing): 68 | r"""Example GNN layer 69 | 70 | """ 71 | def __init__(self, in_channels, out_channels, bias=True, **kwargs): 72 | super(ExampleConv2Layer, self).__init__(aggr=cfg.gnn.agg, **kwargs) 73 | 74 | self.in_channels = in_channels 75 | self.out_channels = out_channels 76 | 77 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 78 | 79 | if bias: 80 | self.bias = Parameter(torch.Tensor(out_channels)) 81 | else: 82 | self.register_parameter('bias', None) 83 | 84 | self.reset_parameters() 85 | 86 | def reset_parameters(self): 87 | glorot(self.weight) 88 | zeros(self.bias) 89 | 90 | def forward(self, x, edge_index): 91 | """""" 92 | x = torch.matmul(x, self.weight) 93 | 94 | return self.propagate(edge_index, x=x) 95 | 96 | def message(self, x_j): 97 | return x_j 98 | 99 | def update(self, aggr_out): 100 | if self.bias is not None: 101 | aggr_out = aggr_out + self.bias 102 | return aggr_out 103 | 104 | def __repr__(self): 105 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 106 | self.out_channels) 107 | 108 | 109 | class ExampleConv2(nn.Module): 110 | def __init__(self, dim_in, dim_out, bias=False, **kwargs): 111 | super(ExampleConv2, self).__init__() 112 | self.model = ExampleConv2Layer(dim_in, dim_out, bias=bias) 113 | 114 | def forward(self, batch): 115 | batch.node_feature = self.model(batch.node_feature, batch.edge_index) 116 | return batch 117 | 118 | 119 | # Remember to register your layer! 120 | register_layer('exampleconv2', ExampleConv2) 121 | -------------------------------------------------------------------------------- /graphgym/contrib/layer/generalconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Parameter 4 | from torch_geometric.nn.conv import MessagePassing 5 | from torch_geometric.nn.inits import glorot, zeros 6 | from torch_geometric.utils import add_remaining_self_loops 7 | from torch_scatter import scatter_add 8 | 9 | from graphgym.config import cfg 10 | 11 | 12 | class GeneralConvLayer(MessagePassing): 13 | r"""General GNN layer 14 | """ 15 | def __init__(self, 16 | in_channels, 17 | out_channels, 18 | improved=False, 19 | cached=False, 20 | bias=True, 21 | **kwargs): 22 | super(GeneralConvLayer, self).__init__(aggr=cfg.gnn.agg, **kwargs) 23 | 24 | self.in_channels = in_channels 25 | self.out_channels = out_channels 26 | self.improved = improved 27 | self.cached = cached 28 | self.normalize = cfg.gnn.normalize_adj 29 | 30 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 31 | if cfg.gnn.self_msg == 'concat': 32 | self.weight_self = Parameter( 33 | torch.Tensor(in_channels, out_channels)) 34 | 35 | if bias: 36 | self.bias = Parameter(torch.Tensor(out_channels)) 37 | else: 38 | self.register_parameter('bias', None) 39 | 40 | self.reset_parameters() 41 | 42 | def reset_parameters(self): 43 | glorot(self.weight) 44 | if cfg.gnn.self_msg == 'concat': 45 | glorot(self.weight_self) 46 | zeros(self.bias) 47 | self.cached_result = None 48 | self.cached_num_edges = None 49 | 50 | @staticmethod 51 | def norm(edge_index, 52 | num_nodes, 53 | edge_weight=None, 54 | improved=False, 55 | dtype=None): 56 | if edge_weight is None: 57 | edge_weight = torch.ones((edge_index.size(1), ), 58 | dtype=dtype, 59 | device=edge_index.device) 60 | 61 | fill_value = 1.0 if not improved else 2.0 62 | edge_index, edge_weight = add_remaining_self_loops( 63 | edge_index, edge_weight, fill_value, num_nodes) 64 | 65 | row, col = edge_index 66 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 67 | deg_inv_sqrt = deg.pow(-0.5) 68 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 69 | 70 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 71 | 72 | def forward(self, x, edge_index, edge_weight=None, edge_feature=None): 73 | """""" 74 | if cfg.gnn.self_msg == 'concat': 75 | x_self = torch.matmul(x, self.weight_self) 76 | x = torch.matmul(x, self.weight) 77 | 78 | if self.cached and self.cached_result is not None: 79 | if edge_index.size(1) != self.cached_num_edges: 80 | raise RuntimeError( 81 | 'Cached {} number of edges, but found {}. Please ' 82 | 'disable the caching behavior of this layer by removing ' 83 | 'the `cached=True` argument in its constructor.'.format( 84 | self.cached_num_edges, edge_index.size(1))) 85 | 86 | if not self.cached or self.cached_result is None: 87 | self.cached_num_edges = edge_index.size(1) 88 | if self.normalize: 89 | edge_index, norm = self.norm(edge_index, x.size(self.node_dim), 90 | edge_weight, self.improved, 91 | x.dtype) 92 | else: 93 | norm = edge_weight 94 | self.cached_result = edge_index, norm 95 | 96 | edge_index, norm = self.cached_result 97 | x_msg = self.propagate(edge_index, 98 | x=x, 99 | norm=norm, 100 | edge_feature=edge_feature) 101 | if cfg.gnn.self_msg == 'none': 102 | return x_msg 103 | elif cfg.gnn.self_msg == 'add': 104 | return x_msg + x 105 | elif cfg.gnn.self_msg == 'concat': 106 | return x_msg + x_self 107 | else: 108 | raise ValueError('self_msg {} not defined'.format( 109 | cfg.gnn.self_msg)) 110 | 111 | def message(self, x_j, norm, edge_feature): 112 | if edge_feature is None: 113 | return norm.view(-1, 1) * x_j if norm is not None else x_j 114 | else: 115 | return norm.view(-1, 1) * ( 116 | x_j + edge_feature) if norm is not None else (x_j + 117 | edge_feature) 118 | 119 | def update(self, aggr_out): 120 | if self.bias is not None: 121 | aggr_out = aggr_out + self.bias 122 | return aggr_out 123 | 124 | def __repr__(self): 125 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 126 | self.out_channels) 127 | 128 | 129 | class GeneralEdgeConvLayer(MessagePassing): 130 | r"""General GNN layer, with edge features 131 | """ 132 | def __init__(self, 133 | in_channels, 134 | out_channels, 135 | improved=False, 136 | cached=False, 137 | bias=True, 138 | **kwargs): 139 | super(GeneralEdgeConvLayer, self).__init__(aggr=cfg.gnn.agg, **kwargs) 140 | 141 | self.in_channels = in_channels 142 | self.out_channels = out_channels 143 | self.improved = improved 144 | self.cached = cached 145 | self.normalize = cfg.gnn.normalize_adj 146 | self.msg_direction = cfg.gnn.msg_direction 147 | 148 | if self.msg_direction == 'single': 149 | self.linear_msg = nn.Linear(in_channels + cfg.dataset.edge_dim, 150 | out_channels, 151 | bias=False) 152 | else: 153 | self.linear_msg = nn.Linear(in_channels * 2 + cfg.dataset.edge_dim, 154 | out_channels, 155 | bias=False) 156 | if cfg.gnn.self_msg == 'concat': 157 | self.linear_self = nn.Linear(in_channels, out_channels, bias=False) 158 | 159 | if bias: 160 | self.bias = Parameter(torch.Tensor(out_channels)) 161 | else: 162 | self.register_parameter('bias', None) 163 | 164 | self.reset_parameters() 165 | 166 | def reset_parameters(self): 167 | zeros(self.bias) 168 | self.cached_result = None 169 | self.cached_num_edges = None 170 | 171 | @staticmethod 172 | def norm(edge_index, 173 | num_nodes, 174 | edge_weight=None, 175 | improved=False, 176 | dtype=None): 177 | if edge_weight is None: 178 | edge_weight = torch.ones((edge_index.size(1), ), 179 | dtype=dtype, 180 | device=edge_index.device) 181 | 182 | fill_value = 1.0 if not improved else 2.0 183 | edge_index, edge_weight = add_remaining_self_loops( 184 | edge_index, edge_weight, fill_value, num_nodes) 185 | 186 | row, col = edge_index 187 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 188 | deg_inv_sqrt = deg.pow(-0.5) 189 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 190 | 191 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 192 | 193 | def forward(self, x, edge_index, edge_weight=None, edge_feature=None): 194 | if self.cached and self.cached_result is not None: 195 | if edge_index.size(1) != self.cached_num_edges: 196 | raise RuntimeError( 197 | 'Cached {} number of edges, but found {}. Please ' 198 | 'disable the caching behavior of this layer by removing ' 199 | 'the `cached=True` argument in its constructor.'.format( 200 | self.cached_num_edges, edge_index.size(1))) 201 | 202 | if not self.cached or self.cached_result is None: 203 | self.cached_num_edges = edge_index.size(1) 204 | if self.normalize: 205 | edge_index, norm = self.norm(edge_index, x.size(self.node_dim), 206 | edge_weight, self.improved, 207 | x.dtype) 208 | else: 209 | norm = edge_weight 210 | self.cached_result = edge_index, norm 211 | 212 | edge_index, norm = self.cached_result 213 | 214 | x_msg = self.propagate(edge_index, 215 | x=x, 216 | norm=norm, 217 | edge_feature=edge_feature) 218 | 219 | if cfg.gnn.self_msg == 'concat': 220 | x_self = self.linear_self(x) 221 | return x_self + x_msg 222 | elif cfg.gnn.self_msg == 'add': 223 | return x + x_msg 224 | else: 225 | return x_msg 226 | 227 | def message(self, x_i, x_j, norm, edge_feature): 228 | if self.msg_direction == 'both': 229 | x_j = torch.cat((x_i, x_j, edge_feature), dim=-1) 230 | else: 231 | x_j = torch.cat((x_j, edge_feature), dim=-1) 232 | x_j = self.linear_msg(x_j) 233 | return norm.view(-1, 1) * x_j if norm is not None else x_j 234 | 235 | def update(self, aggr_out): 236 | if self.bias is not None: 237 | aggr_out = aggr_out + self.bias 238 | return aggr_out 239 | 240 | def __repr__(self): 241 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 242 | self.out_channels) 243 | -------------------------------------------------------------------------------- /graphgym/contrib/layer/generalconv_ogb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ogb.utils.features import get_bond_feature_dims 4 | from torch.nn import Parameter 5 | from torch_geometric.nn.conv import MessagePassing 6 | from torch_geometric.nn.inits import glorot, zeros 7 | from torch_geometric.utils import add_remaining_self_loops 8 | from torch_scatter import scatter_add 9 | 10 | from graphgym.config import cfg 11 | from graphgym.register import register_layer 12 | 13 | full_bond_feature_dims = get_bond_feature_dims() 14 | 15 | 16 | class BondEncoder(torch.nn.Module): 17 | 18 | def __init__(self, emb_dim): 19 | super(BondEncoder, self).__init__() 20 | 21 | self.bond_embedding_list = torch.nn.ModuleList() 22 | 23 | for i, dim in enumerate(full_bond_feature_dims): 24 | emb = torch.nn.Embedding(dim, emb_dim) 25 | torch.nn.init.xavier_uniform_(emb.weight.data) 26 | self.bond_embedding_list.append(emb) 27 | 28 | def forward(self, edge_feature): 29 | bond_embedding = 0 30 | for i in range(edge_feature.shape[1]): 31 | bond_embedding += self.bond_embedding_list[i](edge_feature[:, i]) 32 | 33 | return bond_embedding 34 | 35 | 36 | class GeneralOGBConvLayer(MessagePassing): 37 | r"""General GNN layer, for OGB 38 | """ 39 | 40 | def __init__(self, in_channels, out_channels, improved=False, cached=False, 41 | bias=True, **kwargs): 42 | super(GeneralOGBConvLayer, self).__init__(aggr=cfg.gnn.agg, **kwargs) 43 | 44 | self.in_channels = in_channels 45 | self.out_channels = out_channels 46 | self.improved = improved 47 | self.cached = cached 48 | self.normalize = cfg.gnn.normalize_adj 49 | 50 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 51 | 52 | if bias: 53 | self.bias = Parameter(torch.Tensor(out_channels)) 54 | else: 55 | self.register_parameter('bias', None) 56 | 57 | self.bond_encoder = BondEncoder(emb_dim=out_channels) 58 | 59 | self.reset_parameters() 60 | 61 | def reset_parameters(self): 62 | glorot(self.weight) 63 | zeros(self.bias) 64 | self.cached_result = None 65 | self.cached_num_edges = None 66 | 67 | @staticmethod 68 | def norm(edge_index, num_nodes, edge_weight=None, improved=False, 69 | dtype=None): 70 | if edge_weight is None: 71 | edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, 72 | device=edge_index.device) 73 | 74 | fill_value = 1.0 if not improved else 2.0 75 | edge_index, edge_weight = add_remaining_self_loops( 76 | edge_index, edge_weight, fill_value, num_nodes) 77 | 78 | row, col = edge_index 79 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 80 | deg_inv_sqrt = deg.pow(-0.5) 81 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 82 | 83 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 84 | 85 | def forward(self, x, edge_index, edge_feature, edge_weight=None): 86 | """""" 87 | x = torch.matmul(x, self.weight) 88 | edge_feature = self.bond_encoder(edge_feature) 89 | 90 | if self.cached and self.cached_result is not None: 91 | if edge_index.size(1) != self.cached_num_edges: 92 | raise RuntimeError( 93 | 'Cached {} number of edges, but found {}. Please ' 94 | 'disable the caching behavior of this layer by removing ' 95 | 'the `cached=True` argument in its constructor.'.format( 96 | self.cached_num_edges, edge_index.size(1))) 97 | 98 | if not self.cached or self.cached_result is None: 99 | self.cached_num_edges = edge_index.size(1) 100 | if self.normalize: 101 | edge_index, norm = self.norm(edge_index, x.size(self.node_dim), 102 | edge_weight, self.improved, 103 | x.dtype) 104 | else: 105 | norm = edge_weight 106 | self.cached_result = edge_index, norm 107 | 108 | edge_index, norm = self.cached_result 109 | 110 | return self.propagate(edge_index, x=x, norm=norm, 111 | edge_feature=edge_feature) 112 | 113 | def message(self, x_j, norm, edge_feature): 114 | return norm.view(-1, 1) * ( 115 | x_j + edge_feature) if norm is not None else ( 116 | x_j + edge_feature) 117 | 118 | def update(self, aggr_out): 119 | if self.bias is not None: 120 | aggr_out = aggr_out + self.bias 121 | return aggr_out 122 | 123 | def __repr__(self): 124 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 125 | self.out_channels) 126 | 127 | 128 | class GeneralOGBConv(nn.Module): 129 | def __init__(self, dim_in, dim_out, bias=False, **kwargs): 130 | super(GeneralOGBConv, self).__init__() 131 | self.model = GeneralOGBConvLayer(dim_in, dim_out, bias=bias) 132 | 133 | def forward(self, batch): 134 | batch.node_feature = self.model(batch.node_feature, batch.edge_index, 135 | batch.edge_feature) 136 | return batch 137 | 138 | 139 | register_layer('generalogbconv', GeneralOGBConv) 140 | -------------------------------------------------------------------------------- /graphgym/contrib/layer/sageinitconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Parameter 5 | from torch_geometric.nn.conv import MessagePassing 6 | from torch_geometric.nn.inits import glorot, zeros 7 | from torch_geometric.utils import add_remaining_self_loops 8 | 9 | from graphgym.register import register_layer 10 | 11 | 12 | class SAGEConvLayer(MessagePassing): 13 | r"""The GraphSAGE operator from the `"Inductive Representation Learning on 14 | Large Graphs" `_ paper 15 | 16 | .. math:: 17 | \mathbf{\hat{x}}_i &= \mathbf{\Theta} \cdot 18 | \mathrm{mean}_{j \in \mathcal{N(i) \cup \{ i \}}}(\mathbf{x}_j) 19 | 20 | \mathbf{x}^{\prime}_i &= \frac{\mathbf{\hat{x}}_i} 21 | {\| \mathbf{\hat{x}}_i \|_2}. 22 | 23 | Args: 24 | in_channels (int): Size of each input sample. 25 | out_channels (int): Size of each output sample. 26 | normalize (bool, optional): If set to :obj:`True`, output features 27 | will be :math:`\ell_2`-normalized. (default: :obj:`False`) 28 | concat (bool, optional): If set to :obj:`True`, will concatenate 29 | current node features with aggregated ones. (default: :obj:`False`) 30 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 31 | an additive bias. (default: :obj:`True`) 32 | **kwargs (optional): Additional arguments of 33 | :class:`torch_geometric.nn.conv.MessagePassing`. 34 | """ 35 | 36 | def __init__(self, in_channels, out_channels, normalize=False, 37 | concat=False, bias=True, **kwargs): 38 | super(SAGEConvLayer, self).__init__(aggr='mean', **kwargs) 39 | 40 | self.in_channels = in_channels 41 | self.out_channels = out_channels 42 | self.normalize = normalize 43 | self.concat = concat 44 | 45 | in_channels = 2 * in_channels if concat else in_channels 46 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 47 | 48 | if bias: 49 | self.bias = Parameter(torch.Tensor(out_channels)) 50 | else: 51 | self.register_parameter('bias', None) 52 | 53 | self.reset_parameters() 54 | 55 | def reset_parameters(self): 56 | # original initialization 57 | # uniform(self.weight.size(0), self.weight) 58 | # uniform(self.weight.size(0), self.bias) 59 | # change to new initialization 60 | glorot(self.weight) 61 | zeros(self.bias) 62 | 63 | def forward(self, x, edge_index, edge_weight=None, size=None, 64 | res_n_id=None): 65 | """ 66 | Args: 67 | res_n_id (Tensor, optional): Residual node indices coming from 68 | :obj:`DataFlow` generated by :obj:`NeighborSampler` are used to 69 | select central node features in :obj:`x`. 70 | Required if operating in a bipartite graph and :obj:`concat` is 71 | :obj:`True`. (default: :obj:`None`) 72 | """ 73 | if not self.concat and torch.is_tensor(x): 74 | edge_index, edge_weight = add_remaining_self_loops( 75 | edge_index, edge_weight, 1, x.size(self.node_dim)) 76 | 77 | return self.propagate(edge_index, size=size, x=x, 78 | edge_weight=edge_weight, res_n_id=res_n_id) 79 | 80 | def message(self, x_j, edge_weight): 81 | return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j 82 | 83 | def update(self, aggr_out, x, res_n_id): 84 | if self.concat and torch.is_tensor(x): 85 | aggr_out = torch.cat([x, aggr_out], dim=-1) 86 | elif self.concat and (isinstance(x, tuple) or isinstance(x, list)): 87 | assert res_n_id is not None 88 | aggr_out = torch.cat([x[0][res_n_id], aggr_out], dim=-1) 89 | 90 | aggr_out = torch.matmul(aggr_out, self.weight) 91 | 92 | if self.bias is not None: 93 | aggr_out = aggr_out + self.bias 94 | 95 | if self.normalize: 96 | aggr_out = F.normalize(aggr_out, p=2, dim=-1) 97 | 98 | return aggr_out 99 | 100 | def __repr__(self): 101 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 102 | self.out_channels) 103 | 104 | 105 | class SAGEinitConv(nn.Module): 106 | def __init__(self, dim_in, dim_out, bias=False, **kwargs): 107 | super(SAGEinitConv, self).__init__() 108 | self.model = SAGEConvLayer(dim_in, dim_out, bias=bias, concat=True) 109 | 110 | def forward(self, batch): 111 | batch.node_feature = self.model(batch.node_feature, batch.edge_index) 112 | return batch 113 | 114 | 115 | register_layer('sageinitconv', SAGEinitConv) 116 | -------------------------------------------------------------------------------- /graphgym/contrib/loader/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import basename, dirname, isfile, join 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/contrib/loader/example.py: -------------------------------------------------------------------------------- 1 | from deepsnap.dataset import GraphDataset 2 | from torch_geometric.datasets import QM7b 3 | 4 | from graphgym.register import register_loader 5 | 6 | 7 | def load_dataset_example(format, name, dataset_dir): 8 | dataset_dir = '{}/{}'.format(dataset_dir, name) 9 | if format == 'PyG': 10 | if name == 'QM7b': 11 | dataset_raw = QM7b(dataset_dir) 12 | graphs = GraphDataset.pyg_to_graphs(dataset_raw) 13 | return graphs 14 | 15 | 16 | register_loader('example', load_dataset_example) 17 | -------------------------------------------------------------------------------- /graphgym/contrib/loss/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import basename, dirname, isfile, join 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/contrib/loss/example.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from graphgym.config import cfg 4 | from graphgym.register import register_loss 5 | 6 | 7 | def loss_example(pred, true): 8 | if cfg.model.loss_fun == 'smoothl1': 9 | l1_loss = nn.SmoothL1Loss() 10 | loss = l1_loss(pred, true) 11 | return loss, pred 12 | 13 | 14 | register_loss('smoothl1', loss_example) 15 | -------------------------------------------------------------------------------- /graphgym/contrib/network/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import basename, dirname, isfile, join 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/contrib/network/example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch_geometric.nn as pyg_nn 5 | from torch_geometric.graphgym.models import MLP 6 | 7 | from graphgym.config import cfg 8 | from graphgym.register import register_network 9 | 10 | 11 | class GNNNodeHead(nn.Module): 12 | '''Head of GNN, node prediction''' 13 | def __init__(self, dim_in, dim_out): 14 | super(GNNNodeHead, self).__init__() 15 | self.layer_post_mp = MLP(dim_in, 16 | dim_out, 17 | num_layers=cfg.gnn.layers_post_mp, 18 | bias=True) 19 | 20 | def _apply_index(self, batch): 21 | if batch.node_label_index.shape[0] == batch.node_label.shape[0]: 22 | return batch.node_feature[batch.node_label_index], batch.node_label 23 | else: 24 | return batch.node_feature[batch.node_label_index], \ 25 | batch.node_label[batch.node_label_index] 26 | 27 | def forward(self, batch): 28 | batch = self.layer_post_mp(batch) 29 | pred, label = self._apply_index(batch) 30 | return pred, label 31 | 32 | 33 | class ExampleGNN(torch.nn.Module): 34 | def __init__(self, dim_in, dim_out, num_layers=2, model_type='GCN'): 35 | super(ExampleGNN, self).__init__() 36 | conv_model = self.build_conv_model(model_type) 37 | self.convs = nn.ModuleList() 38 | self.convs.append(conv_model(dim_in, dim_in)) 39 | 40 | for _ in range(num_layers - 1): 41 | self.convs.append(conv_model(dim_in, dim_in)) 42 | 43 | self.post_mp = GNNNodeHead(dim_in=dim_in, dim_out=dim_out) 44 | 45 | def build_conv_model(self, model_type): 46 | if model_type == 'GCN': 47 | return pyg_nn.GCNConv 48 | elif model_type == 'GAT': 49 | return pyg_nn.GATConv 50 | elif model_type == "GraphSage": 51 | return pyg_nn.SAGEConv 52 | else: 53 | raise ValueError("Model {} unavailable".format(model_type)) 54 | 55 | def forward(self, batch): 56 | x, edge_index, x_batch = \ 57 | batch.node_feature, batch.edge_index, batch.batch 58 | 59 | for i in range(len(self.convs)): 60 | x = self.convs[i](x, edge_index) 61 | x = F.relu(x) 62 | x = F.dropout(x, p=self.dropout, training=self.training) 63 | 64 | x = pyg_nn.global_add_pool(x, x_batch) 65 | x = self.post_mp(x) 66 | x = F.log_softmax(x, dim=1) 67 | batch.node_feature = x 68 | return batch 69 | 70 | 71 | register_network('example', ExampleGNN) 72 | -------------------------------------------------------------------------------- /graphgym/contrib/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import basename, dirname, isfile, join 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/contrib/optimizer/example.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | from graphgym.config import cfg 4 | from graphgym.register import register_optimizer, register_scheduler 5 | 6 | 7 | def optimizer_example(params): 8 | if cfg.optim.optimizer == 'adagrad': 9 | optimizer = optim.Adagrad(params, lr=cfg.optim.base_lr, 10 | weight_decay=cfg.optim.weight_decay) 11 | return optimizer 12 | 13 | 14 | register_optimizer('adagrad', optimizer_example) 15 | 16 | 17 | def scheduler_example(optimizer): 18 | if cfg.optim.optimizer == 'reduce': 19 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer) 20 | return scheduler 21 | 22 | 23 | register_scheduler('reduce', scheduler_example) 24 | -------------------------------------------------------------------------------- /graphgym/contrib/pooling/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import basename, dirname, isfile, join 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/contrib/pooling/example.py: -------------------------------------------------------------------------------- 1 | from torch_scatter import scatter 2 | 3 | from graphgym.register import register_pooling 4 | 5 | 6 | def global_example_pool(x, batch, size=None): 7 | size = batch.max().item() + 1 if size is None else size 8 | return scatter(x, batch, dim=0, dim_size=size, reduce='add') 9 | 10 | 11 | register_pooling('example', global_example_pool) 12 | -------------------------------------------------------------------------------- /graphgym/contrib/stage/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import basename, dirname, isfile, join 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/contrib/stage/example.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torch_geometric.nn import GCNConv 4 | 5 | from graphgym.config import cfg 6 | from graphgym.register import register_stage 7 | 8 | 9 | class GNNStackStage(nn.Module): 10 | '''Simple Stage that stack GNN layers''' 11 | 12 | def __init__(self, dim_in, dim_out, num_layers): 13 | super(GNNStackStage, self).__init__() 14 | for i in range(num_layers): 15 | d_in = dim_in if i == 0 else dim_out 16 | layer = GCNConv(d_in, dim_out) 17 | self.add_module('layer{}'.format(i), layer) 18 | self.dim_out = dim_out 19 | 20 | def forward(self, batch): 21 | for layer in self.children(): 22 | batch = layer(batch) 23 | if cfg.gnn.l2norm: 24 | batch.node_feature = F.normalize(batch.node_feature, p=2, dim=-1) 25 | return batch 26 | 27 | 28 | register_stage('example', GNNStackStage) 29 | -------------------------------------------------------------------------------- /graphgym/contrib/train/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import basename, dirname, isfile, join 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/contrib/train/example.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import torch 5 | 6 | from graphgym.checkpoint import clean_ckpt, load_ckpt, save_ckpt 7 | from graphgym.config import cfg 8 | from graphgym.loss import compute_loss 9 | from graphgym.register import register_train 10 | from graphgym.utils.epoch import is_ckpt_epoch, is_eval_epoch 11 | 12 | 13 | def train_epoch(logger, loader, model, optimizer, scheduler): 14 | model.train() 15 | time_start = time.time() 16 | for batch in loader: 17 | optimizer.zero_grad() 18 | batch.to(torch.device(cfg.device)) 19 | pred, true = model(batch) 20 | loss, pred_score = compute_loss(pred, true) 21 | loss.backward() 22 | optimizer.step() 23 | logger.update_stats(true=true.detach().cpu(), 24 | pred=pred_score.detach().cpu(), 25 | loss=loss.item(), 26 | lr=scheduler.get_last_lr()[0], 27 | time_used=time.time() - time_start, 28 | params=cfg.params) 29 | time_start = time.time() 30 | scheduler.step() 31 | 32 | 33 | def eval_epoch(logger, loader, model): 34 | model.eval() 35 | time_start = time.time() 36 | for batch in loader: 37 | batch.to(torch.device(cfg.device)) 38 | pred, true = model(batch) 39 | loss, pred_score = compute_loss(pred, true) 40 | logger.update_stats(true=true.detach().cpu(), 41 | pred=pred_score.detach().cpu(), 42 | loss=loss.item(), 43 | lr=0, 44 | time_used=time.time() - time_start, 45 | params=cfg.params) 46 | time_start = time.time() 47 | 48 | 49 | def train_example(loggers, loaders, model, optimizer, scheduler): 50 | start_epoch = 0 51 | if cfg.train.auto_resume: 52 | start_epoch = load_ckpt(model, optimizer, scheduler) 53 | if start_epoch == cfg.optim.max_epoch: 54 | logging.info('Checkpoint found, Task already done') 55 | else: 56 | logging.info('Start from epoch {}'.format(start_epoch)) 57 | 58 | num_splits = len(loggers) 59 | for cur_epoch in range(start_epoch, cfg.optim.max_epoch): 60 | train_epoch(loggers[0], loaders[0], model, optimizer, scheduler) 61 | loggers[0].write_epoch(cur_epoch) 62 | if is_eval_epoch(cur_epoch): 63 | for i in range(1, num_splits): 64 | eval_epoch(loggers[i], loaders[i], model) 65 | loggers[i].write_epoch(cur_epoch) 66 | if is_ckpt_epoch(cur_epoch): 67 | save_ckpt(model, optimizer, scheduler, cur_epoch) 68 | for logger in loggers: 69 | logger.close() 70 | if cfg.train.ckpt_clean: 71 | clean_ckpt() 72 | 73 | logging.info('Task done, results saved in {}'.format(cfg.out_dir)) 74 | 75 | 76 | register_train('example', train_example) 77 | -------------------------------------------------------------------------------- /graphgym/contrib/transform/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import basename, dirname, isfile, join 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/contrib/transform/identity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import add_remaining_self_loops 3 | from torch_scatter import scatter_add 4 | 5 | 6 | def norm(edge_index, num_nodes, edge_weight=None, improved=False, dtype=None): 7 | if edge_weight is None: 8 | edge_weight = torch.ones((edge_index.size(1), ), 9 | dtype=dtype, 10 | device=edge_index.device) 11 | 12 | fill_value = 1.0 if not improved else 2.0 13 | edge_index, edge_weight = add_remaining_self_loops(edge_index, edge_weight, 14 | fill_value, num_nodes) 15 | 16 | row, col = edge_index 17 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 18 | deg_inv_sqrt = deg.pow(-0.5) 19 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 20 | 21 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 22 | 23 | 24 | # cpu version 25 | def compute_identity(edge_index, n, k): 26 | id, value = norm(edge_index, n) 27 | adj_sparse = torch.sparse.FloatTensor(id, value, torch.Size([n, n])) 28 | adj = adj_sparse.to_dense() 29 | diag_all = [torch.diag(adj)] 30 | adj_power = adj 31 | for i in range(1, k): 32 | adj_power = adj_power @ adj 33 | diag_all.append(torch.diag(adj_power)) 34 | diag_all = torch.stack(diag_all, dim=1) 35 | return diag_all 36 | -------------------------------------------------------------------------------- /graphgym/init.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def init_weights(m): 5 | r""" 6 | Performs weight initialization 7 | 8 | Args: 9 | m (nn.Module): PyTorch module 10 | 11 | """ 12 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 13 | m.weight.data.fill_(1.0) 14 | m.bias.data.zero_() 15 | elif isinstance(m, nn.Linear): 16 | m.weight.data = nn.init.xavier_uniform_( 17 | m.weight.data, gain=nn.init.calculate_gain('relu')) 18 | if m.bias is not None: 19 | m.bias.data.zero_() 20 | -------------------------------------------------------------------------------- /graphgym/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import sys 4 | 5 | import torch 6 | 7 | import graphgym.register as register 8 | from graphgym.config import cfg 9 | from graphgym.utils.device import get_current_gpu_usage 10 | from graphgym.utils.io import dict_to_json, dict_to_tb, makedirs 11 | 12 | 13 | def setup_printing(): 14 | """ 15 | Set up printing options 16 | 17 | """ 18 | logging.root.handlers = [] 19 | logging_cfg = {'level': logging.INFO, 'format': '%(message)s'} 20 | makedirs(cfg.run_dir) 21 | h_file = logging.FileHandler('{}/logging.log'.format(cfg.run_dir)) 22 | h_stdout = logging.StreamHandler(sys.stdout) 23 | if cfg.print == 'file': 24 | logging_cfg['handlers'] = [h_file] 25 | elif cfg.print == 'stdout': 26 | logging_cfg['handlers'] = [h_stdout] 27 | elif cfg.print == 'both': 28 | logging_cfg['handlers'] = [h_file, h_stdout] 29 | else: 30 | raise ValueError('Print option not supported') 31 | logging.basicConfig(**logging_cfg) 32 | 33 | 34 | class Logger(object): 35 | def __init__(self, name='train', task_type=None): 36 | self.name = name 37 | self.task_type = task_type 38 | 39 | self._epoch_total = cfg.optim.max_epoch 40 | self._time_total = 0 # won't be reset 41 | 42 | self.out_dir = '{}/{}'.format(cfg.run_dir, name) 43 | makedirs(self.out_dir) 44 | if cfg.tensorboard_each_run: 45 | from tensorboardX import SummaryWriter 46 | self.tb_writer = SummaryWriter(self.out_dir) 47 | 48 | self.reset() 49 | 50 | def __getitem__(self, key): 51 | return getattr(self, key, None) 52 | 53 | def __setitem__(self, key, value): 54 | setattr(self, key, value) 55 | 56 | def reset(self): 57 | self._iter = 0 58 | self._size_current = 0 59 | self._loss = 0 60 | self._lr = 0 61 | self._params = 0 62 | self._time_used = 0 63 | self._true = [] 64 | self._pred = [] 65 | self._custom_stats = {} 66 | 67 | # basic properties 68 | def basic(self): 69 | stats = { 70 | 'loss': round(self._loss / self._size_current, cfg.round), 71 | 'lr': round(self._lr, cfg.round), 72 | 'params': self._params, 73 | 'time_iter': round(self.time_iter(), cfg.round), 74 | } 75 | gpu_memory = get_current_gpu_usage() 76 | if gpu_memory > 0: 77 | stats['gpu_memory'] = gpu_memory 78 | return stats 79 | 80 | # customized input properties 81 | def custom(self): 82 | if len(self._custom_stats) == 0: 83 | return {} 84 | out = {} 85 | for key, val in self._custom_stats.items(): 86 | out[key] = val / self._size_current 87 | return out 88 | 89 | def _get_pred_int(self, pred_score): 90 | if len(pred_score.shape) == 1 or pred_score.shape[1] == 1: 91 | return (pred_score > cfg.model.thresh).long() 92 | else: 93 | return pred_score.max(dim=1)[1] 94 | 95 | # task properties 96 | def classification_binary(self): 97 | from sklearn.metrics import (accuracy_score, f1_score, precision_score, 98 | recall_score, roc_auc_score) 99 | 100 | true, pred_score = torch.cat(self._true), torch.cat(self._pred) 101 | pred_int = self._get_pred_int(pred_score) 102 | try: 103 | r_a_score = roc_auc_score(true, pred_score) 104 | except ValueError: 105 | r_a_score = 0.0 106 | return { 107 | 'accuracy': round(accuracy_score(true, pred_int), cfg.round), 108 | 'precision': round(precision_score(true, pred_int), cfg.round), 109 | 'recall': round(recall_score(true, pred_int), cfg.round), 110 | 'f1': round(f1_score(true, pred_int), cfg.round), 111 | 'auc': round(r_a_score, cfg.round), 112 | } 113 | 114 | def classification_multi(self): 115 | from sklearn.metrics import accuracy_score 116 | 117 | true, pred_score = torch.cat(self._true), torch.cat(self._pred) 118 | pred_int = self._get_pred_int(pred_score) 119 | return {'accuracy': round(accuracy_score(true, pred_int), cfg.round)} 120 | 121 | def regression(self): 122 | from sklearn.metrics import mean_absolute_error, mean_squared_error 123 | 124 | true, pred = torch.cat(self._true), torch.cat(self._pred) 125 | return { 126 | 'mae': 127 | float(round(mean_absolute_error(true, pred), cfg.round)), 128 | 'mse': 129 | float(round(mean_squared_error(true, pred), cfg.round)), 130 | 'rmse': 131 | float(round(math.sqrt(mean_squared_error(true, pred)), cfg.round)) 132 | } 133 | 134 | def time_iter(self): 135 | return self._time_used / self._iter 136 | 137 | def eta(self, epoch_current): 138 | epoch_current += 1 # since counter starts from 0 139 | time_per_epoch = self._time_total / epoch_current 140 | return time_per_epoch * (self._epoch_total - epoch_current) 141 | 142 | def update_stats(self, true, pred, loss, lr, time_used, params, **kwargs): 143 | assert true.shape[0] == pred.shape[0] 144 | self._iter += 1 145 | self._true.append(true) 146 | self._pred.append(pred) 147 | batch_size = true.shape[0] 148 | self._size_current += batch_size 149 | self._loss += loss * batch_size 150 | self._lr = lr 151 | self._params = params 152 | self._time_used += time_used 153 | self._time_total += time_used 154 | for key, val in kwargs.items(): 155 | if key not in self._custom_stats: 156 | self._custom_stats[key] = val * batch_size 157 | else: 158 | self._custom_stats[key] += val * batch_size 159 | 160 | def write_iter(self): 161 | raise NotImplementedError 162 | 163 | def write_epoch(self, cur_epoch): 164 | basic_stats = self.basic() 165 | 166 | # Try to load customized metrics 167 | task_stats = {} 168 | for custom_metric in cfg.custom_metrics: 169 | func = register.metric_dict.get(custom_metric) 170 | if not func: 171 | raise ValueError( 172 | f'Unknown custom metric function name: {custom_metric}') 173 | custom_metric_score = func(self._true, self._pred, self.task_type) 174 | task_stats[custom_metric] = custom_metric_score 175 | 176 | if not task_stats: # use default metrics if no matching custom metric 177 | if self.task_type == 'regression': 178 | task_stats = self.regression() 179 | elif self.task_type == 'classification_binary': 180 | task_stats = self.classification_binary() 181 | elif self.task_type == 'classification_multi': 182 | task_stats = self.classification_multi() 183 | else: 184 | raise ValueError('Task has to be regression or classification') 185 | 186 | epoch_stats = {'epoch': cur_epoch} 187 | eta_stats = {'eta': round(self.eta(cur_epoch), cfg.round)} 188 | custom_stats = self.custom() 189 | 190 | if self.name == 'train': 191 | stats = { 192 | **epoch_stats, 193 | **eta_stats, 194 | **basic_stats, 195 | **task_stats, 196 | **custom_stats 197 | } 198 | else: 199 | stats = { 200 | **epoch_stats, 201 | **basic_stats, 202 | **task_stats, 203 | **custom_stats 204 | } 205 | 206 | # print 207 | logging.info('{}: {}'.format(self.name, stats)) 208 | # json 209 | dict_to_json(stats, '{}/stats.json'.format(self.out_dir)) 210 | # tensorboard 211 | if cfg.tensorboard_each_run: 212 | dict_to_tb(stats, self.tb_writer, cur_epoch) 213 | self.reset() 214 | 215 | def close(self): 216 | if cfg.tensorboard_each_run: 217 | self.tb_writer.close() 218 | 219 | 220 | def infer_task(): 221 | num_label = cfg.share.dim_out 222 | if cfg.dataset.task_type == 'classification': 223 | if num_label <= 2: 224 | task_type = 'classification_binary' 225 | else: 226 | task_type = 'classification_multi' 227 | else: 228 | task_type = cfg.dataset.task_type 229 | return task_type 230 | 231 | 232 | def create_logger(): 233 | """ 234 | Create logger for the experiment 235 | 236 | Returns: List of logger objects 237 | 238 | """ 239 | loggers = [] 240 | names = ['train', 'val', 'test'] 241 | for i, dataset in enumerate(range(cfg.share.num_splits)): 242 | loggers.append(Logger(name=names[i], task_type=infer_task())) 243 | return loggers 244 | -------------------------------------------------------------------------------- /graphgym/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import graphgym.register as register 6 | from graphgym.config import cfg 7 | 8 | 9 | def compute_loss(pred, true): 10 | """ 11 | Compute loss and prediction score 12 | 13 | Args: 14 | pred (torch.tensor): Unnormalized prediction 15 | true (torch.tensor): Grou 16 | 17 | Returns: Loss, normalized prediction score 18 | 19 | """ 20 | bce_loss = nn.BCEWithLogitsLoss(reduction=cfg.model.size_average) 21 | mse_loss = nn.MSELoss(reduction=cfg.model.size_average) 22 | 23 | # default manipulation for pred and true 24 | # can be skipped if special loss computation is needed 25 | pred = pred.squeeze(-1) if pred.ndim > 1 else pred 26 | true = true.squeeze(-1) if true.ndim > 1 else true 27 | 28 | # Try to load customized loss 29 | for func in register.loss_dict.values(): 30 | value = func(pred, true) 31 | if value is not None: 32 | return value 33 | 34 | if cfg.model.loss_fun == 'cross_entropy': 35 | # multiclass 36 | if pred.ndim > 1 and true.ndim == 1: 37 | pred = F.log_softmax(pred, dim=-1) 38 | return F.nll_loss(pred, true), pred 39 | # binary or multilabel 40 | else: 41 | true = true.float() 42 | return bce_loss(pred, true), torch.sigmoid(pred) 43 | elif cfg.model.loss_fun == 'mse': 44 | true = true.float() 45 | return mse_loss(pred, true), pred 46 | else: 47 | raise ValueError('Loss func {} not supported'.format( 48 | cfg.model.loss_fun)) 49 | -------------------------------------------------------------------------------- /graphgym/model_builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import graphgym.register as register 4 | from graphgym.config import cfg 5 | from graphgym.models.gnn import GNN 6 | 7 | network_dict = { 8 | 'gnn': GNN, 9 | } 10 | network_dict = {**register.network_dict, **network_dict} 11 | 12 | 13 | def create_model(to_device=True, dim_in=None, dim_out=None): 14 | r""" 15 | Create model for graph machine learning 16 | 17 | Args: 18 | to_device (string): The devide that the model will be transferred to 19 | dim_in (int, optional): Input dimension to the model 20 | dim_out (int, optional): Output dimension to the model 21 | """ 22 | dim_in = cfg.share.dim_in if dim_in is None else dim_in 23 | dim_out = cfg.share.dim_out if dim_out is None else dim_out 24 | # binary classification, output dim = 1 25 | if 'classification' in cfg.dataset.task_type and dim_out == 2: 26 | dim_out = 1 27 | 28 | model = network_dict[cfg.model.type](dim_in=dim_in, dim_out=dim_out) 29 | if to_device: 30 | model.to(torch.device(cfg.device)) 31 | return model 32 | -------------------------------------------------------------------------------- /graphgym/model_builder_pyg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from graphgym.config import cfg 4 | from graphgym.models.gnn_pyg import GNN 5 | from graphgym.register import network_dict, register_network 6 | 7 | register_network('gnn', GNN) 8 | 9 | 10 | def create_model(to_device=True, dim_in=None, dim_out=None): 11 | r""" 12 | Create model for graph machine learning 13 | 14 | Args: 15 | to_device (string): The devide that the model will be transferred to 16 | dim_in (int, optional): Input dimension to the model 17 | dim_out (int, optional): Output dimension to the model 18 | """ 19 | dim_in = cfg.share.dim_in if dim_in is None else dim_in 20 | dim_out = cfg.share.dim_out if dim_out is None else dim_out 21 | # binary classification, output dim = 1 22 | if 'classification' in cfg.dataset.task_type and dim_out == 2: 23 | dim_out = 1 24 | 25 | model = network_dict[cfg.model.type](dim_in=dim_in, dim_out=dim_out) 26 | if to_device: 27 | model.to(torch.device(cfg.device)) 28 | return model 29 | -------------------------------------------------------------------------------- /graphgym/models/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 Jiaxuan You 2 | Copyright (c) 2020 Jiaxuan You, Rex Ying, Jonathan Gomes Selman 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy 5 | of this software and associated documentation files (the "Software"), to deal 6 | in the Software without restriction, including without limitation the rights 7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the Software is 9 | furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in 12 | all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 | THE SOFTWARE. -------------------------------------------------------------------------------- /graphgym/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snap-stanford/GraphGym/daded21169ec92fde8b1252b439a8fac35b07d79/graphgym/models/__init__.py -------------------------------------------------------------------------------- /graphgym/models/act.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | import graphgym.register as register 4 | from graphgym.config import cfg 5 | 6 | act_dict = { 7 | 'relu': nn.ReLU(inplace=cfg.mem.inplace), 8 | 'selu': nn.SELU(inplace=cfg.mem.inplace), 9 | 'prelu': nn.PReLU(), 10 | 'elu': nn.ELU(inplace=cfg.mem.inplace), 11 | 'lrelu_01': nn.LeakyReLU(negative_slope=0.1, inplace=cfg.mem.inplace), 12 | 'lrelu_025': nn.LeakyReLU(negative_slope=0.25, inplace=cfg.mem.inplace), 13 | 'lrelu_05': nn.LeakyReLU(negative_slope=0.5, inplace=cfg.mem.inplace), 14 | } 15 | 16 | act_dict = {**register.act_dict, **act_dict} 17 | -------------------------------------------------------------------------------- /graphgym/models/feature_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims 3 | 4 | import graphgym.register as register 5 | 6 | # Used for the OGB Encoders 7 | full_atom_feature_dims = get_atom_feature_dims() 8 | full_bond_feature_dims = get_bond_feature_dims() 9 | 10 | 11 | class IntegerFeatureEncoder(torch.nn.Module): 12 | """ 13 | Provides an encoder for integer node features. 14 | 15 | Args: 16 | emb_dim (int): Output embedding dimension 17 | num_classes (int): the number of classes for the 18 | embedding mapping to learn from 19 | """ 20 | def __init__(self, emb_dim, num_classes=None): 21 | super(IntegerFeatureEncoder, self).__init__() 22 | 23 | self.encoder = torch.nn.Embedding(num_classes, emb_dim) 24 | torch.nn.init.xavier_uniform_(self.encoder.weight.data) 25 | 26 | def forward(self, batch): 27 | # Encode just the first dimension if more exist 28 | batch.node_feature = self.encoder(batch.node_feature[:, 0]) 29 | 30 | return batch 31 | 32 | 33 | class SingleAtomEncoder(torch.nn.Module): 34 | """ 35 | Only encode the first dimension of atom integer features. 36 | This feature encodes just the atom type 37 | Args: 38 | emb_dim (int): Output embedding dimension 39 | num_classes: None 40 | """ 41 | def __init__(self, emb_dim, num_classes=None): 42 | super(SingleAtomEncoder, self).__init__() 43 | 44 | num_atom_types = full_atom_feature_dims[0] 45 | self.atom_type_embedding = torch.nn.Embedding(num_atom_types, emb_dim) 46 | torch.nn.init.xavier_uniform_(self.atom_type_embedding.weight.data) 47 | 48 | def forward(self, batch): 49 | batch.node_feature = self.atom_type_embedding(batch.node_feature[:, 0]) 50 | 51 | return batch 52 | 53 | 54 | class AtomEncoder(torch.nn.Module): 55 | """ 56 | The atom Encoder used in OGB molecule dataset. 57 | 58 | Args: 59 | emb_dim (int): Output embedding dimension 60 | num_classes: None 61 | """ 62 | def __init__(self, emb_dim, num_classes=None): 63 | super(AtomEncoder, self).__init__() 64 | 65 | self.atom_embedding_list = torch.nn.ModuleList() 66 | 67 | for i, dim in enumerate(full_atom_feature_dims): 68 | emb = torch.nn.Embedding(dim, emb_dim) 69 | torch.nn.init.xavier_uniform_(emb.weight.data) 70 | self.atom_embedding_list.append(emb) 71 | 72 | def forward(self, batch): 73 | encoded_features = 0 74 | for i in range(batch.node_feature.shape[1]): 75 | encoded_features += self.atom_embedding_list[i]( 76 | batch.node_feature[:, i]) 77 | batch.node_feature = encoded_features 78 | return batch 79 | 80 | 81 | class BondEncoder(torch.nn.Module): 82 | """ 83 | The bond Encoder used in OGB molecule dataset. 84 | 85 | Args: 86 | emb_dim (int): Output edge embedding dimension 87 | """ 88 | def __init__(self, emb_dim): 89 | super(BondEncoder, self).__init__() 90 | 91 | self.bond_embedding_list = torch.nn.ModuleList() 92 | 93 | for i, dim in enumerate(full_bond_feature_dims): 94 | emb = torch.nn.Embedding(dim, emb_dim) 95 | torch.nn.init.xavier_uniform_(emb.weight.data) 96 | self.bond_embedding_list.append(emb) 97 | 98 | def forward(self, batch): 99 | bond_embedding = 0 100 | for i in range(batch.edge_feature.shape[1]): 101 | bond_embedding += self.bond_embedding_list[i]( 102 | batch.edge_feature[:, i]) 103 | 104 | batch.edge_feature = bond_embedding 105 | return batch 106 | 107 | 108 | node_encoder_dict = { 109 | 'Integer': IntegerFeatureEncoder, 110 | 'SingleAtom': SingleAtomEncoder, 111 | 'Atom': AtomEncoder 112 | } 113 | 114 | node_encoder_dict = {**register.node_encoder_dict, **node_encoder_dict} 115 | 116 | edge_encoder_dict = {'Bond': BondEncoder} 117 | 118 | edge_encoder_dict = {**register.edge_encoder_dict, **edge_encoder_dict} 119 | -------------------------------------------------------------------------------- /graphgym/models/feature_encoder_pyg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims 3 | 4 | import graphgym.register as register 5 | 6 | # Used for the OGB Encoders 7 | full_atom_feature_dims = get_atom_feature_dims() 8 | full_bond_feature_dims = get_bond_feature_dims() 9 | 10 | 11 | # Feature Encoders 12 | class IntegerFeatureEncoder(torch.nn.Module): 13 | """ 14 | Provides an encoder for integer node features 15 | 16 | Parameters: 17 | num_classes - the number of classes for the embedding mapping to learn 18 | """ 19 | 20 | def __init__(self, emb_dim, num_classes=None): 21 | super(IntegerFeatureEncoder, self).__init__() 22 | 23 | self.encoder = torch.nn.Embedding(num_classes, emb_dim) 24 | torch.nn.init.xavier_uniform_(self.encoder.weight.data) 25 | 26 | def forward(self, batch): 27 | # Encode just the first dimension if more exist 28 | batch.x = self.encoder(batch.x[:, 0]) 29 | 30 | return batch 31 | 32 | 33 | class SingleAtomEncoder(torch.nn.Module): 34 | """ 35 | Only encode the first dimension of atom integer features. 36 | This feature encodes just the atom type 37 | 38 | Parameters: 39 | num_classes: Not used! 40 | """ 41 | 42 | def __init__(self, emb_dim, num_classes=None): 43 | super(SingleAtomEncoder, self).__init__() 44 | 45 | num_atom_types = full_atom_feature_dims[0] 46 | self.atom_type_embedding = torch.nn.Embedding(num_atom_types, emb_dim) 47 | torch.nn.init.xavier_uniform_(self.atom_type_embedding.weight.data) 48 | 49 | def forward(self, batch): 50 | batch.x = self.atom_type_embedding(batch.x[:, 0]) 51 | 52 | return batch 53 | 54 | 55 | class AtomEncoder(torch.nn.Module): 56 | """ 57 | The complete Atom Encoder used in OGB dataset 58 | 59 | Parameters: 60 | num_classes: Not used! 61 | """ 62 | 63 | def __init__(self, emb_dim, num_classes=None): 64 | super(AtomEncoder, self).__init__() 65 | 66 | self.atom_embedding_list = torch.nn.ModuleList() 67 | 68 | for i, dim in enumerate(full_atom_feature_dims): 69 | emb = torch.nn.Embedding(dim, emb_dim) 70 | torch.nn.init.xavier_uniform_(emb.weight.data) 71 | self.atom_embedding_list.append(emb) 72 | 73 | def forward(self, batch): 74 | encoded_features = 0 75 | for i in range(batch.x.shape[1]): 76 | encoded_features += self.atom_embedding_list[i]( 77 | batch.x[:, i]) 78 | 79 | batch.x = encoded_features 80 | return batch 81 | 82 | 83 | class BondEncoder(torch.nn.Module): 84 | 85 | def __init__(self, emb_dim): 86 | super(BondEncoder, self).__init__() 87 | 88 | self.bond_embedding_list = torch.nn.ModuleList() 89 | 90 | for i, dim in enumerate(full_bond_feature_dims): 91 | emb = torch.nn.Embedding(dim, emb_dim) 92 | torch.nn.init.xavier_uniform_(emb.weight.data) 93 | self.bond_embedding_list.append(emb) 94 | 95 | def forward(self, batch): 96 | bond_embedding = 0 97 | for i in range(batch.edge_attr.shape[1]): 98 | bond_embedding += self.bond_embedding_list[i]( 99 | batch.edge_attr[:, i]) 100 | 101 | batch.edge_attr = bond_embedding 102 | return batch 103 | 104 | 105 | node_encoder_dict = { 106 | 'Integer': IntegerFeatureEncoder, 107 | 'SingleAtom': SingleAtomEncoder, 108 | 'Atom': AtomEncoder 109 | } 110 | 111 | node_encoder_dict = {**register.node_encoder_dict, **node_encoder_dict} 112 | 113 | edge_encoder_dict = { 114 | 'Bond': BondEncoder 115 | } 116 | 117 | edge_encoder_dict = {**register.edge_encoder_dict, **edge_encoder_dict} 118 | -------------------------------------------------------------------------------- /graphgym/models/gnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import graphgym.register as register 6 | from graphgym.config import cfg 7 | from graphgym.init import init_weights 8 | from graphgym.models.act import act_dict 9 | from graphgym.models.feature_augment import Preprocess 10 | from graphgym.models.feature_encoder import (edge_encoder_dict, 11 | node_encoder_dict) 12 | from graphgym.models.head import head_dict 13 | from graphgym.models.layer import (BatchNorm1dEdge, BatchNorm1dNode, 14 | GeneralLayer, GeneralMultiLayer) 15 | 16 | 17 | # Layer 18 | def GNNLayer(dim_in, dim_out, has_act=True): 19 | """ 20 | Wrapper for a GNN layer 21 | 22 | Args: 23 | dim_in (int): Input dimension 24 | dim_out (int): Output dimension 25 | has_act (bool): Whether has activation function after the layer 26 | 27 | """ 28 | return GeneralLayer(cfg.gnn.layer_type, dim_in, dim_out, has_act) 29 | 30 | 31 | def GNNPreMP(dim_in, dim_out): 32 | """ 33 | Wrapper for NN layer before GNN message passing 34 | 35 | Args: 36 | dim_in (int): Input dimension 37 | dim_out (int): Output dimension 38 | num_layers (int): Number of layers 39 | 40 | """ 41 | return GeneralMultiLayer('linear', 42 | cfg.gnn.layers_pre_mp, 43 | dim_in, 44 | dim_out, 45 | dim_inner=dim_out, 46 | final_act=True) 47 | 48 | 49 | # Block: multiple layers 50 | class GNNSkipBlock(nn.Module): 51 | '''Skip block for GNN''' 52 | def __init__(self, dim_in, dim_out, num_layers): 53 | super(GNNSkipBlock, self).__init__() 54 | if num_layers == 1: 55 | self.f = [GNNLayer(dim_in, dim_out, has_act=False)] 56 | else: 57 | self.f = [] 58 | for i in range(num_layers - 1): 59 | d_in = dim_in if i == 0 else dim_out 60 | self.f.append(GNNLayer(d_in, dim_out)) 61 | d_in = dim_in if num_layers == 1 else dim_out 62 | self.f.append(GNNLayer(d_in, dim_out, has_act=False)) 63 | self.f = nn.Sequential(*self.f) 64 | self.act = act_dict[cfg.gnn.act] 65 | if cfg.gnn.stage_type == 'skipsum': 66 | assert dim_in == dim_out, 'Sum skip must have same dim_in, dim_out' 67 | 68 | def forward(self, batch): 69 | node_feature = batch.node_feature 70 | if cfg.gnn.stage_type == 'skipsum': 71 | batch.node_feature = \ 72 | node_feature + self.f(batch).node_feature 73 | elif cfg.gnn.stage_type == 'skipconcat': 74 | batch.node_feature = \ 75 | torch.cat((node_feature, self.f(batch).node_feature), 1) 76 | else: 77 | raise ValueError( 78 | 'cfg.gnn.stage_type must in [skipsum, skipconcat]') 79 | batch.node_feature = self.act(batch.node_feature) 80 | return batch 81 | 82 | 83 | # Stage: NN except start and head 84 | class GNNStackStage(nn.Module): 85 | '''Simple Stage that stack GNN layers''' 86 | def __init__(self, dim_in, dim_out, num_layers): 87 | super(GNNStackStage, self).__init__() 88 | for i in range(num_layers): 89 | d_in = dim_in if i == 0 else dim_out 90 | layer = GNNLayer(d_in, dim_out) 91 | self.add_module('layer{}'.format(i), layer) 92 | self.dim_out = dim_out 93 | 94 | def forward(self, batch): 95 | for layer in self.children(): 96 | batch = layer(batch) 97 | if cfg.gnn.l2norm: 98 | batch.node_feature = F.normalize(batch.node_feature, p=2, dim=-1) 99 | return batch 100 | 101 | 102 | class GNNSkipStage(nn.Module): 103 | ''' Stage with skip connections''' 104 | def __init__(self, dim_in, dim_out, num_layers): 105 | super(GNNSkipStage, self).__init__() 106 | assert num_layers % cfg.gnn.skip_every == 0, \ 107 | 'cfg.gnn.skip_every must be multiples of cfg.gnn.layer_mp' \ 108 | '(excluding head layer)' 109 | for i in range(num_layers // cfg.gnn.skip_every): 110 | if cfg.gnn.stage_type == 'skipsum': 111 | d_in = dim_in if i == 0 else dim_out 112 | elif cfg.gnn.stage_type == 'skipconcat': 113 | d_in = dim_in if i == 0 else dim_in + i * dim_out 114 | block = GNNSkipBlock(d_in, dim_out, cfg.gnn.skip_every) 115 | self.add_module('block{}'.format(i), block) 116 | if cfg.gnn.stage_type == 'skipconcat': 117 | self.dim_out = d_in + dim_out 118 | else: 119 | self.dim_out = dim_out 120 | 121 | def forward(self, batch): 122 | for layer in self.children(): 123 | batch = layer(batch) 124 | if cfg.gnn.l2norm: 125 | batch.node_feature = F.normalize(batch.node_feature, p=2, dim=-1) 126 | return batch 127 | 128 | 129 | stage_dict = { 130 | 'stack': GNNStackStage, 131 | 'skipsum': GNNSkipStage, 132 | 'skipconcat': GNNSkipStage, 133 | } 134 | 135 | stage_dict = {**register.stage_dict, **stage_dict} 136 | 137 | 138 | # Model: start + stage + head 139 | class GNN(nn.Module): 140 | '''General GNN model''' 141 | def __init__(self, dim_in, dim_out, **kwargs): 142 | """ 143 | Parameters: 144 | node_encoding_classes - For integer features, gives the number 145 | of possible integer features to map. 146 | """ 147 | super(GNN, self).__init__() 148 | GNNStage = stage_dict[cfg.gnn.stage_type] 149 | GNNHead = head_dict[cfg.dataset.task] 150 | 151 | if cfg.dataset.node_encoder: 152 | # Encode integer node features via nn.Embeddings 153 | NodeEncoder = node_encoder_dict[cfg.dataset.node_encoder_name] 154 | self.node_encoder = NodeEncoder(cfg.dataset.encoder_dim) 155 | if cfg.dataset.node_encoder_bn: 156 | self.node_encoder_bn = BatchNorm1dNode(cfg.dataset.encoder_dim) 157 | # Update dim_in to reflect the new dimension fo the node features 158 | dim_in = cfg.dataset.encoder_dim 159 | if cfg.dataset.edge_encoder: 160 | # Encode integer edge features via nn.Embeddings 161 | EdgeEncoder = edge_encoder_dict[cfg.dataset.edge_encoder_name] 162 | self.edge_encoder = EdgeEncoder(cfg.dataset.encoder_dim) 163 | if cfg.dataset.edge_encoder_bn: 164 | self.edge_encoder_bn = BatchNorm1dEdge(cfg.dataset.edge_dim) 165 | 166 | self.preprocess = Preprocess(dim_in) 167 | d_in = self.preprocess.dim_out 168 | if cfg.gnn.layers_pre_mp > 0: 169 | self.pre_mp = GNNPreMP(d_in, cfg.gnn.dim_inner) 170 | d_in = cfg.gnn.dim_inner 171 | if cfg.gnn.layers_mp > 0: 172 | self.mp = GNNStage(dim_in=d_in, 173 | dim_out=cfg.gnn.dim_inner, 174 | num_layers=cfg.gnn.layers_mp) 175 | d_in = self.mp.dim_out 176 | self.post_mp = GNNHead(dim_in=d_in, dim_out=dim_out) 177 | 178 | self.apply(init_weights) 179 | 180 | def forward(self, batch): 181 | for module in self.children(): 182 | batch = module(batch) 183 | return batch 184 | -------------------------------------------------------------------------------- /graphgym/models/gnn_pyg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import graphgym.register as register 6 | from graphgym.config import cfg 7 | from graphgym.init import init_weights 8 | from graphgym.models.feature_encoder_pyg import (edge_encoder_dict, 9 | node_encoder_dict) 10 | from graphgym.models.head_pyg import head_dict 11 | from graphgym.models.layer_pyg import (BatchNorm1dEdge, BatchNorm1dNode, 12 | GeneralLayer, GeneralMultiLayer) 13 | 14 | 15 | # Layer 16 | def GNNLayer(dim_in, dim_out, has_act=True): 17 | """ 18 | Wrapper for a GNN layer 19 | 20 | Args: 21 | dim_in (int): Input dimension 22 | dim_out (int): Output dimension 23 | has_act (bool): Whether has activation function after the layer 24 | 25 | """ 26 | return GeneralLayer(cfg.gnn.layer_type, dim_in, dim_out, has_act) 27 | 28 | 29 | def GNNPreMP(dim_in, dim_out): 30 | """ 31 | Wrapper for NN layer before GNN message passing 32 | 33 | Args: 34 | dim_in (int): Input dimension 35 | dim_out (int): Output dimension 36 | num_layers (int): Number of layers 37 | 38 | """ 39 | return GeneralMultiLayer('linear', 40 | cfg.gnn.layers_pre_mp, 41 | dim_in, 42 | dim_out, 43 | dim_inner=dim_out, 44 | final_act=True) 45 | 46 | 47 | # Stage: NN except start and head 48 | class GNNStackStage(nn.Module): 49 | """ 50 | Simple Stage that stack GNN layers 51 | 52 | Args: 53 | dim_in (int): Input dimension 54 | dim_out (int): Output dimension 55 | num_layers (int): Number of GNN layers 56 | """ 57 | def __init__(self, dim_in, dim_out, num_layers): 58 | super(GNNStackStage, self).__init__() 59 | self.num_layers = num_layers 60 | for i in range(num_layers): 61 | if cfg.gnn.stage_type == 'skipconcat': 62 | d_in = dim_in if i == 0 else dim_in + i * dim_out 63 | else: 64 | d_in = dim_in if i == 0 else dim_out 65 | layer = GNNLayer(d_in, dim_out) 66 | self.add_module('layer{}'.format(i), layer) 67 | 68 | def forward(self, batch): 69 | for i, layer in enumerate(self.children()): 70 | x = batch.x 71 | batch = layer(batch) 72 | if cfg.gnn.stage_type == 'skipsum': 73 | batch.x = x + batch.x 74 | elif cfg.gnn.stage_type == 'skipconcat' and \ 75 | i < self.num_layers - 1: 76 | batch.x = torch.cat([x, batch.x], dim=1) 77 | if cfg.gnn.l2norm: 78 | batch.x = F.normalize(batch.x, p=2, dim=-1) 79 | return batch 80 | 81 | 82 | stage_dict = { 83 | 'stack': GNNStackStage, 84 | 'skipsum': GNNStackStage, 85 | 'skipconcat': GNNStackStage, 86 | } 87 | 88 | stage_dict = {**register.stage_dict, **stage_dict} 89 | 90 | 91 | # Feature encoder 92 | class FeatureEncoder(nn.Module): 93 | """ 94 | Encoding node and edge features 95 | 96 | Args: 97 | dim_in (int): Input feature dimension 98 | """ 99 | def __init__(self, dim_in): 100 | super(FeatureEncoder, self).__init__() 101 | self.dim_in = dim_in 102 | if cfg.dataset.node_encoder: 103 | # Encode integer node features via nn.Embeddings 104 | NodeEncoder = node_encoder_dict[cfg.dataset.node_encoder_name] 105 | self.node_encoder = NodeEncoder(cfg.gnn.dim_inner) 106 | if cfg.dataset.node_encoder_bn: 107 | self.node_encoder_bn = BatchNorm1dNode(cfg.gnn.dim_inner) 108 | # Update dim_in to reflect the new dimension fo the node features 109 | self.dim_in = cfg.gnn.dim_inner 110 | if cfg.dataset.edge_encoder: 111 | # Encode integer edge features via nn.Embeddings 112 | EdgeEncoder = edge_encoder_dict[cfg.dataset.edge_encoder_name] 113 | self.edge_encoder = EdgeEncoder(cfg.gnn.dim_inner) 114 | if cfg.dataset.edge_encoder_bn: 115 | self.edge_encoder_bn = BatchNorm1dEdge(cfg.gnn.dim_inner) 116 | 117 | def forward(self, batch): 118 | for module in self.children(): 119 | batch = module(batch) 120 | return batch 121 | 122 | 123 | # Model: start + stage + head 124 | class GNN(nn.Module): 125 | """ 126 | General GNN model: encoder + stage + head 127 | 128 | Args: 129 | dim_in (int): Input dimension 130 | dim_out (int): Output dimension 131 | **kwargs (optional): Optional additional args 132 | """ 133 | def __init__(self, dim_in, dim_out, **kwargs): 134 | super(GNN, self).__init__() 135 | GNNStage = stage_dict[cfg.gnn.stage_type] 136 | GNNHead = head_dict[cfg.dataset.task] 137 | 138 | self.encoder = FeatureEncoder(dim_in) 139 | dim_in = self.encoder.dim_in 140 | 141 | if cfg.gnn.layers_pre_mp > 0: 142 | self.pre_mp = GNNPreMP(dim_in, cfg.gnn.dim_inner) 143 | dim_in = cfg.gnn.dim_inner 144 | if cfg.gnn.layers_mp > 0: 145 | self.mp = GNNStage(dim_in=dim_in, 146 | dim_out=cfg.gnn.dim_inner, 147 | num_layers=cfg.gnn.layers_mp) 148 | self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) 149 | 150 | self.apply(init_weights) 151 | 152 | def forward(self, batch): 153 | for module in self.children(): 154 | batch = module(batch) 155 | return batch 156 | -------------------------------------------------------------------------------- /graphgym/models/head.py: -------------------------------------------------------------------------------- 1 | """ GNN heads are the last layer of a GNN right before loss computation. 2 | 3 | They are constructed in the init function of the gnn.GNN. 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | import graphgym.register as register 10 | from graphgym.config import cfg 11 | from graphgym.models.layer import MLP 12 | from graphgym.models.pooling import pooling_dict 13 | 14 | 15 | # Head 16 | class GNNNodeHead(nn.Module): 17 | '''Head of GNN, node prediction''' 18 | def __init__(self, dim_in, dim_out): 19 | super(GNNNodeHead, self).__init__() 20 | self.layer_post_mp = MLP(dim_in, 21 | dim_out, 22 | num_layers=cfg.gnn.layers_post_mp, 23 | bias=True) 24 | 25 | def _apply_index(self, batch): 26 | if batch.node_label_index.shape[0] == batch.node_label.shape[0]: 27 | return batch.node_feature[batch.node_label_index], batch.node_label 28 | else: 29 | return batch.node_feature[batch.node_label_index], \ 30 | batch.node_label[batch.node_label_index] 31 | 32 | def forward(self, batch): 33 | batch = self.layer_post_mp(batch) 34 | pred, label = self._apply_index(batch) 35 | return pred, label 36 | 37 | 38 | class GNNEdgeHead(nn.Module): 39 | '''Head of GNN, edge prediction''' 40 | def __init__(self, dim_in, dim_out): 41 | ''' Head of Edge and link prediction models. 42 | 43 | Args: 44 | dim_out: output dimension. For binary prediction, dim_out=1. 45 | ''' 46 | # Use dim_in for graph conv, since link prediction dim_out could be 47 | # binary 48 | # E.g. if decoder='dot', link probability is dot product between 49 | # node embeddings, of dimension dim_in 50 | super(GNNEdgeHead, self).__init__() 51 | # module to decode edges from node embeddings 52 | 53 | if cfg.model.edge_decoding == 'concat': 54 | self.layer_post_mp = MLP(dim_in * 2, 55 | dim_out, 56 | num_layers=cfg.gnn.layers_post_mp, 57 | bias=True) 58 | # requires parameter 59 | self.decode_module = lambda v1, v2: \ 60 | self.layer_post_mp(torch.cat((v1, v2), dim=-1)) 61 | else: 62 | if dim_out > 1: 63 | raise ValueError( 64 | 'Binary edge decoding ({})is used for multi-class ' 65 | 'edge/link prediction.'.format(cfg.model.edge_decoding)) 66 | self.layer_post_mp = MLP(dim_in, 67 | dim_in, 68 | num_layers=cfg.gnn.layers_post_mp, 69 | bias=True) 70 | if cfg.model.edge_decoding == 'dot': 71 | self.decode_module = lambda v1, v2: torch.sum(v1 * v2, dim=-1) 72 | elif cfg.model.edge_decoding == 'cosine_similarity': 73 | self.decode_module = nn.CosineSimilarity(dim=-1) 74 | else: 75 | raise ValueError('Unknown edge decoding {}.'.format( 76 | cfg.model.edge_decoding)) 77 | 78 | def _apply_index(self, batch): 79 | return batch.node_feature[batch.edge_label_index], \ 80 | batch.edge_label 81 | 82 | def forward(self, batch): 83 | if cfg.model.edge_decoding != 'concat': 84 | batch = self.layer_post_mp(batch) 85 | pred, label = self._apply_index(batch) 86 | nodes_first = pred[0] 87 | nodes_second = pred[1] 88 | pred = self.decode_module(nodes_first, nodes_second) 89 | return pred, label 90 | 91 | 92 | class GNNGraphHead(nn.Module): 93 | '''Head of GNN, graph prediction 94 | 95 | The optional post_mp layer (specified by cfg.gnn.post_mp) is used 96 | to transform the pooled embedding using an MLP. 97 | ''' 98 | def __init__(self, dim_in, dim_out): 99 | super(GNNGraphHead, self).__init__() 100 | # todo: PostMP before or after global pooling 101 | self.layer_post_mp = MLP(dim_in, 102 | dim_out, 103 | num_layers=cfg.gnn.layers_post_mp, 104 | bias=True) 105 | self.pooling_fun = pooling_dict[cfg.model.graph_pooling] 106 | 107 | def _apply_index(self, batch): 108 | return batch.graph_feature, batch.graph_label 109 | 110 | def forward(self, batch): 111 | if cfg.dataset.transform == 'ego': 112 | graph_emb = self.pooling_fun(batch.node_feature, batch.batch, 113 | batch.node_id_index) 114 | else: 115 | graph_emb = self.pooling_fun(batch.node_feature, batch.batch) 116 | graph_emb = self.layer_post_mp(graph_emb) 117 | batch.graph_feature = graph_emb 118 | pred, label = self._apply_index(batch) 119 | return pred, label 120 | 121 | 122 | # Head models for external interface 123 | head_dict = { 124 | 'node': GNNNodeHead, 125 | 'edge': GNNEdgeHead, 126 | 'link_pred': GNNEdgeHead, 127 | 'graph': GNNGraphHead 128 | } 129 | 130 | head_dict = {**register.head_dict, **head_dict} 131 | -------------------------------------------------------------------------------- /graphgym/models/head_pyg.py: -------------------------------------------------------------------------------- 1 | """ GNN heads are the last layer of a GNN right before loss computation. 2 | 3 | They are constructed in the init function of the gnn.GNN. 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | import graphgym.register as register 10 | from graphgym.config import cfg 11 | from graphgym.models.layer_pyg import MLP 12 | from graphgym.models.pooling import pooling_dict 13 | 14 | 15 | # Head 16 | class GNNNodeHead(nn.Module): 17 | '''Head of GNN, node prediction''' 18 | def __init__(self, dim_in, dim_out): 19 | super(GNNNodeHead, self).__init__() 20 | self.layer_post_mp = MLP(dim_in, 21 | dim_out, 22 | num_layers=cfg.gnn.layers_post_mp, 23 | bias=True) 24 | 25 | def _apply_index(self, batch): 26 | mask = '{}_mask'.format(batch.split) 27 | return batch.x[batch[mask]], batch.y[batch[mask]] 28 | 29 | def forward(self, batch): 30 | batch = self.layer_post_mp(batch) 31 | pred, label = self._apply_index(batch) 32 | return pred, label 33 | 34 | 35 | class GNNEdgeHead(nn.Module): 36 | '''Head of GNN, edge prediction''' 37 | def __init__(self, dim_in, dim_out): 38 | ''' Head of Edge and link prediction models. 39 | 40 | Args: 41 | dim_out: output dimension. For binary prediction, dim_out=1. 42 | ''' 43 | # Use dim_in for graph conv, since link prediction dim_out could be 44 | # binary 45 | # E.g. if decoder='dot', link probability is dot product between 46 | # node embeddings, of dimension dim_in 47 | super(GNNEdgeHead, self).__init__() 48 | # module to decode edges from node embeddings 49 | 50 | if cfg.model.edge_decoding == 'concat': 51 | self.layer_post_mp = MLP(dim_in * 2, 52 | dim_out, 53 | num_layers=cfg.gnn.layers_post_mp, 54 | bias=True) 55 | # requires parameter 56 | self.decode_module = lambda v1, v2: \ 57 | self.layer_post_mp(torch.cat((v1, v2), dim=-1)) 58 | else: 59 | if dim_out > 1: 60 | raise ValueError( 61 | 'Binary edge decoding ({})is used for multi-class ' 62 | 'edge/link prediction.'.format(cfg.model.edge_decoding)) 63 | self.layer_post_mp = MLP(dim_in, 64 | dim_in, 65 | num_layers=cfg.gnn.layers_post_mp, 66 | bias=True) 67 | if cfg.model.edge_decoding == 'dot': 68 | self.decode_module = lambda v1, v2: torch.sum(v1 * v2, dim=-1) 69 | elif cfg.model.edge_decoding == 'cosine_similarity': 70 | self.decode_module = nn.CosineSimilarity(dim=-1) 71 | else: 72 | raise ValueError('Unknown edge decoding {}.'.format( 73 | cfg.model.edge_decoding)) 74 | 75 | def _apply_index(self, batch): 76 | index = '{}_edge_index'.format(batch.split) 77 | label = '{}_edge_label'.format(batch.split) 78 | return batch.x[batch[index]], batch[label] 79 | 80 | def forward(self, batch): 81 | if cfg.model.edge_decoding != 'concat': 82 | batch = self.layer_post_mp(batch) 83 | pred, label = self._apply_index(batch) 84 | nodes_first = pred[0] 85 | nodes_second = pred[1] 86 | pred = self.decode_module(nodes_first, nodes_second) 87 | return pred, label 88 | 89 | 90 | class GNNGraphHead(nn.Module): 91 | '''Head of GNN, graph prediction 92 | 93 | The optional post_mp layer (specified by cfg.gnn.post_mp) is used 94 | to transform the pooled embedding using an MLP. 95 | ''' 96 | def __init__(self, dim_in, dim_out): 97 | super(GNNGraphHead, self).__init__() 98 | # todo: PostMP before or after global pooling 99 | self.layer_post_mp = MLP(dim_in, 100 | dim_out, 101 | num_layers=cfg.gnn.layers_post_mp, 102 | bias=True) 103 | self.pooling_fun = pooling_dict[cfg.model.graph_pooling] 104 | 105 | def _apply_index(self, batch): 106 | return batch.graph_feature, batch.y 107 | 108 | def forward(self, batch): 109 | graph_emb = self.pooling_fun(batch.x, batch.batch) 110 | graph_emb = self.layer_post_mp(graph_emb) 111 | batch.graph_feature = graph_emb 112 | pred, label = self._apply_index(batch) 113 | return pred, label 114 | 115 | 116 | # Head models for external interface 117 | head_dict = { 118 | 'node': GNNNodeHead, 119 | 'edge': GNNEdgeHead, 120 | 'link_pred': GNNEdgeHead, 121 | 'graph': GNNGraphHead 122 | } 123 | 124 | head_dict = {**register.head_dict, **head_dict} 125 | -------------------------------------------------------------------------------- /graphgym/models/layer_pyg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch_geometric as pyg 5 | 6 | import graphgym.register as register 7 | from graphgym.config import cfg 8 | from graphgym.contrib.layer.generalconv import (GeneralConvLayer, 9 | GeneralEdgeConvLayer) 10 | from graphgym.models.act import act_dict 11 | 12 | 13 | # General classes 14 | class GeneralLayer(nn.Module): 15 | '''General wrapper for layers''' 16 | def __init__(self, 17 | name, 18 | dim_in, 19 | dim_out, 20 | has_act=True, 21 | has_bn=True, 22 | has_l2norm=False, 23 | **kwargs): 24 | super(GeneralLayer, self).__init__() 25 | self.has_l2norm = has_l2norm 26 | has_bn = has_bn and cfg.gnn.batchnorm 27 | self.layer = layer_dict[name](dim_in, 28 | dim_out, 29 | bias=not has_bn, 30 | **kwargs) 31 | layer_wrapper = [] 32 | if has_bn: 33 | layer_wrapper.append( 34 | nn.BatchNorm1d(dim_out, eps=cfg.bn.eps, momentum=cfg.bn.mom)) 35 | if cfg.gnn.dropout > 0: 36 | layer_wrapper.append( 37 | nn.Dropout(p=cfg.gnn.dropout, inplace=cfg.mem.inplace)) 38 | if has_act: 39 | layer_wrapper.append(act_dict[cfg.gnn.act]) 40 | self.post_layer = nn.Sequential(*layer_wrapper) 41 | 42 | def forward(self, batch): 43 | batch = self.layer(batch) 44 | if isinstance(batch, torch.Tensor): 45 | batch = self.post_layer(batch) 46 | if self.has_l2norm: 47 | batch = F.normalize(batch, p=2, dim=1) 48 | else: 49 | batch.x = self.post_layer(batch.x) 50 | if self.has_l2norm: 51 | batch.x = F.normalize(batch.x, p=2, dim=1) 52 | return batch 53 | 54 | 55 | class GeneralMultiLayer(nn.Module): 56 | '''General wrapper for stack of layers''' 57 | def __init__(self, 58 | name, 59 | num_layers, 60 | dim_in, 61 | dim_out, 62 | dim_inner=None, 63 | final_act=True, 64 | **kwargs): 65 | super(GeneralMultiLayer, self).__init__() 66 | dim_inner = dim_in if dim_inner is None else dim_inner 67 | for i in range(num_layers): 68 | d_in = dim_in if i == 0 else dim_inner 69 | d_out = dim_out if i == num_layers - 1 else dim_inner 70 | has_act = final_act if i == num_layers - 1 else True 71 | layer = GeneralLayer(name, d_in, d_out, has_act, **kwargs) 72 | self.add_module('Layer_{}'.format(i), layer) 73 | 74 | def forward(self, batch): 75 | for layer in self.children(): 76 | batch = layer(batch) 77 | return batch 78 | 79 | 80 | # Core basic layers 81 | # Input: batch; Output: batch 82 | class Linear(nn.Module): 83 | def __init__(self, dim_in, dim_out, bias=False, **kwargs): 84 | super(Linear, self).__init__() 85 | self.model = nn.Linear(dim_in, dim_out, bias=bias) 86 | 87 | def forward(self, batch): 88 | if isinstance(batch, torch.Tensor): 89 | batch = self.model(batch) 90 | else: 91 | batch.x = self.model(batch.x) 92 | return batch 93 | 94 | 95 | class BatchNorm1dNode(nn.Module): 96 | '''General wrapper for layers''' 97 | def __init__(self, dim_in): 98 | super(BatchNorm1dNode, self).__init__() 99 | self.bn = nn.BatchNorm1d(dim_in, eps=cfg.bn.eps, momentum=cfg.bn.mom) 100 | 101 | def forward(self, batch): 102 | batch.x = self.bn(batch.x) 103 | return batch 104 | 105 | 106 | class BatchNorm1dEdge(nn.Module): 107 | '''General wrapper for layers''' 108 | def __init__(self, dim_in): 109 | super(BatchNorm1dEdge, self).__init__() 110 | self.bn = nn.BatchNorm1d(dim_in, eps=cfg.bn.eps, momentum=cfg.bn.mom) 111 | 112 | def forward(self, batch): 113 | batch.edge_attr = self.bn(batch.edge_attr) 114 | return batch 115 | 116 | 117 | class MLP(nn.Module): 118 | def __init__(self, 119 | dim_in, 120 | dim_out, 121 | bias=True, 122 | dim_inner=None, 123 | num_layers=2, 124 | **kwargs): 125 | ''' 126 | Note: MLP works for 0 layers 127 | ''' 128 | super(MLP, self).__init__() 129 | dim_inner = dim_in if dim_inner is None else dim_inner 130 | layers = [] 131 | if num_layers > 1: 132 | layers.append( 133 | GeneralMultiLayer('linear', 134 | num_layers - 1, 135 | dim_in, 136 | dim_inner, 137 | dim_inner, 138 | final_act=True)) 139 | layers.append(Linear(dim_inner, dim_out, bias)) 140 | else: 141 | layers.append(Linear(dim_in, dim_out, bias)) 142 | self.model = nn.Sequential(*layers) 143 | 144 | def forward(self, batch): 145 | if isinstance(batch, torch.Tensor): 146 | batch = self.model(batch) 147 | else: 148 | batch.x = self.model(batch.x) 149 | return batch 150 | 151 | 152 | class GCNConv(nn.Module): 153 | def __init__(self, dim_in, dim_out, bias=False, **kwargs): 154 | super(GCNConv, self).__init__() 155 | self.model = pyg.nn.GCNConv(dim_in, dim_out, bias=bias) 156 | 157 | def forward(self, batch): 158 | batch.x = self.model(batch.x, batch.edge_index) 159 | return batch 160 | 161 | 162 | class SAGEConv(nn.Module): 163 | def __init__(self, dim_in, dim_out, bias=False, **kwargs): 164 | super(SAGEConv, self).__init__() 165 | self.model = pyg.nn.SAGEConv(dim_in, dim_out, bias=bias) 166 | 167 | def forward(self, batch): 168 | batch.x = self.model(batch.x, batch.edge_index) 169 | return batch 170 | 171 | 172 | class GATConv(nn.Module): 173 | def __init__(self, dim_in, dim_out, bias=False, **kwargs): 174 | super(GATConv, self).__init__() 175 | self.model = pyg.nn.GATConv(dim_in, dim_out, bias=bias) 176 | 177 | def forward(self, batch): 178 | batch.x = self.model(batch.x, batch.edge_index) 179 | return batch 180 | 181 | 182 | class GINConv(nn.Module): 183 | def __init__(self, dim_in, dim_out, bias=False, **kwargs): 184 | super(GINConv, self).__init__() 185 | gin_nn = nn.Sequential(nn.Linear(dim_in, dim_out), nn.ReLU(), 186 | nn.Linear(dim_out, dim_out)) 187 | self.model = pyg.nn.GINConv(gin_nn) 188 | 189 | def forward(self, batch): 190 | batch.x = self.model(batch.x, batch.edge_index) 191 | return batch 192 | 193 | 194 | class SplineConv(nn.Module): 195 | def __init__(self, dim_in, dim_out, bias=False, **kwargs): 196 | super(SplineConv, self).__init__() 197 | self.model = pyg.nn.SplineConv(dim_in, 198 | dim_out, 199 | dim=1, 200 | kernel_size=2, 201 | bias=bias) 202 | 203 | def forward(self, batch): 204 | batch.x = self.model(batch.x, batch.edge_index, batch.edge_attr) 205 | return batch 206 | 207 | 208 | class GeneralConv(nn.Module): 209 | def __init__(self, dim_in, dim_out, bias=False, **kwargs): 210 | super(GeneralConv, self).__init__() 211 | self.model = GeneralConvLayer(dim_in, dim_out, bias=bias) 212 | 213 | def forward(self, batch): 214 | batch.x = self.model(batch.x, batch.edge_index) 215 | return batch 216 | 217 | 218 | class GeneralEdgeConv(nn.Module): 219 | def __init__(self, dim_in, dim_out, bias=False, **kwargs): 220 | super(GeneralEdgeConv, self).__init__() 221 | self.model = GeneralEdgeConvLayer(dim_in, dim_out, bias=bias) 222 | 223 | def forward(self, batch): 224 | batch.x = self.model(batch.x, 225 | batch.edge_index, 226 | edge_feature=batch.edge_attr) 227 | return batch 228 | 229 | 230 | class GeneralSampleEdgeConv(nn.Module): 231 | def __init__(self, dim_in, dim_out, bias=False, **kwargs): 232 | super(GeneralSampleEdgeConv, self).__init__() 233 | self.model = GeneralEdgeConvLayer(dim_in, dim_out, bias=bias) 234 | 235 | def forward(self, batch): 236 | edge_mask = torch.rand(batch.edge_index.shape[1]) < cfg.gnn.keep_edge 237 | edge_index = batch.edge_index[:, edge_mask] 238 | edge_feature = batch.edge_attr[edge_mask, :] 239 | batch.x = self.model(batch.x, edge_index, edge_feature=edge_feature) 240 | return batch 241 | 242 | 243 | layer_dict = { 244 | 'linear': Linear, 245 | 'mlp': MLP, 246 | 'gcnconv': GCNConv, 247 | 'sageconv': SAGEConv, 248 | 'gatconv': GATConv, 249 | 'splineconv': SplineConv, 250 | 'ginconv': GINConv, 251 | 'generalconv': GeneralConv, 252 | 'generaledgeconv': GeneralEdgeConv, 253 | 'generalsampleedgeconv': GeneralSampleEdgeConv, 254 | } 255 | 256 | # register additional convs 257 | layer_dict = {**register.layer_dict, **layer_dict} 258 | -------------------------------------------------------------------------------- /graphgym/models/pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_scatter import scatter 3 | 4 | import graphgym.register as register 5 | from graphgym.config import cfg 6 | 7 | 8 | # Pooling options (pool nodes into graph representations) 9 | # pooling function takes in node embedding [num_nodes x emb_dim] and 10 | # batch (indices) and outputs graph embedding [num_graphs x emb_dim]. 11 | def global_add_pool(x, batch, id=None, size=None): 12 | size = batch.max().item() + 1 if size is None else size 13 | if cfg.dataset.transform == 'ego': 14 | x = torch.index_select(x, dim=0, index=id) 15 | batch = torch.index_select(batch, dim=0, index=id) 16 | return scatter(x, batch, dim=0, dim_size=size, reduce='add') 17 | 18 | 19 | def global_mean_pool(x, batch, id=None, size=None): 20 | size = batch.max().item() + 1 if size is None else size 21 | if cfg.dataset.transform == 'ego': 22 | x = torch.index_select(x, dim=0, index=id) 23 | batch = torch.index_select(batch, dim=0, index=id) 24 | return scatter(x, batch, dim=0, dim_size=size, reduce='mean') 25 | 26 | 27 | def global_max_pool(x, batch, id=None, size=None): 28 | size = batch.max().item() + 1 if size is None else size 29 | if cfg.dataset.transform == 'ego': 30 | x = torch.index_select(x, dim=0, index=id) 31 | batch = torch.index_select(batch, dim=0, index=id) 32 | return scatter(x, batch, dim=0, dim_size=size, reduce='max') 33 | 34 | 35 | pooling_dict = { 36 | 'add': global_add_pool, 37 | 'mean': global_mean_pool, 38 | 'max': global_max_pool 39 | } 40 | 41 | pooling_dict = {**register.pooling_dict, **pooling_dict} 42 | -------------------------------------------------------------------------------- /graphgym/models/transform.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import torch 3 | from torch_geometric.utils import negative_sampling 4 | 5 | 6 | def remove_node_feature(graph): 7 | '''set node feature to be constant''' 8 | graph.node_feature = torch.ones(graph.num_nodes, 1) 9 | 10 | 11 | def ego_nets(graph, radius=2): 12 | '''get networks for mini batch node/graph prediction tasks''' 13 | # color center 14 | egos = [] 15 | n = graph.num_nodes 16 | # A proper deepsnap.G should have nodes indexed from 0 to n-1 17 | for i in range(n): 18 | if radius > 4: 19 | egos.append(graph.G) 20 | else: 21 | egos.append(nx.ego_graph(graph.G, i, radius=radius)) 22 | # relabel egos: keep center node ID, relabel other node IDs 23 | G = graph.G.__class__() 24 | id_bias = n 25 | for i in range(len(egos)): 26 | G.add_node(i, **egos[i].nodes(data=True)[i]) 27 | for i in range(len(egos)): 28 | keys = list(egos[i].nodes) 29 | keys.remove(i) 30 | id_cur = egos[i].number_of_nodes() - 1 31 | vals = range(id_bias, id_bias + id_cur) 32 | id_bias += id_cur 33 | mapping = dict(zip(keys, vals)) 34 | ego = nx.relabel_nodes(egos[i], mapping, copy=True) 35 | G.add_nodes_from(ego.nodes(data=True)) 36 | G.add_edges_from(ego.edges(data=True)) 37 | graph.G = G 38 | graph.node_id_index = torch.arange(len(egos)) 39 | 40 | 41 | def edge_nets(graph): 42 | '''get networks for mini batch edge prediction tasks''' 43 | # color center 44 | n = graph.num_nodes 45 | # A proper deepsnap.G should have nodes indexed from 0 to n-1 46 | # relabel egos: keep center node ID, relabel other node IDs 47 | G_raw = graph.G 48 | G = graph.G.__class__() 49 | for i in range(n): 50 | keys = list(G_raw.nodes) 51 | vals = range(i * n, (i + 1) * n) 52 | mapping = dict(zip(keys, vals)) 53 | G_raw = nx.relabel_nodes(G_raw, mapping, copy=True) 54 | G.add_nodes_from(G_raw.nodes(data=True)) 55 | G.add_edges_from(G_raw.edges(data=True)) 56 | graph.G = G 57 | graph.node_id_index = torch.arange(0, n * n, n + 1) 58 | 59 | # change link_pred to conditional node classification task 60 | graph.node_label = graph.edge_label 61 | bias = graph.edge_label_index[0] * n 62 | graph.node_label_index = graph.edge_label_index[1] + bias 63 | 64 | graph.edge_label = None 65 | graph.edge_label_index = None 66 | 67 | 68 | def path_len(graph): 69 | '''get networks for mini batch shortest path prediction tasks''' 70 | n = graph.num_nodes 71 | # shortest path label 72 | num_label = 1000 73 | edge_label_index = torch.randint(n, 74 | size=(2, num_label), 75 | device=graph.edge_index.device) 76 | path_dict = dict(nx.all_pairs_shortest_path_length(graph.G)) 77 | edge_label = [] 78 | index_keep = [] 79 | for i in range(num_label): 80 | start = edge_label_index[0, i].item() 81 | end = edge_label_index[1, i].item() 82 | try: 83 | dist = path_dict[start][end] 84 | except Exception: 85 | continue 86 | edge_label.append(min(dist, 4)) 87 | index_keep.append(i) 88 | 89 | edge_label = torch.tensor(edge_label, device=edge_label_index.device) 90 | graph.edge_label_index = edge_label_index[:, index_keep] 91 | graph.edge_label = edge_label 92 | 93 | 94 | def create_link_label(pos_edge_index, neg_edge_index): 95 | """ 96 | Create labels for link prediction, based on positive and negative edges 97 | 98 | Args: 99 | pos_edge_index (torch.tensor): Positive edge index [2, num_edges] 100 | neg_edge_index (torch.tensor): Negative edge index [2, num_edges] 101 | 102 | Returns: Link label tensor, [num_positive_edges + num_negative_edges] 103 | 104 | """ 105 | num_links = pos_edge_index.size(1) + neg_edge_index.size(1) 106 | link_labels = torch.zeros(num_links, 107 | dtype=torch.float, 108 | device=pos_edge_index.device) 109 | link_labels[:pos_edge_index.size(1)] = 1. 110 | return link_labels 111 | 112 | 113 | def neg_sampling_transform(data): 114 | """ 115 | Do negative sampling for link prediction tasks 116 | 117 | Args: 118 | data (torch_geometric.data): Input data object 119 | 120 | Returns: Transformed data object with negative edges + link pred labels 121 | 122 | """ 123 | train_neg_edge_index = negative_sampling( 124 | edge_index=data.train_pos_edge_index, 125 | num_nodes=data.num_nodes, 126 | num_neg_samples=data.train_pos_edge_index.size(1)) 127 | data.train_edge_index = torch.cat( 128 | [data.train_pos_edge_index, train_neg_edge_index], dim=-1) 129 | data.train_edge_label = create_link_label(data.train_pos_edge_index, 130 | train_neg_edge_index) 131 | 132 | return data 133 | -------------------------------------------------------------------------------- /graphgym/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | import graphgym.register as register 4 | from graphgym.config import cfg 5 | 6 | 7 | def create_optimizer(params): 8 | r"""Creates a config-driven optimizer.""" 9 | params = filter(lambda p: p.requires_grad, params) 10 | # Try to load customized optimizer 11 | for func in register.optimizer_dict.values(): 12 | optimizer = func(params) 13 | if optimizer is not None: 14 | return optimizer 15 | if cfg.optim.optimizer == 'adam': 16 | optimizer = optim.Adam(params, 17 | lr=cfg.optim.base_lr, 18 | weight_decay=cfg.optim.weight_decay) 19 | elif cfg.optim.optimizer == 'sgd': 20 | optimizer = optim.SGD(params, 21 | lr=cfg.optim.base_lr, 22 | momentum=cfg.optim.momentum, 23 | weight_decay=cfg.optim.weight_decay) 24 | else: 25 | raise ValueError('Optimizer {} not supported'.format( 26 | cfg.optim.optimizer)) 27 | 28 | return optimizer 29 | 30 | 31 | def create_scheduler(optimizer): 32 | r"""Creates a config-driven learning rate scheduler.""" 33 | # Try to load customized scheduler 34 | for func in register.scheduler_dict.values(): 35 | scheduler = func(optimizer) 36 | if scheduler is not None: 37 | return scheduler 38 | if cfg.optim.scheduler == 'none': 39 | scheduler = optim.lr_scheduler.StepLR(optimizer, 40 | step_size=cfg.optim.max_epoch + 41 | 1) 42 | elif cfg.optim.scheduler == 'step': 43 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, 44 | milestones=cfg.optim.steps, 45 | gamma=cfg.optim.lr_decay) 46 | elif cfg.optim.scheduler == 'cos': 47 | scheduler = optim.lr_scheduler.CosineAnnealingLR( 48 | optimizer, T_max=cfg.optim.max_epoch) 49 | else: 50 | raise ValueError('Scheduler {} not supported'.format( 51 | cfg.optim.scheduler)) 52 | return scheduler 53 | -------------------------------------------------------------------------------- /graphgym/register.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Any, Callable, Dict, Union 3 | 4 | act_dict: Dict[str, Any] = {} 5 | node_encoder_dict: Dict[str, Any] = {} 6 | edge_encoder_dict: Dict[str, Any] = {} 7 | stage_dict: Dict[str, Any] = {} 8 | head_dict: Dict[str, Any] = {} 9 | layer_dict: Dict[str, Any] = {} 10 | pooling_dict: Dict[str, Any] = {} 11 | network_dict: Dict[str, Any] = {} 12 | config_dict: Dict[str, Any] = {} 13 | dataset_dict: Dict[str, Any] = {} 14 | loader_dict: Dict[str, Any] = {} 15 | optimizer_dict: Dict[str, Any] = {} 16 | scheduler_dict: Dict[str, Any] = {} 17 | loss_dict: Dict[str, Any] = {} 18 | train_dict: Dict[str, Any] = {} 19 | feature_augment_dict: Dict[str, Any] = {} 20 | metric_dict: Dict[str, Any] = {} 21 | 22 | 23 | def register_base(mapping: Dict[str, Any], 24 | key: str, 25 | module: Any = None) -> Union[None, Callable]: 26 | r"""Base function for registering a module in GraphGym. 27 | 28 | Args: 29 | mapping (dict): Python dictionary to register the module. 30 | hosting all the registered modules 31 | key (string): The name of the module. 32 | module (any, optional): The module. If set to :obj:`None`, will return 33 | a decorator to register a module. 34 | """ 35 | if module is not None: 36 | if key in mapping: 37 | raise KeyError(f"Module with '{key}' already defined") 38 | mapping[key] = module 39 | return 40 | 41 | # Other-wise, use it as a decorator: 42 | def bounded_register(module): 43 | register_base(mapping, key, module) 44 | return module 45 | 46 | return bounded_register 47 | 48 | 49 | def register_act(key: str, module: Any = None): 50 | r"""Registers an activation function in GraphGym.""" 51 | return register_base(act_dict, key, module) 52 | 53 | 54 | def register_node_encoder(key: str, module: Any = None): 55 | r"""Registers a node feature encoder in GraphGym.""" 56 | return register_base(node_encoder_dict, key, module) 57 | 58 | 59 | def register_edge_encoder(key: str, module: Any = None): 60 | r"""Registers an edge feature encoder in GraphGym.""" 61 | return register_base(edge_encoder_dict, key, module) 62 | 63 | 64 | def register_stage(key: str, module: Any = None): 65 | r"""Registers a customized GNN stage in GraphGym.""" 66 | return register_base(stage_dict, key, module) 67 | 68 | 69 | def register_head(key: str, module: Any = None): 70 | r"""Registers a GNN prediction head in GraphGym.""" 71 | return register_base(head_dict, key, module) 72 | 73 | 74 | def register_layer(key: str, module: Any = None): 75 | r"""Registers a GNN layer in GraphGym.""" 76 | return register_base(layer_dict, key, module) 77 | 78 | 79 | def register_pooling(key: str, module: Any = None): 80 | r"""Registers a GNN global pooling/readout layer in GraphGym.""" 81 | return register_base(pooling_dict, key, module) 82 | 83 | 84 | def register_network(key: str, module: Any = None): 85 | r"""Registers a GNN model in GraphGym.""" 86 | return register_base(network_dict, key, module) 87 | 88 | 89 | def register_config(key: str, module: Any = None): 90 | r"""Registers a configuration group in GraphGym.""" 91 | return register_base(config_dict, key, module) 92 | 93 | 94 | def register_dataset(key: str, module: Any = None): 95 | r"""Registers a dataset in GraphGym.""" 96 | return register_base(dataset_dict, key, module) 97 | 98 | 99 | def register_loader(key: str, module: Any = None): 100 | r"""Registers a data loader in GraphGym.""" 101 | return register_base(loader_dict, key, module) 102 | 103 | 104 | def register_optimizer(key: str, module: Any = None): 105 | r"""Registers an optimizer in GraphGym.""" 106 | return register_base(optimizer_dict, key, module) 107 | 108 | 109 | def register_scheduler(key: str, module: Any = None): 110 | r"""Registers a learning rate scheduler in GraphGym.""" 111 | return register_base(scheduler_dict, key, module) 112 | 113 | 114 | def register_loss(key: str, module: Any = None): 115 | r"""Registers a loss function in GraphGym.""" 116 | return register_base(loss_dict, key, module) 117 | 118 | 119 | def register_train(key: str, module: Any = None): 120 | r"""Registers a training function in GraphGym.""" 121 | return register_base(train_dict, key, module) 122 | 123 | 124 | def register_metric(key: str, module: Any = None): 125 | r"""Register a metric function in GraphGym.""" 126 | return register_base(metric_dict, key, module) 127 | 128 | 129 | def register_feature_augment(key, module): 130 | return register_base(feature_augment_dict, key, module) 131 | 132 | 133 | class ModuleStore(dict): 134 | def __init__(self): 135 | super().__init__() 136 | 137 | def register(self, 138 | module_group: str, 139 | key: str, 140 | module: Any = None) -> Union[None, Callable]: 141 | r"""Base function for registering a module in GraphGym. 142 | 143 | Args: 144 | module_group (str): The name of the module group 145 | key (string): The name of the module. 146 | module (any, optional): The module. If set to :obj:`None`, will 147 | return a decorator to register a module. 148 | """ 149 | 150 | if module_group not in self.keys(): 151 | self[module_group] = {} 152 | 153 | if module is not None: 154 | if key in self[module_group]: 155 | warnings.warn( 156 | f"Module group {module_group} with '{key}' already " 157 | f"defined, registeration failed") 158 | self[module_group][key] = module 159 | return 160 | 161 | # Other-wise, use it as a decorator: 162 | def bounded_register(module): 163 | self.register(module_group, key, module) 164 | return module 165 | 166 | return bounded_register 167 | 168 | 169 | module = ModuleStore() 170 | -------------------------------------------------------------------------------- /graphgym/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import torch 5 | 6 | from graphgym.checkpoint import clean_ckpt, load_ckpt, save_ckpt 7 | from graphgym.config import cfg 8 | from graphgym.loss import compute_loss 9 | from graphgym.utils.epoch import is_ckpt_epoch, is_eval_epoch 10 | 11 | 12 | def train_epoch(logger, loader, model, optimizer, scheduler): 13 | model.train() 14 | time_start = time.time() 15 | for batch in loader: 16 | optimizer.zero_grad() 17 | batch.to(torch.device(cfg.device)) 18 | pred, true = model(batch) 19 | loss, pred_score = compute_loss(pred, true) 20 | loss.backward() 21 | optimizer.step() 22 | logger.update_stats(true=true.detach().cpu(), 23 | pred=pred_score.detach().cpu(), 24 | loss=loss.item(), 25 | lr=scheduler.get_last_lr()[0], 26 | time_used=time.time() - time_start, 27 | params=cfg.params) 28 | time_start = time.time() 29 | scheduler.step() 30 | 31 | 32 | @torch.no_grad() 33 | def eval_epoch(logger, loader, model): 34 | model.eval() 35 | time_start = time.time() 36 | for batch in loader: 37 | batch.to(torch.device(cfg.device)) 38 | pred, true = model(batch) 39 | loss, pred_score = compute_loss(pred, true) 40 | logger.update_stats(true=true.detach().cpu(), 41 | pred=pred_score.detach().cpu(), 42 | loss=loss.item(), 43 | lr=0, 44 | time_used=time.time() - time_start, 45 | params=cfg.params) 46 | time_start = time.time() 47 | 48 | 49 | def train(loggers, loaders, model, optimizer, scheduler): 50 | r""" 51 | The core training pipeline 52 | 53 | Args: 54 | loggers: List of loggers 55 | loaders: List of loaders 56 | model: GNN model 57 | optimizer: PyTorch optimizer 58 | scheduler: PyTorch learning rate scheduler 59 | 60 | """ 61 | start_epoch = 0 62 | if cfg.train.auto_resume: 63 | start_epoch = load_ckpt(model, optimizer, scheduler) 64 | if start_epoch == cfg.optim.max_epoch: 65 | logging.info('Checkpoint found, Task already done') 66 | else: 67 | logging.info('Start from epoch {}'.format(start_epoch)) 68 | 69 | num_splits = len(loggers) 70 | for cur_epoch in range(start_epoch, cfg.optim.max_epoch): 71 | train_epoch(loggers[0], loaders[0], model, optimizer, scheduler) 72 | loggers[0].write_epoch(cur_epoch) 73 | if is_eval_epoch(cur_epoch): 74 | for i in range(1, num_splits): 75 | eval_epoch(loggers[i], loaders[i], model) 76 | loggers[i].write_epoch(cur_epoch) 77 | if is_ckpt_epoch(cur_epoch): 78 | save_ckpt(model, optimizer, scheduler, cur_epoch) 79 | for logger in loggers: 80 | logger.close() 81 | if cfg.train.ckpt_clean: 82 | clean_ckpt() 83 | 84 | logging.info('Task done, results saved in {}'.format(cfg.out_dir)) 85 | -------------------------------------------------------------------------------- /graphgym/train_pyg.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import torch 5 | 6 | from graphgym.checkpoint import clean_ckpt, load_ckpt, save_ckpt 7 | from graphgym.config import cfg 8 | from graphgym.loss import compute_loss 9 | from graphgym.utils.epoch import is_ckpt_epoch, is_eval_epoch 10 | 11 | 12 | def train_epoch(logger, loader, model, optimizer, scheduler): 13 | model.train() 14 | time_start = time.time() 15 | for batch in loader: 16 | batch.split = 'train' 17 | optimizer.zero_grad() 18 | batch.to(torch.device(cfg.device)) 19 | pred, true = model(batch) 20 | loss, pred_score = compute_loss(pred, true) 21 | loss.backward() 22 | optimizer.step() 23 | logger.update_stats(true=true.detach().cpu(), 24 | pred=pred_score.detach().cpu(), 25 | loss=loss.item(), 26 | lr=scheduler.get_last_lr()[0], 27 | time_used=time.time() - time_start, 28 | params=cfg.params) 29 | time_start = time.time() 30 | scheduler.step() 31 | 32 | 33 | @torch.no_grad() 34 | def eval_epoch(logger, loader, model, split='val'): 35 | model.eval() 36 | time_start = time.time() 37 | for batch in loader: 38 | batch.split = split 39 | batch.to(torch.device(cfg.device)) 40 | pred, true = model(batch) 41 | loss, pred_score = compute_loss(pred, true) 42 | logger.update_stats(true=true.detach().cpu(), 43 | pred=pred_score.detach().cpu(), 44 | loss=loss.item(), 45 | lr=0, 46 | time_used=time.time() - time_start, 47 | params=cfg.params) 48 | time_start = time.time() 49 | 50 | 51 | def train(loggers, loaders, model, optimizer, scheduler): 52 | r""" 53 | The core training pipeline 54 | 55 | Args: 56 | loggers: List of loggers 57 | loaders: List of loaders 58 | model: GNN model 59 | optimizer: PyTorch optimizer 60 | scheduler: PyTorch learning rate scheduler 61 | 62 | """ 63 | start_epoch = 0 64 | if cfg.train.auto_resume: 65 | start_epoch = load_ckpt(model, optimizer, scheduler) 66 | if start_epoch == cfg.optim.max_epoch: 67 | logging.info('Checkpoint found, Task already done') 68 | else: 69 | logging.info('Start from epoch {}'.format(start_epoch)) 70 | 71 | num_splits = len(loggers) 72 | split_names = ['val', 'test'] 73 | for cur_epoch in range(start_epoch, cfg.optim.max_epoch): 74 | train_epoch(loggers[0], loaders[0], model, optimizer, scheduler) 75 | loggers[0].write_epoch(cur_epoch) 76 | if is_eval_epoch(cur_epoch): 77 | for i in range(1, num_splits): 78 | eval_epoch(loggers[i], loaders[i], model, 79 | split=split_names[i - 1]) 80 | loggers[i].write_epoch(cur_epoch) 81 | if is_ckpt_epoch(cur_epoch): 82 | save_ckpt(model, optimizer, scheduler, cur_epoch) 83 | for logger in loggers: 84 | logger.close() 85 | if cfg.train.ckpt_clean: 86 | clean_ckpt() 87 | 88 | logging.info('Task done, results saved in {}'.format(cfg.out_dir)) 89 | -------------------------------------------------------------------------------- /graphgym/utils/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Jiaxuan You 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. -------------------------------------------------------------------------------- /graphgym/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snap-stanford/GraphGym/daded21169ec92fde8b1252b439a8fac35b07d79/graphgym/utils/__init__.py -------------------------------------------------------------------------------- /graphgym/utils/comp_budget.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from graphgym.config import cfg, set_cfg 4 | from graphgym.model_builder import create_model 5 | 6 | 7 | def params_count(model): 8 | ''' 9 | Computes the number of parameters. 10 | 11 | Args: 12 | model (nn.Module): PyTorch model 13 | 14 | ''' 15 | return sum([p.numel() for p in model.parameters()]) 16 | 17 | 18 | def get_stats(): 19 | model = create_model(to_device=False, dim_in=1, dim_out=1) 20 | return params_count(model) 21 | 22 | 23 | def match_computation(stats_baseline, key=['gnn', 'dim_inner'], mode='sqrt'): 24 | '''Match computation budget by modifying cfg.gnn.dim_inner''' 25 | stats = get_stats() 26 | if stats != stats_baseline: 27 | # Phase 1: fast approximation 28 | while True: 29 | if mode == 'sqrt': 30 | scale = math.sqrt(stats_baseline / stats) 31 | elif mode == 'linear': 32 | scale = stats_baseline / stats 33 | step = int(round(cfg[key[0]][key[1]] * scale)) \ 34 | - cfg[key[0]][key[1]] 35 | cfg[key[0]][key[1]] += step 36 | stats = get_stats() 37 | if abs(step) <= 1: 38 | break 39 | # Phase 2: fine tune 40 | flag_init = 1 if stats < stats_baseline else -1 41 | step = 1 42 | while True: 43 | cfg[key[0]][key[1]] += flag_init * step 44 | stats = get_stats() 45 | flag = 1 if stats < stats_baseline else -1 46 | if stats == stats_baseline: 47 | return stats 48 | if flag != flag_init: 49 | if not cfg.model.match_upper: # stats is SMALLER 50 | if flag < 0: 51 | cfg[key[0]][key[1]] -= flag_init * step 52 | return get_stats() 53 | else: 54 | if flag > 0: 55 | cfg[key[0]][key[1]] -= flag_init * step 56 | return get_stats() 57 | return stats 58 | 59 | 60 | def dict_to_stats(cfg_dict): 61 | from yacs.config import CfgNode as CN 62 | set_cfg(cfg) 63 | cfg_new = CN(cfg_dict) 64 | cfg.merge_from_other_cfg(cfg_new) 65 | stats = get_stats() 66 | set_cfg(cfg) 67 | return stats 68 | 69 | 70 | def match_baseline_cfg(cfg_dict, cfg_dict_baseline, verbose=True): 71 | ''' 72 | Match the computational budget of a given baseline model. THe current 73 | configuration dictionary will be modifed and returned. 74 | 75 | Args: 76 | cfg_dict (dict): Current experiment's configuration 77 | cfg_dict_baseline (dict): Baseline configuration 78 | verbose (str, optional): If printing matched paramter conunts 79 | 80 | 81 | ''' 82 | from yacs.config import CfgNode as CN 83 | stats_baseline = dict_to_stats(cfg_dict_baseline) 84 | set_cfg(cfg) 85 | cfg_new = CN(cfg_dict) 86 | cfg.merge_from_other_cfg(cfg_new) 87 | stats = match_computation(stats_baseline, key=['gnn', 'dim_inner']) 88 | if 'gnn' in cfg_dict: 89 | cfg_dict['gnn']['dim_inner'] = cfg.gnn.dim_inner 90 | else: 91 | cfg_dict['gnn'] = {'dim_inner', cfg.gnn.dim_inner} 92 | set_cfg(cfg) 93 | if verbose: 94 | print('Computational budget has matched: Baseline params {}, ' 95 | 'Current params {}'.format(stats_baseline, stats)) 96 | return cfg_dict 97 | -------------------------------------------------------------------------------- /graphgym/utils/device.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import subprocess 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from graphgym.config import cfg 9 | 10 | 11 | def get_gpu_memory_map(): 12 | '''Get the current gpu usage.''' 13 | result = subprocess.check_output([ 14 | 'nvidia-smi', '--query-gpu=memory.used', 15 | '--format=csv,nounits,noheader' 16 | ], encoding='utf-8') 17 | gpu_memory = np.array([int(x) for x in result.strip().split('\n')]) 18 | return gpu_memory 19 | 20 | 21 | def get_current_gpu_usage(): 22 | ''' 23 | Get the current GPU memory usage. 24 | ''' 25 | if cfg.gpu_mem and cfg.device != 'cpu' and torch.cuda.is_available(): 26 | result = subprocess.check_output([ 27 | 'nvidia-smi', '--query-compute-apps=pid,used_memory', 28 | '--format=csv,nounits,noheader' 29 | ], encoding='utf-8') 30 | current_pid = os.getpid() 31 | used_memory = 0 32 | for line in result.strip().split('\n'): 33 | line = line.split(', ') 34 | if current_pid == int(line[0]): 35 | used_memory += int(line[1]) 36 | return used_memory 37 | else: 38 | return -1 39 | 40 | 41 | def auto_select_device(memory_max=8000, memory_bias=200, strategy='random'): 42 | r''' 43 | Auto select device for the experiment. Useful when having multiple GPUs. 44 | 45 | Args: 46 | memory_max (int): Threshold of existing GPU memory usage. GPUs with 47 | memory usage beyond this threshold will be deprioritized. 48 | memory_bias (int): A bias GPU memory usage added to all the GPUs. 49 | Avoild dvided by zero error. 50 | strategy (str, optional): 'random' (random select GPU) or 'greedy' 51 | (greedily select GPU) 52 | 53 | ''' 54 | if cfg.device != 'cpu' and torch.cuda.is_available(): 55 | if cfg.device == 'auto': 56 | memory_raw = get_gpu_memory_map() 57 | if strategy == 'greedy' or np.all(memory_raw > memory_max): 58 | cuda = np.argmin(memory_raw) 59 | logging.info('GPU Mem: {}'.format(memory_raw)) 60 | logging.info( 61 | 'Greedy select GPU, select GPU {} with mem: {}'.format( 62 | cuda, memory_raw[cuda])) 63 | elif strategy == 'random': 64 | memory = 1 / (memory_raw + memory_bias) 65 | memory[memory_raw > memory_max] = 0 66 | gpu_prob = memory / memory.sum() 67 | cuda = np.random.choice(len(gpu_prob), p=gpu_prob) 68 | logging.info('GPU Mem: {}'.format(memory_raw)) 69 | logging.info('GPU Prob: {}'.format(gpu_prob.round(2))) 70 | logging.info( 71 | 'Random select GPU, select GPU {} with mem: {}'.format( 72 | cuda, memory_raw[cuda])) 73 | 74 | cfg.device = 'cuda:{}'.format(cuda) 75 | else: 76 | cfg.device = 'cpu' 77 | -------------------------------------------------------------------------------- /graphgym/utils/epoch.py: -------------------------------------------------------------------------------- 1 | from graphgym.config import cfg 2 | 3 | 4 | def is_train_eval_epoch(cur_epoch): 5 | """Determines if the model should be evaluated at the training epoch.""" 6 | return is_eval_epoch(cur_epoch) or not cfg.train.skip_train_eval 7 | 8 | 9 | def is_eval_epoch(cur_epoch): 10 | """Determines if the model should be evaluated at the current epoch.""" 11 | return ((cur_epoch + 1) % cfg.train.eval_period == 0 or cur_epoch == 0 12 | or (cur_epoch + 1) == cfg.optim.max_epoch) 13 | 14 | 15 | def is_ckpt_epoch(cur_epoch): 16 | """Determines if the model should be evaluated at the current epoch.""" 17 | return ((cur_epoch + 1) % cfg.train.ckpt_period == 0 18 | or (cur_epoch + 1) == cfg.optim.max_epoch) 19 | -------------------------------------------------------------------------------- /graphgym/utils/io.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | import os 4 | import shutil 5 | 6 | 7 | def string_to_python(string): 8 | try: 9 | return ast.literal_eval(string) 10 | except Exception: 11 | return string 12 | 13 | 14 | def dict_to_json(dict, fname): 15 | ''' 16 | Dump a Python dictionary to JSON file 17 | 18 | Args: 19 | dict (dict): Python dictionary 20 | fname (str): Output file name 21 | 22 | ''' 23 | with open(fname, 'a') as f: 24 | json.dump(dict, f) 25 | f.write('\n') 26 | 27 | 28 | def dict_list_to_json(dict_list, fname): 29 | ''' 30 | Dump a list of Python dictionaries to JSON file 31 | 32 | Args: 33 | dict_list (list of dict): List of Python dictionaries 34 | fname (str): Output file name 35 | 36 | ''' 37 | with open(fname, 'a') as f: 38 | for dict in dict_list: 39 | json.dump(dict, f) 40 | f.write('\n') 41 | 42 | 43 | def json_to_dict_list(fname): 44 | dict_list = [] 45 | epoch_set = set() 46 | with open(fname) as f: 47 | lines = f.readlines() 48 | for line in lines: 49 | line = line.rstrip() 50 | dict = json.loads(line) 51 | if dict['epoch'] not in epoch_set: 52 | dict_list.append(dict) 53 | epoch_set.add(dict['epoch']) 54 | return dict_list 55 | 56 | 57 | def dict_to_tb(dict, writer, epoch): 58 | ''' 59 | Add a dictionary of statistics to a Tensorboard writer 60 | 61 | Args: 62 | dict (dict): Statistics of experiments, the keys are attribute names, 63 | the values are the attribute values 64 | writer: Tensorboard writer object 65 | epoch (int): The current epoch 66 | 67 | ''' 68 | for key in dict: 69 | writer.add_scalar(key, dict[key], epoch) 70 | 71 | 72 | def dict_list_to_tb(dict_list, writer): 73 | for dict in dict_list: 74 | assert 'epoch' in dict, 'Key epoch must exist in stats dict' 75 | dict_to_tb(dict, writer, dict['epoch']) 76 | 77 | 78 | def makedirs(dir): 79 | os.makedirs(dir, exist_ok=True) 80 | 81 | 82 | def makedirs_rm_exist(dir): 83 | ''' 84 | Make a directory, remove any existing data. 85 | 86 | Args: 87 | dir (str): The directory to be created. 88 | 89 | Returns: 90 | 91 | ''' 92 | if os.path.isdir(dir): 93 | shutil.rmtree(dir) 94 | os.makedirs(dir, exist_ok=True) 95 | -------------------------------------------------------------------------------- /graphgym/utils/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import seaborn as sns 4 | from sklearn.decomposition import PCA 5 | 6 | from graphgym.utils.io import makedirs 7 | 8 | # from sklearn.manifold import TSNE 9 | 10 | 11 | sns.set_context('poster') 12 | 13 | 14 | def view_emb(emb, dir): 15 | ''' 16 | Visualize a embedding matrix. 17 | 18 | Args: 19 | emb (torch.tensor): Embedding matrix with shape (N, D). D is the 20 | feature dimension. 21 | dir (str): Output directory for the embedding figure. 22 | 23 | ''' 24 | if emb.shape[1] > 2: 25 | pca = PCA(n_components=2) 26 | emb = pca.fit_transform(emb) 27 | plt.figure(figsize=(10, 10)) 28 | plt.scatter(emb[:, 0], emb[:, 1]) 29 | plt.savefig('{}/emb_pca.png'.format(dir), dpi=100) 30 | 31 | 32 | def view_emb_kg(emb1, emb2, dir, epoch=0): 33 | pca = PCA(n_components=2) 34 | emb = np.concatenate((emb1, emb2), axis=0) 35 | print(emb.shape) 36 | split = emb1.shape[0] 37 | emb = pca.fit_transform(emb) 38 | plt.figure(figsize=(10, 10)) 39 | plt.scatter(emb[:split, 0], emb[:split, 1], c='green', s=100) 40 | plt.scatter(emb[split:, 0], emb[split:, 1], c='blue', marker='x', s=800) 41 | ax = plt.gca() 42 | annotate = {-3: 'LogP', -2: 'QED', -1: 'Label'} 43 | for i, txt in annotate.items(): 44 | ax.annotate(txt, (emb[i, 0], emb[i, 1])) 45 | makedirs('{}/emb'.format(dir)) 46 | plt.savefig('{}/emb/pca_{}.png'.format(dir, epoch), dpi=100) 47 | -------------------------------------------------------------------------------- /graphgym/utils/tools.py: -------------------------------------------------------------------------------- 1 | class dummy_context(): 2 | '''Default context manager that does nothing''' 3 | def __enter__(self): 4 | return None 5 | 6 | def __exit__(self, exc_type, exc_value, traceback): 7 | return False 8 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | pip install -e . 2 | pip install -r requirements.txt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | yacs 2 | tensorboardx 3 | torch 4 | torch-geometric 5 | deepsnap 6 | ogb 7 | numpy 8 | pandas 9 | scipy 10 | scikit-learn 11 | matplotlib 12 | seaborn 13 | notebook -------------------------------------------------------------------------------- /run/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Jiaxuan You 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. -------------------------------------------------------------------------------- /run/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snap-stanford/GraphGym/daded21169ec92fde8b1252b439a8fac35b07d79/run/__init__.py -------------------------------------------------------------------------------- /run/agg_batch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from graphgym.utils.agg_runs import agg_batch 4 | 5 | 6 | def parse_args(): 7 | """Parses the arguments.""" 8 | parser = argparse.ArgumentParser( 9 | description='Train a classification model') 10 | parser.add_argument('--dir', dest='dir', help='Dir for batch of results', 11 | required=True, type=str) 12 | parser.add_argument('--metric', dest='metric', 13 | help='metric to select best epoch', required=False, 14 | type=str, default='auto') 15 | return parser.parse_args() 16 | 17 | 18 | args = parse_args() 19 | agg_batch(args.dir, args.metric) 20 | -------------------------------------------------------------------------------- /run/configs/IDGNN/edge.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | dataset: 3 | format: nx 4 | name: ba 5 | task: link_pred 6 | task_type: classification 7 | transductive: True 8 | split: [0.8, 0.2] 9 | augment_feature: [] 10 | augment_feature_dims: [10] 11 | augment_feature_repr: original 12 | augment_label: '' 13 | augment_label_dims: 5 14 | transform: edge 15 | edge_train_mode: disjoint 16 | train: 17 | batch_size: 32 18 | eval_period: 20 19 | ckpt_period: 100 20 | model: 21 | type: gnn 22 | loss_fun: cross_entropy 23 | edge_decoding: concat 24 | graph_pooling: add 25 | gnn: 26 | layers_pre_mp: 1 27 | layers_mp: 5 28 | layers_post_mp: 1 29 | dim_inner: 128 30 | layer_type: idconv 31 | stage_type: stack 32 | batchnorm: True 33 | act: relu 34 | dropout: 0.0 35 | agg: add 36 | normalize_adj: False 37 | l2norm: True 38 | optim: 39 | optimizer: adam 40 | base_lr: 0.01 41 | max_epoch: 500 -------------------------------------------------------------------------------- /run/configs/IDGNN/graph.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | dataset: 3 | format: nx 4 | name: ba500 5 | task: graph 6 | task_type: classification 7 | transductive: False 8 | split: [0.8, 0.2] 9 | augment_feature: [] 10 | augment_feature_dims: [10] 11 | augment_feature_repr: position 12 | augment_label: graph_path_len 13 | augment_label_dims: 10 14 | transform: ego 15 | train: 16 | batch_size: 64 17 | eval_period: 20 18 | ckpt_period: 100 19 | model: 20 | type: gnn 21 | loss_fun: cross_entropy 22 | edge_decoding: dot 23 | graph_pooling: add 24 | gnn: 25 | layers_pre_mp: 1 26 | layers_mp: 3 27 | layers_post_mp: 3 28 | dim_inner: 128 29 | layer_type: idconv 30 | stage_type: stack 31 | batchnorm: True 32 | act: relu 33 | dropout: 0.0 34 | agg: add 35 | normalize_adj: False 36 | l2norm: True 37 | optim: 38 | optimizer: adam 39 | base_lr: 0.1 40 | max_epoch: 1000 -------------------------------------------------------------------------------- /run/configs/IDGNN/graph_enzyme.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | dataset: 3 | format: nx 4 | name: ba500 5 | task: graph 6 | task_type: classification 7 | transductive: False 8 | split: [0.8, 0.2] 9 | augment_feature: [] 10 | augment_feature_dims: [10] 11 | augment_feature_repr: position 12 | augment_label: graph_path_len 13 | augment_label_dims: 10 14 | transform: ego 15 | train: 16 | batch_size: 128 17 | eval_period: 20 18 | ckpt_period: 100 19 | model: 20 | type: gnn 21 | loss_fun: cross_entropy 22 | edge_decoding: dot 23 | graph_pooling: add 24 | gnn: 25 | layers_pre_mp: 1 26 | layers_mp: 3 27 | layers_post_mp: 3 28 | dim_inner: 64 29 | layer_type: gatidconv 30 | stage_type: stack 31 | batchnorm: True 32 | act: relu 33 | dropout: 0.0 34 | agg: add 35 | normalize_adj: False 36 | l2norm: True 37 | optim: 38 | optimizer: adam 39 | base_lr: 0.01 40 | max_epoch: 100 -------------------------------------------------------------------------------- /run/configs/IDGNN/graph_ogb.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | dataset: 3 | format: OGB 4 | name: ogbg-molhiv 5 | task: graph 6 | task_type: classification 7 | transductive: False 8 | split: [0.0, 0.0, 0.0] 9 | augment_feature: [] 10 | augment_feature_dims: [10] 11 | augment_feature_repr: original 12 | augment_label: '' 13 | augment_label_dims: 10 14 | transform: ego 15 | node_encoder: True 16 | node_encoder_name: 'SingleAtom' 17 | node_encoder_bn: True 18 | edge_encoder: False 19 | encoder_dim: 128 20 | train: 21 | batch_size: 128 22 | eval_period: 1 23 | ckpt_period: 100 24 | model: 25 | type: gnn 26 | loss_fun: cross_entropy 27 | edge_decoding: dot 28 | graph_pooling: add 29 | gnn: 30 | layers_pre_mp: 1 31 | layers_mp: 3 32 | layers_post_mp: 3 33 | dim_inner: 128 34 | layer_type: idconv 35 | stage_type: stack 36 | batchnorm: True 37 | act: relu 38 | dropout: 0.0 39 | agg: add 40 | normalize_adj: False 41 | l2norm: True 42 | optim: 43 | optimizer: adam 44 | scheduler: none 45 | base_lr: 0.001 46 | max_epoch: 100 -------------------------------------------------------------------------------- /run/configs/IDGNN/node.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | dataset: 3 | format: nx 4 | name: ws 5 | task: node 6 | task_type: classification 7 | transductive: False 8 | split: [0.8, 0.2] 9 | augment_feature: [] 10 | augment_feature_dims: [10] 11 | augment_feature_repr: original 12 | augment_label: node_clustering_coefficient 13 | augment_label_dims: 10 14 | transform: ego 15 | train: 16 | batch_size: 128 17 | eval_period: 10 18 | ckpt_period: 1000 19 | model: 20 | type: gnn 21 | loss_fun: cross_entropy 22 | edge_decoding: dot 23 | graph_pooling: add 24 | gnn: 25 | layers_pre_mp: 1 26 | layers_mp: 3 27 | layers_post_mp: 1 28 | dim_inner: 128 29 | layer_type: idconv 30 | stage_type: stack 31 | batchnorm: True 32 | act: relu 33 | dropout: 0.0 34 | agg: add 35 | normalize_adj: False 36 | l2norm: True 37 | optim: 38 | optimizer: adam 39 | base_lr: 0.01 40 | max_epoch: 1000 -------------------------------------------------------------------------------- /run/configs/IDGNN/node_clustering.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | dataset: 3 | format: nx 4 | name: ws 5 | task: node 6 | task_type: classification 7 | transductive: False 8 | split: [0.8, 0.2] 9 | augment_feature: [] 10 | augment_feature_dims: [10] 11 | augment_feature_repr: original 12 | augment_label: node_clustering_coefficient 13 | augment_label_dims: 10 14 | transform: ego 15 | train: 16 | batch_size: 32 17 | eval_period: 10 18 | ckpt_period: 1000 19 | model: 20 | type: gnn 21 | loss_fun: cross_entropy 22 | edge_decoding: dot 23 | graph_pooling: add 24 | gnn: 25 | layers_pre_mp: 1 26 | layers_mp: 3 27 | layers_post_mp: 1 28 | dim_inner: 64 29 | layer_type: idconv 30 | stage_type: stack 31 | batchnorm: True 32 | act: relu 33 | dropout: 0.0 34 | agg: add 35 | normalize_adj: False 36 | l2norm: True 37 | optim: 38 | optimizer: adam 39 | base_lr: 0.01 40 | max_epoch: 1000 -------------------------------------------------------------------------------- /run/configs/design/design_v1.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | dataset: 3 | format: PyG 4 | name: Cora 5 | task: node 6 | task_type: classification 7 | transductive: True 8 | split: [0.8, 0.2] 9 | augment_feature: [] 10 | augment_feature_dims: [10] 11 | augment_feature_repr: position 12 | augment_label: '' 13 | augment_label_dims: 5 14 | transform: none 15 | train: 16 | batch_size: 128 17 | eval_period: 20 18 | ckpt_period: 100 19 | model: 20 | type: gnn 21 | loss_fun: cross_entropy 22 | edge_decoding: dot 23 | graph_pooling: add 24 | gnn: 25 | layers_pre_mp: 1 26 | layers_mp: 2 27 | layers_post_mp: 1 28 | dim_inner: 256 29 | layer_type: generalconv 30 | stage_type: stack 31 | batchnorm: True 32 | act: relu 33 | dropout: 0.0 34 | agg: add 35 | normalize_adj: False 36 | optim: 37 | optimizer: adam 38 | base_lr: 0.01 39 | max_epoch: 100 -------------------------------------------------------------------------------- /run/configs/design/design_v1att.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | dataset: 3 | format: PyG 4 | name: Cora 5 | task: node 6 | task_type: classification 7 | transductive: True 8 | split: [0.8, 0.2] 9 | augment_feature: [] 10 | augment_feature_dims: [10] 11 | augment_feature_repr: position 12 | augment_label: '' 13 | augment_label_dims: 5 14 | transform: none 15 | train: 16 | batch_size: 128 17 | eval_period: 20 18 | ckpt_period: 100 19 | model: 20 | type: gnn 21 | loss_fun: cross_entropy 22 | edge_decoding: dot 23 | graph_pooling: add 24 | gnn: 25 | layers_pre_mp: 1 26 | layers_mp: 2 27 | layers_post_mp: 1 28 | dim_inner: 256 29 | layer_type: generalmulattconv 30 | stage_type: stack 31 | batchnorm: True 32 | act: relu 33 | dropout: 0.0 34 | agg: add 35 | normalize_adj: False 36 | optim: 37 | optimizer: adam 38 | base_lr: 0.01 39 | max_epoch: 100 -------------------------------------------------------------------------------- /run/configs/design/design_v2.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | dataset: 3 | format: PyG 4 | name: Cora 5 | task: node 6 | task_type: classification 7 | transductive: True 8 | split: [0.8, 0.2] 9 | augment_feature: [] 10 | augment_feature_dims: [10] 11 | augment_feature_repr: position 12 | augment_label: '' 13 | augment_label_dims: 5 14 | transform: none 15 | train: 16 | batch_size: 32 17 | eval_period: 20 18 | ckpt_period: 100 19 | model: 20 | type: gnn 21 | loss_fun: cross_entropy 22 | edge_decoding: dot 23 | graph_pooling: add 24 | gnn: 25 | layers_pre_mp: 1 26 | layers_mp: 2 27 | layers_post_mp: 1 28 | dim_inner: 256 29 | layer_type: generalconv 30 | stage_type: stack 31 | batchnorm: True 32 | act: prelu 33 | dropout: 0.0 34 | agg: add 35 | normalize_adj: False 36 | optim: 37 | optimizer: adam 38 | base_lr: 0.01 39 | max_epoch: 400 -------------------------------------------------------------------------------- /run/configs/design/design_v2link.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | dataset: 3 | format: PyG 4 | name: TU_ENZYMES 5 | task: link_pred 6 | task_type: classification 7 | transductive: True 8 | split: [0.8, 0.2] 9 | augment_feature: [] 10 | augment_feature_dims: [10] 11 | augment_feature_repr: position 12 | augment_label: '' 13 | augment_label_dims: 5 14 | transform: none 15 | edge_train_mode: disjoint 16 | train: 17 | batch_size: 32 18 | eval_period: 20 19 | ckpt_period: 100 20 | model: 21 | type: gnn 22 | loss_fun: cross_entropy 23 | edge_decoding: dot 24 | graph_pooling: add 25 | gnn: 26 | layers_pre_mp: 1 27 | layers_mp: 2 28 | layers_post_mp: 1 29 | dim_inner: 256 30 | layer_type: generalconv 31 | stage_type: stack 32 | batchnorm: True 33 | act: prelu 34 | dropout: 0.0 35 | agg: add 36 | normalize_adj: False 37 | optim: 38 | optimizer: adam 39 | base_lr: 0.01 40 | max_epoch: 100 -------------------------------------------------------------------------------- /run/configs/design/design_v2ogb.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | dataset: 3 | format: OGB 4 | name: ogbg-molhiv 5 | task: graph 6 | task_type: classification 7 | transductive: False 8 | split: [0.0, 0.0, 0.0] 9 | augment_feature: [] 10 | augment_feature_dims: [10] 11 | augment_feature_repr: position 12 | augment_label: '' 13 | augment_label_dims: 5 14 | transform: none 15 | node_encoder: True 16 | node_encoder_name: 'Atom' 17 | node_encoder_bn: True 18 | edge_encoder: False 19 | encoder_dim: 300 20 | train: 21 | batch_size: 128 22 | eval_period: 1 23 | ckpt_period: 100 24 | model: 25 | type: gnn 26 | loss_fun: cross_entropy 27 | edge_decoding: dot 28 | graph_pooling: add 29 | gnn: 30 | layers_pre_mp: 1 31 | layers_mp: 5 32 | layers_post_mp: 1 33 | dim_inner: 300 34 | layer_type: generalogbconv 35 | stage_type: stack 36 | batchnorm: True 37 | act: prelu 38 | dropout: 0.0 39 | agg: add 40 | normalize_adj: False 41 | optim: 42 | optimizer: adam 43 | base_lr: 0.01 44 | max_epoch: 100 45 | num_threads: 10 -------------------------------------------------------------------------------- /run/configs/example.yaml: -------------------------------------------------------------------------------- 1 | # The recommended basic settings for GNN 2 | out_dir: results 3 | dataset: 4 | format: PyG 5 | name: Cora 6 | task: node 7 | task_type: classification 8 | transductive: True 9 | split: [0.8, 0.2] 10 | augment_feature: [] 11 | augment_feature_dims: [0] 12 | augment_feature_repr: position 13 | augment_label: '' 14 | augment_label_dims: 0 15 | transform: none 16 | train: 17 | batch_size: 32 18 | eval_period: 20 19 | ckpt_period: 100 20 | model: 21 | type: gnn 22 | loss_fun: cross_entropy 23 | edge_decoding: dot 24 | graph_pooling: add 25 | gnn: 26 | layers_pre_mp: 1 27 | layers_mp: 2 28 | layers_post_mp: 1 29 | dim_inner: 256 30 | layer_type: generalconv 31 | stage_type: stack 32 | batchnorm: True 33 | act: prelu 34 | dropout: 0.0 35 | agg: add 36 | normalize_adj: False 37 | optim: 38 | optimizer: adam 39 | base_lr: 0.01 40 | max_epoch: 400 -------------------------------------------------------------------------------- /run/configs/example_cpu.yaml: -------------------------------------------------------------------------------- 1 | # The recommended basic settings for GNN 2 | out_dir: results 3 | device: cpu 4 | dataset: 5 | format: PyG 6 | name: Cora 7 | task: node 8 | task_type: classification 9 | transductive: True 10 | split: [0.8, 0.2] 11 | augment_feature: [] 12 | augment_feature_dims: [0] 13 | augment_feature_repr: position 14 | augment_label: '' 15 | augment_label_dims: 0 16 | transform: none 17 | train: 18 | batch_size: 32 19 | eval_period: 20 20 | ckpt_period: 100 21 | model: 22 | type: gnn 23 | loss_fun: cross_entropy 24 | edge_decoding: dot 25 | graph_pooling: add 26 | gnn: 27 | layers_pre_mp: 1 28 | layers_mp: 2 29 | layers_post_mp: 1 30 | dim_inner: 256 31 | layer_type: generalconv 32 | stage_type: stack 33 | batchnorm: True 34 | act: prelu 35 | dropout: 0.0 36 | agg: add 37 | normalize_adj: False 38 | optim: 39 | optimizer: adam 40 | base_lr: 0.01 41 | max_epoch: 400 -------------------------------------------------------------------------------- /run/configs/pyg/example_graph.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | dataset: 3 | format: OGB 4 | name: ogbg-molhiv 5 | task: graph 6 | task_type: classification 7 | node_encoder: True 8 | node_encoder_name: Atom 9 | edge_encoder: True 10 | edge_encoder_name: Bond 11 | train: 12 | batch_size: 128 13 | eval_period: 1 14 | ckpt_period: 100 15 | sampler: full_batch 16 | model: 17 | type: gnn 18 | loss_fun: cross_entropy 19 | edge_decoding: dot 20 | graph_pooling: add 21 | gnn: 22 | layers_pre_mp: 1 23 | layers_mp: 2 24 | layers_post_mp: 1 25 | dim_inner: 300 26 | layer_type: generalconv 27 | stage_type: stack 28 | batchnorm: True 29 | act: prelu 30 | dropout: 0.0 31 | agg: mean 32 | normalize_adj: False 33 | optim: 34 | optimizer: adam 35 | base_lr: 0.01 36 | max_epoch: 100 -------------------------------------------------------------------------------- /run/configs/pyg/example_link.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | dataset: 3 | format: OGB 4 | name: ogbl-collab 5 | task: link_pred 6 | task_type: classification 7 | node_encoder: False 8 | node_encoder_name: Atom 9 | edge_encoder: False 10 | edge_encoder_name: Bond 11 | train: 12 | batch_size: 128 13 | eval_period: 1 14 | ckpt_period: 100 15 | sampler: full_batch 16 | model: 17 | type: gnn 18 | loss_fun: cross_entropy 19 | edge_decoding: dot 20 | graph_pooling: add 21 | gnn: 22 | layers_pre_mp: 1 23 | layers_mp: 2 24 | layers_post_mp: 1 25 | dim_inner: 300 26 | layer_type: gcnconv 27 | stage_type: stack 28 | batchnorm: True 29 | act: prelu 30 | dropout: 0.0 31 | agg: mean 32 | normalize_adj: False 33 | optim: 34 | optimizer: adam 35 | base_lr: 0.01 36 | max_epoch: 100 -------------------------------------------------------------------------------- /run/configs/pyg/example_node.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | dataset: 3 | format: OGB 4 | name: ogbn-arxiv 5 | task: node 6 | task_type: classification 7 | node_encoder: False 8 | node_encoder_name: Atom 9 | edge_encoder: False 10 | edge_encoder_name: Bond 11 | train: 12 | batch_size: 128 13 | eval_period: 1 14 | ckpt_period: 100 15 | sampler: full_batch 16 | model: 17 | type: gnn 18 | loss_fun: cross_entropy 19 | edge_decoding: dot 20 | graph_pooling: add 21 | gnn: 22 | layers_pre_mp: 1 23 | layers_mp: 3 24 | layers_post_mp: 1 25 | dim_inner: 128 26 | layer_type: sageconv 27 | stage_type: skipsum 28 | batchnorm: True 29 | act: prelu 30 | dropout: 0.1 31 | agg: mean 32 | normalize_adj: False 33 | optim: 34 | optimizer: adam 35 | base_lr: 0.01 36 | max_epoch: 200 -------------------------------------------------------------------------------- /run/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snap-stanford/GraphGym/daded21169ec92fde8b1252b439a8fac35b07d79/run/datasets/__init__.py -------------------------------------------------------------------------------- /run/datasets/ba.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snap-stanford/GraphGym/daded21169ec92fde8b1252b439a8fac35b07d79/run/datasets/ba.pkl -------------------------------------------------------------------------------- /run/datasets/ba500.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snap-stanford/GraphGym/daded21169ec92fde8b1252b439a8fac35b07d79/run/datasets/ba500.pkl -------------------------------------------------------------------------------- /run/datasets/scalefree.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snap-stanford/GraphGym/daded21169ec92fde8b1252b439a8fac35b07d79/run/datasets/scalefree.pkl -------------------------------------------------------------------------------- /run/datasets/smallworld.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snap-stanford/GraphGym/daded21169ec92fde8b1252b439a8fac35b07d79/run/datasets/smallworld.pkl -------------------------------------------------------------------------------- /run/datasets/syn_graph.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import pickle 4 | 5 | import matplotlib.pyplot as plt 6 | import networkx as nx 7 | import numpy as np 8 | import torch 9 | 10 | dirname = os.path.dirname(__file__) 11 | 12 | 13 | def degree_dist(G): 14 | degree_sequence = sorted([d for n, d in G.degree()], reverse=True) 15 | degreeCount = collections.Counter(degree_sequence) 16 | deg, cnt = zip(*degreeCount.items()) 17 | 18 | fig, ax = plt.subplots() 19 | plt.bar(deg, cnt, width=0.80, color='b') 20 | 21 | plt.title("Degree Histogram") 22 | plt.ylabel("Count") 23 | plt.xlabel("Degree") 24 | ax.set_xticks([d + 0.4 for d in deg]) 25 | ax.set_xticklabels(deg) 26 | plt.show() 27 | 28 | 29 | def save_syn(): 30 | clustering_bins = np.linspace(0.3, 0.6, 7) 31 | print(clustering_bins) 32 | path_bins = np.linspace(1.8, 3.0, 7) 33 | print(path_bins) 34 | 35 | powerlaw_k = np.arange(2, 12) 36 | powerlaw_p = np.linspace(0, 1, 101) 37 | ws_k = np.arange(4, 23, 2) 38 | ws_p = np.linspace(0, 1, 101) 39 | 40 | counts = np.zeros((8, 8)) 41 | thresh = 4 42 | graphs = [] 43 | n = 64 44 | while True: 45 | k, p = np.random.choice(powerlaw_k), np.random.choice(powerlaw_p) 46 | g = nx.powerlaw_cluster_graph(n, k, p) 47 | clustering = nx.average_clustering(g) 48 | path = nx.average_shortest_path_length(g) 49 | clustering_id = np.digitize(clustering, clustering_bins) 50 | path_id = np.digitize(path, path_bins) 51 | if counts[clustering_id, path_id] < thresh: 52 | counts[clustering_id, path_id] += 1 53 | default_feature = torch.ones(1) 54 | nx.set_node_attributes(g, default_feature, 'node_feature') 55 | graphs.append(g) 56 | print(np.sum(counts)) 57 | if np.sum(counts) == 8 * 8 * thresh: 58 | break 59 | 60 | with open('scalefree.pkl', 'wb') as file: 61 | pickle.dump(graphs, file) 62 | 63 | counts = np.zeros((8, 8)) 64 | thresh = 4 65 | graphs = [] 66 | n = 64 67 | while True: 68 | k, p = np.random.choice(ws_k), np.random.choice(ws_p) 69 | g = nx.watts_strogatz_graph(n, k, p) 70 | clustering = nx.average_clustering(g) 71 | path = nx.average_shortest_path_length(g) 72 | clustering_id = np.digitize(clustering, clustering_bins) 73 | path_id = np.digitize(path, path_bins) 74 | if counts[clustering_id, path_id] < thresh: 75 | counts[clustering_id, path_id] += 1 76 | default_feature = torch.ones(1) 77 | nx.set_node_attributes(g, default_feature, 'node_feature') 78 | graphs.append(g) 79 | print(np.sum(counts)) 80 | if np.sum(counts) == 8 * 8 * thresh: 81 | break 82 | 83 | with open('smallworld.pkl', 'wb') as file: 84 | pickle.dump(graphs, file) 85 | 86 | 87 | def load_syn(): 88 | with open('{}/smallworld.pkl'.format(dirname), 'rb') as file: 89 | graphs = pickle.load(file) 90 | for graph in graphs: 91 | print(nx.average_clustering(graph), 92 | nx.average_shortest_path_length(graph)) 93 | # degree_dist(graph) 94 | -------------------------------------------------------------------------------- /run/datasets/ws.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snap-stanford/GraphGym/daded21169ec92fde8b1252b439a8fac35b07d79/run/datasets/ws.pkl -------------------------------------------------------------------------------- /run/datasets/ws500.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snap-stanford/GraphGym/daded21169ec92fde8b1252b439a8fac35b07d79/run/datasets/ws500.pkl -------------------------------------------------------------------------------- /run/grids/IDGNN/graph.txt: -------------------------------------------------------------------------------- 1 | # name in config.py; short name; range to search 2 | 3 | # task: node, dataset: syn; model: gnn 4 | dataset.format format ['PyG'] 5 | dataset.name dataset ['TU_BZR','TU_COX2','TU_PROTEINS'] 6 | dataset.transform transform ['none'] 7 | dataset.task task ['graph'] 8 | dataset.transductive trans [False] 9 | dataset.augment_feature feature [[],['node_identity']] 10 | dataset.augment_label label [''] 11 | gnn.layers_pre_mp l_pre [1] 12 | gnn.layers_mp l_mp [3] 13 | gnn.layers_post_mp l_post [3] 14 | gnn.layer_type layer ['gcnconv','sageconv','gatconv','ginconv'] 15 | optim.max_epoch epoch [1000] 16 | 17 | 18 | # task: node, dataset: syn; model: idgnn 19 | dataset.format format ['PyG'] 20 | dataset.name dataset ['TU_BZR','TU_COX2','TU_PROTEINS'] 21 | dataset.transform transform ['ego'] 22 | dataset.task task ['graph'] 23 | dataset.transductive trans [False] 24 | dataset.augment_feature feature [[]] 25 | dataset.augment_label label [''] 26 | gnn.layers_pre_mp l_pre [1] 27 | gnn.layers_mp l_mp [3] 28 | gnn.layers_post_mp l_post [3] 29 | gnn.layer_type layer ['gcnidconv','sageidconv','gatidconv','ginidconv'] 30 | optim.max_epoch epoch [1000] 31 | 32 | 33 | # task: node, dataset: syn; model: gnn 34 | dataset.format format ['nx'] 35 | dataset.name dataset ['ws500','ba500'] 36 | dataset.transform transform ['none'] 37 | dataset.task task ['graph'] 38 | dataset.transductive trans [False] 39 | dataset.augment_feature feature [[],['node_identity']] 40 | dataset.augment_label label ['graph_clustering_coefficient'] 41 | gnn.layers_pre_mp l_pre [1] 42 | gnn.layers_mp l_mp [3] 43 | gnn.layers_post_mp l_post [3] 44 | gnn.layer_type layer ['gcnconv','sageconv','gatconv','ginconv'] 45 | optim.max_epoch epoch [1000] 46 | 47 | 48 | # task: node, dataset: syn; model: idgnn 49 | dataset.format format ['nx'] 50 | dataset.name dataset ['ws500','ba500'] 51 | dataset.transform transform ['ego'] 52 | dataset.task task ['graph'] 53 | dataset.transductive trans [False] 54 | dataset.augment_feature feature [[]] 55 | dataset.augment_label label ['graph_clustering_coefficient'] 56 | gnn.layers_pre_mp l_pre [1] 57 | gnn.layers_mp l_mp [3] 58 | gnn.layers_post_mp l_post [3] 59 | gnn.layer_type layer ['gcnidconv','sageidconv','gatidconv','ginidconv'] 60 | optim.max_epoch epoch [1000] 61 | -------------------------------------------------------------------------------- /run/grids/IDGNN/graph_enzyme.txt: -------------------------------------------------------------------------------- 1 | # name in config.py; short name; range to search 2 | 3 | 4 | # task: node, dataset: syn; model: gnn 5 | dataset.format format ['PyG'] 6 | dataset.name dataset ['TU_ENZYMES'] 7 | dataset.transform transform ['none'] 8 | dataset.task task ['graph'] 9 | dataset.transductive trans [False] 10 | dataset.augment_feature feature [[],['node_identity']] 11 | dataset.augment_label label [''] 12 | gnn.layers_pre_mp l_pre [1] 13 | gnn.layers_mp l_mp [3] 14 | gnn.layers_post_mp l_post [3] 15 | gnn.layer_type layer ['gcnconv','sageconv','gatconv','ginconv'] 16 | optim.max_epoch epoch [100] 17 | 18 | 19 | # task: node, dataset: syn; model: idgnn 20 | dataset.format format ['PyG'] 21 | dataset.name dataset ['TU_ENZYMES'] 22 | dataset.transform transform ['ego'] 23 | dataset.task task ['graph'] 24 | dataset.transductive trans [False] 25 | dataset.augment_feature feature [[]] 26 | dataset.augment_label label [''] 27 | gnn.layers_pre_mp l_pre [1] 28 | gnn.layers_mp l_mp [3] 29 | gnn.layers_post_mp l_post [3] 30 | gnn.layer_type layer ['gcnidconv','sageidconv','gatidconv','ginidconv'] 31 | optim.max_epoch epoch [100] 32 | -------------------------------------------------------------------------------- /run/grids/IDGNN/graph_ogb.txt: -------------------------------------------------------------------------------- 1 | # name in config.py; short name; range to search 2 | 3 | 4 | # task: node, dataset: syn; model: gnn 5 | dataset.format format ['OGB'] 6 | dataset.name dataset ['ogbg-molhiv'] 7 | dataset.transform transform ['none'] 8 | dataset.task task ['graph'] 9 | dataset.transductive trans [False] 10 | dataset.augment_feature feature [[],['node_identity']] 11 | dataset.augment_label label [''] 12 | dataset.feature_encoder_name encoder ['Atom'] 13 | gnn.layers_pre_mp l_pre [1] 14 | gnn.layers_mp l_mp [3] 15 | gnn.layers_post_mp l_post [3] 16 | gnn.layer_type layer ['gcnconv','sageconv','gatconv','ginconv'] 17 | optim.max_epoch epoch [100] 18 | 19 | 20 | # task: node, dataset: syn; model: idgnn 21 | dataset.format format ['OGB'] 22 | dataset.name dataset ['ogbg-molhiv'] 23 | dataset.transform transform ['ego'] 24 | dataset.task task ['graph'] 25 | dataset.transductive trans [False] 26 | dataset.augment_feature feature [[]] 27 | dataset.augment_label label [''] 28 | dataset.feature_encoder_name encoder ['Atom'] 29 | gnn.layers_pre_mp l_pre [1] 30 | gnn.layers_mp l_mp [3] 31 | gnn.layers_post_mp l_post [3] 32 | gnn.layer_type layer ['gcnidconv','sageidconv','gatidconv','ginidconv'] 33 | optim.max_epoch epoch [100] 34 | -------------------------------------------------------------------------------- /run/grids/IDGNN/link.txt: -------------------------------------------------------------------------------- 1 | # name in config.py; short name; range to search 2 | 3 | # task: node, dataset: syn; model: gnn 4 | dataset.format format ['nx'] 5 | dataset.name dataset ['ws','ba'] 6 | dataset.transform transform ['none'] 7 | dataset.task task ['link_pred'] 8 | dataset.transductive trans [False] 9 | dataset.augment_feature feature [[],['node_identity']] 10 | dataset.augment_label label [''] 11 | gnn.layers_pre_mp l_pre [1] 12 | gnn.layers_mp l_mp [5] 13 | gnn.layers_post_mp l_post [1] 14 | gnn.layer_type layer ['gcnconv','sageconv','gatconv','ginconv'] 15 | optim.max_epoch epoch [100] 16 | train.batch_size batch [32] 17 | 18 | 19 | # task: node, dataset: syn; model: idgnn 20 | dataset.format format ['nx'] 21 | dataset.name dataset ['ws','ba'] 22 | dataset.transform transform ['edge'] 23 | dataset.task task ['link_pred'] 24 | dataset.transductive trans [False] 25 | dataset.augment_feature feature [[]] 26 | dataset.augment_label label [''] 27 | gnn.layers_pre_mp l_pre [1] 28 | gnn.layers_mp l_mp [5] 29 | gnn.layers_post_mp l_post [1] 30 | gnn.layer_type layer ['gcnidconv','sageidconv','gatidconv','ginidconv'] 31 | optim.max_epoch epoch [100] 32 | train.batch_size batch [32] 33 | 34 | # task: node, dataset: real; model: gnn 35 | dataset.format format ['PyG'] 36 | dataset.name dataset ['TU_ENZYMES','TU_PROTEINS'] 37 | dataset.transform transform ['none'] 38 | dataset.task task ['link_pred'] 39 | dataset.transductive trans [False] 40 | dataset.augment_feature feature [[],['node_identity']] 41 | dataset.augment_label label [''] 42 | gnn.layers_pre_mp l_pre [1] 43 | gnn.layers_mp l_mp [5] 44 | gnn.layers_post_mp l_post [1] 45 | gnn.layer_type layer ['gcnconv','sageconv','gatconv','ginconv'] 46 | optim.max_epoch epoch [100] 47 | train.batch_size batch [8] 48 | 49 | 50 | # task: node, dataset: real; model: idgnn 51 | dataset.format format ['PyG'] 52 | dataset.name dataset ['TU_ENZYMES','TU_PROTEINS'] 53 | dataset.transform transform ['edge'] 54 | dataset.task task ['link_pred'] 55 | dataset.transductive trans [False] 56 | dataset.augment_feature feature [[]] 57 | dataset.augment_label label [''] 58 | gnn.layers_pre_mp l_pre [1] 59 | gnn.layers_mp l_mp [5] 60 | gnn.layers_post_mp l_post [1] 61 | gnn.layer_type layer ['gcnidconv','sageidconv','gatidconv','ginidconv'] 62 | optim.max_epoch epoch [100] 63 | train.batch_size batch [8] 64 | -------------------------------------------------------------------------------- /run/grids/IDGNN/node.txt: -------------------------------------------------------------------------------- 1 | # name in config.py; short name; range to search 2 | 3 | 4 | # task: node, dataset: syn; model: idgnn 5 | dataset.format format ['nx'] 6 | dataset.name dataset ['ba'] 7 | dataset.transform transform ['ego'] 8 | dataset.task task ['node'] 9 | dataset.transductive trans [False] 10 | dataset.augment_feature feature [[]] 11 | dataset.augment_label label ['node_clustering_coefficient'] 12 | gnn.layers_pre_mp l_pre [1] 13 | gnn.layers_mp l_mp [2] 14 | gnn.layers_post_mp l_post [1] 15 | gnn.layer_type layer ['gatidconv'] 16 | optim.max_epoch epoch [1000] 17 | 18 | 19 | 20 | # task: node, dataset: real; model: idgnn 21 | dataset.format format ['PyG'] 22 | dataset.name dataset ['Cora'] 23 | dataset.transform transform ['ego'] 24 | dataset.task task ['node'] 25 | dataset.transductive trans [True] 26 | dataset.augment_feature feature [[]] 27 | dataset.augment_label label [''] 28 | gnn.layers_pre_mp l_pre [1] 29 | gnn.layers_mp l_mp [2] 30 | gnn.layers_post_mp l_post [1] 31 | gnn.layer_type layer ['gatidconv'] 32 | optim.max_epoch epoch [100] 33 | 34 | 35 | # task: node, dataset: syn; model: gnn 36 | dataset.format format ['nx'] 37 | dataset.name dataset ['ba'] 38 | dataset.transform transform ['none'] 39 | dataset.task task ['node'] 40 | dataset.transductive trans [False] 41 | dataset.augment_feature feature [[],['node_identity']] 42 | dataset.augment_label label ['node_clustering_coefficient'] 43 | gnn.layers_pre_mp l_pre [1] 44 | gnn.layers_mp l_mp [2] 45 | gnn.layers_post_mp l_post [1] 46 | gnn.layer_type layer ['gatconv'] 47 | optim.max_epoch epoch [1000] 48 | 49 | 50 | # task: node, dataset: real; model: gnn 51 | dataset.format format ['PyG'] 52 | dataset.name dataset ['Cora'] 53 | dataset.transform transform ['none'] 54 | dataset.task task ['node'] 55 | dataset.transductive trans [True] 56 | dataset.augment_feature feature [[],['node_identity']] 57 | dataset.augment_label label [''] 58 | gnn.layers_pre_mp l_pre [1] 59 | gnn.layers_mp l_mp [2] 60 | gnn.layers_post_mp l_post [1] 61 | gnn.layer_type layer ['gatconv'] 62 | optim.max_epoch epoch [100] 63 | 64 | 65 | 66 | # name in config.py; short name; range to search 67 | 68 | # task: node, dataset: syn; model: gnn 69 | dataset.format format ['nx'] 70 | dataset.name dataset ['ws','ba'] 71 | dataset.transform transform ['none'] 72 | dataset.task task ['node'] 73 | dataset.transductive trans [False] 74 | dataset.augment_feature feature [[],['node_identity']] 75 | dataset.augment_label label ['node_clustering_coefficient'] 76 | gnn.layers_pre_mp l_pre [1] 77 | gnn.layers_mp l_mp [3] 78 | gnn.layers_post_mp l_post [1] 79 | gnn.layer_type layer ['gcnconv','sageconv','ginconv'] 80 | optim.max_epoch epoch [1000] 81 | 82 | # task: node, dataset: syn; model: idgnn 83 | dataset.format format ['nx'] 84 | dataset.name dataset ['ws','ba'] 85 | dataset.transform transform ['ego'] 86 | dataset.task task ['node'] 87 | dataset.transductive trans [False] 88 | dataset.augment_feature feature [[]] 89 | dataset.augment_label label ['node_clustering_coefficient'] 90 | gnn.layers_pre_mp l_pre [1] 91 | gnn.layers_mp l_mp [3] 92 | gnn.layers_post_mp l_post [1] 93 | gnn.layer_type layer ['gcnidconv','sageidconv','ginidconv'] 94 | optim.max_epoch epoch [1000] 95 | 96 | # task: node, dataset: real; model: gnn 97 | dataset.format format ['PyG'] 98 | dataset.name dataset ['Cora','CiteSeer'] 99 | dataset.transform transform ['none'] 100 | dataset.task task ['node'] 101 | dataset.transductive trans [True] 102 | dataset.augment_feature feature [[],['node_identity']] 103 | dataset.augment_label label [''] 104 | gnn.layers_pre_mp l_pre [1] 105 | gnn.layers_mp l_mp [3] 106 | gnn.layers_post_mp l_post [1] 107 | gnn.layer_type layer ['gcnconv','sageconv','ginconv'] 108 | optim.max_epoch epoch [100] 109 | 110 | # task: node, dataset: real; model: idgnn 111 | dataset.format format ['PyG'] 112 | dataset.name dataset ['Cora','CiteSeer'] 113 | dataset.transform transform ['ego'] 114 | dataset.task task ['node'] 115 | dataset.transductive trans [True] 116 | dataset.augment_feature feature [[]] 117 | dataset.augment_label label [''] 118 | gnn.layers_pre_mp l_pre [1] 119 | gnn.layers_mp l_mp [3] 120 | gnn.layers_post_mp l_post [1] 121 | gnn.layer_type layer ['gcnidconv','sageidconv','ginidconv'] 122 | optim.max_epoch epoch [100] -------------------------------------------------------------------------------- /run/grids/IDGNN/node_clustering.txt: -------------------------------------------------------------------------------- 1 | # name in config.py; short name; range to search 2 | 3 | # task: node, dataset: real; model: gnn 4 | dataset.format format ['PyG'] 5 | dataset.name dataset ['TU_ENZYMES','TU_PROTEINS'] 6 | dataset.transform transform ['none'] 7 | dataset.task task ['node'] 8 | dataset.transductive trans [False] 9 | dataset.augment_feature feature [[],['node_identity']] 10 | dataset.augment_label label ['node_clustering_coefficient'] 11 | gnn.layers_pre_mp l_pre [1] 12 | gnn.layers_mp l_mp [3] 13 | gnn.layers_post_mp l_post [1] 14 | gnn.layer_type layer ['gcnconv','sageconv','gatconv','ginconv'] 15 | optim.max_epoch epoch [1000] 16 | 17 | # task: node, dataset: real; model: idgnn 18 | dataset.format format ['PyG'] 19 | dataset.name dataset ['TU_ENZYMES','TU_PROTEINS'] 20 | dataset.transform transform ['ego'] 21 | dataset.task task ['node'] 22 | dataset.transductive trans [False] 23 | dataset.augment_feature feature [[]] 24 | dataset.augment_label label ['node_clustering_coefficient'] 25 | gnn.layers_pre_mp l_pre [1] 26 | gnn.layers_mp l_mp [3] 27 | gnn.layers_post_mp l_post [1] 28 | gnn.layer_type layer ['gcnidconv','sageidconv','gatidconv','ginidconv'] 29 | optim.max_epoch epoch [1000] -------------------------------------------------------------------------------- /run/grids/IDGNN/path.txt: -------------------------------------------------------------------------------- 1 | # name in config.py; short name; range to search 2 | 3 | # task: node, dataset: syn; model: gnn 4 | dataset.format format ['nx'] 5 | dataset.name dataset ['ws','ba'] 6 | dataset.transform transform ['none'] 7 | dataset.task task ['edge'] 8 | dataset.transductive trans [False] 9 | dataset.augment_feature feature [[],['node_identity']] 10 | dataset.augment_label label [''] 11 | gnn.layers_pre_mp l_pre [1] 12 | gnn.layers_mp l_mp [5] 13 | gnn.layers_post_mp l_post [1] 14 | gnn.layer_type layer ['gcnconv','sageconv','gatconv','ginconv'] 15 | optim.max_epoch epoch [100] 16 | train.batch_size batch [32] 17 | 18 | 19 | # task: node, dataset: syn; model: idgnn 20 | dataset.format format ['nx'] 21 | dataset.name dataset ['ws','ba'] 22 | dataset.transform transform ['edge'] 23 | dataset.task task ['edge'] 24 | dataset.transductive trans [False] 25 | dataset.augment_feature feature [[]] 26 | dataset.augment_label label [''] 27 | gnn.layers_pre_mp l_pre [1] 28 | gnn.layers_mp l_mp [5] 29 | gnn.layers_post_mp l_post [1] 30 | gnn.layer_type layer ['gcnidconv','sageidconv','gatidconv','ginidconv'] 31 | optim.max_epoch epoch [100] 32 | train.batch_size batch [32] 33 | 34 | # task: node, dataset: real; model: gnn 35 | dataset.format format ['PyG'] 36 | dataset.name dataset ['TU_ENZYMES','TU_PROTEINS'] 37 | dataset.transform transform ['none'] 38 | dataset.task task ['edge'] 39 | dataset.transductive trans [False] 40 | dataset.augment_feature feature [[],['node_identity']] 41 | dataset.augment_label label [''] 42 | gnn.layers_pre_mp l_pre [1] 43 | gnn.layers_mp l_mp [5] 44 | gnn.layers_post_mp l_post [1] 45 | gnn.layer_type layer ['gcnconv','sageconv','gatconv','ginconv'] 46 | optim.max_epoch epoch [100] 47 | train.batch_size batch [8] 48 | 49 | 50 | # task: node, dataset: real; model: idgnn 51 | dataset.format format ['PyG'] 52 | dataset.name dataset ['TU_ENZYMES','TU_PROTEINS'] 53 | dataset.transform transform ['edge'] 54 | dataset.task task ['edge'] 55 | dataset.transductive trans [False] 56 | dataset.augment_feature feature [[]] 57 | dataset.augment_label label [''] 58 | gnn.layers_pre_mp l_pre [1] 59 | gnn.layers_mp l_mp [5] 60 | gnn.layers_post_mp l_post [1] 61 | gnn.layer_type layer ['gcnidconv','sageidconv','gatidconv','ginidconv'] 62 | optim.max_epoch epoch [100] 63 | train.batch_size batch [8] 64 | -------------------------------------------------------------------------------- /run/grids/design/round1.txt: -------------------------------------------------------------------------------- 1 | # Format for each row: name in config.py; alias; range to search 2 | # No spaces, except between these 3 fields 3 | # Line breaks are used to union different grid search spaces 4 | # Feel free to add '#' to add comments 5 | 6 | # dataset: TU, task: graph 7 | dataset.format format ['PyG'] 8 | dataset.name dataset ['TU_BZR','TU_COX2','TU_DD','TU_IMDB','TU_ENZYMES','TU_PROTEINS'] 9 | dataset.task task ['graph'] 10 | dataset.transductive trans [False] 11 | dataset.augment_feature feature [[]] 12 | dataset.augment_label label [''] 13 | train.batch_size batch [16,32,64] 14 | gnn.layers_pre_mp l_pre [1,2,3] 15 | gnn.layers_mp l_mp [2,4,6,8] 16 | gnn.layers_post_mp l_post [1,2,3] 17 | gnn.stage_type stage ['stack','skipsum','skipconcat'] 18 | gnn.batchnorm bn [True,False] 19 | gnn.act act ['relu','prelu','swish'] 20 | gnn.dropout drop [0.0,0.3,0.6] 21 | gnn.agg agg ['add','mean','max'] 22 | optim.optimizer optim ['adam','sgd'] 23 | optim.base_lr lr [0.1,0.01,0.001] 24 | optim.max_epoch epoch [100,200,400] 25 | 26 | # dataset: Single, task: node 27 | dataset.format format ['PyG'] 28 | dataset.name dataset ['Cora','CiteSeer','CoauthorCS','CoauthorPhysics','AmazonComputers','AmazonPhoto'] 29 | dataset.task task ['node'] 30 | dataset.transductive trans [True] 31 | dataset.augment_feature feature [[]] 32 | dataset.augment_label label [''] 33 | train.batch_size batch [16,32,64] 34 | gnn.layers_pre_mp l_pre [1,2,3] 35 | gnn.layers_mp l_mp [2,4,6,8] 36 | gnn.layers_post_mp l_post [1,2,3] 37 | gnn.stage_type stage ['stack','skipsum','skipconcat'] 38 | gnn.batchnorm bn [True,False] 39 | gnn.act act ['relu','prelu','swish'] 40 | gnn.dropout drop [0.0,0.3,0.6] 41 | gnn.agg agg ['add','mean','max'] 42 | optim.optimizer optim ['adam','sgd'] 43 | optim.base_lr lr [0.1,0.01,0.001] 44 | optim.max_epoch epoch [100,200,400] 45 | 46 | 47 | # dataset: nx, task: node, label: clustering 48 | dataset.format format ['nx'] 49 | dataset.name dataset ['smallworld','scalefree'] 50 | dataset.task task ['node'] 51 | dataset.transductive trans [True] 52 | dataset.augment_feature feature [['node_const'],['node_onehot'],['node_pagerank']] 53 | dataset.augment_label label ['node_clustering_coefficient'] 54 | train.batch_size batch [16,32,64] 55 | gnn.layers_pre_mp l_pre [1,2,3] 56 | gnn.layers_mp l_mp [2,4,6,8] 57 | gnn.layers_post_mp l_post [1,2,3] 58 | gnn.stage_type stage ['stack','skipsum','skipconcat'] 59 | gnn.batchnorm bn [True,False] 60 | gnn.act act ['relu','prelu','swish'] 61 | gnn.dropout drop [0.0,0.3,0.6] 62 | gnn.agg agg ['add','mean','max'] 63 | optim.optimizer optim ['adam','sgd'] 64 | optim.base_lr lr [0.1,0.01,0.001] 65 | optim.max_epoch epoch [100,200,400] 66 | 67 | # dataset: nx, task: node, label: pagerank 68 | dataset.format format ['nx'] 69 | dataset.name dataset ['smallworld','scalefree'] 70 | dataset.task task ['node'] 71 | dataset.transductive trans [True] 72 | dataset.augment_feature feature [['node_const'],['node_onehot'],['node_clustering_coefficient']] 73 | dataset.augment_label label ['node_pagerank'] 74 | train.batch_size batch [16,32,64] 75 | gnn.layers_pre_mp l_pre [1,2,3] 76 | gnn.layers_mp l_mp [2,4,6,8] 77 | gnn.layers_post_mp l_post [1,2,3] 78 | gnn.stage_type stage ['stack','skipsum','skipconcat'] 79 | gnn.batchnorm bn [True,False] 80 | gnn.act act ['relu','prelu','swish'] 81 | gnn.dropout drop [0.0,0.3,0.6] 82 | gnn.agg agg ['add','mean','max'] 83 | optim.optimizer optim ['adam','sgd'] 84 | optim.base_lr lr [0.1,0.01,0.001] 85 | optim.max_epoch epoch [100,200,400] 86 | 87 | # dataset: nx, task: graph, label: graph_path_len 88 | dataset.format format ['nx'] 89 | dataset.name dataset ['smallworld','scalefree'] 90 | dataset.task task ['graph'] 91 | dataset.transductive trans [True] 92 | dataset.augment_feature feature [['node_const'],['node_onehot'],['node_clustering_coefficient'],['node_pagerank']] 93 | dataset.augment_label label ['graph_path_len'] 94 | train.batch_size batch [16,32,64] 95 | gnn.layers_pre_mp l_pre [1,2,3] 96 | gnn.layers_mp l_mp [2,4,6,8] 97 | gnn.layers_post_mp l_post [1,2,3] 98 | gnn.stage_type stage ['stack','skipsum','skipconcat'] 99 | gnn.batchnorm bn [True,False] 100 | gnn.act act ['relu','prelu','swish'] 101 | gnn.dropout drop [0.0,0.3,0.6] 102 | gnn.agg agg ['add','mean','max'] 103 | optim.optimizer optim ['adam','sgd'] 104 | optim.base_lr lr [0.1,0.01,0.001] 105 | optim.max_epoch epoch [100,200,400] -------------------------------------------------------------------------------- /run/grids/design/round1att.txt: -------------------------------------------------------------------------------- 1 | # Format for each row: name in config.py; alias; range to search 2 | # No spaces, except between these 3 fields 3 | # Line breaks are used to union different grid search spaces 4 | # Feel free to add '#' to add comments 5 | 6 | 7 | # dataset: TU, task: graph 8 | dataset.format format ['PyG'] 9 | dataset.name dataset ['TU_BZR','TU_COX2','TU_DD','TU_IMDB','TU_ENZYMES','TU_PROTEINS'] 10 | dataset.task task ['graph'] 11 | dataset.transductive trans [False] 12 | dataset.augment_feature feature [[]] 13 | dataset.augment_label label [''] 14 | # train.batch_size batch [16,32,64] 15 | gnn.layer_type l_t ['generalconv','gaddconv','gmulconv'] 16 | gnn.layers_pre_mp l_pre [1,2,3] 17 | gnn.layers_mp l_mp [2,4,6,8] 18 | gnn.layers_post_mp l_post [1,2,3] 19 | gnn.stage_type stage ['stack','skipsum','skipconcat'] 20 | gnn.batchnorm bn [True,False] 21 | gnn.act act ['relu','prelu','swish'] 22 | gnn.dropout drop [0.0,0.3,0.6] 23 | gnn.agg agg ['add','mean','max'] 24 | optim.optimizer optim ['adam','sgd'] 25 | optim.base_lr lr [0.1,0.01,0.001] 26 | optim.max_epoch epoch [100,200,400] 27 | 28 | # dataset: Single, task: node 29 | dataset.format format ['PyG'] 30 | dataset.name dataset ['Cora','CiteSeer','CoauthorCS','CoauthorPhysics','AmazonComputers','AmazonPhoto'] 31 | dataset.task task ['node'] 32 | dataset.transductive trans [True] 33 | dataset.augment_feature feature [[]] 34 | dataset.augment_label label [''] 35 | # train.batch_size batch [16,32,64] 36 | gnn.layer_type l_t ['generalconv','gaddconv','gmulconv'] 37 | gnn.layers_pre_mp l_pre [1,2,3] 38 | gnn.layers_mp l_mp [2,4,6,8] 39 | gnn.layers_post_mp l_post [1,2,3] 40 | gnn.stage_type stage ['stack','skipsum','skipconcat'] 41 | gnn.batchnorm bn [True,False] 42 | gnn.act act ['relu','prelu','swish'] 43 | gnn.dropout drop [0.0,0.3,0.6] 44 | gnn.agg agg ['add','mean','max'] 45 | optim.optimizer optim ['adam','sgd'] 46 | optim.base_lr lr [0.1,0.01,0.001] 47 | optim.max_epoch epoch [100,200,400] 48 | 49 | 50 | # dataset: nx, task: node, label: clustering 51 | dataset.format format ['nx'] 52 | dataset.name dataset ['smallworld','scalefree'] 53 | dataset.task task ['node'] 54 | dataset.transductive trans [True] 55 | dataset.augment_feature feature [['node_const'],['node_onehot'],['node_pagerank']] 56 | dataset.augment_label label ['node_clustering_coefficient'] 57 | # train.batch_size batch [16,32,64] 58 | gnn.layer_type l_t ['generalconv','gaddconv','gmulconv'] 59 | gnn.layers_pre_mp l_pre [1,2,3] 60 | gnn.layers_mp l_mp [2,4,6,8] 61 | gnn.layers_post_mp l_post [1,2,3] 62 | gnn.stage_type stage ['stack','skipsum','skipconcat'] 63 | gnn.batchnorm bn [True,False] 64 | gnn.act act ['relu','prelu','swish'] 65 | gnn.dropout drop [0.0,0.3,0.6] 66 | gnn.agg agg ['add','mean','max'] 67 | optim.optimizer optim ['adam','sgd'] 68 | optim.base_lr lr [0.1,0.01,0.001] 69 | optim.max_epoch epoch [100,200,400] 70 | 71 | # dataset: nx, task: node, label: pagerank 72 | dataset.format format ['nx'] 73 | dataset.name dataset ['smallworld','scalefree'] 74 | dataset.task task ['node'] 75 | dataset.transductive trans [True] 76 | dataset.augment_feature feature [['node_const'],['node_onehot'],['node_clustering_coefficient']] 77 | dataset.augment_label label ['node_pagerank'] 78 | # train.batch_size batch [16,32,64] 79 | gnn.layer_type l_t ['generalconv','gaddconv','gmulconv'] 80 | gnn.layers_pre_mp l_pre [1,2,3] 81 | gnn.layers_mp l_mp [2,4,6,8] 82 | gnn.layers_post_mp l_post [1,2,3] 83 | gnn.stage_type stage ['stack','skipsum','skipconcat'] 84 | gnn.batchnorm bn [True,False] 85 | gnn.act act ['relu','prelu','swish'] 86 | gnn.dropout drop [0.0,0.3,0.6] 87 | gnn.agg agg ['add','mean','max'] 88 | optim.optimizer optim ['adam','sgd'] 89 | optim.base_lr lr [0.1,0.01,0.001] 90 | optim.max_epoch epoch [100,200,400] 91 | 92 | # dataset: nx, task: graph, label: graph_path_len 93 | dataset.format format ['nx'] 94 | dataset.name dataset ['smallworld','scalefree'] 95 | dataset.task task ['graph'] 96 | dataset.transductive trans [True] 97 | dataset.augment_feature feature [['node_const'],['node_onehot'],['node_clustering_coefficient'],['node_pagerank']] 98 | dataset.augment_label label ['graph_path_len'] 99 | # train.batch_size batch [16,32,64] 100 | gnn.layer_type l_t ['generalconv','gaddconv','gmulconv'] 101 | gnn.layers_pre_mp l_pre [1,2,3] 102 | gnn.layers_mp l_mp [2,4,6,8] 103 | gnn.layers_post_mp l_post [1,2,3] 104 | gnn.stage_type stage ['stack','skipsum','skipconcat'] 105 | gnn.batchnorm bn [True,False] 106 | gnn.act act ['relu','prelu','swish'] 107 | gnn.dropout drop [0.0,0.3,0.6] 108 | gnn.agg agg ['add','mean','max'] 109 | optim.optimizer optim ['adam','sgd'] 110 | optim.base_lr lr [0.1,0.01,0.001] 111 | optim.max_epoch epoch [100,200,400] -------------------------------------------------------------------------------- /run/grids/design/round2.txt: -------------------------------------------------------------------------------- 1 | # Format for each row: name in config.py; alias; range to search 2 | # No spaces, except between these 3 fields 3 | # Line breaks are used to union different grid search spaces 4 | # Feel free to add '#' to add comments 5 | 6 | 7 | # dataset: TU, task: graph 8 | dataset.format format ['PyG'] 9 | dataset.name dataset ['TU_BZR','TU_COX2','TU_DD','TU_IMDB','TU_ENZYMES','TU_PROTEINS'] 10 | dataset.task task ['graph'] 11 | dataset.transductive trans [False] 12 | dataset.augment_feature feature [[]] 13 | dataset.augment_label label [''] 14 | gnn.layers_pre_mp l_pre [1,2] 15 | gnn.layers_mp l_mp [2,4,6,8] 16 | gnn.layers_post_mp l_post [2,3] 17 | gnn.stage_type stage ['skipsum','skipconcat'] 18 | gnn.agg agg ['add','mean','max'] 19 | 20 | # dataset: Single, task: node 21 | dataset.format format ['PyG'] 22 | dataset.name dataset ['Cora','CiteSeer','CoauthorCS','CoauthorPhysics','AmazonComputers','AmazonPhoto'] 23 | dataset.task task ['node'] 24 | dataset.transductive trans [True] 25 | dataset.augment_feature feature [[]] 26 | dataset.augment_label label [''] 27 | gnn.layers_pre_mp l_pre [1,2] 28 | gnn.layers_mp l_mp [2,4,6,8] 29 | gnn.layers_post_mp l_post [2,3] 30 | gnn.stage_type stage ['skipsum','skipconcat'] 31 | gnn.agg agg ['add','mean','max'] 32 | 33 | 34 | # dataset: nx, task: node, label: clustering 35 | dataset.format format ['nx'] 36 | dataset.name dataset ['smallworld','scalefree'] 37 | dataset.task task ['node'] 38 | dataset.transductive trans [True] 39 | dataset.augment_feature feature [['node_const'],['node_onehot'],['node_pagerank']] 40 | dataset.augment_label label ['node_clustering_coefficient'] 41 | gnn.layers_pre_mp l_pre [1,2] 42 | gnn.layers_mp l_mp [2,4,6,8] 43 | gnn.layers_post_mp l_post [2,3] 44 | gnn.stage_type stage ['skipsum','skipconcat'] 45 | gnn.agg agg ['add','mean','max'] 46 | 47 | # dataset: nx, task: node, label: pagerank 48 | dataset.format format ['nx'] 49 | dataset.name dataset ['smallworld','scalefree'] 50 | dataset.task task ['node'] 51 | dataset.transductive trans [True] 52 | dataset.augment_feature feature [['node_const'],['node_onehot'],['node_clustering_coefficient']] 53 | dataset.augment_label label ['node_pagerank'] 54 | gnn.layers_pre_mp l_pre [1,2] 55 | gnn.layers_mp l_mp [2,4,6,8] 56 | gnn.layers_post_mp l_post [2,3] 57 | gnn.stage_type stage ['skipsum','skipconcat'] 58 | gnn.agg agg ['add','mean','max'] 59 | 60 | # dataset: nx, task: graph, label: graph_path_len 61 | dataset.format format ['nx'] 62 | dataset.name dataset ['smallworld','scalefree'] 63 | dataset.task task ['graph'] 64 | dataset.transductive trans [True] 65 | dataset.augment_feature feature [['node_const'],['node_onehot'],['node_clustering_coefficient'],['node_pagerank']] 66 | dataset.augment_label label ['graph_path_len'] 67 | gnn.layers_pre_mp l_pre [1,2] 68 | gnn.layers_mp l_mp [2,4,6,8] 69 | gnn.layers_post_mp l_post [2,3] 70 | gnn.stage_type stage ['skipsum','skipconcat'] 71 | gnn.agg agg ['add','mean','max'] -------------------------------------------------------------------------------- /run/grids/design/round2link.txt: -------------------------------------------------------------------------------- 1 | # Format for each row: name in config.py; alias; range to search 2 | # No spaces, except between these 3 fields 3 | # Line breaks are used to union different grid search spaces 4 | # Feel free to add '#' to add comments 5 | 6 | 7 | # dataset: TU, task: graph 8 | dataset.format format ['PyG'] 9 | dataset.name dataset ['TU_BZR','TU_COX2','TU_DD','TU_IMDB','TU_ENZYMES','TU_PROTEINS','Cora','CiteSeer','CoauthorCS','CoauthorPhysics','AmazonComputers','AmazonPhoto'] 10 | dataset.task task ['link_pred'] 11 | dataset.transductive trans [True] 12 | dataset.augment_feature feature [[]] 13 | dataset.augment_label label [''] 14 | gnn.layers_pre_mp l_pre [1,2] 15 | gnn.layers_mp l_mp [2,4,6,8] 16 | gnn.layers_post_mp l_post [2,3] 17 | gnn.stage_type stage ['skipsum','skipconcat'] 18 | gnn.agg agg ['add','mean','max'] 19 | -------------------------------------------------------------------------------- /run/grids/design/round2ogb.txt: -------------------------------------------------------------------------------- 1 | # Format for each row: name in config.py; alias; range to search 2 | # No spaces, except between these 3 fields 3 | # Line breaks are used to union different grid search spaces 4 | # Feel free to add '#' to add comments 5 | 6 | 7 | # dataset: TU, task: graph 8 | dataset.format format ['OGB'] 9 | dataset.name dataset ['ogbg-molhiv'] 10 | dataset.task task ['graph'] 11 | dataset.transductive trans [False] 12 | dataset.augment_feature feature [[]] 13 | dataset.augment_label label [''] 14 | gnn.layers_pre_mp l_pre [1,2] 15 | gnn.layers_mp l_mp [2,4,6,8] 16 | gnn.layers_post_mp l_post [2,3] 17 | gnn.stage_type stage ['skipsum','skipconcat'] 18 | gnn.agg agg ['add','mean','max'] -------------------------------------------------------------------------------- /run/grids/example.txt: -------------------------------------------------------------------------------- 1 | # Format for each row: name in config.py; alias; range to search 2 | # No spaces, except between these 3 fields 3 | # Line breaks are used to union different grid search spaces 4 | # Feel free to add '#' to add comments 5 | 6 | 7 | # (1) dataset configurations 8 | dataset.format format ['PyG'] 9 | dataset.name dataset ['TU_ENZYMES','TU_PROTEINS'] 10 | dataset.task task ['graph'] 11 | dataset.transductive trans [False] 12 | dataset.augment_feature feature [[]] 13 | dataset.augment_label label [''] 14 | # (2) The recommended GNN design space, 96 models in total 15 | gnn.layers_pre_mp l_pre [1,2] 16 | gnn.layers_mp l_mp [2,4,6,8] 17 | gnn.layers_post_mp l_post [2,3] 18 | gnn.stage_type stage ['skipsum','skipconcat'] 19 | gnn.agg agg ['add','mean','max'] 20 | 21 | -------------------------------------------------------------------------------- /run/grids/pyg/example.txt: -------------------------------------------------------------------------------- 1 | # Format for each row: name in config.py; alias; range to search 2 | # No spaces, except between these 3 fields 3 | # Line breaks are used to union different grid search spaces 4 | # Feel free to add '#' to add comments 5 | 6 | 7 | gnn.layers_pre_mp l_pre [1,2] 8 | gnn.layers_mp l_mp [2,4,6] 9 | gnn.layers_post_mp l_post [1,2] 10 | gnn.stage_type stage ['stack','skipsum','skipconcat'] 11 | gnn.dim_inner dim [64,128,256] 12 | optim.base_lr lr [0.001,0.01] 13 | optim.max_epoch epoch [200,800,1600] 14 | -------------------------------------------------------------------------------- /run/main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import torch 5 | from torch_geometric import seed_everything 6 | 7 | from graphgym.cmd_args import parse_args 8 | from graphgym.config import cfg, dump_cfg, load_cfg, set_run_dir, set_out_dir 9 | from graphgym.loader import create_dataset, create_loader 10 | from graphgym.logger import create_logger, setup_printing 11 | from graphgym.model_builder import create_model 12 | from graphgym.optimizer import create_optimizer, create_scheduler 13 | from graphgym.register import train_dict 14 | from graphgym.train import train 15 | from graphgym.utils.agg_runs import agg_runs 16 | from graphgym.utils.comp_budget import params_count 17 | from graphgym.utils.device import auto_select_device 18 | 19 | if __name__ == '__main__': 20 | # Load cmd line args 21 | args = parse_args() 22 | # Load config file 23 | load_cfg(cfg, args) 24 | set_out_dir(cfg.out_dir, args.cfg_file) 25 | # Set Pytorch environment 26 | torch.set_num_threads(cfg.num_threads) 27 | dump_cfg(cfg) 28 | # Repeat for different random seeds 29 | for i in range(args.repeat): 30 | set_run_dir(cfg.out_dir) 31 | setup_printing() 32 | # Set configurations for each run 33 | cfg.seed = cfg.seed + 1 34 | seed_everything(cfg.seed) 35 | auto_select_device() 36 | # Set machine learning pipeline 37 | datasets = create_dataset() 38 | loaders = create_loader(datasets) 39 | loggers = create_logger() 40 | model = create_model() 41 | optimizer = create_optimizer(model.parameters()) 42 | scheduler = create_scheduler(optimizer) 43 | # Print model info 44 | logging.info(model) 45 | logging.info(cfg) 46 | cfg.params = params_count(model) 47 | logging.info('Num parameters: %s', cfg.params) 48 | # Start training 49 | if cfg.train.mode == 'standard': 50 | train(loggers, loaders, model, optimizer, scheduler) 51 | else: 52 | train_dict[cfg.train.mode](loggers, loaders, model, optimizer, 53 | scheduler) 54 | # Aggregate results from different seeds 55 | agg_runs(cfg.out_dir, cfg.metric_best) 56 | # When being launched in batch mode, mark a yaml as done 57 | if args.mark_done: 58 | os.rename(args.cfg_file, f'{args.cfg_file}_done') 59 | -------------------------------------------------------------------------------- /run/main_pyg.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import torch 5 | from torch_geometric import seed_everything 6 | 7 | from graphgym.cmd_args import parse_args 8 | from graphgym.config import cfg, dump_cfg, load_cfg, set_run_dir, set_out_dir 9 | from graphgym.loader_pyg import create_loader 10 | from graphgym.logger import create_logger, setup_printing 11 | from graphgym.model_builder_pyg import create_model 12 | from graphgym.optimizer import create_optimizer, create_scheduler 13 | from graphgym.register import train_dict 14 | from graphgym.train_pyg import train 15 | from graphgym.utils.agg_runs import agg_runs 16 | from graphgym.utils.comp_budget import params_count 17 | from graphgym.utils.device import auto_select_device 18 | 19 | if __name__ == '__main__': 20 | # Load cmd line args 21 | args = parse_args() 22 | # Load config file 23 | load_cfg(cfg, args) 24 | set_out_dir(cfg.out_dir, args.cfg_file) 25 | # Set Pytorch environment 26 | torch.set_num_threads(cfg.num_threads) 27 | dump_cfg(cfg) 28 | # Repeat for different random seeds 29 | for i in range(args.repeat): 30 | set_run_dir(cfg.out_dir) 31 | setup_printing() 32 | # Set configurations for each run 33 | cfg.seed = cfg.seed + 1 34 | seed_everything(cfg.seed) 35 | auto_select_device() 36 | # Set machine learning pipeline 37 | loaders = create_loader() 38 | loggers = create_logger() 39 | model = create_model() 40 | optimizer = create_optimizer(model.parameters()) 41 | scheduler = create_scheduler(optimizer) 42 | # Print model info 43 | logging.info(model) 44 | logging.info(cfg) 45 | cfg.params = params_count(model) 46 | logging.info('Num parameters: %s', cfg.params) 47 | # Start training 48 | if cfg.train.mode == 'standard': 49 | train(loggers, loaders, model, optimizer, scheduler) 50 | else: 51 | train_dict[cfg.train.mode](loggers, loaders, model, optimizer, 52 | scheduler) 53 | # Aggregate results from different seeds 54 | agg_runs(cfg.out_dir, cfg.metric_best) 55 | # When being launched in batch mode, mark a yaml as done 56 | if args.mark_done: 57 | os.rename(args.cfg_file, f'{args.cfg_file}_done') 58 | -------------------------------------------------------------------------------- /run/parallel.sh: -------------------------------------------------------------------------------- 1 | CONFIG_DIR=$1 2 | REPEAT=$2 3 | MAX_JOBS=${3:-2} 4 | MAIN=${4:-main} 5 | 6 | ( 7 | trap 'kill 0' SIGINT 8 | CUR_JOBS=0 9 | for CONFIG in "$CONFIG_DIR"/*.yaml; do 10 | if [ "$CONFIG" != "$CONFIG_DIR/*.yaml" ]; then 11 | ((CUR_JOBS >= MAX_JOBS)) && wait -n 12 | echo "Job launched: $CONFIG" 13 | python $MAIN.py --cfg $CONFIG --repeat $REPEAT --mark_done & 14 | ((CUR_JOBS < MAX_JOBS)) && sleep 1 15 | ((++CUR_JOBS)) 16 | fi 17 | done 18 | 19 | wait 20 | ) 21 | -------------------------------------------------------------------------------- /run/run_batch.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=example 4 | GRID=example 5 | REPEAT=3 6 | MAX_JOBS=8 7 | 8 | # generate configs (after controlling computational budget) 9 | # please remove --config_budget, if don't control computational budget 10 | python configs_gen.py --config configs/${CONFIG}.yaml \ 11 | --config_budget configs/${CONFIG}.yaml \ 12 | --grid grids/${GRID}.txt \ 13 | --out_dir configs 14 | #python configs_gen.py --config configs/ChemKG/${CONFIG}.yaml --config_budget configs/ChemKG/${CONFIG}.yaml --grid grids/ChemKG/${GRID}.txt --out_dir configs 15 | # run batch of configs 16 | # Args: config_dir, num of repeats, max jobs running 17 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 18 | # rerun missed / stopped experiments 19 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 20 | # rerun missed / stopped experiments 21 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 22 | 23 | # aggregate results for the batch 24 | python agg_batch.py --dir results/${CONFIG}_grid_${GRID} 25 | -------------------------------------------------------------------------------- /run/run_batch_pyg.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=example_node 4 | GRID=example 5 | REPEAT=3 6 | MAX_JOBS=8 7 | MAIN=main_pyg 8 | 9 | # generate configs (after controlling computational budget) 10 | # please remove --config_budget, if don't control computational budget 11 | python configs_gen.py --config configs/pyg/${CONFIG}.yaml \ 12 | --grid grids/pyg/${GRID}.txt \ 13 | --out_dir configs 14 | #python configs_gen.py --config configs/ChemKG/${CONFIG}.yaml --config_budget configs/ChemKG/${CONFIG}.yaml --grid grids/ChemKG/${GRID}.txt --out_dir configs 15 | # run batch of configs 16 | # Args: config_dir, num of repeats, max jobs running 17 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS $MAIN 18 | # rerun missed / stopped experiments 19 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS $MAIN 20 | # rerun missed / stopped experiments 21 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS $MAIN 22 | 23 | # aggregate results for the batch 24 | python agg_batch.py --dir results/${CONFIG}_grid_${GRID} 25 | -------------------------------------------------------------------------------- /run/run_single.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Test for running a single experiment. --repeat means run how many different random seeds. 4 | python main.py --cfg configs/example.yaml --repeat 3 5 | -------------------------------------------------------------------------------- /run/run_single_cpu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Test for running a single experiment. --repeat means run how many different random seeds. 4 | python main.py --cfg configs/example_cpu.yaml --repeat 3 5 | -------------------------------------------------------------------------------- /run/run_single_pyg.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Test for running a single experiment. --repeat means run how many different random seeds. 4 | python main_pyg.py --cfg configs/pyg/example_node.yaml --repeat 3 # node classification 5 | python main_pyg.py --cfg configs/pyg/example_link.yaml --repeat 3 # link prediction 6 | python main_pyg.py --cfg configs/pyg/example_graph.yaml --repeat 3 # graph classification 7 | -------------------------------------------------------------------------------- /run/sample/dimensions.txt: -------------------------------------------------------------------------------- 1 | act bn drop agg l_mp l_pre l_post stage batch lr optim epoch -------------------------------------------------------------------------------- /run/sample/dimensionsatt.txt: -------------------------------------------------------------------------------- 1 | l_tw -------------------------------------------------------------------------------- /run/scripts/IDGNN/run_idgnn_edge.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | cd ../.. 4 | 5 | DIR=IDGNN 6 | CONFIG=edge 7 | GRID=link 8 | REPEAT=3 9 | MAX_JOBS=4 10 | 11 | # generate configs (after controlling computational budget) 12 | # please remove --config_budget, if don't control computational budget 13 | python configs_gen.py --config configs/${DIR}/${CONFIG}.yaml \ 14 | --config_budget configs/${DIR}/${CONFIG}.yaml \ 15 | --grid grids/${DIR}/${GRID}.txt \ 16 | --out_dir configs 17 | # run batch of configs 18 | # Args: config_dir, num of repeats, max jobs running 19 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 20 | # rerun missed / stopped experiments 21 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 22 | # rerun missed / stopped experiments 23 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 24 | 25 | # aggregate results for the batch 26 | python agg_batch.py --dir results/${CONFIG}_grid_${GRID} 27 | 28 | 29 | 30 | 31 | # predicting enzymes dataset (use a smaller model) 32 | 33 | DIR=IDGNN 34 | CONFIG=edge 35 | GRID=path 36 | REPEAT=3 37 | MAX_JOBS=4 38 | 39 | # generate configs (after controlling computational budget) 40 | # please remove --config_budget, if don't control computational budget 41 | python configs_gen.py --config configs/${DIR}/${CONFIG}.yaml \ 42 | --config_budget configs/${DIR}/${CONFIG}.yaml \ 43 | --grid grids/${DIR}/${GRID}.txt \ 44 | --out_dir configs 45 | # run batch of configs 46 | # Args: config_dir, num of repeats, max jobs running 47 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 48 | # rerun missed / stopped experiments 49 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 50 | # rerun missed / stopped experiments 51 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 52 | 53 | # aggregate results for the batch 54 | python agg_batch.py --dir results/${CONFIG}_grid_${GRID} 55 | -------------------------------------------------------------------------------- /run/scripts/IDGNN/run_idgnn_graph.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | cd ../.. 4 | 5 | DIR=IDGNN 6 | CONFIG=graph 7 | GRID=graph 8 | REPEAT=3 9 | MAX_JOBS=4 10 | 11 | # generate configs (after controlling computational budget) 12 | # please remove --config_budget, if don't control computational budget 13 | python configs_gen.py --config configs/${DIR}/${CONFIG}.yaml \ 14 | --config_budget configs/${DIR}/${CONFIG}.yaml \ 15 | --grid grids/${DIR}/${GRID}.txt \ 16 | --out_dir configs 17 | # run batch of configs 18 | # Args: config_dir, num of repeats, max jobs running 19 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 20 | # rerun missed / stopped experiments 21 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 22 | # rerun missed / stopped experiments 23 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 24 | 25 | # aggregate results for the batch 26 | python agg_batch.py --dir results/${CONFIG}_grid_${GRID} 27 | 28 | 29 | 30 | 31 | # predicting enzymes dataset (use a smaller model) 32 | 33 | DIR=IDGNN 34 | CONFIG=graph_enzyme 35 | GRID=graph_enzyme 36 | REPEAT=3 37 | MAX_JOBS=4 38 | 39 | # generate configs (after controlling computational budget) 40 | # please remove --config_budget, if don't control computational budget 41 | python configs_gen.py --config configs/${DIR}/${CONFIG}.yaml \ 42 | --config_budget configs/${DIR}/${CONFIG}.yaml \ 43 | --grid grids/${DIR}/${GRID}.txt \ 44 | --out_dir configs 45 | # run batch of configs 46 | # Args: config_dir, num of repeats, max jobs running 47 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 48 | # rerun missed / stopped experiments 49 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 50 | # rerun missed / stopped experiments 51 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 52 | 53 | # aggregate results for the batch 54 | python agg_batch.py --dir results/${CONFIG}_grid_${GRID} 55 | 56 | 57 | 58 | 59 | # predicting ogb dataset (use a bigger model) 60 | 61 | DIR=IDGNN 62 | CONFIG=graph_ogb 63 | GRID=graph_ogb 64 | REPEAT=3 65 | MAX_JOBS=4 66 | 67 | # generate configs (after controlling computational budget) 68 | # please remove --config_budget, if don't control computational budget 69 | python configs_gen.py --config configs/${DIR}/${CONFIG}.yaml \ 70 | --config_budget configs/${DIR}/${CONFIG}.yaml \ 71 | --grid grids/${DIR}/${GRID}.txt \ 72 | --out_dir configs 73 | # run batch of configs 74 | # Args: config_dir, num of repeats, max jobs running 75 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 76 | # rerun missed / stopped experiments 77 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 78 | # rerun missed / stopped experiments 79 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 80 | 81 | # aggregate results for the batch 82 | python agg_batch.py --dir results/${CONFIG}_grid_${GRID} -------------------------------------------------------------------------------- /run/scripts/IDGNN/run_idgnn_node.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | cd ../.. 4 | 5 | DIR=IDGNN 6 | CONFIG=node 7 | GRID=node 8 | REPEAT=3 9 | MAX_JOBS=4 10 | 11 | # generate configs (after controlling computational budget) 12 | # please remove --config_budget, if don't control computational budget 13 | python configs_gen.py --config configs/${DIR}/${CONFIG}.yaml \ 14 | --config_budget configs/${DIR}/${CONFIG}.yaml \ 15 | --grid grids/${DIR}/${GRID}.txt \ 16 | --out_dir configs 17 | # run batch of configs 18 | # Args: config_dir, num of repeats, max jobs running 19 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 20 | # rerun missed / stopped experiments 21 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 22 | # rerun missed / stopped experiments 23 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 24 | 25 | # aggregate results for the batch 26 | python agg_batch.py --dir results/${CONFIG}_grid_${GRID} 27 | 28 | 29 | 30 | 31 | # predicting node clustering coefficient 32 | 33 | DIR=IDGNN 34 | CONFIG=node 35 | GRID=node_clustering 36 | REPEAT=3 37 | MAX_JOBS=4 38 | 39 | # generate configs (after controlling computational budget) 40 | # please remove --config_budget, if don't control computational budget 41 | python configs_gen.py --config configs/${DIR}/${CONFIG}.yaml \ 42 | --config_budget configs/${DIR}/${CONFIG}.yaml \ 43 | --grid grids/${DIR}/${GRID}.txt \ 44 | --out_dir configs 45 | # run batch of configs 46 | # Args: config_dir, num of repeats, max jobs running 47 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 48 | # rerun missed / stopped experiments 49 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 50 | # rerun missed / stopped experiments 51 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 52 | 53 | # aggregate results for the batch 54 | python agg_batch.py --dir results/${CONFIG}_grid_${GRID} -------------------------------------------------------------------------------- /run/scripts/design/run_design_round1.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | cd ../.. 4 | 5 | DIR=design 6 | CONFIG=design_v1 7 | GRID=round1 8 | SAMPLE_ALIAS=dimensions 9 | SAMPLE_NUM=96 10 | REPEAT=3 11 | MAX_JOBS=8 12 | 13 | # generate configs (after controlling computational budget) 14 | # please remove --config_budget, if don't control computational budget 15 | python configs_gen.py --config configs/${DIR}/${CONFIG}.yaml \ 16 | --config_budget configs/${DIR}/${CONFIG}.yaml \ 17 | --grid grids/${DIR}/${GRID}.txt \ 18 | --sample_alias sample/${SAMPLE_ALIAS}.txt \ 19 | --sample_num $SAMPLE_NUM \ 20 | --out_dir configs 21 | # run batch of configs 22 | # Args: config_dir, num of repeats, max jobs running 23 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 24 | # rerun missed / stopped experiments 25 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 26 | # rerun missed / stopped experiments 27 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 28 | 29 | # aggregate results for the batch 30 | python agg_batch.py --dir results/${CONFIG}_grid_${GRID} 31 | -------------------------------------------------------------------------------- /run/scripts/design/run_design_round2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | cd ../.. 4 | 5 | DIR=design 6 | CONFIG=design_v2 7 | GRID=round2 8 | REPEAT=3 9 | MAX_JOBS=8 10 | 11 | # generate configs (after controlling computational budget) 12 | # please remove --config_budget, if don't control computational budget 13 | python configs_gen.py --config configs/${DIR}/${CONFIG}.yaml \ 14 | --config_budget configs/${DIR}/${CONFIG}.yaml \ 15 | --grid grids/${DIR}/${GRID}.txt \ 16 | --out_dir configs 17 | # run batch of configs 18 | # Args: config_dir, num of repeats, max jobs running 19 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 20 | # rerun missed / stopped experiments 21 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 22 | # rerun missed / stopped experiments 23 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 24 | 25 | # aggregate results for the batch 26 | python agg_batch.py --dir results/${CONFIG}_grid_${GRID} 27 | -------------------------------------------------------------------------------- /run/scripts/design/run_design_round2ogb.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | cd ../.. 4 | 5 | # OGB results 6 | DIR=design 7 | CONFIG=design_v2ogb 8 | GRID=round2ogb 9 | REPEAT=3 10 | MAX_JOBS=8 11 | 12 | # generate configs (after controlling computational budget) 13 | # please remove --config_budget, if don't control computational budget 14 | python configs_gen.py --config configs/${DIR}/${CONFIG}.yaml \ 15 | --config_budget configs/${DIR}/${CONFIG}.yaml \ 16 | --grid grids/${DIR}/${GRID}.txt \ 17 | --out_dir configs 18 | # run batch of configs 19 | # Args: config_dir, num of repeats, max jobs running 20 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 21 | # rerun missed / stopped experiments 22 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 23 | # rerun missed / stopped experiments 24 | bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS 25 | 26 | # aggregate results for the batch 27 | python agg_batch.py --dir results/${CONFIG}_grid_${GRID} 28 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="graphgym", 8 | version="0.4.0", 9 | author="Jiaxuan You", 10 | author_email="jiaxuan@cs.stanford.edu", 11 | description="GraphGym: platform for designing and " 12 | "evaluating Graph Neural Networks (GNN)", 13 | long_description=long_description, 14 | long_description_content_type="text/markdown", 15 | url="https://github.com/snap-stanford/graphgym", 16 | packages=setuptools.find_packages(), 17 | install_requires=[ 18 | 'yacs', 19 | 'tensorboardx', 20 | 'torch', 21 | 'torch-geometric', 22 | 'networkx', 23 | 'numpy', 24 | 'deepsnap', 25 | 'ogb', 26 | ], 27 | classifiers=[ 28 | "Programming Language :: Python :: 3", 29 | "License :: OSI Approved :: MIT License", 30 | "Operating System :: OS Independent", 31 | ], 32 | python_requires='>=3.6', 33 | ) 34 | --------------------------------------------------------------------------------