├── requirements.txt ├── .gitignore ├── tu ├── configs │ ├── NCI1.json │ ├── COLLAB.json │ ├── ENZYMES.json │ ├── REDDIT-BINARY.json │ └── REDDIT-MULTI-12K.json ├── run_script.sh ├── model.py ├── conv.py └── main.py ├── qm9 ├── QM9.json ├── run_script.sh ├── model.py ├── main.py └── conv.py ├── ogbg ├── ppa │ ├── ogbg-ppa.json │ ├── run_script.sh │ ├── model.py │ ├── main.py │ └── conv.py ├── mol │ ├── ogbg-mol.json │ ├── run_script.sh │ ├── model.py │ ├── main.py │ └── conv.py └── code │ ├── ogbg-code.json │ ├── run_script.sh │ ├── model.py │ ├── proc.py │ ├── conv.py │ └── main.py ├── utils ├── jumping_knowledge.py ├── config.py └── qm9.py ├── LICENSE └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch>=1.4.0 3 | torch-scatter 4 | torch-sparse 5 | torch-cluster 6 | torch-spline-conv 7 | torch-geometric==1.5.0 8 | ogb==1.2.1 9 | easydict 10 | tensorboardX 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | ogbg/*/dataset/ 3 | data/ 4 | build/ 5 | dist/ 6 | alpha/ 7 | .cache/ 8 | .eggs/ 9 | *.egg-info/ 10 | .idea/ 11 | .coverage 12 | .coverage.* 13 | core.[0-9]* 14 | 15 | !torch_geometric/data/ 16 | !test/data/ 17 | -------------------------------------------------------------------------------- /tu/configs/NCI1.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "NCI1", 3 | "num_workers": 4, 4 | "hyperparams": { 5 | "batch_size": 64, 6 | "epochs": 501, 7 | "learning_rate": 0.0004, 8 | "step_size": 50, 9 | "decay_rate": 0.75 10 | }, 11 | "architecture": { 12 | "layers": 4, 13 | "hidden": 128, 14 | "pooling": "add", 15 | "JK": "cat", 16 | "methods": "CB", 17 | "dropout": 0.5, 18 | "variants": { 19 | "BN": "Y" 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /tu/configs/COLLAB.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "COLLAB", 3 | "num_workers": 4, 4 | "hyperparams": { 5 | "batch_size": 64, 6 | "epochs": 301, 7 | "learning_rate": 0.0005, 8 | "step_size": 50, 9 | "decay_rate": 0.75 10 | }, 11 | "architecture": { 12 | "layers": 3, 13 | "hidden": 128, 14 | "pooling": "add", 15 | "JK": "cat", 16 | "methods": "EB4", 17 | "dropout": 0.5, 18 | "variants": { 19 | "BN": "Y" 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /tu/configs/ENZYMES.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "ENZYMES", 3 | "num_workers": 4, 4 | "hyperparams": { 5 | "batch_size": 32, 6 | "epochs": 501, 7 | "learning_rate": 0.0005, 8 | "step_size": 50, 9 | "decay_rate": 0.75 10 | }, 11 | "architecture": { 12 | "layers": 2, 13 | "hidden": 512, 14 | "pooling": "add", 15 | "JK": "cat", 16 | "methods": "EB4", 17 | "dropout": 0.5, 18 | "variants": { 19 | "BN": "Y" 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /qm9/QM9.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "QM9", 3 | "target": 0, 4 | "seed": 777, 5 | "hyperparams": { 6 | "batch_size": 64, 7 | "epochs": 601, 8 | "learning_rate": 0.0001, 9 | "step_size": 30, 10 | "decay_rate":0.85 11 | }, 12 | "architecture": { 13 | "layers": 4, 14 | "hidden": 256, 15 | "pooling": "add", 16 | "JK": "cat", 17 | "methods": "EB1", 18 | "variants": { 19 | "BN": "N", 20 | "fea_activation": "ReLU" 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /tu/configs/REDDIT-BINARY.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "REDDIT-BINARY", 3 | "num_workers": 4, 4 | "hyperparams": { 5 | "batch_size": 64, 6 | "epochs": 301, 7 | "learning_rate": 0.001, 8 | "step_size": 10, 9 | "decay_rate": 0.8 10 | }, 11 | "architecture": { 12 | "layers": 3, 13 | "hidden": 256, 14 | "pooling": "add", 15 | "JK": "cat", 16 | "methods": "CB", 17 | "dropout": 0.5, 18 | "variants": { 19 | "BN": "Y" 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /tu/configs/REDDIT-MULTI-12K.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "REDDIT-MULTI-12K", 3 | "num_workers": 4, 4 | "hyperparams": { 5 | "batch_size": 64, 6 | "epochs": 301, 7 | "learning_rate": 0.001, 8 | "step_size": 50, 9 | "decay_rate": 0.75 10 | }, 11 | "architecture": { 12 | "layers": 4, 13 | "hidden": 128, 14 | "pooling": "add", 15 | "JK": "cat", 16 | "methods": "EB4", 17 | "dropout": 0.5, 18 | "variants": { 19 | "BN": "Y" 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /ogbg/ppa/ogbg-ppa.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "ogbg-ppa", 3 | "num_workers": 4, 4 | "feature": "full", 5 | "hyperparams": { 6 | "batch_size": 32, 7 | "epochs": 601, 8 | "learning_rate": 0.0005, 9 | "step_size": 20, 10 | "decay_rate": 0.8 11 | }, 12 | "architecture": { 13 | "layers": 4, 14 | "hidden": 256, 15 | "pooling": "add", 16 | "JK": "cat", 17 | "methods": "EB1", 18 | "dropout": 0.5, 19 | "variants": { 20 | "BN": "Y" 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /utils/jumping_knowledge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class JumpingKnowledge(torch.nn.Module): 5 | 6 | def __init__(self, mode): 7 | super(JumpingKnowledge, self).__init__() 8 | self.mode = mode.lower() 9 | assert self.mode in ['cat'] 10 | 11 | def forward(self, xs): 12 | assert isinstance(xs, list) or isinstance(xs, tuple) 13 | 14 | return torch.cat(xs, dim=-1) 15 | 16 | def __repr__(self): 17 | return '{}({})'.format(self.__class__.__name__, self.mode) 18 | -------------------------------------------------------------------------------- /ogbg/mol/ogbg-mol.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "ogbg-molhiv", 3 | "seed": 777, 4 | "num_workers": 4, 5 | "feature": "full", 6 | "hyperparams": { 7 | "batch_size": 64, 8 | "epochs": 401, 9 | "learning_rate": 0.0001, 10 | "step_size": 20, 11 | "decay_rate": 0.8 12 | }, 13 | "architecture": { 14 | "layers": 3, 15 | "hidden": 64, 16 | "pooling": "mean", 17 | "JK": "cat", 18 | "methods": "EB1", 19 | "dropout": 0.5, 20 | "variants": { 21 | "BN": "Y" 22 | } 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /ogbg/code/ogbg-code.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "ogbg-code", 3 | "seed": 777, 4 | "num_workers": 4, 5 | "feature": "full", 6 | "max_seq_len": 5, 7 | "num_vocab": 5000, 8 | "hyperparams": { 9 | "batch_size": 64, 10 | "epochs": 101, 11 | "learning_rate": 0.0001, 12 | "step_size": 5, 13 | "decay_rate": 0.6 14 | }, 15 | "architecture": { 16 | "layers": 4, 17 | "hidden": 512, 18 | "pooling": "mean", 19 | "JK": "cat", 20 | "methods": "EB1", 21 | "dropout": 0.5, 22 | "variants": { 23 | "BN": "Y" 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /qm9/run_script.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | export PYTHONIOENCODING=utf-8 5 | 6 | dataset="QM9" 7 | output_dir="../../epcb_results/QM9/" 8 | config_file="./"$dataset".json" 9 | 10 | time_stamp=`date '+%s'` 11 | commit_id=`git rev-parse HEAD` 12 | std_file=${output_dir}"stdout/"${time_stamp}_${commit_id}".txt" 13 | 14 | mkdir -p $output_dir"stdout/" 15 | 16 | nohup python3 -u ./main.py --config=$config_file --id=$commit_id --ts=$time_stamp --dir=$output_dir"board/" >> $std_file 2>&1 & 17 | 18 | pid=$! 19 | 20 | echo "Stdout dir: $std_file" 21 | echo "Start time: `date -d @$time_stamp '+%Y-%m-%d %H:%M:%S'`" 22 | echo "CUDA DEVICES: $CUDA_VISIBLE_DEVICES" 23 | echo "pid: $pid" 24 | cat $config_file 25 | 26 | sleep 1 27 | 28 | tail -f $std_file 29 | -------------------------------------------------------------------------------- /ogbg/mol/run_script.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | export PYTHONIOENCODING=utf-8 5 | 6 | dataset="ogbg-mol" 7 | output_dir="../../../epcb_results/ogbg/" 8 | config_file="./"$dataset".json" 9 | 10 | time_stamp=`date '+%s'` 11 | commit_id=`git rev-parse HEAD` 12 | std_file=${output_dir}"stdout/"${time_stamp}_${commit_id}".txt" 13 | 14 | mkdir -p $output_dir"stdout/" 15 | 16 | nohup python3 -u ./main.py --config=$config_file --id=$commit_id --ts=$time_stamp --dir=$output_dir"board/" >> $std_file 2>&1 & 17 | 18 | pid=$! 19 | 20 | echo "Stdout dir: $std_file" 21 | echo "Start time: `date -d @$time_stamp '+%Y-%m-%d %H:%M:%S'`" 22 | echo "CUDA DEVICES: $CUDA_VISIBLE_DEVICES" 23 | echo "pid: $pid" 24 | cat $config_file 25 | 26 | sleep 1 27 | 28 | tail -f $std_file 29 | -------------------------------------------------------------------------------- /ogbg/ppa/run_script.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | export PYTHONIOENCODING=utf-8 5 | 6 | dataset="ogbg-ppa" 7 | output_dir="../../../epcb_results/ogbg/" 8 | config_file="./"$dataset".json" 9 | 10 | time_stamp=`date '+%s'` 11 | commit_id=`git rev-parse HEAD` 12 | std_file=${output_dir}"stdout/"${time_stamp}_${commit_id}".txt" 13 | 14 | mkdir -p $output_dir"stdout/" 15 | 16 | nohup python3 -u ./main.py --config=$config_file --id=$commit_id --ts=$time_stamp --dir=$output_dir"board/" >> $std_file 2>&1 & 17 | 18 | pid=$! 19 | 20 | echo "Stdout dir: $std_file" 21 | echo "Start time: `date -d @$time_stamp '+%Y-%m-%d %H:%M:%S'`" 22 | echo "CUDA DEVICES: $CUDA_VISIBLE_DEVICES" 23 | echo "pid: $pid" 24 | cat $config_file 25 | 26 | sleep 1 27 | 28 | tail -f $std_file 29 | -------------------------------------------------------------------------------- /ogbg/code/run_script.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | export PYTHONIOENCODING=utf-8 5 | 6 | dataset="ogbg-code" 7 | output_dir="../../../epcb_results/ogbg/" 8 | config_file="./"$dataset".json" 9 | 10 | time_stamp=`date '+%s'` 11 | commit_id=`git rev-parse HEAD` 12 | std_file=${output_dir}"stdout/"${time_stamp}_${commit_id}".txt" 13 | 14 | mkdir -p $output_dir"stdout/" 15 | 16 | nohup python3 -u ./main.py --config=$config_file --id=$commit_id --ts=$time_stamp --dir=$output_dir"board/" >> $std_file 2>&1 & 17 | 18 | pid=$! 19 | 20 | echo "Stdout dir: $std_file" 21 | echo "Start time: `date -d @$time_stamp '+%Y-%m-%d %H:%M:%S'`" 22 | echo "CUDA DEVICES: $CUDA_VISIBLE_DEVICES" 23 | echo "pid: $pid" 24 | cat $config_file 25 | 26 | sleep 1 27 | 28 | tail -f $std_file 29 | -------------------------------------------------------------------------------- /tu/run_script.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | export PYTHONIOENCODING=utf-8 5 | 6 | dataset="REDDIT-MULTI-12K" 7 | output_dir="../../epcb_results/tu/" 8 | config_file="./configs/"$dataset".json" 9 | 10 | time_stamp=`date '+%s'` 11 | commit_id=`git rev-parse HEAD` 12 | std_file=${output_dir}"stdout/"${time_stamp}_${commit_id}".txt" 13 | 14 | mkdir -p $output_dir"stdout/" 15 | 16 | nohup python3 -u ./main.py --config=$config_file --id=$commit_id --ts=$time_stamp --dir=$output_dir"board/" >> $std_file 2>&1 & 17 | 18 | pid=$! 19 | 20 | echo "Stdout dir: $std_file" 21 | echo "Start time: `date -d @$time_stamp '+%Y-%m-%d %H:%M:%S'`" 22 | echo "CUDA DEVICES: $CUDA_VISIBLE_DEVICES" 23 | echo "pid: $pid" 24 | cat $config_file 25 | 26 | sleep 1 27 | 28 | tail -f $std_file 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Anonymous Authors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from easydict import EasyDict 4 | 5 | 6 | def get_args(): 7 | argparser = argparse.ArgumentParser(description=__doc__) 8 | argparser.add_argument( 9 | '-c', '--config', 10 | metavar='C', 11 | default=None, 12 | help='The Configuration file') 13 | argparser.add_argument( 14 | '-i', '--id', 15 | metavar='I', 16 | default='', 17 | help='The commit id)') 18 | argparser.add_argument( 19 | '-t', '--ts', 20 | metavar='T', 21 | default='', 22 | help='The time stamp)') 23 | argparser.add_argument( 24 | '-d', '--dir', 25 | metavar='D', 26 | default='', 27 | help='The output directory)') 28 | args = argparser.parse_args() 29 | return args 30 | 31 | 32 | def get_config_from_json(json_file): 33 | # parse the configurations from the configs json file provided 34 | with open(json_file, 'r') as config_file: 35 | config_dict = json.load(config_file) 36 | 37 | # convert the dictionary to a namespace using bunch lib 38 | config = EasyDict(config_dict) 39 | 40 | return config 41 | 42 | 43 | def process_config(args): 44 | config = get_config_from_json(args.config) 45 | config.commit_id = args.id 46 | config.time_stamp = args.ts 47 | config.directory = args.dir 48 | return config 49 | 50 | 51 | if __name__ == '__main__': 52 | config = get_config_from_json('../configs/MUTAG.json') 53 | sub_configurations = config.configurations 54 | print(sub_configurations['pooling']) 55 | -------------------------------------------------------------------------------- /tu/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Linear, Sequential, ReLU, ELU, Sigmoid 4 | from utils.jumping_knowledge import JumpingKnowledge 5 | from torch_geometric.nn import global_add_pool, global_mean_pool 6 | from conv import ExpC, CombC, ExpC_star, CombC_star 7 | 8 | 9 | class Net(torch.nn.Module): 10 | def __init__(self, 11 | dataset, 12 | config): 13 | super(Net, self).__init__() 14 | self.lin0 = Linear(dataset.num_features, config.hidden) 15 | 16 | self.convs = torch.nn.ModuleList() 17 | if config.methods[:2] == 'EB': 18 | for i in range(config.layers): 19 | self.convs.append(ExpC(config.hidden, 20 | int(config.methods[2:]), 21 | config.variants)) 22 | elif config.methods[:2] == 'EA': 23 | for i in range(config.layers): 24 | self.convs.append(ExpC_star(config.hidden, 25 | int(config.methods[2:]), 26 | config.variants)) 27 | elif config.methods == 'CB': 28 | for i in range(config.layers): 29 | self.convs.append(CombC(config.hidden, config.variants)) 30 | elif config.methods == 'CA': 31 | for i in range(config.layers): 32 | self.convs.append(CombC_star(config.hidden, config.variants)) 33 | else: 34 | raise ValueError('Undefined gnn layer called {}'.format(config.methods)) 35 | 36 | self.JK = JumpingKnowledge(config.JK) 37 | 38 | if config.JK == 'cat': 39 | self.lin1 = Linear(config.layers * config.hidden, config.hidden) 40 | else: 41 | self.lin1 = Linear(config.hidden, config.hidden) 42 | 43 | self.lin2 = Linear(config.hidden, dataset.num_classes) 44 | 45 | if config.pooling == 'add': 46 | self.pool = global_add_pool 47 | elif config.pooling == 'mean': 48 | self.pool = global_mean_pool 49 | 50 | self.dropout = config.dropout 51 | 52 | def forward(self, data): 53 | x, edge_index, batch = data.x, data.edge_index, data.batch 54 | x = self.lin0(x) 55 | xs = [] 56 | for conv in self.convs: 57 | x = conv(x, edge_index) 58 | xs += [x] 59 | 60 | x = self.JK(xs) 61 | 62 | x = self.pool(x, batch) 63 | 64 | x = F.relu(self.lin1(x)) 65 | x = F.dropout(x, p=self.dropout, training=self.training) 66 | x = self.lin2(x) 67 | return F.log_softmax(x, dim=-1) 68 | 69 | def __repr__(self): 70 | return self.__class__.__name__ 71 | -------------------------------------------------------------------------------- /ogbg/ppa/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from utils.jumping_knowledge import JumpingKnowledge 4 | from torch_geometric.nn import global_add_pool, global_mean_pool 5 | from conv import GinConv, ExpC, CombC, ExpC_star, CombC_star 6 | 7 | 8 | class Net(torch.nn.Module): 9 | def __init__(self, 10 | config, 11 | num_class): 12 | super(Net, self).__init__() 13 | self.node_encoder = torch.nn.Embedding(1, config.hidden) 14 | 15 | self.convs = torch.nn.ModuleList() 16 | if config.methods == 'GIN': 17 | for i in range(config.layers): 18 | self.convs.append(GinConv(config.hidden, config.variants)) 19 | elif config.methods[:2] == 'EB': 20 | for i in range(config.layers): 21 | self.convs.append(ExpC(config.hidden, 22 | int(config.methods[2:]), 23 | config.variants)) 24 | elif config.methods[:2] == 'EA': 25 | for i in range(config.layers): 26 | self.convs.append(ExpC_star(config.hidden, 27 | int(config.methods[2:]), 28 | config.variants)) 29 | elif config.methods == 'CB': 30 | for i in range(config.layers): 31 | self.convs.append(CombC(config.hidden, config.variants)) 32 | elif config.methods == 'CA': 33 | for i in range(config.layers): 34 | self.convs.append(CombC_star(config.hidden, config.variants)) 35 | else: 36 | raise ValueError('Undefined gnn layer called {}'.format(config.methods)) 37 | 38 | self.JK = JumpingKnowledge(config.JK) 39 | 40 | if config.JK == 'cat': 41 | self.graph_pred_linear = torch.nn.Linear(config.hidden * config.layers, num_class) 42 | else: 43 | self.graph_pred_linear = torch.nn.Linear(config.hidden, num_class) 44 | 45 | if config.pooling == 'add': 46 | self.pool = global_add_pool 47 | elif config.pooling == 'mean': 48 | self.pool = global_mean_pool 49 | 50 | self.dropout = config.dropout 51 | 52 | def forward(self, batched_data): 53 | x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch 54 | x = self.node_encoder(x) 55 | xs = [] 56 | for conv in self.convs: 57 | x = conv(x, edge_index, edge_attr) 58 | xs += [x] 59 | 60 | nr = self.JK(xs) 61 | 62 | nr = F.dropout(nr, p=self.dropout, training=self.training) 63 | h_graph = self.pool(nr, batched_data.batch) 64 | return self.graph_pred_linear(h_graph) 65 | 66 | def __repr__(self): 67 | return self.__class__.__name__ 68 | -------------------------------------------------------------------------------- /ogbg/mol/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from utils.jumping_knowledge import JumpingKnowledge 4 | from torch_geometric.nn import global_add_pool, global_mean_pool 5 | from conv import GinConv, ExpC, CombC, ExpC_star, CombC_star 6 | from ogb.graphproppred.mol_encoder import AtomEncoder 7 | 8 | 9 | class Net(torch.nn.Module): 10 | def __init__(self, 11 | config, 12 | num_tasks): 13 | super(Net, self).__init__() 14 | self.atom_encoder = AtomEncoder(config.hidden) 15 | 16 | self.convs = torch.nn.ModuleList() 17 | if config.methods == 'GIN': 18 | for i in range(config.layers): 19 | self.convs.append(GinConv(config.hidden, config.variants)) 20 | elif config.methods[:2] == 'EB': 21 | for i in range(config.layers): 22 | self.convs.append(ExpC(config.hidden, 23 | int(config.methods[2:]), 24 | config.variants)) 25 | elif config.methods[:2] == 'EA': 26 | for i in range(config.layers): 27 | self.convs.append(ExpC_star(config.hidden, 28 | int(config.methods[2:]), 29 | config.variants)) 30 | elif config.methods == 'CB': 31 | for i in range(config.layers): 32 | self.convs.append(CombC(config.hidden, config.variants)) 33 | elif config.methods == 'CA': 34 | for i in range(config.layers): 35 | self.convs.append(CombC_star(config.hidden, config.variants)) 36 | else: 37 | raise ValueError('Undefined gnn layer called {}'.format(config.methods)) 38 | 39 | self.JK = JumpingKnowledge(config.JK) 40 | 41 | if config.JK == 'cat': 42 | self.graph_pred_linear = torch.nn.Linear(config.hidden * config.layers, num_tasks) 43 | else: 44 | self.graph_pred_linear = torch.nn.Linear(config.hidden, num_tasks) 45 | 46 | if config.pooling == 'add': 47 | self.pool = global_add_pool 48 | elif config.pooling == 'mean': 49 | self.pool = global_mean_pool 50 | 51 | self.dropout = config.dropout 52 | 53 | def forward(self, batched_data): 54 | x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch 55 | x = self.atom_encoder(x) 56 | xs = [] 57 | for conv in self.convs: 58 | x = conv(x, edge_index, edge_attr) 59 | xs += [x] 60 | 61 | nr = self.JK(xs) 62 | 63 | nr = F.dropout(nr, p=self.dropout, training=self.training) 64 | h_graph = self.pool(nr, batched_data.batch) 65 | return self.graph_pred_linear(h_graph) 66 | 67 | def __repr__(self): 68 | return self.__class__.__name__ 69 | -------------------------------------------------------------------------------- /qm9/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Linear, Sequential, ReLU, ELU, Sigmoid 4 | from utils.jumping_knowledge import JumpingKnowledge 5 | from torch_geometric.nn import global_add_pool, global_mean_pool 6 | from conv import GinConv, ExpC, CombC, ExpC_star, CombC_star 7 | 8 | 9 | class Net(torch.nn.Module): 10 | def __init__(self, 11 | dataset, 12 | config): 13 | super(Net, self).__init__() 14 | self.pooling = config.pooling 15 | self.lin0 = Linear(dataset.num_features, config.hidden) 16 | 17 | self.convs = torch.nn.ModuleList() 18 | if config.methods[:2] == 'EB': 19 | for i in range(config.layers): 20 | self.convs.append(ExpC(config.hidden, 21 | int(config.methods[2:]), 22 | config.variants)) 23 | elif config.methods[:2] == 'EA': 24 | for i in range(config.layers): 25 | self.convs.append(ExpC_star(config.hidden, 26 | int(config.methods[2:]), 27 | config.variants)) 28 | elif config.methods == 'CB': 29 | for i in range(config.layers): 30 | self.convs.append(CombC(config.hidden, config.variants)) 31 | elif config.methods == 'CA': 32 | for i in range(config.layers): 33 | self.convs.append(CombC_star(config.hidden, config.variants)) 34 | elif config.methods == 'GIN': 35 | for i in range(config.layers): 36 | self.convs.append(GinConv(config.hidden, config.variants)) 37 | else: 38 | raise ValueError('Undefined gnn layer called {}'.format(config.methods)) 39 | 40 | self.JK = JumpingKnowledge(config.JK) 41 | if config.JK == 'cat': 42 | self.lin1 = torch.nn.Linear(config.layers * config.hidden, (config.layers + 1) // 2 * config.hidden) 43 | self.lin2 = torch.nn.Linear((config.layers + 1) // 2 * config.hidden, config.hidden) 44 | else: 45 | self.lin1 = torch.nn.Linear(config.hidden, config.hidden) 46 | self.lin2 = torch.nn.Linear(config.hidden, config.hidden) 47 | self.lin3 = torch.nn.Linear(config.hidden, 1) 48 | 49 | if config.pooling == 'add': 50 | self.pool = global_add_pool 51 | elif config.pooling == 'mean': 52 | self.pool = global_mean_pool 53 | 54 | def forward(self, data): 55 | x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch 56 | x = self.lin0(x) 57 | xs = [] 58 | 59 | for conv in self.convs: 60 | x = conv(x, edge_index, edge_attr) 61 | xs += [self.pool(x, batch)] 62 | 63 | x = self.JK(xs) 64 | x = F.elu(self.lin1(x)) 65 | x = F.elu(self.lin2(x)) 66 | x = self.lin3(x) 67 | return x.view(-1) 68 | 69 | def __repr__(self): 70 | return self.__class__.__name__ 71 | -------------------------------------------------------------------------------- /utils/qm9.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch_geometric.data import (InMemoryDataset, download_url, extract_tar, 5 | Data) 6 | 7 | 8 | class QM9(InMemoryDataset): 9 | r"""The QM9 dataset from the `"MoleculeNet: A Benchmark for Molecular 10 | Machine Learning" `_ paper, consisting of 11 | about 130,000 molecules with 13 regression targets. 12 | Each molecule includes complete spatial information for the single low 13 | energy conformation of the atoms in the molecule. 14 | In addition, we provide the atom features from the `"Neural Message 15 | Passing for Quantum Chemistry" `_ paper. 16 | 17 | Args: 18 | root (string): Root directory where the dataset should be saved. 19 | transform (callable, optional): A function/transform that takes in an 20 | :obj:`torch_geometric.data.Data` object and returns a transformed 21 | version. The data object will be transformed before every access. 22 | (default: :obj:`None`) 23 | pre_transform (callable, optional): A function/transform that takes in 24 | an :obj:`torch_geometric.data.Data` object and returns a 25 | transformed version. The data object will be transformed before 26 | being saved to disk. (default: :obj:`None`) 27 | pre_filter (callable, optional): A function that takes in an 28 | :obj:`torch_geometric.data.Data` object and returns a boolean 29 | value, indicating whether the data object should be included in the 30 | final dataset. (default: :obj:`None`) 31 | """ 32 | 33 | url = 'http://www.roemisch-drei.de/qm9.tar.gz' 34 | 35 | def __init__(self, 36 | root, 37 | transform=None, 38 | pre_transform=None, 39 | pre_filter=None): 40 | super(QM9, self).__init__(root, transform, pre_transform, pre_filter) 41 | self.data, self.slices = torch.load(self.processed_paths[0]) 42 | 43 | @property 44 | def raw_file_names(self): 45 | return 'qm9.pt' 46 | 47 | @property 48 | def processed_file_names(self): 49 | return 'data.pt' 50 | 51 | def download(self): 52 | file_path = download_url(self.url, self.raw_dir) 53 | extract_tar(file_path, self.raw_dir, mode='r') 54 | os.unlink(file_path) 55 | 56 | def process(self): 57 | raw_data_list = torch.load(self.raw_paths[0]) 58 | data_list = [ 59 | Data( 60 | x=d['x'], 61 | edge_index=d['edge_index'], 62 | edge_attr=d['edge_attr'], 63 | y=d['y'], 64 | pos=d['pos']) for d in raw_data_list 65 | ] 66 | 67 | if self.pre_filter is not None: 68 | data_list = [data for data in data_list if self.pre_filter(data)] 69 | 70 | if self.pre_transform is not None: 71 | data_list = [self.pre_transform(data) for data in data_list] 72 | 73 | data, slices = self.collate(data_list) 74 | torch.save((data, slices), self.processed_paths[0]) 75 | -------------------------------------------------------------------------------- /ogbg/code/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Linear, Sequential, ReLU, ELU, Sigmoid 4 | from utils.jumping_knowledge import JumpingKnowledge 5 | from torch_geometric.nn import global_add_pool, global_mean_pool 6 | from conv import GinConv, ExpC, CombC, ExpC_star, CombC_star 7 | 8 | 9 | class Net(torch.nn.Module): 10 | def __init__(self, 11 | config, 12 | num_vocab, max_seq_len, node_encoder): 13 | super(Net, self).__init__() 14 | self.num_vocab = num_vocab 15 | self.max_seq_len = max_seq_len 16 | 17 | self.node_encoder = node_encoder 18 | 19 | self.convs = torch.nn.ModuleList() 20 | if config.methods == 'GIN': 21 | for i in range(config.layers): 22 | self.convs.append(GinConv(config.hidden, config.variants)) 23 | elif config.methods[:2] == 'EB': 24 | for i in range(config.layers): 25 | self.convs.append(ExpC(config.hidden, 26 | int(config.methods[2:]), 27 | config.variants)) 28 | elif config.methods[:2] == 'EA': 29 | for i in range(config.layers): 30 | self.convs.append(ExpC_star(config.hidden, 31 | int(config.methods[2:]), 32 | config.variants)) 33 | elif config.methods == 'CB': 34 | for i in range(config.layers): 35 | self.convs.append(CombC(config.hidden, config.variants)) 36 | elif config.methods == 'CA': 37 | for i in range(config.layers): 38 | self.convs.append(CombC_star(config.hidden, config.variants)) 39 | else: 40 | raise ValueError('Undefined gnn layer called {}'.format(config.methods)) 41 | 42 | self.JK = JumpingKnowledge(config.JK) 43 | 44 | self.graph_pred_linear_list = torch.nn.ModuleList() 45 | if config.JK == 'cat': 46 | for i in range(max_seq_len): 47 | self.graph_pred_linear_list.append(torch.nn.Linear(config.hidden * config.layers, self.num_vocab)) 48 | else: 49 | for i in range(max_seq_len): 50 | self.graph_pred_linear_list.append(torch.nn.Linear(config.hidden, self.num_vocab)) 51 | 52 | if config.pooling == 'add': 53 | self.pool = global_add_pool 54 | elif config.pooling == 'mean': 55 | self.pool = global_mean_pool 56 | 57 | self.dropout = config.dropout 58 | 59 | def forward(self, batched_data): 60 | ''' 61 | Return: 62 | A list of predictions. 63 | i-th element represents prediction at i-th position of the sequence. 64 | ''' 65 | x, edge_index, edge_attr, node_depth, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.node_depth, batched_data.batch 66 | 67 | x = self.node_encoder(x, node_depth.view(-1,)) 68 | xs = [] 69 | for conv in self.convs: 70 | x = conv(x, edge_index, edge_attr) 71 | xs += [x] 72 | 73 | nr = self.JK(xs) 74 | 75 | nr = F.dropout(nr, p=self.dropout, training=self.training) 76 | h_graph = self.pool(nr, batched_data.batch) 77 | 78 | pred_list = [] 79 | for i in range(self.max_seq_len): 80 | pred_list.append(self.graph_pred_linear_list[i](h_graph)) 81 | 82 | return pred_list 83 | 84 | def __repr__(self): 85 | return self.__class__.__name__ 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Breaking the Expression Bottleneck of Graph Neural Networks 2 | 3 | [![MIT License](https://img.shields.io/badge/license-MIT-blue)](LICENSE) 4 | 5 | This is the code of the paper 6 | *Breaking the Expression Bottleneck of Graph Neural Networks*. 7 | 8 | ## Table of Contents 9 | 10 | - [Requirements](#requirements) 11 | - [Run Experiments](#run-experiments) 12 | - [ogbg-ppa](#ogbg-ppa) 13 | - [ogbg-code](#ogbg-code) 14 | - [ogbg-molhiv & ogbg-molpcba](#ogbg-molhiv--ogbg-molpcba) 15 | - [QM9](#qm9) 16 | - [TU](#tu) 17 | - [License](#license) 18 | 19 | ## Requirements 20 | 21 | The code is built upon the 22 | [PyTorch Geometric](https://github.com/rusty1s/pytorch_geometric), 23 | which has been tested running under Python 3.6.9. 24 | 25 | The following packages need to be installed manually: 26 | 27 | - `torch==1.5.1` 28 | - `torch-geometric==1.5.0` 29 | 30 | The following packages can be installed by running 31 | 32 | ```sh 33 | python3 -m pip install -r requirements.txt 34 | ``` 35 | 36 | - `ogb==1.2.1` 37 | - `numpy` 38 | - `easydict` 39 | - `tensorboardX` 40 | 41 | ## Run Experiments 42 | 43 | You can run experiments on multiple datasets as follows. 44 | 45 | ### ogbg-ppa 46 | 47 | To run experiments on `ogbg-ppa`, change directory to [ogbg/ppa](ogbg/ppa): 48 | 49 | ```sh 50 | cd ogbg/ppa 51 | ``` 52 | 53 | You can set hyper-parameters in 54 | [ogbg-ppa.json](ogbg/ppa/ogbg-ppa.json). 55 | 56 | You can change `CUDA_VISIBLE_DEVICES` and output directory in 57 | [run_script.sh](ogbg/ppa/run_script.sh). 58 | 59 | Then, run the following script: 60 | 61 | ```sh 62 | ./run_script.sh 63 | ``` 64 | 65 | ### ogbg-code 66 | 67 | To run experiments on `ogbg-code`, change directory to [ogbg/code](ogbg/code): 68 | 69 | ```sh 70 | cd ogbg/code 71 | ``` 72 | 73 | You can set hyper-parameters in 74 | [ogbg-code.json](ogbg/code/ogbg-code.json). 75 | 76 | You can change `CUDA_VISIBLE_DEVICES` and output directory in 77 | [run_script.sh](ogbg/code/run_script.sh). 78 | 79 | Then, run the following script: 80 | 81 | ```sh 82 | ./run_script.sh 83 | ``` 84 | 85 | ### ogbg-molhiv & ogbg-molpcba 86 | 87 | To run experiments on `ogbg-mol*`, change directory to [ogbg/mol](ogbg/mol): 88 | 89 | ```sh 90 | cd ogbg/mol 91 | ``` 92 | 93 | You can set dataset name and hyper-parameters in 94 | [ogbg-mol.json](ogbg/mol/ogbg-mol.json). 95 | Dataset name should be either `ogbg-molhiv` or `ogbg-molpcba`. 96 | 97 | You can change `CUDA_VISIBLE_DEVICES` and output directory in 98 | [run_script.sh](ogbg/mol/run_script.sh). 99 | 100 | Then, run the following script: 101 | 102 | ```sh 103 | ./run_script.sh 104 | ``` 105 | 106 | ### QM9 107 | 108 | To run experiments on `QM9`, change directory to [qm9](qm9): 109 | 110 | ```sh 111 | cd qm9 112 | ``` 113 | 114 | You can set hyper-parameters in [QM9.json](qm9/QM9.json). 115 | 116 | You can change `CUDA_VISIBLE_DEVICES` and output directory in 117 | [run_script.sh](qm9/run_script.sh). 118 | 119 | Then, run the following script: 120 | 121 | ```sh 122 | ./run_script.sh 123 | ``` 124 | 125 | ### TU 126 | 127 | To run experiments on `TU`, change directory to [tu](tu): 128 | 129 | ```sh 130 | cd tu 131 | ``` 132 | 133 | There are several datesets in TU. 134 | You can set dataset name in [run_script.sh](tu/run_script.sh) 135 | and set hyper-parameters in [configs/\.json](tu/configs). 136 | 137 | You can change `CUDA_VISIBLE_DEVICES` and output directory in 138 | [run_script.sh](tu/run_script.sh). 139 | 140 | Then, run the following script: 141 | 142 | ```sh 143 | ./run_script.sh 144 | ``` 145 | 146 | ## Reference 147 | ``` 148 | @ARTICLE {yang2022breaking, 149 | author = {Yang, Mingqi and Wang, Renjian and Shen, Yanming and Qi, Heng and Yin, Baocai}, 150 | journal = {IEEE Transactions on Knowledge & Data Engineering}, 151 | title = {Breaking the Expression Bottleneck of Graph Neural Networks}, 152 | year = {2022}, 153 | doi = {10.1109/TKDE.2022.3168070}, 154 | address = {Los Alamitos, CA, USA} 155 | } 156 | ``` 157 | 158 | ## License 159 | 160 | [MIT License](LICENSE) 161 | -------------------------------------------------------------------------------- /qm9/main.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import torch 3 | import torch.nn.functional as F 4 | import torch_geometric.transforms as T 5 | from torch_geometric.data import DataLoader 6 | from tensorboardX import SummaryWriter 7 | 8 | import sys 9 | sys.path.append('..') 10 | 11 | from utils.qm9 import QM9 12 | from model import Net 13 | from utils.config import process_config, get_args 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | 17 | 18 | class MyTransform(object): 19 | def __call__(self, data): 20 | data.y = data.y[:, int(config.target)] 21 | return data 22 | 23 | 24 | def train(): 25 | model.train() 26 | loss_all = 0 27 | 28 | for data in train_loader: 29 | data = data.to(device) 30 | optimizer.zero_grad() 31 | loss = F.mse_loss(model(data), data.y) 32 | loss.backward() 33 | loss_all += loss * data.num_graphs 34 | optimizer.step() 35 | return loss_all / len(train_loader.dataset) 36 | 37 | 38 | def test(loader): 39 | model.eval() 40 | error = 0 41 | 42 | for data in loader: 43 | data = data.to(device) 44 | error += ((model(data) * std[config.target].cuda()) - 45 | (data.y * std[config.target].cuda())).abs().sum().item() 46 | return error / len(loader.dataset) 47 | 48 | 49 | args = get_args() 50 | config = process_config(args) 51 | print(config) 52 | 53 | if config.get('seed') is not None: 54 | torch.manual_seed(config.seed) 55 | if torch.cuda.is_available(): 56 | torch.cuda.manual_seed_all(config.seed) 57 | 58 | 59 | path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data', 'QM9') 60 | dataset = QM9(path, transform=T.Compose([MyTransform(), T.Distance()])) 61 | dataset = dataset.shuffle() 62 | 63 | # Normalize targets to mean = 0 and std = 1. 64 | tenpercent = int(len(dataset) * 0.1) 65 | mean = dataset.data.y[tenpercent * 2:].mean(dim=0) 66 | std = dataset.data.y[tenpercent * 2:].std(dim=0) 67 | dataset.data.y = (dataset.data.y - mean) / std 68 | 69 | test_dataset = dataset[:tenpercent] 70 | val_dataset = dataset[tenpercent:tenpercent * 2] 71 | train_dataset = dataset[tenpercent * 2:] 72 | test_loader = DataLoader(test_dataset, batch_size=config.hyperparams.batch_size) 73 | val_loader = DataLoader(val_dataset, batch_size=config.hyperparams.batch_size) 74 | train_loader = DataLoader(train_dataset, batch_size=config.hyperparams.batch_size, shuffle=True) 75 | 76 | model = Net(dataset, config.architecture).to(device) 77 | optimizer = torch.optim.Adam(model.parameters(), lr=config.hyperparams.learning_rate) 78 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 79 | step_size=config.hyperparams.step_size, 80 | gamma=config.hyperparams.decay_rate) 81 | 82 | ts_algo_hp = str(config.time_stamp) + '_' \ 83 | + str(config.commit_id[0:7]) + '_' \ 84 | + str(config.architecture.methods) + '_' \ 85 | + str(config.architecture.variants.fea_activation) + '_' \ 86 | + str(config.architecture.pooling) + '_' \ 87 | + str(config.architecture.JK) + '_' \ 88 | + str(config.architecture.layers) + '_' \ 89 | + str(config.architecture.hidden) + '_' \ 90 | + str(config.architecture.variants.BN) + '_' \ 91 | + str(config.hyperparams.learning_rate) + '_' \ 92 | + str(config.hyperparams.step_size) + '_' \ 93 | + str(config.hyperparams.decay_rate) + '_' \ 94 | + 'B' + str(config.hyperparams.batch_size) + '_' \ 95 | + 'S' + str(config.seed) 96 | 97 | print('--------') 98 | print('QM9_' + str(config.target) + ', ' 99 | + ts_algo_hp 100 | + ', ID=' + config.commit_id) 101 | 102 | writer = SummaryWriter(config.directory) 103 | 104 | best_val_error = None 105 | for epoch in range(1, config.hyperparams.epochs): 106 | lr = scheduler.optimizer.param_groups[0]['lr'] 107 | loss = train() 108 | val_error = test(val_loader) 109 | scheduler.step() 110 | 111 | if best_val_error is None: 112 | best_val_error = val_error 113 | test_error = test(test_loader) 114 | if val_error <= best_val_error: 115 | best_val_error = val_error 116 | print( 117 | 'Epoch: {:03d}, LR: {:7f}, Loss: {:.7f}, Validation MAE: {:.7f}, ' 118 | 'Test MAE: {:.7f}'.format(epoch, lr, loss, val_error, test_error)) 119 | else: 120 | print( 121 | 'Epoch: {:03d}, {:7f},{:.7f},{:.7f},' 122 | '{:.7f}'.format(epoch, lr, loss, val_error, test_error)) 123 | 124 | writer.add_scalars(config.dataset_name + '_' + str(config.target), {ts_algo_hp + '/lr': lr}, epoch) 125 | writer.add_scalars(config.dataset_name + '_' + str(config.target), {ts_algo_hp + '/te': test_error}, epoch) 126 | writer.add_scalars(config.dataset_name + '_' + str(config.target), {ts_algo_hp + '/ve': val_error}, epoch) 127 | writer.add_scalars(config.dataset_name + '_' + str(config.target), {ts_algo_hp + '/ls': loss}, epoch) 128 | 129 | writer.close() 130 | -------------------------------------------------------------------------------- /tu/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn.conv import MessagePassing 3 | from torch_geometric.utils import remove_self_loops, add_self_loops, softmax 4 | import torch.nn.functional as F 5 | from torch.nn import Linear, Sequential, Tanh, ReLU, ELU, BatchNorm1d as BN, Parameter 6 | 7 | 8 | class ExpC(MessagePassing): 9 | def __init__(self, hidden, num_aggr, config, **kwargs): 10 | super(ExpC, self).__init__(aggr='add', **kwargs) 11 | self.hidden = hidden 12 | self.num_aggr = num_aggr 13 | 14 | self.fea_mlp = Sequential( 15 | Linear(hidden * self.num_aggr, hidden), 16 | ReLU(), 17 | Linear(hidden, hidden), 18 | ReLU()) 19 | 20 | self.aggr_mlp = Sequential( 21 | Linear(hidden * 2, self.num_aggr), 22 | Tanh()) 23 | 24 | if config.BN == 'Y': 25 | self.BN = BN(hidden) 26 | else: 27 | self.BN = None 28 | 29 | def forward(self, x, edge_index): 30 | edge_index, _ = remove_self_loops(edge_index) 31 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 32 | out = self.propagate(edge_index, x=x) 33 | if self.BN is not None: 34 | out = self.BN(out) 35 | return out 36 | 37 | def message(self, x_i, x_j): 38 | xe = x_j 39 | aggr_emb = self.aggr_mlp(torch.cat([x_i, xe], dim=-1)) 40 | feature2d = torch.matmul(aggr_emb.unsqueeze(-1), xe.unsqueeze(-1) 41 | .transpose(-1, -2)).squeeze(-1).view(-1, self.hidden * self.num_aggr) 42 | return self.fea_mlp(feature2d) 43 | 44 | def update(self, aggr_out, x): 45 | root_emb = self.aggr_mlp(torch.cat([x, x], dim=-1)) 46 | feature2d = torch.matmul(root_emb.unsqueeze(-1), x.unsqueeze(-1) 47 | .transpose(-1, -2)).squeeze(-1).view(-1, self.hidden * self.num_aggr) 48 | return aggr_out + self.fea_mlp(feature2d) 49 | 50 | def __repr__(self): 51 | return self.__class__.__name__ 52 | 53 | 54 | class ExpC_star(MessagePassing): 55 | def __init__(self, hidden, num_aggr, config, **kwargs): 56 | super(ExpC_star, self).__init__(aggr='add', **kwargs) 57 | self.hidden = hidden 58 | self.num_aggr = num_aggr 59 | 60 | self.fea_mlp = Sequential( 61 | Linear(hidden * self.num_aggr, hidden), 62 | ReLU(), 63 | Linear(hidden, hidden), 64 | ReLU()) 65 | 66 | self.aggr_mlp = Sequential( 67 | Linear(hidden * 2, self.num_aggr), 68 | Tanh()) 69 | 70 | if config.BN == 'Y': 71 | self.BN = BN(hidden) 72 | else: 73 | self.BN = None 74 | 75 | def forward(self, x, edge_index): 76 | edge_index, _ = remove_self_loops(edge_index) 77 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 78 | out = self.fea_mlp(self.propagate(edge_index, x=x)) 79 | if self.BN is not None: 80 | out = self.BN(out) 81 | return out 82 | 83 | def message(self, x_i, x_j): 84 | xe = x_j 85 | aggr_emb = self.aggr_mlp(torch.cat([x_i, xe], dim=-1)) 86 | feature2d = torch.matmul(aggr_emb.unsqueeze(-1), xe.unsqueeze(-1) 87 | .transpose(-1, -2)).squeeze(-1).view(-1, self.hidden * self.num_aggr) 88 | return feature2d 89 | 90 | def update(self, aggr_out, x): 91 | root_emb = self.aggr_mlp(torch.cat([x, x], dim=-1)) 92 | feature2d = torch.matmul(root_emb.unsqueeze(-1), x.unsqueeze(-1) 93 | .transpose(-1, -2)).squeeze(-1).view(-1, self.hidden * self.num_aggr) 94 | return aggr_out + feature2d 95 | 96 | def __repr__(self): 97 | return self.__class__.__name__ 98 | 99 | 100 | class CombC(MessagePassing): 101 | def __init__(self, hidden, config, **kwargs): 102 | super(CombC, self).__init__(aggr='add', **kwargs) 103 | 104 | self.fea_mlp = Sequential( 105 | Linear(hidden, hidden), 106 | ReLU(), 107 | Linear(hidden, hidden), 108 | ReLU()) 109 | 110 | self.aggr_mlp = Sequential( 111 | Linear(hidden * 2, hidden), 112 | Tanh()) 113 | 114 | if config.BN == 'Y': 115 | self.BN = BN(hidden) 116 | else: 117 | self.BN = None 118 | 119 | def forward(self, x, edge_index): 120 | edge_index, _ = remove_self_loops(edge_index) 121 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 122 | out = self.propagate(edge_index, x=x) 123 | if self.BN is not None: 124 | out = self.BN(out) 125 | return out 126 | 127 | def message(self, x_i, x_j): 128 | xe = x_j 129 | aggr_emb = self.aggr_mlp(torch.cat([x_i, xe], dim=-1)) 130 | return self.fea_mlp(aggr_emb * xe) 131 | 132 | def update(self, aggr_out, x): 133 | root_emb = self.aggr_mlp(torch.cat([x, x], dim=-1)) 134 | return aggr_out + self.fea_mlp(root_emb * x) 135 | 136 | def __repr__(self): 137 | return self.__class__.__name__ 138 | 139 | 140 | class CombC_star(MessagePassing): 141 | def __init__(self, hidden, config, **kwargs): 142 | super(CombC_star, self).__init__(aggr='add', **kwargs) 143 | 144 | self.fea_mlp = Sequential( 145 | Linear(hidden, hidden), 146 | ReLU(), 147 | Linear(hidden, hidden), 148 | ReLU()) 149 | 150 | self.aggr_mlp = Sequential( 151 | Linear(hidden * 2, hidden), 152 | Tanh()) 153 | 154 | if config.BN == 'Y': 155 | self.BN = BN(hidden) 156 | else: 157 | self.BN = None 158 | 159 | def forward(self, x, edge_index): 160 | edge_index, _ = remove_self_loops(edge_index) 161 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 162 | out = self.fea_mlp(self.propagate(edge_index, x=x)) 163 | if self.BN is not None: 164 | out = self.BN(out) 165 | return out 166 | 167 | def message(self, x_i, x_j): 168 | xe = x_j 169 | aggr_emb = self.aggr_mlp(torch.cat([x_i, xe], dim=-1)) 170 | return aggr_emb * xe 171 | 172 | def update(self, aggr_out, x): 173 | root_emb = self.aggr_mlp(torch.cat([x, x], dim=-1)) 174 | return aggr_out + root_emb * x 175 | 176 | def __repr__(self): 177 | return self.__class__.__name__ 178 | -------------------------------------------------------------------------------- /ogbg/ppa/main.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch_geometric.data import DataLoader 4 | from tensorboardX import SummaryWriter 5 | import torch.optim as optim 6 | import numpy as np 7 | ### importing OGB 8 | from ogb.graphproppred import PygGraphPropPredDataset, Evaluator 9 | 10 | import sys 11 | sys.path.append('../..') 12 | 13 | from model import Net 14 | from utils.config import process_config, get_args 15 | 16 | 17 | class In: 18 | def readline(self): 19 | return "y\n" 20 | 21 | def close(self): 22 | pass 23 | 24 | 25 | def train(model, device, loader, optimizer, multicls_criterion): 26 | model.train() 27 | loss_all = 0 28 | 29 | for step, batch in enumerate(loader): 30 | batch = batch.to(device) 31 | 32 | if batch.x.shape[0] == 1 or batch.batch[-1] == 0: 33 | pass 34 | else: 35 | pred = model(batch) 36 | optimizer.zero_grad() 37 | 38 | loss = multicls_criterion(pred.to(torch.float32), batch.y.view(-1,)) 39 | 40 | loss.backward() 41 | loss_all += loss.item() 42 | optimizer.step() 43 | 44 | return loss_all / len(loader) 45 | 46 | 47 | def eval(model, device, loader, evaluator): 48 | model.eval() 49 | y_true = [] 50 | y_pred = [] 51 | 52 | for step, batch in enumerate(loader): 53 | batch = batch.to(device) 54 | 55 | if batch.x.shape[0] == 1: 56 | pass 57 | else: 58 | with torch.no_grad(): 59 | pred = model(batch) 60 | 61 | y_true.append(batch.y.view(-1, 1).detach().cpu()) 62 | y_pred.append(torch.argmax(pred.detach(), dim=1).view(-1, 1).cpu()) 63 | 64 | y_true = torch.cat(y_true, dim=0).numpy() 65 | y_pred = torch.cat(y_pred, dim=0).numpy() 66 | 67 | input_dict = {"y_true": y_true, "y_pred": y_pred} 68 | 69 | return evaluator.eval(input_dict) 70 | 71 | 72 | def add_zeros(data): 73 | data.x = torch.zeros(data.num_nodes, dtype=torch.long) 74 | return data 75 | 76 | 77 | def main(): 78 | args = get_args() 79 | config = process_config(args) 80 | print(config) 81 | 82 | if config.get('seed') is not None: 83 | random.seed(config.seed) 84 | torch.manual_seed(config.seed) 85 | np.random.seed(config.seed) 86 | if torch.cuda.is_available(): 87 | torch.cuda.manual_seed_all(config.seed) 88 | 89 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 90 | 91 | ### automatic dataloading and splitting 92 | 93 | sys.stdin = In() 94 | 95 | dataset = PygGraphPropPredDataset(name=config.dataset_name, transform=add_zeros) 96 | 97 | split_idx = dataset.get_idx_split() 98 | 99 | ### automatic evaluator. takes dataset name as input 100 | evaluator = Evaluator(config.dataset_name) 101 | 102 | train_loader = DataLoader(dataset[split_idx["train"]], batch_size=config.hyperparams.batch_size, shuffle=True, 103 | num_workers=config.num_workers) 104 | valid_loader = DataLoader(dataset[split_idx["valid"]], 105 | batch_size=config.hyperparams.batch_size, shuffle=False, num_workers=config.num_workers) 106 | test_loader = DataLoader(dataset[split_idx["test"]], 107 | batch_size=config.hyperparams.batch_size, shuffle=False, num_workers=config.num_workers) 108 | 109 | model = Net(config.architecture, num_class=dataset.num_classes).to(device) 110 | 111 | optimizer = optim.Adam(model.parameters(), lr=config.hyperparams.learning_rate) 112 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.hyperparams.step_size, 113 | gamma=config.hyperparams.decay_rate) 114 | 115 | multicls_criterion = torch.nn.CrossEntropyLoss() 116 | 117 | valid_curve = [] 118 | test_curve = [] 119 | train_curve = [] 120 | trainL_curve = [] 121 | 122 | writer = SummaryWriter(config.directory) 123 | 124 | ts_fk_algo_hp = str(config.time_stamp) + '_' \ 125 | + str(config.commit_id[0:7]) + '_' \ 126 | + str(config.architecture.methods) + '_' \ 127 | + str(config.architecture.pooling) + '_' \ 128 | + str(config.architecture.JK) + '_' \ 129 | + str(config.architecture.layers) + '_' \ 130 | + str(config.architecture.hidden) + '_' \ 131 | + str(config.architecture.variants.BN) + '_' \ 132 | + str(config.architecture.dropout) + '_' \ 133 | + str(config.hyperparams.learning_rate) + '_' \ 134 | + str(config.hyperparams.step_size) + '_' \ 135 | + str(config.hyperparams.decay_rate) + '_' \ 136 | + 'B' + str(config.hyperparams.batch_size) + '_' \ 137 | + 'S' + str(config.seed if config.get('seed') is not None else "na") + '_' \ 138 | + 'W' + str(config.num_workers if config.get('num_workers') is not None else "na") 139 | 140 | for epoch in range(1, config.hyperparams.epochs + 1): 141 | print("Epoch {} training...".format(epoch)) 142 | train_loss = train(model, device, train_loader, optimizer, multicls_criterion) 143 | 144 | scheduler.step() 145 | 146 | print('Evaluating...') 147 | train_perf = eval(model, device, train_loader, evaluator) 148 | valid_perf = eval(model, device, valid_loader, evaluator) 149 | test_perf = eval(model, device, test_loader, evaluator) 150 | 151 | print('Train:', train_perf[dataset.eval_metric], 152 | 'Validation:', valid_perf[dataset.eval_metric], 153 | 'Test:', test_perf[dataset.eval_metric], 154 | 'Train loss:', train_loss) 155 | 156 | train_curve.append(train_perf[dataset.eval_metric]) 157 | valid_curve.append(valid_perf[dataset.eval_metric]) 158 | test_curve.append(test_perf[dataset.eval_metric]) 159 | trainL_curve.append(train_loss) 160 | 161 | writer.add_scalars(config.dataset_name, {ts_fk_algo_hp + '/traP': train_perf[dataset.eval_metric]}, epoch) 162 | writer.add_scalars(config.dataset_name, {ts_fk_algo_hp + '/valP': valid_perf[dataset.eval_metric]}, epoch) 163 | writer.add_scalars(config.dataset_name, {ts_fk_algo_hp + '/tstP': test_perf[dataset.eval_metric]}, epoch) 164 | writer.add_scalars(config.dataset_name, {ts_fk_algo_hp + '/traL': train_loss}, epoch) 165 | writer.close() 166 | 167 | best_val_epoch = np.argmax(np.array(valid_curve)) 168 | best_train = max(train_curve) 169 | 170 | print('Best validation score: {}'.format(valid_curve[best_val_epoch])) 171 | print('Test score: {}'.format(test_curve[best_val_epoch])) 172 | 173 | print('Finished test: {}, Validation: {}, Train: {}, epoch: {}, best train: {}, best loss: {}' 174 | .format(test_curve[best_val_epoch], valid_curve[best_val_epoch], train_curve[best_val_epoch], 175 | best_val_epoch, best_train, min(trainL_curve))) 176 | 177 | 178 | if __name__ == "__main__": 179 | main() 180 | -------------------------------------------------------------------------------- /ogbg/mol/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from torch_geometric.data import DataLoader 4 | from tensorboardX import SummaryWriter 5 | import numpy as np 6 | ### importing OGB 7 | from ogb.graphproppred import PygGraphPropPredDataset, Evaluator 8 | import torch.optim as optim 9 | 10 | import sys 11 | sys.path.append('../..') 12 | 13 | from model import Net 14 | from utils.config import process_config, get_args 15 | 16 | cls_criterion = torch.nn.BCEWithLogitsLoss() 17 | reg_criterion = torch.nn.MSELoss() 18 | 19 | 20 | class In: 21 | def readline(self): 22 | return "y\n" 23 | 24 | def close(self): 25 | pass 26 | 27 | 28 | def train(model, device, loader, optimizer, task_type): 29 | model.train() 30 | loss_all = 0 31 | 32 | for step, batch in enumerate(loader): 33 | batch = batch.to(device) 34 | 35 | if batch.x.shape[0] == 1 or batch.batch[-1] == 0: 36 | pass 37 | else: 38 | pred = model(batch) 39 | optimizer.zero_grad() 40 | ## ignore nan targets (unlabeled) when computing training loss. 41 | is_labeled = batch.y == batch.y 42 | if "classification" in task_type: 43 | loss = cls_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled]) 44 | else: 45 | loss = reg_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled]) 46 | loss.backward() 47 | loss_all += loss.item() 48 | optimizer.step() 49 | 50 | return loss_all / len(loader) 51 | 52 | 53 | def eval(model, device, loader, evaluator): 54 | model.eval() 55 | y_true = [] 56 | y_pred = [] 57 | 58 | for step, batch in enumerate(loader): 59 | batch = batch.to(device) 60 | 61 | if batch.x.shape[0] == 1: 62 | pass 63 | else: 64 | with torch.no_grad(): 65 | pred = model(batch) 66 | 67 | y_true.append(batch.y.view(pred.shape).detach().cpu()) 68 | y_pred.append(pred.detach().cpu()) 69 | 70 | y_true = torch.cat(y_true, dim=0).numpy() 71 | y_pred = torch.cat(y_pred, dim=0).numpy() 72 | 73 | input_dict = {"y_true": y_true, "y_pred": y_pred} 74 | 75 | return evaluator.eval(input_dict) 76 | 77 | 78 | def main(): 79 | args = get_args() 80 | config = process_config(args) 81 | print(config) 82 | 83 | if config.get('seed') is not None: 84 | random.seed(config.seed) 85 | torch.manual_seed(config.seed) 86 | np.random.seed(config.seed) 87 | if torch.cuda.is_available(): 88 | torch.cuda.manual_seed_all(config.seed) 89 | 90 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 91 | 92 | ### automatic dataloading and splitting 93 | 94 | sys.stdin = In() 95 | 96 | dataset = PygGraphPropPredDataset(name=config.dataset_name) 97 | 98 | if config.feature == 'full': 99 | pass 100 | elif config.feature == 'simple': 101 | print('using simple feature') 102 | # only retain the top two node/edge features 103 | dataset.data.x = dataset.data.x[:, :2] 104 | dataset.data.edge_attr = dataset.data.edge_attr[:, :2] 105 | 106 | split_idx = dataset.get_idx_split() 107 | 108 | ### automatic evaluator. takes dataset name as input 109 | evaluator = Evaluator(config.dataset_name) 110 | 111 | train_loader = DataLoader(dataset[split_idx["train"]], batch_size=config.hyperparams.batch_size, shuffle=True, 112 | num_workers=config.num_workers) 113 | valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=config.hyperparams.batch_size, shuffle=False, 114 | num_workers=config.num_workers) 115 | test_loader = DataLoader(dataset[split_idx["test"]], batch_size=config.hyperparams.batch_size, shuffle=False, 116 | num_workers=config.num_workers) 117 | 118 | model = Net(config.architecture, num_tasks=dataset.num_tasks).to(device) 119 | 120 | optimizer = optim.Adam(model.parameters(), lr=config.hyperparams.learning_rate) 121 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.hyperparams.step_size, 122 | gamma=config.hyperparams.decay_rate) 123 | 124 | valid_curve = [] 125 | test_curve = [] 126 | train_curve = [] 127 | trainL_curve = [] 128 | 129 | writer = SummaryWriter(config.directory) 130 | 131 | ts_fk_algo_hp = str(config.time_stamp) + '_' \ 132 | + str(config.commit_id[0:7]) + '_' \ 133 | + str(config.architecture.methods) + '_' \ 134 | + str(config.architecture.pooling) + '_' \ 135 | + str(config.architecture.JK) + '_' \ 136 | + str(config.architecture.layers) + '_' \ 137 | + str(config.architecture.hidden) + '_' \ 138 | + str(config.architecture.variants.BN) + '_' \ 139 | + str(config.architecture.dropout) + '_' \ 140 | + str(config.hyperparams.learning_rate) + '_' \ 141 | + str(config.hyperparams.step_size) + '_' \ 142 | + str(config.hyperparams.decay_rate) + '_' \ 143 | + 'B' + str(config.hyperparams.batch_size) + '_' \ 144 | + 'S' + str(config.seed if config.get('seed') is not None else "na") + '_' \ 145 | + 'W' + str(config.num_workers if config.get('num_workers') is not None else "na") 146 | 147 | for epoch in range(1, config.hyperparams.epochs + 1): 148 | print("Epoch {} training...".format(epoch)) 149 | train_loss = train(model, device, train_loader, optimizer, dataset.task_type) 150 | 151 | scheduler.step() 152 | 153 | print('Evaluating...') 154 | train_perf = eval(model, device, train_loader, evaluator) 155 | valid_perf = eval(model, device, valid_loader, evaluator) 156 | test_perf = eval(model, device, test_loader, evaluator) 157 | 158 | print('Train:', train_perf[dataset.eval_metric], 159 | 'Validation:', valid_perf[dataset.eval_metric], 160 | 'Test:', test_perf[dataset.eval_metric], 161 | 'Train loss:', train_loss) 162 | 163 | train_curve.append(train_perf[dataset.eval_metric]) 164 | valid_curve.append(valid_perf[dataset.eval_metric]) 165 | test_curve.append(test_perf[dataset.eval_metric]) 166 | trainL_curve.append(train_loss) 167 | 168 | writer.add_scalars(config.dataset_name, {ts_fk_algo_hp + '/traP': train_perf[dataset.eval_metric]}, epoch) 169 | writer.add_scalars(config.dataset_name, {ts_fk_algo_hp + '/valP': valid_perf[dataset.eval_metric]}, epoch) 170 | writer.add_scalars(config.dataset_name, {ts_fk_algo_hp + '/tstP': test_perf[dataset.eval_metric]}, epoch) 171 | writer.add_scalars(config.dataset_name, {ts_fk_algo_hp + '/traL': train_loss}, epoch) 172 | 173 | writer.close() 174 | 175 | if 'classification' in dataset.task_type: 176 | best_val_epoch = np.argmax(np.array(valid_curve)) 177 | best_train = max(train_curve) 178 | else: 179 | best_val_epoch = np.argmin(np.array(valid_curve)) 180 | best_train = min(train_curve) 181 | 182 | print('Finished test: {}, Validation: {}, epoch: {}, best train: {}, best loss: {}' 183 | .format(test_curve[best_val_epoch], valid_curve[best_val_epoch], 184 | best_val_epoch, best_train, min(trainL_curve))) 185 | 186 | 187 | if __name__ == "__main__": 188 | main() 189 | -------------------------------------------------------------------------------- /ogbg/code/proc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import Counter 3 | import numpy as np 4 | import torch 5 | 6 | class ASTNodeEncoder(torch.nn.Module): 7 | ''' 8 | Input: 9 | x: default node feature. the first and second column represents node type and node attributes. 10 | depth: The depth of the node in the AST. 11 | 12 | Output: 13 | emb_dim-dimensional vector 14 | 15 | ''' 16 | def __init__(self, emb_dim, num_nodetypes, num_nodeattributes, max_depth): 17 | super(ASTNodeEncoder, self).__init__() 18 | 19 | self.max_depth = max_depth 20 | 21 | self.type_encoder = torch.nn.Embedding(num_nodetypes, emb_dim) 22 | self.attribute_encoder = torch.nn.Embedding(num_nodeattributes, emb_dim) 23 | self.depth_encoder = torch.nn.Embedding(self.max_depth + 1, emb_dim) 24 | 25 | 26 | def forward(self, x, depth): 27 | depth[depth > self.max_depth] = self.max_depth 28 | return self.type_encoder(x[:,0]) + self.attribute_encoder(x[:,1]) + self.depth_encoder(depth) 29 | 30 | 31 | 32 | def get_vocab_mapping(seq_list, num_vocab): 33 | ''' 34 | Input: 35 | seq_list: a list of sequences 36 | num_vocab: vocabulary size 37 | Output: 38 | vocab2idx: 39 | A dictionary that maps vocabulary into integer index. 40 | Additioanlly, we also index '__UNK__' and '__EOS__' 41 | '__UNK__' : out-of-vocabulary term 42 | '__EOS__' : end-of-sentence 43 | 44 | idx2vocab: 45 | A list that maps idx to actual vocabulary. 46 | 47 | ''' 48 | 49 | vocab_cnt = {} 50 | vocab_list = [] 51 | for seq in seq_list: 52 | for w in seq: 53 | if w in vocab_cnt: 54 | vocab_cnt[w] += 1 55 | else: 56 | vocab_cnt[w] = 1 57 | vocab_list.append(w) 58 | 59 | cnt_list = np.array([vocab_cnt[w] for w in vocab_list]) 60 | topvocab = np.argsort(-cnt_list, kind = 'stable')[:num_vocab] 61 | 62 | print('Coverage of top {} vocabulary:'.format(num_vocab)) 63 | print(float(np.sum(cnt_list[topvocab]))/np.sum(cnt_list)) 64 | 65 | vocab2idx = {vocab_list[vocab_idx]: idx for idx, vocab_idx in enumerate(topvocab)} 66 | idx2vocab = [vocab_list[vocab_idx] for vocab_idx in topvocab] 67 | 68 | # print(topvocab) 69 | # print([vocab_list[v] for v in topvocab[:10]]) 70 | # print([vocab_list[v] for v in topvocab[-10:]]) 71 | 72 | vocab2idx['__UNK__'] = num_vocab 73 | idx2vocab.append('__UNK__') 74 | 75 | vocab2idx['__EOS__'] = num_vocab + 1 76 | idx2vocab.append('__EOS__') 77 | 78 | # test the correspondence between vocab2idx and idx2vocab 79 | for idx, vocab in enumerate(idx2vocab): 80 | assert(idx == vocab2idx[vocab]) 81 | 82 | # test that the idx of '__EOS__' is len(idx2vocab) - 1. 83 | # This fact will be used in decode_arr_to_seq, when finding __EOS__ 84 | assert(vocab2idx['__EOS__'] == len(idx2vocab) - 1) 85 | 86 | return vocab2idx, idx2vocab 87 | 88 | def augment_edge(data): 89 | ''' 90 | Input: 91 | data: PyG data object 92 | Output: 93 | data (edges are augmented in the following ways): 94 | data.edge_index: Added next-token edge. The inverse edges were also added. 95 | data.edge_attr (torch.Long): 96 | data.edge_attr[:,0]: whether it is AST edge (0) for next-token edge (1) 97 | data.edge_attr[:,1]: whether it is original direction (0) or inverse direction (1) 98 | ''' 99 | 100 | ##### AST edge 101 | edge_index_ast = data.edge_index 102 | edge_attr_ast = torch.zeros((edge_index_ast.size(1), 2)) 103 | 104 | ##### Inverse AST edge 105 | edge_index_ast_inverse = torch.stack([edge_index_ast[1], edge_index_ast[0]], dim = 0) 106 | edge_attr_ast_inverse = torch.cat([torch.zeros(edge_index_ast_inverse.size(1), 1), torch.ones(edge_index_ast_inverse.size(1), 1)], dim = 1) 107 | 108 | 109 | ##### Next-token edge 110 | 111 | ## Obtain attributed nodes and get their indices in dfs order 112 | # attributed_node_idx = torch.where(data.node_is_attributed.view(-1,) == 1)[0] 113 | # attributed_node_idx_in_dfs_order = attributed_node_idx[torch.argsort(data.node_dfs_order[attributed_node_idx].view(-1,))] 114 | 115 | ## Since the nodes are already sorted in dfs ordering in our case, we can just do the following. 116 | attributed_node_idx_in_dfs_order = torch.where(data.node_is_attributed.view(-1,) == 1)[0] 117 | 118 | ## build next token edge 119 | # Given: attributed_node_idx_in_dfs_order 120 | # [1, 3, 4, 5, 8, 9, 12] 121 | # Output: 122 | # [[1, 3, 4, 5, 8, 9] 123 | # [3, 4, 5, 8, 9, 12] 124 | edge_index_nextoken = torch.stack([attributed_node_idx_in_dfs_order[:-1], attributed_node_idx_in_dfs_order[1:]], dim = 0) 125 | edge_attr_nextoken = torch.cat([torch.ones(edge_index_nextoken.size(1), 1), torch.zeros(edge_index_nextoken.size(1), 1)], dim = 1) 126 | 127 | 128 | ##### Inverse next-token edge 129 | edge_index_nextoken_inverse = torch.stack([edge_index_nextoken[1], edge_index_nextoken[0]], dim = 0) 130 | edge_attr_nextoken_inverse = torch.ones((edge_index_nextoken.size(1), 2)) 131 | 132 | 133 | data.edge_index = torch.cat([edge_index_ast, edge_index_ast_inverse, edge_index_nextoken, edge_index_nextoken_inverse], dim = 1) 134 | data.edge_attr = torch.cat([edge_attr_ast, edge_attr_ast_inverse, edge_attr_nextoken, edge_attr_nextoken_inverse], dim = 0) 135 | 136 | return data 137 | 138 | def encode_y_to_arr(data, vocab2idx, max_seq_len): 139 | ''' 140 | Input: 141 | data: PyG graph object 142 | output: add y_arr to data 143 | ''' 144 | 145 | # PyG >= 1.5.0 146 | seq = data.y 147 | 148 | # PyG = 1.4.3 149 | # seq = data.y[0] 150 | 151 | data.y_arr = encode_seq_to_arr(seq, vocab2idx, max_seq_len) 152 | 153 | return data 154 | 155 | def encode_seq_to_arr(seq, vocab2idx, max_seq_len): 156 | ''' 157 | Input: 158 | seq: A list of words 159 | output: add y_arr (torch.Tensor) 160 | ''' 161 | 162 | augmented_seq = seq[:max_seq_len] + ['__EOS__'] * max(0, max_seq_len - len(seq)) 163 | return torch.tensor([[vocab2idx[w] if w in vocab2idx else vocab2idx['__UNK__'] for w in augmented_seq]], dtype = torch.long) 164 | 165 | 166 | def decode_arr_to_seq(arr, idx2vocab): 167 | ''' 168 | Input: torch 1d array: y_arr 169 | Output: a sequence of words. 170 | ''' 171 | 172 | 173 | eos_idx_list = (arr == len(idx2vocab) - 1).nonzero() # find the position of __EOS__ (the last vocab in idx2vocab) 174 | if len(eos_idx_list) > 0: 175 | clippted_arr = arr[: torch.min(eos_idx_list)] # find the smallest __EOS__ 176 | else: 177 | clippted_arr = arr 178 | 179 | return list(map(lambda x: idx2vocab[x], clippted_arr.cpu())) 180 | 181 | 182 | def test(): 183 | seq_list = [['a', 'b'], ['a', 'b', 'c', 'df', 'f', '2edea', 'a'], ['eraea', 'a', 'c'], ['d'], ['4rq4f','f','a','a', 'g']] 184 | vocab2idx, idx2vocab = get_vocab_mapping(seq_list, 4) 185 | print(vocab2idx) 186 | print(idx2vocab) 187 | print() 188 | assert(len(vocab2idx) == len(idx2vocab)) 189 | 190 | for vocab, idx in vocab2idx.items(): 191 | assert(idx2vocab[idx] == vocab) 192 | 193 | 194 | for seq in seq_list: 195 | print(seq) 196 | arr = encode_seq_to_arr(seq, vocab2idx, max_seq_len = 4)[0] 197 | # Test the effect of predicting __EOS__ 198 | # arr[2] = vocab2idx['__EOS__'] 199 | print(arr) 200 | seq_dec = decode_arr_to_seq(arr, idx2vocab) 201 | 202 | print(arr) 203 | print(seq_dec) 204 | print('') 205 | 206 | 207 | 208 | 209 | if __name__ == '__main__': 210 | test() 211 | -------------------------------------------------------------------------------- /ogbg/ppa/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn.conv import MessagePassing 3 | from torch_geometric.utils import remove_self_loops, add_self_loops 4 | import torch.nn.functional as F 5 | from torch.nn import Linear, Sequential, Tanh, ReLU, ELU, BatchNorm1d as BN 6 | 7 | 8 | class ExpC(MessagePassing): 9 | def __init__(self, hidden, num_aggr, config, **kwargs): 10 | super(ExpC, self).__init__(aggr='add', **kwargs) 11 | self.hidden = hidden 12 | self.num_aggr = num_aggr 13 | 14 | self.fea_mlp = Sequential( 15 | Linear(hidden * self.num_aggr, hidden), 16 | ReLU(), 17 | Linear(hidden, hidden), 18 | ReLU()) 19 | 20 | self.aggr_mlp = Sequential( 21 | Linear(hidden * 2, self.num_aggr), 22 | Tanh()) 23 | 24 | if config.BN == 'Y': 25 | self.BN = BN(hidden) 26 | else: 27 | self.BN = None 28 | 29 | self.edge_encoder = torch.nn.Linear(7, hidden) 30 | 31 | def forward(self, x, edge_index, edge_attr): 32 | edge_attr = self.edge_encoder(edge_attr) 33 | edge_index, _ = remove_self_loops(edge_index) 34 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 35 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr) 36 | if self.BN is not None: 37 | out = self.BN(out) 38 | return out 39 | 40 | def message(self, x_i, x_j, edge_attr): 41 | xe = x_j + edge_attr 42 | aggr_emb = self.aggr_mlp(torch.cat([x_i, xe], dim=-1)) 43 | feature2d = torch.matmul(aggr_emb.unsqueeze(-1), xe.unsqueeze(-1) 44 | .transpose(-1, -2)).squeeze(-1).view(-1, self.hidden * self.num_aggr) 45 | return self.fea_mlp(feature2d) 46 | 47 | def update(self, aggr_out, x): 48 | root_emb = self.aggr_mlp(torch.cat([x, x], dim=-1)) 49 | feature2d = torch.matmul(root_emb.unsqueeze(-1), x.unsqueeze(-1) 50 | .transpose(-1, -2)).squeeze(-1).view(-1, self.hidden * self.num_aggr) 51 | return aggr_out + self.fea_mlp(feature2d) 52 | 53 | def __repr__(self): 54 | return self.__class__.__name__ 55 | 56 | 57 | class ExpC_star(MessagePassing): 58 | def __init__(self, hidden, num_aggr, config, **kwargs): 59 | super(ExpC_star, self).__init__(aggr='add', **kwargs) 60 | self.hidden = hidden 61 | self.num_aggr = num_aggr 62 | 63 | self.fea_mlp = Sequential( 64 | Linear(hidden * self.num_aggr, hidden), 65 | ReLU(), 66 | Linear(hidden, hidden), 67 | ReLU()) 68 | 69 | self.aggr_mlp = Sequential( 70 | Linear(hidden * 2, self.num_aggr), 71 | Tanh()) 72 | 73 | if config.BN == 'Y': 74 | self.BN = BN(hidden) 75 | else: 76 | self.BN = None 77 | 78 | self.edge_encoder = torch.nn.Linear(7, hidden) 79 | 80 | def forward(self, x, edge_index, edge_attr): 81 | edge_attr = self.edge_encoder(edge_attr) 82 | edge_index, _ = remove_self_loops(edge_index) 83 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 84 | out = self.fea_mlp(self.propagate(edge_index, x=x, edge_attr=edge_attr)) 85 | if self.BN is not None: 86 | out = self.BN(out) 87 | return out 88 | 89 | def message(self, x_i, x_j, edge_attr): 90 | xe = x_j + edge_attr 91 | aggr_emb = self.aggr_mlp(torch.cat([x_i, xe], dim=-1)) 92 | feature2d = torch.matmul(aggr_emb.unsqueeze(-1), xe.unsqueeze(-1) 93 | .transpose(-1, -2)).squeeze(-1).view(-1, self.hidden * self.num_aggr) 94 | return feature2d 95 | 96 | def update(self, aggr_out, x): 97 | root_emb = self.aggr_mlp(torch.cat([x, x], dim=-1)) 98 | feature2d = torch.matmul(root_emb.unsqueeze(-1), x.unsqueeze(-1) 99 | .transpose(-1, -2)).squeeze(-1).view(-1, self.hidden * self.num_aggr) 100 | return aggr_out + feature2d 101 | 102 | def __repr__(self): 103 | return self.__class__.__name__ 104 | 105 | 106 | class CombC(MessagePassing): 107 | def __init__(self, hidden, config, **kwargs): 108 | super(CombC, self).__init__(aggr='add', **kwargs) 109 | 110 | self.fea_mlp = Sequential( 111 | Linear(hidden, hidden), 112 | ReLU(), 113 | Linear(hidden, hidden), 114 | ReLU()) 115 | 116 | self.aggr_mlp = Sequential( 117 | Linear(hidden * 2, hidden), 118 | Tanh()) 119 | 120 | if config.BN == 'Y': 121 | self.BN = BN(hidden) 122 | else: 123 | self.BN = None 124 | 125 | self.edge_encoder = torch.nn.Linear(7, hidden) 126 | 127 | def forward(self, x, edge_index, edge_attr): 128 | edge_attr = self.edge_encoder(edge_attr) 129 | edge_index, _ = remove_self_loops(edge_index) 130 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 131 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr) 132 | if self.BN is not None: 133 | out = self.BN(out) 134 | return out 135 | 136 | def message(self, x_i, x_j, edge_attr): 137 | xe = x_j + edge_attr 138 | aggr_emb = self.aggr_mlp(torch.cat([x_i, xe], dim=-1)) 139 | return self.fea_mlp(aggr_emb * xe) 140 | 141 | def update(self, aggr_out, x): 142 | root_emb = self.aggr_mlp(torch.cat([x, x], dim=-1)) 143 | return aggr_out + self.fea_mlp(root_emb * x) 144 | 145 | def __repr__(self): 146 | return self.__class__.__name__ 147 | 148 | 149 | class CombC_star(MessagePassing): 150 | def __init__(self, hidden, config, **kwargs): 151 | super(CombC_star, self).__init__(aggr='add', **kwargs) 152 | 153 | self.fea_mlp = Sequential( 154 | Linear(hidden, hidden), 155 | ReLU(), 156 | Linear(hidden, hidden), 157 | ReLU()) 158 | 159 | self.aggr_mlp = Sequential( 160 | Linear(hidden * 2, hidden), 161 | Tanh()) 162 | 163 | if config.BN == 'Y': 164 | self.BN = BN(hidden) 165 | else: 166 | self.BN = None 167 | 168 | self.edge_encoder = torch.nn.Linear(7, hidden) 169 | 170 | def forward(self, x, edge_index, edge_attr): 171 | edge_attr = self.edge_encoder(edge_attr) 172 | edge_index, _ = remove_self_loops(edge_index) 173 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 174 | out = self.fea_mlp(self.propagate(edge_index, x=x, edge_attr=edge_attr)) 175 | if self.BN is not None: 176 | out = self.BN(out) 177 | return out 178 | 179 | def message(self, x_i, x_j, edge_attr): 180 | xe = x_j + edge_attr 181 | aggr_emb = self.aggr_mlp(torch.cat([x_i, xe], dim=-1)) 182 | return aggr_emb * xe 183 | 184 | def update(self, aggr_out, x): 185 | root_emb = self.aggr_mlp(torch.cat([x, x], dim=-1)) 186 | return aggr_out + root_emb * x 187 | 188 | def __repr__(self): 189 | return self.__class__.__name__ 190 | 191 | 192 | class GinConv(MessagePassing): 193 | def __init__(self, hidden, config, **kwargs): 194 | super(GinConv, self).__init__(aggr='add', **kwargs) 195 | 196 | self.fea_mlp = Sequential( 197 | Linear(hidden, hidden), 198 | ReLU(), 199 | Linear(hidden, hidden), 200 | ReLU()) 201 | 202 | if config.BN == 'Y': 203 | self.BN = BN(hidden) 204 | else: 205 | self.BN = None 206 | 207 | self.edge_encoder = torch.nn.Linear(7, hidden) 208 | 209 | def forward(self, x, edge_index, edge_attr): 210 | edge_attr = self.edge_encoder(edge_attr) 211 | edge_index, _ = remove_self_loops(edge_index) 212 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 213 | out = self.fea_mlp(self.propagate(edge_index, x=x, edge_attr=edge_attr)) 214 | if self.BN is not None: 215 | out = self.BN(out) 216 | return out 217 | 218 | def message(self, x_j, edge_attr): 219 | return x_j + edge_attr 220 | 221 | def update(self, aggr_out, x): 222 | return aggr_out + x 223 | 224 | def __repr__(self): 225 | return self.__class__.__name__ 226 | -------------------------------------------------------------------------------- /ogbg/mol/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn.conv import MessagePassing 3 | from torch_geometric.utils import remove_self_loops, add_self_loops 4 | import torch.nn.functional as F 5 | from torch.nn import Linear, Sequential, Tanh, ReLU, ELU, BatchNorm1d as BN 6 | from ogb.graphproppred.mol_encoder import BondEncoder 7 | 8 | 9 | class ExpC(MessagePassing): 10 | def __init__(self, hidden, num_aggr, config, **kwargs): 11 | super(ExpC, self).__init__(aggr='add', **kwargs) 12 | self.hidden = hidden 13 | self.num_aggr = num_aggr 14 | 15 | self.fea_mlp = Sequential( 16 | Linear(hidden * self.num_aggr, hidden), 17 | ReLU(), 18 | Linear(hidden, hidden), 19 | ReLU()) 20 | 21 | self.aggr_mlp = Sequential( 22 | Linear(hidden * 2, self.num_aggr), 23 | Tanh()) 24 | 25 | if config.BN == 'Y': 26 | self.BN = BN(hidden) 27 | else: 28 | self.BN = None 29 | 30 | self.bond_encoder = BondEncoder(emb_dim=hidden) 31 | 32 | def forward(self, x, edge_index, edge_attr): 33 | edge_attr = self.bond_encoder(edge_attr) 34 | edge_index, _ = remove_self_loops(edge_index) 35 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 36 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr) 37 | if self.BN is not None: 38 | out = self.BN(out) 39 | return out 40 | 41 | def message(self, x_i, x_j, edge_attr): 42 | xe = x_j + edge_attr 43 | aggr_emb = self.aggr_mlp(torch.cat([x_i, xe], dim=-1)) 44 | feature2d = torch.matmul(aggr_emb.unsqueeze(-1), xe.unsqueeze(-1) 45 | .transpose(-1, -2)).squeeze(-1).view(-1, self.hidden * self.num_aggr) 46 | return self.fea_mlp(feature2d) 47 | 48 | def update(self, aggr_out, x): 49 | root_emb = self.aggr_mlp(torch.cat([x, x], dim=-1)) 50 | feature2d = torch.matmul(root_emb.unsqueeze(-1), x.unsqueeze(-1) 51 | .transpose(-1, -2)).squeeze(-1).view(-1, self.hidden * self.num_aggr) 52 | return aggr_out + self.fea_mlp(feature2d) 53 | 54 | def __repr__(self): 55 | return self.__class__.__name__ 56 | 57 | 58 | class ExpC_star(MessagePassing): 59 | def __init__(self, hidden, num_aggr, config, **kwargs): 60 | super(ExpC_star, self).__init__(aggr='add', **kwargs) 61 | self.hidden = hidden 62 | self.num_aggr = num_aggr 63 | 64 | self.fea_mlp = Sequential( 65 | Linear(hidden * self.num_aggr, hidden), 66 | ReLU(), 67 | Linear(hidden, hidden), 68 | ReLU()) 69 | 70 | self.aggr_mlp = Sequential( 71 | Linear(hidden * 2, self.num_aggr), 72 | Tanh()) 73 | 74 | if config.BN == 'Y': 75 | self.BN = BN(hidden) 76 | else: 77 | self.BN = None 78 | 79 | self.bond_encoder = BondEncoder(emb_dim=hidden) 80 | 81 | def forward(self, x, edge_index, edge_attr): 82 | edge_attr = self.bond_encoder(edge_attr) 83 | edge_index, _ = remove_self_loops(edge_index) 84 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 85 | out = self.fea_mlp(self.propagate(edge_index, x=x, edge_attr=edge_attr)) 86 | if self.BN is not None: 87 | out = self.BN(out) 88 | return out 89 | 90 | def message(self, x_i, x_j, edge_attr): 91 | xe = x_j + edge_attr 92 | aggr_emb = self.aggr_mlp(torch.cat([x_i, xe], dim=-1)) 93 | feature2d = torch.matmul(aggr_emb.unsqueeze(-1), xe.unsqueeze(-1) 94 | .transpose(-1, -2)).squeeze(-1).view(-1, self.hidden * self.num_aggr) 95 | return feature2d 96 | 97 | def update(self, aggr_out, x): 98 | root_emb = self.aggr_mlp(torch.cat([x, x], dim=-1)) 99 | feature2d = torch.matmul(root_emb.unsqueeze(-1), x.unsqueeze(-1) 100 | .transpose(-1, -2)).squeeze(-1).view(-1, self.hidden * self.num_aggr) 101 | return aggr_out + feature2d 102 | 103 | def __repr__(self): 104 | return self.__class__.__name__ 105 | 106 | 107 | class CombC(MessagePassing): 108 | def __init__(self, hidden, config, **kwargs): 109 | super(CombC, self).__init__(aggr='add', **kwargs) 110 | 111 | self.fea_mlp = Sequential( 112 | Linear(hidden, hidden), 113 | ReLU(), 114 | Linear(hidden, hidden), 115 | ReLU()) 116 | 117 | self.aggr_mlp = Sequential( 118 | Linear(hidden * 2, hidden), 119 | Tanh()) 120 | 121 | if config.BN == 'Y': 122 | self.BN = BN(hidden) 123 | else: 124 | self.BN = None 125 | 126 | self.bond_encoder = BondEncoder(emb_dim=hidden) 127 | 128 | def forward(self, x, edge_index, edge_attr): 129 | edge_attr = self.bond_encoder(edge_attr) 130 | edge_index, _ = remove_self_loops(edge_index) 131 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 132 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr) 133 | if self.BN is not None: 134 | out = self.BN(out) 135 | return out 136 | 137 | def message(self, x_i, x_j, edge_attr): 138 | xe = x_j + edge_attr 139 | aggr_emb = self.aggr_mlp(torch.cat([x_i, xe], dim=-1)) 140 | return self.fea_mlp(aggr_emb * xe) 141 | 142 | def update(self, aggr_out, x): 143 | root_emb = self.aggr_mlp(torch.cat([x, x], dim=-1)) 144 | return aggr_out + self.fea_mlp(root_emb * x) 145 | 146 | def __repr__(self): 147 | return self.__class__.__name__ 148 | 149 | 150 | class CombC_star(MessagePassing): 151 | def __init__(self, hidden, config, **kwargs): 152 | super(CombC_star, self).__init__(aggr='add', **kwargs) 153 | 154 | self.fea_mlp = Sequential( 155 | Linear(hidden, hidden), 156 | ReLU(), 157 | Linear(hidden, hidden), 158 | ReLU()) 159 | 160 | self.aggr_mlp = Sequential( 161 | Linear(hidden * 2, hidden), 162 | Tanh()) 163 | 164 | if config.BN == 'Y': 165 | self.BN = BN(hidden) 166 | else: 167 | self.BN = None 168 | 169 | self.bond_encoder = BondEncoder(emb_dim=hidden) 170 | 171 | def forward(self, x, edge_index, edge_attr): 172 | edge_attr = self.bond_encoder(edge_attr) 173 | edge_index, _ = remove_self_loops(edge_index) 174 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 175 | out = self.fea_mlp(self.propagate(edge_index, x=x, edge_attr=edge_attr)) 176 | if self.BN is not None: 177 | out = self.BN(out) 178 | return out 179 | 180 | def message(self, x_i, x_j, edge_attr): 181 | xe = x_j + edge_attr 182 | aggr_emb = self.aggr_mlp(torch.cat([x_i, xe], dim=-1)) 183 | return aggr_emb * xe 184 | 185 | def update(self, aggr_out, x): 186 | root_emb = self.aggr_mlp(torch.cat([x, x], dim=-1)) 187 | return aggr_out + root_emb * x 188 | 189 | def __repr__(self): 190 | return self.__class__.__name__ 191 | 192 | 193 | class GinConv(MessagePassing): 194 | def __init__(self, hidden, config, **kwargs): 195 | super(GinConv, self).__init__(aggr='add', **kwargs) 196 | 197 | self.fea_mlp = Sequential( 198 | Linear(hidden, hidden), 199 | ReLU(), 200 | Linear(hidden, hidden), 201 | ReLU()) 202 | 203 | if config.BN == 'Y': 204 | self.BN = BN(hidden) 205 | else: 206 | self.BN = None 207 | 208 | self.bond_encoder = BondEncoder(emb_dim=hidden) 209 | 210 | def forward(self, x, edge_index, edge_attr): 211 | edge_attr = self.bond_encoder(edge_attr) 212 | edge_index, _ = remove_self_loops(edge_index) 213 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 214 | out = self.fea_mlp(self.propagate(edge_index, x=x, edge_attr=edge_attr)) 215 | if self.BN is not None: 216 | out = self.BN(out) 217 | return out 218 | 219 | def message(self, x_j, edge_attr): 220 | return x_j + edge_attr 221 | 222 | def update(self, aggr_out, x): 223 | return aggr_out + x 224 | 225 | def __repr__(self): 226 | return self.__class__.__name__ 227 | -------------------------------------------------------------------------------- /ogbg/code/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn.conv import MessagePassing 3 | from torch_geometric.utils import remove_self_loops, add_self_loops 4 | import torch.nn.functional as F 5 | from torch.nn import Linear, Sequential, Tanh, ReLU, ELU, BatchNorm1d as BN 6 | 7 | 8 | class ExpC(MessagePassing): 9 | def __init__(self, hidden, num_aggr, config, **kwargs): 10 | super(ExpC, self).__init__(aggr='add', **kwargs) 11 | self.hidden = hidden 12 | self.num_aggr = num_aggr 13 | 14 | self.fea_mlp = Sequential( 15 | Linear(hidden * self.num_aggr, hidden), 16 | ReLU(), 17 | Linear(hidden, hidden), 18 | ReLU()) 19 | 20 | self.aggr_mlp = Sequential( 21 | Linear(hidden * 2, self.num_aggr), 22 | Tanh()) 23 | 24 | if config.BN == 'Y': 25 | self.BN = BN(hidden) 26 | else: 27 | self.BN = None 28 | 29 | # edge_attr is two dimensional after augment_edge transformation 30 | self.edge_encoder = torch.nn.Linear(2, hidden) 31 | 32 | def forward(self, x, edge_index, edge_attr): 33 | edge_attr = self.edge_encoder(edge_attr) 34 | edge_index, _ = remove_self_loops(edge_index) 35 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 36 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr) 37 | if self.BN is not None: 38 | out = self.BN(out) 39 | return out 40 | 41 | def message(self, x_i, x_j, edge_attr): 42 | xe = x_j + edge_attr 43 | aggr_emb = self.aggr_mlp(torch.cat([x_i, xe], dim=-1)) 44 | feature2d = torch.matmul(aggr_emb.unsqueeze(-1), xe.unsqueeze(-1) 45 | .transpose(-1, -2)).squeeze(-1).view(-1, self.hidden * self.num_aggr) 46 | return self.fea_mlp(feature2d) 47 | 48 | def update(self, aggr_out, x): 49 | root_emb = self.aggr_mlp(torch.cat([x, x], dim=-1)) 50 | feature2d = torch.matmul(root_emb.unsqueeze(-1), x.unsqueeze(-1) 51 | .transpose(-1, -2)).squeeze(-1).view(-1, self.hidden * self.num_aggr) 52 | return aggr_out + self.fea_mlp(feature2d) 53 | 54 | def __repr__(self): 55 | return self.__class__.__name__ 56 | 57 | 58 | class ExpC_star(MessagePassing): 59 | def __init__(self, hidden, num_aggr, config, **kwargs): 60 | super(ExpC_star, self).__init__(aggr='add', **kwargs) 61 | self.hidden = hidden 62 | self.num_aggr = num_aggr 63 | 64 | self.fea_mlp = Sequential( 65 | Linear(hidden * self.num_aggr, hidden), 66 | ReLU(), 67 | Linear(hidden, hidden), 68 | ReLU()) 69 | 70 | self.aggr_mlp = Sequential( 71 | Linear(hidden * 2, self.num_aggr), 72 | Tanh()) 73 | 74 | if config.BN == 'Y': 75 | self.BN = BN(hidden) 76 | else: 77 | self.BN = None 78 | 79 | # edge_attr is two dimensional after augment_edge transformation 80 | self.edge_encoder = torch.nn.Linear(2, hidden) 81 | 82 | def forward(self, x, edge_index, edge_attr): 83 | edge_attr = self.edge_encoder(edge_attr) 84 | edge_index, _ = remove_self_loops(edge_index) 85 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 86 | out = self.fea_mlp(self.propagate(edge_index, x=x, edge_attr=edge_attr)) 87 | if self.BN is not None: 88 | out = self.BN(out) 89 | return out 90 | 91 | def message(self, x_i, x_j, edge_attr): 92 | xe = x_j + edge_attr 93 | aggr_emb = self.aggr_mlp(torch.cat([x_i, xe], dim=-1)) 94 | feature2d = torch.matmul(aggr_emb.unsqueeze(-1), xe.unsqueeze(-1) 95 | .transpose(-1, -2)).squeeze(-1).view(-1, self.hidden * self.num_aggr) 96 | return feature2d 97 | 98 | def update(self, aggr_out, x): 99 | root_emb = self.aggr_mlp(torch.cat([x, x], dim=-1)) 100 | feature2d = torch.matmul(root_emb.unsqueeze(-1), x.unsqueeze(-1) 101 | .transpose(-1, -2)).squeeze(-1).view(-1, self.hidden * self.num_aggr) 102 | return aggr_out + feature2d 103 | 104 | def __repr__(self): 105 | return self.__class__.__name__ 106 | 107 | 108 | class CombC(MessagePassing): 109 | def __init__(self, hidden, config, **kwargs): 110 | super(CombC, self).__init__(aggr='add', **kwargs) 111 | 112 | self.fea_mlp = Sequential( 113 | Linear(hidden, hidden), 114 | ReLU(), 115 | Linear(hidden, hidden), 116 | ReLU()) 117 | 118 | self.aggr_mlp = Sequential( 119 | Linear(hidden * 2, hidden), 120 | Tanh()) 121 | 122 | if config.BN == 'Y': 123 | self.BN = BN(hidden) 124 | else: 125 | self.BN = None 126 | 127 | # edge_attr is two dimensional after augment_edge transformation 128 | self.edge_encoder = torch.nn.Linear(2, hidden) 129 | 130 | def forward(self, x, edge_index, edge_attr): 131 | edge_attr = self.edge_encoder(edge_attr) 132 | edge_index, _ = remove_self_loops(edge_index) 133 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 134 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr) 135 | if self.BN is not None: 136 | out = self.BN(out) 137 | return out 138 | 139 | def message(self, x_i, x_j, edge_attr): 140 | xe = x_j + edge_attr 141 | aggr_emb = self.aggr_mlp(torch.cat([x_i, xe], dim=-1)) 142 | return self.fea_mlp(aggr_emb * xe) 143 | 144 | def update(self, aggr_out, x): 145 | root_emb = self.aggr_mlp(torch.cat([x, x], dim=-1)) 146 | return aggr_out + self.fea_mlp(root_emb * x) 147 | 148 | def __repr__(self): 149 | return self.__class__.__name__ 150 | 151 | 152 | class CombC_star(MessagePassing): 153 | def __init__(self, hidden, config, **kwargs): 154 | super(CombC_star, self).__init__(aggr='add', **kwargs) 155 | 156 | self.fea_mlp = Sequential( 157 | Linear(hidden, hidden), 158 | ReLU(), 159 | Linear(hidden, hidden), 160 | ReLU()) 161 | 162 | self.aggr_mlp = Sequential( 163 | Linear(hidden * 2, hidden), 164 | Tanh()) 165 | 166 | if config.BN == 'Y': 167 | self.BN = BN(hidden) 168 | else: 169 | self.BN = None 170 | 171 | # edge_attr is two dimensional after augment_edge transformation 172 | self.edge_encoder = torch.nn.Linear(2, hidden) 173 | 174 | def forward(self, x, edge_index, edge_attr): 175 | edge_attr = self.edge_encoder(edge_attr) 176 | edge_index, _ = remove_self_loops(edge_index) 177 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 178 | out = self.fea_mlp(self.propagate(edge_index, x=x, edge_attr=edge_attr)) 179 | if self.BN is not None: 180 | out = self.BN(out) 181 | return out 182 | 183 | def message(self, x_i, x_j, edge_attr): 184 | xe = x_j + edge_attr 185 | aggr_emb = self.aggr_mlp(torch.cat([x_i, xe], dim=-1)) 186 | return aggr_emb * xe 187 | 188 | def update(self, aggr_out, x): 189 | root_emb = self.aggr_mlp(torch.cat([x, x], dim=-1)) 190 | return aggr_out + root_emb * x 191 | 192 | def __repr__(self): 193 | return self.__class__.__name__ 194 | 195 | 196 | class GinConv(MessagePassing): 197 | def __init__(self, hidden, config, **kwargs): 198 | super(GinConv, self).__init__(aggr='add', **kwargs) 199 | 200 | self.fea_mlp = Sequential( 201 | Linear(hidden, hidden), 202 | ReLU(), 203 | Linear(hidden, hidden), 204 | ReLU()) 205 | 206 | if config.BN == 'Y': 207 | self.BN = BN(hidden) 208 | else: 209 | self.BN = None 210 | 211 | # edge_attr is two dimensional after augment_edge transformation 212 | self.edge_encoder = torch.nn.Linear(2, hidden) 213 | 214 | def forward(self, x, edge_index, edge_attr): 215 | edge_attr = self.edge_encoder(edge_attr) 216 | edge_index, _ = remove_self_loops(edge_index) 217 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 218 | out = self.fea_mlp(self.propagate(edge_index, x=x, edge_attr=edge_attr)) 219 | if self.BN is not None: 220 | out = self.BN(out) 221 | return out 222 | 223 | def message(self, x_j, edge_attr): 224 | return x_j + edge_attr 225 | 226 | def update(self, aggr_out, x): 227 | return aggr_out + x 228 | 229 | def __repr__(self): 230 | return self.__class__.__name__ 231 | -------------------------------------------------------------------------------- /qm9/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn.conv import MessagePassing 3 | from torch_geometric.utils import remove_self_loops 4 | from torch.nn import Linear, Sequential, Tanh, ReLU, ELU, BatchNorm1d as BN 5 | 6 | 7 | class ExpC(MessagePassing): 8 | def __init__(self, hidden, num_aggr, config, **kwargs): 9 | super(ExpC, self).__init__(aggr='add', **kwargs) 10 | self.hidden = hidden 11 | self.num_aggr = num_aggr 12 | 13 | if config.fea_activation == 'ELU': 14 | self.fea_activation = ELU() 15 | elif config.fea_activation == 'ReLU': 16 | self.fea_activation = ReLU() 17 | 18 | self.fea_mlp = Sequential( 19 | Linear(hidden * self.num_aggr, hidden), 20 | ReLU(), 21 | Linear(hidden, hidden), 22 | self.fea_activation) 23 | 24 | self.aggr_mlp = Sequential( 25 | Linear(hidden * 2, self.num_aggr), 26 | Tanh()) 27 | 28 | self.edge_encoder = torch.nn.Linear(5, hidden) 29 | 30 | if config.BN == 'Y': 31 | self.BN = BN(hidden) 32 | else: 33 | self.BN = None 34 | 35 | def forward(self, x, edge_index, edge_attr): 36 | edge_attr = self.edge_encoder(edge_attr) 37 | edge_index, _ = remove_self_loops(edge_index) 38 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 39 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr) 40 | if self.BN is not None: 41 | out = self.BN(out) 42 | return out 43 | 44 | def message(self, x_i, x_j, edge_attr): 45 | xe = x_j + edge_attr 46 | aggr_emb = self.aggr_mlp(torch.cat([x_i, xe], dim=-1)) 47 | feature2d = torch.matmul(aggr_emb.unsqueeze(-1), xe.unsqueeze(-1) 48 | .transpose(-1, -2)).squeeze(-1).view(-1, self.hidden * self.num_aggr) 49 | return self.fea_mlp(feature2d) 50 | 51 | def update(self, aggr_out, x): 52 | root_emb = self.aggr_mlp(torch.cat([x, x], dim=-1)) 53 | feature2d = torch.matmul(root_emb.unsqueeze(-1), x.unsqueeze(-1) 54 | .transpose(-1, -2)).squeeze(-1).view(-1, self.hidden * self.num_aggr) 55 | return aggr_out + self.fea_mlp(feature2d) 56 | 57 | def __repr__(self): 58 | return self.__class__.__name__ 59 | 60 | 61 | class ExpC_star(MessagePassing): 62 | def __init__(self, hidden, num_aggr, config, **kwargs): 63 | super(ExpC_star, self).__init__(aggr='add', **kwargs) 64 | self.hidden = hidden 65 | self.num_aggr = num_aggr 66 | 67 | if config.fea_activation == 'ELU': 68 | self.fea_activation = ELU() 69 | elif config.fea_activation == 'ReLU': 70 | self.fea_activation = ReLU() 71 | 72 | self.fea_mlp = Sequential( 73 | Linear(hidden * self.num_aggr, hidden), 74 | ReLU(), 75 | Linear(hidden, hidden), 76 | self.fea_activation) 77 | 78 | self.aggr_mlp = Sequential( 79 | Linear(hidden * 2, self.num_aggr), 80 | Tanh()) 81 | 82 | self.edge_encoder = torch.nn.Linear(5, hidden) 83 | 84 | if config.BN == 'Y': 85 | self.BN = BN(hidden) 86 | else: 87 | self.BN = None 88 | 89 | def forward(self, x, edge_index, edge_attr): 90 | edge_attr = self.edge_encoder(edge_attr) 91 | edge_index, _ = remove_self_loops(edge_index) 92 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 93 | out = self.fea_mlp(self.propagate(edge_index, x=x, edge_attr=edge_attr)) 94 | if self.BN is not None: 95 | out = self.BN(out) 96 | return out 97 | 98 | def message(self, x_i, x_j, edge_attr): 99 | xe = x_j + edge_attr 100 | aggr_emb = self.aggr_mlp(torch.cat([x_i, xe], dim=-1)) 101 | feature2d = torch.matmul(aggr_emb.unsqueeze(-1), xe.unsqueeze(-1) 102 | .transpose(-1, -2)).squeeze(-1).view(-1, self.hidden * self.num_aggr) 103 | return feature2d 104 | 105 | def update(self, aggr_out, x): 106 | root_emb = self.aggr_mlp(torch.cat([x, x], dim=-1)) 107 | feature2d = torch.matmul(root_emb.unsqueeze(-1), x.unsqueeze(-1) 108 | .transpose(-1, -2)).squeeze(-1).view(-1, self.hidden * self.num_aggr) 109 | return aggr_out + feature2d 110 | 111 | def __repr__(self): 112 | return self.__class__.__name__ 113 | 114 | 115 | class CombC(MessagePassing): 116 | def __init__(self, hidden, config, **kwargs): 117 | super(CombC, self).__init__(aggr='add', **kwargs) 118 | 119 | if config.fea_activation == 'ELU': 120 | self.fea_activation = ELU() 121 | elif config.fea_activation == 'ReLU': 122 | self.fea_activation = ReLU() 123 | 124 | self.fea_mlp = Sequential( 125 | Linear(hidden, hidden), 126 | ReLU(), 127 | Linear(hidden, hidden), 128 | self.fea_activation) 129 | 130 | self.aggr_mlp = Sequential( 131 | Linear(hidden * 2, hidden), 132 | Tanh()) 133 | 134 | self.edge_encoder = torch.nn.Linear(5, hidden) 135 | 136 | if config.BN == 'Y': 137 | self.BN = BN(hidden) 138 | else: 139 | self.BN = None 140 | 141 | def forward(self, x, edge_index, edge_attr): 142 | edge_attr = self.edge_encoder(edge_attr) 143 | edge_index, _ = remove_self_loops(edge_index) 144 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 145 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr) 146 | if self.BN is not None: 147 | out = self.BN(out) 148 | return out 149 | 150 | def message(self, x_i, x_j, edge_attr): 151 | xe = x_j + edge_attr 152 | aggr_emb = self.aggr_mlp(torch.cat([x_i, xe], dim=-1)) 153 | return self.fea_mlp(aggr_emb * xe) 154 | 155 | def update(self, aggr_out, x): 156 | root_emb = self.aggr_mlp(torch.cat([x, x], dim=-1)) 157 | return aggr_out + self.fea_mlp(root_emb * x) 158 | 159 | def __repr__(self): 160 | return self.__class__.__name__ 161 | 162 | 163 | class CombC_star(MessagePassing): 164 | def __init__(self, hidden, config, **kwargs): 165 | super(CombC_star, self).__init__(aggr='add', **kwargs) 166 | 167 | if config.fea_activation == 'ELU': 168 | self.fea_activation = ELU() 169 | elif config.fea_activation == 'ReLU': 170 | self.fea_activation = ReLU() 171 | 172 | self.fea_mlp = Sequential( 173 | Linear(hidden, hidden), 174 | ReLU(), 175 | Linear(hidden, hidden), 176 | self.fea_activation) 177 | 178 | self.aggr_mlp = Sequential( 179 | Linear(hidden * 2, hidden), 180 | Tanh()) 181 | 182 | self.edge_encoder = torch.nn.Linear(5, hidden) 183 | 184 | if config.BN == 'Y': 185 | self.BN = BN(hidden) 186 | else: 187 | self.BN = None 188 | 189 | def forward(self, x, edge_index, edge_attr): 190 | edge_attr = self.edge_encoder(edge_attr) 191 | edge_index, _ = remove_self_loops(edge_index) 192 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 193 | out = self.fea_mlp(self.propagate(edge_index, x=x, edge_attr=edge_attr)) 194 | if self.BN is not None: 195 | out = self.BN(out) 196 | return out 197 | 198 | def message(self, x_i, x_j, edge_attr): 199 | xe = x_j + edge_attr 200 | aggr_emb = self.aggr_mlp(torch.cat([x_i, xe], dim=-1)) 201 | return aggr_emb * xe 202 | 203 | def update(self, aggr_out, x): 204 | root_emb = self.aggr_mlp(torch.cat([x, x], dim=-1)) 205 | return aggr_out + root_emb * x 206 | 207 | def __repr__(self): 208 | return self.__class__.__name__ 209 | 210 | 211 | class GinConv(MessagePassing): 212 | def __init__(self, hidden, config, **kwargs): 213 | super(GinConv, self).__init__(aggr='add', **kwargs) 214 | 215 | if config.fea_activation == 'ELU': 216 | self.fea_activation = ELU() 217 | elif config.fea_activation == 'ReLU': 218 | self.fea_activation = ReLU() 219 | 220 | self.fea_mlp = Sequential( 221 | Linear(hidden, hidden), 222 | ReLU(), 223 | Linear(hidden, hidden), 224 | self.fea_activation) 225 | 226 | self.edge_encoder = torch.nn.Linear(5, hidden) 227 | 228 | if config.BN == 'Y': 229 | self.BN = BN(hidden) 230 | else: 231 | self.BN = None 232 | 233 | def forward(self, x, edge_index, edge_attr): 234 | edge_attr = self.edge_encoder(edge_attr) 235 | edge_index, _ = remove_self_loops(edge_index) 236 | # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 237 | out = self.fea_mlp(self.propagate(edge_index, x=x, edge_attr=edge_attr)) 238 | if self.BN is not None: 239 | out = self.BN(out) 240 | return out 241 | 242 | def message(self, x_j, edge_attr): 243 | return x_j + edge_attr 244 | 245 | def update(self, aggr_out, x): 246 | return aggr_out + x 247 | 248 | def __repr__(self): 249 | return self.__class__.__name__ 250 | -------------------------------------------------------------------------------- /tu/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import torch.nn.functional as F 4 | from torch_geometric.data import DataLoader 5 | from tensorboardX import SummaryWriter 6 | 7 | import sys 8 | sys.path.append('..') 9 | 10 | from model import Net 11 | from sklearn.model_selection import StratifiedKFold 12 | import numpy as np 13 | import os.path as osp 14 | from torch_geometric.datasets import TUDataset 15 | from torch_geometric.utils import degree 16 | import torch_geometric.transforms as T 17 | 18 | from utils.config import process_config, get_args 19 | 20 | 21 | class NormalizedDegree(object): 22 | def __init__(self, mean, std): 23 | self.mean = mean 24 | self.std = std 25 | 26 | def __call__(self, data): 27 | deg = degree(data.edge_index[0], dtype=torch.float) 28 | deg = (deg - self.mean) / self.std 29 | data.x = deg.view(-1, 1) 30 | return data 31 | 32 | 33 | def get_dataset(name): 34 | path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data', name) 35 | dataset = TUDataset(path, name) 36 | dataset.data.edge_attr = None 37 | 38 | if dataset.data.x is None: 39 | max_degree = 0 40 | degs = [] 41 | for data in dataset: 42 | degs += [degree(data.edge_index[0], dtype=torch.long)] 43 | max_degree = max(max_degree, degs[-1].max().item()) 44 | 45 | if max_degree < 1000: 46 | dataset.transform = T.OneHotDegree(max_degree) 47 | else: 48 | deg = torch.cat(degs, dim=0).to(torch.float) 49 | mean, std = deg.mean().item(), deg.std().item() 50 | dataset.transform = NormalizedDegree(mean, std) 51 | 52 | return dataset 53 | 54 | 55 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 56 | 57 | 58 | def train(loader, model, optimizer): 59 | model.train() 60 | loss_all = 0 61 | 62 | for data in loader: 63 | data = data.to(device) 64 | optimizer.zero_grad() 65 | output = model(data) 66 | loss = F.nll_loss(output, data.y) 67 | loss.backward() 68 | loss_all += loss.item() * data.num_graphs 69 | optimizer.step() 70 | return loss_all / len(loader.dataset) 71 | 72 | 73 | def test(loader, model): 74 | model.eval() 75 | correct = 0 76 | 77 | for data in loader: 78 | data = data.to(device) 79 | pred = model(data).max(dim=1)[1] 80 | correct += pred.eq(data.y).sum().item() 81 | return correct / len(loader.dataset) 82 | 83 | 84 | def run_given_fold(net, 85 | dataset, 86 | train_loader, 87 | val_loader, 88 | writer, 89 | ts_kf_algo_hp, 90 | config): 91 | model = net(dataset, config=config.architecture) 92 | model.to(device) 93 | optimizer = torch.optim.Adam(model.parameters(), lr=config.hyperparams.learning_rate) 94 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.hyperparams.step_size, 95 | gamma=config.hyperparams.decay_rate) 96 | 97 | train_losses = [] 98 | train_accs = [] 99 | test_accs = [] 100 | for epoch in range(1, config.hyperparams.epochs): 101 | train_loss = train(train_loader, model, optimizer) 102 | train_acc = test(train_loader, model) 103 | test_acc = test(val_loader, model) 104 | 105 | scheduler.step() 106 | 107 | train_losses.append(train_loss) 108 | train_accs.append(train_acc) 109 | test_accs.append(test_acc) 110 | 111 | print('Epoch: {:03d}, Train Loss: {:.7f}, ' 112 | 'Train Acc: {:.7f}, Test Acc: {:.7f}'.format(epoch, train_loss, 113 | train_acc, test_acc)) 114 | 115 | writer.add_scalars(config.dataset_name, {ts_kf_algo_hp + '/test_acc': test_acc}, epoch) 116 | writer.add_scalars(config.dataset_name, {ts_kf_algo_hp + '/train_acc': train_acc}, epoch) 117 | writer.add_scalars(config.dataset_name, {ts_kf_algo_hp + '/train_loss': train_loss}, epoch) 118 | 119 | return test_accs, train_losses, train_accs 120 | 121 | 122 | def k_fold(dataset, folds, seed): 123 | skf = StratifiedKFold(folds, shuffle=True, random_state=seed) 124 | 125 | val_indices, train_indices = [], [] 126 | for _, idx in skf.split(torch.zeros(len(dataset)), dataset.data.y): 127 | val_indices.append(torch.from_numpy(idx)) 128 | 129 | for i in range(folds): 130 | train_mask = torch.ones(len(dataset), dtype=torch.uint8) 131 | train_mask[val_indices[i]] = 0 132 | train_indices.append(train_mask.nonzero().view(-1)) 133 | 134 | return train_indices, val_indices 135 | 136 | 137 | def run_model(net, dataset, config): 138 | folds_test_accs = [] 139 | folds_train_losses = [] 140 | folds_train_accs = [] 141 | 142 | def k_folds_average(avg_folds): 143 | avg_folds = np.vstack(avg_folds) 144 | return np.mean(avg_folds, axis=0), np.std(avg_folds, axis=0) 145 | 146 | writer = SummaryWriter(config.directory) 147 | 148 | algo_hp = str(config.commit_id[0:7]) + '_' \ 149 | + str(config.architecture.methods) + '_' \ 150 | + str(config.architecture.pooling) + '_' \ 151 | + str(config.architecture.JK) + '_' \ 152 | + str(config.architecture.layers) + '_' \ 153 | + str(config.architecture.hidden) + '_' \ 154 | + str(config.architecture.variants.BN) + '_' \ 155 | + str(config.architecture.dropout) + '_' \ 156 | + str(config.hyperparams.learning_rate) + '_' \ 157 | + str(config.hyperparams.step_size) + '_' \ 158 | + str(config.hyperparams.decay_rate) + '_' \ 159 | + 'B' + str(config.hyperparams.batch_size) + '_' \ 160 | + 'S' + str(config.seed if config.get('seed') is not None else "na") + '_' \ 161 | + 'W' + str(config.num_workers if config.get('num_workers') is not None else "na") 162 | 163 | for fold, (train_idx, val_idx) in enumerate( 164 | zip(*k_fold(dataset, 10, config.get('seed', 12345)))): 165 | 166 | if fold >= config.get('folds_cut', 10): 167 | break 168 | 169 | train_dataset = dataset[train_idx] 170 | val_dataset = dataset[val_idx] 171 | 172 | train_loader = DataLoader(train_dataset, config.hyperparams.batch_size, shuffle=True, num_workers=config.num_workers) 173 | val_loader = DataLoader(val_dataset, config.hyperparams.batch_size, shuffle=False, num_workers=config.num_workers) 174 | 175 | print('-------- FOLD' + str(fold) + 176 | ' DATASET=' + config.dataset_name + 177 | ', COMMIT_ID=' + config.commit_id) 178 | 179 | test_accs, train_losses, train_accs = run_given_fold( 180 | net, 181 | dataset, 182 | train_loader, 183 | val_loader, 184 | writer=writer, 185 | ts_kf_algo_hp=str(config.time_stamp) + '/f' + str(fold) + '/' + algo_hp, 186 | config=config 187 | ) 188 | 189 | folds_test_accs.append(np.array(test_accs)) 190 | folds_train_losses.append(np.array(train_losses)) 191 | folds_train_accs.append(np.array(train_accs)) 192 | 193 | # following the protocol of other GNN baselines 194 | avg_test_accs, std_test_accs = k_folds_average(folds_test_accs) 195 | sel_epoch = np.argmax(avg_test_accs) 196 | sel_test_acc = np.max(avg_test_accs) 197 | sel_test_acc_std = std_test_accs[sel_epoch] 198 | sel_test_with_std = str(sel_test_acc) + '_' + str(sel_test_acc_std) 199 | 200 | avg_train_losses, std_train_losses = k_folds_average(folds_train_losses) 201 | sel_tl_with_std = str(np.min(avg_train_losses)) + '_' + str(std_train_losses[np.argmin(avg_train_losses)]) 202 | 203 | avg_train_accs, std_train_accs = k_folds_average(folds_train_accs) 204 | sel_ta_with_std = str(np.max(avg_train_accs)) + '_' + str(std_train_accs[np.argmax(avg_train_accs)]) 205 | 206 | print('--------') 207 | print('Best Test Acc: ' + sel_test_with_std + ', Epoch: ' + str(sel_epoch)) 208 | print('Best Train Loss: ' + sel_tl_with_std) 209 | print('Best Train Acc: ' + sel_ta_with_std) 210 | 211 | print('FOLD' + str(fold + 1) + ', ' 212 | + config.dataset_name + ', ' 213 | + str(config.time_stamp) + '/' 214 | + str(config.get('seed', 'NoSeed')) + '/' 215 | + str(config.architecture.layers) + '_' 216 | + str(config.architecture.hidden) + '_' 217 | + str(config.hyperparams.learning_rate) + '_' 218 | + str(config.hyperparams.step_size) + '_' 219 | + str(config.hyperparams.decay_rate) 220 | + '_B' + str(config.hyperparams.batch_size) 221 | + ', BT=' + sel_test_with_std 222 | + ', BE=' + str(sel_epoch) 223 | + ', ID=' + config.commit_id) 224 | 225 | ts_fk_algo_hp = str(config.time_stamp) + '/fk' + str(config.get('folds_cut', 10)) + '/' + algo_hp 226 | 227 | writer.add_scalars(config.dataset_name, {ts_fk_algo_hp + '/best_acc': sel_test_acc}, fold) 228 | writer.add_scalars(config.dataset_name, {ts_fk_algo_hp + '/best_std': sel_test_acc_std}, fold) 229 | writer.add_scalars(config.dataset_name, {ts_fk_algo_hp + '/best_epoch': sel_epoch}, fold) 230 | 231 | for i in range(1, config.hyperparams.epochs): 232 | test_acc = avg_test_accs[i - 1] 233 | train_loss = avg_train_losses[i - 1] 234 | train_acc = avg_train_accs[i - 1] 235 | 236 | writer.add_scalars(config.dataset_name, {ts_fk_algo_hp + '/test_acc': test_acc}, i) 237 | writer.add_scalars(config.dataset_name, {ts_fk_algo_hp + '/train_loss': train_loss}, i) 238 | writer.add_scalars(config.dataset_name, {ts_fk_algo_hp + '/train_acc': train_acc}, i) 239 | 240 | # writer.export_scalars_to_json("./all_scalars.json") 241 | writer.close() 242 | 243 | 244 | def main(): 245 | args = get_args() 246 | config = process_config(args) 247 | print(config) 248 | 249 | if config.get('seed') is not None: 250 | random.seed(config.seed) 251 | torch.manual_seed(config.seed) 252 | np.random.seed(config.seed) 253 | if torch.cuda.is_available(): 254 | torch.cuda.manual_seed_all(config.seed) 255 | 256 | dataset = get_dataset(config.dataset_name).shuffle() 257 | run_model(Net, dataset, config=config) 258 | 259 | 260 | if __name__ == "__main__": 261 | main() 262 | -------------------------------------------------------------------------------- /ogbg/code/main.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch_geometric.data import DataLoader 4 | from tensorboardX import SummaryWriter 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | from torchvision import transforms 8 | 9 | from tqdm import tqdm 10 | import time 11 | import numpy as np 12 | import pandas as pd 13 | import os 14 | 15 | ### importing OGB 16 | from ogb.graphproppred import PygGraphPropPredDataset, Evaluator 17 | 18 | import sys 19 | sys.path.append('../..') 20 | 21 | ### importing utils 22 | from ogbg.code.proc import ASTNodeEncoder, get_vocab_mapping 23 | ### for data transform 24 | from ogbg.code.proc import augment_edge, encode_y_to_arr, decode_arr_to_seq 25 | 26 | from model import Net 27 | from utils.config import process_config, get_args 28 | 29 | 30 | class In: 31 | def readline(self): 32 | return "y\n" 33 | 34 | def close(self): 35 | pass 36 | 37 | 38 | def train(model, device, loader, optimizer, multicls_criterion): 39 | model.train() 40 | 41 | loss_accum = 0 42 | for step, batch in enumerate(loader): 43 | batch = batch.to(device) 44 | 45 | if batch.x.shape[0] == 1 or batch.batch[-1] == 0: 46 | pass 47 | else: 48 | pred_list = model(batch) 49 | optimizer.zero_grad() 50 | 51 | loss = 0 52 | for i in range(len(pred_list)): 53 | loss += multicls_criterion(pred_list[i].to(torch.float32), batch.y_arr[:, i]) 54 | 55 | loss = loss / len(pred_list) 56 | 57 | loss.backward() 58 | optimizer.step() 59 | 60 | loss_accum += loss.item() 61 | 62 | print('Average training loss: {}'.format(loss_accum / (step + 1))) 63 | return loss_accum / (step + 1) 64 | 65 | 66 | def eval(model, device, loader, evaluator, arr_to_seq): 67 | model.eval() 68 | seq_ref_list = [] 69 | seq_pred_list = [] 70 | 71 | for step, batch in enumerate(loader): 72 | batch = batch.to(device) 73 | 74 | if batch.x.shape[0] == 1: 75 | pass 76 | else: 77 | with torch.no_grad(): 78 | pred_list = model(batch) 79 | 80 | mat = [] 81 | for i in range(len(pred_list)): 82 | mat.append(torch.argmax(pred_list[i], dim=1).view(-1, 1)) 83 | mat = torch.cat(mat, dim=1) 84 | 85 | seq_pred = [arr_to_seq(arr) for arr in mat] 86 | 87 | # PyG = 1.4.3 88 | # seq_ref = [batch.y[i][0] for i in range(len(batch.y))] 89 | 90 | # PyG >= 1.5.0 91 | seq_ref = [batch.y[i] for i in range(len(batch.y))] 92 | 93 | seq_ref_list.extend(seq_ref) 94 | seq_pred_list.extend(seq_pred) 95 | 96 | input_dict = {"seq_ref": seq_ref_list, "seq_pred": seq_pred_list} 97 | 98 | return evaluator.eval(input_dict) 99 | 100 | 101 | def main(): 102 | args = get_args() 103 | config = process_config(args) 104 | print(config) 105 | 106 | if config.get('seed') is not None: 107 | random.seed(config.seed) 108 | torch.manual_seed(config.seed) 109 | np.random.seed(config.seed) 110 | if torch.cuda.is_available(): 111 | torch.cuda.manual_seed_all(config.seed) 112 | 113 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 114 | 115 | ### automatic dataloading and splitting 116 | 117 | sys.stdin = In() 118 | 119 | dataset = PygGraphPropPredDataset(name=config.dataset_name) 120 | 121 | seq_len_list = np.array([len(seq) for seq in dataset.data.y]) 122 | print('Target seqence less or equal to {} is {}%.'.format(config.max_seq_len, np.sum(seq_len_list <= config.max_seq_len) / len(seq_len_list))) 123 | 124 | split_idx = dataset.get_idx_split() 125 | 126 | # print(split_idx['train']) 127 | # print(split_idx['valid']) 128 | # print(split_idx['test']) 129 | 130 | # train_method_name = [' '.join(dataset.data.y[i]) for i in split_idx['train']] 131 | # valid_method_name = [' '.join(dataset.data.y[i]) for i in split_idx['valid']] 132 | # test_method_name = [' '.join(dataset.data.y[i]) for i in split_idx['test']] 133 | # print('#train') 134 | # print(len(train_method_name)) 135 | # print('#valid') 136 | # print(len(valid_method_name)) 137 | # print('#test') 138 | # print(len(test_method_name)) 139 | 140 | # train_method_name_set = set(train_method_name) 141 | # valid_method_name_set = set(valid_method_name) 142 | # test_method_name_set = set(test_method_name) 143 | 144 | # # unique method name 145 | # print('#unique train') 146 | # print(len(train_method_name_set)) 147 | # print('#unique valid') 148 | # print(len(valid_method_name_set)) 149 | # print('#unique test') 150 | # print(len(test_method_name_set)) 151 | 152 | # # unique valid/test method name 153 | # print('#valid unseen during training') 154 | # print(len(valid_method_name_set - train_method_name_set)) 155 | # print('#test unseen during training') 156 | # print(len(test_method_name_set - train_method_name_set)) 157 | 158 | 159 | ### building vocabulary for sequence predition. Only use training data. 160 | 161 | vocab2idx, idx2vocab = get_vocab_mapping([dataset.data.y[i] for i in split_idx['train']], config.num_vocab) 162 | 163 | # test encoder and decoder 164 | # for data in dataset: 165 | # # PyG >= 1.5.0 166 | # print(data.y) 167 | # 168 | # # PyG 1.4.3 169 | # # print(data.y[0]) 170 | # data = encode_y_to_arr(data, vocab2idx, config.max_seq_len) 171 | # print(data.y_arr[0]) 172 | # decoded_seq = decode_arr_to_seq(data.y_arr[0], idx2vocab) 173 | # print(decoded_seq) 174 | # print('') 175 | 176 | ## test augment_edge 177 | # data = dataset[2] 178 | # print(data) 179 | # data_augmented = augment_edge(data) 180 | # print(data_augmented) 181 | 182 | ### set the transform function 183 | # augment_edge: add next-token edge as well as inverse edges. add edge attributes. 184 | # encode_y_to_arr: add y_arr to PyG data object, indicating the array representation of a sequence. 185 | dataset.transform = transforms.Compose([augment_edge, lambda data: encode_y_to_arr(data, vocab2idx, config.max_seq_len)]) 186 | 187 | ### automatic evaluator. takes dataset name as input 188 | evaluator = Evaluator(config.dataset_name) 189 | 190 | train_loader = DataLoader(dataset[split_idx["train"]], batch_size=config.hyperparams.batch_size, shuffle=True, num_workers=config.num_workers) 191 | valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=config.hyperparams.batch_size, shuffle=False, num_workers=config.num_workers) 192 | test_loader = DataLoader(dataset[split_idx["test"]], batch_size=config.hyperparams.batch_size, shuffle=False, num_workers=config.num_workers) 193 | 194 | nodetypes_mapping = pd.read_csv(os.path.join(dataset.root, 'mapping', 'typeidx2type.csv.gz')) 195 | nodeattributes_mapping = pd.read_csv(os.path.join(dataset.root, 'mapping', 'attridx2attr.csv.gz')) 196 | 197 | ### Encoding node features into emb_dim vectors. 198 | ### The following three node features are used. 199 | # 1. node type 200 | # 2. node attribute 201 | # 3. node depth 202 | node_encoder = ASTNodeEncoder(config.architecture.hidden, num_nodetypes=len(nodetypes_mapping['type']), num_nodeattributes=len(nodeattributes_mapping['attr']), max_depth=20) 203 | 204 | model = Net(config.architecture, 205 | num_vocab=len(vocab2idx), 206 | max_seq_len=config.max_seq_len, 207 | node_encoder=node_encoder).to(device) 208 | 209 | optimizer = optim.Adam(model.parameters(), lr=config.hyperparams.learning_rate) 210 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.hyperparams.step_size, 211 | gamma=config.hyperparams.decay_rate) 212 | 213 | multicls_criterion = torch.nn.CrossEntropyLoss() 214 | 215 | valid_curve = [] 216 | test_curve = [] 217 | train_curve = [] 218 | trainL_curve = [] 219 | 220 | writer = SummaryWriter(config.directory) 221 | 222 | ts_fk_algo_hp = str(config.time_stamp) + '_' \ 223 | + str(config.commit_id[0:7]) + '_' \ 224 | + str(config.architecture.methods) + '_' \ 225 | + str(config.architecture.pooling) + '_' \ 226 | + str(config.architecture.JK) + '_' \ 227 | + str(config.architecture.layers) + '_' \ 228 | + str(config.architecture.hidden) + '_' \ 229 | + str(config.architecture.variants.BN) + '_' \ 230 | + str(config.architecture.dropout) + '_' \ 231 | + str(config.hyperparams.learning_rate) + '_' \ 232 | + str(config.hyperparams.step_size) + '_' \ 233 | + str(config.hyperparams.decay_rate) + '_' \ 234 | + 'B' + str(config.hyperparams.batch_size) + '_' \ 235 | + 'S' + str(config.seed if config.get('seed') is not None else "na") + '_' \ 236 | + 'W' + str(config.num_workers if config.get('num_workers') is not None else "na") 237 | 238 | for epoch in range(1, config.hyperparams.epochs + 1): 239 | print("Epoch {} training...".format(epoch)) 240 | train_loss = train(model, device, train_loader, optimizer, multicls_criterion) 241 | 242 | scheduler.step() 243 | 244 | print('Evaluating...') 245 | train_perf = eval(model, device, train_loader, evaluator, arr_to_seq=lambda arr: decode_arr_to_seq(arr, idx2vocab)) 246 | valid_perf = eval(model, device, valid_loader, evaluator, arr_to_seq=lambda arr: decode_arr_to_seq(arr, idx2vocab)) 247 | test_perf = eval(model, device, test_loader, evaluator, arr_to_seq=lambda arr: decode_arr_to_seq(arr, idx2vocab)) 248 | 249 | print('Train:', train_perf[dataset.eval_metric], 250 | 'Validation:', valid_perf[dataset.eval_metric], 251 | 'Test:', test_perf[dataset.eval_metric], 252 | 'Train loss:', train_loss) 253 | 254 | train_curve.append(train_perf[dataset.eval_metric]) 255 | valid_curve.append(valid_perf[dataset.eval_metric]) 256 | test_curve.append(test_perf[dataset.eval_metric]) 257 | trainL_curve.append(train_loss) 258 | 259 | writer.add_scalars(config.dataset_name, {ts_fk_algo_hp + '/traP': train_perf[dataset.eval_metric]}, epoch) 260 | writer.add_scalars(config.dataset_name, {ts_fk_algo_hp + '/valP': valid_perf[dataset.eval_metric]}, epoch) 261 | writer.add_scalars(config.dataset_name, {ts_fk_algo_hp + '/tstP': test_perf[dataset.eval_metric]}, epoch) 262 | writer.add_scalars(config.dataset_name, {ts_fk_algo_hp + '/traL': train_loss}, epoch) 263 | writer.close() 264 | 265 | print('F1') 266 | best_val_epoch = np.argmax(np.array(valid_curve)) 267 | best_train = max(train_curve) 268 | print('Finished training!') 269 | print('Best validation score: {}'.format(valid_curve[best_val_epoch])) 270 | print('Test score: {}'.format(test_curve[best_val_epoch])) 271 | 272 | print('Finished test: {}, Validation: {}, Train: {}, epoch: {}, best train: {}, best loss: {}' 273 | .format(test_curve[best_val_epoch], valid_curve[best_val_epoch], train_curve[best_val_epoch], 274 | best_val_epoch, best_train, min(trainL_curve))) 275 | 276 | 277 | if __name__ == "__main__": 278 | main() 279 | --------------------------------------------------------------------------------