├── configs ├── code2 │ ├── gcn │ │ └── baseline+run1+seed.yml │ ├── gcn-virtual │ │ └── baseline+run1+seed.yml │ ├── transformer │ │ └── pooling=cls.yml │ ├── pna │ │ └── base.yml │ ├── gnn-transformer │ │ ├── no-virtual │ │ │ └── pooling=cls+norm_input.yml │ │ └── JK=cat │ │ │ └── pooling=cls+norm_input.yml │ ├── transformer-gnn │ │ ├── no-virtual │ │ │ └── pooling=cls+norm_input.yml │ │ └── JK=cat │ │ │ └── pooling=cls+norm_input.yml │ └── pna-transformer │ │ └── pooling=cls+norm_input.yml ├── NCI1 │ ├── gcn │ │ └── base.yml │ ├── gcn-virtual │ │ └── base.yml │ ├── transformer │ │ └── pooling=cls.yml │ ├── gnn-transformer │ │ └── no-virtual │ │ │ ├── gin+gdp=0.1+tdp=0.1+l=4+cosine.yml │ │ │ ├── gd=128+gdp=0.1+tdp=0.1+l=3+cosine.yml │ │ │ └── ablation-pos_encoder │ │ │ ├── gin+gdp=0.1+tdp=0.1+l=4+cosine.yml │ │ │ └── gd=128+gdp=0.1+tdp=0.1+l=3+cosine.yml │ └── transformer-gnn │ │ └── no-virtual │ │ ├── gin+gdp=0.1+tdp=0.1+l=4+cosine.yml │ │ └── gd=128+gdp=0.1+tdp=0.1+l=3+cosine.yml ├── NCI109 │ ├── gcn │ │ └── base.yml │ ├── gcn-virtual │ │ └── base.yml │ ├── transformer │ │ └── pooling=cls.yml │ ├── gnn-transformer │ │ └── no-virtual │ │ │ ├── gin+gdp=0.1+tdp=0.1+l=4+cosine.yml │ │ │ ├── gd=128+gdp=0.1+tdp=0.1+l=3+cosine.yml │ │ │ └── ablation-pos_encoder │ │ │ ├── gdp=0.1+tdp=0.1+l=4+cosine.yml │ │ │ ├── gin+gdp=0.1+tdp=0.1+l=4+cosine.yml │ │ │ └── gd=128+gdp=0.1+tdp=0.1+l=3+cosine.yml │ └── transformer_gnn │ │ └── no-virtual │ │ ├── gdp=0.1+tdp=0.1+l=4+cosine.yml │ │ ├── gin+gdp=0.1+tdp=0.1+l=4+cosine.yml │ │ └── gd=128+gdp=0.1+tdp=0.1+l=3+cosine.yml └── molpcba │ ├── gin │ └── baseline+batch=256.yml │ ├── gin-virtual │ └── baseline+batch=256.yml │ ├── transformer │ └── pooling=cls.yml │ └── gnn-transformer │ ├── JK=cat │ └── pooling=cls+gin+norm_input.yml │ └── no-virtual │ └── JK=cat │ └── pooling=cls+gin+norm_input.yml ├── utils.py ├── trainers ├── baseline_trainer.py ├── base_trainer.py ├── __init__.py └── flag_trainer.py ├── dataset ├── __init__.py ├── tud.py ├── mol.py ├── code.py └── utils.py ├── run.sh ├── models ├── base_model.py ├── __init__.py ├── pna.py ├── gnn.py ├── transformer.py ├── pna_transformer.py ├── gnn_transformer.py └── transformer_gnn.py ├── modules ├── pna │ ├── scalers.py │ ├── aggregators.py │ └── pna_module.py ├── utils.py ├── conv.py ├── transformer_encoder.py ├── masked_transformer_encoder.py ├── gnn_module.py └── pna_layer.py ├── data └── adj_list.py ├── .gitignore ├── README.md ├── requirement.yml ├── LICENSE └── main.py /configs/code2/gcn/baseline+run1+seed.yml: -------------------------------------------------------------------------------- 1 | dataset: ogbg-code2 2 | aug: baseline 3 | model_type: gnn 4 | gnn_type: gcn 5 | gnn_virtual_node: false 6 | runs: 1 7 | seed: 12344 -------------------------------------------------------------------------------- /configs/code2/gcn-virtual/baseline+run1+seed.yml: -------------------------------------------------------------------------------- 1 | dataset: ogbg-code2 2 | aug: baseline 3 | model_type: gnn 4 | gnn_type: gcn 5 | gnn_virtual_node: true 6 | runs: 1 7 | seed: 12344 -------------------------------------------------------------------------------- /configs/NCI1/gcn/base.yml: -------------------------------------------------------------------------------- 1 | dataset: NCI1 2 | aug: baseline 3 | model_type: gnn 4 | gnn_type: gcn 5 | gnn_virtual_node: false 6 | runs: 5 7 | epochs: 200 8 | seed: 12344 9 | scheduler: plateau 10 | -------------------------------------------------------------------------------- /configs/NCI109/gcn/base.yml: -------------------------------------------------------------------------------- 1 | dataset: NCI109 2 | aug: baseline 3 | model_type: gnn 4 | gnn_type: gcn 5 | gnn_virtual_node: false 6 | runs: 5 7 | epochs: 200 8 | seed: 12344 9 | scheduler: plateau 10 | -------------------------------------------------------------------------------- /configs/NCI1/gcn-virtual/base.yml: -------------------------------------------------------------------------------- 1 | dataset: NCI1 2 | aug: baseline 3 | model_type: gnn 4 | gnn_type: gcn 5 | gnn_virtual_node: true 6 | runs: 5 7 | epochs: 200 8 | seed: 12344 9 | scheduler: plateau 10 | -------------------------------------------------------------------------------- /configs/NCI109/gcn-virtual/base.yml: -------------------------------------------------------------------------------- 1 | dataset: NCI109 2 | aug: baseline 3 | model_type: gnn 4 | gnn_type: gcn 5 | gnn_virtual_node: true 6 | runs: 5 7 | epochs: 200 8 | seed: 12344 9 | scheduler: plateau 10 | -------------------------------------------------------------------------------- /configs/molpcba/gin/baseline+batch=256.yml: -------------------------------------------------------------------------------- 1 | dataset: ogbg-molpcba 2 | aug: baseline 3 | model_type: gnn 4 | gnn_type: gin 5 | gnn_virtual_node: false 6 | runs: 1 7 | batch_size: 256 8 | lr: 0.01 9 | scheduler: plateau -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | def num_total_parameters(model): 2 | return sum(p.numel() for p in model.parameters()) 3 | 4 | 5 | def num_trainable_parameters(model): 6 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 7 | -------------------------------------------------------------------------------- /configs/molpcba/gin-virtual/baseline+batch=256.yml: -------------------------------------------------------------------------------- 1 | dataset: ogbg-molpcba 2 | aug: baseline 3 | model_type: gnn 4 | gnn_type: gin 5 | gnn_virtual_node: true 6 | runs: 1 7 | batch_size: 256 8 | gnn_dropout: 0.5 9 | epochs: 100 10 | lr: 0.01 11 | scheduler: plateau -------------------------------------------------------------------------------- /trainers/baseline_trainer.py: -------------------------------------------------------------------------------- 1 | from trainers import register_trainer 2 | 3 | from .base_trainer import BaseTrainer 4 | 5 | 6 | @register_trainer("baseline") 7 | class BaselineTrainer(BaseTrainer): 8 | @staticmethod 9 | def name(args): 10 | return "baseline" 11 | -------------------------------------------------------------------------------- /configs/code2/transformer/pooling=cls.yml: -------------------------------------------------------------------------------- 1 | dataset: ogbg-code2 2 | aug: baseline 3 | model_type: transformer 4 | graph_pooling: cls 5 | 6 | seed: 12344 7 | batch_size: 16 8 | lr: 0.0001 9 | start-eval: 1 10 | scheduler: cosine 11 | 12 | num_encoder_layers: 5 13 | 14 | d_model: 256 15 | gnn_emb_dim: 256 16 | -------------------------------------------------------------------------------- /configs/molpcba/transformer/pooling=cls.yml: -------------------------------------------------------------------------------- 1 | dataset: ogbg-molpcba 2 | aug: baseline 3 | model_type: transformer 4 | graph_pooling: cls 5 | 6 | seed: 12344 7 | batch_size: 256 8 | lr: 0.0001 9 | start-eval: 1 10 | scheduler: cosine 11 | 12 | num_encoder_layers: 5 13 | 14 | d_model: 256 15 | gnn_emb_dim: 256 16 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .code import CodeUtil 2 | from .mol import MolUtil 3 | from .tud import TUUtil 4 | 5 | DATASET_UTILS = { 6 | 'ogbg-code': CodeUtil, 7 | 'ogbg-code2': CodeUtil, 8 | 'ogbg-molhiv': MolUtil, 9 | 'ogbg-molpcba': MolUtil, 10 | 'NCI1': TUUtil, 11 | 'NCI109': TUUtil, 12 | } 13 | -------------------------------------------------------------------------------- /configs/code2/pna/base.yml: -------------------------------------------------------------------------------- 1 | dataset: ogbg-code2 2 | aug: baseline 3 | model_type: pna 4 | runs: 1 5 | seed: 12344 6 | gnn_emb_dim: 272 7 | scheduler: plateau 8 | lr: 0.00063096 9 | 10 | gnn_num_layer: 4 11 | batch_size: 128 12 | epochs: 30 13 | weight_decay: 3e-6 14 | gnn_dropout: 0 15 | 16 | start-eval: 10 17 | test-freq: 5 -------------------------------------------------------------------------------- /configs/code2/gnn-transformer/no-virtual/pooling=cls+norm_input.yml: -------------------------------------------------------------------------------- 1 | dataset: ogbg-code2 2 | aug: baseline 3 | model_type: gnn-transformer 4 | 5 | gnn_type: gcn 6 | gnn_virtual_node: false 7 | 8 | graph_pooling: cls 9 | 10 | batch_size: 16 11 | lr: 0.0001 12 | 13 | runs: 1 14 | seed: 12344 15 | 16 | transformer_norm_input: true 17 | -------------------------------------------------------------------------------- /configs/NCI1/transformer/pooling=cls.yml: -------------------------------------------------------------------------------- 1 | dataset: NCI1 2 | aug: baseline 3 | model_type: transformer 4 | graph_pooling: cls 5 | 6 | seed: 12344 7 | batch_size: 128 8 | lr: 0.0001 9 | scheduler: cosine 10 | 11 | num_encoder_layers: 5 12 | runs: 20 13 | 14 | d_model: 256 15 | gnn_emb_dim: 256 16 | 17 | transformer_dropout: 0.1 18 | 19 | epochs: 100 -------------------------------------------------------------------------------- /configs/code2/gnn-transformer/JK=cat/pooling=cls+norm_input.yml: -------------------------------------------------------------------------------- 1 | dataset: ogbg-code2 2 | aug: baseline 3 | model_type: gnn-transformer 4 | 5 | gnn_type: gcn 6 | gnn_virtual_node: true 7 | 8 | graph_pooling: cls 9 | 10 | batch_size: 16 11 | lr: 0.0001 12 | 13 | runs: 1 14 | seed: 12344 15 | 16 | transformer_norm_input: true 17 | gnn_JK: cat 18 | -------------------------------------------------------------------------------- /configs/code2/transformer-gnn/no-virtual/pooling=cls+norm_input.yml: -------------------------------------------------------------------------------- 1 | dataset: ogbg-code2 2 | aug: baseline 3 | model_type: transformer-gnn 4 | 5 | gnn_type: gcn 6 | gnn_virtual_node: false 7 | 8 | graph_pooling: mean 9 | 10 | batch_size: 16 11 | lr: 0.0001 12 | 13 | runs: 1 14 | seed: 12344 15 | 16 | transformer_norm_input: true 17 | graph_input_dim: 300 -------------------------------------------------------------------------------- /configs/NCI109/transformer/pooling=cls.yml: -------------------------------------------------------------------------------- 1 | dataset: NCI109 2 | aug: baseline 3 | model_type: transformer 4 | graph_pooling: cls 5 | 6 | seed: 12344 7 | batch_size: 128 8 | lr: 0.0001 9 | scheduler: cosine 10 | 11 | num_encoder_layers: 5 12 | runs: 20 13 | 14 | d_model: 256 15 | gnn_emb_dim: 256 16 | 17 | transformer_dropout: 0.1 18 | 19 | epochs: 100 20 | -------------------------------------------------------------------------------- /configs/code2/transformer-gnn/JK=cat/pooling=cls+norm_input.yml: -------------------------------------------------------------------------------- 1 | dataset: ogbg-code2 2 | aug: baseline 3 | model_type: transformer-gnn 4 | 5 | gnn_type: gcn 6 | gnn_virtual_node: true 7 | 8 | graph_pooling: mean 9 | 10 | batch_size: 16 11 | lr: 0.0001 12 | 13 | runs: 1 14 | seed: 12344 15 | 16 | transformer_norm_input: true 17 | gnn_JK: cat 18 | graph_input_dim: 300 19 | -------------------------------------------------------------------------------- /configs/molpcba/gnn-transformer/JK=cat/pooling=cls+gin+norm_input.yml: -------------------------------------------------------------------------------- 1 | dataset: ogbg-molpcba 2 | aug: baseline 3 | model_type: gnn-transformer 4 | gnn_type: gin 5 | gnn_virtual_node: true 6 | graph_pooling: cls 7 | runs: 1 8 | seed: 12344 9 | batch_size: 256 10 | lr: 0.0001 11 | start-eval: 1 12 | scheduler: plateau 13 | 14 | transformer_norm_input: true 15 | gnn_dropout: 0.3 16 | gnn_JK: cat 17 | -------------------------------------------------------------------------------- /configs/molpcba/gnn-transformer/no-virtual/JK=cat/pooling=cls+gin+norm_input.yml: -------------------------------------------------------------------------------- 1 | dataset: ogbg-molpcba 2 | aug: baseline 3 | model_type: gnn-transformer 4 | gnn_type: gin 5 | gnn_virtual_node: false 6 | graph_pooling: cls 7 | runs: 1 8 | seed: 12344 9 | batch_size: 256 10 | lr: 0.0001 11 | start-eval: 1 12 | scheduler: plateau 13 | 14 | transformer_norm_input: true 15 | gnn_dropout: 0.3 16 | gnn_JK: cat 17 | -------------------------------------------------------------------------------- /configs/code2/pna-transformer/pooling=cls+norm_input.yml: -------------------------------------------------------------------------------- 1 | dataset: ogbg-code2 2 | aug: baseline 3 | model_type: pna-transformer 4 | 5 | graph_pooling: cls 6 | 7 | batch_size: 16 8 | lr: 0.0001 9 | 10 | runs: 1 11 | seed: 12344 12 | 13 | transformer_norm_input: true 14 | 15 | 16 | gnn_emb_dim: 272 17 | scheduler: plateau 18 | 19 | gnn_num_layer: 4 20 | epochs: 30 21 | batch_size: 16 22 | weight_decay: 3e-6 23 | gnn_dropout: 0 24 | 25 | start-eval: 10 26 | test-freq: 5 -------------------------------------------------------------------------------- /configs/NCI1/gnn-transformer/no-virtual/gin+gdp=0.1+tdp=0.1+l=4+cosine.yml: -------------------------------------------------------------------------------- 1 | dataset: NCI1 2 | aug: baseline 3 | model_type: gnn-transformer 4 | 5 | gnn_type: gin 6 | gnn_virtual_node: false 7 | 8 | graph_pooling: cls 9 | 10 | lr: 0.0001 11 | 12 | seed: 12344 13 | 14 | transformer_norm_input: true 15 | 16 | runs: 20 17 | epochs: 100 18 | 19 | d_model: 128 20 | dim_feedforward: 256 21 | num_encoder_layers: 4 22 | scheduler: cosine 23 | 24 | transformer_dropout: 0.1 25 | gnn_dropout: 0.1 26 | -------------------------------------------------------------------------------- /configs/NCI1/transformer-gnn/no-virtual/gin+gdp=0.1+tdp=0.1+l=4+cosine.yml: -------------------------------------------------------------------------------- 1 | dataset: NCI1 2 | aug: baseline 3 | model_type: transformer-gnn 4 | 5 | gnn_type: gin 6 | gnn_virtual_node: false 7 | 8 | graph_pooling: mean 9 | 10 | lr: 0.0001 11 | 12 | seed: 12344 13 | 14 | transformer_norm_input: true 15 | 16 | runs: 20 17 | epochs: 100 18 | 19 | d_model: 128 20 | dim_feedforward: 256 21 | num_encoder_layers: 4 22 | scheduler: cosine 23 | 24 | transformer_dropout: 0.1 25 | gnn_dropout: 0.1 26 | -------------------------------------------------------------------------------- /configs/NCI109/gnn-transformer/no-virtual/gin+gdp=0.1+tdp=0.1+l=4+cosine.yml: -------------------------------------------------------------------------------- 1 | dataset: NCI109 2 | aug: baseline 3 | model_type: gnn-transformer 4 | 5 | gnn_type: gin 6 | gnn_virtual_node: false 7 | 8 | graph_pooling: cls 9 | 10 | lr: 0.0001 11 | 12 | seed: 12344 13 | 14 | transformer_norm_input: true 15 | 16 | runs: 5 17 | epochs: 100 18 | 19 | d_model: 128 20 | dim_feedforward: 256 21 | num_encoder_layers: 4 22 | scheduler: cosine 23 | 24 | transformer_dropout: 0.1 25 | gnn_dropout: 0.1 26 | -------------------------------------------------------------------------------- /configs/NCI109/transformer_gnn/no-virtual/gdp=0.1+tdp=0.1+l=4+cosine.yml: -------------------------------------------------------------------------------- 1 | dataset: NCI109 2 | aug: baseline 3 | model_type: transformer-gnn 4 | 5 | gnn_type: gcn 6 | gnn_virtual_node: false 7 | 8 | graph_pooling: mean 9 | 10 | lr: 0.0001 11 | 12 | seed: 12344 13 | 14 | transformer_norm_input: true 15 | 16 | runs: 20 17 | epochs: 100 18 | 19 | d_model: 128 20 | dim_feedforward: 256 21 | num_encoder_layers: 4 22 | scheduler: cosine 23 | 24 | transformer_dropout: 0.1 25 | gnn_dropout: 0.1 26 | -------------------------------------------------------------------------------- /configs/NCI109/transformer_gnn/no-virtual/gin+gdp=0.1+tdp=0.1+l=4+cosine.yml: -------------------------------------------------------------------------------- 1 | dataset: NCI109 2 | aug: baseline 3 | model_type: transformer-gnn 4 | 5 | gnn_type: gin 6 | gnn_virtual_node: false 7 | 8 | graph_pooling: mean 9 | 10 | lr: 0.0001 11 | 12 | seed: 12344 13 | 14 | transformer_norm_input: true 15 | 16 | runs: 20 17 | epochs: 100 18 | 19 | d_model: 128 20 | dim_feedforward: 256 21 | num_encoder_layers: 4 22 | scheduler: cosine 23 | 24 | transformer_dropout: 0.1 25 | gnn_dropout: 0.1 26 | -------------------------------------------------------------------------------- /configs/NCI1/gnn-transformer/no-virtual/gd=128+gdp=0.1+tdp=0.1+l=3+cosine.yml: -------------------------------------------------------------------------------- 1 | dataset: NCI1 2 | aug: baseline 3 | model_type: gnn-transformer 4 | 5 | gnn_type: gcn 6 | gnn_virtual_node: false 7 | 8 | graph_pooling: cls 9 | 10 | lr: 0.0001 11 | 12 | seed: 12344 13 | 14 | transformer_norm_input: true 15 | 16 | runs: 20 17 | epochs: 100 18 | 19 | d_model: 128 20 | dim_feedforward: 256 21 | num_encoder_layers: 3 22 | scheduler: cosine 23 | 24 | transformer_dropout: 0.1 25 | gnn_dropout: 0.1 26 | gnn_emb_dim: 128 -------------------------------------------------------------------------------- /configs/NCI1/transformer-gnn/no-virtual/gd=128+gdp=0.1+tdp=0.1+l=3+cosine.yml: -------------------------------------------------------------------------------- 1 | dataset: NCI1 2 | aug: baseline 3 | model_type: transformer-gnn 4 | 5 | gnn_type: gcn 6 | gnn_virtual_node: false 7 | 8 | graph_pooling: mean 9 | 10 | lr: 0.0001 11 | 12 | seed: 12344 13 | 14 | transformer_norm_input: true 15 | 16 | runs: 20 17 | epochs: 100 18 | 19 | d_model: 128 20 | dim_feedforward: 256 21 | num_encoder_layers: 3 22 | scheduler: cosine 23 | 24 | transformer_dropout: 0.1 25 | gnn_dropout: 0.1 26 | gnn_emb_dim: 128 -------------------------------------------------------------------------------- /configs/NCI109/gnn-transformer/no-virtual/gd=128+gdp=0.1+tdp=0.1+l=3+cosine.yml: -------------------------------------------------------------------------------- 1 | dataset: NCI109 2 | aug: baseline 3 | model_type: gnn-transformer 4 | 5 | gnn_type: gcn 6 | gnn_virtual_node: false 7 | 8 | graph_pooling: cls 9 | 10 | lr: 0.0001 11 | 12 | seed: 12344 13 | 14 | transformer_norm_input: true 15 | 16 | runs: 20 17 | epochs: 100 18 | 19 | d_model: 128 20 | dim_feedforward: 256 21 | num_encoder_layers: 3 22 | scheduler: cosine 23 | 24 | transformer_dropout: 0.1 25 | gnn_dropout: 0.1 26 | gnn_emb_dim: 128 -------------------------------------------------------------------------------- /configs/NCI109/transformer_gnn/no-virtual/gd=128+gdp=0.1+tdp=0.1+l=3+cosine.yml: -------------------------------------------------------------------------------- 1 | dataset: NCI109 2 | aug: baseline 3 | model_type: transformer-gnn 4 | 5 | gnn_type: gcn 6 | gnn_virtual_node: false 7 | 8 | graph_pooling: mean 9 | 10 | lr: 0.0001 11 | 12 | seed: 12344 13 | 14 | transformer_norm_input: true 15 | 16 | runs: 20 17 | epochs: 100 18 | 19 | d_model: 128 20 | dim_feedforward: 256 21 | num_encoder_layers: 3 22 | scheduler: cosine 23 | 24 | transformer_dropout: 0.1 25 | gnn_dropout: 0.1 26 | gnn_emb_dim: 128 -------------------------------------------------------------------------------- /configs/NCI1/gnn-transformer/no-virtual/ablation-pos_encoder/gin+gdp=0.1+tdp=0.1+l=4+cosine.yml: -------------------------------------------------------------------------------- 1 | dataset: NCI1 2 | aug: baseline 3 | model_type: gnn-transformer 4 | 5 | gnn_type: gin 6 | gnn_virtual_node: false 7 | 8 | graph_pooling: cls 9 | 10 | lr: 0.0001 11 | 12 | seed: 12344 13 | 14 | transformer_norm_input: true 15 | 16 | runs: 20 17 | epochs: 100 18 | 19 | d_model: 128 20 | dim_feedforward: 256 21 | num_encoder_layers: 4 22 | scheduler: cosine 23 | 24 | transformer_dropout: 0.1 25 | gnn_dropout: 0.1 26 | 27 | pos_encoder: True 28 | -------------------------------------------------------------------------------- /configs/NCI109/gnn-transformer/no-virtual/ablation-pos_encoder/gdp=0.1+tdp=0.1+l=4+cosine.yml: -------------------------------------------------------------------------------- 1 | dataset: NCI109 2 | aug: baseline 3 | model_type: gnn-transformer 4 | 5 | gnn_type: gcn 6 | gnn_virtual_node: false 7 | 8 | graph_pooling: cls 9 | 10 | lr: 0.0001 11 | 12 | seed: 12344 13 | 14 | transformer_norm_input: true 15 | 16 | runs: 20 17 | epochs: 100 18 | 19 | d_model: 128 20 | dim_feedforward: 256 21 | num_encoder_layers: 4 22 | scheduler: cosine 23 | 24 | transformer_dropout: 0.1 25 | gnn_dropout: 0.1 26 | 27 | pos_encoder: True 28 | -------------------------------------------------------------------------------- /configs/NCI109/gnn-transformer/no-virtual/ablation-pos_encoder/gin+gdp=0.1+tdp=0.1+l=4+cosine.yml: -------------------------------------------------------------------------------- 1 | dataset: NCI109 2 | aug: baseline 3 | model_type: gnn-transformer 4 | 5 | gnn_type: gin 6 | gnn_virtual_node: false 7 | 8 | graph_pooling: cls 9 | 10 | lr: 0.0001 11 | 12 | seed: 12344 13 | 14 | transformer_norm_input: true 15 | 16 | runs: 20 17 | epochs: 100 18 | 19 | d_model: 128 20 | dim_feedforward: 256 21 | num_encoder_layers: 4 22 | scheduler: cosine 23 | 24 | transformer_dropout: 0.1 25 | gnn_dropout: 0.1 26 | 27 | pos_encoder: True 28 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --nodes=1 4 | #SBATCH --ntasks-per-node=1 5 | #SBATCH --gres=gpu:1 6 | #SBATCH --cpus-per-task=10 7 | 8 | config=$1 9 | 10 | echo $(scontrol show hostnames $SLURM_JOB_NODELIST) 11 | source ~/.bashrc 12 | conda activate graph-aug 13 | 14 | echo CUDA_VISIBLE_DEVICES $CUDA_VISIBLE_DEVICES 15 | 16 | echo "python main.py --configs $config --num_workers 8 --devices $CUDA_VISIBLE_DEVICES" 17 | CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES python main.py --configs $config --num_workers 8 --devices $CUDA_VISIBLE_DEVICES 18 | -------------------------------------------------------------------------------- /configs/NCI1/gnn-transformer/no-virtual/ablation-pos_encoder/gd=128+gdp=0.1+tdp=0.1+l=3+cosine.yml: -------------------------------------------------------------------------------- 1 | dataset: NCI1 2 | aug: baseline 3 | model_type: gnn-transformer 4 | 5 | gnn_type: gcn 6 | gnn_virtual_node: false 7 | 8 | graph_pooling: cls 9 | 10 | lr: 0.0001 11 | 12 | seed: 12344 13 | 14 | transformer_norm_input: true 15 | 16 | runs: 20 17 | epochs: 100 18 | 19 | d_model: 128 20 | dim_feedforward: 256 21 | num_encoder_layers: 3 22 | scheduler: cosine 23 | 24 | transformer_dropout: 0.1 25 | gnn_dropout: 0.1 26 | gnn_emb_dim: 128 27 | 28 | pos_encoder: True -------------------------------------------------------------------------------- /configs/NCI109/gnn-transformer/no-virtual/ablation-pos_encoder/gd=128+gdp=0.1+tdp=0.1+l=3+cosine.yml: -------------------------------------------------------------------------------- 1 | dataset: NCI109 2 | aug: baseline 3 | model_type: gnn-transformer 4 | 5 | gnn_type: gcn 6 | gnn_virtual_node: false 7 | 8 | graph_pooling: cls 9 | 10 | lr: 0.0001 11 | 12 | seed: 12344 13 | 14 | transformer_norm_input: true 15 | 16 | runs: 20 17 | epochs: 100 18 | 19 | d_model: 128 20 | dim_feedforward: 256 21 | num_encoder_layers: 3 22 | scheduler: cosine 23 | 24 | transformer_dropout: 0.1 25 | gnn_dropout: 0.1 26 | gnn_emb_dim: 128 27 | 28 | pos_encoder: True 29 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BaseModel(nn.Module): 6 | @staticmethod 7 | def need_deg(): 8 | return False 9 | 10 | @staticmethod 11 | def add_args(parser): 12 | return 13 | 14 | @staticmethod 15 | def name(args): 16 | raise NotImplementedError 17 | 18 | def __init__(self): 19 | super().__init__() 20 | 21 | def forward(self, batched_data, perturb=None): 22 | raise NotImplementedError 23 | 24 | def epoch_callback(self, epoch): 25 | return 26 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .gnn import GNN 2 | from .gnn_transformer import GNNTransformer 3 | from .pna import PNANet 4 | from .pna_transformer import PNATransformer 5 | from .transformer import Transformer 6 | from .transformer_gnn import TransformerGNN 7 | 8 | 9 | def get_model_and_parser(args, parser): 10 | model_cls = MODELS[args.model_type] 11 | model_cls.add_args(parser) 12 | return model_cls 13 | 14 | 15 | MODELS = {"gnn": GNN, "pna": PNANet, "gnn-transformer": GNNTransformer, "transformer": Transformer, "pna-transformer": PNATransformer, "transformer-gnn": TransformerGNN} 16 | -------------------------------------------------------------------------------- /modules/pna/scalers.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | # Implemented with the help of Matthias Fey, author of PyTorch Geometric 7 | # For an example see https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pna.py 8 | 9 | 10 | def scale_identity(src: Tensor, deg: Tensor, avg_deg: Dict[str, float]): 11 | return src 12 | 13 | 14 | def scale_amplification(src: Tensor, deg: Tensor, avg_deg: Dict[str, float]): 15 | return src * (torch.log(deg + 1) / avg_deg["log"]) 16 | 17 | 18 | def scale_attenuation(src: Tensor, deg: Tensor, avg_deg: Dict[str, float]): 19 | scale = avg_deg["log"] / torch.log(deg + 1) 20 | scale[deg == 0] = 1 21 | return src * scale 22 | 23 | 24 | def scale_linear(src: Tensor, deg: Tensor, avg_deg: Dict[str, float]): 25 | return src * (deg / avg_deg["lin"]) 26 | 27 | 28 | def scale_inverse_linear(src: Tensor, deg: Tensor, avg_deg: Dict[str, float]): 29 | scale = avg_deg["lin"] / deg 30 | scale[deg == 0] = 1 31 | return src * scale 32 | 33 | 34 | SCALERS = { 35 | "identity": scale_identity, 36 | "amplification": scale_amplification, 37 | "attenuation": scale_attenuation, 38 | "linear": scale_linear, 39 | "inverse_linear": scale_inverse_linear, 40 | } 41 | -------------------------------------------------------------------------------- /modules/pna/aggregators.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch_scatter import scatter 6 | 7 | # Implemented with the help of Matthias Fey, author of PyTorch Geometric 8 | # For an example see https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pna.py 9 | 10 | 11 | def aggregate_sum(src: Tensor, index: Tensor, dim_size: Optional[int]): 12 | return scatter(src, index, 0, None, dim_size, reduce="sum") 13 | 14 | 15 | def aggregate_mean(src: Tensor, index: Tensor, dim_size: Optional[int]): 16 | return scatter(src, index, 0, None, dim_size, reduce="mean") 17 | 18 | 19 | def aggregate_min(src: Tensor, index: Tensor, dim_size: Optional[int]): 20 | return scatter(src, index, 0, None, dim_size, reduce="min") 21 | 22 | 23 | def aggregate_max(src: Tensor, index: Tensor, dim_size: Optional[int]): 24 | return scatter(src, index, 0, None, dim_size, reduce="max") 25 | 26 | 27 | def aggregate_var(src, index, dim_size): 28 | mean = aggregate_mean(src, index, dim_size) 29 | mean_squares = aggregate_mean(src * src, index, dim_size) 30 | return mean_squares - mean * mean 31 | 32 | 33 | def aggregate_std(src, index, dim_size): 34 | return torch.sqrt(torch.relu(aggregate_var(src, index, dim_size)) + 1e-5) 35 | 36 | 37 | AGGREGATORS = { 38 | "sum": aggregate_sum, 39 | "mean": aggregate_mean, 40 | "min": aggregate_min, 41 | "max": aggregate_max, 42 | "var": aggregate_var, 43 | "std": aggregate_std, 44 | } 45 | -------------------------------------------------------------------------------- /data/adj_list.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import numpy as np 5 | import torch 6 | from loguru import logger 7 | from tqdm import tqdm 8 | 9 | 10 | def make_adj_list(N, edge_index_transposed): 11 | A = np.eye(N) 12 | for edge in edge_index_transposed: 13 | A[edge[0], edge[1]] = 1 14 | adj_list = A != 0 15 | return adj_list 16 | 17 | 18 | def make_adj_list_wrapper(x): 19 | return make_adj_list(x["num_nodes"], x["edge_index"].T) 20 | 21 | 22 | def compute_adjacency_list(data): 23 | out = [] 24 | for x in tqdm(data, "adjacency list", leave=False): 25 | out.append(make_adj_list_wrapper(x)) 26 | return out 27 | 28 | 29 | def combine_results(data, adj_list): 30 | out_data = [] 31 | for x, l in tqdm(zip(data, adj_list), "assembling adj_list result", total=len(data), leave=False): 32 | x["adj_list"] = l 33 | out_data.append(x) 34 | return out_data 35 | 36 | 37 | def compute_adjacency_list_cached(data, key, root="/data/zhwu/tmp"): 38 | cachefile = f"{root}/OGB_ADJLIST_{key}.pickle" 39 | if os.path.exists(cachefile): 40 | with open(cachefile, "rb") as cachehandle: 41 | logger.debug("using cached result from '%s'" % cachefile) 42 | result = pickle.load(cachehandle) 43 | return combine_results(data, result) 44 | result = compute_adjacency_list(data) 45 | with open(cachefile, "wb") as cachehandle: 46 | logger.debug("saving result to cache '%s'" % cachefile) 47 | pickle.dump(result, cachehandle) 48 | logger.info("Got adjacency list data for key %s" % key) 49 | return combine_results(data, result) 50 | -------------------------------------------------------------------------------- /trainers/base_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import wandb 3 | from loguru import logger 4 | from tqdm import tqdm 5 | 6 | 7 | class BaseTrainer: 8 | @staticmethod 9 | def transform(args): 10 | return None 11 | 12 | @staticmethod 13 | def add_args(parser): 14 | pass 15 | 16 | @staticmethod 17 | def train(model, device, loader, optimizer, args, calc_loss, scheduler=None): 18 | model.train() 19 | 20 | loss_accum = 0 21 | t = tqdm(loader, desc="Train") 22 | for step, batch in enumerate(t): 23 | batch = batch.to(device) 24 | 25 | if batch.x.shape[0] == 1 or batch.batch[-1] == 0: 26 | pass 27 | else: 28 | optimizer.zero_grad() 29 | pred_list = model(batch) 30 | 31 | loss = calc_loss(pred_list, batch) 32 | 33 | loss.backward() 34 | if args.grad_clip is not None: 35 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 36 | optimizer.step() 37 | 38 | if scheduler: 39 | scheduler.step() 40 | 41 | detached_loss = loss.item() 42 | loss_accum += detached_loss 43 | t.set_description(f"Train (loss = {detached_loss:.4f}, smoothed = {loss_accum / (step + 1):.4f})") 44 | wandb.log({"train/iter-loss": detached_loss, "train/iter-loss-smoothed": loss_accum / (step + 1)}) 45 | 46 | logger.info("Average training loss: {:.4f}".format(loss_accum / (step + 1))) 47 | return loss_accum / (step + 1) 48 | 49 | @staticmethod 50 | def name(args): 51 | raise NotImplemented 52 | -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from .base_trainer import BaseTrainer 5 | 6 | TRAINER_REGISTRY = {} 7 | TRAINER_CLASS_NAMES = set() 8 | 9 | __all__ = { 10 | "BaseTrainer", 11 | } 12 | 13 | 14 | def get_trainer_and_parser(args, parser): 15 | trainer = TRAINER_REGISTRY[args.aug] 16 | trainer.add_args(parser) 17 | return trainer 18 | 19 | 20 | def register_trainer(name, dataclass=None): 21 | """ 22 | New tasks can be added to fairseq with the 23 | :func:`~fairseq.tasks.register_task` function decorator. 24 | For example:: 25 | @register_task('classification') 26 | class ClassificationTask(FairseqTask): 27 | (...) 28 | .. note:: 29 | All Tasks must implement the :class:`~fairseq.tasks.FairseqTask` 30 | interface. 31 | Args: 32 | name (str): the name of the task 33 | """ 34 | 35 | def register_trainer_cls(cls): 36 | if name in TRAINER_REGISTRY: 37 | raise ValueError("Cannot register duplicate task ({})".format(name)) 38 | if not issubclass(cls, BaseTrainer): 39 | raise ValueError("Trainer ({}: {}) must extend BaseTrainer".format(name, cls.__name__)) 40 | if cls.__name__ in TRAINER_CLASS_NAMES: 41 | raise ValueError("Cannot register task with duplicate class name ({})".format(cls.__name__)) 42 | TRAINER_REGISTRY[name] = cls 43 | TRAINER_CLASS_NAMES.add(cls.__name__) 44 | 45 | return cls 46 | 47 | return register_trainer_cls 48 | 49 | 50 | # automatically import any Python files in the models/ directory 51 | trainers_dir = os.path.dirname(__file__) 52 | for file in os.listdir(trainers_dir): 53 | path = os.path.join(trainers_dir, file) 54 | if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)): 55 | trainer_name = file[: file.find(".py")] if file.endswith(".py") else file 56 | module = importlib.import_module("trainers." + trainer_name) 57 | -------------------------------------------------------------------------------- /modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from loguru import logger 3 | 4 | 5 | def pad_batch(h_node, batch, max_input_len, get_mask=False): 6 | num_batch = batch[-1] + 1 7 | num_nodes = [] 8 | masks = [] 9 | for i in range(num_batch): 10 | mask = batch.eq(i) 11 | masks.append(mask) 12 | num_node = mask.sum() 13 | num_nodes.append(num_node) 14 | 15 | # logger.info(max(num_nodes)) 16 | max_num_nodes = min(max(num_nodes), max_input_len) 17 | padded_h_node = h_node.data.new(max_num_nodes, num_batch, h_node.size(-1)).fill_(0) 18 | src_padding_mask = h_node.data.new(num_batch, max_num_nodes).fill_(0).bool() 19 | 20 | for i, mask in enumerate(masks): 21 | num_node = num_nodes[i] 22 | if num_node > max_num_nodes: 23 | num_node = max_num_nodes 24 | padded_h_node[-num_node:, i] = h_node[mask][-num_node:] 25 | src_padding_mask[i, : max_num_nodes - num_node] = True # [b, s] 26 | 27 | if get_mask: 28 | return padded_h_node, src_padding_mask, num_nodes, masks, max_num_nodes 29 | return padded_h_node, src_padding_mask 30 | 31 | 32 | def unpad_batch(padded_h_node, prev_h_node, num_nodes, origin_mask, max_num_nodes): 33 | """ 34 | padded_h_node: [s, b, f] 35 | prev_h_node: [bxs, f] 36 | batch: [n] 37 | pad_mask: [b, s] 38 | """ 39 | 40 | for i, mask in enumerate(origin_mask): 41 | num_node = num_nodes[i] 42 | if num_node > max_num_nodes: 43 | num_node = max_num_nodes 44 | # cutoff mask 45 | indices = mask.nonzero() 46 | indices = indices[-num_node:] 47 | mask = torch.zeros_like(mask) 48 | mask[indices] = True 49 | # logger.info("prev_h_node:", prev_h_node.size()) 50 | # logger.info("padded_h_node:", padded_h_node.size()) 51 | # logger.info("mask:", mask.size()) 52 | prev_h_node = prev_h_node.masked_scatter(mask.unsqueeze(-1), padded_h_node[-num_node:, i]) 53 | return prev_h_node 54 | -------------------------------------------------------------------------------- /trainers/flag_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import wandb 3 | from tqdm import tqdm 4 | 5 | from trainers import register_trainer 6 | 7 | from .base_trainer import BaseTrainer 8 | 9 | 10 | @register_trainer("flag") 11 | class FlagTrainer(BaseTrainer): 12 | @staticmethod 13 | def add_args(parser): 14 | # fmt: off 15 | parser.add_argument('--step-size', type=float, default=8e-3) 16 | parser.add_argument('-m', type=int, default=3) 17 | # fmt: on 18 | 19 | @staticmethod 20 | def train(model, device, loader, optimizer, args, calc_loss): 21 | model.train() 22 | 23 | loss_accum = 0 24 | for step, batch in enumerate(tqdm(loader, desc="Iteration")): 25 | batch = batch.to(device) 26 | 27 | if batch.x.shape[0] == 1 or batch.batch[-1] == 0: 28 | pass 29 | else: 30 | optimizer.zero_grad() 31 | 32 | perturb = torch.FloatTensor(batch.x.shape[0], args.gnn_emb_dim).uniform_(-args.step_size, args.step_size).to(device) 33 | perturb.requires_grad_() 34 | 35 | pred_list = model(batch, perturb) 36 | 37 | loss = calc_loss(pred_list, batch, args.m) 38 | 39 | for _ in range(args.m - 1): 40 | loss.backward() 41 | perturb_data = perturb.detach() + args.step_size * torch.sign(perturb.grad.detach()) 42 | perturb.data = perturb_data.data 43 | perturb.grad[:] = 0 44 | 45 | pred_list = model(batch, perturb) 46 | 47 | loss = calc_loss(pred_list, batch, args.m) 48 | 49 | loss.backward() 50 | optimizer.step() 51 | 52 | detached_loss = loss.item() 53 | loss_accum += detached_loss 54 | wandb.log({"train/iter-loss": detached_loss}) 55 | 56 | return loss_accum / (step + 1) 57 | 58 | @staticmethod 59 | def name(args): 60 | return "flag" 61 | -------------------------------------------------------------------------------- /dataset/tud.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.utils.data import random_split 8 | from torch_geometric.datasets import TUDataset 9 | from torch_geometric.utils import degree 10 | from tqdm import tqdm 11 | 12 | 13 | class TUUtil: 14 | @staticmethod 15 | def add_args(parser): 16 | parser.set_defaults(batch_size=128) 17 | parser.set_defaults(epochs=10000) 18 | parser.set_defaults(lr=0.0005) 19 | parser.set_defaults(weight_decay=0.0001) 20 | parser.set_defaults(gnn_dropout=0.5) 21 | parser.set_defaults(gnn_emb_dim=128) 22 | 23 | @staticmethod 24 | def loss_fn(task_type): 25 | def calc_loss(pred, batch, m=1.0): 26 | loss = F.cross_entropy(pred, batch.y) 27 | return loss 28 | 29 | return calc_loss 30 | 31 | @staticmethod 32 | @torch.no_grad() 33 | def eval(model, device, loader, evaluator): 34 | model.eval() 35 | 36 | correct = 0 37 | for step, batch in enumerate(tqdm(loader, desc="Eval")): 38 | batch = batch.to(device) 39 | 40 | pred = model(batch) 41 | pred = pred.max(dim=1)[1] 42 | correct += pred.eq(batch.y).sum().item() 43 | return {"acc": correct / len(loader.dataset)} 44 | 45 | @staticmethod 46 | def preprocess(args): 47 | dataset = TUDataset(os.path.join(args.data_root, args.dataset), name=args.dataset) 48 | num_tasks = dataset.num_classes 49 | 50 | num_features = dataset.num_features 51 | 52 | num_training = int(len(dataset) * 0.8) 53 | num_val = int(len(dataset) * 0.1) 54 | num_test = len(dataset) - (num_training + num_val) 55 | training_set, validation_set, test_set = random_split(dataset, [num_training, num_val, num_test]) 56 | 57 | class Dataset(dict): 58 | pass 59 | 60 | dataset = Dataset({"train": training_set, "valid": validation_set, "test": test_set}) 61 | dataset.eval_metric = "acc" 62 | dataset.task_type = "classification" 63 | dataset.get_idx_split = lambda: {"train": "train", "valid": "valid", "test": "test"} 64 | 65 | node_encoder_cls = lambda: nn.Linear(num_features, args.gnn_emb_dim) 66 | 67 | def edge_encoder_cls(_): 68 | def zero(_): 69 | return 0 70 | 71 | return zero 72 | 73 | return dataset, num_tasks, node_encoder_cls, edge_encoder_cls, None 74 | -------------------------------------------------------------------------------- /modules/conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch_geometric.nn import MessagePassing, global_add_pool, global_mean_pool 6 | from torch_geometric.utils import degree 7 | 8 | 9 | ### GIN convolution along the graph structure 10 | class GINConv(MessagePassing): 11 | def __init__(self, emb_dim: int, edge_encoder_cls): 12 | """ 13 | emb_dim (int): node embedding dimensionality 14 | """ 15 | 16 | super(GINConv, self).__init__(aggr="add") 17 | 18 | self.mlp = torch.nn.Sequential( 19 | torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim), torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, emb_dim) 20 | ) 21 | self.eps = torch.nn.Parameter(torch.Tensor([0])) 22 | 23 | # edge_attr is two dimensional after augment_edge transformation 24 | self.edge_encoder = edge_encoder_cls(emb_dim) 25 | 26 | def forward(self, x, edge_index, edge_attr): 27 | edge_embedding = self.edge_encoder(edge_attr) 28 | out = self.mlp((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_embedding)) 29 | 30 | return out 31 | 32 | def message(self, x_j, edge_attr): 33 | return F.relu(x_j + edge_attr) 34 | 35 | def update(self, aggr_out): 36 | return aggr_out 37 | 38 | 39 | ### GCN convolution along the graph structure 40 | class GCNConv(MessagePassing): 41 | def __init__(self, emb_dim, edge_encoder_cls): 42 | super(GCNConv, self).__init__(aggr="add") 43 | 44 | self.linear = torch.nn.Linear(emb_dim, emb_dim) 45 | self.root_emb = torch.nn.Embedding(1, emb_dim) 46 | 47 | # edge_attr is two dimensional after augment_edge transformation 48 | self.edge_encoder = edge_encoder_cls(emb_dim) 49 | 50 | def forward(self, x, edge_index, edge_attr): 51 | x = self.linear(x) 52 | edge_embedding = self.edge_encoder(edge_attr) 53 | 54 | row, col = edge_index 55 | 56 | # edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device) 57 | deg = degree(row, x.size(0), dtype=x.dtype) + 1 58 | deg_inv_sqrt = deg.pow(-0.5) 59 | deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0 60 | 61 | norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] 62 | 63 | return self.propagate(edge_index, x=x, edge_attr=edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1.0 / deg.view( 64 | -1, 1 65 | ) 66 | 67 | def message(self, x_j, edge_attr, norm): 68 | return norm.view(-1, 1) * F.relu(x_j + edge_attr) 69 | 70 | def update(self, aggr_out): 71 | return aggr_out 72 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | slurm*.out 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # pytype static type analyzer 137 | .pytype/ 138 | 139 | # Cython debug symbols 140 | cython_debug/ 141 | wandb 142 | .vscode 143 | exps 144 | -------------------------------------------------------------------------------- /modules/transformer_encoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import models.gnn as gnn 7 | 8 | 9 | class TransformerNodeEncoder(nn.Module): 10 | @staticmethod 11 | def add_args(parser): 12 | group = parser.add_argument_group("transformer") 13 | group.add_argument("--d_model", type=int, default=128, help="transformer d_model.") 14 | group.add_argument("--nhead", type=int, default=4, help="transformer heads") 15 | group.add_argument("--dim_feedforward", type=int, default=512, help="transformer feedforward dim") 16 | group.add_argument("--transformer_dropout", type=float, default=0.3) 17 | group.add_argument("--transformer_activation", type=str, default="relu") 18 | group.add_argument("--num_encoder_layers", type=int, default=4) 19 | group.add_argument("--max_input_len", default=1000, help="The max input length of transformer input") 20 | group.add_argument("--transformer_norm_input", action="store_true", default=False) 21 | 22 | def __init__(self, args): 23 | super().__init__() 24 | 25 | self.d_model = args.d_model 26 | self.num_layer = args.num_encoder_layers 27 | # Creating Transformer Encoder Model 28 | encoder_layer = nn.TransformerEncoderLayer( 29 | args.d_model, args.nhead, args.dim_feedforward, args.transformer_dropout, args.transformer_activation 30 | ) 31 | encoder_norm = nn.LayerNorm(args.d_model) 32 | self.transformer = nn.TransformerEncoder(encoder_layer, args.num_encoder_layers, encoder_norm) 33 | self.max_input_len = args.max_input_len 34 | 35 | self.norm_input = None 36 | if args.transformer_norm_input: 37 | self.norm_input = nn.LayerNorm(args.d_model) 38 | self.cls_embedding = None 39 | if args.graph_pooling == "cls": 40 | self.cls_embedding = nn.Parameter(torch.randn([1, 1, args.d_model], requires_grad=True)) 41 | 42 | def forward(self, padded_h_node, src_padding_mask): 43 | """ 44 | padded_h_node: n_b x B x h_d 45 | src_key_padding_mask: B x n_b 46 | """ 47 | 48 | # (S, B, h_d), (B, S) 49 | 50 | if self.cls_embedding is not None: 51 | expand_cls_embedding = self.cls_embedding.expand(1, padded_h_node.size(1), -1) 52 | padded_h_node = torch.cat([padded_h_node, expand_cls_embedding], dim=0) 53 | 54 | zeros = src_padding_mask.data.new(src_padding_mask.size(0), 1).fill_(0) 55 | src_padding_mask = torch.cat([src_padding_mask, zeros], dim=1) 56 | if self.norm_input is not None: 57 | padded_h_node = self.norm_input(padded_h_node) 58 | 59 | transformer_out = self.transformer(padded_h_node, src_key_padding_mask=src_padding_mask) # (S, B, h_d) 60 | 61 | return transformer_out, src_padding_mask 62 | -------------------------------------------------------------------------------- /modules/pna/pna_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from ogb.graphproppred.mol_encoder import AtomEncoder 5 | from torch_geometric.nn import ( 6 | BatchNorm, 7 | GlobalAttention, 8 | PNAConv, 9 | Set2Set, 10 | global_add_pool, 11 | global_max_pool, 12 | global_mean_pool, 13 | ) 14 | 15 | 16 | class PNANodeEmbedding(nn.Module): 17 | @staticmethod 18 | def add_args(parser): 19 | group = parser.add_argument_group("PNANet configs") 20 | group.add_argument("--aggregators", type=str, nargs="+", default=["mean", "max", "min", "std"]) 21 | group.add_argument("--scalers", type=str, nargs="+", default=["identity", "amplification", "attenuation"]) 22 | group.add_argument("--post_layers", type=int, default=1) 23 | group.add_argument("--add_edge", type=str, default="none") 24 | group.set_defaults(gnn_residual=True) 25 | group.set_defaults(gnn_dropout=0.3) 26 | group.set_defaults(gnn_emb_dim=70) 27 | group.set_defaults(gnn_num_layer=4) 28 | 29 | def __init__(self, node_encoder, args): 30 | super().__init__() 31 | self.num_layer = args.gnn_num_layer 32 | self.max_seq_len = args.max_seq_len 33 | self.aggregators = args.aggregators 34 | self.scalers = args.scalers 35 | self.residual = args.gnn_residual 36 | self.drop_ratio = args.gnn_dropout 37 | self.graph_pooling = args.graph_pooling 38 | 39 | self.node_encoder = node_encoder 40 | 41 | self.layers = nn.ModuleList( 42 | [ 43 | PNAConv( 44 | args.gnn_emb_dim, 45 | args.gnn_emb_dim, 46 | aggregators=self.aggregators, 47 | scalers=self.scalers, 48 | deg=args.deg, 49 | towers=4, 50 | divide_input=True, 51 | ) 52 | for _ in range(self.num_layer) 53 | ] 54 | ) 55 | self.batch_norms = nn.ModuleList([BatchNorm(args.gnn_emb_dim) for _ in range(self.num_layer)]) 56 | 57 | def forward(self, batched_data, perturb=None): 58 | x, edge_index, edge_attr = batched_data.x, batched_data.edge_index, batched_data.edge_attr 59 | node_depth = batched_data.node_depth if hasattr(batched_data, "node_depth") else None 60 | encoded_node = ( 61 | self.node_encoder(x) 62 | if node_depth is None 63 | else self.node_encoder( 64 | x, 65 | node_depth.view( 66 | -1, 67 | ), 68 | ) 69 | ) 70 | x = encoded_node + perturb if perturb is not None else encoded_node 71 | 72 | for conv, batch_norm in zip(self.layers, self.batch_norms): 73 | h = F.relu(batch_norm(conv(x, edge_index))) 74 | if self.residual: 75 | x = h + x 76 | x = F.dropout(x, self.drop_ratio, training=self.training) 77 | 78 | return x 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Representing Long-Range Context for Graph Neural Networks with Global Attention 2 | ``` 3 | @inproceedings{Wu2021GraphTrans, 4 | title={Representing Long-Range Context for Graph Neural Networks with Global Attention}, 5 | author={Wu, Zhanghao and Jain, Paras and Wright, Matthew and Mirhoseini, Azalia and Gonzalez, Joseph E and Stoica, Ion}, 6 | booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, 7 | year={2021} 8 | } 9 | ``` 10 | ## Overview 11 | We release the PyTorch code for the GraphTrans [[paper](https://proceedings.neurips.cc//paper/2021/hash/6e67691b60ed3e4a55935261314dd534-Abstract.html)] 12 | 13 | ## Installation 14 | To setup the Python environment, please install conda first. 15 | All the required environments are in [requirement.yml](./requirement.yml). 16 | ```bash 17 | conda env create -f requirement.yml 18 | ``` 19 | ## How to Run 20 | 21 | To run the experiments, please refer to the commands below (taking OGBG-Code2 as an example): 22 | ```bash 23 | # GraphTrans (GCN-Virtual) 24 | python main.py --configs configs/code2/gnn-transformer/JK=cat/pooling=cls+norm_input.yml --runs 5 25 | # GraphTrans (GCN) 26 | python main.py --configs configs/code2/gnn-transformer/no-virtual/pooling=cls+norm_input.yml --runs 5 27 | # Or to use slurm 28 | sbatch ./slurm-run.sh ”configs/code2/gnn-transformer/JK=cat/pooling=cls+norm_input.yml --runs 5” 29 | ``` 30 | The config path for each dataset/model can be found in the result table below. 31 | ## Results 32 | | Dataset | Model | Valid | Test | Config | 33 | |:--|:--|:--:|:--:|:--:| 34 | | [OGBG-Code2](https://ogb.stanford.edu/docs/leader_graphprop/#ogbg-code2) | GraphTrans (GCN) | 0.1599±0.0009 | 0.1751±0.0015 | [Config](configs/code2/gnn-transformer/no-virtual/pooling=cls+norm_input.yml) | 35 | | | GraphTrans (PNA) | 0.1622±0.0025 | 0.1765±0.0033 | [Config](configs/code2/pna-transformer/pooling=cls+norm_input.yml) | 36 | | | GraphTrans (GCN-Virtual) | 0.1661±0.0012 | 0.1830±0.0024 | [Config](configs/code2/gnn-transformer/JK=cat/pooling=cls+norm_input.yml) | 37 | | [OGBG-Molpcba](https://ogb.stanford.edu/docs/leader_graphprop/#ogbg-molpcba) | GraphTrans (GIN) | 0.2893±0.0050 | 0.2756±0.0039 | [Config](configs/molpcba/gnn-transformer/no-virtual/JK=cat/pooling=cls+gin+norm_input.yml) | 38 | | | GraphTrans (GIN-Virtual) | 0.2867±0.0022 | 0.2761±0.0029 | [Config](configs/molpcba/gnn-transformer/JK=cat/pooling=cls+gin+norm_input.yml) | 39 | | [NCI1](https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets) | GraphTrans (small, GCN) | — | 81.3±1.9 | [Config](configs/NCI1/gnn-transformer/no-virtual/gd=128+gdp=0.1+tdp=0.1+l=3+cosine.yml) | 40 | | | GraphTrans (large, GIN) | — | 82.6±1.2 | [Config](configs/NCI1/gnn-transformer/no-virtual/gin+gdp=0.1+tdp=0.1+l=4+cosine.yml) | 41 | | [NCI109](https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets) | GraphTrans (small, GCN) | — | 79.2±2.2 | [Config](configs/NCI109/gnn-transformer/no-virtual/ablation-pos_encoder) | 42 | | | GraphTrans (large, GIN) | — | 82.3±2.6 | [Config](configs/NCI109/gnn-transformer/no-virtual/gin+gdp=0.1+tdp=0.1+l=4+cosine.yml) | 43 | 44 | 45 | -------------------------------------------------------------------------------- /requirement.yml: -------------------------------------------------------------------------------- 1 | name: graph-aug 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - blas=1.0=mkl 8 | - ca-certificates=2021.1.19=h06a4308_1 9 | - certifi=2020.12.5=py38h06a4308_0 10 | - cudatoolkit=10.2.89=hfd86e86_1 11 | - freetype=2.10.4=h5ab3b9f_0 12 | - intel-openmp=2020.2=254 13 | - jpeg=9b=h024ee3a_2 14 | - lcms2=2.11=h396b838_0 15 | - ld_impl_linux-64=2.33.1=h53a641e_7 16 | - libedit=3.1.20191231=h14c3975_1 17 | - libffi=3.3=he6710b0_2 18 | - libgcc-ng=9.1.0=hdf63c60_0 19 | - libpng=1.6.37=hbc83047_0 20 | - libstdcxx-ng=9.1.0=hdf63c60_0 21 | - libtiff=4.1.0=h2733197_1 22 | - libuv=1.40.0=h7b6447c_0 23 | - lz4-c=1.9.3=h2531618_0 24 | - mkl=2020.2=256 25 | - mkl-service=2.3.0=py38he904b0f_0 26 | - mkl_fft=1.3.0=py38h54f3939_0 27 | - mkl_random=1.1.1=py38h0573a6f_0 28 | - ncurses=6.2=he6710b0_1 29 | - ninja=1.10.2=py38hff7bd54_0 30 | - numpy=1.19.2=py38h54aff64_0 31 | - numpy-base=1.19.2=py38hfa32c7d_0 32 | - olefile=0.46=py_0 33 | - openssl=1.1.1j=h27cfd23_0 34 | - pillow=8.1.2=py38he98fc37_0 35 | - pip=21.0.1=py38h06a4308_0 36 | - python=3.8.8=hdb3f193_4 37 | - pytorch=1.7.1=py3.8_cuda10.2.89_cudnn7.6.5_0 38 | - readline=8.1=h27cfd23_0 39 | - setuptools=52.0.0=py38h06a4308_0 40 | - six=1.15.0=py38h06a4308_0 41 | - sqlite=3.33.0=h62c20be_0 42 | - tk=8.6.10=hbc83047_0 43 | - torchaudio=0.7.2=py38 44 | - torchvision=0.8.2=py38_cu102 45 | - typing_extensions=3.7.4.3=pyha847dfd_0 46 | - wheel=0.36.2=pyhd3eb1b0_0 47 | - xz=5.2.5=h7b6447c_0 48 | - zlib=1.2.11=h7b6447c_3 49 | - zstd=1.4.5=h9ceee32_0 50 | - pip: 51 | - ase==3.21.1 52 | - chardet==4.0.0 53 | - click==7.1.2 54 | - configargparse==1.3 55 | - configparser==5.0.2 56 | - cycler==0.10.0 57 | - decorator==4.4.2 58 | - docker-pycreds==0.4.0 59 | - gitdb==4.0.5 60 | - gitpython==3.1.14 61 | - googledrivedownloader==0.4 62 | - h5py==3.2.1 63 | - idna==2.10 64 | - isodate==0.6.0 65 | - jinja2==2.11.3 66 | - joblib==1.0.1 67 | - kiwisolver==1.3.1 68 | - littleutils==0.2.2 69 | - llvmlite==0.35.0 70 | - loguru==0.5.3 71 | - markupsafe==1.1.1 72 | - matplotlib==3.3.4 73 | - networkx==2.5 74 | - numba==0.52.0 75 | - ogb==1.2.6 76 | - outdated==0.2.0 77 | - pandas==1.2.3 78 | - pathtools==0.1.2 79 | - promise==2.3 80 | - protobuf==3.15.5 81 | - psutil==5.8.0 82 | - pyparsing==2.4.7 83 | - python-dateutil==2.8.1 84 | - python-louvain==0.15 85 | - pytz==2021.1 86 | - pyyaml==5.4.1 87 | - rdflib==5.0.0 88 | - requests==2.25.1 89 | - scikit-learn==0.24.1 90 | - scipy==1.6.1 91 | - sentry-sdk==1.0.0 92 | - shortuuid==1.0.1 93 | - smmap==3.0.5 94 | - subprocess32==3.5.4 95 | - threadpoolctl==2.1.0 96 | - torch-cluster==1.5.9 97 | - torch-geometric==1.6.3 98 | - torch-scatter==2.0.6 99 | - torch-sparse==0.6.9 100 | - torch-spline-conv==1.2.1 101 | - tqdm==4.59.0 102 | - urllib3==1.26.3 103 | - wandb==0.10.22 104 | prefix: /data/zhwu/miniconda3/envs/graph-aug 105 | -------------------------------------------------------------------------------- /dataset/mol.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from loguru import logger 6 | from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder 7 | from torch_geometric.utils import degree 8 | from tqdm import tqdm 9 | 10 | 11 | class MolUtil: 12 | @staticmethod 13 | def add_args(parser): 14 | parser.add_argument("--feature", type=str, default="full", help="full feature or simple feature") 15 | parser.set_defaults(batch_size=32) 16 | parser.set_defaults(epochs=100) 17 | parser.set_defaults(gnn_dropout=0.5) 18 | 19 | @staticmethod 20 | def loss_fn(task_type): 21 | cls_criterion = torch.nn.BCEWithLogitsLoss() 22 | reg_criterion = torch.nn.MSELoss() 23 | 24 | def calc_loss(pred, batch, m=1.0): 25 | is_labeled = batch.y == batch.y 26 | if "classification" in task_type: 27 | loss = cls_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled]) 28 | else: 29 | loss = reg_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled]) 30 | loss /= m 31 | return loss 32 | 33 | return calc_loss 34 | 35 | @staticmethod 36 | def eval(model, device, loader, evaluator): 37 | model.eval() 38 | y_true = [] 39 | y_pred = [] 40 | 41 | for step, batch in enumerate(tqdm(loader, desc="Eval")): 42 | batch = batch.to(device) 43 | 44 | if batch.x.shape[0] == 1: 45 | pass 46 | else: 47 | with torch.no_grad(): 48 | pred = model(batch) 49 | 50 | y_true.append(batch.y.view(pred.shape).detach().cpu()) 51 | y_pred.append(pred.detach().cpu()) 52 | 53 | y_true = torch.cat(y_true, dim=0).numpy() 54 | y_pred = torch.cat(y_pred, dim=0).numpy() 55 | 56 | input_dict = {"y_true": y_true, "y_pred": y_pred} 57 | 58 | return evaluator.eval(input_dict) 59 | 60 | @staticmethod 61 | def preprocess(dataset, dataset_eval, model_cls, args): 62 | split_idx = dataset.get_idx_split() 63 | if args.feature == "full": 64 | pass 65 | elif args.feature == "simple": 66 | logger.debug("using simple feature") 67 | # only retain the top two node/edge features 68 | dataset.data.x = dataset.data.x[:, :2] 69 | dataset.data.edge_attr = dataset.data.edge_attr[:, :2] 70 | # Compute in-degree histogram over training data. 71 | deg = torch.zeros(10, dtype=torch.long) 72 | num_nodes = 0.0 73 | num_graphs = 0 74 | for data in dataset_eval[split_idx["train"]]: 75 | d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) 76 | deg += torch.bincount(d, minlength=deg.numel()) 77 | num_nodes += data.num_nodes 78 | num_graphs += 1 79 | args.deg = deg 80 | logger.debug("Avg num nodes: {}", num_nodes / num_graphs) 81 | logger.debug("Avg deg: {}", deg) 82 | 83 | node_encoder_cls = lambda: AtomEncoder(model_cls.get_emb_dim(args)) 84 | edge_encoder_cls = lambda emb_dim: BondEncoder(emb_dim=emb_dim) 85 | return dataset.num_tasks, node_encoder_cls, edge_encoder_cls, deg 86 | -------------------------------------------------------------------------------- /models/pna.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from ogb.graphproppred.mol_encoder import AtomEncoder 5 | from torch_geometric.nn import ( 6 | BatchNorm, 7 | GlobalAttention, 8 | PNAConv, 9 | Set2Set, 10 | global_add_pool, 11 | global_max_pool, 12 | global_mean_pool, 13 | ) 14 | 15 | from modules.pna.pna_module import PNANodeEmbedding 16 | 17 | from .base_model import BaseModel 18 | 19 | 20 | class PNANet(BaseModel): 21 | @staticmethod 22 | def get_emb_dim(args): 23 | return args.gnn_emb_dim 24 | 25 | @staticmethod 26 | def need_deg(): 27 | return True 28 | 29 | @staticmethod 30 | def add_args(parser): 31 | PNANodeEmbedding.add_args(parser) 32 | 33 | @staticmethod 34 | def name(args): 35 | name = f"{args.model_type}" 36 | return name 37 | 38 | def __init__(self, num_tasks, node_encoder, edge_encoder_cls, args): 39 | super().__init__() 40 | self.num_layer = args.gnn_num_layer 41 | self.num_tasks = num_tasks 42 | self.max_seq_len = args.max_seq_len 43 | self.aggregators = args.aggregators 44 | self.scalers = args.scalers 45 | self.residual = args.gnn_residual 46 | self.drop_ratio = args.gnn_dropout 47 | self.graph_pooling = args.graph_pooling 48 | 49 | self.node_encoder = node_encoder 50 | 51 | self.pna_module = PNANodeEmbedding(node_encoder, args) 52 | 53 | if self.max_seq_len is None: 54 | self.mlp = nn.Sequential( 55 | nn.Linear(args.gnn_emb_dim, 35, bias=True), 56 | nn.ReLU(), 57 | nn.Linear(35, 17, bias=True), 58 | nn.ReLU(), 59 | nn.Linear(17, self.num_tasks, bias=True), 60 | ) 61 | 62 | else: 63 | self.graph_pred_linear_list = torch.nn.ModuleList() 64 | for i in range(self.max_seq_len): 65 | self.graph_pred_linear_list.append( 66 | nn.Sequential( 67 | nn.Linear(args.gnn_emb_dim, args.gnn_emb_dim), 68 | nn.ReLU(), 69 | nn.Linear(args.gnn_emb_dim, self.num_tasks), 70 | ) 71 | ) 72 | 73 | ### Pooling function to generate whole-graph embeddings 74 | if self.graph_pooling == "sum": 75 | self.pool = global_add_pool 76 | elif self.graph_pooling == "mean": 77 | self.pool = global_mean_pool 78 | elif self.graph_pooling == "max": 79 | self.pool = global_max_pool 80 | elif self.graph_pooling == "attention": 81 | self.pool = GlobalAttention( 82 | gate_nn=torch.nn.Sequential( 83 | torch.nn.Linear(args.gnn_emb_dim, 2 * args.gnn_emb_dim), 84 | torch.nn.BatchNorm1d(2 * args.gnn_emb_dim), 85 | torch.nn.ReLU(), 86 | torch.nn.Linear(2 * args.gnn_emb_dim, 1), 87 | ) 88 | ) 89 | elif self.graph_pooling == "set2set": 90 | self.pool = Set2Set(args.gnn_emb_dim, processing_steps=2) 91 | else: 92 | raise ValueError("Invalid graph pooling type.") 93 | 94 | def forward(self, batched_data, perturb=None): 95 | x = self.pna_module(batched_data, perturb) 96 | 97 | h_graph = self.pool(x, batched_data.batch) 98 | 99 | if self.max_seq_len is None: 100 | return self.mlp(h_graph) 101 | pred_list = [] 102 | for i in range(self.max_seq_len): 103 | pred_list.append(self.graph_pred_linear_list[i](h_graph)) 104 | return pred_list 105 | -------------------------------------------------------------------------------- /models/gnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch_geometric.nn import ( 4 | GlobalAttention, 5 | Set2Set, 6 | global_add_pool, 7 | global_max_pool, 8 | global_mean_pool, 9 | ) 10 | 11 | from modules.gnn_module import GNNNodeEmbedding 12 | 13 | from .base_model import BaseModel 14 | 15 | 16 | class GNN(BaseModel): 17 | @staticmethod 18 | def get_emb_dim(args): 19 | return args.gnn_emb_dim 20 | 21 | @staticmethod 22 | def add_args(parser): 23 | return 24 | 25 | @staticmethod 26 | def name(args): 27 | name = f"{args.model_type}+{args.gnn_type}" 28 | name += "-virtual" if args.gnn_virtual_node else "" 29 | return name 30 | 31 | def __init__(self, num_tasks, node_encoder, edge_encoder_cls, args): 32 | """ 33 | num_tasks (int): number of labels to be predicted 34 | virtual_node (bool): whether to add virtual node or not 35 | """ 36 | 37 | super(GNN, self).__init__() 38 | 39 | self.num_layer = args.gnn_num_layer 40 | self.drop_ratio = args.gnn_dropout 41 | self.JK = args.gnn_JK 42 | self.emb_dim = args.gnn_emb_dim 43 | self.num_tasks = num_tasks 44 | self.max_seq_len = args.max_seq_len 45 | self.graph_pooling = args.graph_pooling 46 | 47 | if self.num_layer < 2: 48 | raise ValueError("Number of GNN layers must be greater than 1.") 49 | 50 | ### GNN to generate node embeddings 51 | self.gnn_node = GNNNodeEmbedding( 52 | args.gnn_virtual_node, 53 | self.num_layer, 54 | self.emb_dim, 55 | node_encoder, 56 | edge_encoder_cls, 57 | JK=self.JK, 58 | drop_ratio=self.drop_ratio, 59 | residual=args.gnn_residual, 60 | gnn_type=args.gnn_type, 61 | ) 62 | 63 | ### Pooling function to generate whole-graph embeddings 64 | if self.graph_pooling == "sum": 65 | self.pool = global_add_pool 66 | elif self.graph_pooling == "mean": 67 | self.pool = global_mean_pool 68 | elif self.graph_pooling == "max": 69 | self.pool = global_max_pool 70 | elif self.graph_pooling == "attention": 71 | self.pool = GlobalAttention( 72 | gate_nn=torch.nn.Sequential( 73 | torch.nn.Linear(self.emb_dim, 2 * self.emb_dim), 74 | torch.nn.BatchNorm1d(2 * self.emb_dim), 75 | torch.nn.ReLU(), 76 | torch.nn.Linear(2 * self.emb_dim, 1), 77 | ) 78 | ) 79 | elif self.graph_pooling == "set2set": 80 | self.pool = Set2Set(self.emb_dim, processing_steps=2) 81 | else: 82 | raise ValueError("Invalid graph pooling type.") 83 | 84 | if self.max_seq_len is None: 85 | if self.graph_pooling == "set2set": 86 | self.graph_pred_linear = torch.nn.Linear(2 * self.emb_dim, self.num_tasks) 87 | else: 88 | self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks) 89 | else: 90 | self.graph_pred_linear_list = torch.nn.ModuleList() 91 | if self.graph_pooling == "set2set": 92 | for i in range(self.max_seq_len): 93 | self.graph_pred_linear_list.append(torch.nn.Linear(2 * self.emb_dim, self.num_tasks)) 94 | else: 95 | for i in range(self.max_seq_len): 96 | self.graph_pred_linear_list.append(torch.nn.Linear(self.emb_dim, self.num_tasks)) 97 | 98 | def forward(self, batched_data, perturb=None): 99 | """ 100 | Return: 101 | A (list of) predictions. 102 | i-th element represents prediction at i-th position of the sequence. 103 | """ 104 | 105 | h_node = self.gnn_node(batched_data, perturb) 106 | 107 | h_graph = self.pool(h_node, batched_data.batch) 108 | 109 | if self.max_seq_len is None: 110 | return self.graph_pred_linear(h_graph) 111 | pred_list = [] 112 | for i in range(self.max_seq_len): 113 | pred_list.append(self.graph_pred_linear_list[i](h_graph)) 114 | 115 | return pred_list 116 | 117 | 118 | if __name__ == "__main__": 119 | pass 120 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch_geometric.nn import ( 6 | GlobalAttention, 7 | Set2Set, 8 | global_add_pool, 9 | global_max_pool, 10 | global_mean_pool, 11 | ) 12 | 13 | from modules.gnn_module import GNNNodeEmbedding 14 | from modules.transformer_encoder import TransformerNodeEncoder 15 | from modules.utils import pad_batch, unpad_batch 16 | 17 | from .base_model import BaseModel 18 | 19 | 20 | class Transformer(BaseModel): 21 | @staticmethod 22 | def get_emb_dim(args): 23 | return args.d_model 24 | 25 | @staticmethod 26 | def add_args(parser): 27 | TransformerNodeEncoder.add_args(parser) 28 | 29 | @staticmethod 30 | def name(args): 31 | name = f"{args.model_type}-pooling={args.graph_pooling}" 32 | name += f"+{args.gnn_type}" 33 | name += "-virtual" if args.gnn_virtual_node else "" 34 | name += f"-d={args.d_model}" 35 | name += f"-tdp={args.transformer_dropout}" 36 | return name 37 | 38 | def __init__(self, num_tasks, node_encoder, edge_encoder_cls, args): 39 | super().__init__() 40 | self.transformer = TransformerNodeEncoder(args) 41 | 42 | self.node_encoder = node_encoder 43 | 44 | self.emb_dim = args.d_model 45 | self.num_tasks = num_tasks 46 | self.max_seq_len = args.max_seq_len 47 | self.graph_pooling = args.graph_pooling 48 | 49 | ### Pooling function to generate whole-graph embeddings 50 | if self.graph_pooling == "sum": 51 | self.pool = global_add_pool 52 | elif self.graph_pooling == "mean": 53 | self.pool = global_mean_pool 54 | elif self.graph_pooling == "max": 55 | self.pool = global_max_pool 56 | elif self.graph_pooling == "attention": 57 | self.pool = GlobalAttention( 58 | gate_nn=torch.nn.Sequential( 59 | torch.nn.Linear(self.emb_dim, 2 * self.emb_dim), 60 | torch.nn.BatchNorm1d(2 * self.emb_dim), 61 | torch.nn.ReLU(), 62 | torch.nn.Linear(2 * self.emb_dim, 1), 63 | ) 64 | ) 65 | elif self.graph_pooling == "set2set": 66 | self.pool = Set2Set(self.emb_dim, processing_steps=2) 67 | elif self.graph_pooling == "cls": 68 | self.pool = None 69 | else: 70 | raise ValueError("Invalid graph pooling type.") 71 | 72 | if self.max_seq_len is None: 73 | if self.graph_pooling == "set2set": 74 | self.graph_pred_linear = torch.nn.Linear(2 * self.emb_dim, self.num_tasks) 75 | else: 76 | self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks) 77 | else: 78 | self.graph_pred_linear_list = torch.nn.ModuleList() 79 | if self.graph_pooling == "set2set": 80 | for i in range(self.max_seq_len): 81 | self.graph_pred_linear_list.append(torch.nn.Linear(2 * self.emb_dim, self.num_tasks)) 82 | else: 83 | for i in range(self.max_seq_len): 84 | self.graph_pred_linear_list.append(torch.nn.Linear(self.emb_dim, self.num_tasks)) 85 | 86 | def forward(self, batched_data, perturb=None): 87 | x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch 88 | node_depth = batched_data.node_depth if hasattr(batched_data, "node_depth") else None 89 | encoded_node = ( 90 | self.node_encoder(x) 91 | if node_depth is None 92 | else self.node_encoder( 93 | x, 94 | node_depth.view( 95 | -1, 96 | ), 97 | ) 98 | ) 99 | tmp = encoded_node + perturb if perturb is not None else encoded_node 100 | 101 | h_node, src_key_padding_mask, num_nodes, mask, max_num_nodes = pad_batch(tmp, batch, self.transformer.max_input_len, get_mask=True) 102 | h_node, src_key_padding_mask = self.transformer(h_node, src_key_padding_mask) 103 | if self.graph_pooling == "cls": 104 | h_graph = h_node[-1] 105 | else: 106 | h_node = unpad_batch(h_node, tmp, num_nodes, mask, max_num_nodes) 107 | h_graph = self.pool(h_node, batched_data.batch) 108 | 109 | if self.max_seq_len is None: 110 | return self.graph_pred_linear(h_graph) 111 | pred_list = [] 112 | for i in range(self.max_seq_len): 113 | pred_list.append(self.graph_pred_linear_list[i](h_graph)) 114 | 115 | return pred_list 116 | -------------------------------------------------------------------------------- /models/pna_transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from loguru import logger 6 | 7 | from modules.gnn_module import GNNNodeEmbedding 8 | from modules.pna.pna_module import PNANodeEmbedding 9 | from modules.transformer_encoder import TransformerNodeEncoder 10 | from modules.utils import pad_batch 11 | 12 | from .base_model import BaseModel 13 | 14 | 15 | class PNATransformer(BaseModel): 16 | @staticmethod 17 | def get_emb_dim(args): 18 | return args.gnn_emb_dim 19 | 20 | @staticmethod 21 | def need_deg(): 22 | return True 23 | 24 | @staticmethod 25 | def add_args(parser): 26 | TransformerNodeEncoder.add_args(parser) 27 | PNANodeEmbedding.add_args(parser) 28 | 29 | group = parser.add_argument_group("GNNTransformer - Training Config") 30 | group.add_argument("--pretrained_gnn", type=str, default=None, help="pretrained gnn_node node embedding path") 31 | # group.add_argument('--drop_last_pretrained', action='store_true', default=False, help='drop the last layer for the pretrained model') 32 | group.add_argument("--freeze_gnn", type=int, default=None, help="Freeze gnn_node weight from epoch `freeze_gnn`") 33 | 34 | @staticmethod 35 | def name(args): 36 | name = f"{args.model_type}-pooling={args.graph_pooling}" 37 | name += "-norm_input" if args.transformer_norm_input else "" 38 | name += f"+{args.gnn_type}" 39 | name += "-virtual" if args.gnn_virtual_node else "" 40 | name += f"-JK={args.gnn_JK}" 41 | name += f"-enc_layer={args.num_encoder_layers}" 42 | name += f"-d={args.d_model}" 43 | name += f"-act={args.transformer_activation}" 44 | name += f"-tdrop={args.transformer_dropout}" 45 | name += f"-gdrop={args.gnn_dropout}" 46 | name += "-pretrained_gnn" if args.pretrained_gnn else "" 47 | name += f"-freeze_gnn={args.freeze_gnn}" if args.freeze_gnn is not None else "" 48 | return name 49 | 50 | def __init__(self, num_tasks, node_encoder, edge_encoder_cls, args): 51 | super().__init__() 52 | self.gnn_node = PNANodeEmbedding(node_encoder, args) 53 | if args.pretrained_gnn: 54 | # logger.info(self.gnn_node) 55 | state_dict = torch.load(args.pretrained_gnn) 56 | state_dict = self._gnn_node_state(state_dict["model"]) 57 | logger.info("Load GNN state from: {}", state_dict.keys()) 58 | self.gnn_node.load_state_dict(state_dict) 59 | self.freeze_gnn = args.freeze_gnn 60 | 61 | gnn_emb_dim = 2 * args.gnn_emb_dim if args.gnn_JK == "cat" else args.gnn_emb_dim 62 | self.gnn2transformer = nn.Linear(gnn_emb_dim, args.d_model) 63 | self.transformer_encoder = TransformerNodeEncoder(args) 64 | 65 | self.num_tasks = num_tasks 66 | self.pooling = args.graph_pooling 67 | self.graph_pred_linear_list = torch.nn.ModuleList() 68 | 69 | self.max_seq_len = args.max_seq_len 70 | output_dim = args.d_model 71 | 72 | if args.max_seq_len is None: 73 | self.graph_pred_linear = torch.nn.Linear(output_dim, self.num_tasks) 74 | else: 75 | for i in range(args.max_seq_len): 76 | self.graph_pred_linear_list.append(torch.nn.Linear(output_dim, self.num_tasks)) 77 | 78 | def forward(self, batched_data, perturb=None): 79 | h_node = self.gnn_node(batched_data, perturb) 80 | h_node = self.gnn2transformer(h_node) # [s, b, d_model] 81 | 82 | padded_h_node, src_padding_mask = pad_batch(h_node, batched_data.batch, self.transformer_encoder.max_input_len) # Pad in the front 83 | 84 | transformer_out, mask = self.transformer_encoder(padded_h_node, src_padding_mask) # [s, b, h], [b, s] 85 | 86 | if self.pooling in ["last", "cls"]: 87 | h_graph = transformer_out[-1] 88 | elif self.pooling == "mean": 89 | h_graph = transformer_out.sum(0) / (~mask).sum(-1, keepdim=True) 90 | else: 91 | raise NotImplementedError 92 | 93 | if self.max_seq_len is None: 94 | out = self.graph_pred_linear(h_graph) 95 | return out 96 | pred_list = [] 97 | for i in range(self.max_seq_len): 98 | pred_list.append(self.graph_pred_linear_list[i](h_graph)) 99 | 100 | return pred_list 101 | 102 | def epoch_callback(self, epoch): 103 | # TODO: maybe unfreeze the gnn at the end. 104 | if self.freeze_gnn is not None and epoch >= self.freeze_gnn: 105 | logger.info(f"Freeze GNN weight after epoch: {epoch}") 106 | for param in self.gnn_node.parameters(): 107 | param.requires_grad = False 108 | 109 | def _gnn_node_state(self, state_dict): 110 | module_name = "gnn_node" 111 | new_state_dict = dict() 112 | for k, v in state_dict.items(): 113 | if module_name in k: 114 | new_key = k.split(".") 115 | module_index = new_key.index(module_name) 116 | new_key = ".".join(new_key[module_index + 1 :]) 117 | new_state_dict[new_key] = v 118 | return new_state_dict 119 | -------------------------------------------------------------------------------- /dataset/code.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch_geometric.utils import degree 9 | from torchvision import transforms 10 | from tqdm import tqdm 11 | 12 | from loguru import logger 13 | # for data transform 14 | # importing utils 15 | from .utils import ( 16 | ASTNodeEncoder, 17 | augment_edge, 18 | decode_arr_to_seq, 19 | encode_y_to_arr, 20 | get_vocab_mapping, 21 | ) 22 | 23 | 24 | class CodeUtil: 25 | def __init__(self): 26 | self.arr_to_seq = None 27 | 28 | @staticmethod 29 | def add_args(parser): 30 | parser.add_argument( 31 | "--num_vocab", type=int, default=5000, help="the number of vocabulary used for sequence prediction (default: 5000)" 32 | ) 33 | parser.set_defaults(max_seq_len=5) 34 | 35 | @staticmethod 36 | def loss_fn(_): 37 | multicls_criterion = torch.nn.CrossEntropyLoss() 38 | 39 | def calc_loss(pred_list, batch, m=1.0): 40 | loss = 0 41 | for i in range(len(pred_list)): 42 | loss += multicls_criterion(pred_list[i].to(torch.float32), batch.y_arr[:, i]) 43 | loss = loss / len(pred_list) 44 | loss /= m 45 | return loss 46 | 47 | return calc_loss 48 | 49 | def eval(self, model, device, loader, evaluator): 50 | model.eval() 51 | seq_ref_list = [] 52 | seq_pred_list = [] 53 | 54 | for step, batch in enumerate(tqdm(loader, desc="Eval")): 55 | batch = batch.to(device) 56 | 57 | if batch.x.shape[0] == 1: 58 | pass 59 | else: 60 | with torch.no_grad(): 61 | pred_list = model(batch) 62 | 63 | mat = [] 64 | for i in range(len(pred_list)): 65 | mat.append(torch.argmax(pred_list[i], dim=1).view(-1, 1)) 66 | mat = torch.cat(mat, dim=1) 67 | 68 | seq_pred = [self.arr_to_seq(arr) for arr in mat] 69 | 70 | # PyG >= 1.5.0 71 | seq_ref = [batch.y[i] for i in range(len(batch.y))] 72 | 73 | seq_ref_list.extend(seq_ref) 74 | seq_pred_list.extend(seq_pred) 75 | 76 | input_dict = {"seq_ref": seq_ref_list, "seq_pred": seq_pred_list} 77 | 78 | return evaluator.eval(input_dict) 79 | 80 | def preprocess(self, dataset, dataset_eval, model_cls, args): 81 | split_idx = dataset.get_idx_split() 82 | seq_len_list = np.array([len(seq) for seq in dataset.data.y]) 83 | print( 84 | "Target seqence less or equal to {} is {}%.".format( 85 | args.max_seq_len, np.sum(seq_len_list <= args.max_seq_len) / len(seq_len_list) 86 | ) 87 | ) 88 | 89 | # building vocabulary for sequence predition. Only use training data. 90 | vocab2idx, idx2vocab = get_vocab_mapping([dataset.data.y[i] for i in split_idx["train"]], args.num_vocab) 91 | 92 | self.arr_to_seq = lambda arr: decode_arr_to_seq(arr, idx2vocab) 93 | 94 | # set the transform function 95 | # augment_edge: add next-token edge as well as inverse edges. add edge attributes. 96 | # encode_y_to_arr: add y_arr to PyG data object, indicating the array representation of a sequence. 97 | dataset_transform = [augment_edge, lambda data: encode_y_to_arr(data, vocab2idx, args.max_seq_len)] 98 | dataset_eval.transform = transforms.Compose(dataset_transform) 99 | if dataset.transform is not None: 100 | dataset_transform.append(dataset.transform) 101 | dataset.transform = transforms.Compose(dataset_transform) 102 | 103 | nodetypes_mapping = pd.read_csv(os.path.join(dataset.root, "mapping", "typeidx2type.csv.gz")) 104 | nodeattributes_mapping = pd.read_csv(os.path.join(dataset.root, "mapping", "attridx2attr.csv.gz")) 105 | 106 | # Encoding node features into gnn_emb_dim vectors. 107 | # The following three node features are used. 108 | # 1. node type 109 | # 2. node attribute 110 | # 3. node depth 111 | node_encoder_cls = lambda: ASTNodeEncoder( 112 | args.gnn_emb_dim, 113 | num_nodetypes=len(nodetypes_mapping["type"]), 114 | num_nodeattributes=len(nodeattributes_mapping["attr"]), 115 | max_depth=20, 116 | ) 117 | edge_encoder_cls = lambda emb_dim: nn.Linear(2, emb_dim) 118 | 119 | deg = None 120 | # Compute in-degree histogram over training data. 121 | if model_cls.need_deg(): 122 | deg = torch.zeros(800, dtype=torch.long) 123 | num_nodes = 0.0 124 | num_graphs = 0 125 | for data in dataset_eval[split_idx["train"]]: 126 | d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) 127 | deg += torch.bincount(d, minlength=deg.numel()) 128 | num_nodes += data.num_nodes 129 | num_graphs += 1 130 | args.deg = deg 131 | logger.debug("Avg num nodes: {}", num_nodes / num_graphs) 132 | logger.debug("Avg deg: {}", deg) 133 | return len(vocab2idx), node_encoder_cls, edge_encoder_cls, deg 134 | -------------------------------------------------------------------------------- /modules/masked_transformer_encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from loguru import logger 8 | 9 | 10 | class CausalSelfAttention(nn.Module): 11 | """ 12 | A vanilla multi-head masked self-attention layer with a projection at the end. 13 | It is possible to use torch.nn.MultiheadAttention here but I am including an 14 | explicit implementation here to show that there is nothing too scary here. 15 | """ 16 | 17 | def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop): 18 | super().__init__() 19 | assert n_embd % n_head == 0 20 | # key, query, value projections for all heads 21 | self.key = nn.Linear(n_embd, n_embd) 22 | self.query = nn.Linear(n_embd, n_embd) 23 | self.value = nn.Linear(n_embd, n_embd) 24 | # regularization 25 | self.attn_drop = nn.Dropout(attn_pdrop) 26 | self.resid_drop = nn.Dropout(resid_pdrop) 27 | # output projection 28 | self.proj = nn.Linear(n_embd, n_embd) 29 | 30 | self.n_head = n_head 31 | 32 | def forward(self, x, attn_mask: torch.Tensor = None, valid_input_mask: torch.Tensor = None, mask_value=-1e6): 33 | """mask should be a 3D tensor of shape (B, T, T)""" 34 | B, T, C = x.size() 35 | 36 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 37 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 38 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 39 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 40 | 41 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 42 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 43 | 44 | if attn_mask is not None: 45 | att = att.masked_fill(attn_mask.unsqueeze(1) == 0, mask_value) 46 | if valid_input_mask is not None: 47 | att = att.masked_fill(valid_input_mask.unsqueeze(1).unsqueeze(2) == 0, mask_value) 48 | 49 | att = F.softmax(att, dim=-1) 50 | att = self.attn_drop(att) 51 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 52 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 53 | 54 | # output projection 55 | y = self.resid_drop(self.proj(y)) 56 | return y 57 | 58 | 59 | class Block(nn.Module): 60 | def __init__(self, n_embd, n_ff, n_head, attn_pdrop, resid_pdrop, prenorm=True): 61 | super().__init__() 62 | self.prenorm = prenorm 63 | self.ln1 = nn.LayerNorm(n_embd) 64 | self.ln2 = nn.LayerNorm(n_embd) 65 | self.attn = CausalSelfAttention(n_embd, n_head, attn_pdrop, resid_pdrop) 66 | self.mlp = nn.Sequential( 67 | nn.Linear(n_embd, n_ff), 68 | nn.GELU(), 69 | nn.Linear(n_ff, n_embd), 70 | nn.Dropout(resid_pdrop), 71 | ) 72 | 73 | def forward(self, x, attn_mask=None, valid_input_mask=None): 74 | if self.prenorm: 75 | x = x + self.attn(self.ln1(x), attn_mask, valid_input_mask) 76 | x = x + self.mlp(self.ln2(x)) 77 | else: 78 | x = self.ln1(x + self.attn(x, attn_mask, valid_input_mask)) 79 | x = self.ln2(x + self.mlp(x)) 80 | return x 81 | 82 | 83 | class MaskedTransformerBlock(nn.Module): 84 | def __init__(self, n_layer, n_embd, n_ff, n_head, attn_pdrop, resid_pdrop, prenorm=True): 85 | super().__init__() 86 | self.blocks = nn.ModuleList([Block(n_embd, n_ff, n_head, attn_pdrop, resid_pdrop, prenorm) for _ in range(n_layer)]) 87 | # self.apply(self._init_weights) 88 | 89 | # def _init_weights(self, module): 90 | # if isinstance(module, (nn.Linear, nn.Embedding)): 91 | # module.weight.data.normal_(mean=0.0, std=0.02) 92 | # if isinstance(module, nn.Linear) and module.bias is not None: 93 | # module.bias.data.zero_() 94 | # elif isinstance(module, nn.LayerNorm): 95 | # module.bias.data.zero_() 96 | # module.weight.data.fill_(1.0) 97 | 98 | def forward(self, x, attn_mask=None, valid_input_mask=None): 99 | for block in self.blocks: 100 | x = block(x, attn_mask, valid_input_mask) 101 | return x 102 | 103 | 104 | class MaskedOnlyTransformerEncoder(nn.Module): 105 | @staticmethod 106 | def add_args(parser): 107 | group = parser.add_argument_group("Masked Transformer Encoder -- architecture config") 108 | group.add_argument("--num_encoder_layers_masked", type=int, default=0) 109 | group.add_argument("--transformer_prenorm", action="store_true", default=False) 110 | 111 | def __init__(self, args): 112 | super().__init__() 113 | self.max_input_len = args.max_input_len 114 | self.masked_transformer = MaskedTransformerBlock( 115 | args.num_encoder_layers_masked, 116 | args.d_model, 117 | args.dim_feedforward, 118 | args.nhead, 119 | args.transformer_dropout, 120 | args.transformer_dropout, 121 | ) 122 | logger.info("number of parameters: %e" % sum(p.numel() for p in self.parameters())) 123 | 124 | def forward(self, x, attn_mask=None, valid_input_mask=None): 125 | """ 126 | padded_h_node: n_b x B x h_d 127 | src_key_padding_mask: B x n_b 128 | """ 129 | x = self.masked_transformer(x, attn_mask=attn_mask, valid_input_mask=valid_input_mask) 130 | return x 131 | -------------------------------------------------------------------------------- /models/gnn_transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from loguru import logger 6 | 7 | import math 8 | from modules.gnn_module import GNNNodeEmbedding 9 | from modules.masked_transformer_encoder import MaskedOnlyTransformerEncoder 10 | from modules.transformer_encoder import TransformerNodeEncoder 11 | from modules.utils import pad_batch 12 | 13 | from .base_model import BaseModel 14 | 15 | 16 | class GNNTransformer(BaseModel): 17 | @staticmethod 18 | def get_emb_dim(args): 19 | return args.gnn_emb_dim 20 | 21 | @staticmethod 22 | def add_args(parser): 23 | TransformerNodeEncoder.add_args(parser) 24 | MaskedOnlyTransformerEncoder.add_args(parser) 25 | group = parser.add_argument_group("GNNTransformer - Training Config") 26 | group.add_argument("--pos_encoder", default=False, action='store_true') 27 | group.add_argument("--pretrained_gnn", type=str, default=None, help="pretrained gnn_node node embedding path") 28 | group.add_argument("--freeze_gnn", type=int, default=None, help="Freeze gnn_node weight from epoch `freeze_gnn`") 29 | 30 | @staticmethod 31 | def name(args): 32 | name = f"{args.model_type}-pooling={args.graph_pooling}" 33 | name += "-norm_input" if args.transformer_norm_input else "" 34 | name += f"+{args.gnn_type}" 35 | name += "-virtual" if args.gnn_virtual_node else "" 36 | name += f"-JK={args.gnn_JK}" 37 | name += f"-enc_layer={args.num_encoder_layers}" 38 | name += f"-enc_layer_masked={args.num_encoder_layers_masked}" 39 | name += f"-d={args.d_model}" 40 | name += f"-act={args.transformer_activation}" 41 | name += f"-tdrop={args.transformer_dropout}" 42 | name += f"-gdrop={args.gnn_dropout}" 43 | name += "-pretrained_gnn" if args.pretrained_gnn else "" 44 | name += f"-freeze_gnn={args.freeze_gnn}" if args.freeze_gnn is not None else "" 45 | name += "-prenorm" if args.transformer_prenorm else "-postnorm" 46 | return name 47 | 48 | def __init__(self, num_tasks, node_encoder, edge_encoder_cls, args): 49 | super().__init__() 50 | self.gnn_node = GNNNodeEmbedding( 51 | args.gnn_virtual_node, 52 | args.gnn_num_layer, 53 | args.gnn_emb_dim, 54 | node_encoder, 55 | edge_encoder_cls, 56 | JK=args.gnn_JK, 57 | drop_ratio=args.gnn_dropout, 58 | residual=args.gnn_residual, 59 | gnn_type=args.gnn_type, 60 | ) 61 | if args.pretrained_gnn: 62 | # logger.info(self.gnn_node) 63 | state_dict = torch.load(args.pretrained_gnn) 64 | state_dict = self._gnn_node_state(state_dict["model"]) 65 | logger.info("Load GNN state from: {}", state_dict.keys()) 66 | self.gnn_node.load_state_dict(state_dict) 67 | self.freeze_gnn = args.freeze_gnn 68 | 69 | gnn_emb_dim = 2 * args.gnn_emb_dim if args.gnn_JK == "cat" else args.gnn_emb_dim 70 | self.gnn2transformer = nn.Linear(gnn_emb_dim, args.d_model) 71 | self.pos_encoder = PositionalEncoding(args.d_model, dropout=0) if args.pos_encoder else None 72 | self.transformer_encoder = TransformerNodeEncoder(args) 73 | self.masked_transformer_encoder = MaskedOnlyTransformerEncoder(args) 74 | self.num_encoder_layers = args.num_encoder_layers 75 | self.num_encoder_layers_masked = args.num_encoder_layers_masked 76 | 77 | self.num_tasks = num_tasks 78 | self.pooling = args.graph_pooling 79 | self.graph_pred_linear_list = torch.nn.ModuleList() 80 | 81 | self.max_seq_len = args.max_seq_len 82 | output_dim = args.d_model 83 | 84 | if args.max_seq_len is None: 85 | self.graph_pred_linear = torch.nn.Linear(output_dim, self.num_tasks) 86 | else: 87 | for i in range(args.max_seq_len): 88 | self.graph_pred_linear_list.append(torch.nn.Linear(output_dim, self.num_tasks)) 89 | 90 | def forward(self, batched_data, perturb=None): 91 | h_node = self.gnn_node(batched_data, perturb) 92 | h_node = self.gnn2transformer(h_node) # [s, b, d_model] 93 | 94 | padded_h_node, src_padding_mask, num_nodes, mask, max_num_nodes = pad_batch( 95 | h_node, batched_data.batch, self.transformer_encoder.max_input_len, get_mask=True 96 | ) # Pad in the front 97 | 98 | # TODO(paras): implement mask 99 | transformer_out = padded_h_node 100 | if self.pos_encoder is not None: 101 | transformer_out = self.pos_encoder(transformer_out) 102 | if self.num_encoder_layers_masked > 0: 103 | adj_list = batched_data.adj_list 104 | padded_adj_list = torch.zeros((len(adj_list), max_num_nodes, max_num_nodes), device=h_node.device) 105 | for idx, adj_list_item in enumerate(adj_list): 106 | N, _ = adj_list_item.shape 107 | padded_adj_list[idx, 0:N, 0:N] = torch.from_numpy(adj_list_item) 108 | transformer_out = self.masked_transformer_encoder( 109 | transformer_out.transpose(0, 1), attn_mask=padded_adj_list, valid_input_mask=src_padding_mask 110 | ).transpose(0, 1) 111 | if self.num_encoder_layers > 0: 112 | transformer_out, _ = self.transformer_encoder(transformer_out, src_padding_mask) # [s, b, h], [b, s] 113 | 114 | if self.pooling in ["last", "cls"]: 115 | h_graph = transformer_out[-1] 116 | elif self.pooling == "mean": 117 | h_graph = transformer_out.sum(0) / src_padding_mask.sum(-1, keepdim=True) 118 | else: 119 | raise NotImplementedError 120 | 121 | if self.max_seq_len is None: 122 | out = self.graph_pred_linear(h_graph) 123 | return out 124 | pred_list = [] 125 | for i in range(self.max_seq_len): 126 | pred_list.append(self.graph_pred_linear_list[i](h_graph)) 127 | 128 | return pred_list 129 | 130 | def epoch_callback(self, epoch): 131 | # TODO: maybe unfreeze the gnn at the end. 132 | if self.freeze_gnn is not None and epoch >= self.freeze_gnn: 133 | logger.info(f"Freeze GNN weight after epoch: {epoch}") 134 | for param in self.gnn_node.parameters(): 135 | param.requires_grad = False 136 | 137 | def _gnn_node_state(self, state_dict): 138 | module_name = "gnn_node" 139 | new_state_dict = dict() 140 | for k, v in state_dict.items(): 141 | if module_name in k: 142 | new_key = k.split(".") 143 | module_index = new_key.index(module_name) 144 | new_key = ".".join(new_key[module_index + 1 :]) 145 | new_state_dict[new_key] = v 146 | return new_state_dict 147 | 148 | 149 | class PositionalEncoding(nn.Module): 150 | 151 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): 152 | super().__init__() 153 | self.dropout = nn.Dropout(p=dropout) 154 | 155 | position = torch.arange(max_len).unsqueeze(1) 156 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 157 | pe = torch.zeros(max_len, 1, d_model) 158 | pe[:, 0, 0::2] = torch.sin(position * div_term) 159 | pe[:, 0, 1::2] = torch.cos(position * div_term) 160 | self.register_buffer('pe', pe) 161 | 162 | def forward(self, x): 163 | """ 164 | Args: 165 | x: Tensor, shape [seq_len, batch_size, embedding_dim] 166 | """ 167 | x = x + self.pe[:x.size(0)] 168 | return self.dropout(x) -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | import numpy as np 4 | import torch 5 | from loguru import logger 6 | 7 | 8 | class ASTNodeEncoder(torch.nn.Module): 9 | """ 10 | Input: 11 | x: default node feature. the first and second column represents node type and node attributes. 12 | depth: The depth of the node in the AST. 13 | 14 | Output: 15 | emb_dim-dimensional vector 16 | 17 | """ 18 | 19 | def __init__(self, emb_dim, num_nodetypes, num_nodeattributes, max_depth): 20 | super(ASTNodeEncoder, self).__init__() 21 | 22 | self.max_depth = max_depth 23 | 24 | self.type_encoder = torch.nn.Embedding(num_nodetypes, emb_dim) 25 | self.attribute_encoder = torch.nn.Embedding(num_nodeattributes, emb_dim) 26 | self.depth_encoder = torch.nn.Embedding(self.max_depth + 1, emb_dim) 27 | 28 | def forward(self, x, depth): 29 | depth[depth > self.max_depth] = self.max_depth 30 | return self.type_encoder(x[:, 0]) + self.attribute_encoder(x[:, 1]) + self.depth_encoder(depth) 31 | 32 | 33 | def get_vocab_mapping(seq_list, num_vocab): 34 | """ 35 | Input: 36 | seq_list: a list of sequences 37 | num_vocab: vocabulary size 38 | Output: 39 | vocab2idx: 40 | A dictionary that maps vocabulary into integer index. 41 | Additioanlly, we also index '__UNK__' and '__EOS__' 42 | '__UNK__' : out-of-vocabulary term 43 | '__EOS__' : end-of-sentence 44 | 45 | idx2vocab: 46 | A list that maps idx to actual vocabulary. 47 | 48 | """ 49 | 50 | vocab_cnt = {} 51 | vocab_list = [] 52 | for seq in seq_list: 53 | for w in seq: 54 | if w in vocab_cnt: 55 | vocab_cnt[w] += 1 56 | else: 57 | vocab_cnt[w] = 1 58 | vocab_list.append(w) 59 | 60 | cnt_list = np.array([vocab_cnt[w] for w in vocab_list]) 61 | topvocab = np.argsort(-cnt_list, kind="stable")[:num_vocab] 62 | 63 | logger.info("Coverage of top {} vocabulary: {:.4f}", num_vocab, float(np.sum(cnt_list[topvocab])) / np.sum(cnt_list)) 64 | 65 | vocab2idx = {vocab_list[vocab_idx]: idx for idx, vocab_idx in enumerate(topvocab)} 66 | idx2vocab = [vocab_list[vocab_idx] for vocab_idx in topvocab] 67 | 68 | # logger.info(topvocab) 69 | # logger.info([vocab_list[v] for v in topvocab[:10]]) 70 | # logger.info([vocab_list[v] for v in topvocab[-10:]]) 71 | 72 | vocab2idx["__UNK__"] = num_vocab 73 | idx2vocab.append("__UNK__") 74 | 75 | vocab2idx["__EOS__"] = num_vocab + 1 76 | idx2vocab.append("__EOS__") 77 | 78 | # test the correspondence between vocab2idx and idx2vocab 79 | for idx, vocab in enumerate(idx2vocab): 80 | assert idx == vocab2idx[vocab] 81 | 82 | # test that the idx of '__EOS__' is len(idx2vocab) - 1. 83 | # This fact will be used in decode_arr_to_seq, when finding __EOS__ 84 | assert vocab2idx["__EOS__"] == len(idx2vocab) - 1 85 | 86 | return vocab2idx, idx2vocab 87 | 88 | 89 | def augment_edge(data): 90 | """ 91 | Input: 92 | data: PyG data object 93 | Output: 94 | data (edges are augmented in the following ways): 95 | data.edge_index: Added next-token edge. The inverse edges were also added. 96 | data.edge_attr (torch.Long): 97 | data.edge_attr[:,0]: whether it is AST edge (0) for next-token edge (1) 98 | data.edge_attr[:,1]: whether it is original direction (0) or inverse direction (1) 99 | """ 100 | 101 | ##### AST edge 102 | edge_index_ast = data.edge_index 103 | edge_attr_ast = torch.zeros((edge_index_ast.size(1), 2)) 104 | 105 | ##### Inverse AST edge 106 | edge_index_ast_inverse = torch.stack([edge_index_ast[1], edge_index_ast[0]], dim=0) 107 | edge_attr_ast_inverse = torch.cat( 108 | [torch.zeros(edge_index_ast_inverse.size(1), 1), torch.ones(edge_index_ast_inverse.size(1), 1)], dim=1 109 | ) 110 | 111 | ##### Next-token edge 112 | 113 | ## Obtain attributed nodes and get their indices in dfs order 114 | # attributed_node_idx = torch.where(data.node_is_attributed.view(-1,) == 1)[0] 115 | # attributed_node_idx_in_dfs_order = attributed_node_idx[torch.argsort(data.node_dfs_order[attributed_node_idx].view(-1,))] 116 | 117 | ## Since the nodes are already sorted in dfs ordering in our case, we can just do the following. 118 | attributed_node_idx_in_dfs_order = torch.where( 119 | data.node_is_attributed.view( 120 | -1, 121 | ) 122 | == 1 123 | )[0] 124 | 125 | ## build next token edge 126 | # Given: attributed_node_idx_in_dfs_order 127 | # [1, 3, 4, 5, 8, 9, 12] 128 | # Output: 129 | # [[1, 3, 4, 5, 8, 9] 130 | # [3, 4, 5, 8, 9, 12] 131 | edge_index_nextoken = torch.stack([attributed_node_idx_in_dfs_order[:-1], attributed_node_idx_in_dfs_order[1:]], dim=0) 132 | edge_attr_nextoken = torch.cat([torch.ones(edge_index_nextoken.size(1), 1), torch.zeros(edge_index_nextoken.size(1), 1)], dim=1) 133 | 134 | ##### Inverse next-token edge 135 | edge_index_nextoken_inverse = torch.stack([edge_index_nextoken[1], edge_index_nextoken[0]], dim=0) 136 | edge_attr_nextoken_inverse = torch.ones((edge_index_nextoken.size(1), 2)) 137 | 138 | data.edge_index = torch.cat([edge_index_ast, edge_index_ast_inverse, edge_index_nextoken, edge_index_nextoken_inverse], dim=1) 139 | data.edge_attr = torch.cat([edge_attr_ast, edge_attr_ast_inverse, edge_attr_nextoken, edge_attr_nextoken_inverse], dim=0) 140 | 141 | return data 142 | 143 | 144 | def encode_y_to_arr(data, vocab2idx, max_seq_len): 145 | """ 146 | Input: 147 | data: PyG graph object 148 | output: add y_arr to data 149 | """ 150 | 151 | # PyG >= 1.5.0 152 | seq = data.y 153 | 154 | # PyG = 1.4.3 155 | # seq = data.y[0] 156 | 157 | data.y_arr = encode_seq_to_arr(seq, vocab2idx, max_seq_len) 158 | 159 | return data 160 | 161 | 162 | def encode_seq_to_arr(seq, vocab2idx, max_seq_len): 163 | """ 164 | Input: 165 | seq: A list of words 166 | output: add y_arr (torch.Tensor) 167 | """ 168 | 169 | augmented_seq = seq[:max_seq_len] + ["__EOS__"] * max(0, max_seq_len - len(seq)) 170 | return torch.tensor([[vocab2idx[w] if w in vocab2idx else vocab2idx["__UNK__"] for w in augmented_seq]], dtype=torch.long) 171 | 172 | 173 | def decode_arr_to_seq(arr, idx2vocab): 174 | """ 175 | Input: torch 1d array: y_arr 176 | Output: a sequence of words. 177 | """ 178 | 179 | eos_idx_list = (arr == len(idx2vocab) - 1).nonzero() # find the position of __EOS__ (the last vocab in idx2vocab) 180 | if len(eos_idx_list) > 0: 181 | clippted_arr = arr[: torch.min(eos_idx_list)] # find the smallest __EOS__ 182 | else: 183 | clippted_arr = arr 184 | 185 | return list(map(lambda x: idx2vocab[x], clippted_arr.cpu())) 186 | 187 | 188 | def test(): 189 | seq_list = [["a", "b"], ["a", "b", "c", "df", "f", "2edea", "a"], ["eraea", "a", "c"], ["d"], ["4rq4f", "f", "a", "a", "g"]] 190 | vocab2idx, idx2vocab = get_vocab_mapping(seq_list, 4) 191 | logger.debug(vocab2idx) 192 | logger.debug(idx2vocab) 193 | assert len(vocab2idx) == len(idx2vocab) 194 | 195 | for vocab, idx in vocab2idx.items(): 196 | assert idx2vocab[idx] == vocab 197 | 198 | for seq in seq_list: 199 | logger.debug(seq) 200 | arr = encode_seq_to_arr(seq, vocab2idx, max_seq_len=4)[0] 201 | # Test the effect of predicting __EOS__ 202 | # arr[2] = vocab2idx['__EOS__'] 203 | logger.debug(arr) 204 | seq_dec = decode_arr_to_seq(arr, idx2vocab) 205 | 206 | logger.debug(arr) 207 | logger.debug(seq_dec) 208 | 209 | 210 | if __name__ == "__main__": 211 | test() 212 | -------------------------------------------------------------------------------- /models/transformer_gnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from loguru import logger 6 | 7 | from modules.gnn_module import GNNNodeEmbedding 8 | from modules.masked_transformer_encoder import MaskedOnlyTransformerEncoder 9 | from modules.transformer_encoder import TransformerNodeEncoder 10 | from modules.utils import pad_batch, unpad_batch 11 | 12 | from .base_model import BaseModel 13 | from torch_geometric.nn import ( 14 | GlobalAttention, 15 | MessagePassing, 16 | Set2Set, 17 | global_add_pool, 18 | global_max_pool, 19 | global_mean_pool, 20 | ) 21 | 22 | class TransformerGNN(BaseModel): 23 | @staticmethod 24 | def get_emb_dim(args): 25 | return args.gnn_emb_dim 26 | 27 | @staticmethod 28 | def add_args(parser): 29 | TransformerNodeEncoder.add_args(parser) 30 | MaskedOnlyTransformerEncoder.add_args(parser) 31 | group = parser.add_argument_group("GNNTransformer - Training Config") 32 | group.add_argument("--pretrained_gnn", type=str, default=None, help="pretrained gnn_node node embedding path") 33 | group.add_argument("--freeze_gnn", type=int, default=None, help="Freeze gnn_node weight from epoch `freeze_gnn`") 34 | group.add_argument("--graph_input_dim", type=int, default=None) 35 | 36 | @staticmethod 37 | def name(args): 38 | name = f"{args.model_type}-pooling={args.graph_pooling}" 39 | name += "-norm_input" if args.transformer_norm_input else "" 40 | name += f"+{args.gnn_type}" 41 | name += "-virtual" if args.gnn_virtual_node else "" 42 | name += f"-JK={args.gnn_JK}" 43 | name += f"-enc_layer={args.num_encoder_layers}" 44 | name += f"-enc_layer_masked={args.num_encoder_layers_masked}" 45 | name += f"-d={args.d_model}" 46 | name += f"-act={args.transformer_activation}" 47 | name += f"-tdrop={args.transformer_dropout}" 48 | name += f"-gdrop={args.gnn_dropout}" 49 | name += "-pretrained_gnn" if args.pretrained_gnn else "" 50 | name += f"-freeze_gnn={args.freeze_gnn}" if args.freeze_gnn is not None else "" 51 | name += "-prenorm" if args.transformer_prenorm else "-postnorm" 52 | return name 53 | 54 | def __init__(self, num_tasks, node_encoder, edge_encoder_cls, args): 55 | super().__init__() 56 | self.node_encoder = node_encoder 57 | self.input2transformer = nn.Linear(args.graph_input_dim, args.d_model) if args.graph_input_dim is not None else None 58 | self.transformer_encoder = TransformerNodeEncoder(args) 59 | self.masked_transformer_encoder = MaskedOnlyTransformerEncoder(args) 60 | gnn_emb_dim = args.gnn_emb_dim 61 | self.transformer2gnn = nn.Linear(args.d_model, gnn_emb_dim) 62 | self.gnn_node = GNNNodeEmbedding( 63 | args.gnn_virtual_node, 64 | args.gnn_num_layer, 65 | args.gnn_emb_dim, 66 | node_encoder=None, 67 | edge_encoder_cls=edge_encoder_cls, 68 | JK=args.gnn_JK, 69 | drop_ratio=args.gnn_dropout, 70 | residual=args.gnn_residual, 71 | gnn_type=args.gnn_type, 72 | ) 73 | if args.pretrained_gnn: 74 | # logger.info(self.gnn_node) 75 | state_dict = torch.load(args.pretrained_gnn) 76 | state_dict = self._gnn_node_state(state_dict["model"]) 77 | logger.info("Load GNN state from: {}", state_dict.keys()) 78 | self.gnn_node.load_state_dict(state_dict) 79 | self.freeze_gnn = args.freeze_gnn 80 | 81 | self.num_encoder_layers = args.num_encoder_layers 82 | self.num_encoder_layers_masked = args.num_encoder_layers_masked 83 | 84 | self.num_tasks = num_tasks 85 | self.pooling = args.graph_pooling 86 | self.graph_pred_linear_list = torch.nn.ModuleList() 87 | 88 | self.max_seq_len = args.max_seq_len 89 | 90 | ### Pooling function to generate whole-graph embeddings 91 | if self.pooling == "sum": 92 | self.pool = global_add_pool 93 | elif self.pooling == "mean": 94 | self.pool = global_mean_pool 95 | elif self.pooling == "max": 96 | self.pool = global_max_pool 97 | elif self.pooling == "attention": 98 | self.pool = GlobalAttention( 99 | gate_nn=torch.nn.Sequential( 100 | torch.nn.Linear(gnn_emb_dim, 2 * gnn_emb_dim), 101 | torch.nn.BatchNorm1d(2 * gnn_emb_dim), 102 | torch.nn.ReLU(), 103 | torch.nn.Linear(2 * gnn_emb_dim, 1), 104 | ) 105 | ) 106 | elif self.pooling == "set2set": 107 | self.pool = Set2Set(gnn_emb_dim, processing_steps=2) 108 | else: 109 | raise ValueError(f"Invalid graph pooling type. {self.pooling}") 110 | 111 | if self.max_seq_len is None: 112 | if self.pooling == "set2set": 113 | self.graph_pred_linear = torch.nn.Linear(2 * gnn_emb_dim, self.num_tasks) 114 | else: 115 | self.graph_pred_linear = torch.nn.Linear(gnn_emb_dim, self.num_tasks) 116 | else: 117 | self.graph_pred_linear_list = torch.nn.ModuleList() 118 | if self.pooling == "set2set": 119 | for i in range(self.max_seq_len): 120 | self.graph_pred_linear_list.append(torch.nn.Linear(2 * gnn_emb_dim, self.num_tasks)) 121 | else: 122 | for i in range(self.max_seq_len): 123 | if args.gnn_JK == 'cat': 124 | self.graph_pred_linear_list.append(torch.nn.Linear(2 * gnn_emb_dim, self.num_tasks)) 125 | else: 126 | self.graph_pred_linear_list.append(torch.nn.Linear(gnn_emb_dim, self.num_tasks)) 127 | 128 | def forward(self, batched_data, perturb=None): 129 | x = batched_data.x 130 | node_depth = batched_data.node_depth if hasattr(batched_data, "node_depth") else None 131 | encoded_node = ( 132 | self.node_encoder(x) 133 | if node_depth is None 134 | else self.node_encoder( 135 | x, 136 | node_depth.view( 137 | -1, 138 | ), 139 | ) 140 | ) 141 | tmp = encoded_node + perturb if perturb is not None else encoded_node 142 | if self.input2transformer is not None: 143 | tmp = self.input2transformer(tmp) 144 | padded_h_node, src_padding_mask, num_nodes, mask, max_num_nodes = pad_batch( 145 | tmp, batched_data.batch, self.transformer_encoder.max_input_len, get_mask=True 146 | ) # Pad in the front 147 | # TODO(paras): implement mask 148 | transformer_out = padded_h_node 149 | if self.num_encoder_layers_masked > 0: 150 | adj_list = batched_data.adj_list 151 | padded_adj_list = torch.zeros((len(adj_list), max_num_nodes, max_num_nodes), device=h_node.device) 152 | for idx, adj_list_item in enumerate(adj_list): 153 | N, _ = adj_list_item.shape 154 | padded_adj_list[idx, 0:N, 0:N] = torch.from_numpy(adj_list_item) 155 | transformer_out = self.masked_transformer_encoder( 156 | transformer_out.transpose(0, 1), attn_mask=padded_adj_list, valid_input_mask=src_padding_mask 157 | ).transpose(0, 1) 158 | if self.num_encoder_layers > 0: 159 | transformer_out, _ = self.transformer_encoder(transformer_out, src_padding_mask) # [s, b, h], [b, s] 160 | 161 | h_node = unpad_batch(transformer_out, tmp, num_nodes, mask, max_num_nodes) 162 | batched_data.x = self.transformer2gnn(h_node) 163 | h_node = self.gnn_node(batched_data, None) 164 | 165 | h_graph = self.pool(h_node, batched_data.batch) 166 | 167 | if self.max_seq_len is None: 168 | out = self.graph_pred_linear(h_graph) 169 | return out 170 | pred_list = [] 171 | for i in range(self.max_seq_len): 172 | pred_list.append(self.graph_pred_linear_list[i](h_graph)) 173 | 174 | return pred_list 175 | 176 | def epoch_callback(self, epoch): 177 | # TODO: maybe unfreeze the gnn at the end. 178 | if self.freeze_gnn is not None and epoch >= self.freeze_gnn: 179 | logger.info(f"Freeze GNN weight after epoch: {epoch}") 180 | for param in self.gnn_node.parameters(): 181 | param.requires_grad = False 182 | 183 | def _gnn_node_state(self, state_dict): 184 | module_name = "gnn_node" 185 | new_state_dict = dict() 186 | for k, v in state_dict.items(): 187 | if module_name in k: 188 | new_key = k.split(".") 189 | module_index = new_key.index(module_name) 190 | new_key = ".".join(new_key[module_index + 1 :]) 191 | new_state_dict[new_key] = v 192 | return new_state_dict 193 | -------------------------------------------------------------------------------- /modules/gnn_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch_geometric.nn import ( 4 | GlobalAttention, 5 | MessagePassing, 6 | Set2Set, 7 | global_add_pool, 8 | global_max_pool, 9 | global_mean_pool, 10 | ) 11 | from torch_geometric.nn.inits import uniform 12 | 13 | from modules.conv import GCNConv, GINConv 14 | from modules.utils import pad_batch 15 | 16 | 17 | ### GNN to generate nodse embedding 18 | class GNN_node(torch.nn.Module): 19 | """ 20 | Output: 21 | node representations 22 | """ 23 | 24 | @staticmethod 25 | def need_deg(): 26 | return False 27 | 28 | def __init__(self, num_layer, emb_dim, node_encoder, edge_encoder_cls, drop_ratio=0.5, JK="last", residual=False, gnn_type="gin"): 29 | """ 30 | emb_dim (int): node embedding dimensionality 31 | num_layer (int): number of GNN message passing layers 32 | """ 33 | 34 | super(GNN_node, self).__init__() 35 | self.num_layer = num_layer 36 | self.drop_ratio = drop_ratio 37 | self.JK = JK 38 | ### add residual connection or not 39 | self.residual = residual 40 | 41 | if self.num_layer < 2: 42 | raise ValueError("Number of GNN layers must be greater than 1.") 43 | 44 | self.node_encoder = node_encoder 45 | 46 | ###List of GNNs 47 | self.convs = torch.nn.ModuleList() 48 | self.batch_norms = torch.nn.ModuleList() 49 | 50 | for layer in range(num_layer): 51 | if gnn_type == "gin": 52 | self.convs.append(GINConv(emb_dim, edge_encoder_cls)) 53 | elif gnn_type == "gcn": 54 | self.convs.append(GCNConv(emb_dim, edge_encoder_cls)) 55 | else: 56 | ValueError("Undefined GNN type called {}".format(gnn_type)) 57 | 58 | self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) 59 | 60 | def forward(self, batched_data, perturb=None): 61 | x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch 62 | node_depth = batched_data.node_depth if hasattr(batched_data, "node_depth") else None 63 | 64 | ### computing input node embedding 65 | if self.node_encoder is not None: 66 | encoded_node = ( 67 | self.node_encoder(x) 68 | if node_depth is None 69 | else self.node_encoder( 70 | x, 71 | node_depth.view( 72 | -1, 73 | ), 74 | ) 75 | ) 76 | else: 77 | encoded_node = x 78 | tmp = encoded_node + perturb if perturb is not None else encoded_node 79 | h_list = [tmp] 80 | 81 | for layer in range(self.num_layer): 82 | 83 | h = self.convs[layer](h_list[layer], edge_index, edge_attr) 84 | h = self.batch_norms[layer](h) 85 | 86 | if layer == self.num_layer - 1: 87 | # remove relu for the last layer 88 | h = F.dropout(h, self.drop_ratio, training=self.training) 89 | else: 90 | h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) 91 | 92 | if self.residual: 93 | h += h_list[layer] 94 | 95 | h_list.append(h) 96 | 97 | ### Different implementations of Jk-concat 98 | if self.JK == "last": 99 | node_representation = h_list[-1] 100 | elif self.JK == "sum": 101 | node_representation = 0 102 | for layer in range(self.num_layer): 103 | node_representation += h_list[layer] 104 | elif self.JK == "cat": 105 | node_representation = torch.cat([h_list[0], h_list[-1]], dim=-1) 106 | 107 | return node_representation 108 | 109 | 110 | ### Virtual GNN to generate node embedding 111 | class GNN_node_Virtualnode(torch.nn.Module): 112 | """ 113 | Output: 114 | node representations 115 | """ 116 | 117 | @staticmethod 118 | def need_deg(): 119 | return False 120 | 121 | def __init__(self, num_layer, emb_dim, node_encoder, edge_encoder_cls, drop_ratio=0.5, JK="last", residual=False, gnn_type="gin"): 122 | """ 123 | emb_dim (int): node embedding dimensionality 124 | """ 125 | 126 | super(GNN_node_Virtualnode, self).__init__() 127 | self.num_layer = num_layer 128 | self.drop_ratio = drop_ratio 129 | self.JK = JK 130 | ### add residual connection or not 131 | self.residual = residual 132 | 133 | if self.num_layer < 2: 134 | raise ValueError("Number of GNN layers must be greater than 1.") 135 | 136 | self.node_encoder = node_encoder 137 | 138 | ### set the initial virtual node embedding to 0. 139 | self.virtualnode_embedding = torch.nn.Embedding(1, emb_dim) 140 | torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0) 141 | 142 | ### List of GNNs 143 | self.convs = torch.nn.ModuleList() 144 | ### batch norms applied to node embeddings 145 | self.batch_norms = torch.nn.ModuleList() 146 | 147 | ### List of MLPs to transform virtual node at every layer 148 | self.mlp_virtualnode_list = torch.nn.ModuleList() 149 | 150 | for layer in range(num_layer): 151 | if gnn_type == "gin": 152 | self.convs.append(GINConv(emb_dim, edge_encoder_cls)) 153 | elif gnn_type == "gcn": 154 | self.convs.append(GCNConv(emb_dim, edge_encoder_cls)) 155 | else: 156 | ValueError("Undefined GNN type called {}".format(gnn_type)) 157 | 158 | self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) 159 | 160 | for layer in range(num_layer - 1): 161 | self.mlp_virtualnode_list.append( 162 | torch.nn.Sequential( 163 | torch.nn.Linear(emb_dim, 2 * emb_dim), 164 | torch.nn.BatchNorm1d(2 * emb_dim), 165 | torch.nn.ReLU(), 166 | torch.nn.Linear(2 * emb_dim, emb_dim), 167 | torch.nn.BatchNorm1d(emb_dim), 168 | torch.nn.ReLU(), 169 | ) 170 | ) 171 | 172 | def forward(self, batched_data, perturb=None): 173 | 174 | x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch 175 | node_depth = batched_data.node_depth if hasattr(batched_data, "node_depth") else None 176 | 177 | ### computing input node embedding 178 | if self.node_encoder is not None: 179 | encoded_node = ( 180 | self.node_encoder(x) 181 | if node_depth is None 182 | else self.node_encoder( 183 | x, 184 | node_depth.view( 185 | -1, 186 | ), 187 | ) 188 | ) 189 | else: 190 | encoded_node = x 191 | tmp = encoded_node + perturb if perturb is not None else encoded_node 192 | h_list = [tmp] 193 | 194 | ### virtual node embeddings for graphs 195 | virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device)) 196 | 197 | for layer in range(self.num_layer): 198 | ### add message from virtual nodes to graph nodes 199 | h_list[layer] = h_list[layer] + virtualnode_embedding[batch] 200 | 201 | ### Message passing among graph nodes 202 | h = self.convs[layer](h_list[layer], edge_index, edge_attr) 203 | 204 | h = self.batch_norms[layer](h) 205 | if layer == self.num_layer - 1: 206 | # remove relu for the last layer 207 | h = F.dropout(h, self.drop_ratio, training=self.training) 208 | else: 209 | h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) 210 | 211 | if self.residual: 212 | h = h + h_list[layer] 213 | 214 | h_list.append(h) 215 | 216 | ### update the virtual nodes 217 | if layer < self.num_layer - 1: 218 | ### add message from graph nodes to virtual nodes 219 | virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding 220 | ### transform virtual nodes using MLP 221 | 222 | if self.residual: 223 | virtualnode_embedding = virtualnode_embedding + F.dropout( 224 | self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training=self.training 225 | ) 226 | else: 227 | virtualnode_embedding = F.dropout( 228 | self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training=self.training 229 | ) 230 | 231 | ### Different implementations of Jk-concat 232 | if self.JK == "last": 233 | node_representation = h_list[-1] 234 | elif self.JK == "sum": 235 | node_representation = 0 236 | for layer in range(self.num_layer): 237 | node_representation += h_list[layer] 238 | elif self.JK == "cat": 239 | node_representation = torch.cat([h_list[0], h_list[-1]], dim=-1) 240 | 241 | return node_representation 242 | 243 | 244 | def GNNNodeEmbedding(virtual_node, *args, **kwargs): 245 | if virtual_node: 246 | return GNN_node_Virtualnode(*args, **kwargs) 247 | else: 248 | return GNN_node(*args, **kwargs) 249 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /modules/pna_layer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from ogb.graphproppred.mol_encoder import BondEncoder 7 | from torch import Tensor 8 | from torch_geometric.nn.conv import MessagePassing 9 | from torch_geometric.nn.inits import reset 10 | from torch_geometric.typing import Adj, OptTensor 11 | from torch_geometric.utils import degree 12 | 13 | from .pna.aggregators import AGGREGATORS 14 | from .pna.scalers import SCALERS 15 | 16 | # Implemented with the help of Matthias Fey, author of PyTorch Geometric 17 | # For an example see https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pna.py 18 | 19 | 20 | class PNAConv(MessagePassing): 21 | r"""The Principal Neighbourhood Aggregation graph convolution operator 22 | from the `"Principal Neighbourhood Aggregation for Graph Nets" 23 | `_ paper 24 | .. math:: 25 | \bigoplus = \underbrace{\begin{bmatrix}I \\ S(D, \alpha=1) \\ 26 | S(D, \alpha=-1) \end{bmatrix} }_{\text{scalers}} 27 | \otimes \underbrace{\begin{bmatrix} \mu \\ \sigma \\ \max \\ \min 28 | \end{bmatrix}}_{\text{aggregators}}, 29 | in: 30 | .. math:: 31 | X_i^{(t+1)} = U \left( X_i^{(t)}, \underset{(j,i) \in E}{\bigoplus} 32 | M \left( X_i^{(t)}, X_j^{(t)} \right) \right) 33 | where :math:`M` and :math:`U` denote the MLP referred to with pretrans 34 | and posttrans respectively. 35 | Args: 36 | in_channels (int): Size of each input sample. 37 | out_channels (int): Size of each output sample. 38 | aggregators (list of str): Set of aggregation function identifiers, 39 | namely :obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, 40 | :obj:`"var"` and :obj:`"std"`. 41 | scalers: (list of str): Set of scaling function identifiers, namely 42 | :obj:`"identity"`, :obj:`"amplification"`, 43 | :obj:`"attenuation"`, :obj:`"linear"` and 44 | :obj:`"inverse_linear"`. 45 | deg (Tensor): Histogram of in-degrees of nodes in the training set, 46 | used by scalers to normalize. 47 | edge_dim (int, optional): Edge feature dimensionality (in case 48 | there are any). (default :obj:`None`) 49 | towers (int, optional): Number of towers (default: :obj:`1`). 50 | pre_layers (int, optional): Number of transformation layers before 51 | aggregation (default: :obj:`1`). 52 | post_layers (int, optional): Number of transformation layers after 53 | aggregation (default: :obj:`1`). 54 | divide_input (bool, optional): Whether the input features should 55 | be split between towers or not (default: :obj:`False`). 56 | **kwargs (optional): Additional arguments of 57 | :class:`torch_geometric.nn.conv.MessagePassing`. 58 | """ 59 | 60 | def __init__( 61 | self, 62 | in_channels: int, 63 | out_channels: int, 64 | aggregators: List[str], 65 | scalers: List[str], 66 | deg: Tensor, 67 | edge_dim: Optional[int] = None, 68 | towers: int = 1, 69 | pre_layers: int = 1, 70 | post_layers: int = 1, 71 | divide_input: bool = False, 72 | **kwargs, 73 | ): 74 | 75 | super(PNAConv, self).__init__(aggr=None, node_dim=0, **kwargs) 76 | 77 | if divide_input: 78 | assert in_channels % towers == 0 79 | assert out_channels % towers == 0 80 | 81 | self.in_channels = in_channels 82 | self.out_channels = out_channels 83 | self.aggregators = [AGGREGATORS[aggr] for aggr in aggregators] 84 | self.scalers = [SCALERS[scale] for scale in scalers] 85 | self.edge_dim = edge_dim 86 | self.towers = towers 87 | self.divide_input = divide_input 88 | 89 | self.F_in = in_channels // towers if divide_input else in_channels 90 | self.F_out = self.out_channels // towers 91 | 92 | deg = deg.to(torch.float) 93 | self.avg_deg: Dict[str, float] = { 94 | "lin": deg.mean().item(), 95 | "log": (deg + 1).log().mean().item(), 96 | "exp": deg.exp().mean().item(), 97 | } 98 | 99 | if self.edge_dim is not None: 100 | self.edge_encoder = BondEncoder(emb_dim=in_channels) 101 | 102 | self.pre_nns = ModuleList() 103 | self.post_nns = ModuleList() 104 | for _ in range(towers): 105 | modules = [Linear((3 if edge_dim else 2) * self.F_in, self.F_in)] 106 | for _ in range(pre_layers - 1): 107 | modules += [ReLU()] 108 | modules += [Linear(self.F_in, self.F_in)] 109 | self.pre_nns.append(Sequential(*modules)) 110 | 111 | in_channels = (len(aggregators) * len(scalers) + 1) * self.F_in 112 | modules = [Linear(in_channels, self.F_out)] 113 | for _ in range(post_layers - 1): 114 | modules += [ReLU()] 115 | modules += [Linear(self.F_out, self.F_out)] 116 | self.post_nns.append(Sequential(*modules)) 117 | 118 | self.lin = Linear(out_channels, out_channels) 119 | 120 | self.reset_parameters() 121 | 122 | def reset_parameters(self): 123 | if self.edge_dim is not None: 124 | self.edge_encoder.reset_parameters() 125 | for nn in self.pre_nns: 126 | reset(nn) 127 | for nn in self.post_nns: 128 | reset(nn) 129 | self.lin.reset_parameters() 130 | 131 | def forward(self, x: Tensor, edge_index: Adj, edge_attr: OptTensor = None) -> Tensor: 132 | 133 | if self.divide_input: 134 | x = x.view(-1, self.towers, self.F_in) 135 | else: 136 | x = x.view(-1, 1, self.F_in).repeat(1, self.towers, 1) 137 | 138 | # propagate_type: (x: Tensor, edge_attr: OptTensor) 139 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None) 140 | 141 | out = torch.cat([x, out], dim=-1) 142 | outs = [nn(out[:, i]) for i, nn in enumerate(self.post_nns)] 143 | out = torch.cat(outs, dim=1) 144 | 145 | return self.lin(out) 146 | 147 | def message(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor) -> Tensor: 148 | 149 | h: Tensor = x_i # Dummy. 150 | if edge_attr is not None: 151 | edge_attr = self.edge_encoder(edge_attr) 152 | edge_attr = edge_attr.view(-1, 1, self.F_in) 153 | edge_attr = edge_attr.repeat(1, self.towers, 1) 154 | h = torch.cat([x_i, x_j, edge_attr], dim=-1) 155 | else: 156 | h = torch.cat([x_i, x_j], dim=-1) 157 | 158 | hs = [nn(h[:, i]) for i, nn in enumerate(self.pre_nns)] 159 | return torch.stack(hs, dim=1) 160 | 161 | def aggregate(self, inputs: Tensor, index: Tensor, dim_size: Optional[int] = None) -> Tensor: 162 | outs = [aggr(inputs, index, dim_size) for aggr in self.aggregators] 163 | out = torch.cat(outs, dim=-1) 164 | 165 | deg = degree(index, dim_size, dtype=inputs.dtype).view(-1, 1, 1) 166 | outs = [scaler(out, deg, self.avg_deg) for scaler in self.scalers] 167 | return torch.cat(outs, dim=-1) 168 | 169 | def __repr__(self): 170 | return f"{self.__class__.__name__}({self.in_channels}, " f"{self.out_channels}, towers={self.towers}, dim={self.dim})" 171 | raise NotImplementedError 172 | 173 | 174 | class PNAConvSimple(MessagePassing): 175 | r"""The Principal Neighbourhood Aggregation graph convolution operator 176 | from the `"Principal Neighbourhood Aggregation for Graph Nets" 177 | `_ paper 178 | .. math:: 179 | \bigoplus = \underbrace{\begin{bmatrix}I \\ S(D, \alpha=1) \\ 180 | S(D, \alpha=-1) \end{bmatrix} }_{\text{scalers}} 181 | \otimes \underbrace{\begin{bmatrix} \mu \\ \sigma \\ \max \\ \min 182 | \end{bmatrix}}_{\text{aggregators}}, 183 | in: 184 | .. math:: 185 | X_i^{(t+1)} = U \left( \underset{(j,i) \in E}{\bigoplus} 186 | M \left(X_j^{(t)} \right) \right) 187 | where :math:`U` denote the MLP referred to with posttrans. 188 | Args: 189 | in_channels (int): Size of each input sample. 190 | out_channels (int): Size of each output sample. 191 | aggregators (list of str): Set of aggregation function identifiers, 192 | namely :obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, 193 | :obj:`"var"` and :obj:`"std"`. 194 | scalers: (list of str): Set of scaling function identifiers, namely 195 | :obj:`"identity"`, :obj:`"amplification"`, 196 | :obj:`"attenuation"`, :obj:`"linear"` and 197 | :obj:`"inverse_linear"`. 198 | deg (Tensor): Histogram of in-degrees of nodes in the training set, 199 | used by scalers to normalize. 200 | post_layers (int, optional): Number of transformation layers after 201 | aggregation (default: :obj:`1`). 202 | **kwargs (optional): Additional arguments of 203 | :class:`torch_geometric.nn.conv.MessagePassing`. 204 | """ 205 | 206 | def __init__( 207 | self, 208 | in_channels: int, 209 | out_channels: int, 210 | edge_encoder_cls, 211 | aggregators: List[str], 212 | scalers: List[str], 213 | deg: Tensor, 214 | drop_ratio: float = None, 215 | post_layers: int = 1, 216 | add_edge="none", 217 | **kwargs, 218 | ): 219 | 220 | super(PNAConvSimple, self).__init__(aggr=None, node_dim=0, **kwargs) 221 | 222 | self.in_channels = in_channels 223 | self.out_channels = out_channels 224 | self.aggregators = [AGGREGATORS[aggr] for aggr in aggregators] 225 | self.scalers = [SCALERS[scale] for scale in scalers] 226 | 227 | self.add_edge = add_edge 228 | if add_edge != "none": 229 | self.edge_encoder = edge_encoder_cls(in_channels) 230 | 231 | self.F_in = in_channels 232 | self.F_out = self.out_channels 233 | 234 | deg = deg.to(torch.float) 235 | self.avg_deg: Dict[str, float] = { 236 | "lin": deg.mean().item(), 237 | "log": (deg + 1).log().mean().item(), 238 | "exp": deg.exp().mean().item(), 239 | } 240 | 241 | in_channels = (len(aggregators) * len(scalers)) * self.F_in 242 | modules = [nn.Linear(in_channels, self.F_out)] 243 | # modules += [nn.Dropout(drop_ratio)] if drop_ratio is not None else [] 244 | for _ in range(post_layers - 1): 245 | modules += [nn.ReLU()] 246 | # modules += [nn.Dropout(drop_ratio)] if drop_ratio is not None else [] 247 | modules += [nn.Linear(self.F_out, self.F_out)] 248 | self.post_nn = nn.Sequential(*modules) 249 | if self.add_edge == "gincat": 250 | self.pre_nn = nn.Linear(self.F_in * 2, self.F_in) 251 | 252 | self.reset_parameters() 253 | 254 | def reset_parameters(self): 255 | reset(self.post_nn) 256 | 257 | def forward(self, x: Tensor, edge_index: Adj, edge_attr: OptTensor = None) -> Tensor: 258 | 259 | # propagate_type: (x: Tensor) 260 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None) 261 | return self.post_nn(out) 262 | 263 | def message(self, x_j: Tensor, edge_attr: OptTensor) -> Tensor: 264 | if self.add_edge == "gin": 265 | edge_attr = self.edge_encoder(edge_attr) 266 | return F.relu(x_j + edge_attr) 267 | elif self.add_edge == "gincat": 268 | edge_attr = self.edge_encoder(edge_attr) 269 | x_j = torch.cat([x_j, edge_attr], dim=-1) 270 | return self.pre_nn(x_j) 271 | return x_j 272 | 273 | def aggregate(self, inputs: Tensor, index: Tensor, dim_size: Optional[int] = None) -> Tensor: 274 | outs = [aggr(inputs, index, dim_size) for aggr in self.aggregators] 275 | out = torch.cat(outs, dim=-1) 276 | 277 | deg = degree(index, dim_size, dtype=inputs.dtype).view(-1, 1) 278 | outs = [scaler(out, deg, self.avg_deg) for scaler in self.scalers] 279 | return torch.cat(outs, dim=-1) 280 | 281 | def __repr__(self): 282 | return f"{self.__class__.__name__}({self.in_channels}, " f"{self.out_channels}" 283 | raise NotImplementedError 284 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from datetime import datetime 4 | 5 | import configargparse 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | import wandb 12 | from loguru import logger 13 | from ogb.graphproppred import Evaluator, PygGraphPropPredDataset 14 | from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR, ReduceLROnPlateau 15 | from torch_geometric.data import DataLoader 16 | from tqdm import tqdm 17 | 18 | import utils 19 | from data.adj_list import compute_adjacency_list_cached 20 | from dataset import DATASET_UTILS 21 | from models import get_model_and_parser 22 | from trainers import get_trainer_and_parser 23 | 24 | wandb.init(project="graph-aug") 25 | now = datetime.now() 26 | now = now.strftime("%m_%d-%H_%M_%S") 27 | 28 | 29 | def main(): 30 | # fmt: off 31 | parser = configargparse.ArgumentParser(allow_abbrev=False, 32 | description='GNN baselines on ogbg-code data with Pytorch Geometrics') 33 | parser.add_argument('--configs', required=False, is_config_file=True) 34 | parser.add_argument('--wandb_run_idx', type=str, default=None) 35 | 36 | 37 | parser.add_argument('--data_root', type=str, default='/data/zhwu/ogb') 38 | parser.add_argument('--dataset', type=str, default="ogbg-code", 39 | help='dataset name (default: ogbg-code)') 40 | 41 | parser.add_argument('--aug', type=str, default='baseline', 42 | help='augment method to use [baseline|flag|augment]') 43 | 44 | parser.add_argument('--max_seq_len', type=int, default=None, 45 | help='maximum sequence length to predict (default: None)') 46 | 47 | group = parser.add_argument_group('model') 48 | group.add_argument('--model_type', type=str, default='gnn', help='gnn|pna|gnn-transformer') 49 | group.add_argument('--graph_pooling', type=str, default='mean') 50 | group = parser.add_argument_group('gnn') 51 | group.add_argument('--gnn_type', type=str, default='gcn') 52 | group.add_argument('--gnn_virtual_node', action='store_true') 53 | group.add_argument('--gnn_dropout', type=float, default=0) 54 | group.add_argument('--gnn_num_layer', type=int, default=5, 55 | help='number of GNN message passing layers (default: 5)') 56 | group.add_argument('--gnn_emb_dim', type=int, default=300, 57 | help='dimensionality of hidden units in GNNs (default: 300)') 58 | group.add_argument('--gnn_JK', type=str, default='last') 59 | group.add_argument('--gnn_residual', action='store_true', default=False) 60 | 61 | group = parser.add_argument_group('training') 62 | group.add_argument('--devices', type=str, default="0", 63 | help='which gpu to use if any (default: 0)') 64 | group.add_argument('--batch_size', type=int, default=128, 65 | help='input batch size for training (default: 128)') 66 | group.add_argument('--eval_batch_size', type=int, default=None, 67 | help='input batch size for training (default: train batch size)') 68 | group.add_argument('--epochs', type=int, default=30, 69 | help='number of epochs to train (default: 30)') 70 | group.add_argument('--num_workers', type=int, default=0, 71 | help='number of workers (default: 0)') 72 | group.add_argument('--scheduler', type=str, default=None) 73 | group.add_argument('--pct_start', type=float, default=0.3) 74 | group.add_argument('--weight_decay', type=float, default=0.0) 75 | group.add_argument('--grad_clip', type=float, default=None) 76 | group.add_argument('--lr', type=float, default=0.001) 77 | group.add_argument('--max_lr', type=float, default=0.001) 78 | group.add_argument('--runs', type=int, default=10) 79 | group.add_argument('--test-freq', type=int, default=1) 80 | group.add_argument('--start-eval', type=int, default=15) 81 | group.add_argument('--resume', type=str, default=None) 82 | group.add_argument('--seed', type=int, default=None) 83 | # fmt: on 84 | 85 | args, _ = parser.parse_known_args() 86 | 87 | dataset_util = DATASET_UTILS[args.dataset]() 88 | dataset_util.add_args(parser) 89 | args, _ = parser.parse_known_args() 90 | 91 | # Setup Trainer and add customized args 92 | trainer = get_trainer_and_parser(args, parser) 93 | train = trainer.train 94 | model_cls = get_model_and_parser(args, parser) 95 | args = parser.parse_args() 96 | data_transform = trainer.transform(args) 97 | 98 | run_name = f"{args.dataset}+{model_cls.name(args)}" 99 | run_name += f"+{trainer.name(args)}+lr={args.lr}+wd={args.weight_decay}" 100 | if args.scheduler is not None: 101 | run_name = run_name + f"+sch={args.scheduler}" 102 | if args.seed: 103 | run_name = run_name + f"+seed{args.seed}" 104 | if args.wandb_run_idx is not None: 105 | run_name = args.wandb_run_idx + "_" + run_name 106 | 107 | wandb.run.name = run_name 108 | 109 | device = torch.device("cuda") if torch.cuda.is_available() and args.devices else torch.device("cpu") 110 | args.save_path = f"exps/{run_name}-{now}" 111 | os.makedirs(args.save_path, exist_ok=True) 112 | if args.resume is not None: 113 | args.save_path = args.resume 114 | logger.info(args) 115 | wandb.config.update(args) 116 | 117 | if args.seed is not None: 118 | random.seed(args.seed) 119 | np.random.seed(args.seed) 120 | torch.manual_seed(args.seed) 121 | 122 | if device == torch.cuda.is_available(): 123 | # cudnn.deterministic = True 124 | torch.cuda.manual_seed(args.seed) 125 | 126 | if "ogb" in args.dataset: 127 | # automatic dataloading and splitting 128 | dataset_ = PygGraphPropPredDataset(name=args.dataset, root=args.data_root, transform=data_transform) 129 | dataset_eval_ = PygGraphPropPredDataset(name=args.dataset, root=args.data_root) 130 | num_tasks, node_encoder_cls, edge_encoder_cls, deg = dataset_util.preprocess(dataset_, dataset_eval_, model_cls, args) 131 | evaluator = Evaluator(args.dataset) # automatic evaluator. takes dataset name as input 132 | else: 133 | dataset_, num_tasks, node_encoder_cls, edge_encoder_cls, deg = dataset_util.preprocess(args) 134 | dataset_eval_ = dataset_ 135 | evaluator = None 136 | 137 | task_type = dataset_.task_type 138 | split_idx = dataset_.get_idx_split() 139 | calc_loss = dataset_util.loss_fn(task_type) 140 | eval = dataset_util.eval 141 | 142 | def create_loader(dataset, dataset_eval): 143 | test_data = compute_adjacency_list_cached(dataset[split_idx["test"]], key=f"{args.dataset}_test") 144 | valid_data = compute_adjacency_list_cached(dataset_eval[split_idx["valid"]], key=f"{args.dataset}_valid") 145 | train_data = compute_adjacency_list_cached(dataset[split_idx["train"]], key=f"{args.dataset}_train") 146 | logger.debug("Finished computing adjacency list") 147 | 148 | eval_bs = args.batch_size if args.eval_batch_size is None else args.eval_batch_size 149 | train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) 150 | train_loader_eval = DataLoader(train_data, batch_size=eval_bs, shuffle=False, num_workers=args.num_workers, pin_memory=True) 151 | valid_loader = DataLoader(valid_data, batch_size=eval_bs, shuffle=False, num_workers=args.num_workers, pin_memory=True) 152 | test_loader = DataLoader(test_data, batch_size=eval_bs, shuffle=False, num_workers=args.num_workers, pin_memory=True) 153 | return train_loader, train_loader_eval, valid_loader, test_loader 154 | 155 | train_loader_, train_loader_eval_, valid_loader_, test_loader_ = create_loader(dataset_, dataset_eval_) 156 | 157 | def count_parameters(model): 158 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 159 | def run(run_id): 160 | if "ogb" not in args.dataset: 161 | dataset, _, _, _, _ = dataset_util.preprocess(args) 162 | dataset_eval = dataset 163 | train_loader, train_loader_eval, valid_loader, test_loader = create_loader(dataset, dataset_eval) 164 | else: 165 | train_loader, train_loader_eval, valid_loader, test_loader = train_loader_, train_loader_eval_, valid_loader_, test_loader_ 166 | dataset = dataset_ 167 | node_encoder = node_encoder_cls() 168 | 169 | os.makedirs(os.path.join(args.save_path, str(run_id)), exist_ok=True) 170 | best_val, final_test = 0, 0 171 | model = model_cls(num_tasks=num_tasks, args=args, node_encoder=node_encoder, edge_encoder_cls=edge_encoder_cls).to(device) 172 | print("Model Parameters: ", count_parameters(model)) 173 | # exit(-1) 174 | # model = nn.DataParallel(model) 175 | 176 | wandb.watch(model) 177 | 178 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 179 | if args.scheduler == "plateau": 180 | # NOTE(ajayjain): For Molhiv config, this min_lr is too high -- means that lr does not decay. 181 | scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=20, min_lr=0.0001, verbose=False) 182 | elif args.scheduler == "cosine": 183 | scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs * len(train_loader), verbose=False) 184 | elif args.scheduler == "onecycle": 185 | scheduler = OneCycleLR( 186 | optimizer, 187 | max_lr=args.max_lr, 188 | epochs=args.epochs, 189 | steps_per_epoch=len(train_loader), 190 | pct_start=args.pct_start, 191 | verbose=False, 192 | ) 193 | elif args.scheduler is None: 194 | scheduler = None 195 | else: 196 | raise NotImplementedError 197 | 198 | # Load resume model, if any 199 | start_epoch = 1 200 | last_model_path = os.path.join(args.save_path, str(run_id), "last_model.pt") 201 | if os.path.exists(last_model_path): 202 | state_dict = torch.load(last_model_path) 203 | start_epoch = state_dict["epoch"] + 1 204 | model.load_state_dict(state_dict["model"]) 205 | optimizer.load_state_dict(state_dict["optimizer"]) 206 | if args.scheduler: 207 | scheduler.load_state_dict(state_dict["scheduler"]) 208 | logger.info("[Resume] Loaded: {last_model_path} epoch: {start_epoch}") 209 | 210 | model.epoch_callback(epoch=start_epoch - 1) 211 | for epoch in range(start_epoch, args.epochs + 1): 212 | logger.info(f"=====Epoch {epoch}=====") 213 | logger.info("Training...") 214 | logger.info("Total parameters: {}", utils.num_total_parameters(model)) 215 | logger.info("Trainable parameters: {}", utils.num_trainable_parameters(model)) 216 | loss = train(model, device, train_loader, optimizer, args, calc_loss, scheduler if args.scheduler != "plateau" else None) 217 | 218 | model.epoch_callback(epoch) 219 | wandb.log({f"train/loss-runs{run_id}": loss, f"train/lr": optimizer.param_groups[0]["lr"], f"epoch": epoch}) 220 | 221 | if args.scheduler == "plateau": 222 | valid_perf = eval(model, device, valid_loader, evaluator) 223 | valid_metric = valid_perf[dataset.eval_metric] 224 | scheduler.step(valid_metric) 225 | if epoch > args.start_eval and epoch % args.test_freq == 0 or epoch in [1, args.epochs]: 226 | logger.info("Evaluating...") 227 | with torch.no_grad(): 228 | train_perf = eval(model, device, train_loader_eval, evaluator) 229 | if args.scheduler != "plateau": 230 | valid_perf = eval(model, device, valid_loader, evaluator) 231 | test_perf = eval(model, device, test_loader, evaluator) 232 | 233 | train_metric, valid_metric, test_metric = ( 234 | train_perf[dataset.eval_metric], 235 | valid_perf[dataset.eval_metric], 236 | test_perf[dataset.eval_metric], 237 | ) 238 | wandb.log( 239 | { 240 | f"train/{dataset.eval_metric}-runs{run_id}": train_metric, 241 | f"valid/{dataset.eval_metric}-runs{run_id}": valid_metric, 242 | f"test/{dataset.eval_metric}-runs{run_id}": test_metric, 243 | "epoch": epoch, 244 | } 245 | ) 246 | logger.info(f"Running: {run_name} (runs {run_id})") 247 | logger.info(f"Run {run_id} - train: {train_metric}, val: {valid_metric}, test: {test_metric}") 248 | 249 | # Save checkpoints 250 | state_dict = {"model": model.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch} 251 | state_dict["scheduler"] = scheduler.state_dict() if args.scheduler else None 252 | torch.save(state_dict, os.path.join(args.save_path, str(run_id), "last_model.pt")) 253 | logger.info("[Save] Save model: {}", os.path.join(args.save_path, str(run_id), "last_model.pt")) 254 | if best_val < valid_metric: 255 | best_val = valid_metric 256 | final_test = test_metric 257 | wandb.run.summary[f"best/valid/{dataset.eval_metric}-runs{run_id}"] = valid_metric 258 | wandb.run.summary[f"best/test/{dataset.eval_metric}-runs{run_id}"] = test_metric 259 | torch.save(state_dict, os.path.join(args.save_path, str(run_id), "best_model.pt")) 260 | logger.info("[Best Model] Save model: {}", os.path.join(args.save_path, str(run_id), "best_model.pt")) 261 | 262 | state_dict = torch.load(os.path.join(args.save_path, str(run_id), "best_model.pt")) 263 | logger.info("[Evaluate] Loaded from {}", os.path.join(args.save_path, str(run_id), "best_model.pt")) 264 | model.load_state_dict(state_dict["model"]) 265 | best_valid_perf = eval(model, device, valid_loader, evaluator) 266 | best_test_perf = eval(model, device, test_loader, evaluator) 267 | return best_valid_perf[dataset.eval_metric], best_test_perf[dataset.eval_metric] 268 | 269 | vals, tests = [], [] 270 | for run_id in range(args.runs): 271 | best_val, final_test = run(run_id) 272 | vals.append(best_val) 273 | tests.append(final_test) 274 | logger.info(f"Run {run_id} - val: {best_val}, test: {final_test}") 275 | logger.info(f"Average val accuracy: {np.mean(vals)} ± {np.std(vals)}") 276 | logger.info(f"Average test accuracy: {np.mean(tests)} ± {np.std(tests)}") 277 | 278 | 279 | if __name__ == "__main__": 280 | main() 281 | --------------------------------------------------------------------------------