├── .gitignore ├── README.md ├── configs └── GSSC │ ├── cifar10-GSSC-tune.yaml │ ├── cifar10-GSSC.yaml │ ├── cluster-GSSC-tune.yaml │ ├── cluster-GSSC.yaml │ ├── mnist-GSSC-tune.yaml │ ├── mnist-GSSC.yaml │ ├── ogbg-molhiv-GSSC-tune.yaml │ ├── ogbg-molhiv-GSSC.yaml │ ├── pattern-GSSC-tune.yaml │ ├── pattern-GSSC.yaml │ ├── peptides-func-GSSC-tune.yaml │ ├── peptides-func-GSSC.yaml │ ├── peptides-struct-GSSC-tune.yaml │ ├── peptides-struct-GSSC.yaml │ ├── vocsuperpixels-GSSC-tune.yaml │ ├── vocsuperpixels-GSSC.yaml │ ├── zinc-GSSC-tune.yaml │ ├── zinc-GSSC.yaml │ ├── zincfull-GSSC-tune.yaml │ └── zincfull-GSSC.yaml ├── datasets └── .gitignore ├── environment.yml ├── gssc ├── __init__.py ├── act │ ├── __init__.py │ └── example.py ├── agg_runs.py ├── config │ ├── __init__.py │ ├── custom_gnn_config.py │ ├── data_preprocess_config.py │ ├── dataset_config.py │ ├── defaults_config.py │ ├── example.py │ ├── gt_config.py │ ├── optimizers_config.py │ ├── posenc_config.py │ ├── pretrained_config.py │ ├── split_config.py │ └── wandb_config.py ├── encoder │ ├── ER_edge_encoder.py │ ├── ER_node_encoder.py │ ├── __init__.py │ ├── ast_encoder.py │ ├── composed_encoders.py │ ├── dummy_edge_encoder.py │ ├── equivstable_laplace_pos_encoder.py │ ├── example.py │ ├── exp_edge_fixer.py │ ├── kernel_pos_encoder.py │ ├── laplace_pos_encoder.py │ ├── linear_edge_encoder.py │ ├── linear_node_encoder.py │ ├── ppa_encoder.py │ ├── signnet_pos_encoder.py │ ├── type_dict_encoder.py │ └── voc_superpixels_encoder.py ├── finetuning.py ├── head │ ├── __init__.py │ ├── example.py │ ├── inductive_edge.py │ ├── inductive_node.py │ ├── ogb_code_graph.py │ └── san_graph.py ├── layer │ ├── ETransformer.py │ ├── Exphormer.py │ ├── __init__.py │ ├── bigbird_layer.py │ ├── example.py │ ├── gatedgcn_layer.py │ ├── gine_conv_layer.py │ ├── gps_layer.py │ ├── gssc_layer.py │ ├── multi_model_layer.py │ ├── performer_layer.py │ ├── san2_layer.py │ └── san_layer.py ├── loader │ ├── __init__.py │ ├── dataset │ │ ├── __init__.py │ │ ├── aqsol_molecules.py │ │ ├── coco_superpixels.py │ │ ├── count_cycle.py │ │ ├── malnet_tiny.py │ │ ├── pcqm4mv2_contact.py │ │ ├── peptides_functional.py │ │ ├── peptides_structural.py │ │ └── voc_superpixels.py │ ├── master_loader.py │ ├── ogbg_code2_utils.py │ ├── planetoid.py │ └── split_generator.py ├── logger.py ├── loss │ ├── __init__.py │ ├── l1.py │ ├── multilabel_classification_loss.py │ ├── subtoken_prediction_loss.py │ └── weighted_cross_entropy.py ├── metric_wrapper.py ├── metrics_ogb.py ├── network │ ├── MaskedReduce.py │ ├── __init__.py │ ├── big_bird.py │ ├── custom_gnn.py │ ├── example.py │ ├── gps_model.py │ ├── multi_model.py │ ├── norm.py │ ├── performer.py │ ├── san_transformer.py │ └── utils.py ├── optimizer │ ├── __init__.py │ └── extra_optimizers.py ├── pooling │ ├── __init__.py │ └── example.py ├── stage │ ├── __init__.py │ └── example.py ├── train │ ├── __init__.py │ ├── custom_train.py │ └── example.py ├── transform │ ├── __init__.py │ ├── dist_transforms.py │ ├── expander_edges.py │ ├── posenc_stats.py │ └── transforms.py └── utils.py ├── main.py └── wandb └── .gitignore /.gitignore: -------------------------------------------------------------------------------- 1 | # CUSTOM 2 | .vscode/ 3 | scripts/pcqm4m/**/*.zip 4 | scripts/pcqm4m/**/*.sdf 5 | scripts/pcqm4m/**/*.xyz 6 | scripts/pcqm4m/**/*.csv 7 | !scripts/pcqm4m/**/periodic_table.csv 8 | scripts/pcqm4m/**/*.gz 9 | scripts/pcqm4m/**/*.tsv 10 | scripts/pcqm4m/pcqm4m-v2/ 11 | slurm_history/ 12 | pretrained/ 13 | results/ 14 | vocprep/benchmark_RELEASE/ 15 | vocprep/voc_viz_files/ 16 | vocprep/VOC/benchmark_RELEASE/ 17 | vocprep/VOC/*.tgz 18 | vocprep/VOC/*.pickle 19 | vocprep/VOC/*.pkl 20 | vocprep/VOC/*.zip 21 | splits/ 22 | __pycache__/ 23 | .idea 24 | *.log 25 | *.bak 26 | *.png 27 | *.txt 28 | 29 | # Byte-compiled / optimized / DLL files 30 | __pycache__/ 31 | *.py[cod] 32 | *$py.class 33 | 34 | # C extensions 35 | *.so 36 | 37 | # Distribution / packaging 38 | .Python 39 | build/ 40 | develop-eggs/ 41 | dist/ 42 | downloads/ 43 | eggs/ 44 | .eggs/ 45 | lib/ 46 | lib64/ 47 | parts/ 48 | sdist/ 49 | var/ 50 | wheels/ 51 | pip-wheel-metadata/ 52 | share/python-wheels/ 53 | *.egg-info/ 54 | .installed.cfg 55 | *.egg 56 | MANIFEST 57 | 58 | # PyInstaller 59 | # Usually these files are written by a python script from a template 60 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 61 | *.manifest 62 | *.spec 63 | 64 | # Installer logs 65 | pip-log.txt 66 | pip-delete-this-directory.txt 67 | 68 | # Unit test / coverage reports 69 | htmlcov/ 70 | .tox/ 71 | .nox/ 72 | .coverage 73 | .coverage.* 74 | .cache 75 | nosetests.xml 76 | coverage.xml 77 | *.cover 78 | *.py,cover 79 | .hypothesis/ 80 | .pytest_cache/ 81 | 82 | # Translations 83 | *.mo 84 | *.pot 85 | 86 | # Django stuff: 87 | *.log 88 | local_settings.py 89 | db.sqlite3 90 | db.sqlite3-journal 91 | 92 | # Flask stuff: 93 | instance/ 94 | .webassets-cache 95 | 96 | # Scrapy stuff: 97 | .scrapy 98 | 99 | # Sphinx documentation 100 | docs/_build/ 101 | 102 | # PyBuilder 103 | target/ 104 | 105 | # Jupyter Notebook 106 | .ipynb_checkpoints 107 | 108 | # IPython 109 | profile_default/ 110 | ipython_config.py 111 | 112 | # pyenv 113 | .python-version 114 | 115 | # pipenv 116 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 117 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 118 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 119 | # install all needed dependencies. 120 | #Pipfile.lock 121 | 122 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 123 | __pypackages__/ 124 | 125 | # Celery stuff 126 | celerybeat-schedule 127 | celerybeat.pid 128 | 129 | # SageMath parsed files 130 | *.sage.py 131 | 132 | # Environments 133 | .env 134 | .venv 135 | env/ 136 | venv/ 137 | ENV/ 138 | env.bak/ 139 | venv.bak/ 140 | 141 | # Spyder project settings 142 | .spyderproject 143 | .spyproject 144 | 145 | # Rope project settings 146 | .ropeproject 147 | 148 | # mkdocs documentation 149 | /site 150 | 151 | # mypy 152 | .mypy_cache/ 153 | .dmypy.json 154 | dmypy.json 155 | 156 | # Pyre type checker 157 | .pyre/ 158 | 159 | # vim edit buffer 160 | *.swp 161 | 162 | 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Graph State Space Convolution (GSSC)

2 |

3 | Paper 4 | Github 5 |

6 | 7 | 8 | This repository contains the official implementation of GSSC as described in the paper: [What Can We Learn from State Space Models for Machine Learning on Graphs?](https://arxiv.org/abs/2406.05815) by Yinan Huang*, Siqi Miao*, and Pan Li. 9 | 10 | (*Equal contribution, listed in alphabetical order) 11 | 12 | ## Installation 13 | All required packages are listed in `environment.yml`. 14 | 15 | ## Running the code 16 | Replace `--cfg` with the path to the configuration file and `--device` with the GPU device number like below: 17 | ``` 18 | python main.py --cfg configs/GSSC/peptides-func-GSSC.yaml --device 0 wandb.use False 19 | ``` 20 | This command will train the model on the `peptides-func` dataset using the GSSC method with default hyperparameters. 21 | 22 | ## Reproducing the results 23 | We use wandb to log and sweep the results. To reproduce the reported results, one needs to create and login to a wandb account. Then, one can launch the sweep using the configuration files in the `configs` directory. 24 | For example, to reproduce the tuned results of GSSC on the `peptides-func` dataset, one can launch the sweep using `configs/GSSC/peptides-func-GSSC-tune.yaml`. 25 | 26 | ## Acknowledgement 27 | This repository is built upon [GraphGPS (Rampasek et al., 2022)](https://github.com/rampasek/GraphGPS). 28 | -------------------------------------------------------------------------------- /configs/GSSC/cifar10-GSSC-tune.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | project: cifar10 3 | entity: anonymity 4 | name: cifar10-final 5 | method: grid 6 | metric: 7 | goal: maximize 8 | name: best/test_accuracy 9 | parameters: 10 | dropout_res: 11 | value: 0.1 12 | dropout_local: 13 | value: 0.1 14 | dropout_ff: 15 | value: 0.1 16 | base_lr: 17 | value: 0.005 18 | weight_decay: 19 | value: 0.01 20 | 21 | reweigh_self: 22 | value: 2 23 | jk: 24 | value: 0 25 | init_pe_dim: 26 | value: 32 27 | more_mapping: 28 | value: 1 29 | cfg: 30 | value: configs/GSSC/cifar10-GSSC.yaml 31 | name_tag: 32 | value: random 33 | log_code: 34 | value: 0 35 | device: 36 | value: 0 37 | seed: 38 | values: [0, 1, 2, 3, 4] 39 | -------------------------------------------------------------------------------- /configs/GSSC/cifar10-GSSC.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: accuracy 3 | wandb: 4 | use: True 5 | project: cifar10 6 | entity: anonymity 7 | dataset: 8 | format: PyG-GNNBenchmarkDataset 9 | name: CIFAR10 10 | task: graph 11 | task_type: classification 12 | transductive: False 13 | node_encoder: True 14 | node_encoder_name: LapPE 15 | node_encoder_bn: False 16 | edge_encoder: True 17 | edge_encoder_name: LinearEdge 18 | edge_encoder_bn: False 19 | posenc_LapPE: 20 | enable: True 21 | eigen: 22 | laplacian_norm: none 23 | eigvec_norm: L2 24 | max_freqs: 33 25 | model: DeepSet 26 | dim_pe: 16 27 | layers: 2 28 | raw_norm_type: none 29 | train: 30 | mode: custom 31 | batch_size: 16 32 | eval_period: 1 33 | ckpt_period: 100 34 | model: 35 | type: GPSModel 36 | loss_fun: cross_entropy 37 | edge_decoding: dot 38 | graph_pooling: mean 39 | gt: # Hyperparameters optimized for ~100k budget. 40 | layer_type: CustomGatedGCN+GSSC 41 | layers: 3 42 | n_heads: 4 43 | dim_hidden: 52 # `gt.dim_hidden` must match `gnn.dim_inner` 44 | dropout: 0.0 45 | attn_dropout: 0.5 46 | layer_norm: False 47 | batch_norm: True 48 | gnn: 49 | head: default 50 | layers_pre_mp: 0 51 | layers_post_mp: 2 52 | dim_inner: 52 # `gt.dim_hidden` must match `gnn.dim_inner` 53 | batchnorm: False 54 | act: relu 55 | dropout: 0.0 56 | agg: mean 57 | normalize_adj: False 58 | optim: 59 | clip_grad_norm: True 60 | optimizer: adamW 61 | weight_decay: 1e-5 62 | base_lr: 0.001 63 | max_epoch: 300 64 | scheduler: cosine_with_warmup 65 | num_warmup_epochs: 5 66 | seed: 0 67 | name_tag: "random" 68 | -------------------------------------------------------------------------------- /configs/GSSC/cluster-GSSC-tune.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | project: CLUSTER 3 | entity: anonymity 4 | name: cluster-final 5 | method: grid 6 | metric: 7 | goal: maximize 8 | name: best/test_accuracy-SBM 9 | parameters: 10 | dropout_res: 11 | value: 0.3 12 | dropout_local: 13 | value: 0.3 14 | dropout_ff: 15 | value: 0.1 16 | base_lr: 17 | values: [0.001, 0.002, 0.003] 18 | weight_decay: 19 | values: [0.1, 0.2, 0.001] 20 | 21 | reweigh_self: 22 | value: 2 23 | jk: 24 | value: 1 25 | init_pe_dim: 26 | value: 32 27 | more_mapping: 28 | value: 1 29 | cfg: 30 | value: configs/GSSC/cluster-GSSC.yaml 31 | name_tag: 32 | value: random 33 | log_code: 34 | value: 0 35 | device: 36 | value: 0 37 | seed: 38 | values: [0, 1, 2, 3, 4] 39 | -------------------------------------------------------------------------------- /configs/GSSC/cluster-GSSC.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: accuracy-SBM 3 | wandb: 4 | use: True 5 | project: CLUSTER 6 | entity: anonymity 7 | dataset: 8 | format: PyG-GNNBenchmarkDataset 9 | name: CLUSTER 10 | task: graph 11 | task_type: classification 12 | transductive: False 13 | split_mode: standard 14 | node_encoder: True 15 | node_encoder_name: RWSE 16 | node_encoder_bn: False 17 | edge_encoder: True 18 | edge_encoder_name: DummyEdge 19 | edge_encoder_bn: False 20 | posenc_RWSE: 21 | enable: True 22 | kernel: 23 | times_func: range(1,33) 24 | model: Linear 25 | dim_pe: 18 26 | raw_norm_type: BatchNorm 27 | posenc_LapPE: 28 | enable: True 29 | eigen: 30 | laplacian_norm: none 31 | eigvec_norm: L2 32 | max_freqs: 33 33 | model: DeepSet 34 | dim_pe: 18 35 | layers: 2 36 | n_heads: 4 # Only used when `posenc.model: Transformer` 37 | raw_norm_type: none 38 | train: 39 | mode: custom 40 | batch_size: 16 41 | eval_period: 1 42 | ckpt_period: 100 43 | model: 44 | type: GPSModel 45 | loss_fun: weighted_cross_entropy 46 | edge_decoding: dot 47 | gt: # Hyperparameters optimized for ~500k budget. 48 | layer_type: CustomGatedGCN+GSSC 49 | layers: 24 50 | n_heads: 8 51 | dim_hidden: 36 # `gt.dim_hidden` must match `gnn.dim_inner` 52 | dropout: 0.1 53 | attn_dropout: 0.5 54 | layer_norm: False 55 | batch_norm: True 56 | gnn: 57 | head: inductive_node 58 | layers_pre_mp: 0 59 | layers_post_mp: 3 60 | dim_inner: 36 # `gt.dim_hidden` must match `gnn.dim_inner` 61 | batchnorm: True 62 | act: relu 63 | dropout: 0.0 64 | agg: mean 65 | normalize_adj: False 66 | optim: 67 | clip_grad_norm: True 68 | optimizer: adamW 69 | weight_decay: 0.1 70 | base_lr: 0.001 71 | max_epoch: 300 72 | scheduler: cosine_with_warmup 73 | num_warmup_epochs: 5 74 | seed: 0 75 | name_tag: "random" 76 | -------------------------------------------------------------------------------- /configs/GSSC/mnist-GSSC-tune.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | project: MNIST 3 | entity: anonymity 4 | name: MNIST-final 5 | method: grid 6 | metric: 7 | goal: maximize 8 | name: best/test_accuracy 9 | parameters: 10 | dropout_res: 11 | value: 0.1 12 | dropout_local: 13 | value: 0.1 14 | dropout_ff: 15 | value: 0.1 16 | base_lr: 17 | value: 0.005 18 | weight_decay: 19 | value: 0.01 20 | 21 | reweigh_self: 22 | value: 2 23 | jk: 24 | value: 0 25 | init_pe_dim: 26 | value: 32 27 | more_mapping: 28 | value: 1 29 | cfg: 30 | value: configs/GSSC/mnist-GSSC.yaml 31 | name_tag: 32 | value: random 33 | log_code: 34 | value: 0 35 | device: 36 | value: 0 37 | seed: 38 | values: [0, 1, 2, 3, 4] 39 | -------------------------------------------------------------------------------- /configs/GSSC/mnist-GSSC.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: accuracy 3 | wandb: 4 | use: True 5 | project: MNIST 6 | entity: anonymity 7 | dataset: 8 | format: PyG-GNNBenchmarkDataset 9 | name: MNIST 10 | task: graph 11 | task_type: classification 12 | transductive: False 13 | node_encoder: True 14 | node_encoder_name: LapPE 15 | node_encoder_bn: False 16 | edge_encoder: True 17 | edge_encoder_name: LinearEdge 18 | edge_encoder_bn: False 19 | posenc_LapPE: 20 | enable: True 21 | eigen: 22 | laplacian_norm: none 23 | eigvec_norm: L2 24 | max_freqs: 33 25 | model: DeepSet 26 | dim_pe: 16 27 | layers: 2 28 | raw_norm_type: none 29 | train: 30 | mode: custom 31 | batch_size: 16 32 | eval_period: 1 33 | ckpt_period: 100 34 | model: 35 | type: GPSModel 36 | loss_fun: cross_entropy 37 | edge_decoding: dot 38 | graph_pooling: mean 39 | gt: # Hyperparameters optimized for ~100k budget. 40 | layer_type: CustomGatedGCN+GSSC 41 | layers: 3 42 | n_heads: 4 43 | dim_hidden: 52 # `gt.dim_hidden` must match `gnn.dim_inner` 44 | dropout: 0.0 45 | attn_dropout: 0.5 46 | layer_norm: False 47 | batch_norm: True 48 | gnn: 49 | head: default 50 | layers_pre_mp: 0 51 | layers_post_mp: 3 52 | dim_inner: 52 # `gt.dim_hidden` must match `gnn.dim_inner` 53 | batchnorm: False 54 | act: relu 55 | dropout: 0.0 56 | agg: mean 57 | normalize_adj: False 58 | optim: 59 | clip_grad_norm: True 60 | optimizer: adamW 61 | weight_decay: 1e-5 62 | base_lr: 0.001 63 | max_epoch: 300 64 | scheduler: cosine_with_warmup 65 | num_warmup_epochs: 5 66 | seed: 0 67 | name_tag: "random" 68 | -------------------------------------------------------------------------------- /configs/GSSC/ogbg-molhiv-GSSC-tune.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | project: molhiv 3 | entity: anonymity 4 | name: molhiv-final 5 | method: grid 6 | metric: 7 | goal: maximize 8 | name: best/test_auc 9 | parameters: 10 | dropout_res: 11 | value: 0.0 12 | dropout_local: 13 | value: 0.3 14 | dropout_ff: 15 | value: 0.0 16 | weight_decay: 17 | values: [0.1, 1.0e-3, 1.0e-5] 18 | base_lr: 19 | values: [0.001, 0.002, 0.0005] 20 | 21 | reweigh_self: 22 | value: 2 23 | jk: 24 | value: 0 25 | init_pe_dim: 26 | value: 32 27 | more_mapping: 28 | value: 1 29 | cfg: 30 | value: configs/GSSC/ogbg-molhiv-GSSC.yaml 31 | name_tag: 32 | value: random 33 | log_code: 34 | value: 0 35 | device: 36 | value: 0 37 | seed: 38 | values: [0, 1, 2, 3, 4] 39 | -------------------------------------------------------------------------------- /configs/GSSC/ogbg-molhiv-GSSC.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: auc 3 | wandb: 4 | use: True 5 | project: molhiv 6 | entity: anonymity 7 | dataset: 8 | format: OGB 9 | name: ogbg-molhiv 10 | task: graph 11 | task_type: classification 12 | transductive: False 13 | node_encoder: True 14 | node_encoder_name: Atom+RWSE 15 | node_encoder_bn: False 16 | edge_encoder: True 17 | edge_encoder_name: Bond 18 | edge_encoder_bn: False 19 | posenc_RWSE: 20 | enable: True 21 | kernel: 22 | times_func: range(1,21) 23 | model: Linear 24 | dim_pe: 28 25 | raw_norm_type: BatchNorm 26 | posenc_LapPE: 27 | enable: True 28 | eigen: 29 | laplacian_norm: none 30 | eigvec_norm: L2 31 | max_freqs: 17 32 | model: DeepSet 33 | dim_pe: 16 34 | layers: 2 35 | raw_norm_type: none 36 | train: 37 | mode: custom 38 | batch_size: 32 39 | eval_period: 1 40 | ckpt_period: 100 41 | model: 42 | type: GPSModel 43 | loss_fun: cross_entropy 44 | edge_decoding: dot 45 | graph_pooling: mean 46 | gt: 47 | layer_type: CustomGatedGCN+GSSC 48 | layers: 6 49 | n_heads: 4 50 | dim_hidden: 64 # `gt.dim_hidden` must match `gnn.dim_inner` 51 | dropout: 0.0 52 | attn_dropout: 0.5 53 | layer_norm: False 54 | batch_norm: True 55 | gnn: 56 | head: san_graph 57 | layers_pre_mp: 0 58 | layers_post_mp: 3 # Not used when `gnn.head: san_graph` 59 | dim_inner: 64 # `gt.dim_hidden` must match `gnn.dim_inner` 60 | batchnorm: True 61 | act: relu 62 | dropout: 0.0 63 | agg: mean 64 | normalize_adj: False 65 | optim: 66 | clip_grad_norm: True 67 | optimizer: adamW 68 | weight_decay: 1e-5 69 | base_lr: 0.0001 70 | max_epoch: 100 71 | scheduler: cosine_with_warmup 72 | num_warmup_epochs: 5 73 | seed: 0 74 | name_tag: "random" 75 | -------------------------------------------------------------------------------- /configs/GSSC/pattern-GSSC-tune.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | project: PATTERN 3 | entity: anonymity 4 | name: pattern-final 5 | method: grid 6 | metric: 7 | goal: maximize 8 | name: best/test_accuracy-SBM 9 | parameters: 10 | dropout_res: 11 | value: 0.5 12 | dropout_local: 13 | value: 0.1 14 | dropout_ff: 15 | value: 0.1 16 | base_lr: 17 | value: 0.001 18 | weight_decay: 19 | value: 0.1 20 | 21 | reweigh_self: 22 | value: 2 23 | jk: 24 | value: 1 25 | init_pe_dim: 26 | value: 32 27 | more_mapping: 28 | value: 1 29 | cfg: 30 | value: configs/GSSC/pattern-GSSC.yaml 31 | name_tag: 32 | value: random 33 | log_code: 34 | value: 0 35 | device: 36 | value: 0 37 | seed: 38 | values: [0, 1, 2, 3, 4] 39 | -------------------------------------------------------------------------------- /configs/GSSC/pattern-GSSC.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: accuracy-SBM 3 | wandb: 4 | use: True 5 | project: PATTERN 6 | entity: anonymity 7 | dataset: 8 | format: PyG-GNNBenchmarkDataset 9 | name: PATTERN 10 | task: graph 11 | task_type: classification 12 | transductive: False 13 | split_mode: standard 14 | node_encoder: True 15 | node_encoder_name: RWSE 16 | node_encoder_bn: False 17 | edge_encoder: True 18 | edge_encoder_name: DummyEdge 19 | edge_encoder_bn: False 20 | posenc_RWSE: 21 | enable: True 22 | kernel: 23 | times_func: range(1,22) 24 | model: Linear 25 | dim_pe: 18 26 | raw_norm_type: BatchNorm 27 | posenc_LapPE: 28 | enable: True 29 | eigen: 30 | laplacian_norm: none 31 | eigvec_norm: L2 32 | max_freqs: 33 33 | model: DeepSet 34 | dim_pe: 18 35 | layers: 2 36 | n_heads: 4 # Only used when `posenc.model: Transformer` 37 | raw_norm_type: none 38 | train: 39 | mode: custom 40 | batch_size: 32 41 | eval_period: 1 42 | ckpt_period: 100 43 | model: 44 | type: GPSModel 45 | loss_fun: weighted_cross_entropy 46 | edge_decoding: dot 47 | gt: # Hyperparameters optimized for up to ~500k budget. 48 | layer_type: CustomGatedGCN+GSSC 49 | layers: 24 50 | n_heads: 4 51 | dim_hidden: 36 # `gt.dim_hidden` must match `gnn.dim_inner` 52 | dropout: 0.0 53 | attn_dropout: 0.5 54 | layer_norm: False 55 | batch_norm: True 56 | gnn: 57 | head: inductive_node 58 | layers_pre_mp: 0 59 | layers_post_mp: 3 60 | dim_inner: 36 # `gt.dim_hidden` must match `gnn.dim_inner` 61 | batchnorm: True 62 | act: relu 63 | dropout: 0.0 64 | agg: mean 65 | normalize_adj: False 66 | optim: 67 | clip_grad_norm: True 68 | optimizer: adamW 69 | weight_decay: 0.0001 70 | base_lr: 0.001 71 | max_epoch: 200 72 | scheduler: cosine_with_warmup 73 | num_warmup_epochs: 5 74 | seed: 0 75 | name_tag: "random" 76 | -------------------------------------------------------------------------------- /configs/GSSC/peptides-func-GSSC-tune.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | project: LRGB 3 | entity: anonymity 4 | name: peptides-final 5 | method: grid 6 | metric: 7 | goal: maximize 8 | name: best/test_ap 9 | parameters: 10 | dropout_res: 11 | value: 0.1 12 | dropout_local: 13 | value: 0.1 14 | dropout_ff: 15 | values: [0.1, 0.0] 16 | base_lr: 17 | value: 0.003 18 | weight_decay: 19 | value: 0.1 20 | 21 | reweigh_self: 22 | value: 2 23 | jk: 24 | value: 0 25 | init_pe_dim: 26 | value: 32 27 | more_mapping: 28 | value: 1 29 | cfg: 30 | value: configs/GSSC/peptides-func-GSSC.yaml 31 | name_tag: 32 | value: random 33 | log_code: 34 | value: 0 35 | device: 36 | value: 0 37 | seed: 38 | values: [0, 1, 2, 3, 4] 39 | -------------------------------------------------------------------------------- /configs/GSSC/peptides-func-GSSC.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: ap 3 | wandb: 4 | use: True 5 | project: LRGB-Benchmark-RandomSeed 6 | entity: anonymity 7 | dataset: 8 | format: OGB 9 | name: peptides-functional 10 | task: graph 11 | task_type: classification_multilabel 12 | transductive: False 13 | node_encoder: True 14 | node_encoder_name: Atom+LapPE 15 | node_encoder_bn: False 16 | edge_encoder: True 17 | edge_encoder_name: Bond 18 | edge_encoder_bn: False 19 | split_mode: standard 20 | posenc_LapPE: 21 | enable: True 22 | eigen: 23 | laplacian_norm: none 24 | eigvec_norm: L2 25 | max_freqs: 32 26 | model: DeepSet 27 | dim_pe: 16 28 | layers: 2 29 | raw_norm_type: none 30 | train: 31 | mode: custom 32 | batch_size: 128 33 | eval_period: 1 34 | ckpt_period: 100 35 | model: 36 | type: GPSModel 37 | loss_fun: cross_entropy 38 | graph_pooling: mean 39 | gt: 40 | layer_type: CustomGatedGCN+GSSC 41 | n_heads: 4 42 | dim_hidden: 100 # `gt.dim_hidden` must match `gnn.dim_inner` 43 | dropout: 0.0 44 | attn_dropout: 0.5 45 | layer_norm: False 46 | batch_norm: True 47 | gnn: 48 | head: default 49 | layers_pre_mp: 0 50 | layers_post_mp: 1 # Not used when `gnn.head: san_graph` 51 | dim_inner: 100 # `gt.dim_hidden` must match `gnn.dim_inner` 52 | batchnorm: True 53 | act: relu 54 | dropout: 0.0 55 | optim: 56 | clip_grad_norm: True 57 | optimizer: adamW 58 | weight_decay: 0.1 59 | base_lr: 0.003 60 | max_epoch: 200 61 | scheduler: cosine_with_warmup 62 | num_warmup_epochs: 10 63 | seed: 0 64 | name_tag: "random" 65 | -------------------------------------------------------------------------------- /configs/GSSC/peptides-struct-GSSC-tune.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | project: LRGB 3 | entity: anonymity 4 | name: peptides-final 5 | method: grid 6 | metric: 7 | goal: minimize 8 | name: best/test_mae 9 | parameters: 10 | dropout_res: 11 | values: [0.3, 0.1] 12 | dropout_local: 13 | values: [0.1, 0.3] 14 | dropout_ff: 15 | value: 0.1 16 | base_lr: 17 | value: 0.001 18 | weight_decay: 19 | value: 0.1 20 | 21 | reweigh_self: 22 | value: 2 23 | jk: 24 | value: 1 25 | init_pe_dim: 26 | value: 32 27 | more_mapping: 28 | value: 1 29 | cfg: 30 | value: configs/GSSC/peptides-struct-GSSC.yaml 31 | name_tag: 32 | value: random 33 | log_code: 34 | value: 0 35 | device: 36 | value: 0 37 | seed: 38 | values: [0, 1, 2, 3, 4] 39 | -------------------------------------------------------------------------------- /configs/GSSC/peptides-struct-GSSC.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: mae 3 | metric_agg: argmin 4 | wandb: 5 | use: True 6 | project: LRGB-Benchmark-RandomSeed 7 | entity: anonymity 8 | dataset: 9 | format: OGB 10 | name: peptides-structural 11 | task: graph 12 | task_type: regression 13 | transductive: False 14 | node_encoder: True 15 | node_encoder_name: Atom+LapPE 16 | node_encoder_bn: False 17 | edge_encoder: True 18 | edge_encoder_name: Bond 19 | edge_encoder_bn: False 20 | posenc_LapPE: 21 | enable: True 22 | eigen: 23 | laplacian_norm: none 24 | eigvec_norm: L2 25 | max_freqs: 32 26 | model: DeepSet 27 | dim_pe: 16 28 | layers: 2 29 | raw_norm_type: none 30 | train: 31 | mode: custom 32 | batch_size: 128 33 | eval_period: 1 34 | ckpt_period: 100 35 | model: 36 | type: GPSModel 37 | loss_fun: l1 38 | graph_pooling: mean 39 | gt: 40 | layer_type: CustomGatedGCN+GSSC 41 | layers: 3 42 | n_heads: 4 43 | dim_hidden: 100 # `gt.dim_hidden` must match `gnn.dim_inner` 44 | dropout: 0.0 45 | attn_dropout: 0.5 46 | layer_norm: False 47 | batch_norm: True 48 | gnn: 49 | head: san_graph 50 | layers_pre_mp: 0 51 | layers_post_mp: 2 # Not used when `gnn.head: san_graph` 52 | dim_inner: 100 # `gt.dim_hidden` must match `gnn.dim_inner` 53 | batchnorm: True 54 | act: relu 55 | dropout: 0.0 56 | optim: 57 | clip_grad_norm: True 58 | optimizer: adamW 59 | weight_decay: 0.1 60 | base_lr: 0.003 61 | max_epoch: 200 62 | scheduler: cosine_with_warmup 63 | num_warmup_epochs: 10 64 | seed: 0 65 | name_tag: "random" 66 | -------------------------------------------------------------------------------- /configs/GSSC/vocsuperpixels-GSSC-tune.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | project: LRGB 3 | entity: anonymity 4 | name: vocsuperpixels-final 5 | method: grid 6 | metric: 7 | goal: maximize 8 | name: best/test_f1 9 | parameters: 10 | dropout_res: 11 | value: 0.5 12 | dropout_local: 13 | value: 0.0 14 | dropout_ff: 15 | value: 0.0 16 | weight_decay: 17 | value: 0.1 18 | base_lr: 19 | value: 0.002 20 | 21 | reweigh_self: 22 | value: 0 23 | jk: 24 | value: 0 25 | init_pe_dim: 26 | value: 32 27 | more_mapping: 28 | value: 0 29 | cfg: 30 | value: configs/GSSC/vocsuperpixels-GSSC.yaml 31 | name_tag: 32 | value: random 33 | log_code: 34 | value: 0 35 | device: 36 | value: 0 37 | seed: 38 | values: [0, 1, 2, 3, 4] 39 | -------------------------------------------------------------------------------- /configs/GSSC/vocsuperpixels-GSSC.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: f1 3 | wandb: 4 | project: LRGB-Benchmark-RandomSeed 5 | entity: anonymity 6 | use: True 7 | dataset: 8 | format: PyG-VOCSuperpixels 9 | name: edge_wt_region_boundary 10 | slic_compactness: 30 11 | task: graph # Even if VOC is node-level task, this needs to be set as 'graph' 12 | task_type: classification 13 | transductive: False 14 | node_encoder: True 15 | node_encoder_name: VOCNode+LapPE 16 | node_encoder_bn: False 17 | edge_encoder: True 18 | edge_encoder_name: VOCEdge 19 | edge_encoder_bn: False 20 | posenc_LapPE: 21 | enable: True 22 | eigen: 23 | laplacian_norm: none 24 | eigvec_norm: L2 25 | max_freqs: 64 26 | model: DeepSet 27 | dim_pe: 16 28 | layers: 2 29 | raw_norm_type: none 30 | train: 31 | mode: custom 32 | batch_size: 32 33 | eval_period: 1 34 | ckpt_period: 100 35 | model: 36 | type: GPSModel 37 | loss_fun: weighted_cross_entropy 38 | gt: 39 | layer_type: CustomGatedGCN+GSSC 40 | layers: 4 41 | n_heads: 8 42 | dim_hidden: 96 # `gt.dim_hidden` must match `gnn.dim_inner` 43 | dropout: 0.0 44 | attn_dropout: 0.5 45 | layer_norm: False 46 | batch_norm: True 47 | gnn: 48 | head: inductive_node 49 | layers_pre_mp: 0 50 | layers_post_mp: 3 51 | dim_inner: 96 # `gt.dim_hidden` must match `gnn.dim_inner` 52 | batchnorm: True 53 | act: relu 54 | dropout: 0.0 55 | agg: mean 56 | normalize_adj: False 57 | optim: 58 | clip_grad_norm: True 59 | optimizer: adamW 60 | weight_decay: 0.1 61 | base_lr: 0.002 62 | max_epoch: 300 63 | scheduler: cosine_with_warmup 64 | num_warmup_epochs: 10 65 | seed: 0 66 | name_tag: "random" 67 | -------------------------------------------------------------------------------- /configs/GSSC/zinc-GSSC-tune.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | project: ZINC 3 | entity: anonymity 4 | name: zinc-final 5 | method: grid 6 | metric: 7 | goal: minimize 8 | name: best/test_mae 9 | parameters: 10 | dropout_res: 11 | values: [0.5, 0.6] 12 | dropout_local: 13 | value: 0.0 14 | dropout_ff: 15 | value: 0.1 16 | weight_decay: 17 | value: 1.0e-5 18 | base_lr: 19 | value: 1.0e-3 20 | 21 | reweigh_self: 22 | value: 2 23 | jk: 24 | value: 1 25 | init_pe_dim: 26 | value: 32 27 | more_mapping: 28 | value: 1 29 | cfg: 30 | value: configs/GSSC/zinc-GSSC.yaml 31 | name_tag: 32 | value: random 33 | log_code: 34 | value: 0 35 | device: 36 | value: 0 37 | seed: 38 | values: [0, 1, 2, 3, 4] 39 | -------------------------------------------------------------------------------- /configs/GSSC/zinc-GSSC.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: mae 3 | metric_agg: argmin 4 | wandb: 5 | use: True 6 | project: ZINC 7 | entity: anonymity 8 | dataset: 9 | format: PyG-ZINC 10 | name: subset 11 | task: graph 12 | task_type: regression 13 | transductive: False 14 | node_encoder: True 15 | node_encoder_name: TypeDictNode+RWSE 16 | node_encoder_num_types: 21 17 | node_encoder_bn: False 18 | edge_encoder: True 19 | edge_encoder_name: TypeDictEdge 20 | edge_encoder_num_types: 4 21 | edge_encoder_bn: False 22 | posenc_RWSE: 23 | enable: True 24 | kernel: 25 | times_func: range(1,21) 26 | model: Linear 27 | dim_pe: 28 28 | raw_norm_type: BatchNorm 29 | posenc_LapPE: 30 | enable: True 31 | eigen: 32 | laplacian_norm: none 33 | eigvec_norm: L2 34 | max_freqs: 17 35 | model: DeepSet 36 | dim_pe: 16 37 | layers: 2 38 | raw_norm_type: none 39 | train: 40 | mode: custom 41 | batch_size: 32 42 | eval_period: 1 43 | ckpt_period: 100 44 | model: 45 | type: GPSModel 46 | loss_fun: l1 47 | edge_decoding: dot 48 | graph_pooling: add 49 | gt: 50 | layer_type: GINE+GSSC # CustomGatedGCN+Performer 51 | layers: 10 52 | n_heads: 4 53 | dim_hidden: 64 # `gt.dim_hidden` must match `gnn.dim_inner` 54 | dropout: 0.0 55 | attn_dropout: 0.5 56 | layer_norm: False 57 | batch_norm: True 58 | gnn: 59 | head: san_graph 60 | layers_pre_mp: 0 61 | layers_post_mp: 3 # Not used when `gnn.head: san_graph` 62 | dim_inner: 64 # `gt.dim_hidden` must match `gnn.dim_inner` 63 | batchnorm: True 64 | act: relu 65 | dropout: 0.0 66 | agg: mean 67 | normalize_adj: False 68 | optim: 69 | clip_grad_norm: True 70 | optimizer: adamW 71 | weight_decay: 1.0e-5 72 | base_lr: 1.0e-3 73 | max_epoch: 2000 74 | scheduler: cosine_with_warmup 75 | num_warmup_epochs: 200 76 | seed: 0 77 | name_tag: "random" 78 | -------------------------------------------------------------------------------- /configs/GSSC/zincfull-GSSC-tune.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | project: ZINC 3 | entity: anonymity 4 | name: zincfull-final 5 | method: grid 6 | metric: 7 | goal: minimize 8 | name: best/test_mae 9 | parameters: 10 | dropout_res: 11 | values: [0.1, 0.3] 12 | dropout_local: 13 | value: 0.0 14 | dropout_ff: 15 | value: 0.1 16 | weight_decay: 17 | values: [1.0e-3, 1.0e-05] 18 | base_lr: 19 | value: 0.002 20 | 21 | reweigh_self: 22 | value: 1 23 | jk: 24 | value: 0 25 | init_pe_dim: 26 | value: 32 27 | more_mapping: 28 | value: 1 29 | cfg: 30 | value: configs/GSSC/zincfull-GSSC.yaml 31 | name_tag: 32 | value: random 33 | log_code: 34 | value: 0 35 | device: 36 | value: 0 37 | seed: 38 | values: [0, 1, 2] 39 | -------------------------------------------------------------------------------- /configs/GSSC/zincfull-GSSC.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: mae 3 | metric_agg: argmin 4 | wandb: 5 | use: True 6 | project: ZINC 7 | entity: anonymity 8 | dataset: 9 | format: PyG-ZINC 10 | name: full 11 | task: graph 12 | task_type: regression 13 | transductive: False 14 | node_encoder: True 15 | node_encoder_name: TypeDictNode+RWSE 16 | node_encoder_num_types: 28 17 | node_encoder_bn: False 18 | edge_encoder: True 19 | edge_encoder_name: TypeDictEdge 20 | edge_encoder_num_types: 4 21 | edge_encoder_bn: False 22 | posenc_RWSE: 23 | enable: True 24 | kernel: 25 | times_func: range(1,21) 26 | model: Linear 27 | dim_pe: 28 28 | raw_norm_type: BatchNorm 29 | posenc_LapPE: 30 | enable: True 31 | eigen: 32 | laplacian_norm: none 33 | eigvec_norm: L2 34 | max_freqs: 17 35 | model: DeepSet 36 | dim_pe: 16 37 | layers: 2 38 | raw_norm_type: none 39 | train: 40 | mode: custom 41 | batch_size: 128 42 | eval_period: 1 43 | ckpt_period: 100 44 | model: 45 | type: GPSModel 46 | loss_fun: l1 47 | edge_decoding: dot 48 | graph_pooling: add 49 | gt: 50 | layer_type: GINE+GSSC # CustomGatedGCN+Performer 51 | layers: 10 52 | n_heads: 4 53 | dim_hidden: 64 # `gt.dim_hidden` must match `gnn.dim_inner` 54 | dropout: 0.0 55 | attn_dropout: 0.5 56 | layer_norm: False 57 | batch_norm: True 58 | gnn: 59 | head: san_graph 60 | layers_pre_mp: 0 61 | layers_post_mp: 3 # Not used when `gnn.head: san_graph` 62 | dim_inner: 64 # `gt.dim_hidden` must match `gnn.dim_inner` 63 | batchnorm: True 64 | act: relu 65 | dropout: 0.0 66 | agg: mean 67 | normalize_adj: False 68 | optim: 69 | clip_grad_norm: True 70 | optimizer: adamW 71 | weight_decay: 1.0e-3 72 | base_lr: 2.0e-3 73 | max_epoch: 2000 74 | scheduler: cosine_with_warmup 75 | num_warmup_epochs: 50 76 | seed: 0 77 | name_tag: "random" 78 | extra: 79 | check_if_pe_done: 1 80 | -------------------------------------------------------------------------------- /datasets/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: gssc 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=2_gnu 10 | - asttokens=2.4.1=pyhd8ed1ab_0 11 | - blas=1.0=mkl 12 | - brotli-python=1.0.9=py39h6a678d5_7 13 | - bzip2=1.0.8=h5eee18b_5 14 | - ca-certificates=2024.3.11=h06a4308_0 15 | - certifi=2024.2.2=pyhd8ed1ab_0 16 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 17 | - comm=0.2.2=pyhd8ed1ab_0 18 | - cuda-cudart=11.7.99=0 19 | - cuda-cupti=11.7.101=0 20 | - cuda-libraries=11.7.1=0 21 | - cuda-nvrtc=11.7.99=0 22 | - cuda-nvtx=11.7.91=0 23 | - cuda-runtime=11.7.1=0 24 | - debugpy=1.6.7=py39h6a678d5_0 25 | - decorator=5.1.1=pyhd8ed1ab_0 26 | - exceptiongroup=1.2.0=pyhd8ed1ab_2 27 | - executing=2.0.1=pyhd8ed1ab_0 28 | - ffmpeg=4.3=hf484d3e_0 29 | - filelock=3.13.1=py39h06a4308_0 30 | - freetype=2.12.1=h4a9f257_0 31 | - gmp=6.2.1=h295c915_3 32 | - gmpy2=2.1.2=py39heeb90bb_0 33 | - gnutls=3.6.15=he1e5248_0 34 | - idna=3.4=py39h06a4308_0 35 | - importlib-metadata=7.1.0=pyha770c72_0 36 | - importlib_metadata=7.1.0=hd8ed1ab_0 37 | - intel-openmp=2023.1.0=hdb19cb5_46306 38 | - ipykernel=6.29.3=pyhd33586a_0 39 | - ipython=8.18.1=pyh707e725_3 40 | - jedi=0.19.1=pyhd8ed1ab_0 41 | - jinja2=3.1.3=py39h06a4308_0 42 | - jpeg=9e=h5eee18b_1 43 | - jupyter_client=8.6.1=pyhd8ed1ab_0 44 | - jupyter_core=5.7.2=py39hf3d152e_0 45 | - lame=3.100=h7b6447c_0 46 | - lcms2=2.12=h3be6417_0 47 | - ld_impl_linux-64=2.38=h1181459_1 48 | - lerc=3.0=h295c915_0 49 | - libcublas=11.10.3.66=0 50 | - libcufft=10.7.2.124=h4fbf590_0 51 | - libcufile=1.9.0.20=0 52 | - libcurand=10.3.5.119=0 53 | - libcusolver=11.4.0.1=0 54 | - libcusparse=11.7.4.91=0 55 | - libdeflate=1.17=h5eee18b_1 56 | - libffi=3.4.4=h6a678d5_0 57 | - libgcc-ng=13.2.0=h807b86a_5 58 | - libgomp=13.2.0=h807b86a_5 59 | - libiconv=1.16=h7f8727e_2 60 | - libidn2=2.3.4=h5eee18b_0 61 | - libnpp=11.7.4.75=0 62 | - libnvjpeg=11.8.0.2=0 63 | - libpng=1.6.39=h5eee18b_0 64 | - libsodium=1.0.18=h36c2ea0_1 65 | - libstdcxx-ng=11.2.0=h1234567_1 66 | - libtasn1=4.19.0=h5eee18b_0 67 | - libtiff=4.5.1=h6a678d5_0 68 | - libunistring=0.9.10=h27cfd23_0 69 | - libwebp-base=1.3.2=h5eee18b_0 70 | - lz4-c=1.9.4=h6a678d5_0 71 | - markupsafe=2.1.3=py39h5eee18b_0 72 | - matplotlib-inline=0.1.6=pyhd8ed1ab_0 73 | - mkl=2023.1.0=h213fc3f_46344 74 | - mkl-service=2.4.0=py39h5eee18b_1 75 | - mkl_fft=1.3.8=py39h5eee18b_0 76 | - mkl_random=1.2.4=py39hdb19cb5_0 77 | - mpc=1.1.0=h10f8cd9_1 78 | - mpfr=4.0.2=hb69a4c5_1 79 | - mpmath=1.3.0=py39h06a4308_0 80 | - ncurses=6.4=h6a678d5_0 81 | - nest-asyncio=1.6.0=pyhd8ed1ab_0 82 | - nettle=3.7.3=hbbd107a_1 83 | - networkx=3.1=py39h06a4308_0 84 | - numpy=1.26.4=py39h5f9d8c6_0 85 | - numpy-base=1.26.4=py39hb5e798b_0 86 | - openh264=2.1.1=h4ff587b_0 87 | - openjpeg=2.4.0=h3ad879b_0 88 | - openssl=3.2.1=hd590300_1 89 | - packaging=24.0=pyhd8ed1ab_0 90 | - parso=0.8.3=pyhd8ed1ab_0 91 | - pexpect=4.9.0=pyhd8ed1ab_0 92 | - pickleshare=0.7.5=py_1003 93 | - pillow=10.2.0=py39h5eee18b_0 94 | - pip=23.3.1=py39h06a4308_0 95 | - platformdirs=4.2.0=pyhd8ed1ab_0 96 | - prompt-toolkit=3.0.42=pyha770c72_0 97 | - psutil=5.9.8=py39hd1e30aa_0 98 | - ptyprocess=0.7.0=pyhd3deb0d_0 99 | - pure_eval=0.2.2=pyhd8ed1ab_0 100 | - pygments=2.17.2=pyhd8ed1ab_0 101 | - pysocks=1.7.1=py39h06a4308_0 102 | - python=3.9.19=h955ad1f_0 103 | - python_abi=3.9=2_cp39 104 | - pytorch=2.0.0=py3.9_cuda11.7_cudnn8.5.0_0 105 | - pytorch-cuda=11.7=h778d358_5 106 | - pytorch-mutex=1.0=cuda 107 | - pyzmq=25.1.2=py39h6a678d5_0 108 | - readline=8.2=h5eee18b_0 109 | - requests=2.31.0=py39h06a4308_1 110 | - setuptools=68.2.2=py39h06a4308_0 111 | - six=1.16.0=pyh6c4a22f_0 112 | - sqlite=3.41.2=h5eee18b_0 113 | - stack_data=0.6.2=pyhd8ed1ab_0 114 | - sympy=1.12=py39h06a4308_0 115 | - tbb=2021.8.0=hdb19cb5_0 116 | - tk=8.6.12=h1ccaba5_0 117 | - torchaudio=2.0.0=py39_cu117 118 | - torchtriton=2.0.0=py39 119 | - torchvision=0.15.0=py39_cu117 120 | - tornado=6.4=py39hd1e30aa_0 121 | - traitlets=5.14.2=pyhd8ed1ab_0 122 | - typing_extensions=4.9.0=py39h06a4308_1 123 | - urllib3=2.1.0=py39h06a4308_1 124 | - wcwidth=0.2.13=pyhd8ed1ab_0 125 | - wheel=0.41.2=py39h06a4308_0 126 | - xz=5.4.6=h5eee18b_0 127 | - zeromq=4.3.5=h6a678d5_0 128 | - zipp=3.17.0=pyhd8ed1ab_0 129 | - zlib=1.2.13=h5eee18b_0 130 | - zstd=1.5.5=hc292b87_0 131 | - pip: 132 | - aiohttp==3.9.3 133 | - aiosignal==1.3.1 134 | - annotated-types==0.6.0 135 | - appdirs==1.4.4 136 | - argparse==1.4.0 137 | - async-timeout==4.0.3 138 | - attrs==23.2.0 139 | - automat==22.10.0 140 | - axial-positional-embedding==0.2.1 141 | - buildtools==1.0.6 142 | - causal-conv1d==1.2.0.post2 143 | - click==8.1.7 144 | - cmake==3.29.0.1 145 | - constantly==23.10.4 146 | - deepspeed==0.14.0 147 | - docker-pycreds==0.4.0 148 | - docopt==0.6.2 149 | - einops==0.7.0 150 | - frozenlist==1.4.1 151 | - fsspec==2024.3.1 152 | - furl==2.1.3 153 | - gitdb==4.0.11 154 | - gitpython==3.1.43 155 | - greenlet==3.0.3 156 | - hjson==3.1.0 157 | - huggingface-hub==0.22.2 158 | - hyperlink==21.0.0 159 | - incremental==22.10.0 160 | - joblib==1.3.2 161 | - lightning-utilities==0.11.2 162 | - lit==18.1.2 163 | - littleutils==0.2.2 164 | - local-attention==1.9.0 165 | - mamba-ssm==1.2.0.post1 166 | - multidict==6.0.5 167 | - ninja==1.11.1.1 168 | - ogb==1.3.6 169 | - openbabel-wheel==3.1.1.19 170 | - orderedmultidict==1.0.1 171 | - outdated==0.2.2 172 | - pandas==2.2.1 173 | - performer-pytorch==1.1.4 174 | - protobuf==4.25.3 175 | - py-cpuinfo==9.0.0 176 | - pydantic==2.6.4 177 | - pydantic-core==2.16.3 178 | - pynvml==11.5.0 179 | - pyparsing==3.1.2 180 | - python-dateutil==2.9.0.post0 181 | - pytorch-lightning==2.2.1 182 | - pytz==2024.1 183 | - pyyaml==6.0.1 184 | - rdkit==2023.9.5 185 | - redo==2.0.4 186 | - regex==2023.12.25 187 | - safetensors==0.4.2 188 | - scikit-learn==1.4.1.post1 189 | - scipy==1.12.0 190 | - sentry-sdk==1.44.0 191 | - setproctitle==1.3.3 192 | - simplejson==3.19.2 193 | - smmap==5.0.1 194 | - sqlalchemy==2.0.29 195 | - tensorboardx==2.6.2.2 196 | - threadpoolctl==3.4.0 197 | - tokenizers==0.15.2 198 | - torchmetrics==0.9.3 199 | - tqdm==4.66.2 200 | - transformers==4.39.3 201 | - twisted==24.3.0 202 | - tzdata==2024.1 203 | - wandb==0.16.5 204 | - yacs==0.1.8 205 | - yarl==1.9.4 206 | - zope-interface==6.2 207 | # pip install torch_geometric==2.0.4 208 | # pip install torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-2.0.0+cu117.html 209 | -------------------------------------------------------------------------------- /gssc/__init__.py: -------------------------------------------------------------------------------- 1 | from .act import * # noqa 2 | from .config import * # noqa 3 | from .encoder import * # noqa 4 | from .head import * # noqa 5 | from .layer import * # noqa 6 | from .loader import * # noqa 7 | from .loss import * # noqa 8 | from .network import * # noqa 9 | from .optimizer import * # noqa 10 | from .pooling import * # noqa 11 | from .stage import * # noqa 12 | from .train import * # noqa 13 | from .transform import * # noqa -------------------------------------------------------------------------------- /gssc/act/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /gssc/act/example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch_geometric.graphgym.config import cfg 5 | from torch_geometric.graphgym.register import register_act 6 | 7 | 8 | class SWISH(nn.Module): 9 | def __init__(self, inplace=False): 10 | super().__init__() 11 | self.inplace = inplace 12 | 13 | def forward(self, x): 14 | if self.inplace: 15 | x.mul_(torch.sigmoid(x)) 16 | return x 17 | else: 18 | return x * torch.sigmoid(x) 19 | 20 | 21 | register_act('swish', SWISH(inplace=cfg.mem.inplace)) 22 | register_act('lrelu_03', nn.LeakyReLU(0.3, inplace=cfg.mem.inplace)) 23 | -------------------------------------------------------------------------------- /gssc/agg_runs.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import numpy as np 5 | 6 | from torch_geometric.graphgym.config import cfg 7 | from torch_geometric.graphgym.utils.io import ( 8 | dict_list_to_json, 9 | dict_list_to_tb, 10 | dict_to_json, 11 | json_to_dict_list, 12 | makedirs_rm_exist, 13 | string_to_python, 14 | ) 15 | 16 | try: 17 | from tensorboardX import SummaryWriter 18 | except ImportError: 19 | SummaryWriter = None 20 | 21 | 22 | def is_seed(s): 23 | try: 24 | int(s) 25 | return True 26 | except Exception: 27 | return False 28 | 29 | 30 | def is_split(s): 31 | if s in ['train', 'val', 'test']: 32 | return True 33 | else: 34 | return False 35 | 36 | 37 | def join_list(l1, l2): 38 | assert len(l1) == len(l2), \ 39 | 'Results with different seeds must have the save format' 40 | for i in range(len(l1)): 41 | l1[i] += l2[i] 42 | return l1 43 | 44 | 45 | def agg_dict_list(dict_list): 46 | """ 47 | Aggregate a list of dictionaries: mean + std 48 | Args: 49 | dict_list: list of dictionaries 50 | 51 | """ 52 | dict_agg = {'epoch': dict_list[0]['epoch']} 53 | for key in dict_list[0]: 54 | if key != 'epoch': 55 | value = np.array([dict[key] for dict in dict_list]) 56 | dict_agg[key] = np.mean(value).round(cfg.round) 57 | dict_agg['{}_std'.format(key)] = np.std(value).round(cfg.round) 58 | return dict_agg 59 | 60 | 61 | def name_to_dict(run): 62 | run = run.split('-', 1)[-1] 63 | cols = run.split('=') 64 | keys, vals = [], [] 65 | keys.append(cols[0]) 66 | for col in cols[1:-1]: 67 | try: 68 | val, key = col.rsplit('-', 1) 69 | except Exception: 70 | print(col) 71 | keys.append(key) 72 | vals.append(string_to_python(val)) 73 | vals.append(cols[-1]) 74 | return dict(zip(keys, vals)) 75 | 76 | 77 | def rm_keys(dict, keys): 78 | for key in keys: 79 | dict.pop(key, None) 80 | 81 | 82 | def agg_runs(dir, metric_best='auto'): 83 | r''' 84 | Aggregate over different random seeds of a single experiment 85 | 86 | Args: 87 | dir (str): Directory of the results, containing 1 experiment 88 | metric_best (str, optional): The metric for selecting the best 89 | validation performance. Options: auto, accuracy, auc. 90 | 91 | ''' 92 | results = {'train': None, 'val': None, 'test': None} 93 | results_best = {'train': None, 'val': None, 'test': None} 94 | for seed in os.listdir(dir): 95 | if is_seed(seed): 96 | dir_seed = os.path.join(dir, seed) 97 | 98 | split = 'val' 99 | if split in os.listdir(dir_seed): 100 | dir_split = os.path.join(dir_seed, split) 101 | fname_stats = os.path.join(dir_split, 'stats.json') 102 | stats_list = json_to_dict_list(fname_stats) 103 | if metric_best == 'auto': 104 | metric = 'auc' if 'auc' in stats_list[0] else 'accuracy' 105 | else: 106 | metric = metric_best 107 | performance_np = np.array( # noqa 108 | [stats[metric] for stats in stats_list]) 109 | best_epoch = \ 110 | stats_list[ 111 | eval("performance_np.{}()".format(cfg.metric_agg))][ 112 | 'epoch'] 113 | print(best_epoch) 114 | 115 | for split in os.listdir(dir_seed): 116 | if is_split(split): 117 | dir_split = os.path.join(dir_seed, split) 118 | fname_stats = os.path.join(dir_split, 'stats.json') 119 | stats_list = json_to_dict_list(fname_stats) 120 | stats_best = [ 121 | stats for stats in stats_list 122 | if stats['epoch'] == best_epoch 123 | ][0] 124 | print(stats_best) 125 | stats_list = [[stats] for stats in stats_list] 126 | if results[split] is None: 127 | results[split] = stats_list 128 | else: 129 | results[split] = join_list(results[split], stats_list) 130 | if results_best[split] is None: 131 | results_best[split] = [stats_best] 132 | else: 133 | results_best[split] += [stats_best] 134 | results = {k: v for k, v in results.items() if v is not None} # rm None 135 | results_best = {k: v 136 | for k, v in results_best.items() 137 | if v is not None} # rm None 138 | for key in results: 139 | for i in range(len(results[key])): 140 | results[key][i] = agg_dict_list(results[key][i]) 141 | for key in results_best: 142 | results_best[key] = agg_dict_list(results_best[key]) 143 | # save aggregated results 144 | for key, value in results.items(): 145 | dir_out = os.path.join(dir, 'agg', key) 146 | makedirs_rm_exist(dir_out) 147 | fname = os.path.join(dir_out, 'stats.json') 148 | dict_list_to_json(value, fname) 149 | 150 | if cfg.tensorboard_agg: 151 | if SummaryWriter is None: 152 | raise ImportError( 153 | 'Tensorboard support requires `tensorboardX`.') 154 | writer = SummaryWriter(dir_out) 155 | dict_list_to_tb(value, writer) 156 | writer.close() 157 | for key, value in results_best.items(): 158 | dir_out = os.path.join(dir, 'agg', key) 159 | fname = os.path.join(dir_out, 'best.json') 160 | dict_to_json(value, fname) 161 | logging.info('Results aggregated across runs saved in {}'.format( 162 | os.path.join(dir, 'agg'))) 163 | -------------------------------------------------------------------------------- /gssc/config/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /gssc/config/custom_gnn_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | 3 | 4 | @register_config('custom_gnn') 5 | def custom_gnn_cfg(cfg): 6 | """Extending config group of GraphGym's built-in GNN for purposes of our 7 | CustomGNN network model. 8 | """ 9 | 10 | # Use residual connections between the GNN layers. 11 | cfg.gnn.residual = False 12 | -------------------------------------------------------------------------------- /gssc/config/data_preprocess_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | from yacs.config import CfgNode as CN 3 | 4 | 5 | def set_cfg_preprocess(cfg): 6 | """Extend configuration with preprocessing options 7 | """ 8 | 9 | cfg.prep = CN() 10 | 11 | # Argument group for adding expander edges 12 | 13 | # if it's enabled expander edges would be available by e.g. data.expander_edges 14 | cfg.prep.exp = False 15 | cfg.prep.exp_algorithm = 'Random-d' #Other option is 'Hamiltonian' 16 | cfg.prep.use_exp_edges = True 17 | cfg.prep.exp_deg = 5 18 | cfg.prep.exp_max_num_iters = 100 19 | cfg.prep.add_edge_index = True 20 | cfg.prep.num_virt_node = 0 21 | cfg.prep.exp_count = 1 22 | cfg.prep.add_self_loops = False 23 | cfg.prep.add_reverse_edges = True 24 | cfg.prep.train_percent = 0.6 25 | cfg.prep.layer_edge_indices_dir = None 26 | 27 | 28 | 29 | # Argument group for adding node distances 30 | cfg.prep.dist_enable = False 31 | cfg.prep.dist_cutoff = 510 32 | 33 | 34 | register_config('preprocess', set_cfg_preprocess) 35 | -------------------------------------------------------------------------------- /gssc/config/dataset_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | 3 | 4 | @register_config('dataset_cfg') 5 | def dataset_cfg(cfg): 6 | """Dataset-specific config options. 7 | """ 8 | 9 | # The number of node types to expect in TypeDictNodeEncoder. 10 | cfg.dataset.node_encoder_num_types = 0 11 | 12 | # The number of edge types to expect in TypeDictEdgeEncoder. 13 | cfg.dataset.edge_encoder_num_types = 0 14 | 15 | # VOC/COCO Superpixels dataset version based on SLIC compactness parameter. 16 | cfg.dataset.slic_compactness = 10 17 | -------------------------------------------------------------------------------- /gssc/config/defaults_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | 3 | 4 | @register_config('overwrite_defaults') 5 | def overwrite_defaults_cfg(cfg): 6 | """Overwrite the default config values that are first set by GraphGym in 7 | torch_geometric.graphgym.config.set_cfg 8 | 9 | WARNING: At the time of writing, the order in which custom config-setting 10 | functions like this one are executed is random; see the referenced `set_cfg` 11 | Therefore never reset here config options that are custom added, only change 12 | those that exist in core GraphGym. 13 | """ 14 | 15 | # Overwrite default dataset name 16 | cfg.dataset.name = 'none' 17 | 18 | # Overwrite default rounding precision 19 | cfg.round = 5 20 | 21 | 22 | @register_config('extended_cfg') 23 | def extended_cfg(cfg): 24 | """General extended config options. 25 | """ 26 | 27 | # Additional name tag used in `run_dir` and `wandb_name` auto generation. 28 | cfg.name_tag = "" 29 | 30 | # In training, if True (and also cfg.train.enable_ckpt is True) then 31 | # always checkpoint the current best model based on validation performance, 32 | # instead, when False, follow cfg.train.eval_period checkpointing frequency. 33 | cfg.train.ckpt_best = False 34 | -------------------------------------------------------------------------------- /gssc/config/example.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | from yacs.config import CfgNode as CN 3 | 4 | 5 | @register_config('example') 6 | def set_cfg_example(cfg): 7 | r''' 8 | This function sets the default config value for customized options 9 | :return: customized configuration use by the experiment. 10 | ''' 11 | 12 | # ----------------------------------------------------------------------- # 13 | # Customized options 14 | # ----------------------------------------------------------------------- # 15 | 16 | # example argument 17 | cfg.example_arg = 'example' 18 | 19 | # example argument group 20 | cfg.example_group = CN() 21 | 22 | # then argument can be specified within the group 23 | cfg.example_group.example_arg = 'example' 24 | -------------------------------------------------------------------------------- /gssc/config/gt_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | from yacs.config import CfgNode as CN 3 | 4 | 5 | @register_config('cfg_gt') 6 | def set_cfg_gt(cfg): 7 | """Configuration for Graph Transformer-style models, e.g.: 8 | - Spectral Attention Network (SAN) Graph Transformer. 9 | - "vanilla" Transformer / Performer. 10 | - General Powerful Scalable (GPS) Model. 11 | """ 12 | 13 | # Positional encodings argument group 14 | cfg.gt = CN() 15 | 16 | # Type of Graph Transformer layer to use 17 | cfg.gt.layer_type = 'SANLayer' 18 | 19 | # Number of Transformer layers in the model 20 | cfg.gt.layers = 3 21 | 22 | # Number of attention heads in the Graph Transformer 23 | cfg.gt.n_heads = 8 24 | 25 | # Size of the hidden node and edge representation 26 | cfg.gt.dim_hidden = 64 27 | 28 | # Size of the edge embedding 29 | cfg.gt.dim_edge = None 30 | 31 | # Full attention SAN transformer including all possible pairwise edges 32 | cfg.gt.full_graph = True 33 | 34 | # Type of extra edges used for transformer 35 | cfg.gt.secondary_edges = 'full_graph' 36 | 37 | # SAN real vs fake edge attention weighting coefficient 38 | cfg.gt.gamma = 1e-5 39 | 40 | # Histogram of in-degrees of nodes in the training set used by PNAConv. 41 | # Used when `gt.layer_type: PNAConv+...`. If empty it is precomputed during 42 | # the dataset loading process. 43 | cfg.gt.pna_degrees = [] 44 | 45 | # Dropout in feed-forward module. 46 | cfg.gt.dropout = 0.0 47 | 48 | # Dropout in self-attention. 49 | cfg.gt.attn_dropout = 0.0 50 | 51 | cfg.gt.layer_norm = False 52 | 53 | cfg.gt.batch_norm = True 54 | 55 | cfg.gt.residual = True 56 | 57 | cfg.gt.activation = 'relu' 58 | 59 | # BigBird model/GPS-BigBird layer. 60 | cfg.gt.bigbird = CN() 61 | 62 | cfg.gt.bigbird.attention_type = "block_sparse" 63 | 64 | cfg.gt.bigbird.chunk_size_feed_forward = 0 65 | 66 | cfg.gt.bigbird.is_decoder = False 67 | 68 | cfg.gt.bigbird.add_cross_attention = False 69 | 70 | cfg.gt.bigbird.hidden_act = "relu" 71 | 72 | cfg.gt.bigbird.max_position_embeddings = 128 73 | 74 | cfg.gt.bigbird.use_bias = False 75 | 76 | cfg.gt.bigbird.num_random_blocks = 3 77 | 78 | cfg.gt.bigbird.block_size = 3 79 | 80 | cfg.gt.bigbird.layer_norm_eps = 1e-6 81 | -------------------------------------------------------------------------------- /gssc/config/optimizers_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | 3 | 4 | @register_config('extended_optim') 5 | def extended_optim_cfg(cfg): 6 | """Extend optimizer config group that is first set by GraphGym in 7 | torch_geometric.graphgym.config.set_cfg 8 | """ 9 | 10 | # Number of batches to accumulate gradients over before updating parameters 11 | # Requires `custom` training loop, set `train.mode: custom` 12 | cfg.optim.batch_accumulation = 1 13 | 14 | # ReduceLROnPlateau: Factor by which the learning rate will be reduced 15 | cfg.optim.reduce_factor = 0.1 16 | 17 | # ReduceLROnPlateau: #epochs without improvement after which LR gets reduced 18 | cfg.optim.schedule_patience = 10 19 | 20 | # ReduceLROnPlateau: Lower bound on the learning rate 21 | cfg.optim.min_lr = 0.0 22 | 23 | # For schedulers with warm-up phase, set the warm-up number of epochs 24 | cfg.optim.num_warmup_epochs = 50 25 | 26 | # Clip gradient norms while training 27 | cfg.optim.clip_grad_norm = False 28 | -------------------------------------------------------------------------------- /gssc/config/posenc_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | from yacs.config import CfgNode as CN 3 | 4 | 5 | @register_config('posenc') 6 | def set_cfg_posenc(cfg): 7 | """Extend configuration with positional encoding options. 8 | """ 9 | 10 | # Argument group for each Positional Encoding class. 11 | cfg.posenc_LapPE = CN() 12 | cfg.posenc_SignNet = CN() 13 | cfg.posenc_RWSE = CN() 14 | cfg.posenc_HKdiagSE = CN() 15 | cfg.posenc_ElstaticSE = CN() 16 | cfg.posenc_EquivStableLapPE = CN() 17 | 18 | # Effective Resistance Embeddings 19 | cfg.posenc_ERN = CN() #Effective Resistance for Nodes 20 | cfg.posenc_ERE = CN() #Effective Resistance for Edges 21 | 22 | # Common arguments to all PE types. 23 | for name in ['posenc_LapPE', 'posenc_SignNet', 24 | 'posenc_RWSE', 'posenc_HKdiagSE', 'posenc_ElstaticSE', 25 | 'posenc_ERN', 'posenc_ERE']: 26 | pecfg = getattr(cfg, name) 27 | 28 | # Use extended positional encodings 29 | pecfg.enable = False 30 | 31 | # Neural-net model type within the PE encoder: 32 | # 'DeepSet', 'Transformer', 'Linear', 'none', ... 33 | pecfg.model = 'none' 34 | 35 | # Size of Positional Encoding embedding 36 | pecfg.dim_pe = 16 37 | 38 | # Number of layers in PE encoder model 39 | pecfg.layers = 3 40 | 41 | # Number of attention heads in PE encoder when model == 'Transformer' 42 | pecfg.n_heads = 4 43 | 44 | # Number of layers to apply in LapPE encoder post its pooling stage 45 | pecfg.post_layers = 0 46 | 47 | # Choice of normalization applied to raw PE stats: 'none', 'BatchNorm' 48 | pecfg.raw_norm_type = 'none' 49 | 50 | # In addition to appending PE to the node features, pass them also as 51 | # a separate variable in the PyG graph batch object. 52 | pecfg.pass_as_var = False 53 | 54 | # Config for EquivStable LapPE 55 | cfg.posenc_EquivStableLapPE.enable = False 56 | cfg.posenc_EquivStableLapPE.raw_norm_type = 'none' 57 | 58 | # Config for Laplacian Eigen-decomposition for PEs that use it. 59 | for name in ['posenc_LapPE', 'posenc_SignNet', 'posenc_EquivStableLapPE']: 60 | pecfg = getattr(cfg, name) 61 | pecfg.eigen = CN() 62 | 63 | # The normalization scheme for the graph Laplacian: 'none', 'sym', or 'rw' 64 | pecfg.eigen.laplacian_norm = 'sym' 65 | 66 | # The normalization scheme for the eigen vectors of the Laplacian 67 | pecfg.eigen.eigvec_norm = 'L2' 68 | 69 | # Maximum number of top smallest frequencies & eigenvectors to use 70 | pecfg.eigen.max_freqs = 10 71 | 72 | # Config for SignNet-specific options. 73 | cfg.posenc_SignNet.phi_out_dim = 4 74 | cfg.posenc_SignNet.phi_hidden_dim = 64 75 | 76 | for name in ['posenc_RWSE', 'posenc_HKdiagSE', 'posenc_ElstaticSE']: 77 | pecfg = getattr(cfg, name) 78 | 79 | # Config for Kernel-based PE specific options. 80 | pecfg.kernel = CN() 81 | 82 | # List of times to compute the heat kernel for (the time is equivalent to 83 | # the variance of the kernel) / the number of steps for random walk kernel 84 | # Can be overridden by `posenc.kernel.times_func` 85 | pecfg.kernel.times = [] 86 | 87 | # Python snippet to generate `posenc.kernel.times`, e.g. 'range(1, 17)' 88 | # If set, it will be executed via `eval()` and override posenc.kernel.times 89 | pecfg.kernel.times_func = '' 90 | 91 | # Override default, electrostatic kernel has fixed set of 10 measures. 92 | cfg.posenc_ElstaticSE.kernel.times_func = 'range(10)' 93 | 94 | # Setting accuracy for Effective Resistance Calculations: 95 | cfg.posenc_ERN.accuracy = 0.1 96 | cfg.posenc_ERE.accuracy = 0.1 97 | 98 | # To be set during the calculations: 99 | cfg.posenc_ERN.er_dim = 'none' 100 | -------------------------------------------------------------------------------- /gssc/config/pretrained_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | from yacs.config import CfgNode as CN 3 | 4 | 5 | @register_config('cfg_pretrained') 6 | def set_cfg_pretrained(cfg): 7 | """Configuration options for loading a pretrained model. 8 | """ 9 | 10 | cfg.pretrained = CN() 11 | 12 | # Directory path to a saved experiment, if set, load the model from there 13 | # and fine-tune / run inference with it on a specified dataset. 14 | cfg.pretrained.dir = "" 15 | 16 | # Discard pretrained weights of the prediction head and reinitialize. 17 | cfg.pretrained.reset_prediction_head = True 18 | 19 | # Freeze the main pretrained 'body' of the model, learning only the new head 20 | cfg.pretrained.freeze_main = False 21 | -------------------------------------------------------------------------------- /gssc/config/split_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | 3 | 4 | @register_config('split') 5 | def set_cfg_split(cfg): 6 | """Reconfigure the default config value for dataset split options. 7 | 8 | Returns: 9 | Reconfigured split configuration use by the experiment. 10 | """ 11 | 12 | # Default to selecting the standard split that ships with the dataset 13 | cfg.dataset.split_mode = 'standard' 14 | 15 | # Choose a particular split to use if multiple splits are available 16 | cfg.dataset.split_index = 0 17 | 18 | # Dir to cache cross-validation splits 19 | cfg.dataset.split_dir = './splits' 20 | 21 | # Choose to run multiple splits in one program execution, if set, 22 | # takes the precedence over cfg.dataset.split_index for split selection 23 | cfg.run_multiple_splits = [] 24 | -------------------------------------------------------------------------------- /gssc/config/wandb_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | from yacs.config import CfgNode as CN 3 | 4 | 5 | @register_config('cfg_wandb') 6 | def set_cfg_wandb(cfg): 7 | """Weights & Biases tracker configuration. 8 | """ 9 | 10 | # WandB group 11 | cfg.wandb = CN() 12 | 13 | # Use wandb or not 14 | cfg.wandb.use = False 15 | 16 | # Wandb entity name, should exist beforehand 17 | cfg.wandb.entity = "gtransformers" 18 | 19 | # Wandb project name, will be created in your team if doesn't exist already 20 | cfg.wandb.project = "gtblueprint" 21 | 22 | # Optional run name 23 | cfg.wandb.name = "" 24 | -------------------------------------------------------------------------------- /gssc/encoder/ER_edge_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.graphgym.config import cfg 4 | from torch_geometric.graphgym.register import register_edge_encoder 5 | 6 | 7 | @register_edge_encoder('ERE') 8 | class EREdgeEncoder(torch.nn.Module): 9 | def __init__(self, emb_dim, use_edge_attr=False, expand_edge_attr=False): 10 | super().__init__() 11 | 12 | dim_in = cfg.gt.dim_edge # Expected final edge_dim 13 | 14 | pecfg = cfg.posenc_ERE 15 | n_layers = pecfg.layers # Num. layers in PE encoder model 16 | self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable 17 | 18 | self.use_edge_attr = use_edge_attr 19 | self.expand_edge_attr = expand_edge_attr 20 | if expand_edge_attr: 21 | self.linear_x = nn.Linear(dim_in, dim_in - emb_dim) 22 | 23 | if not self.use_edge_attr: 24 | assert emb_dim == dim_in 25 | 26 | layers = [] 27 | layers.append(nn.Linear(1, emb_dim)) 28 | layers.append(nn.ReLU()) 29 | if n_layers > 1: 30 | for _ in range(n_layers - 1): 31 | layers.append(nn.Linear(emb_dim, emb_dim)) 32 | layers.append(nn.ReLU()) 33 | self.er_encoder = nn.Sequential(*layers) 34 | 35 | def forward(self, batch): 36 | ere = self.er_encoder(batch.er_edge) 37 | if self.expand_edge_attr: 38 | batch.edge_attr = self.linear_x(batch.edge_attr) 39 | 40 | if self.use_edge_attr: 41 | batch.edge_attr = torch.cat([batch.edge_attr, ere], dim=1) 42 | else: 43 | batch.edge_attr = ere 44 | 45 | if self.pass_as_var: 46 | batch.er_edge = ere 47 | 48 | return batch -------------------------------------------------------------------------------- /gssc/encoder/ER_node_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.graphgym.config import cfg 4 | from torch_geometric.graphgym.register import register_node_encoder 5 | 6 | 7 | @register_node_encoder('ERN') 8 | class ERNodeEncoder(torch.nn.Module): 9 | """Effective Resistance Node Encoder 10 | 11 | ER of size dim_pe will get appended to each node feature vector. 12 | If `expand_x` set True, original node features will be first linearly 13 | projected to (dim_emb - dim_pe) size and the concatenated with ER. 14 | 15 | Args: 16 | dim_emb: Size of final node embedding 17 | expand_x: Expand node features `x` from dim_in to (dim_emb - dim_pe) 18 | """ 19 | 20 | def __init__(self, dim_emb, expand_x=True): 21 | super().__init__() 22 | dim_in = cfg.share.dim_in # Expected original input node features dim 23 | 24 | pecfg = cfg.posenc_ERN 25 | dim_pe = pecfg.dim_pe # Size of Laplace PE embedding 26 | model_type = pecfg.model # Encoder NN model type for DEs 27 | if model_type not in ['Transformer', 'DeepSet', 'Linear']: 28 | raise ValueError(f"Unexpected PE model {model_type}") 29 | self.model_type = model_type 30 | n_layers = pecfg.layers # Num. layers in PE encoder model 31 | n_heads = pecfg.n_heads # Num. attention heads in Trf PE encoder 32 | post_n_layers = pecfg.post_layers # Num. layers to apply after pooling 33 | self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable 34 | 35 | er_dim = pecfg.er_dim 36 | 37 | if dim_emb - dim_pe < 1: 38 | raise ValueError(f"ER_Node size {dim_pe} is too large for " 39 | f"desired embedding size of {dim_emb}.") 40 | 41 | if expand_x: 42 | self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe) 43 | self.expand_x = expand_x 44 | 45 | if model_type == 'Linear': 46 | self.pe_encoder = nn.Linear(er_dim, dim_pe) 47 | 48 | else: 49 | if model_type == 'Transformer': 50 | # Initial projection of each value of ER embedding 51 | self.linear_A = nn.Linear(1, dim_pe) 52 | # Transformer model for ER_Node 53 | encoder_layer = nn.TransformerEncoderLayer(d_model=dim_pe, 54 | nhead=n_heads, 55 | batch_first=True) 56 | self.pe_encoder = nn.TransformerEncoder(encoder_layer, 57 | num_layers=n_layers) 58 | else: 59 | # DeepSet model for ER_Node 60 | layers = [] 61 | if n_layers == 1: 62 | layers.append(nn.ReLU()) 63 | else: 64 | self.linear_A = nn.Linear(1, dim_pe) 65 | layers.append(nn.ReLU()) 66 | for _ in range(n_layers - 1): 67 | layers.append(nn.Linear(dim_pe, dim_pe)) 68 | layers.append(nn.ReLU()) 69 | self.pe_encoder = nn.Sequential(*layers) 70 | 71 | self.post_mlp = None 72 | if post_n_layers > 0: 73 | # MLP to apply post pooling 74 | layers = [] 75 | if post_n_layers == 1: 76 | layers.append(nn.Linear(dim_pe, dim_pe)) 77 | layers.append(nn.ReLU()) 78 | else: 79 | layers.append(nn.Linear(dim_pe, 2 * dim_pe)) 80 | layers.append(nn.ReLU()) 81 | for _ in range(post_n_layers - 2): 82 | layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe)) 83 | layers.append(nn.ReLU()) 84 | layers.append(nn.Linear(2 * dim_pe, dim_pe)) 85 | layers.append(nn.ReLU()) 86 | self.post_mlp = nn.Sequential(*layers) 87 | 88 | 89 | def forward(self, batch): 90 | if not hasattr(batch, 'er_emb'): 91 | raise ValueError("Precomputed ER embeddings required for calculating ER Node Encodings") 92 | 93 | pos_enc = batch.er_emb # N * er_dim 94 | 95 | if self.training: 96 | pos_enc = pos_enc[:, torch.randperm(pos_enc.size()[1])] 97 | 98 | if self.model_type == 'Linear': 99 | pos_enc = self.pe_encoder(pos_enc) # N * er_dim -> N * dim_pe 100 | 101 | else: 102 | pos_enc = torch.unsqueeze(pos_enc, 2) 103 | pos_enc = self.linear_A(pos_enc) # (Num nodes) x (er_dim) x dim_pe 104 | 105 | # PE encoder: a Transformer or DeepSet model 106 | if self.model_type == 'Transformer': 107 | pos_enc = self.pe_encoder(src=pos_enc) 108 | else: 109 | pos_enc = self.pe_encoder(pos_enc) 110 | 111 | # Sum pooling 112 | pos_enc = torch.sum(pos_enc, 1, keepdim=False) # (Num nodes) x dim_pe 113 | 114 | # MLP post pooling 115 | if self.post_mlp is not None: 116 | pos_enc = self.post_mlp(pos_enc) # (Num nodes) x dim_pe 117 | 118 | # Expand node features if needed 119 | if self.expand_x: 120 | h = self.linear_x(batch.x) 121 | else: 122 | h = batch.x 123 | # Concatenate final PEs to input embedding 124 | batch.x = torch.cat((h, pos_enc), 1) 125 | # Keep PE also separate in a variable (e.g. for skip connections to input) 126 | if self.pass_as_var: 127 | batch.pe_ern = pos_enc 128 | return batch 129 | -------------------------------------------------------------------------------- /gssc/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /gssc/encoder/ast_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.graphgym.register import (register_node_encoder, 3 | register_edge_encoder) 4 | 5 | """ 6 | === Description of the ogbg-code2 dataset === 7 | 8 | * Node Encoder code based on OGB's: 9 | https://github.com/snap-stanford/ogb/blob/master/examples/graphproppred/code2/utils.py 10 | 11 | Node Encoder config parameters are set based on the OGB example: 12 | https://github.com/snap-stanford/ogb/blob/master/examples/graphproppred/code2/main_pyg.py 13 | where the following three node features are used: 14 | 1. node type 15 | 2. node attribute 16 | 3. node depth 17 | 18 | nodetypes_mapping = pd.read_csv(os.path.join(dataset.root, 'mapping', 'typeidx2type.csv.gz')) 19 | nodeattributes_mapping = pd.read_csv(os.path.join(dataset.root, 'mapping', 'attridx2attr.csv.gz')) 20 | num_nodetypes = len(nodetypes_mapping['type']) 21 | num_nodeattributes = len(nodeattributes_mapping['attr']) 22 | max_depth = 20 23 | 24 | * Edge attributes are generated by `augment_edge` function dynamically: 25 | edge_attr[:,0]: whether it is AST edge (0) for next-token edge (1) 26 | edge_attr[:,1]: whether it is original direction (0) or inverse direction (1) 27 | """ 28 | 29 | num_nodetypes = 98 30 | num_nodeattributes = 10030 31 | max_depth = 20 32 | 33 | 34 | @register_node_encoder('ASTNode') 35 | class ASTNodeEncoder(torch.nn.Module): 36 | """The Abstract Syntax Tree (AST) Node Encoder used for ogbg-code2 dataset. 37 | 38 | Input: 39 | x: Default node feature. The first and second column represents node 40 | type and node attributes. 41 | node_depth: The depth of the node in the AST. 42 | Output: 43 | emb_dim-dimensional vector 44 | """ 45 | 46 | def __init__(self, emb_dim): 47 | super().__init__() 48 | self.max_depth = max_depth 49 | 50 | self.type_encoder = torch.nn.Embedding(num_nodetypes, emb_dim) 51 | self.attribute_encoder = torch.nn.Embedding(num_nodeattributes, emb_dim) 52 | self.depth_encoder = torch.nn.Embedding(self.max_depth + 1, emb_dim) 53 | 54 | def forward(self, batch): 55 | x = batch.x 56 | depth = batch.node_depth.view(-1, ) 57 | depth[depth > self.max_depth] = self.max_depth 58 | batch.x = self.type_encoder(x[:, 0]) + self.attribute_encoder(x[:, 1]) \ 59 | + self.depth_encoder(depth) 60 | return batch 61 | 62 | 63 | @register_edge_encoder('ASTEdge') 64 | class ASTEdgeEncoder(torch.nn.Module): 65 | """The Abstract Syntax Tree (AST) Edge Encoder used for ogbg-code2 dataset. 66 | 67 | Edge attributes are generated by `augment_edge` function dynamically and 68 | are expected to be: 69 | edge_attr[:,0]: whether it is AST edge (0) for next-token edge (1) 70 | edge_attr[:,1]: whether it is original direction (0) or inverse direction (1) 71 | 72 | Args: 73 | emb_dim (int): Output edge embedding dimension 74 | """ 75 | 76 | def __init__(self, emb_dim): 77 | super().__init__() 78 | self.embedding_type = torch.nn.Embedding(2, emb_dim) 79 | self.embedding_direction = torch.nn.Embedding(2, emb_dim) 80 | 81 | def forward(self, batch): 82 | embedding = self.embedding_type(batch.edge_attr[:, 0]) + \ 83 | self.embedding_direction(batch.edge_attr[:, 1]) 84 | batch.edge_attr = embedding 85 | return batch 86 | -------------------------------------------------------------------------------- /gssc/encoder/composed_encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.graphgym.config import cfg 3 | from torch_geometric.graphgym.models.encoder import AtomEncoder 4 | from torch_geometric.graphgym.register import register_node_encoder 5 | 6 | from gssc.encoder.ast_encoder import ASTNodeEncoder 7 | from gssc.encoder.kernel_pos_encoder import RWSENodeEncoder, \ 8 | HKdiagSENodeEncoder, ElstaticSENodeEncoder 9 | from gssc.encoder.laplace_pos_encoder import LapPENodeEncoder 10 | from gssc.encoder.ppa_encoder import PPANodeEncoder 11 | from gssc.encoder.signnet_pos_encoder import SignNetNodeEncoder 12 | from gssc.encoder.voc_superpixels_encoder import VOCNodeEncoder 13 | from gssc.encoder.type_dict_encoder import TypeDictNodeEncoder 14 | from gssc.encoder.linear_node_encoder import LinearNodeEncoder 15 | from gssc.encoder.equivstable_laplace_pos_encoder import EquivStableLapPENodeEncoder 16 | from gssc.encoder.ER_node_encoder import ERNodeEncoder 17 | 18 | 19 | def concat_node_encoders(encoder_classes, pe_enc_names): 20 | """ 21 | A factory that creates a new Encoder class that concatenates functionality 22 | of the given list of two or three Encoder classes. First Encoder is expected 23 | to be a dataset-specific encoder, and the rest PE Encoders. 24 | 25 | Args: 26 | encoder_classes: List of node encoder classes 27 | pe_enc_names: List of PE embedding Encoder names, used to query a dict 28 | with their desired PE embedding dims. That dict can only be created 29 | during the runtime, once the config is loaded. 30 | 31 | Returns: 32 | new node encoder class 33 | """ 34 | 35 | class Concat2NodeEncoder(torch.nn.Module): 36 | """Encoder that concatenates two node encoders. 37 | """ 38 | enc1_cls = None 39 | enc2_cls = None 40 | enc2_name = None 41 | 42 | def __init__(self, dim_emb): 43 | super().__init__() 44 | 45 | if cfg.posenc_EquivStableLapPE.enable: # Special handling for Equiv_Stable LapPE where node feats and PE are not concat 46 | self.encoder1 = self.enc1_cls(dim_emb) 47 | self.encoder2 = self.enc2_cls(dim_emb) 48 | else: 49 | # PE dims can only be gathered once the cfg is loaded. 50 | enc2_dim_pe = getattr(cfg, f"posenc_{self.enc2_name}").dim_pe 51 | 52 | self.encoder1 = self.enc1_cls(dim_emb - enc2_dim_pe) 53 | self.encoder2 = self.enc2_cls(dim_emb, expand_x=False) 54 | 55 | def forward(self, batch): 56 | batch = self.encoder1(batch) 57 | batch = self.encoder2(batch) 58 | return batch 59 | 60 | class Concat3NodeEncoder(torch.nn.Module): 61 | """Encoder that concatenates three node encoders. 62 | """ 63 | enc1_cls = None 64 | enc2_cls = None 65 | enc2_name = None 66 | enc3_cls = None 67 | enc3_name = None 68 | 69 | def __init__(self, dim_emb): 70 | super().__init__() 71 | # PE dims can only be gathered once the cfg is loaded. 72 | enc2_dim_pe = getattr(cfg, f"posenc_{self.enc2_name}").dim_pe 73 | enc3_dim_pe = getattr(cfg, f"posenc_{self.enc3_name}").dim_pe 74 | self.encoder1 = self.enc1_cls(dim_emb - enc2_dim_pe - enc3_dim_pe) 75 | self.encoder2 = self.enc2_cls(dim_emb - enc3_dim_pe, expand_x=False) 76 | self.encoder3 = self.enc3_cls(dim_emb, expand_x=False) 77 | 78 | def forward(self, batch): 79 | batch = self.encoder1(batch) 80 | batch = self.encoder2(batch) 81 | batch = self.encoder3(batch) 82 | return batch 83 | 84 | # Configure the correct concatenation class and return it. 85 | if len(encoder_classes) == 2: 86 | Concat2NodeEncoder.enc1_cls = encoder_classes[0] 87 | Concat2NodeEncoder.enc2_cls = encoder_classes[1] 88 | Concat2NodeEncoder.enc2_name = pe_enc_names[0] 89 | return Concat2NodeEncoder 90 | elif len(encoder_classes) == 3: 91 | Concat3NodeEncoder.enc1_cls = encoder_classes[0] 92 | Concat3NodeEncoder.enc2_cls = encoder_classes[1] 93 | Concat3NodeEncoder.enc3_cls = encoder_classes[2] 94 | Concat3NodeEncoder.enc2_name = pe_enc_names[0] 95 | Concat3NodeEncoder.enc3_name = pe_enc_names[1] 96 | return Concat3NodeEncoder 97 | else: 98 | raise ValueError(f"Does not support concatenation of " 99 | f"{len(encoder_classes)} encoder classes.") 100 | 101 | 102 | # Dataset-specific node encoders. 103 | ds_encs = {'Atom': AtomEncoder, 104 | 'ASTNode': ASTNodeEncoder, 105 | 'PPANode': PPANodeEncoder, 106 | 'TypeDictNode': TypeDictNodeEncoder, 107 | 'VOCNode': VOCNodeEncoder, 108 | 'LinearNode': LinearNodeEncoder} 109 | 110 | # Positional Encoding node encoders. 111 | pe_encs = {'LapPE': LapPENodeEncoder, 112 | 'RWSE': RWSENodeEncoder, 113 | 'HKdiagSE': HKdiagSENodeEncoder, 114 | 'ElstaticSE': ElstaticSENodeEncoder, 115 | 'SignNet': SignNetNodeEncoder, 116 | 'EquivStableLapPE': EquivStableLapPENodeEncoder, 117 | 'ERN': ERNodeEncoder} 118 | 119 | # Concat dataset-specific and PE encoders. 120 | for ds_enc_name, ds_enc_cls in ds_encs.items(): 121 | for pe_enc_name, pe_enc_cls in pe_encs.items(): 122 | register_node_encoder( 123 | f"{ds_enc_name}+{pe_enc_name}", 124 | concat_node_encoders([ds_enc_cls, pe_enc_cls], 125 | [pe_enc_name]) 126 | ) 127 | 128 | # Combine both LapPE and RWSE positional encodings. 129 | for ds_enc_name, ds_enc_cls in ds_encs.items(): 130 | register_node_encoder( 131 | f"{ds_enc_name}+LapPE+RWSE", 132 | concat_node_encoders([ds_enc_cls, LapPENodeEncoder, RWSENodeEncoder], 133 | ['LapPE', 'RWSE']) 134 | ) 135 | 136 | # Combine both SignNet and RWSE positional encodings. 137 | for ds_enc_name, ds_enc_cls in ds_encs.items(): 138 | register_node_encoder( 139 | f"{ds_enc_name}+SignNet+RWSE", 140 | concat_node_encoders([ds_enc_cls, SignNetNodeEncoder, RWSENodeEncoder], 141 | ['SignNet', 'RWSE']) 142 | ) 143 | -------------------------------------------------------------------------------- /gssc/encoder/dummy_edge_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.graphgym.register import register_edge_encoder 3 | 4 | 5 | @register_edge_encoder('DummyEdge') 6 | class DummyEdgeEncoder(torch.nn.Module): 7 | def __init__(self, emb_dim): 8 | super().__init__() 9 | 10 | self.encoder = torch.nn.Embedding(num_embeddings=1, 11 | embedding_dim=emb_dim) 12 | # torch.nn.init.xavier_uniform_(self.encoder.weight.data) 13 | 14 | def forward(self, batch): 15 | dummy_attr = batch.edge_index.new_zeros(batch.edge_index.shape[1]) 16 | batch.edge_attr = self.encoder(dummy_attr) 17 | return batch 18 | -------------------------------------------------------------------------------- /gssc/encoder/equivstable_laplace_pos_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.graphgym.config import cfg 4 | from torch_geometric.graphgym.register import register_node_encoder 5 | 6 | 7 | @register_node_encoder('EquivStableLapPE') 8 | class EquivStableLapPENodeEncoder(torch.nn.Module): 9 | """Equivariant and Stable Laplace Positional Embedding node encoder. 10 | 11 | This encoder simply transforms the k-dim node LapPE to d-dim to be 12 | later used at the local GNN module as edge weights. 13 | Based on the approach proposed in paper https://openreview.net/pdf?id=e95i1IHcWj 14 | 15 | Args: 16 | dim_emb: Size of final node embedding 17 | """ 18 | 19 | def __init__(self, dim_emb): 20 | super().__init__() 21 | 22 | pecfg = cfg.posenc_EquivStableLapPE 23 | max_freqs = pecfg.eigen.max_freqs # Num. eigenvectors (frequencies) 24 | norm_type = pecfg.raw_norm_type.lower() # Raw PE normalization layer type 25 | 26 | if norm_type == 'batchnorm': 27 | self.raw_norm = nn.BatchNorm1d(max_freqs) 28 | else: 29 | self.raw_norm = None 30 | 31 | self.linear_encoder_eigenvec = nn.Linear(max_freqs, dim_emb) 32 | 33 | def forward(self, batch): 34 | if not (hasattr(batch, 'EigVals') and hasattr(batch, 'EigVecs')): 35 | raise ValueError("Precomputed eigen values and vectors are " 36 | f"required for {self.__class__.__name__}; set " 37 | f"config 'posenc_EquivStableLapPE.enable' to True") 38 | pos_enc = batch.EigVecs 39 | 40 | empty_mask = torch.isnan(pos_enc) # (Num nodes) x (Num Eigenvectors) 41 | pos_enc[empty_mask] = 0. # (Num nodes) x (Num Eigenvectors) 42 | 43 | if self.raw_norm: 44 | pos_enc = self.raw_norm(pos_enc) 45 | 46 | pos_enc = self.linear_encoder_eigenvec(pos_enc) 47 | 48 | # Keep PE separate in a variable 49 | batch.pe_EquivStableLapPE = pos_enc 50 | 51 | return batch 52 | -------------------------------------------------------------------------------- /gssc/encoder/example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ogb.utils.features import get_bond_feature_dims 3 | 4 | from torch_geometric.graphgym.register import ( 5 | register_edge_encoder, 6 | register_node_encoder, 7 | ) 8 | 9 | 10 | @register_node_encoder('example') 11 | class ExampleNodeEncoder(torch.nn.Module): 12 | """ 13 | Provides an encoder for integer node features 14 | Parameters: 15 | num_classes - the number of classes for the embedding mapping to learn 16 | """ 17 | def __init__(self, emb_dim, num_classes=None): 18 | super().__init__() 19 | 20 | self.encoder = torch.nn.Embedding(num_classes, emb_dim) 21 | torch.nn.init.xavier_uniform_(self.encoder.weight.data) 22 | 23 | def forward(self, batch): 24 | # Encode just the first dimension if more exist 25 | batch.x = self.encoder(batch.x[:, 0]) 26 | 27 | return batch 28 | 29 | 30 | @register_edge_encoder('example') 31 | class ExampleEdgeEncoder(torch.nn.Module): 32 | def __init__(self, emb_dim): 33 | super().__init__() 34 | 35 | self.bond_embedding_list = torch.nn.ModuleList() 36 | full_bond_feature_dims = get_bond_feature_dims() 37 | 38 | for i, dim in enumerate(full_bond_feature_dims): 39 | emb = torch.nn.Embedding(dim, emb_dim) 40 | torch.nn.init.xavier_uniform_(emb.weight.data) 41 | self.bond_embedding_list.append(emb) 42 | 43 | def forward(self, batch): 44 | bond_embedding = 0 45 | for i in range(batch.edge_feature.shape[1]): 46 | bond_embedding += \ 47 | self.bond_embedding_list[i](batch.edge_attr[:, i]) 48 | 49 | batch.edge_attr = bond_embedding 50 | return batch 51 | -------------------------------------------------------------------------------- /gssc/encoder/exp_edge_fixer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch_scatter import scatter 6 | 7 | from torch_geometric.graphgym.config import cfg 8 | from torch_geometric.graphgym.register import register_layer 9 | 10 | 11 | class ExpanderEdgeFixer(nn.Module): 12 | ''' 13 | Gets the batch and sets new edge indices + global nodes 14 | ''' 15 | def __init__(self, add_edge_index=False, num_virt_node=0): 16 | 17 | super().__init__() 18 | 19 | if not hasattr(cfg.gt, 'dim_edge') or cfg.gt.dim_edge is None: 20 | cfg.gt.dim_edge = cfg.gt.dim_hidden 21 | 22 | self.add_edge_index = add_edge_index 23 | self.num_virt_node = num_virt_node 24 | self.exp_edge_attr = nn.Embedding(1, cfg.gt.dim_edge) 25 | self.use_exp_edges = cfg.prep.use_exp_edges and cfg.prep.exp 26 | 27 | if self.num_virt_node > 0: 28 | self.virt_node_emb = nn.Embedding(self.num_virt_node, cfg.gt.dim_hidden) 29 | self.virt_edge_out_emb = nn.Embedding(self.num_virt_node, cfg.gt.dim_edge) 30 | self.virt_edge_in_emb = nn.Embedding(self.num_virt_node, cfg.gt.dim_edge) 31 | 32 | 33 | def forward(self, batch): 34 | edge_types = [] 35 | device = self.exp_edge_attr.weight.device 36 | edge_index_sets = [] 37 | edge_attr_sets = [] 38 | if self.add_edge_index: 39 | edge_index_sets.append(batch.edge_index) 40 | edge_attr_sets.append(batch.edge_attr) 41 | edge_types.append(torch.zeros(batch.edge_index.shape[1], dtype=torch.long)) 42 | 43 | 44 | num_node = batch.batch.shape[0] 45 | num_graphs = batch.num_graphs 46 | 47 | if self.use_exp_edges: 48 | if not hasattr(batch, 'expander_edges'): 49 | raise ValueError('expander edges not stored in data') 50 | 51 | data_list = batch.to_data_list() 52 | exp_edges = [] 53 | cumulative_num_nodes = 0 54 | for data in data_list: 55 | exp_edges.append(data.expander_edges + cumulative_num_nodes) 56 | cumulative_num_nodes += data.num_nodes 57 | 58 | exp_edges = torch.cat(exp_edges, dim=0).t() 59 | edge_index_sets.append(exp_edges) 60 | edge_attr_sets.append(self.exp_edge_attr(torch.zeros(exp_edges.shape[1], dtype=torch.long).to(device))) 61 | edge_types.append(torch.zeros(exp_edges.shape[1], dtype=torch.long) + 1) 62 | 63 | if self.num_virt_node > 0: 64 | global_h = [] 65 | virt_edges = [] 66 | virt_edge_attrs = [] 67 | for idx in range(self.num_virt_node): 68 | global_h.append(self.virt_node_emb(torch.zeros(num_graphs, dtype=torch.long).to(device)+idx)) 69 | virt_edge_index = torch.cat([torch.arange(num_node).view(1, -1).to(device), 70 | (batch.batch+(num_node+idx*num_graphs)).view(1, -1)], dim=0) 71 | virt_edges.append(virt_edge_index) 72 | virt_edge_attrs.append(self.virt_edge_in_emb(torch.zeros(virt_edge_index.shape[1], dtype=torch.long).to(device)+idx)) 73 | 74 | virt_edge_index = torch.cat([(batch.batch+(num_node+idx*num_graphs)).view(1, -1), 75 | torch.arange(num_node).view(1, -1).to(device)], dim=0) 76 | virt_edges.append(virt_edge_index) 77 | virt_edge_attrs.append(self.virt_edge_out_emb(torch.zeros(virt_edge_index.shape[1], dtype=torch.long).to(device)+idx)) 78 | 79 | batch.virt_h = torch.cat(global_h, dim=0) 80 | batch.virt_edge_index = torch.cat(virt_edges, dim=1) 81 | batch.virt_edge_attr = torch.cat(virt_edge_attrs, dim=0) 82 | edge_types.append(torch.zeros(batch.virt_edge_index.shape[1], dtype=torch.long)+2) 83 | 84 | if len(edge_index_sets) > 1: 85 | edge_index = torch.cat(edge_index_sets, dim=1) 86 | edge_attr = torch.cat(edge_attr_sets, dim=0) 87 | edge_types = torch.cat(edge_types) 88 | else: 89 | edge_index = edge_index_sets[0] 90 | edge_attr = edge_attr_sets[0] 91 | edge_types = edge_types[0] 92 | 93 | del batch.expander_edges 94 | batch.expander_edge_index = edge_index 95 | batch.expander_edge_attr = edge_attr 96 | 97 | return batch -------------------------------------------------------------------------------- /gssc/encoder/kernel_pos_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.graphgym.config import cfg 4 | from torch_geometric.graphgym.register import register_node_encoder 5 | 6 | 7 | class KernelPENodeEncoder(torch.nn.Module): 8 | """Configurable kernel-based Positional Encoding node encoder. 9 | 10 | The choice of which kernel-based statistics to use is configurable through 11 | setting of `kernel_type`. Based on this, the appropriate config is selected, 12 | and also the appropriate variable with precomputed kernel stats is then 13 | selected from PyG Data graphs in `forward` function. 14 | E.g., supported are 'RWSE', 'HKdiagSE', 'ElstaticSE'. 15 | 16 | PE of size `dim_pe` will get appended to each node feature vector. 17 | If `expand_x` set True, original node features will be first linearly 18 | projected to (dim_emb - dim_pe) size and the concatenated with PE. 19 | 20 | Args: 21 | dim_emb: Size of final node embedding 22 | expand_x: Expand node features `x` from dim_in to (dim_emb - dim_pe) 23 | """ 24 | 25 | kernel_type = None # Instantiated type of the KernelPE, e.g. RWSE 26 | 27 | def __init__(self, dim_emb, expand_x=True): 28 | super().__init__() 29 | if self.kernel_type is None: 30 | raise ValueError(f"{self.__class__.__name__} has to be " 31 | f"preconfigured by setting 'kernel_type' class" 32 | f"variable before calling the constructor.") 33 | 34 | dim_in = cfg.share.dim_in # Expected original input node features dim 35 | 36 | pecfg = getattr(cfg, f"posenc_{self.kernel_type}") 37 | dim_pe = pecfg.dim_pe # Size of the kernel-based PE embedding 38 | num_rw_steps = len(pecfg.kernel.times) 39 | model_type = pecfg.model.lower() # Encoder NN model type for PEs 40 | n_layers = pecfg.layers # Num. layers in PE encoder model 41 | norm_type = pecfg.raw_norm_type.lower() # Raw PE normalization layer type 42 | self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable 43 | 44 | if dim_emb - dim_pe < 1: 45 | raise ValueError(f"PE dim size {dim_pe} is too large for " 46 | f"desired embedding size of {dim_emb}.") 47 | 48 | if expand_x: 49 | self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe) 50 | self.expand_x = expand_x 51 | 52 | if norm_type == 'batchnorm': 53 | self.raw_norm = nn.BatchNorm1d(num_rw_steps) 54 | else: 55 | self.raw_norm = None 56 | 57 | if model_type == 'mlp': 58 | layers = [] 59 | if n_layers == 1: 60 | layers.append(nn.Linear(num_rw_steps, dim_pe)) 61 | layers.append(nn.ReLU()) 62 | else: 63 | layers.append(nn.Linear(num_rw_steps, 2 * dim_pe)) 64 | layers.append(nn.ReLU()) 65 | for _ in range(n_layers - 2): 66 | layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe)) 67 | layers.append(nn.ReLU()) 68 | layers.append(nn.Linear(2 * dim_pe, dim_pe)) 69 | layers.append(nn.ReLU()) 70 | self.pe_encoder = nn.Sequential(*layers) 71 | elif model_type == 'linear': 72 | self.pe_encoder = nn.Linear(num_rw_steps, dim_pe) 73 | else: 74 | raise ValueError(f"{self.__class__.__name__}: Does not support " 75 | f"'{model_type}' encoder model.") 76 | 77 | def forward(self, batch): 78 | pestat_var = f"pestat_{self.kernel_type}" 79 | if not hasattr(batch, pestat_var): 80 | raise ValueError(f"Precomputed '{pestat_var}' variable is " 81 | f"required for {self.__class__.__name__}; set " 82 | f"config 'posenc_{self.kernel_type}.enable' to " 83 | f"True, and also set 'posenc.kernel.times' values") 84 | 85 | pos_enc = getattr(batch, pestat_var) # (Num nodes) x (Num kernel times) 86 | # pos_enc = batch.rw_landing # (Num nodes) x (Num kernel times) 87 | if self.raw_norm: 88 | pos_enc = self.raw_norm(pos_enc) 89 | pos_enc = self.pe_encoder(pos_enc) # (Num nodes) x dim_pe 90 | 91 | # Expand node features if needed 92 | if self.expand_x: 93 | h = self.linear_x(batch.x.to(torch.float32)) 94 | else: 95 | h = batch.x.to(torch.float32) 96 | # Concatenate final PEs to input embedding 97 | batch.x = torch.cat((h, pos_enc), 1) 98 | # Keep PE also separate in a variable (e.g. for skip connections to input) 99 | if self.pass_as_var: 100 | setattr(batch, f'pe_{self.kernel_type}', pos_enc) 101 | return batch 102 | 103 | 104 | @register_node_encoder('RWSE') 105 | class RWSENodeEncoder(KernelPENodeEncoder): 106 | """Random Walk Structural Encoding node encoder. 107 | """ 108 | kernel_type = 'RWSE' 109 | 110 | 111 | @register_node_encoder('HKdiagSE') 112 | class HKdiagSENodeEncoder(KernelPENodeEncoder): 113 | """Heat kernel (diagonal) Structural Encoding node encoder. 114 | """ 115 | kernel_type = 'HKdiagSE' 116 | 117 | 118 | @register_node_encoder('ElstaticSE') 119 | class ElstaticSENodeEncoder(KernelPENodeEncoder): 120 | """Electrostatic interactions Structural Encoding node encoder. 121 | """ 122 | kernel_type = 'ElstaticSE' 123 | -------------------------------------------------------------------------------- /gssc/encoder/laplace_pos_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.graphgym.config import cfg 4 | from torch_geometric.graphgym.register import register_node_encoder 5 | 6 | 7 | @register_node_encoder('LapPE') 8 | class LapPENodeEncoder(torch.nn.Module): 9 | """Laplace Positional Embedding node encoder. 10 | 11 | LapPE of size dim_pe will get appended to each node feature vector. 12 | If `expand_x` set True, original node features will be first linearly 13 | projected to (dim_emb - dim_pe) size and the concatenated with LapPE. 14 | 15 | Args: 16 | dim_emb: Size of final node embedding 17 | expand_x: Expand node features `x` from dim_in to (dim_emb - dim_pe) 18 | """ 19 | 20 | def __init__(self, dim_emb, expand_x=True): 21 | super().__init__() 22 | dim_in = cfg.share.dim_in # Expected original input node features dim 23 | 24 | pecfg = cfg.posenc_LapPE 25 | dim_pe = pecfg.dim_pe # Size of Laplace PE embedding 26 | model_type = pecfg.model # Encoder NN model type for PEs 27 | if model_type not in ['Transformer', 'DeepSet']: 28 | raise ValueError(f"Unexpected PE model {model_type}") 29 | self.model_type = model_type 30 | n_layers = pecfg.layers # Num. layers in PE encoder model 31 | n_heads = pecfg.n_heads # Num. attention heads in Trf PE encoder 32 | post_n_layers = pecfg.post_layers # Num. layers to apply after pooling 33 | max_freqs = pecfg.eigen.max_freqs # Num. eigenvectors (frequencies) 34 | norm_type = pecfg.raw_norm_type.lower() # Raw PE normalization layer type 35 | self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable 36 | 37 | if dim_emb - dim_pe < 1: 38 | raise ValueError(f"LapPE size {dim_pe} is too large for " 39 | f"desired embedding size of {dim_emb}.") 40 | 41 | if expand_x: 42 | self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe) 43 | self.expand_x = expand_x 44 | 45 | # Initial projection of eigenvalue and the node's eigenvector value 46 | self.linear_A = nn.Linear(2, dim_pe) 47 | if norm_type == 'batchnorm': 48 | self.raw_norm = nn.BatchNorm1d(max_freqs) 49 | else: 50 | self.raw_norm = None 51 | 52 | if model_type == 'Transformer': 53 | # Transformer model for LapPE 54 | encoder_layer = nn.TransformerEncoderLayer(d_model=dim_pe, 55 | nhead=n_heads, 56 | batch_first=True) 57 | self.pe_encoder = nn.TransformerEncoder(encoder_layer, 58 | num_layers=n_layers) 59 | else: 60 | # DeepSet model for LapPE 61 | layers = [] 62 | if n_layers == 1: 63 | layers.append(nn.ReLU()) 64 | else: 65 | self.linear_A = nn.Linear(2, 2 * dim_pe) 66 | layers.append(nn.ReLU()) 67 | for _ in range(n_layers - 2): 68 | layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe)) 69 | layers.append(nn.ReLU()) 70 | layers.append(nn.Linear(2 * dim_pe, dim_pe)) 71 | layers.append(nn.ReLU()) 72 | self.pe_encoder = nn.Sequential(*layers) 73 | 74 | self.post_mlp = None 75 | if post_n_layers > 0: 76 | # MLP to apply post pooling 77 | layers = [] 78 | if post_n_layers == 1: 79 | layers.append(nn.Linear(dim_pe, dim_pe)) 80 | layers.append(nn.ReLU()) 81 | else: 82 | layers.append(nn.Linear(dim_pe, 2 * dim_pe)) 83 | layers.append(nn.ReLU()) 84 | for _ in range(post_n_layers - 2): 85 | layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe)) 86 | layers.append(nn.ReLU()) 87 | layers.append(nn.Linear(2 * dim_pe, dim_pe)) 88 | layers.append(nn.ReLU()) 89 | self.post_mlp = nn.Sequential(*layers) 90 | 91 | 92 | def forward(self, batch): 93 | if not (hasattr(batch, 'EigVals') and hasattr(batch, 'EigVecs')): 94 | raise ValueError("Precomputed eigen values and vectors are " 95 | f"required for {self.__class__.__name__}; " 96 | "set config 'posenc_LapPE.enable' to True") 97 | EigVals = batch.EigVals 98 | EigVecs = batch.EigVecs 99 | 100 | if self.training: 101 | sign_flip = torch.rand(EigVecs.size(1), device=EigVecs.device) 102 | sign_flip[sign_flip >= 0.5] = 1.0 103 | sign_flip[sign_flip < 0.5] = -1.0 104 | EigVecs = EigVecs * sign_flip.unsqueeze(0) 105 | 106 | pos_enc = torch.cat((EigVecs.unsqueeze(2), EigVals), dim=2) # (Num nodes) x (Num Eigenvectors) x 2 107 | empty_mask = torch.isnan(pos_enc) # (Num nodes) x (Num Eigenvectors) x 2 108 | 109 | pos_enc[empty_mask] = 0 # (Num nodes) x (Num Eigenvectors) x 2 110 | if self.raw_norm: 111 | pos_enc = self.raw_norm(pos_enc) 112 | pos_enc = self.linear_A(pos_enc) # (Num nodes) x (Num Eigenvectors) x dim_pe 113 | 114 | # PE encoder: a Transformer or DeepSet model 115 | if self.model_type == 'Transformer': 116 | pos_enc = self.pe_encoder(src=pos_enc, 117 | src_key_padding_mask=empty_mask[:, :, 0]) 118 | else: 119 | pos_enc = self.pe_encoder(pos_enc) 120 | 121 | # Remove masked sequences; must clone before overwriting masked elements 122 | pos_enc = pos_enc.clone().masked_fill_(empty_mask[:, :, 0].unsqueeze(2), 123 | 0.) 124 | 125 | # Sum pooling 126 | pos_enc = torch.sum(pos_enc, 1, keepdim=False) # (Num nodes) x dim_pe 127 | 128 | # MLP post pooling 129 | if self.post_mlp is not None: 130 | pos_enc = self.post_mlp(pos_enc) # (Num nodes) x dim_pe 131 | 132 | # Expand node features if needed 133 | if self.expand_x: 134 | h = self.linear_x(batch.x.to(torch.float32)) 135 | else: 136 | h = batch.x 137 | # Concatenate final PEs to input embedding 138 | batch.x = torch.cat((h, pos_enc), 1) 139 | # Keep PE also separate in a variable (e.g. for skip connections to input) 140 | if self.pass_as_var: 141 | batch.pe_LapPE = pos_enc 142 | return batch 143 | -------------------------------------------------------------------------------- /gssc/encoder/linear_edge_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.graphgym import cfg 3 | from torch_geometric.graphgym.register import register_edge_encoder 4 | 5 | 6 | @register_edge_encoder('LinearEdge') 7 | class LinearEdgeEncoder(torch.nn.Module): 8 | def __init__(self, emb_dim): 9 | super().__init__() 10 | if cfg.dataset.name in ['MNIST', 'CIFAR10']: 11 | self.in_dim = 1 12 | elif cfg.dataset.name == "ogbn-proteins": 13 | self.in_dim = 8 14 | else: 15 | raise ValueError("Input edge feature dim is required to be hardset " 16 | "or refactored to use a cfg option.") 17 | self.encoder = torch.nn.Linear(self.in_dim, emb_dim) 18 | 19 | def forward(self, batch): 20 | batch.edge_attr = self.encoder(batch.edge_attr.view(-1, self.in_dim)) 21 | return batch 22 | -------------------------------------------------------------------------------- /gssc/encoder/linear_node_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.graphgym import cfg 3 | from torch_geometric.graphgym.register import register_node_encoder 4 | 5 | 6 | @register_node_encoder('LinearNode') 7 | class LinearNodeEncoder(torch.nn.Module): 8 | def __init__(self, emb_dim): 9 | super().__init__() 10 | 11 | # self.encoder = torch.nn.Linear(cfg.share.dim_in - 1, emb_dim // 2) 12 | # self.emb_layer = torch.nn.Embedding(7, emb_dim // 2) 13 | # self.bn = torch.nn.BatchNorm1d(emb_dim // 2) 14 | self.encoder = torch.nn.Linear(cfg.share.dim_in, emb_dim) 15 | 16 | def forward(self, batch): 17 | # emb_res = self.emb_layer(batch.x[:, 0].long()) 18 | # linear_res = self.encoder(batch.x[:, 1:].float()) 19 | # linear_res = self.bn(linear_res) 20 | 21 | batch.x = self.encoder(batch.x.float()) 22 | # batch.x = torch.cat([emb_res, linear_res], dim=1) 23 | return batch 24 | -------------------------------------------------------------------------------- /gssc/encoder/ppa_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.graphgym.register import (register_node_encoder, 3 | register_edge_encoder) 4 | 5 | 6 | @register_node_encoder('PPANode') 7 | class PPANodeEncoder(torch.nn.Module): 8 | """ 9 | Uniform input node embedding for PPA that has no node features. 10 | """ 11 | 12 | def __init__(self, emb_dim): 13 | super().__init__() 14 | self.encoder = torch.nn.Embedding(1, emb_dim) 15 | 16 | def forward(self, batch): 17 | batch.x = self.encoder(batch.x) 18 | return batch 19 | 20 | 21 | @register_edge_encoder('PPAEdge') 22 | class PPAEdgeEncoder(torch.nn.Module): 23 | def __init__(self, emb_dim): 24 | super().__init__() 25 | self.encoder = torch.nn.Linear(7, emb_dim) 26 | 27 | def forward(self, batch): 28 | batch.edge_attr = self.encoder(batch.edge_attr) 29 | return batch 30 | -------------------------------------------------------------------------------- /gssc/encoder/type_dict_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.graphgym.config import cfg 3 | from torch_geometric.graphgym.register import (register_node_encoder, 4 | register_edge_encoder) 5 | 6 | """ 7 | Generic Node and Edge encoders for datasets with node/edge features that 8 | consist of only one type dictionary thus require a single nn.Embedding layer. 9 | 10 | The number of possible Node and Edge types must be set by cfg options: 11 | 1) cfg.dataset.node_encoder_num_types 12 | 2) cfg.dataset.edge_encoder_num_types 13 | 14 | In case of a more complex feature set, use a data-specific encoder. 15 | 16 | These generic encoders can be used e.g. for: 17 | * ZINC 18 | cfg.dataset.node_encoder_num_types: 28 19 | cfg.dataset.edge_encoder_num_types: 4 20 | 21 | * AQSOL 22 | cfg.dataset.node_encoder_num_types: 65 23 | cfg.dataset.edge_encoder_num_types: 5 24 | 25 | 26 | === Description of the ZINC dataset === 27 | https://github.com/graphdeeplearning/benchmarking-gnns/issues/42 28 | The node labels are atom types and the edge labels atom bond types. 29 | 30 | Node labels: 31 | 'C': 0 32 | 'O': 1 33 | 'N': 2 34 | 'F': 3 35 | 'C H1': 4 36 | 'S': 5 37 | 'Cl': 6 38 | 'O -': 7 39 | 'N H1 +': 8 40 | 'Br': 9 41 | 'N H3 +': 10 42 | 'N H2 +': 11 43 | 'N +': 12 44 | 'N -': 13 45 | 'S -': 14 46 | 'I': 15 47 | 'P': 16 48 | 'O H1 +': 17 49 | 'N H1 -': 18 50 | 'O +': 19 51 | 'S +': 20 52 | 'P H1': 21 53 | 'P H2': 22 54 | 'C H2 -': 23 55 | 'P +': 24 56 | 'S H1 +': 25 57 | 'C H1 -': 26 58 | 'P H1 +': 27 59 | 60 | Edge labels: 61 | 'NONE': 0 62 | 'SINGLE': 1 63 | 'DOUBLE': 2 64 | 'TRIPLE': 3 65 | 66 | 67 | === Description of the AQSOL dataset === 68 | Node labels: 69 | 'Br': 0, 'C': 1, 'N': 2, 'O': 3, 'Cl': 4, 'Zn': 5, 'F': 6, 'P': 7, 'S': 8, 'Na': 9, 'Al': 10, 70 | 'Si': 11, 'Mo': 12, 'Ca': 13, 'W': 14, 'Pb': 15, 'B': 16, 'V': 17, 'Co': 18, 'Mg': 19, 'Bi': 20, 'Fe': 21, 71 | 'Ba': 22, 'K': 23, 'Ti': 24, 'Sn': 25, 'Cd': 26, 'I': 27, 'Re': 28, 'Sr': 29, 'H': 30, 'Cu': 31, 'Ni': 32, 72 | 'Lu': 33, 'Pr': 34, 'Te': 35, 'Ce': 36, 'Nd': 37, 'Gd': 38, 'Zr': 39, 'Mn': 40, 'As': 41, 'Hg': 42, 'Sb': 73 | 43, 'Cr': 44, 'Se': 45, 'La': 46, 'Dy': 47, 'Y': 48, 'Pd': 49, 'Ag': 50, 'In': 51, 'Li': 52, 'Rh': 53, 74 | 'Nb': 54, 'Hf': 55, 'Cs': 56, 'Ru': 57, 'Au': 58, 'Sm': 59, 'Ta': 60, 'Pt': 61, 'Ir': 62, 'Be': 63, 'Ge': 64 75 | 76 | Edge labels: 77 | 'NONE': 0, 'SINGLE': 1, 'DOUBLE': 2, 'AROMATIC': 3, 'TRIPLE': 4 78 | """ 79 | 80 | 81 | @register_node_encoder('TypeDictNode') 82 | class TypeDictNodeEncoder(torch.nn.Module): 83 | def __init__(self, emb_dim): 84 | super().__init__() 85 | 86 | num_types = cfg.dataset.node_encoder_num_types 87 | if num_types < 1: 88 | raise ValueError(f"Invalid 'node_encoder_num_types': {num_types}") 89 | 90 | self.encoder = torch.nn.Embedding(num_embeddings=num_types, 91 | embedding_dim=emb_dim) 92 | # torch.nn.init.xavier_uniform_(self.encoder.weight.data) 93 | 94 | def forward(self, batch): 95 | # Encode just the first dimension if more exist 96 | batch.x = self.encoder(batch.x[:, 0]) 97 | 98 | return batch 99 | 100 | 101 | @register_edge_encoder('TypeDictEdge') 102 | class TypeDictEdgeEncoder(torch.nn.Module): 103 | def __init__(self, emb_dim): 104 | super().__init__() 105 | 106 | num_types = cfg.dataset.edge_encoder_num_types 107 | if num_types < 1: 108 | raise ValueError(f"Invalid 'edge_encoder_num_types': {num_types}") 109 | 110 | self.encoder = torch.nn.Embedding(num_embeddings=num_types, 111 | embedding_dim=emb_dim) 112 | # torch.nn.init.xavier_uniform_(self.encoder.weight.data) 113 | 114 | def forward(self, batch): 115 | batch.edge_attr = self.encoder(batch.edge_attr) 116 | return batch 117 | -------------------------------------------------------------------------------- /gssc/encoder/voc_superpixels_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.graphgym.config import cfg 3 | from torch_geometric.graphgym.register import (register_node_encoder, 4 | register_edge_encoder) 5 | 6 | """ 7 | === Description of the VOCSuperpixels dataset === 8 | Each graph is a tuple (x, edge_attr, edge_index, y) 9 | Shape of x : [num_nodes, 14] 10 | Shape of edge_attr : [num_edges, 1] or [num_edges, 2] 11 | Shape of edge_index : [2, num_edges] 12 | Shape of y : [num_nodes] 13 | """ 14 | 15 | VOC_node_input_dim = 14 16 | # VOC_edge_input_dim = 1 or 2; defined in class VOCEdgeEncoder 17 | 18 | @register_node_encoder('VOCNode') 19 | class VOCNodeEncoder(torch.nn.Module): 20 | def __init__(self, emb_dim): 21 | super().__init__() 22 | 23 | self.encoder = torch.nn.Linear(VOC_node_input_dim, emb_dim) 24 | # torch.nn.init.xavier_uniform_(self.encoder.weight.data) 25 | 26 | def forward(self, batch): 27 | batch.x = self.encoder(batch.x) 28 | 29 | return batch 30 | 31 | 32 | @register_edge_encoder('VOCEdge') 33 | class VOCEdgeEncoder(torch.nn.Module): 34 | def __init__(self, emb_dim): 35 | super().__init__() 36 | 37 | VOC_edge_input_dim = 2 if cfg.dataset.name == 'edge_wt_region_boundary' else 1 38 | self.encoder = torch.nn.Linear(VOC_edge_input_dim, emb_dim) 39 | # torch.nn.init.xavier_uniform_(self.encoder.weight.data) 40 | 41 | def forward(self, batch): 42 | batch.edge_attr = self.encoder(batch.edge_attr) 43 | return batch 44 | -------------------------------------------------------------------------------- /gssc/finetuning.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import os.path as osp 4 | 5 | import torch 6 | from torch_geometric.graphgym.config import set_cfg 7 | from yacs.config import CfgNode 8 | 9 | 10 | def get_final_pretrained_ckpt(ckpt_dir): 11 | if osp.exists(ckpt_dir): 12 | names = os.listdir(ckpt_dir) 13 | epochs = [int(name.split('.')[0]) for name in names] 14 | final_epoch = max(epochs) 15 | else: 16 | raise FileNotFoundError(f"Pretrained model dir not found: {ckpt_dir}") 17 | return osp.join(ckpt_dir, f'{final_epoch}.ckpt') 18 | 19 | 20 | def compare_cfg(cfg_main, cfg_secondary, field_name, strict=False): 21 | main_val, secondary_val = cfg_main, cfg_secondary 22 | for f in field_name.split('.'): 23 | main_val = main_val[f] 24 | secondary_val = secondary_val[f] 25 | if main_val != secondary_val: 26 | if strict: 27 | raise ValueError(f"Main and pretrained configs must match on " 28 | f"'{field_name}'") 29 | else: 30 | logging.warning(f"Pretrained models '{field_name}' differs, " 31 | f"using: {main_val}") 32 | 33 | 34 | def set_new_cfg_allowed(config, is_new_allowed): 35 | """ Set YACS config (and recursively its subconfigs) to allow merging 36 | new keys from other configs. 37 | """ 38 | config.__dict__[CfgNode.NEW_ALLOWED] = is_new_allowed 39 | # Recursively set new_allowed state 40 | for v in config.__dict__.values(): 41 | if isinstance(v, CfgNode): 42 | set_new_cfg_allowed(v, is_new_allowed) 43 | for v in config.values(): 44 | if isinstance(v, CfgNode): 45 | set_new_cfg_allowed(v, is_new_allowed) 46 | 47 | 48 | def load_pretrained_model_cfg(cfg): 49 | pretrained_cfg_fname = osp.join(cfg.pretrained.dir, 'config.yaml') 50 | if not os.path.isfile(pretrained_cfg_fname): 51 | FileNotFoundError(f"Pretrained model config not found: " 52 | f"{pretrained_cfg_fname}") 53 | 54 | logging.info(f"[*] Updating cfg from pretrained model: " 55 | f"{pretrained_cfg_fname}") 56 | 57 | pretrained_cfg = CfgNode() 58 | set_cfg(pretrained_cfg) 59 | set_new_cfg_allowed(pretrained_cfg, True) 60 | pretrained_cfg.merge_from_file(pretrained_cfg_fname) 61 | 62 | assert cfg.model.type == 'GPSModel', \ 63 | "Fine-tuning regime is untested for other model types." 64 | compare_cfg(cfg, pretrained_cfg, 'model.type', strict=True) 65 | compare_cfg(cfg, pretrained_cfg, 'model.graph_pooling') 66 | compare_cfg(cfg, pretrained_cfg, 'model.edge_decoding') 67 | compare_cfg(cfg, pretrained_cfg, 'dataset.node_encoder', strict=True) 68 | compare_cfg(cfg, pretrained_cfg, 'dataset.node_encoder_name', strict=True) 69 | compare_cfg(cfg, pretrained_cfg, 'dataset.node_encoder_bn', strict=True) 70 | compare_cfg(cfg, pretrained_cfg, 'dataset.edge_encoder', strict=True) 71 | compare_cfg(cfg, pretrained_cfg, 'dataset.edge_encoder_name', strict=True) 72 | compare_cfg(cfg, pretrained_cfg, 'dataset.edge_encoder_bn', strict=True) 73 | 74 | # Copy over all PE/SE configs 75 | for key in cfg.keys(): 76 | if key.startswith('posenc_'): 77 | cfg[key] = pretrained_cfg[key] 78 | 79 | # Copy over GT config 80 | cfg.gt = pretrained_cfg.gt 81 | 82 | # Copy over GNN cfg but not those for the prediction head 83 | compare_cfg(cfg, pretrained_cfg, 'gnn.head') 84 | compare_cfg(cfg, pretrained_cfg, 'gnn.layers_post_mp') 85 | compare_cfg(cfg, pretrained_cfg, 'gnn.act') 86 | compare_cfg(cfg, pretrained_cfg, 'gnn.dropout') 87 | head = cfg.gnn.head 88 | post_mp = cfg.gnn.layers_post_mp 89 | act = cfg.gnn.act 90 | drp = cfg.gnn.dropout 91 | cfg.gnn = pretrained_cfg.gnn 92 | cfg.gnn.head = head 93 | cfg.gnn.layers_post_mp = post_mp 94 | cfg.gnn.act = act 95 | cfg.gnn.dropout = drp 96 | return cfg 97 | 98 | 99 | def init_model_from_pretrained(model, pretrained_dir, 100 | freeze_main=False, reset_prediction_head=True): 101 | """ Copy model parameters from pretrained model except the prediction head. 102 | 103 | Args: 104 | model: Initialized model with random weights. 105 | pretrained_dir: Root directory of saved pretrained model. 106 | freeze_main: If True, do not finetune the loaded pretrained parameters 107 | of the `main body` (train the prediction head only), else train all. 108 | reset_prediction_head: If True, reset parameters of the prediction head, 109 | else keep the pretrained weights. 110 | 111 | Returns: 112 | Updated pytorch model object. 113 | """ 114 | from torch_geometric.graphgym.checkpoint import MODEL_STATE 115 | 116 | ckpt_file = get_final_pretrained_ckpt(osp.join(pretrained_dir, '0', 'ckpt')) 117 | logging.info(f"[*] Loading from pretrained model: {ckpt_file}") 118 | 119 | ckpt = torch.load(ckpt_file) 120 | pretrained_dict = ckpt[MODEL_STATE] 121 | model_dict = model.state_dict() 122 | 123 | # print('>>>> pretrained dict: ') 124 | # print(pretrained_dict.keys()) 125 | # print('>>>> model dict: ') 126 | # print(model_dict.keys()) 127 | 128 | if reset_prediction_head: 129 | # Filter out prediction head parameter keys. 130 | pretrained_dict = {k: v for k, v in pretrained_dict.items() 131 | if not k.startswith('post_mp')} 132 | # Overwrite entries in the existing state dict. 133 | model_dict.update(pretrained_dict) 134 | # Load the new state dict. 135 | model.load_state_dict(model_dict) 136 | 137 | if freeze_main: 138 | for key, param in model.named_parameters(): 139 | if not key.startswith('post_mp'): 140 | param.requires_grad = False 141 | return model 142 | -------------------------------------------------------------------------------- /gssc/head/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /gssc/head/example.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from torch_geometric.graphgym.register import register_head 4 | 5 | 6 | @register_head('head') 7 | class ExampleNodeHead(nn.Module): 8 | '''Head of GNN, node prediction''' 9 | def __init__(self, dim_in, dim_out): 10 | super().__init__() 11 | self.layer_post_mp = nn.Linear(dim_in, dim_out, bias=True) 12 | 13 | def _apply_index(self, batch): 14 | if batch.node_label_index.shape[0] == batch.node_label.shape[0]: 15 | return batch.x[batch.node_label_index], batch.node_label 16 | else: 17 | return batch.x[batch.node_label_index], \ 18 | batch.node_label[batch.node_label_index] 19 | 20 | def forward(self, batch): 21 | batch = self.layer_post_mp(batch) 22 | pred, label = self._apply_index(batch) 23 | return pred, label 24 | -------------------------------------------------------------------------------- /gssc/head/inductive_node.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch_geometric.graphgym.config import cfg 3 | from torch_geometric.graphgym.models.layer import new_layer_config, MLP 4 | from torch_geometric.graphgym.register import register_head 5 | import torch 6 | 7 | 8 | @register_head('inductive_node') 9 | class GNNInductiveNodeHead(nn.Module): 10 | """ 11 | GNN prediction head for inductive node prediction tasks. 12 | 13 | Args: 14 | dim_in (int): Input dimension 15 | dim_out (int): Output dimension. For binary prediction, dim_out=1. 16 | """ 17 | 18 | def __init__(self, dim_in, dim_out): 19 | super(GNNInductiveNodeHead, self).__init__() 20 | self.layer_post_mp = MLP( 21 | new_layer_config(dim_in, dim_out, cfg.gnn.layers_post_mp, 22 | has_act=False, has_bias=True, cfg=cfg)) 23 | if cfg.extra.jk: 24 | self.jk_mlp = nn.Sequential(nn.Linear(cfg.gnn.dim_inner*cfg.gt.layers, cfg.gnn.dim_inner), nn.SiLU(), nn.Linear(cfg.gnn.dim_inner, cfg.gnn.dim_inner)) 25 | 26 | def _apply_index(self, batch): 27 | return batch.x, batch.y 28 | 29 | def forward(self, batch): 30 | if cfg.extra.jk: 31 | batch.x = self.jk_mlp(torch.cat(batch.all_x, dim=-1)) 32 | batch = self.layer_post_mp(batch) 33 | pred, label = self._apply_index(batch) 34 | return pred, label 35 | -------------------------------------------------------------------------------- /gssc/head/ogb_code_graph.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | import torch_geometric.graphgym.register as register 4 | from torch_geometric.graphgym import cfg 5 | from torch_geometric.graphgym.register import register_head 6 | 7 | 8 | @register_head('ogb_code_graph') 9 | class OGBCodeGraphHead(nn.Module): 10 | """ 11 | Sequence prediction head for ogbg-code2 graph-level prediction tasks. 12 | 13 | Args: 14 | dim_in (int): Input dimension. 15 | dim_out (int): IGNORED, kept for GraphGym framework compatibility 16 | L (int): Number of hidden layers. 17 | """ 18 | 19 | def __init__(self, dim_in, dim_out, L=1): 20 | super().__init__() 21 | self.pooling_fun = register.pooling_dict[cfg.model.graph_pooling] 22 | self.L = L 23 | num_vocab = 5002 24 | self.max_seq_len = 5 25 | 26 | if self.L != 1: 27 | raise ValueError(f"Multilayer prediction heads are not supported.") 28 | 29 | self.graph_pred_linear_list = nn.ModuleList() 30 | for i in range(self.max_seq_len): 31 | self.graph_pred_linear_list.append(nn.Linear(dim_in, num_vocab)) 32 | 33 | def _apply_index(self, batch): 34 | return batch.pred_list, {'y_arr': batch.y_arr, 'y': batch.y} 35 | 36 | def forward(self, batch): 37 | graph_emb = self.pooling_fun(batch.x, batch.batch) 38 | 39 | pred_list = [] 40 | for i in range(self.max_seq_len): 41 | pred_list.append(self.graph_pred_linear_list[i](graph_emb)) 42 | batch.pred_list = pred_list 43 | 44 | pred, label = self._apply_index(batch) 45 | return pred, label 46 | -------------------------------------------------------------------------------- /gssc/head/san_graph.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | import torch_geometric.graphgym.register as register 5 | from torch_geometric.graphgym import cfg 6 | from torch_geometric.graphgym.register import register_head 7 | 8 | 9 | @register_head('san_graph') 10 | class SANGraphHead(nn.Module): 11 | """ 12 | SAN prediction head for graph prediction tasks. 13 | 14 | Args: 15 | dim_in (int): Input dimension. 16 | dim_out (int): Output dimension. For binary prediction, dim_out=1. 17 | L (int): Number of hidden layers. 18 | """ 19 | 20 | def __init__(self, dim_in, dim_out, L=2): 21 | super().__init__() 22 | self.pooling_fun = register.pooling_dict[cfg.model.graph_pooling] 23 | list_FC_layers = [ 24 | nn.Linear(dim_in // 2 ** l, dim_in // 2 ** (l + 1), bias=True) 25 | for l in range(L)] 26 | list_FC_layers.append( 27 | nn.Linear(dim_in // 2 ** L, dim_out, bias=True)) 28 | self.FC_layers = nn.ModuleList(list_FC_layers) 29 | self.L = L 30 | 31 | def _apply_index(self, batch): 32 | return batch.graph_feature, batch.y 33 | 34 | def forward(self, batch): 35 | graph_emb = self.pooling_fun(batch.x, batch.batch) 36 | for l in range(self.L): 37 | graph_emb = self.FC_layers[l](graph_emb) 38 | graph_emb = F.relu(graph_emb) 39 | graph_emb = self.FC_layers[self.L](graph_emb) 40 | batch.graph_feature = graph_emb 41 | pred, label = self._apply_index(batch) 42 | return pred, label 43 | -------------------------------------------------------------------------------- /gssc/layer/ETransformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch_scatter import scatter 6 | 7 | from torch_geometric.graphgym.models.layer import LayerConfig 8 | from torch_geometric.graphgym.config import cfg 9 | from torch_geometric.graphgym.register import register_layer 10 | 11 | 12 | class ETransformer(nn.Module): 13 | """Mostly Multi-Head Graph Attention Layer. 14 | 15 | Ported to PyG from original repo: 16 | https://github.com/DevinKreuzer/SAN/blob/main/layers/graph_transformer_layer.py 17 | """ 18 | 19 | def __init__(self, in_dim, out_dim, num_heads, use_bias, edge_index='edge_index', use_edge_attr=False, edge_attr='edge_attr'): 20 | super().__init__() 21 | 22 | if out_dim % num_heads != 0: 23 | raise ValueError('hidden dimension is not dividable by the number of heads') 24 | self.out_dim = out_dim // num_heads 25 | self.num_heads = num_heads 26 | self.edge_index = edge_index 27 | self.edge_attr = edge_attr 28 | self.use_edge_attr = use_edge_attr 29 | 30 | self.Q = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias) 31 | self.K = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias) 32 | if self.use_edge_attr: 33 | self.E = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias) 34 | self.V = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias) 35 | 36 | def propagate_attention(self, batch): 37 | edge_index = getattr(batch, self.edge_index) 38 | 39 | src = batch.K_h[edge_index[0]] # (num real edges) x num_heads x out_dim 40 | dest = batch.Q_h[edge_index[1]] # (num real edges) x num_heads x out_dim 41 | score = torch.mul(src, dest) # element-wise multiplication 42 | 43 | # Scale scores by sqrt(d) 44 | score = score / np.sqrt(self.out_dim) 45 | 46 | # Use available edge features to modify the scores for edges 47 | if self.use_edge_attr: 48 | score = torch.mul(score, batch.E) # (num real edges) x num_heads x out_dim 49 | score = torch.exp(score.sum(-1, keepdim=True).clamp(-5, 5)) # (num real edges) x num_heads x 1 50 | 51 | # Apply attention score to each source node to create edge messages 52 | msg = batch.V_h[edge_index[0]] * score # (num real edges) x num_heads x out_dim 53 | # Add-up real msgs in destination nodes as given by batch.edge_index[1] 54 | batch.wV = torch.zeros_like(batch.V_h) # (num nodes in batch) x num_heads x out_dim 55 | scatter(msg, edge_index[1], dim=0, out=batch.wV, reduce='add') 56 | 57 | # Compute attention normalization coefficient 58 | batch.Z = score.new_zeros(batch.size(0), self.num_heads, 1) # (num nodes in batch) x num_heads x 1 59 | scatter(score, edge_index[1], dim=0, out=batch.Z, reduce='add') 60 | 61 | def forward(self, batch): 62 | edge_index = getattr(batch, self.edge_index) 63 | 64 | if edge_index is None: 65 | raise ValueError(f'edge index: f{self.edge_index} not found') 66 | 67 | if edge_index.shape[0] != 2 and edge_index.shape[1] == 2: 68 | edge_index = torch.t(edge_index) 69 | setattr(batch, self.edge_index, edge_index) 70 | 71 | if self.use_edge_attr: 72 | edge_attr = getattr(batch, self.edge_attr) 73 | if edge_attr is None or edge_attr.shape[0] != edge_index.shape[1]: 74 | print('edge_attr shape does not match edge_index shape, ignoring edge_attr') 75 | self.use_edge_attr = False 76 | 77 | Q_h = self.Q(batch.x) 78 | K_h = self.K(batch.x) 79 | if self.use_edge_attr: 80 | E = self.E(edge_attr) 81 | V_h = self.V(batch.x) 82 | 83 | # Reshaping into [num_nodes, num_heads, feat_dim] to 84 | # get projections for multi-head attention 85 | batch.Q_h = Q_h.view(-1, self.num_heads, self.out_dim) 86 | batch.K_h = K_h.view(-1, self.num_heads, self.out_dim) 87 | if self.use_edge_attr: 88 | batch.E = E.view(-1, self.num_heads, self.out_dim) 89 | batch.V_h = V_h.view(-1, self.num_heads, self.out_dim) 90 | 91 | self.propagate_attention(batch) 92 | 93 | h_out = batch.wV / (batch.Z + 1e-6) 94 | 95 | h_out = h_out.view(-1, self.out_dim * self.num_heads) 96 | 97 | return h_out 98 | 99 | 100 | register_layer('etransformer', ETransformer) -------------------------------------------------------------------------------- /gssc/layer/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /gssc/layer/example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Parameter 4 | 5 | from torch_geometric.graphgym.config import cfg 6 | from torch_geometric.graphgym.register import register_layer 7 | from torch_geometric.nn.conv import MessagePassing 8 | from torch_geometric.nn.inits import glorot, zeros 9 | 10 | # Note: A registered GNN layer should take 'batch' as input 11 | # and 'batch' as output 12 | 13 | 14 | # Example 1: Directly define a GraphGym format Conv 15 | # take 'batch' as input and 'batch' as output 16 | @register_layer('exampleconv1') 17 | class ExampleConv1(MessagePassing): 18 | r"""Example GNN layer 19 | """ 20 | def __init__(self, in_channels, out_channels, bias=True, **kwargs): 21 | super().__init__(aggr=cfg.gnn.agg, **kwargs) 22 | 23 | self.in_channels = in_channels 24 | self.out_channels = out_channels 25 | 26 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 27 | 28 | if bias: 29 | self.bias = Parameter(torch.Tensor(out_channels)) 30 | else: 31 | self.register_parameter('bias', None) 32 | 33 | self.reset_parameters() 34 | 35 | def reset_parameters(self): 36 | glorot(self.weight) 37 | zeros(self.bias) 38 | 39 | def forward(self, batch): 40 | """""" 41 | x, edge_index = batch.x, batch.edge_index 42 | x = torch.matmul(x, self.weight) 43 | 44 | batch.x = self.propagate(edge_index, x=x) 45 | 46 | return batch 47 | 48 | def message(self, x_j): 49 | return x_j 50 | 51 | def update(self, aggr_out): 52 | if self.bias is not None: 53 | aggr_out = aggr_out + self.bias 54 | return aggr_out 55 | 56 | 57 | # Example 2: First define a PyG format Conv layer 58 | # Then wrap it to become GraphGym format 59 | class ExampleConv2Layer(MessagePassing): 60 | r"""Example GNN layer 61 | """ 62 | def __init__(self, in_channels, out_channels, bias=True, **kwargs): 63 | super().__init__(aggr=cfg.gnn.agg, **kwargs) 64 | 65 | self.in_channels = in_channels 66 | self.out_channels = out_channels 67 | 68 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 69 | 70 | if bias: 71 | self.bias = Parameter(torch.Tensor(out_channels)) 72 | else: 73 | self.register_parameter('bias', None) 74 | 75 | self.reset_parameters() 76 | 77 | def reset_parameters(self): 78 | glorot(self.weight) 79 | zeros(self.bias) 80 | 81 | def forward(self, x, edge_index): 82 | """""" 83 | x = torch.matmul(x, self.weight) 84 | 85 | return self.propagate(edge_index, x=x) 86 | 87 | def message(self, x_j): 88 | return x_j 89 | 90 | def update(self, aggr_out): 91 | if self.bias is not None: 92 | aggr_out = aggr_out + self.bias 93 | return aggr_out 94 | 95 | 96 | @register_layer('exampleconv2') 97 | class ExampleConv2(nn.Module): 98 | def __init__(self, dim_in, dim_out, bias=False, **kwargs): 99 | super().__init__() 100 | self.model = ExampleConv2Layer(dim_in, dim_out, bias=bias) 101 | 102 | def forward(self, batch): 103 | batch.x = self.model(batch.x, batch.edge_index) 104 | return batch 105 | -------------------------------------------------------------------------------- /gssc/layer/gatedgcn_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch_geometric.nn as pyg_nn 5 | from torch_geometric.graphgym.models.layer import LayerConfig 6 | from torch_scatter import scatter 7 | 8 | from torch_geometric.graphgym.config import cfg 9 | from torch_geometric.graphgym.register import register_layer 10 | 11 | 12 | class GatedGCNLayer(pyg_nn.conv.MessagePassing): 13 | """ 14 | GatedGCN layer 15 | Residual Gated Graph ConvNets 16 | https://arxiv.org/pdf/1711.07553.pdf 17 | """ 18 | def __init__(self, in_dim, out_dim, dropout, residual, 19 | equivstable_pe=False, **kwargs): 20 | super().__init__(**kwargs) 21 | self.A = pyg_nn.Linear(in_dim, out_dim, bias=True) 22 | self.B = pyg_nn.Linear(in_dim, out_dim, bias=True) 23 | self.C = pyg_nn.Linear(in_dim, out_dim, bias=True) 24 | self.D = pyg_nn.Linear(in_dim, out_dim, bias=True) 25 | self.E = pyg_nn.Linear(in_dim, out_dim, bias=True) 26 | 27 | # Handling for Equivariant and Stable PE using LapPE 28 | # ICLR 2022 https://openreview.net/pdf?id=e95i1IHcWj 29 | self.EquivStablePE = equivstable_pe 30 | if self.EquivStablePE: 31 | self.mlp_r_ij = nn.Sequential( 32 | nn.Linear(1, out_dim), nn.ReLU(), 33 | nn.Linear(out_dim, 1), 34 | nn.Sigmoid()) 35 | 36 | self.bn_node_x = nn.BatchNorm1d(out_dim) 37 | self.bn_edge_e = nn.BatchNorm1d(out_dim) 38 | self.dropout = dropout 39 | self.residual = residual 40 | self.e = None 41 | 42 | def forward(self, batch): 43 | x, e, edge_index = batch.x, batch.edge_attr, batch.edge_index 44 | 45 | """ 46 | x : [n_nodes, in_dim] 47 | e : [n_edges, in_dim] 48 | edge_index : [2, n_edges] 49 | """ 50 | if self.residual: 51 | x_in = x 52 | e_in = e 53 | 54 | Ax = self.A(x) 55 | Bx = self.B(x) 56 | Ce = self.C(e) 57 | Dx = self.D(x) 58 | Ex = self.E(x) 59 | 60 | # Handling for Equivariant and Stable PE using LapPE 61 | # ICLR 2022 https://openreview.net/pdf?id=e95i1IHcWj 62 | pe_LapPE = batch.pe_EquivStableLapPE if self.EquivStablePE else None 63 | 64 | x, e = self.propagate(edge_index, 65 | Bx=Bx, Dx=Dx, Ex=Ex, Ce=Ce, 66 | e=e, Ax=Ax, 67 | PE=pe_LapPE) 68 | 69 | x = self.bn_node_x(x) 70 | e = self.bn_edge_e(e) 71 | 72 | x = F.relu(x) 73 | e = F.relu(e) 74 | 75 | x = F.dropout(x, self.dropout, training=self.training) 76 | e = F.dropout(e, self.dropout, training=self.training) 77 | 78 | if self.residual: 79 | x = x_in + x 80 | e = e_in + e 81 | 82 | batch.x = x 83 | batch.edge_attr = e 84 | 85 | return batch 86 | 87 | def message(self, Dx_i, Ex_j, PE_i, PE_j, Ce): 88 | """ 89 | {}x_i : [n_edges, out_dim] 90 | {}x_j : [n_edges, out_dim] 91 | {}e : [n_edges, out_dim] 92 | """ 93 | e_ij = Dx_i + Ex_j + Ce 94 | sigma_ij = torch.sigmoid(e_ij) 95 | 96 | # Handling for Equivariant and Stable PE using LapPE 97 | # ICLR 2022 https://openreview.net/pdf?id=e95i1IHcWj 98 | if self.EquivStablePE: 99 | r_ij = ((PE_i - PE_j) ** 2).sum(dim=-1, keepdim=True) 100 | r_ij = self.mlp_r_ij(r_ij) # the MLP is 1 dim --> hidden_dim --> 1 dim 101 | sigma_ij = sigma_ij * r_ij 102 | 103 | self.e = e_ij 104 | return sigma_ij 105 | 106 | def aggregate(self, sigma_ij, index, Bx_j, Bx): 107 | """ 108 | sigma_ij : [n_edges, out_dim] ; is the output from message() function 109 | index : [n_edges] 110 | {}x_j : [n_edges, out_dim] 111 | """ 112 | dim_size = Bx.shape[0] # or None ?? <--- Double check this 113 | 114 | sum_sigma_x = sigma_ij * Bx_j 115 | numerator_eta_xj = scatter(sum_sigma_x, index, 0, None, dim_size, 116 | reduce='sum') 117 | 118 | sum_sigma = sigma_ij 119 | denominator_eta_xj = scatter(sum_sigma, index, 0, None, dim_size, 120 | reduce='sum') 121 | 122 | out = numerator_eta_xj / (denominator_eta_xj + 1e-6) 123 | return out 124 | 125 | def update(self, aggr_out, Ax): 126 | """ 127 | aggr_out : [n_nodes, out_dim] ; is the output from aggregate() function after the aggregation 128 | {}x : [n_nodes, out_dim] 129 | """ 130 | x = Ax + aggr_out 131 | e_out = self.e 132 | del self.e 133 | return x, e_out 134 | 135 | 136 | @register_layer('gatedgcnconv') 137 | class GatedGCNGraphGymLayer(nn.Module): 138 | """GatedGCN layer. 139 | Residual Gated Graph ConvNets 140 | https://arxiv.org/pdf/1711.07553.pdf 141 | """ 142 | def __init__(self, layer_config: LayerConfig, **kwargs): 143 | super().__init__() 144 | self.model = GatedGCNLayer(in_dim=layer_config.dim_in, 145 | out_dim=layer_config.dim_out, 146 | dropout=0., # Dropout is handled by GraphGym's `GeneralLayer` wrapper 147 | residual=False, # Residual connections are handled by GraphGym's `GNNStackStage` wrapper 148 | **kwargs) 149 | 150 | def forward(self, batch): 151 | return self.model(batch) 152 | -------------------------------------------------------------------------------- /gssc/layer/gine_conv_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch_geometric.nn as pyg_nn 5 | 6 | from torch_geometric.graphgym.models.layer import LayerConfig 7 | from torch_geometric.graphgym.register import register_layer 8 | from torch_geometric.nn import Linear as Linear_pyg 9 | 10 | 11 | class GINEConvESLapPE(pyg_nn.conv.MessagePassing): 12 | """GINEConv Layer with EquivStableLapPE implementation. 13 | 14 | Modified torch_geometric.nn.conv.GINEConv layer to perform message scaling 15 | according to equiv. stable PEG-layer with Laplacian Eigenmap (LapPE): 16 | ICLR 2022 https://openreview.net/pdf?id=e95i1IHcWj 17 | """ 18 | def __init__(self, nn, eps=0., train_eps=False, edge_dim=None, **kwargs): 19 | kwargs.setdefault('aggr', 'add') 20 | super().__init__(**kwargs) 21 | self.nn = nn 22 | self.initial_eps = eps 23 | if train_eps: 24 | self.eps = torch.nn.Parameter(torch.Tensor([eps])) 25 | else: 26 | self.register_buffer('eps', torch.Tensor([eps])) 27 | if edge_dim is not None: 28 | if hasattr(self.nn[0], 'in_features'): 29 | in_channels = self.nn[0].in_features 30 | else: 31 | in_channels = self.nn[0].in_channels 32 | self.lin = pyg_nn.Linear(edge_dim, in_channels) 33 | else: 34 | self.lin = None 35 | self.reset_parameters() 36 | 37 | if hasattr(self.nn[0], 'in_features'): 38 | out_dim = self.nn[0].out_features 39 | else: 40 | out_dim = self.nn[0].out_channels 41 | 42 | # Handling for Equivariant and Stable PE using LapPE 43 | # ICLR 2022 https://openreview.net/pdf?id=e95i1IHcWj 44 | self.mlp_r_ij = torch.nn.Sequential( 45 | torch.nn.Linear(1, out_dim), torch.nn.ReLU(), 46 | torch.nn.Linear(out_dim, 1), 47 | torch.nn.Sigmoid()) 48 | 49 | def reset_parameters(self): 50 | pyg_nn.inits.reset(self.nn) 51 | self.eps.data.fill_(self.initial_eps) 52 | if self.lin is not None: 53 | self.lin.reset_parameters() 54 | pyg_nn.inits.reset(self.mlp_r_ij) 55 | 56 | def forward(self, x, edge_index, edge_attr=None, pe_LapPE=None, size=None): 57 | # if isinstance(x, Tensor): 58 | # x: OptPairTensor = (x, x) 59 | 60 | # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) 61 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr, 62 | PE=pe_LapPE, size=size) 63 | 64 | x_r = x[1] 65 | if x_r is not None: 66 | out += (1 + self.eps) * x_r 67 | 68 | return self.nn(out) 69 | 70 | def message(self, x_j, edge_attr, PE_i, PE_j): 71 | if self.lin is None and x_j.size(-1) != edge_attr.size(-1): 72 | raise ValueError("Node and edge feature dimensionalities do not " 73 | "match. Consider setting the 'edge_dim' " 74 | "attribute of 'GINEConv'") 75 | 76 | if self.lin is not None: 77 | edge_attr = self.lin(edge_attr) 78 | 79 | # Handling for Equivariant and Stable PE using LapPE 80 | # ICLR 2022 https://openreview.net/pdf?id=e95i1IHcWj 81 | r_ij = ((PE_i - PE_j) ** 2).sum(dim=-1, keepdim=True) 82 | r_ij = self.mlp_r_ij(r_ij) # the MLP is 1 dim --> hidden_dim --> 1 dim 83 | 84 | return ((x_j + edge_attr).relu()) * r_ij 85 | 86 | def __repr__(self): 87 | return f'{self.__class__.__name__}(nn={self.nn})' 88 | 89 | 90 | class GINEConvLayer(nn.Module): 91 | """Graph Isomorphism Network with Edge features (GINE) layer. 92 | """ 93 | def __init__(self, dim_in, dim_out, dropout, residual): 94 | super().__init__() 95 | self.dim_in = dim_in 96 | self.dim_out = dim_out 97 | self.dropout = dropout 98 | self.residual = residual 99 | 100 | gin_nn = nn.Sequential( 101 | pyg_nn.Linear(dim_in, dim_out), nn.ReLU(), 102 | pyg_nn.Linear(dim_out, dim_out)) 103 | self.model = pyg_nn.GINEConv(gin_nn) 104 | 105 | def forward(self, batch): 106 | x_in = batch.x 107 | 108 | batch.x = self.model(batch.x, batch.edge_index, batch.edge_attr) 109 | 110 | batch.x = F.relu(batch.x) 111 | batch.x = F.dropout(batch.x, p=self.dropout, training=self.training) 112 | 113 | if self.residual: 114 | batch.x = x_in + batch.x # residual connection 115 | 116 | return batch 117 | 118 | 119 | @register_layer('gineconv') 120 | class GINEConvGraphGymLayer(nn.Module): 121 | """Graph Isomorphism Network with Edge features (GINE) layer. 122 | """ 123 | def __init__(self, layer_config: LayerConfig, **kwargs): 124 | super().__init__() 125 | gin_nn = nn.Sequential( 126 | Linear_pyg(layer_config.dim_in, layer_config.dim_out), nn.ReLU(), 127 | Linear_pyg(layer_config.dim_out, layer_config.dim_out)) 128 | self.model = pyg_nn.GINEConv(gin_nn) 129 | 130 | def forward(self, batch): 131 | batch.x = self.model(batch.x, batch.edge_index, batch.edge_attr) 132 | return batch 133 | -------------------------------------------------------------------------------- /gssc/layer/gssc_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | from torch_geometric.utils import to_dense_batch 5 | import torch.nn.functional as F 6 | from torch_geometric.graphgym.config import cfg 7 | 8 | 9 | class GSSC(nn.Module): 10 | def __init__(self, *args, **kwargs) -> None: 11 | super().__init__(*args, **kwargs) 12 | self.deg_coef = nn.Parameter(torch.zeros(1, 1, cfg.gnn.dim_inner, 2)) 13 | nn.init.xavier_normal_(self.deg_coef) 14 | if cfg.extra.more_mapping: 15 | self.x_head_mapping = nn.Linear(cfg.gnn.dim_inner, cfg.gnn.dim_inner, bias=False) 16 | self.q_pe_head_mapping = nn.Linear(cfg.extra.init_pe_dim, cfg.extra.init_pe_dim, bias=False) 17 | self.k_pe_head_mapping = nn.Linear(cfg.extra.init_pe_dim, cfg.extra.init_pe_dim, bias=False) 18 | self.out_mapping = nn.Linear(cfg.gnn.dim_inner, cfg.gnn.dim_inner, bias=True) 19 | else: 20 | self.x_head_mapping = nn.Identity() 21 | self.q_pe_head_mapping = nn.Identity() 22 | self.k_pe_head_mapping = nn.Identity() 23 | self.out_mapping = nn.Identity() 24 | if cfg.extra.reweigh_self: 25 | self.reweigh_pe = nn.Linear(cfg.extra.init_pe_dim, cfg.extra.init_pe_dim, bias=False) 26 | self.reweigh_x = nn.Linear(cfg.gnn.dim_inner, cfg.gnn.dim_inner, bias=False) 27 | if cfg.extra.reweigh_self == 2: 28 | self.reweigh_pe_2 = nn.Linear(cfg.extra.init_pe_dim, cfg.extra.init_pe_dim, bias=False) 29 | 30 | def forward(self, x, batch): 31 | init_pe = batch.init_pe 32 | log_deg = to_dense_batch(batch.log_deg, batch.batch)[0][..., None] 33 | x = torch.stack([x, x * log_deg], dim=-1) 34 | x = (x * self.deg_coef).sum(dim=-1) 35 | 36 | if cfg.extra.reweigh_self: 37 | pe_reweigh = self.reweigh_pe(init_pe) 38 | x_reweigh = self.reweigh_x(x) 39 | pe_reweigh_2 = self.reweigh_pe_2(init_pe) if cfg.extra.reweigh_self == 2 else pe_reweigh 40 | 41 | x = self.x_head_mapping(x) 42 | q_pe = self.q_pe_head_mapping(init_pe) 43 | k_pe = self.k_pe_head_mapping(init_pe) 44 | first = torch.einsum("bnrd, bnl -> brdl", k_pe, x) 45 | x = torch.einsum("bnrd, brdl -> bnl", q_pe, first) 46 | x = self.out_mapping(x) 47 | 48 | if cfg.extra.reweigh_self: 49 | x = x + (pe_reweigh * pe_reweigh_2).sum(dim=(-1, -2))[..., None] * x_reweigh 50 | return x 51 | -------------------------------------------------------------------------------- /gssc/loader/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /gssc/loader/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-COM/GSSC/ebb0ce4c52b6904ff95514b68f4a72aa48dfeb0f/gssc/loader/dataset/__init__.py -------------------------------------------------------------------------------- /gssc/loader/dataset/aqsol_molecules.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import shutil 4 | import pickle 5 | 6 | import torch 7 | from tqdm import tqdm 8 | from torch_geometric.data import (InMemoryDataset, Data, download_url, 9 | extract_zip) 10 | from torch_geometric.utils import add_self_loops 11 | 12 | 13 | class AQSOL(InMemoryDataset): 14 | r"""The AQSOL dataset from Benchmarking GNNs (Dwivedi et al., 2020) is based on AqSolDB 15 | (Sorkun et al., 2019) which is a standardized database of 9,982 molecular graphs with 16 | their aqueous solubility values, collected from 9 different data sources. 17 | 18 | The aqueous solubility targets are collected from experimental measurements and standardized 19 | to LogS units in AqSolDB. These final values as the property to regress in the AQSOL dataset 20 | which is the resultant collection in 'Benchmarking GNNs' after filtering out few graphs 21 | with no bonds/edges and a small number of graphs with missing node feature values. 22 | 23 | Thus, the total molecular graphs are 9,823. For each molecular graph, the node features are the 24 | types f heavy atoms and the edge features are the types of bonds between them, similar as ZINC. 25 | 26 | Size of Dataset: 9,982 molecules. 27 | Split: Scaffold split (8:1:1) following same code as OGB. 28 | After cleaning: 7,831 train / 996 val / 996 test 29 | Number of (unique) atoms: 65 30 | Number of (unique) bonds: 5 31 | Performance Metric: MAE, same as ZINC 32 | 33 | Atom Dict: {'Br': 0, 'C': 1, 'N': 2, 'O': 3, 'Cl': 4, 'Zn': 5, 'F': 6, 'P': 7, 'S': 8, 'Na': 9, 'Al': 10, 34 | 'Si': 11, 'Mo': 12, 'Ca': 13, 'W': 14, 'Pb': 15, 'B': 16, 'V': 17, 'Co': 18, 'Mg': 19, 'Bi': 20, 'Fe': 21, 35 | 'Ba': 22, 'K': 23, 'Ti': 24, 'Sn': 25, 'Cd': 26, 'I': 27, 'Re': 28, 'Sr': 29, 'H': 30, 'Cu': 31, 'Ni': 32, 36 | 'Lu': 33, 'Pr': 34, 'Te': 35, 'Ce': 36, 'Nd': 37, 'Gd': 38, 'Zr': 39, 'Mn': 40, 'As': 41, 'Hg': 42, 'Sb': 37 | 43, 'Cr': 44, 'Se': 45, 'La': 46, 'Dy': 47, 'Y': 48, 'Pd': 49, 'Ag': 50, 'In': 51, 'Li': 52, 'Rh': 53, 38 | 'Nb': 54, 'Hf': 55, 'Cs': 56, 'Ru': 57, 'Au': 58, 'Sm': 59, 'Ta': 60, 'Pt': 61, 'Ir': 62, 'Be': 63, 'Ge': 64} 39 | 40 | Bond Dict: {'NONE': 0, 'SINGLE': 1, 'DOUBLE': 2, 'AROMATIC': 3, 'TRIPLE': 4} 41 | 42 | Args: 43 | root (string): Root directory where the dataset should be saved. 44 | transform (callable, optional): A function/transform that takes in an 45 | :obj:`torch_geometric.data.Data` object and returns a transformed 46 | version. The data object will be transformed before every access. 47 | (default: :obj:`None`) 48 | pre_transform (callable, optional): A function/transform that takes in 49 | an :obj:`torch_geometric.data.Data` object and returns a 50 | transformed version. The data object will be transformed before 51 | being saved to disk. (default: :obj:`None`) 52 | pre_filter (callable, optional): A function that takes in an 53 | :obj:`torch_geometric.data.Data` object and returns a boolean 54 | value, indicating whether the data object should be included in the 55 | final dataset. (default: :obj:`None`) 56 | """ 57 | 58 | url = 'https://www.dropbox.com/s/lzu9lmukwov12kt/aqsol_graph_raw.zip?dl=1' 59 | 60 | def __init__(self, root, split='train', transform=None, pre_transform=None, 61 | pre_filter=None): 62 | self.name = "AQSOL" 63 | assert split in ['train', 'val', 'test'] 64 | super().__init__(root, transform, pre_transform, pre_filter) 65 | path = osp.join(self.processed_dir, f'{split}.pt') 66 | self.data, self.slices = torch.load(path) 67 | 68 | 69 | @property 70 | def raw_file_names(self): 71 | return ['train.pickle', 'val.pickle', 'test.pickle'] 72 | 73 | @property 74 | def processed_file_names(self): 75 | return ['train.pt', 'val.pt', 'test.pt'] 76 | 77 | def download(self): 78 | shutil.rmtree(self.raw_dir) 79 | path = download_url(self.url, self.root) 80 | extract_zip(path, self.root) 81 | os.rename(osp.join(self.root, 'asqol_graph_raw'), self.raw_dir) 82 | os.unlink(path) 83 | 84 | def process(self): 85 | for split in ['train', 'val', 'test']: 86 | with open(osp.join(self.raw_dir, f'{split}.pickle'), 'rb') as f: 87 | graphs = pickle.load(f) 88 | 89 | indices = range(len(graphs)) 90 | 91 | pbar = tqdm(total=len(indices)) 92 | pbar.set_description(f'Processing {split} dataset') 93 | 94 | data_list = [] 95 | for idx in indices: 96 | graph = graphs[idx] 97 | 98 | """ 99 | Each `graph` is a tuple (x, edge_attr, edge_index, y) 100 | Shape of x : [num_nodes, 1] 101 | Shape of edge_attr : [num_edges] 102 | Shape of edge_index : [2, num_edges] 103 | Shape of y : [1] 104 | """ 105 | 106 | x = torch.LongTensor(graph[0]).unsqueeze(-1) 107 | edge_attr = torch.LongTensor(graph[1])#.unsqueeze(-1) 108 | edge_index = torch.LongTensor(graph[2]) 109 | y = torch.tensor(graph[3]) 110 | 111 | data = Data(edge_index=edge_index) 112 | 113 | if edge_index.shape[1] == 0: 114 | continue # skipping for graphs with no bonds/edges 115 | 116 | if data.num_nodes != len(x): 117 | continue # cleaning <10 graphs with this discrepancy 118 | 119 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, 120 | y=y) 121 | 122 | if self.pre_filter is not None and not self.pre_filter(data): 123 | continue 124 | 125 | if self.pre_transform is not None: 126 | data = self.pre_transform(data) 127 | 128 | data_list.append(data) 129 | pbar.update(1) 130 | 131 | pbar.close() 132 | torch.save(self.collate(data_list), 133 | osp.join(self.processed_dir, f'{split}.pt')) 134 | -------------------------------------------------------------------------------- /gssc/loader/dataset/count_cycle.py: -------------------------------------------------------------------------------- 1 | """ 2 | modified from https://github.com/GraphPKU/I2GNN/blob/master/data_processing.py 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | from torch_geometric.data import InMemoryDataset, Data 8 | import os 9 | from torch_geometric.data import Data, InMemoryDataset 10 | import scipy.io as scio 11 | 12 | 13 | class CountCycle(InMemoryDataset): 14 | def __init__( 15 | self, 16 | dataname="count_cycle", 17 | root="dataset", 18 | processed_name="processed", 19 | split="train", 20 | yidx: int = 0, 21 | ymean: float = 0, 22 | ystd: float = 1, 23 | ymean_log: float = 0, 24 | ystd_log: float = 1, 25 | replace=False, 26 | transform=None, 27 | ): 28 | self.root = root 29 | self.dataname = dataname 30 | self.raw = os.path.join(root, dataname) 31 | self.processed = os.path.join(root, dataname, processed_name) 32 | super(CountCycle, self).__init__(root=root, transform=transform, pre_transform=None, pre_filter=None) 33 | split_id = 0 if split == "train" else 1 if split == "val" else 2 34 | data, slices = torch.load(self.processed_paths[split_id]) 35 | 36 | data.log_y = torch.log10(data.y + 1.0) 37 | data.log_y = (data.log_y[:, [yidx]] - ymean_log) / ystd_log 38 | data.log_y = data.log_y.reshape(-1) 39 | 40 | data.y = (data.y[:, [yidx]] - ymean) / ystd 41 | data.y = data.y.reshape(-1) 42 | 43 | if replace: 44 | data.y = data.log_y 45 | 46 | self.data, self.slices = data, slices 47 | self.mean, self.std, self.log_mean, self.log_std = ymean, ystd, ymean_log, ystd_log 48 | # ((10**((data.log_y * ystd_log) + ymean_log) - 1) - ymean) / ystd 49 | 50 | @property 51 | def raw_dir(self): 52 | name = "raw" 53 | return os.path.join(self.root, self.dataname, name) 54 | 55 | @property 56 | def processed_dir(self): 57 | return self.processed 58 | 59 | @property 60 | def raw_file_names(self): 61 | names = ["data"] 62 | return ["{}.mat".format(name) for name in names] 63 | 64 | @property 65 | def processed_file_names(self): 66 | return ["data_tr.pt", "data_val.pt", "data_te.pt"] 67 | 68 | def adj2data(self, A, y): 69 | # x: (n, d), A: (e, n, n) 70 | # begin, end = np.where(np.sum(A, axis=0) == 1.) 71 | begin, end = np.where(A == 1.0) 72 | edge_index = torch.tensor(np.array([begin, end])) 73 | num_nodes = A.shape[0] 74 | if y.ndim == 1: 75 | y = y.reshape([1, -1]) 76 | x = torch.ones((num_nodes, 1), dtype=torch.long) 77 | return Data(x=x, edge_index=edge_index, y=torch.tensor(y), num_nodes=torch.tensor([num_nodes])) 78 | 79 | def process(self): 80 | # process npy data into pyg.Data 81 | print("Processing data from " + self.raw_dir + "...") 82 | raw_data = scio.loadmat(self.raw_paths[0]) 83 | if raw_data["F"].shape[0] == 1: 84 | data_list_all = [ 85 | [self.adj2data(raw_data["A"][0][i], raw_data["F"][0][i]) for i in idx] 86 | for idx in [raw_data["train_idx"][0], raw_data["val_idx"][0], raw_data["test_idx"][0]] 87 | ] 88 | else: 89 | data_list_all = [ 90 | [self.adj2data(A, y) for A, y in zip(raw_data["A"][0][idx][0], raw_data["F"][idx][0])] 91 | for idx in [raw_data["train_idx"], raw_data["val_idx"], raw_data["test_idx"]] 92 | ] 93 | for save_path, data_list in zip(self.processed_paths, data_list_all): 94 | print("pre-transforming for data at" + save_path) 95 | if self.pre_filter is not None: 96 | data_list = [data for data in data_list if self.pre_filter(data)] 97 | if self.pre_transform is not None: 98 | temp = [] 99 | for i, data in enumerate(data_list): 100 | if i % 100 == 0: 101 | print("Pre-processing %d/%d" % (i, len(data_list))) 102 | temp.append(self.pre_transform(data)) 103 | data_list = temp 104 | # data_list = [self.pre_transform(data) for data in data_list] 105 | data, slices = self.collate(data_list) 106 | torch.save((data, slices), save_path) 107 | -------------------------------------------------------------------------------- /gssc/loader/dataset/malnet_tiny.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable, List 2 | 3 | import os 4 | import glob 5 | import os.path as osp 6 | 7 | import torch 8 | from torch_geometric.data import (InMemoryDataset, Data, download_url, 9 | extract_tar, extract_zip) 10 | from torch_geometric.utils import remove_isolated_nodes 11 | 12 | """ 13 | This is a local copy of MalNetTiny class from PyG 14 | https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/datasets/malnet_tiny.py 15 | 16 | TODO: Delete and use PyG's version once it is part of a released version. 17 | At the time of writing this class is in the main PyG github branch but is not 18 | included in the current latest released version 2.0.2. 19 | """ 20 | 21 | class MalNetTiny(InMemoryDataset): 22 | r"""The MalNet Tiny dataset from the 23 | `"A Large-Scale Database for Graph Representation Learning" 24 | `_ paper. 25 | :class:`MalNetTiny` contains 5,000 malicious and benign software function 26 | call graphs across 5 different types. Each graph contains at most 5k nodes. 27 | 28 | Args: 29 | root (string): Root directory where the dataset should be saved. 30 | transform (callable, optional): A function/transform that takes in an 31 | :obj:`torch_geometric.data.Data` object and returns a transformed 32 | version. The data object will be transformed before every access. 33 | (default: :obj:`None`) 34 | pre_transform (callable, optional): A function/transform that takes in 35 | an :obj:`torch_geometric.data.Data` object and returns a 36 | transformed version. The data object will be transformed before 37 | being saved to disk. (default: :obj:`None`) 38 | pre_filter (callable, optional): A function that takes in an 39 | :obj:`torch_geometric.data.Data` object and returns a boolean 40 | value, indicating whether the data object should be included in the 41 | final dataset. (default: :obj:`None`) 42 | """ 43 | 44 | url = 'http://malnet.cc.gatech.edu/graph-data/malnet-graphs-tiny.tar.gz' 45 | # 70/10/20 train, val, test split by type 46 | split_url = 'http://malnet.cc.gatech.edu/split-info/split_info_tiny.zip' 47 | 48 | def __init__(self, root: str, transform: Optional[Callable] = None, 49 | pre_transform: Optional[Callable] = None, 50 | pre_filter: Optional[Callable] = None): 51 | super().__init__(root, transform, pre_transform, pre_filter) 52 | self.data, self.slices = torch.load(self.processed_paths[0]) 53 | 54 | @property 55 | def raw_file_names(self) -> List[str]: 56 | folders = ['addisplay', 'adware', 'benign', 'downloader', 'trojan'] 57 | return [osp.join('malnet-graphs-tiny', folder) for folder in folders] 58 | 59 | @property 60 | def processed_file_names(self) -> List[str]: 61 | return ['data.pt', 'split_dict.pt'] 62 | 63 | def download(self): 64 | path = download_url(self.url, self.raw_dir) 65 | extract_tar(path, self.raw_dir) 66 | os.unlink(path) 67 | path = download_url(self.split_url, self.raw_dir) 68 | extract_zip(path, self.raw_dir) 69 | os.unlink(path) 70 | 71 | def process(self): 72 | data_list = [] 73 | split_dict = {'train': [], 'valid': [], 'test': []} 74 | 75 | parse = lambda f: set([x.split('/')[-1] 76 | for x in f.read().split('\n')[:-1]]) # -1 for empty line at EOF 77 | split_dir = osp.join(self.raw_dir, 'split_info_tiny', 'type') 78 | with open(osp.join(split_dir, 'train.txt'), 'r') as f: 79 | train_names = parse(f) 80 | assert len(train_names) == 3500 81 | with open(osp.join(split_dir, 'val.txt'), 'r') as f: 82 | val_names = parse(f) 83 | assert len(val_names) == 500 84 | with open(osp.join(split_dir, 'test.txt'), 'r') as f: 85 | test_names = parse(f) 86 | assert len(test_names) == 1000 87 | 88 | for y, raw_path in enumerate(self.raw_paths): 89 | raw_path = osp.join(raw_path, os.listdir(raw_path)[0]) 90 | filenames = glob.glob(osp.join(raw_path, '*.edgelist')) 91 | 92 | for filename in filenames: 93 | with open(filename, 'r') as f: 94 | edges = f.read().split('\n')[5:-1] 95 | edge_index = [[int(s) for s in edge.split()] for edge in edges] 96 | edge_index = torch.tensor(edge_index).t().contiguous() 97 | # Remove isolated nodes, including those with only a self-loop 98 | edge_index = remove_isolated_nodes(edge_index)[0] 99 | num_nodes = int(edge_index.max()) + 1 100 | data = Data(edge_index=edge_index, y=y, num_nodes=num_nodes) 101 | data_list.append(data) 102 | 103 | ind = len(data_list) - 1 104 | graph_id = osp.splitext(osp.basename(filename))[0] 105 | if graph_id in train_names: 106 | split_dict['train'].append(ind) 107 | elif graph_id in val_names: 108 | split_dict['valid'].append(ind) 109 | elif graph_id in test_names: 110 | split_dict['test'].append(ind) 111 | else: 112 | raise ValueError(f'No split assignment for "{graph_id}".') 113 | 114 | if self.pre_filter is not None: 115 | data_list = [data for data in data_list if self.pre_filter(data)] 116 | 117 | if self.pre_transform is not None: 118 | data_list = [self.pre_transform(data) for data in data_list] 119 | 120 | torch.save(self.collate(data_list), self.processed_paths[0]) 121 | torch.save(split_dict, self.processed_paths[1]) 122 | 123 | def get_idx_split(self): 124 | return torch.load(self.processed_paths[1]) 125 | -------------------------------------------------------------------------------- /gssc/loader/dataset/peptides_functional.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os.path as osp 3 | import pickle 4 | import shutil 5 | 6 | import pandas as pd 7 | import torch 8 | from ogb.utils import smiles2graph 9 | from ogb.utils.torch_util import replace_numpy_with_torchtensor 10 | from ogb.utils.url import decide_download 11 | from torch_geometric.data import Data, InMemoryDataset, download_url 12 | from tqdm import tqdm 13 | 14 | 15 | class PeptidesFunctionalDataset(InMemoryDataset): 16 | def __init__(self, root='datasets', smiles2graph=smiles2graph, 17 | transform=None, pre_transform=None): 18 | """ 19 | PyG dataset of 15,535 peptides represented as their molecular graph 20 | (SMILES) with 10-way multi-task binary classification of their 21 | functional classes. 22 | 23 | The goal is use the molecular representation of peptides instead 24 | of amino acid sequence representation ('peptide_seq' field in the file, 25 | provided for possible baseline benchmarking but not used here) to test 26 | GNNs' representation capability. 27 | 28 | The 10 classes represent the following functional classes (in order): 29 | ['antifungal', 'cell_cell_communication', 'anticancer', 30 | 'drug_delivery_vehicle', 'antimicrobial', 'antiviral', 31 | 'antihypertensive', 'antibacterial', 'antiparasitic', 'toxic'] 32 | 33 | Args: 34 | root (string): Root directory where the dataset should be saved. 35 | smiles2graph (callable): A callable function that converts a SMILES 36 | string into a graph object. We use the OGB featurization. 37 | * The default smiles2graph requires rdkit to be installed * 38 | """ 39 | 40 | self.original_root = root 41 | self.smiles2graph = smiles2graph 42 | self.folder = osp.join(root, 'peptides-functional') 43 | 44 | self.url = 'https://www.dropbox.com/s/ol2v01usvaxbsr8/peptide_multi_class_dataset.csv.gz?dl=1' 45 | self.version = '701eb743e899f4d793f0e13c8fa5a1b4' # MD5 hash of the intended dataset file 46 | self.url_stratified_split = 'https://www.dropbox.com/s/j4zcnx2eipuo0xz/splits_random_stratified_peptide.pickle?dl=1' 47 | self.md5sum_stratified_split = '5a0114bdadc80b94fc7ae974f13ef061' 48 | 49 | # Check version and update if necessary. 50 | release_tag = osp.join(self.folder, self.version) 51 | if osp.isdir(self.folder) and (not osp.exists(release_tag)): 52 | print(f"{self.__class__.__name__} has been updated.") 53 | if input("Will you update the dataset now? (y/N)\n").lower() == 'y': 54 | shutil.rmtree(self.folder) 55 | 56 | super().__init__(self.folder, transform, pre_transform) 57 | self.data, self.slices = torch.load(self.processed_paths[0]) 58 | 59 | @property 60 | def raw_file_names(self): 61 | return 'peptide_multi_class_dataset.csv.gz' 62 | 63 | @property 64 | def processed_file_names(self): 65 | return 'geometric_data_processed.pt' 66 | 67 | def _md5sum(self, path): 68 | hash_md5 = hashlib.md5() 69 | with open(path, 'rb') as f: 70 | buffer = f.read() 71 | hash_md5.update(buffer) 72 | return hash_md5.hexdigest() 73 | 74 | def download(self): 75 | if decide_download(self.url): 76 | path = download_url(self.url, self.raw_dir) 77 | # Save to disk the MD5 hash of the downloaded file. 78 | hash = self._md5sum(path) 79 | if hash != self.version: 80 | raise ValueError("Unexpected MD5 hash of the downloaded file") 81 | open(osp.join(self.root, hash), 'w').close() 82 | # Download train/val/test splits. 83 | path_split1 = download_url(self.url_stratified_split, self.root) 84 | assert self._md5sum(path_split1) == self.md5sum_stratified_split 85 | else: 86 | print('Stop download.') 87 | exit(-1) 88 | 89 | def process(self): 90 | data_df = pd.read_csv(osp.join(self.raw_dir, 91 | 'peptide_multi_class_dataset.csv.gz')) 92 | smiles_list = data_df['smiles'] 93 | 94 | print('Converting SMILES strings into graphs...') 95 | data_list = [] 96 | for i in tqdm(range(len(smiles_list))): 97 | data = Data() 98 | 99 | smiles = smiles_list[i] 100 | graph = self.smiles2graph(smiles) 101 | 102 | assert (len(graph['edge_feat']) == graph['edge_index'].shape[1]) 103 | assert (len(graph['node_feat']) == graph['num_nodes']) 104 | 105 | data.__num_nodes__ = int(graph['num_nodes']) 106 | data.edge_index = torch.from_numpy(graph['edge_index']).to( 107 | torch.int64) 108 | data.edge_attr = torch.from_numpy(graph['edge_feat']).to( 109 | torch.int64) 110 | data.x = torch.from_numpy(graph['node_feat']).to(torch.int64) 111 | data.y = torch.Tensor([eval(data_df['labels'].iloc[i])]) 112 | 113 | data_list.append(data) 114 | 115 | if self.pre_transform is not None: 116 | data_list = [self.pre_transform(data) for data in data_list] 117 | 118 | data, slices = self.collate(data_list) 119 | 120 | print('Saving...') 121 | torch.save((data, slices), self.processed_paths[0]) 122 | 123 | def get_idx_split(self): 124 | """ Get dataset splits. 125 | 126 | Returns: 127 | Dict with 'train', 'val', 'test', splits indices. 128 | """ 129 | split_file = osp.join(self.root, 130 | "splits_random_stratified_peptide.pickle") 131 | with open(split_file, 'rb') as f: 132 | splits = pickle.load(f) 133 | split_dict = replace_numpy_with_torchtensor(splits) 134 | return split_dict 135 | 136 | 137 | if __name__ == '__main__': 138 | dataset = PeptidesFunctionalDataset() 139 | print(dataset) 140 | print(dataset.data.edge_index) 141 | print(dataset.data.edge_index.shape) 142 | print(dataset.data.x.shape) 143 | print(dataset[100]) 144 | print(dataset[100].y) 145 | print(dataset.get_idx_split()) 146 | -------------------------------------------------------------------------------- /gssc/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /gssc/loss/l1.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch_geometric.graphgym.config import cfg 3 | from torch_geometric.graphgym.register import register_loss 4 | 5 | 6 | @register_loss('l1_losses') 7 | def l1_losses(pred, true): 8 | if cfg.model.loss_fun == 'l1': 9 | l1_loss = nn.L1Loss() 10 | loss = l1_loss(pred, true) 11 | return loss, pred 12 | elif cfg.model.loss_fun == 'smoothl1': 13 | l1_loss = nn.SmoothL1Loss() 14 | loss = l1_loss(pred, true) 15 | return loss, pred 16 | elif cfg.model.loss_fun == 'mse': 17 | l1_loss = nn.MSELoss() 18 | loss = l1_loss(pred, true) 19 | return loss, pred 20 | -------------------------------------------------------------------------------- /gssc/loss/multilabel_classification_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch_geometric.graphgym.config import cfg 3 | from torch_geometric.graphgym.register import register_loss 4 | 5 | 6 | @register_loss('multilabel_cross_entropy') 7 | def multilabel_cross_entropy(pred, true): 8 | """Multilabel cross-entropy loss. 9 | """ 10 | if cfg.dataset.task_type == 'classification_multilabel': 11 | if cfg.model.loss_fun != 'cross_entropy': 12 | raise ValueError("Only 'cross_entropy' loss_fun supported with " 13 | "'classification_multilabel' task_type.") 14 | bce_loss = nn.BCEWithLogitsLoss() 15 | is_labeled = true == true # Filter our nans. 16 | return bce_loss(pred[is_labeled], true[is_labeled].float()), pred 17 | -------------------------------------------------------------------------------- /gssc/loss/subtoken_prediction_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.graphgym.config import cfg 3 | from torch_geometric.graphgym.register import register_loss 4 | 5 | 6 | @register_loss('subtoken_cross_entropy') 7 | def subtoken_cross_entropy(pred_list, true): 8 | """Subtoken prediction cross-entropy loss for ogbg-code2. 9 | """ 10 | if cfg.dataset.task_type == 'subtoken_prediction': 11 | if cfg.model.loss_fun != 'cross_entropy': 12 | raise ValueError("Only 'cross_entropy' loss_fun supported with " 13 | "'subtoken_prediction' task_type.") 14 | multicls_criterion = torch.nn.CrossEntropyLoss() 15 | loss = 0 16 | for i in range(len(pred_list)): 17 | loss += multicls_criterion(pred_list[i].to(torch.float32), true['y_arr'][:, i]) 18 | loss = loss / len(pred_list) 19 | 20 | return loss, pred_list 21 | -------------------------------------------------------------------------------- /gssc/loss/weighted_cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch_geometric.graphgym.config import cfg 4 | from torch_geometric.graphgym.register import register_loss 5 | 6 | 7 | @register_loss('weighted_cross_entropy') 8 | def weighted_cross_entropy(pred, true): 9 | """Weighted cross-entropy for unbalanced classes. 10 | """ 11 | if cfg.model.loss_fun == 'weighted_cross_entropy': 12 | # calculating label weights for weighted loss computation 13 | V = true.size(0) 14 | n_classes = pred.shape[1] if pred.ndim > 1 else 2 15 | label_count = torch.bincount(true) 16 | label_count = label_count[label_count.nonzero(as_tuple=True)].squeeze() 17 | cluster_sizes = torch.zeros(n_classes, device=pred.device).long() 18 | cluster_sizes[torch.unique(true)] = label_count 19 | weight = (V - cluster_sizes).float() / V 20 | weight *= (cluster_sizes > 0).float() 21 | # multiclass 22 | if pred.ndim > 1: 23 | pred = F.log_softmax(pred, dim=-1) 24 | return F.nll_loss(pred, true, weight=weight), pred 25 | # binary 26 | else: 27 | loss = F.binary_cross_entropy_with_logits(pred, true.float(), 28 | weight=weight[true]) 29 | return loss, torch.sigmoid(pred) 30 | -------------------------------------------------------------------------------- /gssc/metrics_ogb.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import roc_auc_score, average_precision_score 3 | 4 | """ 5 | Evaluation functions from OGB. 6 | https://github.com/snap-stanford/ogb/blob/master/ogb/graphproppred/evaluate.py 7 | """ 8 | 9 | def eval_rocauc(y_true, y_pred): 10 | ''' 11 | compute ROC-AUC averaged across tasks 12 | ''' 13 | 14 | rocauc_list = [] 15 | 16 | for i in range(y_true.shape[1]): 17 | # AUC is only defined when there is at least one positive data. 18 | if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0: 19 | # ignore nan values 20 | is_labeled = y_true[:, i] == y_true[:, i] 21 | rocauc_list.append( 22 | roc_auc_score(y_true[is_labeled, i], y_pred[is_labeled, i])) 23 | 24 | if len(rocauc_list) == 0: 25 | raise RuntimeError( 26 | 'No positively labeled data available. Cannot compute ROC-AUC.') 27 | 28 | return {'rocauc': sum(rocauc_list) / len(rocauc_list)} 29 | 30 | 31 | def eval_ap(y_true, y_pred): 32 | ''' 33 | compute Average Precision (AP) averaged across tasks 34 | ''' 35 | 36 | ap_list = [] 37 | 38 | for i in range(y_true.shape[1]): 39 | # AUC is only defined when there is at least one positive data. 40 | if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0: 41 | # ignore nan values 42 | is_labeled = y_true[:, i] == y_true[:, i] 43 | ap = average_precision_score(y_true[is_labeled, i], 44 | y_pred[is_labeled, i]) 45 | 46 | ap_list.append(ap) 47 | 48 | if len(ap_list) == 0: 49 | raise RuntimeError( 50 | 'No positively labeled data available. Cannot compute Average Precision.') 51 | 52 | return {'ap': sum(ap_list) / len(ap_list)} 53 | 54 | 55 | def eval_rmse(y_true, y_pred): 56 | ''' 57 | compute RMSE score averaged across tasks 58 | ''' 59 | rmse_list = [] 60 | 61 | for i in range(y_true.shape[1]): 62 | # ignore nan values 63 | is_labeled = y_true[:, i] == y_true[:, i] 64 | rmse_list.append(np.sqrt( 65 | ((y_true[is_labeled, i] - y_pred[is_labeled, i]) ** 2).mean())) 66 | 67 | return {'rmse': sum(rmse_list) / len(rmse_list)} 68 | 69 | 70 | def eval_acc(y_true, y_pred): 71 | acc_list = [] 72 | 73 | for i in range(y_true.shape[1]): 74 | is_labeled = y_true[:, i] == y_true[:, i] 75 | correct = y_true[is_labeled, i] == y_pred[is_labeled, i] 76 | acc_list.append(float(np.sum(correct)) / len(correct)) 77 | 78 | return {'acc': sum(acc_list) / len(acc_list)} 79 | 80 | 81 | def eval_F1(seq_ref, seq_pred): 82 | # ''' 83 | # compute F1 score averaged over samples 84 | # ''' 85 | 86 | precision_list = [] 87 | recall_list = [] 88 | f1_list = [] 89 | 90 | for l, p in zip(seq_ref, seq_pred): 91 | label = set(l) 92 | prediction = set(p) 93 | true_positive = len(label.intersection(prediction)) 94 | false_positive = len(prediction - label) 95 | false_negative = len(label - prediction) 96 | 97 | if true_positive + false_positive > 0: 98 | precision = true_positive / (true_positive + false_positive) 99 | else: 100 | precision = 0 101 | 102 | if true_positive + false_negative > 0: 103 | recall = true_positive / (true_positive + false_negative) 104 | else: 105 | recall = 0 106 | if precision + recall > 0: 107 | f1 = 2 * precision * recall / (precision + recall) 108 | else: 109 | f1 = 0 110 | 111 | precision_list.append(precision) 112 | recall_list.append(recall) 113 | f1_list.append(f1) 114 | 115 | return {'precision': np.average(precision_list), 116 | 'recall': np.average(recall_list), 117 | 'F1': np.average(f1_list)} 118 | -------------------------------------------------------------------------------- /gssc/network/MaskedReduce.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, BoolTensor 3 | ''' 4 | x (B, N, d) 5 | mask (B, N, 1) 6 | ''' 7 | def maskedSum(x: Tensor, mask: BoolTensor, dim: int): 8 | ''' 9 | mask true elements 10 | ''' 11 | return torch.sum(torch.where(mask, 0, x), dim=dim) 12 | 13 | def maskedMean(x: Tensor, mask: BoolTensor, dim: int, gsize: Tensor = None): 14 | ''' 15 | mask true elements 16 | ''' 17 | if gsize is None: 18 | gsize = x.shape[dim] - torch.sum(mask, dim=dim) 19 | return torch.sum(torch.where(mask, 0, x), dim=dim)/gsize 20 | 21 | def maskedMax(x: Tensor, mask: BoolTensor, dim: int): 22 | return torch.max(torch.where(mask, -torch.inf, x), dim=dim)[0] 23 | 24 | def maskedMin(x: Tensor, mask: BoolTensor, dim: int): 25 | return torch.min(torch.where(mask, torch.inf, x), dim=dim)[0] 26 | 27 | def maskednone(x: Tensor, mask: BoolTensor, dim: int): 28 | return x 29 | 30 | reduce_dict = { 31 | "sum": maskedSum, 32 | "mean": maskedMean, 33 | "max": maskedMax, 34 | "none": maskednone 35 | } -------------------------------------------------------------------------------- /gssc/network/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /gssc/network/big_bird.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric.graphgym.register as register 3 | from torch_geometric.graphgym.config import cfg 4 | from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNPreMP 5 | from torch_geometric.graphgym.register import register_network 6 | 7 | from gssc.layer.bigbird_layer import BigBirdModel as BackboneBigBird 8 | 9 | 10 | @register_network('BigBird') 11 | class BigBird(torch.nn.Module): 12 | """BigBird without edge features. 13 | This model disregards edge features and runs a linear transformer over a set of node features only. 14 | BirBird applies random sparse attention to the input sequence - the longer the sequence the closer it is to O(N) 15 | https://arxiv.org/abs/2007.14062 16 | """ 17 | 18 | def __init__(self, dim_in, dim_out): 19 | super().__init__() 20 | self.encoder = FeatureEncoder(dim_in) 21 | dim_in = self.encoder.dim_in 22 | 23 | if cfg.gnn.layers_pre_mp > 0: 24 | self.pre_mp = GNNPreMP( 25 | dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) 26 | dim_in = cfg.gnn.dim_inner 27 | 28 | assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \ 29 | "The inner and hidden dims must match." 30 | 31 | # Copy main Transformer hyperparams to the BigBird config. 32 | cfg.gt.bigbird.layers = cfg.gt.layers 33 | cfg.gt.bigbird.n_heads = cfg.gt.n_heads 34 | cfg.gt.bigbird.dim_hidden = cfg.gt.dim_hidden 35 | cfg.gt.bigbird.dropout = cfg.gt.dropout 36 | self.trf = BackboneBigBird( 37 | config=cfg.gt.bigbird, 38 | ) 39 | 40 | GNNHead = register.head_dict[cfg.gnn.head] 41 | self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) 42 | 43 | def forward(self, batch): 44 | for module in self.children(): 45 | batch = module(batch) 46 | return batch 47 | -------------------------------------------------------------------------------- /gssc/network/custom_gnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric.graphgym.models.head # noqa, register module 3 | import torch_geometric.graphgym.register as register 4 | from torch_geometric.graphgym.config import cfg 5 | from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNPreMP 6 | from torch_geometric.graphgym.register import register_network 7 | 8 | from gssc.layer.gatedgcn_layer import GatedGCNLayer 9 | from gssc.layer.gine_conv_layer import GINEConvLayer 10 | 11 | 12 | @register_network('custom_gnn') 13 | class CustomGNN(torch.nn.Module): 14 | """ 15 | GNN model that customizes the torch_geometric.graphgym.models.gnn.GNN 16 | to support specific handling of new conv layers. 17 | """ 18 | 19 | def __init__(self, dim_in, dim_out): 20 | super().__init__() 21 | self.encoder = FeatureEncoder(dim_in) 22 | dim_in = self.encoder.dim_in 23 | 24 | if cfg.gnn.layers_pre_mp > 0: 25 | self.pre_mp = GNNPreMP( 26 | dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) 27 | dim_in = cfg.gnn.dim_inner 28 | 29 | assert cfg.gnn.dim_inner == dim_in, \ 30 | "The inner and hidden dims must match." 31 | 32 | conv_model = self.build_conv_model(cfg.gnn.layer_type) 33 | layers = [] 34 | for _ in range(cfg.gnn.layers_mp): 35 | layers.append(conv_model(dim_in, 36 | dim_in, 37 | dropout=cfg.gnn.dropout, 38 | residual=cfg.gnn.residual)) 39 | self.gnn_layers = torch.nn.Sequential(*layers) 40 | 41 | GNNHead = register.head_dict[cfg.gnn.head] 42 | self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) 43 | 44 | def build_conv_model(self, model_type): 45 | if model_type == 'gatedgcnconv': 46 | return GatedGCNLayer 47 | elif model_type == 'gineconv': 48 | return GINEConvLayer 49 | else: 50 | raise ValueError("Model {} unavailable".format(model_type)) 51 | 52 | def forward(self, batch): 53 | for module in self.children(): 54 | batch = module(batch) 55 | return batch 56 | -------------------------------------------------------------------------------- /gssc/network/example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import torch_geometric.graphgym.models.head # noqa, register module 6 | import torch_geometric.graphgym.register as register 7 | import torch_geometric.nn as pyg_nn 8 | from torch_geometric.graphgym.config import cfg 9 | from torch_geometric.graphgym.register import register_network 10 | 11 | 12 | @register_network('example') 13 | class ExampleGNN(torch.nn.Module): 14 | def __init__(self, dim_in, dim_out, num_layers=2, model_type='GCN'): 15 | super().__init__() 16 | conv_model = self.build_conv_model(model_type) 17 | self.convs = nn.ModuleList() 18 | self.convs.append(conv_model(dim_in, dim_in)) 19 | 20 | for _ in range(num_layers - 1): 21 | self.convs.append(conv_model(dim_in, dim_in)) 22 | 23 | GNNHead = register.head_dict[cfg.dataset.task] 24 | self.post_mp = GNNHead(dim_in=dim_in, dim_out=dim_out) 25 | 26 | def build_conv_model(self, model_type): 27 | if model_type == 'GCN': 28 | return pyg_nn.GCNConv 29 | elif model_type == 'GAT': 30 | return pyg_nn.GATConv 31 | elif model_type == "GraphSage": 32 | return pyg_nn.SAGEConv 33 | else: 34 | raise ValueError(f'Model {model_type} unavailable') 35 | 36 | def forward(self, batch): 37 | x, edge_index = batch.x, batch.edge_index 38 | 39 | for i in range(len(self.convs)): 40 | x = self.convs[i](x, edge_index) 41 | x = F.relu(x) 42 | x = F.dropout(x, p=0.1, training=self.training) 43 | 44 | batch.x = x 45 | batch = self.post_mp(batch) 46 | 47 | return batch 48 | -------------------------------------------------------------------------------- /gssc/network/gps_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric.graphgym.register as register 3 | from torch_geometric.graphgym.config import cfg 4 | from torch_geometric.graphgym.models.gnn import GNNPreMP 5 | from torch_geometric.graphgym.models.layer import new_layer_config, BatchNorm1dNode 6 | from torch_geometric.graphgym.register import register_network 7 | from gssc.encoder.ER_edge_encoder import EREdgeEncoder 8 | from gssc.layer.gps_layer import GPSLayer 9 | from gssc.network.utils import InitPEs 10 | 11 | 12 | class FeatureEncoder(torch.nn.Module): 13 | """ 14 | Encoding node and edge features 15 | 16 | Args: 17 | dim_in (int): Input feature dimension 18 | """ 19 | 20 | def __init__(self, dim_in): 21 | super(FeatureEncoder, self).__init__() 22 | self.dim_in = dim_in 23 | if cfg.dataset.node_encoder: 24 | # Encode integer node features via nn.Embeddings 25 | NodeEncoder = register.node_encoder_dict[cfg.dataset.node_encoder_name] 26 | self.node_encoder = NodeEncoder(cfg.gnn.dim_inner) 27 | if cfg.dataset.node_encoder_bn: 28 | self.node_encoder_bn = BatchNorm1dNode( 29 | new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False, has_bias=False, cfg=cfg) 30 | ) 31 | # Update dim_in to reflect the new dimension fo the node features 32 | self.dim_in = cfg.gnn.dim_inner 33 | if cfg.dataset.edge_encoder: 34 | # Hard-set edge dim for PNA. 35 | cfg.gnn.dim_edge = 16 if "PNA" in cfg.gt.layer_type else cfg.gnn.dim_inner 36 | if cfg.dataset.edge_encoder_name == "ER": 37 | self.edge_encoder = EREdgeEncoder(cfg.gnn.dim_edge) 38 | elif cfg.dataset.edge_encoder_name.endswith("+ER"): 39 | EdgeEncoder = register.edge_encoder_dict[cfg.dataset.edge_encoder_name[:-3]] 40 | self.edge_encoder = EdgeEncoder(cfg.gnn.dim_edge - cfg.posenc_ERE.dim_pe) 41 | self.edge_encoder_er = EREdgeEncoder(cfg.posenc_ERE.dim_pe, use_edge_attr=True) 42 | else: 43 | EdgeEncoder = register.edge_encoder_dict[cfg.dataset.edge_encoder_name] 44 | self.edge_encoder = EdgeEncoder(cfg.gnn.dim_edge) 45 | 46 | if cfg.dataset.edge_encoder_bn: 47 | self.edge_encoder_bn = BatchNorm1dNode( 48 | new_layer_config(cfg.gnn.dim_edge, -1, -1, has_act=False, has_bias=False, cfg=cfg) 49 | ) 50 | 51 | def forward(self, batch): 52 | for module in self.children(): 53 | batch = module(batch) 54 | return batch 55 | 56 | 57 | @register_network("GPSModel") 58 | class GPSModel(torch.nn.Module): 59 | """Multi-scale graph x-former.""" 60 | 61 | def __init__(self, dim_in, dim_out): 62 | super().__init__() 63 | self.encoder = FeatureEncoder(dim_in) 64 | dim_in = self.encoder.dim_in 65 | 66 | if cfg.gnn.layers_pre_mp > 0: 67 | self.pre_mp = GNNPreMP(dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) 68 | dim_in = cfg.gnn.dim_inner 69 | 70 | assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, "The inner and hidden dims must match." 71 | 72 | try: 73 | local_gnn_type, global_model_type = cfg.gt.layer_type.split("+") 74 | except: 75 | raise ValueError(f"Unexpected layer type: {cfg.gt.layer_type}") 76 | 77 | if global_model_type == "GSSC": 78 | self.get_init_pe = InitPEs(cfg.extra.init_pe_dim) 79 | 80 | layers = [] 81 | for _ in range(cfg.gt.layers): 82 | layers.append( 83 | GPSLayer( 84 | dim_h=cfg.gt.dim_hidden, 85 | local_gnn_type=local_gnn_type, 86 | global_model_type=global_model_type, 87 | num_heads=cfg.gt.n_heads, 88 | pna_degrees=cfg.gt.pna_degrees, 89 | equivstable_pe=cfg.posenc_EquivStableLapPE.enable, 90 | dropout=cfg.gt.dropout, 91 | attn_dropout=cfg.gt.attn_dropout, 92 | layer_norm=cfg.gt.layer_norm, 93 | batch_norm=cfg.gt.batch_norm, 94 | bigbird_cfg=cfg.gt.bigbird, 95 | ) 96 | ) 97 | self.layers = torch.nn.Sequential(*layers) 98 | 99 | GNNHead = register.head_dict[cfg.gnn.head] 100 | self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) 101 | 102 | def forward(self, batch): 103 | batch.all_x = [] 104 | for module in self.children(): 105 | batch = module(batch) 106 | return batch 107 | -------------------------------------------------------------------------------- /gssc/network/multi_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric.graphgym.register as register 3 | from torch_geometric.graphgym.config import cfg 4 | from torch_geometric.graphgym.models.gnn import GNNPreMP 5 | from torch_geometric.graphgym.models.layer import (new_layer_config, 6 | BatchNorm1dNode) 7 | from torch_geometric.graphgym.register import register_network 8 | 9 | from gssc.layer.multi_model_layer import MultiLayer, SingleLayer 10 | from gssc.encoder.ER_edge_encoder import EREdgeEncoder 11 | from gssc.encoder.exp_edge_fixer import ExpanderEdgeFixer 12 | 13 | 14 | class FeatureEncoder(torch.nn.Module): 15 | """ 16 | Encoding node and edge features 17 | 18 | Args: 19 | dim_in (int): Input feature dimension 20 | """ 21 | def __init__(self, dim_in): 22 | super(FeatureEncoder, self).__init__() 23 | self.dim_in = dim_in 24 | if cfg.dataset.node_encoder: 25 | # Encode integer node features via nn.Embeddings 26 | NodeEncoder = register.node_encoder_dict[ 27 | cfg.dataset.node_encoder_name] 28 | self.node_encoder = NodeEncoder(cfg.gnn.dim_inner) 29 | if cfg.dataset.node_encoder_bn: 30 | self.node_encoder_bn = BatchNorm1dNode( 31 | new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False, 32 | has_bias=False, cfg=cfg)) 33 | # Update dim_in to reflect the new dimension fo the node features 34 | self.dim_in = cfg.gnn.dim_inner 35 | if cfg.dataset.edge_encoder: 36 | if not hasattr(cfg.gt, 'dim_edge') or cfg.gt.dim_edge is None: 37 | cfg.gt.dim_edge = cfg.gt.dim_hidden 38 | 39 | if cfg.dataset.edge_encoder_name == 'ER': 40 | self.edge_encoder = EREdgeEncoder(cfg.gt.dim_edge) 41 | elif cfg.dataset.edge_encoder_name.endswith('+ER'): 42 | EdgeEncoder = register.edge_encoder_dict[ 43 | cfg.dataset.edge_encoder_name[:-3]] 44 | self.edge_encoder = EdgeEncoder(cfg.gt.dim_edge - cfg.posenc_ERE.dim_pe) 45 | self.edge_encoder_er = EREdgeEncoder(cfg.posenc_ERE.dim_pe, use_edge_attr=True) 46 | else: 47 | EdgeEncoder = register.edge_encoder_dict[ 48 | cfg.dataset.edge_encoder_name] 49 | self.edge_encoder = EdgeEncoder(cfg.gt.dim_edge) 50 | 51 | if cfg.dataset.edge_encoder_bn: 52 | self.edge_encoder_bn = BatchNorm1dNode( 53 | new_layer_config(cfg.gt.dim_edge, -1, -1, has_act=False, 54 | has_bias=False, cfg=cfg)) 55 | 56 | if 'Exphormer' in cfg.gt.layer_type: 57 | self.exp_edge_fixer = ExpanderEdgeFixer(add_edge_index=cfg.prep.add_edge_index, 58 | num_virt_node=cfg.prep.num_virt_node) 59 | 60 | def forward(self, batch): 61 | for module in self.children(): 62 | batch = module(batch) 63 | return batch 64 | 65 | 66 | class MultiModel(torch.nn.Module): 67 | """Multiple layer types can be combined here. 68 | """ 69 | 70 | def __init__(self, dim_in, dim_out): 71 | super().__init__() 72 | self.encoder = FeatureEncoder(dim_in) 73 | dim_in = self.encoder.dim_in 74 | 75 | if cfg.gnn.layers_pre_mp > 0: 76 | self.pre_mp = GNNPreMP( 77 | dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) 78 | dim_in = cfg.gnn.dim_inner 79 | 80 | assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \ 81 | "The inner and hidden dims must match." 82 | 83 | try: 84 | model_types = cfg.gt.layer_type.split('+') 85 | except: 86 | raise ValueError(f"Unexpected layer type: {cfg.gt.layer_type}") 87 | layers = [] 88 | for _ in range(cfg.gt.layers): 89 | layers.append(MultiLayer( 90 | dim_h=cfg.gt.dim_hidden, 91 | model_types=model_types, 92 | num_heads=cfg.gt.n_heads, 93 | pna_degrees=cfg.gt.pna_degrees, 94 | equivstable_pe=cfg.posenc_EquivStableLapPE.enable, 95 | dropout=cfg.gt.dropout, 96 | attn_dropout=cfg.gt.attn_dropout, 97 | layer_norm=cfg.gt.layer_norm, 98 | batch_norm=cfg.gt.batch_norm, 99 | bigbird_cfg=cfg.gt.bigbird, 100 | exp_edges_cfg=cfg.prep 101 | )) 102 | self.layers = torch.nn.Sequential(*layers) 103 | 104 | GNNHead = register.head_dict[cfg.gnn.head] 105 | self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) 106 | 107 | def forward(self, batch): 108 | for module in self.children(): 109 | batch = module(batch) 110 | return batch 111 | 112 | 113 | class SingleModel(torch.nn.Module): 114 | """A single layer type can be used without FFN between the layers. 115 | """ 116 | 117 | def __init__(self, dim_in, dim_out): 118 | super().__init__() 119 | self.encoder = FeatureEncoder(dim_in) 120 | dim_in = self.encoder.dim_in 121 | 122 | if cfg.gnn.layers_pre_mp > 0: 123 | self.pre_mp = GNNPreMP( 124 | dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) 125 | dim_in = cfg.gnn.dim_inner 126 | 127 | assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \ 128 | "The inner and hidden dims must match." 129 | 130 | layers = [] 131 | for _ in range(cfg.gt.layers): 132 | layers.append(SingleLayer( 133 | dim_h=cfg.gt.dim_hidden, 134 | model_type=cfg.gt.layer_type, 135 | num_heads=cfg.gt.n_heads, 136 | pna_degrees=cfg.gt.pna_degrees, 137 | equivstable_pe=cfg.posenc_EquivStableLapPE.enable, 138 | dropout=cfg.gt.dropout, 139 | attn_dropout=cfg.gt.attn_dropout, 140 | layer_norm=cfg.gt.layer_norm, 141 | batch_norm=cfg.gt.batch_norm, 142 | bigbird_cfg=cfg.gt.bigbird, 143 | exp_edges_cfg=cfg.prep 144 | )) 145 | self.layers = torch.nn.Sequential(*layers) 146 | 147 | GNNHead = register.head_dict[cfg.gnn.head] 148 | self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) 149 | 150 | def forward(self, batch): 151 | for module in self.children(): 152 | batch = module(batch) 153 | return batch 154 | 155 | 156 | register_network('MultiModel', MultiModel) 157 | register_network('SingleModel', SingleModel) 158 | -------------------------------------------------------------------------------- /gssc/network/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Callable 3 | from torch_geometric.nn.norm import GraphNorm as PygGN, InstanceNorm as PygIN 4 | from torch import Tensor 5 | import torch.nn as nn 6 | 7 | def expandbatch(x: Tensor, batch: Tensor): 8 | if batch is None: 9 | return x.flatten(0, 1), None 10 | else: 11 | R = x.shape[0] 12 | N = batch[-1] + 1 13 | offset = N*torch.arange(R, device=x.device).reshape(-1, 1) 14 | batch = batch.unsqueeze(0) + offset 15 | return x.flatten(0, 1), batch.flatten() 16 | 17 | 18 | class NormMomentumScheduler: 19 | def __init__(self, mfunc: Callable, initmomentum: float, normtype=nn.BatchNorm1d) -> None: 20 | super().__init__() 21 | self.normtype = normtype 22 | self.mfunc = mfunc 23 | self.epoch = 0 24 | self.initmomentum = initmomentum 25 | 26 | def step(self, model: nn.Module): 27 | ratio = self.mfunc(self.epoch) 28 | if 1-1e-6 None: 39 | super().__init__() 40 | self.num_features = dim 41 | 42 | def forward(self, x): 43 | return x 44 | 45 | class BatchNorm(nn.Module): 46 | def __init__(self, dim, normparam=0.1) -> None: 47 | super().__init__() 48 | self.num_features = dim 49 | self.norm = nn.BatchNorm1d(dim, momentum=normparam) 50 | 51 | def forward(self, x: Tensor): 52 | if x.dim() == 2: 53 | return self.norm(x) 54 | elif x.dim() > 2: 55 | shape = x.shape 56 | x = self.norm(x.flatten(0, -2)).reshape(shape) 57 | return x 58 | else: 59 | raise NotImplementedError 60 | 61 | class LayerNorm(nn.Module): 62 | def __init__(self, dim, normparam=0.1) -> None: 63 | super().__init__() 64 | self.num_features = dim 65 | self.norm = nn.LayerNorm(dim) 66 | 67 | def forward(self, x: Tensor): 68 | return self.norm(x) 69 | 70 | class InstanceNorm(nn.Module): 71 | def __init__(self, dim, normparam=0.1) -> None: 72 | super().__init__() 73 | self.norm = PygIN(dim, momentum=normparam) 74 | self.num_features = dim 75 | 76 | def forward(self, x: Tensor): 77 | if x.dim() == 2: 78 | return self.norm(x) 79 | elif x.dim() > 2: 80 | shape = x.shape 81 | x = self.norm(x.flatten(0, -2)).reshape(shape) 82 | return x 83 | else: 84 | raise NotImplementedError 85 | 86 | normdict = {"bn": BatchNorm, "ln": LayerNorm, "in": InstanceNorm, "none": NoneNorm} 87 | basenormdict = {"bn": nn.BatchNorm1d, "ln": None, "in": PygIN, "gn": None, "none": None} 88 | 89 | if __name__ == "__main__": 90 | x = torch.randn((3,4,5)) 91 | batch = torch.tensor((0,0,1,2)) 92 | x, batch = expandbatch(x, batch) 93 | print(x.shape, batch) 94 | x = torch.randn((3,4,5)) 95 | batch = None 96 | x, batch = expandbatch(x, batch) 97 | print(x.shape, batch) 98 | 99 | print(list(InstanceNorm(1000).modules())) -------------------------------------------------------------------------------- /gssc/network/performer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric.graphgym.register as register 3 | from torch_geometric.graphgym.config import cfg 4 | from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNPreMP 5 | from torch_geometric.graphgym.register import register_network 6 | 7 | from gssc.layer.performer_layer import Performer as BackbonePerformer 8 | 9 | 10 | @register_network('Performer') 11 | class Performer(torch.nn.Module): 12 | """Performer without edge features. 13 | This model disregards edge features and runs a linear transformer over a set of node features only. 14 | https://arxiv.org/abs/2009.14794 15 | """ 16 | 17 | def __init__(self, dim_in, dim_out): 18 | super().__init__() 19 | self.encoder = FeatureEncoder(dim_in) 20 | dim_in = self.encoder.dim_in 21 | 22 | if cfg.gnn.layers_pre_mp > 0: 23 | self.pre_mp = GNNPreMP( 24 | dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) 25 | dim_in = cfg.gnn.dim_inner 26 | 27 | assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \ 28 | "The inner and hidden dims must match." 29 | 30 | self.trf = BackbonePerformer( 31 | dim=cfg.gt.dim_hidden, 32 | depth=cfg.gt.layers, 33 | heads=cfg.gt.n_heads, 34 | dim_head=cfg.gt.dim_hidden // cfg.gt.n_heads 35 | ) 36 | 37 | GNNHead = register.head_dict[cfg.gnn.head] 38 | self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) 39 | 40 | def forward(self, batch): 41 | for module in self.children(): 42 | batch = module(batch) 43 | return batch 44 | -------------------------------------------------------------------------------- /gssc/network/san_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric.graphgym.register as register 3 | from torch_geometric.graphgym.config import cfg 4 | from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNPreMP 5 | from torch_geometric.graphgym.register import register_network 6 | 7 | from gssc.layer.san_layer import SANLayer 8 | from gssc.layer.san2_layer import SAN2Layer 9 | 10 | 11 | @register_network('SANTransformer') 12 | class SANTransformer(torch.nn.Module): 13 | """Spectral Attention Network (SAN) Graph Transformer. 14 | https://arxiv.org/abs/2106.03893 15 | """ 16 | 17 | def __init__(self, dim_in, dim_out): 18 | super().__init__() 19 | self.encoder = FeatureEncoder(dim_in) 20 | dim_in = self.encoder.dim_in 21 | 22 | if cfg.gnn.layers_pre_mp > 0: 23 | self.pre_mp = GNNPreMP( 24 | dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) 25 | dim_in = cfg.gnn.dim_inner 26 | 27 | assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \ 28 | "The inner and hidden dims must match." 29 | 30 | fake_edge_emb = torch.nn.Embedding(1, cfg.gt.dim_hidden) 31 | # torch.nn.init.xavier_uniform_(fake_edge_emb.weight.data) 32 | Layer = { 33 | 'SANLayer': SANLayer, 34 | 'SAN2Layer': SAN2Layer 35 | }.get(cfg.gt.layer_type) 36 | layers = [] 37 | for _ in range(cfg.gt.layers): 38 | layers.append(Layer(gamma=cfg.gt.gamma, 39 | in_dim=cfg.gt.dim_hidden, 40 | out_dim=cfg.gt.dim_hidden, 41 | num_heads=cfg.gt.n_heads, 42 | secondary_edges=cfg.gt.secondary_edges, 43 | fake_edge_emb=fake_edge_emb, 44 | dropout=cfg.gt.dropout, 45 | layer_norm=cfg.gt.layer_norm, 46 | batch_norm=cfg.gt.batch_norm, 47 | residual=cfg.gt.residual)) 48 | self.trf_layers = torch.nn.Sequential(*layers) 49 | 50 | GNNHead = register.head_dict[cfg.gnn.head] 51 | self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) 52 | 53 | def forward(self, batch): 54 | for module in self.children(): 55 | batch = module(batch) 56 | return batch 57 | -------------------------------------------------------------------------------- /gssc/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /gssc/pooling/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /gssc/pooling/example.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_pooling 2 | from torch_scatter import scatter 3 | 4 | 5 | @register_pooling('example') 6 | def global_example_pool(x, batch, size=None): 7 | size = batch.max().item() + 1 if size is None else size 8 | return scatter(x, batch, dim=0, dim_size=size, reduce='add') 9 | -------------------------------------------------------------------------------- /gssc/stage/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /gssc/stage/example.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from torch_geometric.graphgym.config import cfg 5 | from torch_geometric.graphgym.models.layer import GeneralLayer 6 | from torch_geometric.graphgym.register import register_stage 7 | 8 | 9 | def GNNLayer(dim_in, dim_out, has_act=True): 10 | return GeneralLayer(cfg.gnn.layer_type, dim_in, dim_out, has_act) 11 | 12 | 13 | @register_stage('example') 14 | class GNNStackStage(nn.Module): 15 | '''Simple Stage that stack GNN layers''' 16 | def __init__(self, dim_in, dim_out, num_layers): 17 | super().__init__() 18 | for i in range(num_layers): 19 | d_in = dim_in if i == 0 else dim_out 20 | layer = GNNLayer(d_in, dim_out) 21 | self.add_module(f'layer{i}', layer) 22 | self.dim_out = dim_out 23 | 24 | def forward(self, batch): 25 | for layer in self.children(): 26 | batch = layer(batch) 27 | if cfg.gnn.l2norm: 28 | batch.x = F.normalize(batch.x, p=2, dim=-1) 29 | return batch 30 | -------------------------------------------------------------------------------- /gssc/train/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /gssc/train/example.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import torch 5 | 6 | from torch_geometric.graphgym.checkpoint import ( 7 | clean_ckpt, 8 | load_ckpt, 9 | save_ckpt, 10 | ) 11 | from torch_geometric.graphgym.config import cfg 12 | from torch_geometric.graphgym.loss import compute_loss 13 | from torch_geometric.graphgym.register import register_train 14 | from torch_geometric.graphgym.utils.epoch import is_ckpt_epoch, is_eval_epoch 15 | 16 | 17 | def train_epoch(logger, loader, model, optimizer, scheduler): 18 | model.train() 19 | time_start = time.time() 20 | for batch in loader: 21 | optimizer.zero_grad() 22 | batch.to(torch.device(cfg.device)) 23 | pred, true = model(batch) 24 | loss, pred_score = compute_loss(pred, true) 25 | loss.backward() 26 | optimizer.step() 27 | logger.update_stats(true=true.detach().cpu(), 28 | pred=pred_score.detach().cpu(), loss=loss.item(), 29 | lr=scheduler.get_last_lr()[0], 30 | time_used=time.time() - time_start, 31 | params=cfg.params) 32 | time_start = time.time() 33 | scheduler.step() 34 | 35 | 36 | def eval_epoch(logger, loader, model): 37 | model.eval() 38 | time_start = time.time() 39 | for batch in loader: 40 | batch.to(torch.device(cfg.device)) 41 | pred, true = model(batch) 42 | loss, pred_score = compute_loss(pred, true) 43 | logger.update_stats(true=true.detach().cpu(), 44 | pred=pred_score.detach().cpu(), loss=loss.item(), 45 | lr=0, time_used=time.time() - time_start, 46 | params=cfg.params) 47 | time_start = time.time() 48 | 49 | 50 | @register_train('example') 51 | def train_example(loggers, loaders, model, optimizer, scheduler): 52 | start_epoch = 0 53 | if cfg.train.auto_resume: 54 | start_epoch = load_ckpt(model, optimizer, scheduler, 55 | cfg.train.epoch_resume) 56 | if start_epoch == cfg.optim.max_epoch: 57 | logging.info('Checkpoint found, Task already done') 58 | else: 59 | logging.info('Start from epoch %s', start_epoch) 60 | 61 | num_splits = len(loggers) 62 | for cur_epoch in range(start_epoch, cfg.optim.max_epoch): 63 | train_epoch(loggers[0], loaders[0], model, optimizer, scheduler) 64 | loggers[0].write_epoch(cur_epoch) 65 | if is_eval_epoch(cur_epoch): 66 | for i in range(1, num_splits): 67 | eval_epoch(loggers[i], loaders[i], model) 68 | loggers[i].write_epoch(cur_epoch) 69 | if is_ckpt_epoch(cur_epoch): 70 | save_ckpt(model, optimizer, scheduler, cur_epoch) 71 | for logger in loggers: 72 | logger.close() 73 | if cfg.train.ckpt_clean: 74 | clean_ckpt() 75 | 76 | logging.info('Task done, results saved in %s', cfg.run_dir) 77 | -------------------------------------------------------------------------------- /gssc/transform/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /gssc/transform/expander_edges.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import scipy as sp 4 | from typing import Any, Optional 5 | import torch 6 | from gssc.transform.dist_transforms import laplacian_eigenv 7 | 8 | 9 | def generate_random_regular_graph1(num_nodes, degree, rng=None): 10 | """Generates a random 2d-regular graph with n nodes using permutations algorithm. 11 | Returns the list of edges. This list is symmetric; i.e., if 12 | (x, y) is an edge so is (y,x). 13 | Args: 14 | num_nodes: Number of nodes in the desired graph. 15 | degree: Desired degree. 16 | rng: random number generator 17 | Returns: 18 | senders: tail of each edge. 19 | receivers: head of each edge. 20 | """ 21 | 22 | if rng is None: 23 | rng = np.random.default_rng() 24 | 25 | senders = [*range(0, num_nodes)] * degree 26 | receivers = [] 27 | for _ in range(degree): 28 | receivers.extend(rng.permutation(list(range(num_nodes))).tolist()) 29 | 30 | senders, receivers = [*senders, *receivers], [*receivers, *senders] 31 | 32 | senders = np.array(senders) 33 | receivers = np.array(receivers) 34 | 35 | return senders, receivers 36 | 37 | 38 | 39 | def generate_random_regular_graph2(num_nodes, degree, rng=None): 40 | """Generates a random 2d-regular graph with n nodes using simple variant of permutations algorithm. 41 | Returns the list of edges. This list is symmetric; i.e., if 42 | (x, y) is an edge so is (y,x). 43 | Args: 44 | num_nodes: Number of nodes in the desired graph. 45 | degree: Desired degree. 46 | rng: random number generator 47 | Returns: 48 | senders: tail of each edge. 49 | receivers: head of each edge. 50 | """ 51 | 52 | if rng is None: 53 | rng = np.random.default_rng() 54 | 55 | senders = [*range(0, num_nodes)] * degree 56 | receivers = rng.permutation(senders).tolist() 57 | 58 | senders, receivers = [*senders, *receivers], [*receivers, *senders] 59 | 60 | return senders, receivers 61 | 62 | 63 | def generate_random_graph_with_hamiltonian_cycles(num_nodes, degree, rng=None): 64 | """Generates a 2d-regular graph with n nodes using d random hamiltonian cycles. 65 | Returns the list of edges. This list is symmetric; i.e., if 66 | (x, y) is an edge so is (y,x). 67 | Args: 68 | num_nodes: Number of nodes in the desired graph. 69 | degree: Desired degree. 70 | rng: random number generator 71 | Returns: 72 | senders: tail of each edge. 73 | receivers: head of each edge. 74 | """ 75 | 76 | if rng is None: 77 | rng = np.random.default_rng() 78 | 79 | senders = [] 80 | receivers = [] 81 | for _ in range(degree): 82 | permutation = rng.permutation(list(range(num_nodes))).tolist() 83 | for idx, v in enumerate(permutation): 84 | u = permutation[idx - 1] 85 | senders.extend([v, u]) 86 | receivers.extend([u, v]) 87 | 88 | senders = np.array(senders) 89 | receivers = np.array(receivers) 90 | 91 | return senders, receivers 92 | 93 | 94 | def generate_random_expander(data, degree, algorithm, rng=None, max_num_iters=100, exp_index=0): 95 | """Generates a random d-regular expander graph with n nodes. 96 | Returns the list of edges. This list is symmetric; i.e., if 97 | (x, y) is an edge so is (y,x). 98 | Args: 99 | num_nodes: Number of nodes in the desired graph. 100 | degree: Desired degree. 101 | rng: random number generator 102 | max_num_iters: maximum number of iterations 103 | Returns: 104 | senders: tail of each edge. 105 | receivers: head of each edge. 106 | """ 107 | 108 | num_nodes = data.num_nodes 109 | 110 | if rng is None: 111 | rng = np.random.default_rng() 112 | 113 | eig_val = -1 114 | eig_val_lower_bound = max(0, 2 * degree - 2 * math.sqrt(2 * degree - 1) - 0.1) 115 | 116 | max_eig_val_so_far = -1 117 | max_senders = [] 118 | max_receivers = [] 119 | cur_iter = 1 120 | 121 | if num_nodes <= degree: 122 | degree = num_nodes - 1 123 | 124 | # if there are too few nodes, random graph generation will fail. in this case, we will 125 | # add the whole graph. 126 | if num_nodes <= 10: 127 | for i in range(num_nodes): 128 | for j in range(num_nodes): 129 | if i != j: 130 | max_senders.append(i) 131 | max_receivers.append(j) 132 | else: 133 | while eig_val < eig_val_lower_bound and cur_iter <= max_num_iters: 134 | if algorithm == 'Random-d': 135 | senders, receivers = generate_random_regular_graph1(num_nodes, degree, rng) 136 | elif algorithm == 'Random-d-2': 137 | senders, receivers = generate_random_regular_graph2(num_nodes, degree, rng) 138 | elif algorithm == 'Hamiltonian': 139 | senders, receivers = generate_random_graph_with_hamiltonian_cycles(num_nodes, degree, rng) 140 | else: 141 | raise ValueError('prep.exp_algorithm should be one of the Random-d or Hamiltonian') 142 | [eig_val, _] = laplacian_eigenv(senders, receivers, k=1, n=num_nodes) 143 | if len(eig_val) == 0: 144 | print("num_nodes = %d, degree = %d, cur_iter = %d, mmax_iters = %d, senders = %d, receivers = %d" %(num_nodes, degree, cur_iter, max_num_iters, len(senders), len(receivers))) 145 | eig_val = 0 146 | else: 147 | eig_val = eig_val[0] 148 | 149 | if eig_val > max_eig_val_so_far: 150 | max_eig_val_so_far = eig_val 151 | max_senders = senders 152 | max_receivers = receivers 153 | 154 | cur_iter += 1 155 | 156 | # eliminate self loops. 157 | non_loops = [ 158 | *filter(lambda i: max_senders[i] != max_receivers[i], range(0, len(max_senders))) 159 | ] 160 | 161 | senders = np.array(max_senders)[non_loops] 162 | receivers = np.array(max_receivers)[non_loops] 163 | 164 | max_senders = torch.tensor(max_senders, dtype=torch.long).view(-1, 1) 165 | max_receivers = torch.tensor(max_receivers, dtype=torch.long).view(-1, 1) 166 | 167 | if exp_index == 0: 168 | data.expander_edges = torch.cat([max_senders, max_receivers], dim=1) 169 | else: 170 | attrname = f"expander_edges{exp_index}" 171 | setattr(data, attrname, torch.cat([max_senders, max_receivers], dim=1)) 172 | 173 | return data 174 | -------------------------------------------------------------------------------- /gssc/transform/transforms.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import os 4 | import torch 5 | from torch_geometric.utils import subgraph 6 | from tqdm import tqdm 7 | 8 | 9 | def pre_transform_in_memory(dataset, transform_func, show_progress=False, check_if_pe_done=False): 10 | """Pre-transform already loaded PyG dataset object. 11 | 12 | Apply transform function to a loaded PyG dataset object so that 13 | the transformed result is persistent for the lifespan of the object. 14 | This means the result is not saved to disk, as what PyG's `pre_transform` 15 | would do, but also the transform is applied only once and not at each 16 | data access as what PyG's `transform` hook does. 17 | 18 | Implementation is based on torch_geometric.data.in_memory_dataset.copy 19 | 20 | Args: 21 | dataset: PyG dataset object to modify 22 | transform_func: transformation function to apply to each data example 23 | show_progress: show tqdm progress bar 24 | """ 25 | if transform_func is None: 26 | return dataset 27 | 28 | if check_if_pe_done: 29 | file_name = dataset.processed_dir + "/pe_transformed.pt" 30 | if os.path.exists(file_name): 31 | print(f"Loading pre-transformed PE from {file_name}") 32 | data_list, collate_data, collate_slices = torch.load(file_name) 33 | dataset._indices, dataset._data_list, dataset.data, dataset.slices = None, data_list, collate_data, collate_slices 34 | print(f"Loaded pre-transformed PE from {file_name}") 35 | else: 36 | print(f"Pre-transforming PE and saving to {file_name}") 37 | data_list = [transform_func(dataset.get(i)) 38 | for i in tqdm(range(len(dataset)), 39 | disable=not show_progress, 40 | mininterval=10, 41 | miniters=len(dataset)//20)] 42 | data_list = list(filter(None, data_list)) 43 | 44 | dataset._indices = None 45 | dataset._data_list = data_list 46 | dataset.data, dataset.slices = dataset.collate(data_list) 47 | torch.save((data_list, dataset.data, dataset.slices), file_name) 48 | print(f"Saved pre-transformed PE to {file_name}") 49 | else: 50 | data_list = [transform_func(dataset.get(i)) 51 | for i in tqdm(range(len(dataset)), 52 | disable=not show_progress, 53 | mininterval=10, 54 | miniters=len(dataset)//20)] 55 | data_list = list(filter(None, data_list)) 56 | 57 | dataset._indices = None 58 | dataset._data_list = data_list 59 | dataset.data, dataset.slices = dataset.collate(data_list) 60 | 61 | 62 | def generate_splits(data, g_split): 63 | n_nodes = len(data.x) 64 | train_mask = torch.zeros(n_nodes, dtype=bool) 65 | valid_mask = torch.zeros(n_nodes, dtype=bool) 66 | test_mask = torch.zeros(n_nodes, dtype=bool) 67 | idx = torch.randperm(n_nodes) 68 | val_num = test_num = int(n_nodes * (1 - g_split) / 2) 69 | train_mask[idx[val_num + test_num:]] = True 70 | valid_mask[idx[:val_num]] = True 71 | test_mask[idx[val_num:val_num + test_num]] = True 72 | data.train_mask = train_mask 73 | data.val_mask = valid_mask 74 | data.test_mask = test_mask 75 | return data 76 | 77 | 78 | def typecast_x(data, type_str): 79 | if type_str == 'float': 80 | data.x = data.x.float() 81 | elif type_str == 'long': 82 | data.x = data.x.long() 83 | else: 84 | raise ValueError(f"Unexpected type '{type_str}'.") 85 | return data 86 | 87 | 88 | def concat_x_and_pos(data): 89 | data.x = torch.cat((data.x, data.pos), 1) 90 | return data 91 | 92 | def move_node_feat_to_x(data): 93 | """For ogbn-proteins, move the attribute node_species to attribute x.""" 94 | data.x = data.node_species 95 | return data 96 | 97 | def clip_graphs_to_size(data, size_limit=5000): 98 | if hasattr(data, 'num_nodes'): 99 | N = data.num_nodes # Explicitly given number of nodes, e.g. ogbg-ppa 100 | else: 101 | N = data.x.shape[0] # Number of nodes, including disconnected nodes. 102 | if N <= size_limit: 103 | return data 104 | else: 105 | logging.info(f' ...clip to {size_limit} a graph of size: {N}') 106 | if hasattr(data, 'edge_attr'): 107 | edge_attr = data.edge_attr 108 | else: 109 | edge_attr = None 110 | edge_index, edge_attr = subgraph(list(range(size_limit)), 111 | data.edge_index, edge_attr) 112 | if hasattr(data, 'x'): 113 | data.x = data.x[:size_limit] 114 | data.num_nodes = size_limit 115 | else: 116 | data.num_nodes = size_limit 117 | if hasattr(data, 'node_is_attributed'): # for ogbg-code2 dataset 118 | data.node_is_attributed = data.node_is_attributed[:size_limit] 119 | data.node_dfs_order = data.node_dfs_order[:size_limit] 120 | data.node_depth = data.node_depth[:size_limit] 121 | data.edge_index = edge_index 122 | if hasattr(data, 'edge_attr'): 123 | data.edge_attr = edge_attr 124 | return data 125 | -------------------------------------------------------------------------------- /gssc/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from torch_geometric.utils import remove_self_loops 5 | from torch_scatter import scatter 6 | 7 | from yacs.config import CfgNode 8 | 9 | 10 | def negate_edge_index(edge_index, batch=None): 11 | """Negate batched sparse adjacency matrices given by edge indices. 12 | 13 | Returns batched sparse adjacency matrices with exactly those edges that 14 | are not in the input `edge_index` while ignoring self-loops. 15 | 16 | Implementation inspired by `torch_geometric.utils.to_dense_adj` 17 | 18 | Args: 19 | edge_index: The edge indices. 20 | batch: Batch vector, which assigns each node to a specific example. 21 | 22 | Returns: 23 | Complementary edge index. 24 | """ 25 | 26 | if batch is None: 27 | batch = edge_index.new_zeros(edge_index.max().item() + 1) 28 | 29 | batch_size = batch.max().item() + 1 30 | one = batch.new_ones(batch.size(0)) 31 | num_nodes = scatter(one, batch, 32 | dim=0, dim_size=batch_size, reduce='add') 33 | cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)]) 34 | 35 | idx0 = batch[edge_index[0]] 36 | idx1 = edge_index[0] - cum_nodes[batch][edge_index[0]] 37 | idx2 = edge_index[1] - cum_nodes[batch][edge_index[1]] 38 | 39 | negative_index_list = [] 40 | for i in range(batch_size): 41 | n = num_nodes[i].item() 42 | size = [n, n] 43 | adj = torch.ones(size, dtype=torch.short, 44 | device=edge_index.device) 45 | 46 | # Remove existing edges from the full N x N adjacency matrix 47 | flattened_size = n * n 48 | adj = adj.view([flattened_size]) 49 | _idx1 = idx1[idx0 == i] 50 | _idx2 = idx2[idx0 == i] 51 | idx = _idx1 * n + _idx2 52 | zero = torch.zeros(_idx1.numel(), dtype=torch.short, 53 | device=edge_index.device) 54 | scatter(zero, idx, dim=0, out=adj, reduce='mul') 55 | 56 | # Convert to edge index format 57 | adj = adj.view(size) 58 | _edge_index = adj.nonzero(as_tuple=False).t().contiguous() 59 | _edge_index, _ = remove_self_loops(_edge_index) 60 | negative_index_list.append(_edge_index + cum_nodes[i]) 61 | 62 | edge_index_negative = torch.cat(negative_index_list, dim=1).contiguous() 63 | return edge_index_negative 64 | 65 | 66 | def flatten_dict(metrics): 67 | """Flatten a list of train/val/test metrics into one dict to send to wandb. 68 | 69 | Args: 70 | metrics: List of Dicts with metrics 71 | 72 | Returns: 73 | A flat dictionary with names prefixed with "train/" , "val/" , "test/" 74 | """ 75 | prefixes = ['train', 'val', 'test'] 76 | result = {} 77 | for i in range(len(metrics)): 78 | # Take the latest metrics. 79 | stats = metrics[i][-1] 80 | result.update({f"{prefixes[i]}/{k}": v for k, v in stats.items()}) 81 | return result 82 | 83 | 84 | def cfg_to_dict(cfg_node, key_list=[]): 85 | """Convert a config node to dictionary. 86 | 87 | Yacs doesn't have a default function to convert the cfg object to plain 88 | python dict. The following function was taken from 89 | https://github.com/rbgirshick/yacs/issues/19 90 | """ 91 | _VALID_TYPES = {tuple, list, str, int, float, bool} 92 | 93 | if not isinstance(cfg_node, CfgNode): 94 | if type(cfg_node) not in _VALID_TYPES: 95 | logging.warning(f"Key {'.'.join(key_list)} with " 96 | f"value {type(cfg_node)} is not " 97 | f"a valid type; valid types: {_VALID_TYPES}") 98 | return cfg_node 99 | else: 100 | cfg_dict = dict(cfg_node) 101 | for k, v in cfg_dict.items(): 102 | cfg_dict[k] = cfg_to_dict(v, key_list + [k]) 103 | return cfg_dict 104 | 105 | 106 | def make_wandb_name(cfg): 107 | # Format dataset name. 108 | dataset_name = cfg.dataset.format 109 | if dataset_name.startswith('OGB'): 110 | dataset_name = dataset_name[3:] 111 | if dataset_name.startswith('PyG-'): 112 | dataset_name = dataset_name[4:] 113 | if dataset_name in ['GNNBenchmarkDataset', 'TUDataset']: 114 | # Shorten some verbose dataset naming schemes. 115 | dataset_name = "" 116 | if cfg.dataset.name != 'none': 117 | dataset_name += "-" if dataset_name != "" else "" 118 | if cfg.dataset.name == 'LocalDegreeProfile': 119 | dataset_name += 'LDP' 120 | else: 121 | dataset_name += cfg.dataset.name 122 | # Format model name. 123 | model_name = cfg.model.type 124 | if cfg.model.type in ['gnn', 'custom_gnn']: 125 | model_name += f".{cfg.gnn.layer_type}" 126 | elif cfg.model.type == 'GPSModel': 127 | model_name = f"GPS.{cfg.gt.layer_type}" 128 | model_name += f".{cfg.name_tag}" if cfg.name_tag else "" 129 | # Compose wandb run name. 130 | name = f"{dataset_name}.{model_name}.r{cfg.run_id}" 131 | return name 132 | -------------------------------------------------------------------------------- /wandb/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | --------------------------------------------------------------------------------