├── .gitignore
├── README.md
├── configs
└── GSSC
│ ├── cifar10-GSSC-tune.yaml
│ ├── cifar10-GSSC.yaml
│ ├── cluster-GSSC-tune.yaml
│ ├── cluster-GSSC.yaml
│ ├── mnist-GSSC-tune.yaml
│ ├── mnist-GSSC.yaml
│ ├── ogbg-molhiv-GSSC-tune.yaml
│ ├── ogbg-molhiv-GSSC.yaml
│ ├── pattern-GSSC-tune.yaml
│ ├── pattern-GSSC.yaml
│ ├── peptides-func-GSSC-tune.yaml
│ ├── peptides-func-GSSC.yaml
│ ├── peptides-struct-GSSC-tune.yaml
│ ├── peptides-struct-GSSC.yaml
│ ├── vocsuperpixels-GSSC-tune.yaml
│ ├── vocsuperpixels-GSSC.yaml
│ ├── zinc-GSSC-tune.yaml
│ ├── zinc-GSSC.yaml
│ ├── zincfull-GSSC-tune.yaml
│ └── zincfull-GSSC.yaml
├── datasets
└── .gitignore
├── environment.yml
├── gssc
├── __init__.py
├── act
│ ├── __init__.py
│ └── example.py
├── agg_runs.py
├── config
│ ├── __init__.py
│ ├── custom_gnn_config.py
│ ├── data_preprocess_config.py
│ ├── dataset_config.py
│ ├── defaults_config.py
│ ├── example.py
│ ├── gt_config.py
│ ├── optimizers_config.py
│ ├── posenc_config.py
│ ├── pretrained_config.py
│ ├── split_config.py
│ └── wandb_config.py
├── encoder
│ ├── ER_edge_encoder.py
│ ├── ER_node_encoder.py
│ ├── __init__.py
│ ├── ast_encoder.py
│ ├── composed_encoders.py
│ ├── dummy_edge_encoder.py
│ ├── equivstable_laplace_pos_encoder.py
│ ├── example.py
│ ├── exp_edge_fixer.py
│ ├── kernel_pos_encoder.py
│ ├── laplace_pos_encoder.py
│ ├── linear_edge_encoder.py
│ ├── linear_node_encoder.py
│ ├── ppa_encoder.py
│ ├── signnet_pos_encoder.py
│ ├── type_dict_encoder.py
│ └── voc_superpixels_encoder.py
├── finetuning.py
├── head
│ ├── __init__.py
│ ├── example.py
│ ├── inductive_edge.py
│ ├── inductive_node.py
│ ├── ogb_code_graph.py
│ └── san_graph.py
├── layer
│ ├── ETransformer.py
│ ├── Exphormer.py
│ ├── __init__.py
│ ├── bigbird_layer.py
│ ├── example.py
│ ├── gatedgcn_layer.py
│ ├── gine_conv_layer.py
│ ├── gps_layer.py
│ ├── gssc_layer.py
│ ├── multi_model_layer.py
│ ├── performer_layer.py
│ ├── san2_layer.py
│ └── san_layer.py
├── loader
│ ├── __init__.py
│ ├── dataset
│ │ ├── __init__.py
│ │ ├── aqsol_molecules.py
│ │ ├── coco_superpixels.py
│ │ ├── count_cycle.py
│ │ ├── malnet_tiny.py
│ │ ├── pcqm4mv2_contact.py
│ │ ├── peptides_functional.py
│ │ ├── peptides_structural.py
│ │ └── voc_superpixels.py
│ ├── master_loader.py
│ ├── ogbg_code2_utils.py
│ ├── planetoid.py
│ └── split_generator.py
├── logger.py
├── loss
│ ├── __init__.py
│ ├── l1.py
│ ├── multilabel_classification_loss.py
│ ├── subtoken_prediction_loss.py
│ └── weighted_cross_entropy.py
├── metric_wrapper.py
├── metrics_ogb.py
├── network
│ ├── MaskedReduce.py
│ ├── __init__.py
│ ├── big_bird.py
│ ├── custom_gnn.py
│ ├── example.py
│ ├── gps_model.py
│ ├── multi_model.py
│ ├── norm.py
│ ├── performer.py
│ ├── san_transformer.py
│ └── utils.py
├── optimizer
│ ├── __init__.py
│ └── extra_optimizers.py
├── pooling
│ ├── __init__.py
│ └── example.py
├── stage
│ ├── __init__.py
│ └── example.py
├── train
│ ├── __init__.py
│ ├── custom_train.py
│ └── example.py
├── transform
│ ├── __init__.py
│ ├── dist_transforms.py
│ ├── expander_edges.py
│ ├── posenc_stats.py
│ └── transforms.py
└── utils.py
├── main.py
└── wandb
└── .gitignore
/.gitignore:
--------------------------------------------------------------------------------
1 | # CUSTOM
2 | .vscode/
3 | scripts/pcqm4m/**/*.zip
4 | scripts/pcqm4m/**/*.sdf
5 | scripts/pcqm4m/**/*.xyz
6 | scripts/pcqm4m/**/*.csv
7 | !scripts/pcqm4m/**/periodic_table.csv
8 | scripts/pcqm4m/**/*.gz
9 | scripts/pcqm4m/**/*.tsv
10 | scripts/pcqm4m/pcqm4m-v2/
11 | slurm_history/
12 | pretrained/
13 | results/
14 | vocprep/benchmark_RELEASE/
15 | vocprep/voc_viz_files/
16 | vocprep/VOC/benchmark_RELEASE/
17 | vocprep/VOC/*.tgz
18 | vocprep/VOC/*.pickle
19 | vocprep/VOC/*.pkl
20 | vocprep/VOC/*.zip
21 | splits/
22 | __pycache__/
23 | .idea
24 | *.log
25 | *.bak
26 | *.png
27 | *.txt
28 |
29 | # Byte-compiled / optimized / DLL files
30 | __pycache__/
31 | *.py[cod]
32 | *$py.class
33 |
34 | # C extensions
35 | *.so
36 |
37 | # Distribution / packaging
38 | .Python
39 | build/
40 | develop-eggs/
41 | dist/
42 | downloads/
43 | eggs/
44 | .eggs/
45 | lib/
46 | lib64/
47 | parts/
48 | sdist/
49 | var/
50 | wheels/
51 | pip-wheel-metadata/
52 | share/python-wheels/
53 | *.egg-info/
54 | .installed.cfg
55 | *.egg
56 | MANIFEST
57 |
58 | # PyInstaller
59 | # Usually these files are written by a python script from a template
60 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
61 | *.manifest
62 | *.spec
63 |
64 | # Installer logs
65 | pip-log.txt
66 | pip-delete-this-directory.txt
67 |
68 | # Unit test / coverage reports
69 | htmlcov/
70 | .tox/
71 | .nox/
72 | .coverage
73 | .coverage.*
74 | .cache
75 | nosetests.xml
76 | coverage.xml
77 | *.cover
78 | *.py,cover
79 | .hypothesis/
80 | .pytest_cache/
81 |
82 | # Translations
83 | *.mo
84 | *.pot
85 |
86 | # Django stuff:
87 | *.log
88 | local_settings.py
89 | db.sqlite3
90 | db.sqlite3-journal
91 |
92 | # Flask stuff:
93 | instance/
94 | .webassets-cache
95 |
96 | # Scrapy stuff:
97 | .scrapy
98 |
99 | # Sphinx documentation
100 | docs/_build/
101 |
102 | # PyBuilder
103 | target/
104 |
105 | # Jupyter Notebook
106 | .ipynb_checkpoints
107 |
108 | # IPython
109 | profile_default/
110 | ipython_config.py
111 |
112 | # pyenv
113 | .python-version
114 |
115 | # pipenv
116 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
117 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
118 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
119 | # install all needed dependencies.
120 | #Pipfile.lock
121 |
122 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
123 | __pypackages__/
124 |
125 | # Celery stuff
126 | celerybeat-schedule
127 | celerybeat.pid
128 |
129 | # SageMath parsed files
130 | *.sage.py
131 |
132 | # Environments
133 | .env
134 | .venv
135 | env/
136 | venv/
137 | ENV/
138 | env.bak/
139 | venv.bak/
140 |
141 | # Spyder project settings
142 | .spyderproject
143 | .spyproject
144 |
145 | # Rope project settings
146 | .ropeproject
147 |
148 | # mkdocs documentation
149 | /site
150 |
151 | # mypy
152 | .mypy_cache/
153 | .dmypy.json
154 | dmypy.json
155 |
156 | # Pyre type checker
157 | .pyre/
158 |
159 | # vim edit buffer
160 | *.swp
161 |
162 |
163 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Graph State Space Convolution (GSSC)
2 |
3 |
4 |
5 |
6 |
7 |
8 | This repository contains the official implementation of GSSC as described in the paper: [What Can We Learn from State Space Models for Machine Learning on Graphs?](https://arxiv.org/abs/2406.05815) by Yinan Huang*, Siqi Miao*, and Pan Li.
9 |
10 | (*Equal contribution, listed in alphabetical order)
11 |
12 | ## Installation
13 | All required packages are listed in `environment.yml`.
14 |
15 | ## Running the code
16 | Replace `--cfg` with the path to the configuration file and `--device` with the GPU device number like below:
17 | ```
18 | python main.py --cfg configs/GSSC/peptides-func-GSSC.yaml --device 0 wandb.use False
19 | ```
20 | This command will train the model on the `peptides-func` dataset using the GSSC method with default hyperparameters.
21 |
22 | ## Reproducing the results
23 | We use wandb to log and sweep the results. To reproduce the reported results, one needs to create and login to a wandb account. Then, one can launch the sweep using the configuration files in the `configs` directory.
24 | For example, to reproduce the tuned results of GSSC on the `peptides-func` dataset, one can launch the sweep using `configs/GSSC/peptides-func-GSSC-tune.yaml`.
25 |
26 | ## Acknowledgement
27 | This repository is built upon [GraphGPS (Rampasek et al., 2022)](https://github.com/rampasek/GraphGPS).
28 |
--------------------------------------------------------------------------------
/configs/GSSC/cifar10-GSSC-tune.yaml:
--------------------------------------------------------------------------------
1 | program: main.py
2 | project: cifar10
3 | entity: anonymity
4 | name: cifar10-final
5 | method: grid
6 | metric:
7 | goal: maximize
8 | name: best/test_accuracy
9 | parameters:
10 | dropout_res:
11 | value: 0.1
12 | dropout_local:
13 | value: 0.1
14 | dropout_ff:
15 | value: 0.1
16 | base_lr:
17 | value: 0.005
18 | weight_decay:
19 | value: 0.01
20 |
21 | reweigh_self:
22 | value: 2
23 | jk:
24 | value: 0
25 | init_pe_dim:
26 | value: 32
27 | more_mapping:
28 | value: 1
29 | cfg:
30 | value: configs/GSSC/cifar10-GSSC.yaml
31 | name_tag:
32 | value: random
33 | log_code:
34 | value: 0
35 | device:
36 | value: 0
37 | seed:
38 | values: [0, 1, 2, 3, 4]
39 |
--------------------------------------------------------------------------------
/configs/GSSC/cifar10-GSSC.yaml:
--------------------------------------------------------------------------------
1 | out_dir: results
2 | metric_best: accuracy
3 | wandb:
4 | use: True
5 | project: cifar10
6 | entity: anonymity
7 | dataset:
8 | format: PyG-GNNBenchmarkDataset
9 | name: CIFAR10
10 | task: graph
11 | task_type: classification
12 | transductive: False
13 | node_encoder: True
14 | node_encoder_name: LapPE
15 | node_encoder_bn: False
16 | edge_encoder: True
17 | edge_encoder_name: LinearEdge
18 | edge_encoder_bn: False
19 | posenc_LapPE:
20 | enable: True
21 | eigen:
22 | laplacian_norm: none
23 | eigvec_norm: L2
24 | max_freqs: 33
25 | model: DeepSet
26 | dim_pe: 16
27 | layers: 2
28 | raw_norm_type: none
29 | train:
30 | mode: custom
31 | batch_size: 16
32 | eval_period: 1
33 | ckpt_period: 100
34 | model:
35 | type: GPSModel
36 | loss_fun: cross_entropy
37 | edge_decoding: dot
38 | graph_pooling: mean
39 | gt: # Hyperparameters optimized for ~100k budget.
40 | layer_type: CustomGatedGCN+GSSC
41 | layers: 3
42 | n_heads: 4
43 | dim_hidden: 52 # `gt.dim_hidden` must match `gnn.dim_inner`
44 | dropout: 0.0
45 | attn_dropout: 0.5
46 | layer_norm: False
47 | batch_norm: True
48 | gnn:
49 | head: default
50 | layers_pre_mp: 0
51 | layers_post_mp: 2
52 | dim_inner: 52 # `gt.dim_hidden` must match `gnn.dim_inner`
53 | batchnorm: False
54 | act: relu
55 | dropout: 0.0
56 | agg: mean
57 | normalize_adj: False
58 | optim:
59 | clip_grad_norm: True
60 | optimizer: adamW
61 | weight_decay: 1e-5
62 | base_lr: 0.001
63 | max_epoch: 300
64 | scheduler: cosine_with_warmup
65 | num_warmup_epochs: 5
66 | seed: 0
67 | name_tag: "random"
68 |
--------------------------------------------------------------------------------
/configs/GSSC/cluster-GSSC-tune.yaml:
--------------------------------------------------------------------------------
1 | program: main.py
2 | project: CLUSTER
3 | entity: anonymity
4 | name: cluster-final
5 | method: grid
6 | metric:
7 | goal: maximize
8 | name: best/test_accuracy-SBM
9 | parameters:
10 | dropout_res:
11 | value: 0.3
12 | dropout_local:
13 | value: 0.3
14 | dropout_ff:
15 | value: 0.1
16 | base_lr:
17 | values: [0.001, 0.002, 0.003]
18 | weight_decay:
19 | values: [0.1, 0.2, 0.001]
20 |
21 | reweigh_self:
22 | value: 2
23 | jk:
24 | value: 1
25 | init_pe_dim:
26 | value: 32
27 | more_mapping:
28 | value: 1
29 | cfg:
30 | value: configs/GSSC/cluster-GSSC.yaml
31 | name_tag:
32 | value: random
33 | log_code:
34 | value: 0
35 | device:
36 | value: 0
37 | seed:
38 | values: [0, 1, 2, 3, 4]
39 |
--------------------------------------------------------------------------------
/configs/GSSC/cluster-GSSC.yaml:
--------------------------------------------------------------------------------
1 | out_dir: results
2 | metric_best: accuracy-SBM
3 | wandb:
4 | use: True
5 | project: CLUSTER
6 | entity: anonymity
7 | dataset:
8 | format: PyG-GNNBenchmarkDataset
9 | name: CLUSTER
10 | task: graph
11 | task_type: classification
12 | transductive: False
13 | split_mode: standard
14 | node_encoder: True
15 | node_encoder_name: RWSE
16 | node_encoder_bn: False
17 | edge_encoder: True
18 | edge_encoder_name: DummyEdge
19 | edge_encoder_bn: False
20 | posenc_RWSE:
21 | enable: True
22 | kernel:
23 | times_func: range(1,33)
24 | model: Linear
25 | dim_pe: 18
26 | raw_norm_type: BatchNorm
27 | posenc_LapPE:
28 | enable: True
29 | eigen:
30 | laplacian_norm: none
31 | eigvec_norm: L2
32 | max_freqs: 33
33 | model: DeepSet
34 | dim_pe: 18
35 | layers: 2
36 | n_heads: 4 # Only used when `posenc.model: Transformer`
37 | raw_norm_type: none
38 | train:
39 | mode: custom
40 | batch_size: 16
41 | eval_period: 1
42 | ckpt_period: 100
43 | model:
44 | type: GPSModel
45 | loss_fun: weighted_cross_entropy
46 | edge_decoding: dot
47 | gt: # Hyperparameters optimized for ~500k budget.
48 | layer_type: CustomGatedGCN+GSSC
49 | layers: 24
50 | n_heads: 8
51 | dim_hidden: 36 # `gt.dim_hidden` must match `gnn.dim_inner`
52 | dropout: 0.1
53 | attn_dropout: 0.5
54 | layer_norm: False
55 | batch_norm: True
56 | gnn:
57 | head: inductive_node
58 | layers_pre_mp: 0
59 | layers_post_mp: 3
60 | dim_inner: 36 # `gt.dim_hidden` must match `gnn.dim_inner`
61 | batchnorm: True
62 | act: relu
63 | dropout: 0.0
64 | agg: mean
65 | normalize_adj: False
66 | optim:
67 | clip_grad_norm: True
68 | optimizer: adamW
69 | weight_decay: 0.1
70 | base_lr: 0.001
71 | max_epoch: 300
72 | scheduler: cosine_with_warmup
73 | num_warmup_epochs: 5
74 | seed: 0
75 | name_tag: "random"
76 |
--------------------------------------------------------------------------------
/configs/GSSC/mnist-GSSC-tune.yaml:
--------------------------------------------------------------------------------
1 | program: main.py
2 | project: MNIST
3 | entity: anonymity
4 | name: MNIST-final
5 | method: grid
6 | metric:
7 | goal: maximize
8 | name: best/test_accuracy
9 | parameters:
10 | dropout_res:
11 | value: 0.1
12 | dropout_local:
13 | value: 0.1
14 | dropout_ff:
15 | value: 0.1
16 | base_lr:
17 | value: 0.005
18 | weight_decay:
19 | value: 0.01
20 |
21 | reweigh_self:
22 | value: 2
23 | jk:
24 | value: 0
25 | init_pe_dim:
26 | value: 32
27 | more_mapping:
28 | value: 1
29 | cfg:
30 | value: configs/GSSC/mnist-GSSC.yaml
31 | name_tag:
32 | value: random
33 | log_code:
34 | value: 0
35 | device:
36 | value: 0
37 | seed:
38 | values: [0, 1, 2, 3, 4]
39 |
--------------------------------------------------------------------------------
/configs/GSSC/mnist-GSSC.yaml:
--------------------------------------------------------------------------------
1 | out_dir: results
2 | metric_best: accuracy
3 | wandb:
4 | use: True
5 | project: MNIST
6 | entity: anonymity
7 | dataset:
8 | format: PyG-GNNBenchmarkDataset
9 | name: MNIST
10 | task: graph
11 | task_type: classification
12 | transductive: False
13 | node_encoder: True
14 | node_encoder_name: LapPE
15 | node_encoder_bn: False
16 | edge_encoder: True
17 | edge_encoder_name: LinearEdge
18 | edge_encoder_bn: False
19 | posenc_LapPE:
20 | enable: True
21 | eigen:
22 | laplacian_norm: none
23 | eigvec_norm: L2
24 | max_freqs: 33
25 | model: DeepSet
26 | dim_pe: 16
27 | layers: 2
28 | raw_norm_type: none
29 | train:
30 | mode: custom
31 | batch_size: 16
32 | eval_period: 1
33 | ckpt_period: 100
34 | model:
35 | type: GPSModel
36 | loss_fun: cross_entropy
37 | edge_decoding: dot
38 | graph_pooling: mean
39 | gt: # Hyperparameters optimized for ~100k budget.
40 | layer_type: CustomGatedGCN+GSSC
41 | layers: 3
42 | n_heads: 4
43 | dim_hidden: 52 # `gt.dim_hidden` must match `gnn.dim_inner`
44 | dropout: 0.0
45 | attn_dropout: 0.5
46 | layer_norm: False
47 | batch_norm: True
48 | gnn:
49 | head: default
50 | layers_pre_mp: 0
51 | layers_post_mp: 3
52 | dim_inner: 52 # `gt.dim_hidden` must match `gnn.dim_inner`
53 | batchnorm: False
54 | act: relu
55 | dropout: 0.0
56 | agg: mean
57 | normalize_adj: False
58 | optim:
59 | clip_grad_norm: True
60 | optimizer: adamW
61 | weight_decay: 1e-5
62 | base_lr: 0.001
63 | max_epoch: 300
64 | scheduler: cosine_with_warmup
65 | num_warmup_epochs: 5
66 | seed: 0
67 | name_tag: "random"
68 |
--------------------------------------------------------------------------------
/configs/GSSC/ogbg-molhiv-GSSC-tune.yaml:
--------------------------------------------------------------------------------
1 | program: main.py
2 | project: molhiv
3 | entity: anonymity
4 | name: molhiv-final
5 | method: grid
6 | metric:
7 | goal: maximize
8 | name: best/test_auc
9 | parameters:
10 | dropout_res:
11 | value: 0.0
12 | dropout_local:
13 | value: 0.3
14 | dropout_ff:
15 | value: 0.0
16 | weight_decay:
17 | values: [0.1, 1.0e-3, 1.0e-5]
18 | base_lr:
19 | values: [0.001, 0.002, 0.0005]
20 |
21 | reweigh_self:
22 | value: 2
23 | jk:
24 | value: 0
25 | init_pe_dim:
26 | value: 32
27 | more_mapping:
28 | value: 1
29 | cfg:
30 | value: configs/GSSC/ogbg-molhiv-GSSC.yaml
31 | name_tag:
32 | value: random
33 | log_code:
34 | value: 0
35 | device:
36 | value: 0
37 | seed:
38 | values: [0, 1, 2, 3, 4]
39 |
--------------------------------------------------------------------------------
/configs/GSSC/ogbg-molhiv-GSSC.yaml:
--------------------------------------------------------------------------------
1 | out_dir: results
2 | metric_best: auc
3 | wandb:
4 | use: True
5 | project: molhiv
6 | entity: anonymity
7 | dataset:
8 | format: OGB
9 | name: ogbg-molhiv
10 | task: graph
11 | task_type: classification
12 | transductive: False
13 | node_encoder: True
14 | node_encoder_name: Atom+RWSE
15 | node_encoder_bn: False
16 | edge_encoder: True
17 | edge_encoder_name: Bond
18 | edge_encoder_bn: False
19 | posenc_RWSE:
20 | enable: True
21 | kernel:
22 | times_func: range(1,21)
23 | model: Linear
24 | dim_pe: 28
25 | raw_norm_type: BatchNorm
26 | posenc_LapPE:
27 | enable: True
28 | eigen:
29 | laplacian_norm: none
30 | eigvec_norm: L2
31 | max_freqs: 17
32 | model: DeepSet
33 | dim_pe: 16
34 | layers: 2
35 | raw_norm_type: none
36 | train:
37 | mode: custom
38 | batch_size: 32
39 | eval_period: 1
40 | ckpt_period: 100
41 | model:
42 | type: GPSModel
43 | loss_fun: cross_entropy
44 | edge_decoding: dot
45 | graph_pooling: mean
46 | gt:
47 | layer_type: CustomGatedGCN+GSSC
48 | layers: 6
49 | n_heads: 4
50 | dim_hidden: 64 # `gt.dim_hidden` must match `gnn.dim_inner`
51 | dropout: 0.0
52 | attn_dropout: 0.5
53 | layer_norm: False
54 | batch_norm: True
55 | gnn:
56 | head: san_graph
57 | layers_pre_mp: 0
58 | layers_post_mp: 3 # Not used when `gnn.head: san_graph`
59 | dim_inner: 64 # `gt.dim_hidden` must match `gnn.dim_inner`
60 | batchnorm: True
61 | act: relu
62 | dropout: 0.0
63 | agg: mean
64 | normalize_adj: False
65 | optim:
66 | clip_grad_norm: True
67 | optimizer: adamW
68 | weight_decay: 1e-5
69 | base_lr: 0.0001
70 | max_epoch: 100
71 | scheduler: cosine_with_warmup
72 | num_warmup_epochs: 5
73 | seed: 0
74 | name_tag: "random"
75 |
--------------------------------------------------------------------------------
/configs/GSSC/pattern-GSSC-tune.yaml:
--------------------------------------------------------------------------------
1 | program: main.py
2 | project: PATTERN
3 | entity: anonymity
4 | name: pattern-final
5 | method: grid
6 | metric:
7 | goal: maximize
8 | name: best/test_accuracy-SBM
9 | parameters:
10 | dropout_res:
11 | value: 0.5
12 | dropout_local:
13 | value: 0.1
14 | dropout_ff:
15 | value: 0.1
16 | base_lr:
17 | value: 0.001
18 | weight_decay:
19 | value: 0.1
20 |
21 | reweigh_self:
22 | value: 2
23 | jk:
24 | value: 1
25 | init_pe_dim:
26 | value: 32
27 | more_mapping:
28 | value: 1
29 | cfg:
30 | value: configs/GSSC/pattern-GSSC.yaml
31 | name_tag:
32 | value: random
33 | log_code:
34 | value: 0
35 | device:
36 | value: 0
37 | seed:
38 | values: [0, 1, 2, 3, 4]
39 |
--------------------------------------------------------------------------------
/configs/GSSC/pattern-GSSC.yaml:
--------------------------------------------------------------------------------
1 | out_dir: results
2 | metric_best: accuracy-SBM
3 | wandb:
4 | use: True
5 | project: PATTERN
6 | entity: anonymity
7 | dataset:
8 | format: PyG-GNNBenchmarkDataset
9 | name: PATTERN
10 | task: graph
11 | task_type: classification
12 | transductive: False
13 | split_mode: standard
14 | node_encoder: True
15 | node_encoder_name: RWSE
16 | node_encoder_bn: False
17 | edge_encoder: True
18 | edge_encoder_name: DummyEdge
19 | edge_encoder_bn: False
20 | posenc_RWSE:
21 | enable: True
22 | kernel:
23 | times_func: range(1,22)
24 | model: Linear
25 | dim_pe: 18
26 | raw_norm_type: BatchNorm
27 | posenc_LapPE:
28 | enable: True
29 | eigen:
30 | laplacian_norm: none
31 | eigvec_norm: L2
32 | max_freqs: 33
33 | model: DeepSet
34 | dim_pe: 18
35 | layers: 2
36 | n_heads: 4 # Only used when `posenc.model: Transformer`
37 | raw_norm_type: none
38 | train:
39 | mode: custom
40 | batch_size: 32
41 | eval_period: 1
42 | ckpt_period: 100
43 | model:
44 | type: GPSModel
45 | loss_fun: weighted_cross_entropy
46 | edge_decoding: dot
47 | gt: # Hyperparameters optimized for up to ~500k budget.
48 | layer_type: CustomGatedGCN+GSSC
49 | layers: 24
50 | n_heads: 4
51 | dim_hidden: 36 # `gt.dim_hidden` must match `gnn.dim_inner`
52 | dropout: 0.0
53 | attn_dropout: 0.5
54 | layer_norm: False
55 | batch_norm: True
56 | gnn:
57 | head: inductive_node
58 | layers_pre_mp: 0
59 | layers_post_mp: 3
60 | dim_inner: 36 # `gt.dim_hidden` must match `gnn.dim_inner`
61 | batchnorm: True
62 | act: relu
63 | dropout: 0.0
64 | agg: mean
65 | normalize_adj: False
66 | optim:
67 | clip_grad_norm: True
68 | optimizer: adamW
69 | weight_decay: 0.0001
70 | base_lr: 0.001
71 | max_epoch: 200
72 | scheduler: cosine_with_warmup
73 | num_warmup_epochs: 5
74 | seed: 0
75 | name_tag: "random"
76 |
--------------------------------------------------------------------------------
/configs/GSSC/peptides-func-GSSC-tune.yaml:
--------------------------------------------------------------------------------
1 | program: main.py
2 | project: LRGB
3 | entity: anonymity
4 | name: peptides-final
5 | method: grid
6 | metric:
7 | goal: maximize
8 | name: best/test_ap
9 | parameters:
10 | dropout_res:
11 | value: 0.1
12 | dropout_local:
13 | value: 0.1
14 | dropout_ff:
15 | values: [0.1, 0.0]
16 | base_lr:
17 | value: 0.003
18 | weight_decay:
19 | value: 0.1
20 |
21 | reweigh_self:
22 | value: 2
23 | jk:
24 | value: 0
25 | init_pe_dim:
26 | value: 32
27 | more_mapping:
28 | value: 1
29 | cfg:
30 | value: configs/GSSC/peptides-func-GSSC.yaml
31 | name_tag:
32 | value: random
33 | log_code:
34 | value: 0
35 | device:
36 | value: 0
37 | seed:
38 | values: [0, 1, 2, 3, 4]
39 |
--------------------------------------------------------------------------------
/configs/GSSC/peptides-func-GSSC.yaml:
--------------------------------------------------------------------------------
1 | out_dir: results
2 | metric_best: ap
3 | wandb:
4 | use: True
5 | project: LRGB-Benchmark-RandomSeed
6 | entity: anonymity
7 | dataset:
8 | format: OGB
9 | name: peptides-functional
10 | task: graph
11 | task_type: classification_multilabel
12 | transductive: False
13 | node_encoder: True
14 | node_encoder_name: Atom+LapPE
15 | node_encoder_bn: False
16 | edge_encoder: True
17 | edge_encoder_name: Bond
18 | edge_encoder_bn: False
19 | split_mode: standard
20 | posenc_LapPE:
21 | enable: True
22 | eigen:
23 | laplacian_norm: none
24 | eigvec_norm: L2
25 | max_freqs: 32
26 | model: DeepSet
27 | dim_pe: 16
28 | layers: 2
29 | raw_norm_type: none
30 | train:
31 | mode: custom
32 | batch_size: 128
33 | eval_period: 1
34 | ckpt_period: 100
35 | model:
36 | type: GPSModel
37 | loss_fun: cross_entropy
38 | graph_pooling: mean
39 | gt:
40 | layer_type: CustomGatedGCN+GSSC
41 | n_heads: 4
42 | dim_hidden: 100 # `gt.dim_hidden` must match `gnn.dim_inner`
43 | dropout: 0.0
44 | attn_dropout: 0.5
45 | layer_norm: False
46 | batch_norm: True
47 | gnn:
48 | head: default
49 | layers_pre_mp: 0
50 | layers_post_mp: 1 # Not used when `gnn.head: san_graph`
51 | dim_inner: 100 # `gt.dim_hidden` must match `gnn.dim_inner`
52 | batchnorm: True
53 | act: relu
54 | dropout: 0.0
55 | optim:
56 | clip_grad_norm: True
57 | optimizer: adamW
58 | weight_decay: 0.1
59 | base_lr: 0.003
60 | max_epoch: 200
61 | scheduler: cosine_with_warmup
62 | num_warmup_epochs: 10
63 | seed: 0
64 | name_tag: "random"
65 |
--------------------------------------------------------------------------------
/configs/GSSC/peptides-struct-GSSC-tune.yaml:
--------------------------------------------------------------------------------
1 | program: main.py
2 | project: LRGB
3 | entity: anonymity
4 | name: peptides-final
5 | method: grid
6 | metric:
7 | goal: minimize
8 | name: best/test_mae
9 | parameters:
10 | dropout_res:
11 | values: [0.3, 0.1]
12 | dropout_local:
13 | values: [0.1, 0.3]
14 | dropout_ff:
15 | value: 0.1
16 | base_lr:
17 | value: 0.001
18 | weight_decay:
19 | value: 0.1
20 |
21 | reweigh_self:
22 | value: 2
23 | jk:
24 | value: 1
25 | init_pe_dim:
26 | value: 32
27 | more_mapping:
28 | value: 1
29 | cfg:
30 | value: configs/GSSC/peptides-struct-GSSC.yaml
31 | name_tag:
32 | value: random
33 | log_code:
34 | value: 0
35 | device:
36 | value: 0
37 | seed:
38 | values: [0, 1, 2, 3, 4]
39 |
--------------------------------------------------------------------------------
/configs/GSSC/peptides-struct-GSSC.yaml:
--------------------------------------------------------------------------------
1 | out_dir: results
2 | metric_best: mae
3 | metric_agg: argmin
4 | wandb:
5 | use: True
6 | project: LRGB-Benchmark-RandomSeed
7 | entity: anonymity
8 | dataset:
9 | format: OGB
10 | name: peptides-structural
11 | task: graph
12 | task_type: regression
13 | transductive: False
14 | node_encoder: True
15 | node_encoder_name: Atom+LapPE
16 | node_encoder_bn: False
17 | edge_encoder: True
18 | edge_encoder_name: Bond
19 | edge_encoder_bn: False
20 | posenc_LapPE:
21 | enable: True
22 | eigen:
23 | laplacian_norm: none
24 | eigvec_norm: L2
25 | max_freqs: 32
26 | model: DeepSet
27 | dim_pe: 16
28 | layers: 2
29 | raw_norm_type: none
30 | train:
31 | mode: custom
32 | batch_size: 128
33 | eval_period: 1
34 | ckpt_period: 100
35 | model:
36 | type: GPSModel
37 | loss_fun: l1
38 | graph_pooling: mean
39 | gt:
40 | layer_type: CustomGatedGCN+GSSC
41 | layers: 3
42 | n_heads: 4
43 | dim_hidden: 100 # `gt.dim_hidden` must match `gnn.dim_inner`
44 | dropout: 0.0
45 | attn_dropout: 0.5
46 | layer_norm: False
47 | batch_norm: True
48 | gnn:
49 | head: san_graph
50 | layers_pre_mp: 0
51 | layers_post_mp: 2 # Not used when `gnn.head: san_graph`
52 | dim_inner: 100 # `gt.dim_hidden` must match `gnn.dim_inner`
53 | batchnorm: True
54 | act: relu
55 | dropout: 0.0
56 | optim:
57 | clip_grad_norm: True
58 | optimizer: adamW
59 | weight_decay: 0.1
60 | base_lr: 0.003
61 | max_epoch: 200
62 | scheduler: cosine_with_warmup
63 | num_warmup_epochs: 10
64 | seed: 0
65 | name_tag: "random"
66 |
--------------------------------------------------------------------------------
/configs/GSSC/vocsuperpixels-GSSC-tune.yaml:
--------------------------------------------------------------------------------
1 | program: main.py
2 | project: LRGB
3 | entity: anonymity
4 | name: vocsuperpixels-final
5 | method: grid
6 | metric:
7 | goal: maximize
8 | name: best/test_f1
9 | parameters:
10 | dropout_res:
11 | value: 0.5
12 | dropout_local:
13 | value: 0.0
14 | dropout_ff:
15 | value: 0.0
16 | weight_decay:
17 | value: 0.1
18 | base_lr:
19 | value: 0.002
20 |
21 | reweigh_self:
22 | value: 0
23 | jk:
24 | value: 0
25 | init_pe_dim:
26 | value: 32
27 | more_mapping:
28 | value: 0
29 | cfg:
30 | value: configs/GSSC/vocsuperpixels-GSSC.yaml
31 | name_tag:
32 | value: random
33 | log_code:
34 | value: 0
35 | device:
36 | value: 0
37 | seed:
38 | values: [0, 1, 2, 3, 4]
39 |
--------------------------------------------------------------------------------
/configs/GSSC/vocsuperpixels-GSSC.yaml:
--------------------------------------------------------------------------------
1 | out_dir: results
2 | metric_best: f1
3 | wandb:
4 | project: LRGB-Benchmark-RandomSeed
5 | entity: anonymity
6 | use: True
7 | dataset:
8 | format: PyG-VOCSuperpixels
9 | name: edge_wt_region_boundary
10 | slic_compactness: 30
11 | task: graph # Even if VOC is node-level task, this needs to be set as 'graph'
12 | task_type: classification
13 | transductive: False
14 | node_encoder: True
15 | node_encoder_name: VOCNode+LapPE
16 | node_encoder_bn: False
17 | edge_encoder: True
18 | edge_encoder_name: VOCEdge
19 | edge_encoder_bn: False
20 | posenc_LapPE:
21 | enable: True
22 | eigen:
23 | laplacian_norm: none
24 | eigvec_norm: L2
25 | max_freqs: 64
26 | model: DeepSet
27 | dim_pe: 16
28 | layers: 2
29 | raw_norm_type: none
30 | train:
31 | mode: custom
32 | batch_size: 32
33 | eval_period: 1
34 | ckpt_period: 100
35 | model:
36 | type: GPSModel
37 | loss_fun: weighted_cross_entropy
38 | gt:
39 | layer_type: CustomGatedGCN+GSSC
40 | layers: 4
41 | n_heads: 8
42 | dim_hidden: 96 # `gt.dim_hidden` must match `gnn.dim_inner`
43 | dropout: 0.0
44 | attn_dropout: 0.5
45 | layer_norm: False
46 | batch_norm: True
47 | gnn:
48 | head: inductive_node
49 | layers_pre_mp: 0
50 | layers_post_mp: 3
51 | dim_inner: 96 # `gt.dim_hidden` must match `gnn.dim_inner`
52 | batchnorm: True
53 | act: relu
54 | dropout: 0.0
55 | agg: mean
56 | normalize_adj: False
57 | optim:
58 | clip_grad_norm: True
59 | optimizer: adamW
60 | weight_decay: 0.1
61 | base_lr: 0.002
62 | max_epoch: 300
63 | scheduler: cosine_with_warmup
64 | num_warmup_epochs: 10
65 | seed: 0
66 | name_tag: "random"
67 |
--------------------------------------------------------------------------------
/configs/GSSC/zinc-GSSC-tune.yaml:
--------------------------------------------------------------------------------
1 | program: main.py
2 | project: ZINC
3 | entity: anonymity
4 | name: zinc-final
5 | method: grid
6 | metric:
7 | goal: minimize
8 | name: best/test_mae
9 | parameters:
10 | dropout_res:
11 | values: [0.5, 0.6]
12 | dropout_local:
13 | value: 0.0
14 | dropout_ff:
15 | value: 0.1
16 | weight_decay:
17 | value: 1.0e-5
18 | base_lr:
19 | value: 1.0e-3
20 |
21 | reweigh_self:
22 | value: 2
23 | jk:
24 | value: 1
25 | init_pe_dim:
26 | value: 32
27 | more_mapping:
28 | value: 1
29 | cfg:
30 | value: configs/GSSC/zinc-GSSC.yaml
31 | name_tag:
32 | value: random
33 | log_code:
34 | value: 0
35 | device:
36 | value: 0
37 | seed:
38 | values: [0, 1, 2, 3, 4]
39 |
--------------------------------------------------------------------------------
/configs/GSSC/zinc-GSSC.yaml:
--------------------------------------------------------------------------------
1 | out_dir: results
2 | metric_best: mae
3 | metric_agg: argmin
4 | wandb:
5 | use: True
6 | project: ZINC
7 | entity: anonymity
8 | dataset:
9 | format: PyG-ZINC
10 | name: subset
11 | task: graph
12 | task_type: regression
13 | transductive: False
14 | node_encoder: True
15 | node_encoder_name: TypeDictNode+RWSE
16 | node_encoder_num_types: 21
17 | node_encoder_bn: False
18 | edge_encoder: True
19 | edge_encoder_name: TypeDictEdge
20 | edge_encoder_num_types: 4
21 | edge_encoder_bn: False
22 | posenc_RWSE:
23 | enable: True
24 | kernel:
25 | times_func: range(1,21)
26 | model: Linear
27 | dim_pe: 28
28 | raw_norm_type: BatchNorm
29 | posenc_LapPE:
30 | enable: True
31 | eigen:
32 | laplacian_norm: none
33 | eigvec_norm: L2
34 | max_freqs: 17
35 | model: DeepSet
36 | dim_pe: 16
37 | layers: 2
38 | raw_norm_type: none
39 | train:
40 | mode: custom
41 | batch_size: 32
42 | eval_period: 1
43 | ckpt_period: 100
44 | model:
45 | type: GPSModel
46 | loss_fun: l1
47 | edge_decoding: dot
48 | graph_pooling: add
49 | gt:
50 | layer_type: GINE+GSSC # CustomGatedGCN+Performer
51 | layers: 10
52 | n_heads: 4
53 | dim_hidden: 64 # `gt.dim_hidden` must match `gnn.dim_inner`
54 | dropout: 0.0
55 | attn_dropout: 0.5
56 | layer_norm: False
57 | batch_norm: True
58 | gnn:
59 | head: san_graph
60 | layers_pre_mp: 0
61 | layers_post_mp: 3 # Not used when `gnn.head: san_graph`
62 | dim_inner: 64 # `gt.dim_hidden` must match `gnn.dim_inner`
63 | batchnorm: True
64 | act: relu
65 | dropout: 0.0
66 | agg: mean
67 | normalize_adj: False
68 | optim:
69 | clip_grad_norm: True
70 | optimizer: adamW
71 | weight_decay: 1.0e-5
72 | base_lr: 1.0e-3
73 | max_epoch: 2000
74 | scheduler: cosine_with_warmup
75 | num_warmup_epochs: 200
76 | seed: 0
77 | name_tag: "random"
78 |
--------------------------------------------------------------------------------
/configs/GSSC/zincfull-GSSC-tune.yaml:
--------------------------------------------------------------------------------
1 | program: main.py
2 | project: ZINC
3 | entity: anonymity
4 | name: zincfull-final
5 | method: grid
6 | metric:
7 | goal: minimize
8 | name: best/test_mae
9 | parameters:
10 | dropout_res:
11 | values: [0.1, 0.3]
12 | dropout_local:
13 | value: 0.0
14 | dropout_ff:
15 | value: 0.1
16 | weight_decay:
17 | values: [1.0e-3, 1.0e-05]
18 | base_lr:
19 | value: 0.002
20 |
21 | reweigh_self:
22 | value: 1
23 | jk:
24 | value: 0
25 | init_pe_dim:
26 | value: 32
27 | more_mapping:
28 | value: 1
29 | cfg:
30 | value: configs/GSSC/zincfull-GSSC.yaml
31 | name_tag:
32 | value: random
33 | log_code:
34 | value: 0
35 | device:
36 | value: 0
37 | seed:
38 | values: [0, 1, 2]
39 |
--------------------------------------------------------------------------------
/configs/GSSC/zincfull-GSSC.yaml:
--------------------------------------------------------------------------------
1 | out_dir: results
2 | metric_best: mae
3 | metric_agg: argmin
4 | wandb:
5 | use: True
6 | project: ZINC
7 | entity: anonymity
8 | dataset:
9 | format: PyG-ZINC
10 | name: full
11 | task: graph
12 | task_type: regression
13 | transductive: False
14 | node_encoder: True
15 | node_encoder_name: TypeDictNode+RWSE
16 | node_encoder_num_types: 28
17 | node_encoder_bn: False
18 | edge_encoder: True
19 | edge_encoder_name: TypeDictEdge
20 | edge_encoder_num_types: 4
21 | edge_encoder_bn: False
22 | posenc_RWSE:
23 | enable: True
24 | kernel:
25 | times_func: range(1,21)
26 | model: Linear
27 | dim_pe: 28
28 | raw_norm_type: BatchNorm
29 | posenc_LapPE:
30 | enable: True
31 | eigen:
32 | laplacian_norm: none
33 | eigvec_norm: L2
34 | max_freqs: 17
35 | model: DeepSet
36 | dim_pe: 16
37 | layers: 2
38 | raw_norm_type: none
39 | train:
40 | mode: custom
41 | batch_size: 128
42 | eval_period: 1
43 | ckpt_period: 100
44 | model:
45 | type: GPSModel
46 | loss_fun: l1
47 | edge_decoding: dot
48 | graph_pooling: add
49 | gt:
50 | layer_type: GINE+GSSC # CustomGatedGCN+Performer
51 | layers: 10
52 | n_heads: 4
53 | dim_hidden: 64 # `gt.dim_hidden` must match `gnn.dim_inner`
54 | dropout: 0.0
55 | attn_dropout: 0.5
56 | layer_norm: False
57 | batch_norm: True
58 | gnn:
59 | head: san_graph
60 | layers_pre_mp: 0
61 | layers_post_mp: 3 # Not used when `gnn.head: san_graph`
62 | dim_inner: 64 # `gt.dim_hidden` must match `gnn.dim_inner`
63 | batchnorm: True
64 | act: relu
65 | dropout: 0.0
66 | agg: mean
67 | normalize_adj: False
68 | optim:
69 | clip_grad_norm: True
70 | optimizer: adamW
71 | weight_decay: 1.0e-3
72 | base_lr: 2.0e-3
73 | max_epoch: 2000
74 | scheduler: cosine_with_warmup
75 | num_warmup_epochs: 50
76 | seed: 0
77 | name_tag: "random"
78 | extra:
79 | check_if_pe_done: 1
80 |
--------------------------------------------------------------------------------
/datasets/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: gssc
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=conda_forge
9 | - _openmp_mutex=4.5=2_gnu
10 | - asttokens=2.4.1=pyhd8ed1ab_0
11 | - blas=1.0=mkl
12 | - brotli-python=1.0.9=py39h6a678d5_7
13 | - bzip2=1.0.8=h5eee18b_5
14 | - ca-certificates=2024.3.11=h06a4308_0
15 | - certifi=2024.2.2=pyhd8ed1ab_0
16 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
17 | - comm=0.2.2=pyhd8ed1ab_0
18 | - cuda-cudart=11.7.99=0
19 | - cuda-cupti=11.7.101=0
20 | - cuda-libraries=11.7.1=0
21 | - cuda-nvrtc=11.7.99=0
22 | - cuda-nvtx=11.7.91=0
23 | - cuda-runtime=11.7.1=0
24 | - debugpy=1.6.7=py39h6a678d5_0
25 | - decorator=5.1.1=pyhd8ed1ab_0
26 | - exceptiongroup=1.2.0=pyhd8ed1ab_2
27 | - executing=2.0.1=pyhd8ed1ab_0
28 | - ffmpeg=4.3=hf484d3e_0
29 | - filelock=3.13.1=py39h06a4308_0
30 | - freetype=2.12.1=h4a9f257_0
31 | - gmp=6.2.1=h295c915_3
32 | - gmpy2=2.1.2=py39heeb90bb_0
33 | - gnutls=3.6.15=he1e5248_0
34 | - idna=3.4=py39h06a4308_0
35 | - importlib-metadata=7.1.0=pyha770c72_0
36 | - importlib_metadata=7.1.0=hd8ed1ab_0
37 | - intel-openmp=2023.1.0=hdb19cb5_46306
38 | - ipykernel=6.29.3=pyhd33586a_0
39 | - ipython=8.18.1=pyh707e725_3
40 | - jedi=0.19.1=pyhd8ed1ab_0
41 | - jinja2=3.1.3=py39h06a4308_0
42 | - jpeg=9e=h5eee18b_1
43 | - jupyter_client=8.6.1=pyhd8ed1ab_0
44 | - jupyter_core=5.7.2=py39hf3d152e_0
45 | - lame=3.100=h7b6447c_0
46 | - lcms2=2.12=h3be6417_0
47 | - ld_impl_linux-64=2.38=h1181459_1
48 | - lerc=3.0=h295c915_0
49 | - libcublas=11.10.3.66=0
50 | - libcufft=10.7.2.124=h4fbf590_0
51 | - libcufile=1.9.0.20=0
52 | - libcurand=10.3.5.119=0
53 | - libcusolver=11.4.0.1=0
54 | - libcusparse=11.7.4.91=0
55 | - libdeflate=1.17=h5eee18b_1
56 | - libffi=3.4.4=h6a678d5_0
57 | - libgcc-ng=13.2.0=h807b86a_5
58 | - libgomp=13.2.0=h807b86a_5
59 | - libiconv=1.16=h7f8727e_2
60 | - libidn2=2.3.4=h5eee18b_0
61 | - libnpp=11.7.4.75=0
62 | - libnvjpeg=11.8.0.2=0
63 | - libpng=1.6.39=h5eee18b_0
64 | - libsodium=1.0.18=h36c2ea0_1
65 | - libstdcxx-ng=11.2.0=h1234567_1
66 | - libtasn1=4.19.0=h5eee18b_0
67 | - libtiff=4.5.1=h6a678d5_0
68 | - libunistring=0.9.10=h27cfd23_0
69 | - libwebp-base=1.3.2=h5eee18b_0
70 | - lz4-c=1.9.4=h6a678d5_0
71 | - markupsafe=2.1.3=py39h5eee18b_0
72 | - matplotlib-inline=0.1.6=pyhd8ed1ab_0
73 | - mkl=2023.1.0=h213fc3f_46344
74 | - mkl-service=2.4.0=py39h5eee18b_1
75 | - mkl_fft=1.3.8=py39h5eee18b_0
76 | - mkl_random=1.2.4=py39hdb19cb5_0
77 | - mpc=1.1.0=h10f8cd9_1
78 | - mpfr=4.0.2=hb69a4c5_1
79 | - mpmath=1.3.0=py39h06a4308_0
80 | - ncurses=6.4=h6a678d5_0
81 | - nest-asyncio=1.6.0=pyhd8ed1ab_0
82 | - nettle=3.7.3=hbbd107a_1
83 | - networkx=3.1=py39h06a4308_0
84 | - numpy=1.26.4=py39h5f9d8c6_0
85 | - numpy-base=1.26.4=py39hb5e798b_0
86 | - openh264=2.1.1=h4ff587b_0
87 | - openjpeg=2.4.0=h3ad879b_0
88 | - openssl=3.2.1=hd590300_1
89 | - packaging=24.0=pyhd8ed1ab_0
90 | - parso=0.8.3=pyhd8ed1ab_0
91 | - pexpect=4.9.0=pyhd8ed1ab_0
92 | - pickleshare=0.7.5=py_1003
93 | - pillow=10.2.0=py39h5eee18b_0
94 | - pip=23.3.1=py39h06a4308_0
95 | - platformdirs=4.2.0=pyhd8ed1ab_0
96 | - prompt-toolkit=3.0.42=pyha770c72_0
97 | - psutil=5.9.8=py39hd1e30aa_0
98 | - ptyprocess=0.7.0=pyhd3deb0d_0
99 | - pure_eval=0.2.2=pyhd8ed1ab_0
100 | - pygments=2.17.2=pyhd8ed1ab_0
101 | - pysocks=1.7.1=py39h06a4308_0
102 | - python=3.9.19=h955ad1f_0
103 | - python_abi=3.9=2_cp39
104 | - pytorch=2.0.0=py3.9_cuda11.7_cudnn8.5.0_0
105 | - pytorch-cuda=11.7=h778d358_5
106 | - pytorch-mutex=1.0=cuda
107 | - pyzmq=25.1.2=py39h6a678d5_0
108 | - readline=8.2=h5eee18b_0
109 | - requests=2.31.0=py39h06a4308_1
110 | - setuptools=68.2.2=py39h06a4308_0
111 | - six=1.16.0=pyh6c4a22f_0
112 | - sqlite=3.41.2=h5eee18b_0
113 | - stack_data=0.6.2=pyhd8ed1ab_0
114 | - sympy=1.12=py39h06a4308_0
115 | - tbb=2021.8.0=hdb19cb5_0
116 | - tk=8.6.12=h1ccaba5_0
117 | - torchaudio=2.0.0=py39_cu117
118 | - torchtriton=2.0.0=py39
119 | - torchvision=0.15.0=py39_cu117
120 | - tornado=6.4=py39hd1e30aa_0
121 | - traitlets=5.14.2=pyhd8ed1ab_0
122 | - typing_extensions=4.9.0=py39h06a4308_1
123 | - urllib3=2.1.0=py39h06a4308_1
124 | - wcwidth=0.2.13=pyhd8ed1ab_0
125 | - wheel=0.41.2=py39h06a4308_0
126 | - xz=5.4.6=h5eee18b_0
127 | - zeromq=4.3.5=h6a678d5_0
128 | - zipp=3.17.0=pyhd8ed1ab_0
129 | - zlib=1.2.13=h5eee18b_0
130 | - zstd=1.5.5=hc292b87_0
131 | - pip:
132 | - aiohttp==3.9.3
133 | - aiosignal==1.3.1
134 | - annotated-types==0.6.0
135 | - appdirs==1.4.4
136 | - argparse==1.4.0
137 | - async-timeout==4.0.3
138 | - attrs==23.2.0
139 | - automat==22.10.0
140 | - axial-positional-embedding==0.2.1
141 | - buildtools==1.0.6
142 | - causal-conv1d==1.2.0.post2
143 | - click==8.1.7
144 | - cmake==3.29.0.1
145 | - constantly==23.10.4
146 | - deepspeed==0.14.0
147 | - docker-pycreds==0.4.0
148 | - docopt==0.6.2
149 | - einops==0.7.0
150 | - frozenlist==1.4.1
151 | - fsspec==2024.3.1
152 | - furl==2.1.3
153 | - gitdb==4.0.11
154 | - gitpython==3.1.43
155 | - greenlet==3.0.3
156 | - hjson==3.1.0
157 | - huggingface-hub==0.22.2
158 | - hyperlink==21.0.0
159 | - incremental==22.10.0
160 | - joblib==1.3.2
161 | - lightning-utilities==0.11.2
162 | - lit==18.1.2
163 | - littleutils==0.2.2
164 | - local-attention==1.9.0
165 | - mamba-ssm==1.2.0.post1
166 | - multidict==6.0.5
167 | - ninja==1.11.1.1
168 | - ogb==1.3.6
169 | - openbabel-wheel==3.1.1.19
170 | - orderedmultidict==1.0.1
171 | - outdated==0.2.2
172 | - pandas==2.2.1
173 | - performer-pytorch==1.1.4
174 | - protobuf==4.25.3
175 | - py-cpuinfo==9.0.0
176 | - pydantic==2.6.4
177 | - pydantic-core==2.16.3
178 | - pynvml==11.5.0
179 | - pyparsing==3.1.2
180 | - python-dateutil==2.9.0.post0
181 | - pytorch-lightning==2.2.1
182 | - pytz==2024.1
183 | - pyyaml==6.0.1
184 | - rdkit==2023.9.5
185 | - redo==2.0.4
186 | - regex==2023.12.25
187 | - safetensors==0.4.2
188 | - scikit-learn==1.4.1.post1
189 | - scipy==1.12.0
190 | - sentry-sdk==1.44.0
191 | - setproctitle==1.3.3
192 | - simplejson==3.19.2
193 | - smmap==5.0.1
194 | - sqlalchemy==2.0.29
195 | - tensorboardx==2.6.2.2
196 | - threadpoolctl==3.4.0
197 | - tokenizers==0.15.2
198 | - torchmetrics==0.9.3
199 | - tqdm==4.66.2
200 | - transformers==4.39.3
201 | - twisted==24.3.0
202 | - tzdata==2024.1
203 | - wandb==0.16.5
204 | - yacs==0.1.8
205 | - yarl==1.9.4
206 | - zope-interface==6.2
207 | # pip install torch_geometric==2.0.4
208 | # pip install torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-2.0.0+cu117.html
209 |
--------------------------------------------------------------------------------
/gssc/__init__.py:
--------------------------------------------------------------------------------
1 | from .act import * # noqa
2 | from .config import * # noqa
3 | from .encoder import * # noqa
4 | from .head import * # noqa
5 | from .layer import * # noqa
6 | from .loader import * # noqa
7 | from .loss import * # noqa
8 | from .network import * # noqa
9 | from .optimizer import * # noqa
10 | from .pooling import * # noqa
11 | from .stage import * # noqa
12 | from .train import * # noqa
13 | from .transform import * # noqa
--------------------------------------------------------------------------------
/gssc/act/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/gssc/act/example.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from torch_geometric.graphgym.config import cfg
5 | from torch_geometric.graphgym.register import register_act
6 |
7 |
8 | class SWISH(nn.Module):
9 | def __init__(self, inplace=False):
10 | super().__init__()
11 | self.inplace = inplace
12 |
13 | def forward(self, x):
14 | if self.inplace:
15 | x.mul_(torch.sigmoid(x))
16 | return x
17 | else:
18 | return x * torch.sigmoid(x)
19 |
20 |
21 | register_act('swish', SWISH(inplace=cfg.mem.inplace))
22 | register_act('lrelu_03', nn.LeakyReLU(0.3, inplace=cfg.mem.inplace))
23 |
--------------------------------------------------------------------------------
/gssc/agg_runs.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import numpy as np
5 |
6 | from torch_geometric.graphgym.config import cfg
7 | from torch_geometric.graphgym.utils.io import (
8 | dict_list_to_json,
9 | dict_list_to_tb,
10 | dict_to_json,
11 | json_to_dict_list,
12 | makedirs_rm_exist,
13 | string_to_python,
14 | )
15 |
16 | try:
17 | from tensorboardX import SummaryWriter
18 | except ImportError:
19 | SummaryWriter = None
20 |
21 |
22 | def is_seed(s):
23 | try:
24 | int(s)
25 | return True
26 | except Exception:
27 | return False
28 |
29 |
30 | def is_split(s):
31 | if s in ['train', 'val', 'test']:
32 | return True
33 | else:
34 | return False
35 |
36 |
37 | def join_list(l1, l2):
38 | assert len(l1) == len(l2), \
39 | 'Results with different seeds must have the save format'
40 | for i in range(len(l1)):
41 | l1[i] += l2[i]
42 | return l1
43 |
44 |
45 | def agg_dict_list(dict_list):
46 | """
47 | Aggregate a list of dictionaries: mean + std
48 | Args:
49 | dict_list: list of dictionaries
50 |
51 | """
52 | dict_agg = {'epoch': dict_list[0]['epoch']}
53 | for key in dict_list[0]:
54 | if key != 'epoch':
55 | value = np.array([dict[key] for dict in dict_list])
56 | dict_agg[key] = np.mean(value).round(cfg.round)
57 | dict_agg['{}_std'.format(key)] = np.std(value).round(cfg.round)
58 | return dict_agg
59 |
60 |
61 | def name_to_dict(run):
62 | run = run.split('-', 1)[-1]
63 | cols = run.split('=')
64 | keys, vals = [], []
65 | keys.append(cols[0])
66 | for col in cols[1:-1]:
67 | try:
68 | val, key = col.rsplit('-', 1)
69 | except Exception:
70 | print(col)
71 | keys.append(key)
72 | vals.append(string_to_python(val))
73 | vals.append(cols[-1])
74 | return dict(zip(keys, vals))
75 |
76 |
77 | def rm_keys(dict, keys):
78 | for key in keys:
79 | dict.pop(key, None)
80 |
81 |
82 | def agg_runs(dir, metric_best='auto'):
83 | r'''
84 | Aggregate over different random seeds of a single experiment
85 |
86 | Args:
87 | dir (str): Directory of the results, containing 1 experiment
88 | metric_best (str, optional): The metric for selecting the best
89 | validation performance. Options: auto, accuracy, auc.
90 |
91 | '''
92 | results = {'train': None, 'val': None, 'test': None}
93 | results_best = {'train': None, 'val': None, 'test': None}
94 | for seed in os.listdir(dir):
95 | if is_seed(seed):
96 | dir_seed = os.path.join(dir, seed)
97 |
98 | split = 'val'
99 | if split in os.listdir(dir_seed):
100 | dir_split = os.path.join(dir_seed, split)
101 | fname_stats = os.path.join(dir_split, 'stats.json')
102 | stats_list = json_to_dict_list(fname_stats)
103 | if metric_best == 'auto':
104 | metric = 'auc' if 'auc' in stats_list[0] else 'accuracy'
105 | else:
106 | metric = metric_best
107 | performance_np = np.array( # noqa
108 | [stats[metric] for stats in stats_list])
109 | best_epoch = \
110 | stats_list[
111 | eval("performance_np.{}()".format(cfg.metric_agg))][
112 | 'epoch']
113 | print(best_epoch)
114 |
115 | for split in os.listdir(dir_seed):
116 | if is_split(split):
117 | dir_split = os.path.join(dir_seed, split)
118 | fname_stats = os.path.join(dir_split, 'stats.json')
119 | stats_list = json_to_dict_list(fname_stats)
120 | stats_best = [
121 | stats for stats in stats_list
122 | if stats['epoch'] == best_epoch
123 | ][0]
124 | print(stats_best)
125 | stats_list = [[stats] for stats in stats_list]
126 | if results[split] is None:
127 | results[split] = stats_list
128 | else:
129 | results[split] = join_list(results[split], stats_list)
130 | if results_best[split] is None:
131 | results_best[split] = [stats_best]
132 | else:
133 | results_best[split] += [stats_best]
134 | results = {k: v for k, v in results.items() if v is not None} # rm None
135 | results_best = {k: v
136 | for k, v in results_best.items()
137 | if v is not None} # rm None
138 | for key in results:
139 | for i in range(len(results[key])):
140 | results[key][i] = agg_dict_list(results[key][i])
141 | for key in results_best:
142 | results_best[key] = agg_dict_list(results_best[key])
143 | # save aggregated results
144 | for key, value in results.items():
145 | dir_out = os.path.join(dir, 'agg', key)
146 | makedirs_rm_exist(dir_out)
147 | fname = os.path.join(dir_out, 'stats.json')
148 | dict_list_to_json(value, fname)
149 |
150 | if cfg.tensorboard_agg:
151 | if SummaryWriter is None:
152 | raise ImportError(
153 | 'Tensorboard support requires `tensorboardX`.')
154 | writer = SummaryWriter(dir_out)
155 | dict_list_to_tb(value, writer)
156 | writer.close()
157 | for key, value in results_best.items():
158 | dir_out = os.path.join(dir, 'agg', key)
159 | fname = os.path.join(dir_out, 'best.json')
160 | dict_to_json(value, fname)
161 | logging.info('Results aggregated across runs saved in {}'.format(
162 | os.path.join(dir, 'agg')))
163 |
--------------------------------------------------------------------------------
/gssc/config/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/gssc/config/custom_gnn_config.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.graphgym.register import register_config
2 |
3 |
4 | @register_config('custom_gnn')
5 | def custom_gnn_cfg(cfg):
6 | """Extending config group of GraphGym's built-in GNN for purposes of our
7 | CustomGNN network model.
8 | """
9 |
10 | # Use residual connections between the GNN layers.
11 | cfg.gnn.residual = False
12 |
--------------------------------------------------------------------------------
/gssc/config/data_preprocess_config.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.graphgym.register import register_config
2 | from yacs.config import CfgNode as CN
3 |
4 |
5 | def set_cfg_preprocess(cfg):
6 | """Extend configuration with preprocessing options
7 | """
8 |
9 | cfg.prep = CN()
10 |
11 | # Argument group for adding expander edges
12 |
13 | # if it's enabled expander edges would be available by e.g. data.expander_edges
14 | cfg.prep.exp = False
15 | cfg.prep.exp_algorithm = 'Random-d' #Other option is 'Hamiltonian'
16 | cfg.prep.use_exp_edges = True
17 | cfg.prep.exp_deg = 5
18 | cfg.prep.exp_max_num_iters = 100
19 | cfg.prep.add_edge_index = True
20 | cfg.prep.num_virt_node = 0
21 | cfg.prep.exp_count = 1
22 | cfg.prep.add_self_loops = False
23 | cfg.prep.add_reverse_edges = True
24 | cfg.prep.train_percent = 0.6
25 | cfg.prep.layer_edge_indices_dir = None
26 |
27 |
28 |
29 | # Argument group for adding node distances
30 | cfg.prep.dist_enable = False
31 | cfg.prep.dist_cutoff = 510
32 |
33 |
34 | register_config('preprocess', set_cfg_preprocess)
35 |
--------------------------------------------------------------------------------
/gssc/config/dataset_config.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.graphgym.register import register_config
2 |
3 |
4 | @register_config('dataset_cfg')
5 | def dataset_cfg(cfg):
6 | """Dataset-specific config options.
7 | """
8 |
9 | # The number of node types to expect in TypeDictNodeEncoder.
10 | cfg.dataset.node_encoder_num_types = 0
11 |
12 | # The number of edge types to expect in TypeDictEdgeEncoder.
13 | cfg.dataset.edge_encoder_num_types = 0
14 |
15 | # VOC/COCO Superpixels dataset version based on SLIC compactness parameter.
16 | cfg.dataset.slic_compactness = 10
17 |
--------------------------------------------------------------------------------
/gssc/config/defaults_config.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.graphgym.register import register_config
2 |
3 |
4 | @register_config('overwrite_defaults')
5 | def overwrite_defaults_cfg(cfg):
6 | """Overwrite the default config values that are first set by GraphGym in
7 | torch_geometric.graphgym.config.set_cfg
8 |
9 | WARNING: At the time of writing, the order in which custom config-setting
10 | functions like this one are executed is random; see the referenced `set_cfg`
11 | Therefore never reset here config options that are custom added, only change
12 | those that exist in core GraphGym.
13 | """
14 |
15 | # Overwrite default dataset name
16 | cfg.dataset.name = 'none'
17 |
18 | # Overwrite default rounding precision
19 | cfg.round = 5
20 |
21 |
22 | @register_config('extended_cfg')
23 | def extended_cfg(cfg):
24 | """General extended config options.
25 | """
26 |
27 | # Additional name tag used in `run_dir` and `wandb_name` auto generation.
28 | cfg.name_tag = ""
29 |
30 | # In training, if True (and also cfg.train.enable_ckpt is True) then
31 | # always checkpoint the current best model based on validation performance,
32 | # instead, when False, follow cfg.train.eval_period checkpointing frequency.
33 | cfg.train.ckpt_best = False
34 |
--------------------------------------------------------------------------------
/gssc/config/example.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.graphgym.register import register_config
2 | from yacs.config import CfgNode as CN
3 |
4 |
5 | @register_config('example')
6 | def set_cfg_example(cfg):
7 | r'''
8 | This function sets the default config value for customized options
9 | :return: customized configuration use by the experiment.
10 | '''
11 |
12 | # ----------------------------------------------------------------------- #
13 | # Customized options
14 | # ----------------------------------------------------------------------- #
15 |
16 | # example argument
17 | cfg.example_arg = 'example'
18 |
19 | # example argument group
20 | cfg.example_group = CN()
21 |
22 | # then argument can be specified within the group
23 | cfg.example_group.example_arg = 'example'
24 |
--------------------------------------------------------------------------------
/gssc/config/gt_config.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.graphgym.register import register_config
2 | from yacs.config import CfgNode as CN
3 |
4 |
5 | @register_config('cfg_gt')
6 | def set_cfg_gt(cfg):
7 | """Configuration for Graph Transformer-style models, e.g.:
8 | - Spectral Attention Network (SAN) Graph Transformer.
9 | - "vanilla" Transformer / Performer.
10 | - General Powerful Scalable (GPS) Model.
11 | """
12 |
13 | # Positional encodings argument group
14 | cfg.gt = CN()
15 |
16 | # Type of Graph Transformer layer to use
17 | cfg.gt.layer_type = 'SANLayer'
18 |
19 | # Number of Transformer layers in the model
20 | cfg.gt.layers = 3
21 |
22 | # Number of attention heads in the Graph Transformer
23 | cfg.gt.n_heads = 8
24 |
25 | # Size of the hidden node and edge representation
26 | cfg.gt.dim_hidden = 64
27 |
28 | # Size of the edge embedding
29 | cfg.gt.dim_edge = None
30 |
31 | # Full attention SAN transformer including all possible pairwise edges
32 | cfg.gt.full_graph = True
33 |
34 | # Type of extra edges used for transformer
35 | cfg.gt.secondary_edges = 'full_graph'
36 |
37 | # SAN real vs fake edge attention weighting coefficient
38 | cfg.gt.gamma = 1e-5
39 |
40 | # Histogram of in-degrees of nodes in the training set used by PNAConv.
41 | # Used when `gt.layer_type: PNAConv+...`. If empty it is precomputed during
42 | # the dataset loading process.
43 | cfg.gt.pna_degrees = []
44 |
45 | # Dropout in feed-forward module.
46 | cfg.gt.dropout = 0.0
47 |
48 | # Dropout in self-attention.
49 | cfg.gt.attn_dropout = 0.0
50 |
51 | cfg.gt.layer_norm = False
52 |
53 | cfg.gt.batch_norm = True
54 |
55 | cfg.gt.residual = True
56 |
57 | cfg.gt.activation = 'relu'
58 |
59 | # BigBird model/GPS-BigBird layer.
60 | cfg.gt.bigbird = CN()
61 |
62 | cfg.gt.bigbird.attention_type = "block_sparse"
63 |
64 | cfg.gt.bigbird.chunk_size_feed_forward = 0
65 |
66 | cfg.gt.bigbird.is_decoder = False
67 |
68 | cfg.gt.bigbird.add_cross_attention = False
69 |
70 | cfg.gt.bigbird.hidden_act = "relu"
71 |
72 | cfg.gt.bigbird.max_position_embeddings = 128
73 |
74 | cfg.gt.bigbird.use_bias = False
75 |
76 | cfg.gt.bigbird.num_random_blocks = 3
77 |
78 | cfg.gt.bigbird.block_size = 3
79 |
80 | cfg.gt.bigbird.layer_norm_eps = 1e-6
81 |
--------------------------------------------------------------------------------
/gssc/config/optimizers_config.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.graphgym.register import register_config
2 |
3 |
4 | @register_config('extended_optim')
5 | def extended_optim_cfg(cfg):
6 | """Extend optimizer config group that is first set by GraphGym in
7 | torch_geometric.graphgym.config.set_cfg
8 | """
9 |
10 | # Number of batches to accumulate gradients over before updating parameters
11 | # Requires `custom` training loop, set `train.mode: custom`
12 | cfg.optim.batch_accumulation = 1
13 |
14 | # ReduceLROnPlateau: Factor by which the learning rate will be reduced
15 | cfg.optim.reduce_factor = 0.1
16 |
17 | # ReduceLROnPlateau: #epochs without improvement after which LR gets reduced
18 | cfg.optim.schedule_patience = 10
19 |
20 | # ReduceLROnPlateau: Lower bound on the learning rate
21 | cfg.optim.min_lr = 0.0
22 |
23 | # For schedulers with warm-up phase, set the warm-up number of epochs
24 | cfg.optim.num_warmup_epochs = 50
25 |
26 | # Clip gradient norms while training
27 | cfg.optim.clip_grad_norm = False
28 |
--------------------------------------------------------------------------------
/gssc/config/posenc_config.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.graphgym.register import register_config
2 | from yacs.config import CfgNode as CN
3 |
4 |
5 | @register_config('posenc')
6 | def set_cfg_posenc(cfg):
7 | """Extend configuration with positional encoding options.
8 | """
9 |
10 | # Argument group for each Positional Encoding class.
11 | cfg.posenc_LapPE = CN()
12 | cfg.posenc_SignNet = CN()
13 | cfg.posenc_RWSE = CN()
14 | cfg.posenc_HKdiagSE = CN()
15 | cfg.posenc_ElstaticSE = CN()
16 | cfg.posenc_EquivStableLapPE = CN()
17 |
18 | # Effective Resistance Embeddings
19 | cfg.posenc_ERN = CN() #Effective Resistance for Nodes
20 | cfg.posenc_ERE = CN() #Effective Resistance for Edges
21 |
22 | # Common arguments to all PE types.
23 | for name in ['posenc_LapPE', 'posenc_SignNet',
24 | 'posenc_RWSE', 'posenc_HKdiagSE', 'posenc_ElstaticSE',
25 | 'posenc_ERN', 'posenc_ERE']:
26 | pecfg = getattr(cfg, name)
27 |
28 | # Use extended positional encodings
29 | pecfg.enable = False
30 |
31 | # Neural-net model type within the PE encoder:
32 | # 'DeepSet', 'Transformer', 'Linear', 'none', ...
33 | pecfg.model = 'none'
34 |
35 | # Size of Positional Encoding embedding
36 | pecfg.dim_pe = 16
37 |
38 | # Number of layers in PE encoder model
39 | pecfg.layers = 3
40 |
41 | # Number of attention heads in PE encoder when model == 'Transformer'
42 | pecfg.n_heads = 4
43 |
44 | # Number of layers to apply in LapPE encoder post its pooling stage
45 | pecfg.post_layers = 0
46 |
47 | # Choice of normalization applied to raw PE stats: 'none', 'BatchNorm'
48 | pecfg.raw_norm_type = 'none'
49 |
50 | # In addition to appending PE to the node features, pass them also as
51 | # a separate variable in the PyG graph batch object.
52 | pecfg.pass_as_var = False
53 |
54 | # Config for EquivStable LapPE
55 | cfg.posenc_EquivStableLapPE.enable = False
56 | cfg.posenc_EquivStableLapPE.raw_norm_type = 'none'
57 |
58 | # Config for Laplacian Eigen-decomposition for PEs that use it.
59 | for name in ['posenc_LapPE', 'posenc_SignNet', 'posenc_EquivStableLapPE']:
60 | pecfg = getattr(cfg, name)
61 | pecfg.eigen = CN()
62 |
63 | # The normalization scheme for the graph Laplacian: 'none', 'sym', or 'rw'
64 | pecfg.eigen.laplacian_norm = 'sym'
65 |
66 | # The normalization scheme for the eigen vectors of the Laplacian
67 | pecfg.eigen.eigvec_norm = 'L2'
68 |
69 | # Maximum number of top smallest frequencies & eigenvectors to use
70 | pecfg.eigen.max_freqs = 10
71 |
72 | # Config for SignNet-specific options.
73 | cfg.posenc_SignNet.phi_out_dim = 4
74 | cfg.posenc_SignNet.phi_hidden_dim = 64
75 |
76 | for name in ['posenc_RWSE', 'posenc_HKdiagSE', 'posenc_ElstaticSE']:
77 | pecfg = getattr(cfg, name)
78 |
79 | # Config for Kernel-based PE specific options.
80 | pecfg.kernel = CN()
81 |
82 | # List of times to compute the heat kernel for (the time is equivalent to
83 | # the variance of the kernel) / the number of steps for random walk kernel
84 | # Can be overridden by `posenc.kernel.times_func`
85 | pecfg.kernel.times = []
86 |
87 | # Python snippet to generate `posenc.kernel.times`, e.g. 'range(1, 17)'
88 | # If set, it will be executed via `eval()` and override posenc.kernel.times
89 | pecfg.kernel.times_func = ''
90 |
91 | # Override default, electrostatic kernel has fixed set of 10 measures.
92 | cfg.posenc_ElstaticSE.kernel.times_func = 'range(10)'
93 |
94 | # Setting accuracy for Effective Resistance Calculations:
95 | cfg.posenc_ERN.accuracy = 0.1
96 | cfg.posenc_ERE.accuracy = 0.1
97 |
98 | # To be set during the calculations:
99 | cfg.posenc_ERN.er_dim = 'none'
100 |
--------------------------------------------------------------------------------
/gssc/config/pretrained_config.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.graphgym.register import register_config
2 | from yacs.config import CfgNode as CN
3 |
4 |
5 | @register_config('cfg_pretrained')
6 | def set_cfg_pretrained(cfg):
7 | """Configuration options for loading a pretrained model.
8 | """
9 |
10 | cfg.pretrained = CN()
11 |
12 | # Directory path to a saved experiment, if set, load the model from there
13 | # and fine-tune / run inference with it on a specified dataset.
14 | cfg.pretrained.dir = ""
15 |
16 | # Discard pretrained weights of the prediction head and reinitialize.
17 | cfg.pretrained.reset_prediction_head = True
18 |
19 | # Freeze the main pretrained 'body' of the model, learning only the new head
20 | cfg.pretrained.freeze_main = False
21 |
--------------------------------------------------------------------------------
/gssc/config/split_config.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.graphgym.register import register_config
2 |
3 |
4 | @register_config('split')
5 | def set_cfg_split(cfg):
6 | """Reconfigure the default config value for dataset split options.
7 |
8 | Returns:
9 | Reconfigured split configuration use by the experiment.
10 | """
11 |
12 | # Default to selecting the standard split that ships with the dataset
13 | cfg.dataset.split_mode = 'standard'
14 |
15 | # Choose a particular split to use if multiple splits are available
16 | cfg.dataset.split_index = 0
17 |
18 | # Dir to cache cross-validation splits
19 | cfg.dataset.split_dir = './splits'
20 |
21 | # Choose to run multiple splits in one program execution, if set,
22 | # takes the precedence over cfg.dataset.split_index for split selection
23 | cfg.run_multiple_splits = []
24 |
--------------------------------------------------------------------------------
/gssc/config/wandb_config.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.graphgym.register import register_config
2 | from yacs.config import CfgNode as CN
3 |
4 |
5 | @register_config('cfg_wandb')
6 | def set_cfg_wandb(cfg):
7 | """Weights & Biases tracker configuration.
8 | """
9 |
10 | # WandB group
11 | cfg.wandb = CN()
12 |
13 | # Use wandb or not
14 | cfg.wandb.use = False
15 |
16 | # Wandb entity name, should exist beforehand
17 | cfg.wandb.entity = "gtransformers"
18 |
19 | # Wandb project name, will be created in your team if doesn't exist already
20 | cfg.wandb.project = "gtblueprint"
21 |
22 | # Optional run name
23 | cfg.wandb.name = ""
24 |
--------------------------------------------------------------------------------
/gssc/encoder/ER_edge_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch_geometric.graphgym.config import cfg
4 | from torch_geometric.graphgym.register import register_edge_encoder
5 |
6 |
7 | @register_edge_encoder('ERE')
8 | class EREdgeEncoder(torch.nn.Module):
9 | def __init__(self, emb_dim, use_edge_attr=False, expand_edge_attr=False):
10 | super().__init__()
11 |
12 | dim_in = cfg.gt.dim_edge # Expected final edge_dim
13 |
14 | pecfg = cfg.posenc_ERE
15 | n_layers = pecfg.layers # Num. layers in PE encoder model
16 | self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable
17 |
18 | self.use_edge_attr = use_edge_attr
19 | self.expand_edge_attr = expand_edge_attr
20 | if expand_edge_attr:
21 | self.linear_x = nn.Linear(dim_in, dim_in - emb_dim)
22 |
23 | if not self.use_edge_attr:
24 | assert emb_dim == dim_in
25 |
26 | layers = []
27 | layers.append(nn.Linear(1, emb_dim))
28 | layers.append(nn.ReLU())
29 | if n_layers > 1:
30 | for _ in range(n_layers - 1):
31 | layers.append(nn.Linear(emb_dim, emb_dim))
32 | layers.append(nn.ReLU())
33 | self.er_encoder = nn.Sequential(*layers)
34 |
35 | def forward(self, batch):
36 | ere = self.er_encoder(batch.er_edge)
37 | if self.expand_edge_attr:
38 | batch.edge_attr = self.linear_x(batch.edge_attr)
39 |
40 | if self.use_edge_attr:
41 | batch.edge_attr = torch.cat([batch.edge_attr, ere], dim=1)
42 | else:
43 | batch.edge_attr = ere
44 |
45 | if self.pass_as_var:
46 | batch.er_edge = ere
47 |
48 | return batch
--------------------------------------------------------------------------------
/gssc/encoder/ER_node_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch_geometric.graphgym.config import cfg
4 | from torch_geometric.graphgym.register import register_node_encoder
5 |
6 |
7 | @register_node_encoder('ERN')
8 | class ERNodeEncoder(torch.nn.Module):
9 | """Effective Resistance Node Encoder
10 |
11 | ER of size dim_pe will get appended to each node feature vector.
12 | If `expand_x` set True, original node features will be first linearly
13 | projected to (dim_emb - dim_pe) size and the concatenated with ER.
14 |
15 | Args:
16 | dim_emb: Size of final node embedding
17 | expand_x: Expand node features `x` from dim_in to (dim_emb - dim_pe)
18 | """
19 |
20 | def __init__(self, dim_emb, expand_x=True):
21 | super().__init__()
22 | dim_in = cfg.share.dim_in # Expected original input node features dim
23 |
24 | pecfg = cfg.posenc_ERN
25 | dim_pe = pecfg.dim_pe # Size of Laplace PE embedding
26 | model_type = pecfg.model # Encoder NN model type for DEs
27 | if model_type not in ['Transformer', 'DeepSet', 'Linear']:
28 | raise ValueError(f"Unexpected PE model {model_type}")
29 | self.model_type = model_type
30 | n_layers = pecfg.layers # Num. layers in PE encoder model
31 | n_heads = pecfg.n_heads # Num. attention heads in Trf PE encoder
32 | post_n_layers = pecfg.post_layers # Num. layers to apply after pooling
33 | self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable
34 |
35 | er_dim = pecfg.er_dim
36 |
37 | if dim_emb - dim_pe < 1:
38 | raise ValueError(f"ER_Node size {dim_pe} is too large for "
39 | f"desired embedding size of {dim_emb}.")
40 |
41 | if expand_x:
42 | self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe)
43 | self.expand_x = expand_x
44 |
45 | if model_type == 'Linear':
46 | self.pe_encoder = nn.Linear(er_dim, dim_pe)
47 |
48 | else:
49 | if model_type == 'Transformer':
50 | # Initial projection of each value of ER embedding
51 | self.linear_A = nn.Linear(1, dim_pe)
52 | # Transformer model for ER_Node
53 | encoder_layer = nn.TransformerEncoderLayer(d_model=dim_pe,
54 | nhead=n_heads,
55 | batch_first=True)
56 | self.pe_encoder = nn.TransformerEncoder(encoder_layer,
57 | num_layers=n_layers)
58 | else:
59 | # DeepSet model for ER_Node
60 | layers = []
61 | if n_layers == 1:
62 | layers.append(nn.ReLU())
63 | else:
64 | self.linear_A = nn.Linear(1, dim_pe)
65 | layers.append(nn.ReLU())
66 | for _ in range(n_layers - 1):
67 | layers.append(nn.Linear(dim_pe, dim_pe))
68 | layers.append(nn.ReLU())
69 | self.pe_encoder = nn.Sequential(*layers)
70 |
71 | self.post_mlp = None
72 | if post_n_layers > 0:
73 | # MLP to apply post pooling
74 | layers = []
75 | if post_n_layers == 1:
76 | layers.append(nn.Linear(dim_pe, dim_pe))
77 | layers.append(nn.ReLU())
78 | else:
79 | layers.append(nn.Linear(dim_pe, 2 * dim_pe))
80 | layers.append(nn.ReLU())
81 | for _ in range(post_n_layers - 2):
82 | layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe))
83 | layers.append(nn.ReLU())
84 | layers.append(nn.Linear(2 * dim_pe, dim_pe))
85 | layers.append(nn.ReLU())
86 | self.post_mlp = nn.Sequential(*layers)
87 |
88 |
89 | def forward(self, batch):
90 | if not hasattr(batch, 'er_emb'):
91 | raise ValueError("Precomputed ER embeddings required for calculating ER Node Encodings")
92 |
93 | pos_enc = batch.er_emb # N * er_dim
94 |
95 | if self.training:
96 | pos_enc = pos_enc[:, torch.randperm(pos_enc.size()[1])]
97 |
98 | if self.model_type == 'Linear':
99 | pos_enc = self.pe_encoder(pos_enc) # N * er_dim -> N * dim_pe
100 |
101 | else:
102 | pos_enc = torch.unsqueeze(pos_enc, 2)
103 | pos_enc = self.linear_A(pos_enc) # (Num nodes) x (er_dim) x dim_pe
104 |
105 | # PE encoder: a Transformer or DeepSet model
106 | if self.model_type == 'Transformer':
107 | pos_enc = self.pe_encoder(src=pos_enc)
108 | else:
109 | pos_enc = self.pe_encoder(pos_enc)
110 |
111 | # Sum pooling
112 | pos_enc = torch.sum(pos_enc, 1, keepdim=False) # (Num nodes) x dim_pe
113 |
114 | # MLP post pooling
115 | if self.post_mlp is not None:
116 | pos_enc = self.post_mlp(pos_enc) # (Num nodes) x dim_pe
117 |
118 | # Expand node features if needed
119 | if self.expand_x:
120 | h = self.linear_x(batch.x)
121 | else:
122 | h = batch.x
123 | # Concatenate final PEs to input embedding
124 | batch.x = torch.cat((h, pos_enc), 1)
125 | # Keep PE also separate in a variable (e.g. for skip connections to input)
126 | if self.pass_as_var:
127 | batch.pe_ern = pos_enc
128 | return batch
129 |
--------------------------------------------------------------------------------
/gssc/encoder/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/gssc/encoder/ast_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.graphgym.register import (register_node_encoder,
3 | register_edge_encoder)
4 |
5 | """
6 | === Description of the ogbg-code2 dataset ===
7 |
8 | * Node Encoder code based on OGB's:
9 | https://github.com/snap-stanford/ogb/blob/master/examples/graphproppred/code2/utils.py
10 |
11 | Node Encoder config parameters are set based on the OGB example:
12 | https://github.com/snap-stanford/ogb/blob/master/examples/graphproppred/code2/main_pyg.py
13 | where the following three node features are used:
14 | 1. node type
15 | 2. node attribute
16 | 3. node depth
17 |
18 | nodetypes_mapping = pd.read_csv(os.path.join(dataset.root, 'mapping', 'typeidx2type.csv.gz'))
19 | nodeattributes_mapping = pd.read_csv(os.path.join(dataset.root, 'mapping', 'attridx2attr.csv.gz'))
20 | num_nodetypes = len(nodetypes_mapping['type'])
21 | num_nodeattributes = len(nodeattributes_mapping['attr'])
22 | max_depth = 20
23 |
24 | * Edge attributes are generated by `augment_edge` function dynamically:
25 | edge_attr[:,0]: whether it is AST edge (0) for next-token edge (1)
26 | edge_attr[:,1]: whether it is original direction (0) or inverse direction (1)
27 | """
28 |
29 | num_nodetypes = 98
30 | num_nodeattributes = 10030
31 | max_depth = 20
32 |
33 |
34 | @register_node_encoder('ASTNode')
35 | class ASTNodeEncoder(torch.nn.Module):
36 | """The Abstract Syntax Tree (AST) Node Encoder used for ogbg-code2 dataset.
37 |
38 | Input:
39 | x: Default node feature. The first and second column represents node
40 | type and node attributes.
41 | node_depth: The depth of the node in the AST.
42 | Output:
43 | emb_dim-dimensional vector
44 | """
45 |
46 | def __init__(self, emb_dim):
47 | super().__init__()
48 | self.max_depth = max_depth
49 |
50 | self.type_encoder = torch.nn.Embedding(num_nodetypes, emb_dim)
51 | self.attribute_encoder = torch.nn.Embedding(num_nodeattributes, emb_dim)
52 | self.depth_encoder = torch.nn.Embedding(self.max_depth + 1, emb_dim)
53 |
54 | def forward(self, batch):
55 | x = batch.x
56 | depth = batch.node_depth.view(-1, )
57 | depth[depth > self.max_depth] = self.max_depth
58 | batch.x = self.type_encoder(x[:, 0]) + self.attribute_encoder(x[:, 1]) \
59 | + self.depth_encoder(depth)
60 | return batch
61 |
62 |
63 | @register_edge_encoder('ASTEdge')
64 | class ASTEdgeEncoder(torch.nn.Module):
65 | """The Abstract Syntax Tree (AST) Edge Encoder used for ogbg-code2 dataset.
66 |
67 | Edge attributes are generated by `augment_edge` function dynamically and
68 | are expected to be:
69 | edge_attr[:,0]: whether it is AST edge (0) for next-token edge (1)
70 | edge_attr[:,1]: whether it is original direction (0) or inverse direction (1)
71 |
72 | Args:
73 | emb_dim (int): Output edge embedding dimension
74 | """
75 |
76 | def __init__(self, emb_dim):
77 | super().__init__()
78 | self.embedding_type = torch.nn.Embedding(2, emb_dim)
79 | self.embedding_direction = torch.nn.Embedding(2, emb_dim)
80 |
81 | def forward(self, batch):
82 | embedding = self.embedding_type(batch.edge_attr[:, 0]) + \
83 | self.embedding_direction(batch.edge_attr[:, 1])
84 | batch.edge_attr = embedding
85 | return batch
86 |
--------------------------------------------------------------------------------
/gssc/encoder/composed_encoders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.graphgym.config import cfg
3 | from torch_geometric.graphgym.models.encoder import AtomEncoder
4 | from torch_geometric.graphgym.register import register_node_encoder
5 |
6 | from gssc.encoder.ast_encoder import ASTNodeEncoder
7 | from gssc.encoder.kernel_pos_encoder import RWSENodeEncoder, \
8 | HKdiagSENodeEncoder, ElstaticSENodeEncoder
9 | from gssc.encoder.laplace_pos_encoder import LapPENodeEncoder
10 | from gssc.encoder.ppa_encoder import PPANodeEncoder
11 | from gssc.encoder.signnet_pos_encoder import SignNetNodeEncoder
12 | from gssc.encoder.voc_superpixels_encoder import VOCNodeEncoder
13 | from gssc.encoder.type_dict_encoder import TypeDictNodeEncoder
14 | from gssc.encoder.linear_node_encoder import LinearNodeEncoder
15 | from gssc.encoder.equivstable_laplace_pos_encoder import EquivStableLapPENodeEncoder
16 | from gssc.encoder.ER_node_encoder import ERNodeEncoder
17 |
18 |
19 | def concat_node_encoders(encoder_classes, pe_enc_names):
20 | """
21 | A factory that creates a new Encoder class that concatenates functionality
22 | of the given list of two or three Encoder classes. First Encoder is expected
23 | to be a dataset-specific encoder, and the rest PE Encoders.
24 |
25 | Args:
26 | encoder_classes: List of node encoder classes
27 | pe_enc_names: List of PE embedding Encoder names, used to query a dict
28 | with their desired PE embedding dims. That dict can only be created
29 | during the runtime, once the config is loaded.
30 |
31 | Returns:
32 | new node encoder class
33 | """
34 |
35 | class Concat2NodeEncoder(torch.nn.Module):
36 | """Encoder that concatenates two node encoders.
37 | """
38 | enc1_cls = None
39 | enc2_cls = None
40 | enc2_name = None
41 |
42 | def __init__(self, dim_emb):
43 | super().__init__()
44 |
45 | if cfg.posenc_EquivStableLapPE.enable: # Special handling for Equiv_Stable LapPE where node feats and PE are not concat
46 | self.encoder1 = self.enc1_cls(dim_emb)
47 | self.encoder2 = self.enc2_cls(dim_emb)
48 | else:
49 | # PE dims can only be gathered once the cfg is loaded.
50 | enc2_dim_pe = getattr(cfg, f"posenc_{self.enc2_name}").dim_pe
51 |
52 | self.encoder1 = self.enc1_cls(dim_emb - enc2_dim_pe)
53 | self.encoder2 = self.enc2_cls(dim_emb, expand_x=False)
54 |
55 | def forward(self, batch):
56 | batch = self.encoder1(batch)
57 | batch = self.encoder2(batch)
58 | return batch
59 |
60 | class Concat3NodeEncoder(torch.nn.Module):
61 | """Encoder that concatenates three node encoders.
62 | """
63 | enc1_cls = None
64 | enc2_cls = None
65 | enc2_name = None
66 | enc3_cls = None
67 | enc3_name = None
68 |
69 | def __init__(self, dim_emb):
70 | super().__init__()
71 | # PE dims can only be gathered once the cfg is loaded.
72 | enc2_dim_pe = getattr(cfg, f"posenc_{self.enc2_name}").dim_pe
73 | enc3_dim_pe = getattr(cfg, f"posenc_{self.enc3_name}").dim_pe
74 | self.encoder1 = self.enc1_cls(dim_emb - enc2_dim_pe - enc3_dim_pe)
75 | self.encoder2 = self.enc2_cls(dim_emb - enc3_dim_pe, expand_x=False)
76 | self.encoder3 = self.enc3_cls(dim_emb, expand_x=False)
77 |
78 | def forward(self, batch):
79 | batch = self.encoder1(batch)
80 | batch = self.encoder2(batch)
81 | batch = self.encoder3(batch)
82 | return batch
83 |
84 | # Configure the correct concatenation class and return it.
85 | if len(encoder_classes) == 2:
86 | Concat2NodeEncoder.enc1_cls = encoder_classes[0]
87 | Concat2NodeEncoder.enc2_cls = encoder_classes[1]
88 | Concat2NodeEncoder.enc2_name = pe_enc_names[0]
89 | return Concat2NodeEncoder
90 | elif len(encoder_classes) == 3:
91 | Concat3NodeEncoder.enc1_cls = encoder_classes[0]
92 | Concat3NodeEncoder.enc2_cls = encoder_classes[1]
93 | Concat3NodeEncoder.enc3_cls = encoder_classes[2]
94 | Concat3NodeEncoder.enc2_name = pe_enc_names[0]
95 | Concat3NodeEncoder.enc3_name = pe_enc_names[1]
96 | return Concat3NodeEncoder
97 | else:
98 | raise ValueError(f"Does not support concatenation of "
99 | f"{len(encoder_classes)} encoder classes.")
100 |
101 |
102 | # Dataset-specific node encoders.
103 | ds_encs = {'Atom': AtomEncoder,
104 | 'ASTNode': ASTNodeEncoder,
105 | 'PPANode': PPANodeEncoder,
106 | 'TypeDictNode': TypeDictNodeEncoder,
107 | 'VOCNode': VOCNodeEncoder,
108 | 'LinearNode': LinearNodeEncoder}
109 |
110 | # Positional Encoding node encoders.
111 | pe_encs = {'LapPE': LapPENodeEncoder,
112 | 'RWSE': RWSENodeEncoder,
113 | 'HKdiagSE': HKdiagSENodeEncoder,
114 | 'ElstaticSE': ElstaticSENodeEncoder,
115 | 'SignNet': SignNetNodeEncoder,
116 | 'EquivStableLapPE': EquivStableLapPENodeEncoder,
117 | 'ERN': ERNodeEncoder}
118 |
119 | # Concat dataset-specific and PE encoders.
120 | for ds_enc_name, ds_enc_cls in ds_encs.items():
121 | for pe_enc_name, pe_enc_cls in pe_encs.items():
122 | register_node_encoder(
123 | f"{ds_enc_name}+{pe_enc_name}",
124 | concat_node_encoders([ds_enc_cls, pe_enc_cls],
125 | [pe_enc_name])
126 | )
127 |
128 | # Combine both LapPE and RWSE positional encodings.
129 | for ds_enc_name, ds_enc_cls in ds_encs.items():
130 | register_node_encoder(
131 | f"{ds_enc_name}+LapPE+RWSE",
132 | concat_node_encoders([ds_enc_cls, LapPENodeEncoder, RWSENodeEncoder],
133 | ['LapPE', 'RWSE'])
134 | )
135 |
136 | # Combine both SignNet and RWSE positional encodings.
137 | for ds_enc_name, ds_enc_cls in ds_encs.items():
138 | register_node_encoder(
139 | f"{ds_enc_name}+SignNet+RWSE",
140 | concat_node_encoders([ds_enc_cls, SignNetNodeEncoder, RWSENodeEncoder],
141 | ['SignNet', 'RWSE'])
142 | )
143 |
--------------------------------------------------------------------------------
/gssc/encoder/dummy_edge_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.graphgym.register import register_edge_encoder
3 |
4 |
5 | @register_edge_encoder('DummyEdge')
6 | class DummyEdgeEncoder(torch.nn.Module):
7 | def __init__(self, emb_dim):
8 | super().__init__()
9 |
10 | self.encoder = torch.nn.Embedding(num_embeddings=1,
11 | embedding_dim=emb_dim)
12 | # torch.nn.init.xavier_uniform_(self.encoder.weight.data)
13 |
14 | def forward(self, batch):
15 | dummy_attr = batch.edge_index.new_zeros(batch.edge_index.shape[1])
16 | batch.edge_attr = self.encoder(dummy_attr)
17 | return batch
18 |
--------------------------------------------------------------------------------
/gssc/encoder/equivstable_laplace_pos_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch_geometric.graphgym.config import cfg
4 | from torch_geometric.graphgym.register import register_node_encoder
5 |
6 |
7 | @register_node_encoder('EquivStableLapPE')
8 | class EquivStableLapPENodeEncoder(torch.nn.Module):
9 | """Equivariant and Stable Laplace Positional Embedding node encoder.
10 |
11 | This encoder simply transforms the k-dim node LapPE to d-dim to be
12 | later used at the local GNN module as edge weights.
13 | Based on the approach proposed in paper https://openreview.net/pdf?id=e95i1IHcWj
14 |
15 | Args:
16 | dim_emb: Size of final node embedding
17 | """
18 |
19 | def __init__(self, dim_emb):
20 | super().__init__()
21 |
22 | pecfg = cfg.posenc_EquivStableLapPE
23 | max_freqs = pecfg.eigen.max_freqs # Num. eigenvectors (frequencies)
24 | norm_type = pecfg.raw_norm_type.lower() # Raw PE normalization layer type
25 |
26 | if norm_type == 'batchnorm':
27 | self.raw_norm = nn.BatchNorm1d(max_freqs)
28 | else:
29 | self.raw_norm = None
30 |
31 | self.linear_encoder_eigenvec = nn.Linear(max_freqs, dim_emb)
32 |
33 | def forward(self, batch):
34 | if not (hasattr(batch, 'EigVals') and hasattr(batch, 'EigVecs')):
35 | raise ValueError("Precomputed eigen values and vectors are "
36 | f"required for {self.__class__.__name__}; set "
37 | f"config 'posenc_EquivStableLapPE.enable' to True")
38 | pos_enc = batch.EigVecs
39 |
40 | empty_mask = torch.isnan(pos_enc) # (Num nodes) x (Num Eigenvectors)
41 | pos_enc[empty_mask] = 0. # (Num nodes) x (Num Eigenvectors)
42 |
43 | if self.raw_norm:
44 | pos_enc = self.raw_norm(pos_enc)
45 |
46 | pos_enc = self.linear_encoder_eigenvec(pos_enc)
47 |
48 | # Keep PE separate in a variable
49 | batch.pe_EquivStableLapPE = pos_enc
50 |
51 | return batch
52 |
--------------------------------------------------------------------------------
/gssc/encoder/example.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from ogb.utils.features import get_bond_feature_dims
3 |
4 | from torch_geometric.graphgym.register import (
5 | register_edge_encoder,
6 | register_node_encoder,
7 | )
8 |
9 |
10 | @register_node_encoder('example')
11 | class ExampleNodeEncoder(torch.nn.Module):
12 | """
13 | Provides an encoder for integer node features
14 | Parameters:
15 | num_classes - the number of classes for the embedding mapping to learn
16 | """
17 | def __init__(self, emb_dim, num_classes=None):
18 | super().__init__()
19 |
20 | self.encoder = torch.nn.Embedding(num_classes, emb_dim)
21 | torch.nn.init.xavier_uniform_(self.encoder.weight.data)
22 |
23 | def forward(self, batch):
24 | # Encode just the first dimension if more exist
25 | batch.x = self.encoder(batch.x[:, 0])
26 |
27 | return batch
28 |
29 |
30 | @register_edge_encoder('example')
31 | class ExampleEdgeEncoder(torch.nn.Module):
32 | def __init__(self, emb_dim):
33 | super().__init__()
34 |
35 | self.bond_embedding_list = torch.nn.ModuleList()
36 | full_bond_feature_dims = get_bond_feature_dims()
37 |
38 | for i, dim in enumerate(full_bond_feature_dims):
39 | emb = torch.nn.Embedding(dim, emb_dim)
40 | torch.nn.init.xavier_uniform_(emb.weight.data)
41 | self.bond_embedding_list.append(emb)
42 |
43 | def forward(self, batch):
44 | bond_embedding = 0
45 | for i in range(batch.edge_feature.shape[1]):
46 | bond_embedding += \
47 | self.bond_embedding_list[i](batch.edge_attr[:, i])
48 |
49 | batch.edge_attr = bond_embedding
50 | return batch
51 |
--------------------------------------------------------------------------------
/gssc/encoder/exp_edge_fixer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch_scatter import scatter
6 |
7 | from torch_geometric.graphgym.config import cfg
8 | from torch_geometric.graphgym.register import register_layer
9 |
10 |
11 | class ExpanderEdgeFixer(nn.Module):
12 | '''
13 | Gets the batch and sets new edge indices + global nodes
14 | '''
15 | def __init__(self, add_edge_index=False, num_virt_node=0):
16 |
17 | super().__init__()
18 |
19 | if not hasattr(cfg.gt, 'dim_edge') or cfg.gt.dim_edge is None:
20 | cfg.gt.dim_edge = cfg.gt.dim_hidden
21 |
22 | self.add_edge_index = add_edge_index
23 | self.num_virt_node = num_virt_node
24 | self.exp_edge_attr = nn.Embedding(1, cfg.gt.dim_edge)
25 | self.use_exp_edges = cfg.prep.use_exp_edges and cfg.prep.exp
26 |
27 | if self.num_virt_node > 0:
28 | self.virt_node_emb = nn.Embedding(self.num_virt_node, cfg.gt.dim_hidden)
29 | self.virt_edge_out_emb = nn.Embedding(self.num_virt_node, cfg.gt.dim_edge)
30 | self.virt_edge_in_emb = nn.Embedding(self.num_virt_node, cfg.gt.dim_edge)
31 |
32 |
33 | def forward(self, batch):
34 | edge_types = []
35 | device = self.exp_edge_attr.weight.device
36 | edge_index_sets = []
37 | edge_attr_sets = []
38 | if self.add_edge_index:
39 | edge_index_sets.append(batch.edge_index)
40 | edge_attr_sets.append(batch.edge_attr)
41 | edge_types.append(torch.zeros(batch.edge_index.shape[1], dtype=torch.long))
42 |
43 |
44 | num_node = batch.batch.shape[0]
45 | num_graphs = batch.num_graphs
46 |
47 | if self.use_exp_edges:
48 | if not hasattr(batch, 'expander_edges'):
49 | raise ValueError('expander edges not stored in data')
50 |
51 | data_list = batch.to_data_list()
52 | exp_edges = []
53 | cumulative_num_nodes = 0
54 | for data in data_list:
55 | exp_edges.append(data.expander_edges + cumulative_num_nodes)
56 | cumulative_num_nodes += data.num_nodes
57 |
58 | exp_edges = torch.cat(exp_edges, dim=0).t()
59 | edge_index_sets.append(exp_edges)
60 | edge_attr_sets.append(self.exp_edge_attr(torch.zeros(exp_edges.shape[1], dtype=torch.long).to(device)))
61 | edge_types.append(torch.zeros(exp_edges.shape[1], dtype=torch.long) + 1)
62 |
63 | if self.num_virt_node > 0:
64 | global_h = []
65 | virt_edges = []
66 | virt_edge_attrs = []
67 | for idx in range(self.num_virt_node):
68 | global_h.append(self.virt_node_emb(torch.zeros(num_graphs, dtype=torch.long).to(device)+idx))
69 | virt_edge_index = torch.cat([torch.arange(num_node).view(1, -1).to(device),
70 | (batch.batch+(num_node+idx*num_graphs)).view(1, -1)], dim=0)
71 | virt_edges.append(virt_edge_index)
72 | virt_edge_attrs.append(self.virt_edge_in_emb(torch.zeros(virt_edge_index.shape[1], dtype=torch.long).to(device)+idx))
73 |
74 | virt_edge_index = torch.cat([(batch.batch+(num_node+idx*num_graphs)).view(1, -1),
75 | torch.arange(num_node).view(1, -1).to(device)], dim=0)
76 | virt_edges.append(virt_edge_index)
77 | virt_edge_attrs.append(self.virt_edge_out_emb(torch.zeros(virt_edge_index.shape[1], dtype=torch.long).to(device)+idx))
78 |
79 | batch.virt_h = torch.cat(global_h, dim=0)
80 | batch.virt_edge_index = torch.cat(virt_edges, dim=1)
81 | batch.virt_edge_attr = torch.cat(virt_edge_attrs, dim=0)
82 | edge_types.append(torch.zeros(batch.virt_edge_index.shape[1], dtype=torch.long)+2)
83 |
84 | if len(edge_index_sets) > 1:
85 | edge_index = torch.cat(edge_index_sets, dim=1)
86 | edge_attr = torch.cat(edge_attr_sets, dim=0)
87 | edge_types = torch.cat(edge_types)
88 | else:
89 | edge_index = edge_index_sets[0]
90 | edge_attr = edge_attr_sets[0]
91 | edge_types = edge_types[0]
92 |
93 | del batch.expander_edges
94 | batch.expander_edge_index = edge_index
95 | batch.expander_edge_attr = edge_attr
96 |
97 | return batch
--------------------------------------------------------------------------------
/gssc/encoder/kernel_pos_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch_geometric.graphgym.config import cfg
4 | from torch_geometric.graphgym.register import register_node_encoder
5 |
6 |
7 | class KernelPENodeEncoder(torch.nn.Module):
8 | """Configurable kernel-based Positional Encoding node encoder.
9 |
10 | The choice of which kernel-based statistics to use is configurable through
11 | setting of `kernel_type`. Based on this, the appropriate config is selected,
12 | and also the appropriate variable with precomputed kernel stats is then
13 | selected from PyG Data graphs in `forward` function.
14 | E.g., supported are 'RWSE', 'HKdiagSE', 'ElstaticSE'.
15 |
16 | PE of size `dim_pe` will get appended to each node feature vector.
17 | If `expand_x` set True, original node features will be first linearly
18 | projected to (dim_emb - dim_pe) size and the concatenated with PE.
19 |
20 | Args:
21 | dim_emb: Size of final node embedding
22 | expand_x: Expand node features `x` from dim_in to (dim_emb - dim_pe)
23 | """
24 |
25 | kernel_type = None # Instantiated type of the KernelPE, e.g. RWSE
26 |
27 | def __init__(self, dim_emb, expand_x=True):
28 | super().__init__()
29 | if self.kernel_type is None:
30 | raise ValueError(f"{self.__class__.__name__} has to be "
31 | f"preconfigured by setting 'kernel_type' class"
32 | f"variable before calling the constructor.")
33 |
34 | dim_in = cfg.share.dim_in # Expected original input node features dim
35 |
36 | pecfg = getattr(cfg, f"posenc_{self.kernel_type}")
37 | dim_pe = pecfg.dim_pe # Size of the kernel-based PE embedding
38 | num_rw_steps = len(pecfg.kernel.times)
39 | model_type = pecfg.model.lower() # Encoder NN model type for PEs
40 | n_layers = pecfg.layers # Num. layers in PE encoder model
41 | norm_type = pecfg.raw_norm_type.lower() # Raw PE normalization layer type
42 | self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable
43 |
44 | if dim_emb - dim_pe < 1:
45 | raise ValueError(f"PE dim size {dim_pe} is too large for "
46 | f"desired embedding size of {dim_emb}.")
47 |
48 | if expand_x:
49 | self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe)
50 | self.expand_x = expand_x
51 |
52 | if norm_type == 'batchnorm':
53 | self.raw_norm = nn.BatchNorm1d(num_rw_steps)
54 | else:
55 | self.raw_norm = None
56 |
57 | if model_type == 'mlp':
58 | layers = []
59 | if n_layers == 1:
60 | layers.append(nn.Linear(num_rw_steps, dim_pe))
61 | layers.append(nn.ReLU())
62 | else:
63 | layers.append(nn.Linear(num_rw_steps, 2 * dim_pe))
64 | layers.append(nn.ReLU())
65 | for _ in range(n_layers - 2):
66 | layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe))
67 | layers.append(nn.ReLU())
68 | layers.append(nn.Linear(2 * dim_pe, dim_pe))
69 | layers.append(nn.ReLU())
70 | self.pe_encoder = nn.Sequential(*layers)
71 | elif model_type == 'linear':
72 | self.pe_encoder = nn.Linear(num_rw_steps, dim_pe)
73 | else:
74 | raise ValueError(f"{self.__class__.__name__}: Does not support "
75 | f"'{model_type}' encoder model.")
76 |
77 | def forward(self, batch):
78 | pestat_var = f"pestat_{self.kernel_type}"
79 | if not hasattr(batch, pestat_var):
80 | raise ValueError(f"Precomputed '{pestat_var}' variable is "
81 | f"required for {self.__class__.__name__}; set "
82 | f"config 'posenc_{self.kernel_type}.enable' to "
83 | f"True, and also set 'posenc.kernel.times' values")
84 |
85 | pos_enc = getattr(batch, pestat_var) # (Num nodes) x (Num kernel times)
86 | # pos_enc = batch.rw_landing # (Num nodes) x (Num kernel times)
87 | if self.raw_norm:
88 | pos_enc = self.raw_norm(pos_enc)
89 | pos_enc = self.pe_encoder(pos_enc) # (Num nodes) x dim_pe
90 |
91 | # Expand node features if needed
92 | if self.expand_x:
93 | h = self.linear_x(batch.x.to(torch.float32))
94 | else:
95 | h = batch.x.to(torch.float32)
96 | # Concatenate final PEs to input embedding
97 | batch.x = torch.cat((h, pos_enc), 1)
98 | # Keep PE also separate in a variable (e.g. for skip connections to input)
99 | if self.pass_as_var:
100 | setattr(batch, f'pe_{self.kernel_type}', pos_enc)
101 | return batch
102 |
103 |
104 | @register_node_encoder('RWSE')
105 | class RWSENodeEncoder(KernelPENodeEncoder):
106 | """Random Walk Structural Encoding node encoder.
107 | """
108 | kernel_type = 'RWSE'
109 |
110 |
111 | @register_node_encoder('HKdiagSE')
112 | class HKdiagSENodeEncoder(KernelPENodeEncoder):
113 | """Heat kernel (diagonal) Structural Encoding node encoder.
114 | """
115 | kernel_type = 'HKdiagSE'
116 |
117 |
118 | @register_node_encoder('ElstaticSE')
119 | class ElstaticSENodeEncoder(KernelPENodeEncoder):
120 | """Electrostatic interactions Structural Encoding node encoder.
121 | """
122 | kernel_type = 'ElstaticSE'
123 |
--------------------------------------------------------------------------------
/gssc/encoder/laplace_pos_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch_geometric.graphgym.config import cfg
4 | from torch_geometric.graphgym.register import register_node_encoder
5 |
6 |
7 | @register_node_encoder('LapPE')
8 | class LapPENodeEncoder(torch.nn.Module):
9 | """Laplace Positional Embedding node encoder.
10 |
11 | LapPE of size dim_pe will get appended to each node feature vector.
12 | If `expand_x` set True, original node features will be first linearly
13 | projected to (dim_emb - dim_pe) size and the concatenated with LapPE.
14 |
15 | Args:
16 | dim_emb: Size of final node embedding
17 | expand_x: Expand node features `x` from dim_in to (dim_emb - dim_pe)
18 | """
19 |
20 | def __init__(self, dim_emb, expand_x=True):
21 | super().__init__()
22 | dim_in = cfg.share.dim_in # Expected original input node features dim
23 |
24 | pecfg = cfg.posenc_LapPE
25 | dim_pe = pecfg.dim_pe # Size of Laplace PE embedding
26 | model_type = pecfg.model # Encoder NN model type for PEs
27 | if model_type not in ['Transformer', 'DeepSet']:
28 | raise ValueError(f"Unexpected PE model {model_type}")
29 | self.model_type = model_type
30 | n_layers = pecfg.layers # Num. layers in PE encoder model
31 | n_heads = pecfg.n_heads # Num. attention heads in Trf PE encoder
32 | post_n_layers = pecfg.post_layers # Num. layers to apply after pooling
33 | max_freqs = pecfg.eigen.max_freqs # Num. eigenvectors (frequencies)
34 | norm_type = pecfg.raw_norm_type.lower() # Raw PE normalization layer type
35 | self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable
36 |
37 | if dim_emb - dim_pe < 1:
38 | raise ValueError(f"LapPE size {dim_pe} is too large for "
39 | f"desired embedding size of {dim_emb}.")
40 |
41 | if expand_x:
42 | self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe)
43 | self.expand_x = expand_x
44 |
45 | # Initial projection of eigenvalue and the node's eigenvector value
46 | self.linear_A = nn.Linear(2, dim_pe)
47 | if norm_type == 'batchnorm':
48 | self.raw_norm = nn.BatchNorm1d(max_freqs)
49 | else:
50 | self.raw_norm = None
51 |
52 | if model_type == 'Transformer':
53 | # Transformer model for LapPE
54 | encoder_layer = nn.TransformerEncoderLayer(d_model=dim_pe,
55 | nhead=n_heads,
56 | batch_first=True)
57 | self.pe_encoder = nn.TransformerEncoder(encoder_layer,
58 | num_layers=n_layers)
59 | else:
60 | # DeepSet model for LapPE
61 | layers = []
62 | if n_layers == 1:
63 | layers.append(nn.ReLU())
64 | else:
65 | self.linear_A = nn.Linear(2, 2 * dim_pe)
66 | layers.append(nn.ReLU())
67 | for _ in range(n_layers - 2):
68 | layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe))
69 | layers.append(nn.ReLU())
70 | layers.append(nn.Linear(2 * dim_pe, dim_pe))
71 | layers.append(nn.ReLU())
72 | self.pe_encoder = nn.Sequential(*layers)
73 |
74 | self.post_mlp = None
75 | if post_n_layers > 0:
76 | # MLP to apply post pooling
77 | layers = []
78 | if post_n_layers == 1:
79 | layers.append(nn.Linear(dim_pe, dim_pe))
80 | layers.append(nn.ReLU())
81 | else:
82 | layers.append(nn.Linear(dim_pe, 2 * dim_pe))
83 | layers.append(nn.ReLU())
84 | for _ in range(post_n_layers - 2):
85 | layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe))
86 | layers.append(nn.ReLU())
87 | layers.append(nn.Linear(2 * dim_pe, dim_pe))
88 | layers.append(nn.ReLU())
89 | self.post_mlp = nn.Sequential(*layers)
90 |
91 |
92 | def forward(self, batch):
93 | if not (hasattr(batch, 'EigVals') and hasattr(batch, 'EigVecs')):
94 | raise ValueError("Precomputed eigen values and vectors are "
95 | f"required for {self.__class__.__name__}; "
96 | "set config 'posenc_LapPE.enable' to True")
97 | EigVals = batch.EigVals
98 | EigVecs = batch.EigVecs
99 |
100 | if self.training:
101 | sign_flip = torch.rand(EigVecs.size(1), device=EigVecs.device)
102 | sign_flip[sign_flip >= 0.5] = 1.0
103 | sign_flip[sign_flip < 0.5] = -1.0
104 | EigVecs = EigVecs * sign_flip.unsqueeze(0)
105 |
106 | pos_enc = torch.cat((EigVecs.unsqueeze(2), EigVals), dim=2) # (Num nodes) x (Num Eigenvectors) x 2
107 | empty_mask = torch.isnan(pos_enc) # (Num nodes) x (Num Eigenvectors) x 2
108 |
109 | pos_enc[empty_mask] = 0 # (Num nodes) x (Num Eigenvectors) x 2
110 | if self.raw_norm:
111 | pos_enc = self.raw_norm(pos_enc)
112 | pos_enc = self.linear_A(pos_enc) # (Num nodes) x (Num Eigenvectors) x dim_pe
113 |
114 | # PE encoder: a Transformer or DeepSet model
115 | if self.model_type == 'Transformer':
116 | pos_enc = self.pe_encoder(src=pos_enc,
117 | src_key_padding_mask=empty_mask[:, :, 0])
118 | else:
119 | pos_enc = self.pe_encoder(pos_enc)
120 |
121 | # Remove masked sequences; must clone before overwriting masked elements
122 | pos_enc = pos_enc.clone().masked_fill_(empty_mask[:, :, 0].unsqueeze(2),
123 | 0.)
124 |
125 | # Sum pooling
126 | pos_enc = torch.sum(pos_enc, 1, keepdim=False) # (Num nodes) x dim_pe
127 |
128 | # MLP post pooling
129 | if self.post_mlp is not None:
130 | pos_enc = self.post_mlp(pos_enc) # (Num nodes) x dim_pe
131 |
132 | # Expand node features if needed
133 | if self.expand_x:
134 | h = self.linear_x(batch.x.to(torch.float32))
135 | else:
136 | h = batch.x
137 | # Concatenate final PEs to input embedding
138 | batch.x = torch.cat((h, pos_enc), 1)
139 | # Keep PE also separate in a variable (e.g. for skip connections to input)
140 | if self.pass_as_var:
141 | batch.pe_LapPE = pos_enc
142 | return batch
143 |
--------------------------------------------------------------------------------
/gssc/encoder/linear_edge_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.graphgym import cfg
3 | from torch_geometric.graphgym.register import register_edge_encoder
4 |
5 |
6 | @register_edge_encoder('LinearEdge')
7 | class LinearEdgeEncoder(torch.nn.Module):
8 | def __init__(self, emb_dim):
9 | super().__init__()
10 | if cfg.dataset.name in ['MNIST', 'CIFAR10']:
11 | self.in_dim = 1
12 | elif cfg.dataset.name == "ogbn-proteins":
13 | self.in_dim = 8
14 | else:
15 | raise ValueError("Input edge feature dim is required to be hardset "
16 | "or refactored to use a cfg option.")
17 | self.encoder = torch.nn.Linear(self.in_dim, emb_dim)
18 |
19 | def forward(self, batch):
20 | batch.edge_attr = self.encoder(batch.edge_attr.view(-1, self.in_dim))
21 | return batch
22 |
--------------------------------------------------------------------------------
/gssc/encoder/linear_node_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.graphgym import cfg
3 | from torch_geometric.graphgym.register import register_node_encoder
4 |
5 |
6 | @register_node_encoder('LinearNode')
7 | class LinearNodeEncoder(torch.nn.Module):
8 | def __init__(self, emb_dim):
9 | super().__init__()
10 |
11 | # self.encoder = torch.nn.Linear(cfg.share.dim_in - 1, emb_dim // 2)
12 | # self.emb_layer = torch.nn.Embedding(7, emb_dim // 2)
13 | # self.bn = torch.nn.BatchNorm1d(emb_dim // 2)
14 | self.encoder = torch.nn.Linear(cfg.share.dim_in, emb_dim)
15 |
16 | def forward(self, batch):
17 | # emb_res = self.emb_layer(batch.x[:, 0].long())
18 | # linear_res = self.encoder(batch.x[:, 1:].float())
19 | # linear_res = self.bn(linear_res)
20 |
21 | batch.x = self.encoder(batch.x.float())
22 | # batch.x = torch.cat([emb_res, linear_res], dim=1)
23 | return batch
24 |
--------------------------------------------------------------------------------
/gssc/encoder/ppa_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.graphgym.register import (register_node_encoder,
3 | register_edge_encoder)
4 |
5 |
6 | @register_node_encoder('PPANode')
7 | class PPANodeEncoder(torch.nn.Module):
8 | """
9 | Uniform input node embedding for PPA that has no node features.
10 | """
11 |
12 | def __init__(self, emb_dim):
13 | super().__init__()
14 | self.encoder = torch.nn.Embedding(1, emb_dim)
15 |
16 | def forward(self, batch):
17 | batch.x = self.encoder(batch.x)
18 | return batch
19 |
20 |
21 | @register_edge_encoder('PPAEdge')
22 | class PPAEdgeEncoder(torch.nn.Module):
23 | def __init__(self, emb_dim):
24 | super().__init__()
25 | self.encoder = torch.nn.Linear(7, emb_dim)
26 |
27 | def forward(self, batch):
28 | batch.edge_attr = self.encoder(batch.edge_attr)
29 | return batch
30 |
--------------------------------------------------------------------------------
/gssc/encoder/type_dict_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.graphgym.config import cfg
3 | from torch_geometric.graphgym.register import (register_node_encoder,
4 | register_edge_encoder)
5 |
6 | """
7 | Generic Node and Edge encoders for datasets with node/edge features that
8 | consist of only one type dictionary thus require a single nn.Embedding layer.
9 |
10 | The number of possible Node and Edge types must be set by cfg options:
11 | 1) cfg.dataset.node_encoder_num_types
12 | 2) cfg.dataset.edge_encoder_num_types
13 |
14 | In case of a more complex feature set, use a data-specific encoder.
15 |
16 | These generic encoders can be used e.g. for:
17 | * ZINC
18 | cfg.dataset.node_encoder_num_types: 28
19 | cfg.dataset.edge_encoder_num_types: 4
20 |
21 | * AQSOL
22 | cfg.dataset.node_encoder_num_types: 65
23 | cfg.dataset.edge_encoder_num_types: 5
24 |
25 |
26 | === Description of the ZINC dataset ===
27 | https://github.com/graphdeeplearning/benchmarking-gnns/issues/42
28 | The node labels are atom types and the edge labels atom bond types.
29 |
30 | Node labels:
31 | 'C': 0
32 | 'O': 1
33 | 'N': 2
34 | 'F': 3
35 | 'C H1': 4
36 | 'S': 5
37 | 'Cl': 6
38 | 'O -': 7
39 | 'N H1 +': 8
40 | 'Br': 9
41 | 'N H3 +': 10
42 | 'N H2 +': 11
43 | 'N +': 12
44 | 'N -': 13
45 | 'S -': 14
46 | 'I': 15
47 | 'P': 16
48 | 'O H1 +': 17
49 | 'N H1 -': 18
50 | 'O +': 19
51 | 'S +': 20
52 | 'P H1': 21
53 | 'P H2': 22
54 | 'C H2 -': 23
55 | 'P +': 24
56 | 'S H1 +': 25
57 | 'C H1 -': 26
58 | 'P H1 +': 27
59 |
60 | Edge labels:
61 | 'NONE': 0
62 | 'SINGLE': 1
63 | 'DOUBLE': 2
64 | 'TRIPLE': 3
65 |
66 |
67 | === Description of the AQSOL dataset ===
68 | Node labels:
69 | 'Br': 0, 'C': 1, 'N': 2, 'O': 3, 'Cl': 4, 'Zn': 5, 'F': 6, 'P': 7, 'S': 8, 'Na': 9, 'Al': 10,
70 | 'Si': 11, 'Mo': 12, 'Ca': 13, 'W': 14, 'Pb': 15, 'B': 16, 'V': 17, 'Co': 18, 'Mg': 19, 'Bi': 20, 'Fe': 21,
71 | 'Ba': 22, 'K': 23, 'Ti': 24, 'Sn': 25, 'Cd': 26, 'I': 27, 'Re': 28, 'Sr': 29, 'H': 30, 'Cu': 31, 'Ni': 32,
72 | 'Lu': 33, 'Pr': 34, 'Te': 35, 'Ce': 36, 'Nd': 37, 'Gd': 38, 'Zr': 39, 'Mn': 40, 'As': 41, 'Hg': 42, 'Sb':
73 | 43, 'Cr': 44, 'Se': 45, 'La': 46, 'Dy': 47, 'Y': 48, 'Pd': 49, 'Ag': 50, 'In': 51, 'Li': 52, 'Rh': 53,
74 | 'Nb': 54, 'Hf': 55, 'Cs': 56, 'Ru': 57, 'Au': 58, 'Sm': 59, 'Ta': 60, 'Pt': 61, 'Ir': 62, 'Be': 63, 'Ge': 64
75 |
76 | Edge labels:
77 | 'NONE': 0, 'SINGLE': 1, 'DOUBLE': 2, 'AROMATIC': 3, 'TRIPLE': 4
78 | """
79 |
80 |
81 | @register_node_encoder('TypeDictNode')
82 | class TypeDictNodeEncoder(torch.nn.Module):
83 | def __init__(self, emb_dim):
84 | super().__init__()
85 |
86 | num_types = cfg.dataset.node_encoder_num_types
87 | if num_types < 1:
88 | raise ValueError(f"Invalid 'node_encoder_num_types': {num_types}")
89 |
90 | self.encoder = torch.nn.Embedding(num_embeddings=num_types,
91 | embedding_dim=emb_dim)
92 | # torch.nn.init.xavier_uniform_(self.encoder.weight.data)
93 |
94 | def forward(self, batch):
95 | # Encode just the first dimension if more exist
96 | batch.x = self.encoder(batch.x[:, 0])
97 |
98 | return batch
99 |
100 |
101 | @register_edge_encoder('TypeDictEdge')
102 | class TypeDictEdgeEncoder(torch.nn.Module):
103 | def __init__(self, emb_dim):
104 | super().__init__()
105 |
106 | num_types = cfg.dataset.edge_encoder_num_types
107 | if num_types < 1:
108 | raise ValueError(f"Invalid 'edge_encoder_num_types': {num_types}")
109 |
110 | self.encoder = torch.nn.Embedding(num_embeddings=num_types,
111 | embedding_dim=emb_dim)
112 | # torch.nn.init.xavier_uniform_(self.encoder.weight.data)
113 |
114 | def forward(self, batch):
115 | batch.edge_attr = self.encoder(batch.edge_attr)
116 | return batch
117 |
--------------------------------------------------------------------------------
/gssc/encoder/voc_superpixels_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.graphgym.config import cfg
3 | from torch_geometric.graphgym.register import (register_node_encoder,
4 | register_edge_encoder)
5 |
6 | """
7 | === Description of the VOCSuperpixels dataset ===
8 | Each graph is a tuple (x, edge_attr, edge_index, y)
9 | Shape of x : [num_nodes, 14]
10 | Shape of edge_attr : [num_edges, 1] or [num_edges, 2]
11 | Shape of edge_index : [2, num_edges]
12 | Shape of y : [num_nodes]
13 | """
14 |
15 | VOC_node_input_dim = 14
16 | # VOC_edge_input_dim = 1 or 2; defined in class VOCEdgeEncoder
17 |
18 | @register_node_encoder('VOCNode')
19 | class VOCNodeEncoder(torch.nn.Module):
20 | def __init__(self, emb_dim):
21 | super().__init__()
22 |
23 | self.encoder = torch.nn.Linear(VOC_node_input_dim, emb_dim)
24 | # torch.nn.init.xavier_uniform_(self.encoder.weight.data)
25 |
26 | def forward(self, batch):
27 | batch.x = self.encoder(batch.x)
28 |
29 | return batch
30 |
31 |
32 | @register_edge_encoder('VOCEdge')
33 | class VOCEdgeEncoder(torch.nn.Module):
34 | def __init__(self, emb_dim):
35 | super().__init__()
36 |
37 | VOC_edge_input_dim = 2 if cfg.dataset.name == 'edge_wt_region_boundary' else 1
38 | self.encoder = torch.nn.Linear(VOC_edge_input_dim, emb_dim)
39 | # torch.nn.init.xavier_uniform_(self.encoder.weight.data)
40 |
41 | def forward(self, batch):
42 | batch.edge_attr = self.encoder(batch.edge_attr)
43 | return batch
44 |
--------------------------------------------------------------------------------
/gssc/finetuning.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import os.path as osp
4 |
5 | import torch
6 | from torch_geometric.graphgym.config import set_cfg
7 | from yacs.config import CfgNode
8 |
9 |
10 | def get_final_pretrained_ckpt(ckpt_dir):
11 | if osp.exists(ckpt_dir):
12 | names = os.listdir(ckpt_dir)
13 | epochs = [int(name.split('.')[0]) for name in names]
14 | final_epoch = max(epochs)
15 | else:
16 | raise FileNotFoundError(f"Pretrained model dir not found: {ckpt_dir}")
17 | return osp.join(ckpt_dir, f'{final_epoch}.ckpt')
18 |
19 |
20 | def compare_cfg(cfg_main, cfg_secondary, field_name, strict=False):
21 | main_val, secondary_val = cfg_main, cfg_secondary
22 | for f in field_name.split('.'):
23 | main_val = main_val[f]
24 | secondary_val = secondary_val[f]
25 | if main_val != secondary_val:
26 | if strict:
27 | raise ValueError(f"Main and pretrained configs must match on "
28 | f"'{field_name}'")
29 | else:
30 | logging.warning(f"Pretrained models '{field_name}' differs, "
31 | f"using: {main_val}")
32 |
33 |
34 | def set_new_cfg_allowed(config, is_new_allowed):
35 | """ Set YACS config (and recursively its subconfigs) to allow merging
36 | new keys from other configs.
37 | """
38 | config.__dict__[CfgNode.NEW_ALLOWED] = is_new_allowed
39 | # Recursively set new_allowed state
40 | for v in config.__dict__.values():
41 | if isinstance(v, CfgNode):
42 | set_new_cfg_allowed(v, is_new_allowed)
43 | for v in config.values():
44 | if isinstance(v, CfgNode):
45 | set_new_cfg_allowed(v, is_new_allowed)
46 |
47 |
48 | def load_pretrained_model_cfg(cfg):
49 | pretrained_cfg_fname = osp.join(cfg.pretrained.dir, 'config.yaml')
50 | if not os.path.isfile(pretrained_cfg_fname):
51 | FileNotFoundError(f"Pretrained model config not found: "
52 | f"{pretrained_cfg_fname}")
53 |
54 | logging.info(f"[*] Updating cfg from pretrained model: "
55 | f"{pretrained_cfg_fname}")
56 |
57 | pretrained_cfg = CfgNode()
58 | set_cfg(pretrained_cfg)
59 | set_new_cfg_allowed(pretrained_cfg, True)
60 | pretrained_cfg.merge_from_file(pretrained_cfg_fname)
61 |
62 | assert cfg.model.type == 'GPSModel', \
63 | "Fine-tuning regime is untested for other model types."
64 | compare_cfg(cfg, pretrained_cfg, 'model.type', strict=True)
65 | compare_cfg(cfg, pretrained_cfg, 'model.graph_pooling')
66 | compare_cfg(cfg, pretrained_cfg, 'model.edge_decoding')
67 | compare_cfg(cfg, pretrained_cfg, 'dataset.node_encoder', strict=True)
68 | compare_cfg(cfg, pretrained_cfg, 'dataset.node_encoder_name', strict=True)
69 | compare_cfg(cfg, pretrained_cfg, 'dataset.node_encoder_bn', strict=True)
70 | compare_cfg(cfg, pretrained_cfg, 'dataset.edge_encoder', strict=True)
71 | compare_cfg(cfg, pretrained_cfg, 'dataset.edge_encoder_name', strict=True)
72 | compare_cfg(cfg, pretrained_cfg, 'dataset.edge_encoder_bn', strict=True)
73 |
74 | # Copy over all PE/SE configs
75 | for key in cfg.keys():
76 | if key.startswith('posenc_'):
77 | cfg[key] = pretrained_cfg[key]
78 |
79 | # Copy over GT config
80 | cfg.gt = pretrained_cfg.gt
81 |
82 | # Copy over GNN cfg but not those for the prediction head
83 | compare_cfg(cfg, pretrained_cfg, 'gnn.head')
84 | compare_cfg(cfg, pretrained_cfg, 'gnn.layers_post_mp')
85 | compare_cfg(cfg, pretrained_cfg, 'gnn.act')
86 | compare_cfg(cfg, pretrained_cfg, 'gnn.dropout')
87 | head = cfg.gnn.head
88 | post_mp = cfg.gnn.layers_post_mp
89 | act = cfg.gnn.act
90 | drp = cfg.gnn.dropout
91 | cfg.gnn = pretrained_cfg.gnn
92 | cfg.gnn.head = head
93 | cfg.gnn.layers_post_mp = post_mp
94 | cfg.gnn.act = act
95 | cfg.gnn.dropout = drp
96 | return cfg
97 |
98 |
99 | def init_model_from_pretrained(model, pretrained_dir,
100 | freeze_main=False, reset_prediction_head=True):
101 | """ Copy model parameters from pretrained model except the prediction head.
102 |
103 | Args:
104 | model: Initialized model with random weights.
105 | pretrained_dir: Root directory of saved pretrained model.
106 | freeze_main: If True, do not finetune the loaded pretrained parameters
107 | of the `main body` (train the prediction head only), else train all.
108 | reset_prediction_head: If True, reset parameters of the prediction head,
109 | else keep the pretrained weights.
110 |
111 | Returns:
112 | Updated pytorch model object.
113 | """
114 | from torch_geometric.graphgym.checkpoint import MODEL_STATE
115 |
116 | ckpt_file = get_final_pretrained_ckpt(osp.join(pretrained_dir, '0', 'ckpt'))
117 | logging.info(f"[*] Loading from pretrained model: {ckpt_file}")
118 |
119 | ckpt = torch.load(ckpt_file)
120 | pretrained_dict = ckpt[MODEL_STATE]
121 | model_dict = model.state_dict()
122 |
123 | # print('>>>> pretrained dict: ')
124 | # print(pretrained_dict.keys())
125 | # print('>>>> model dict: ')
126 | # print(model_dict.keys())
127 |
128 | if reset_prediction_head:
129 | # Filter out prediction head parameter keys.
130 | pretrained_dict = {k: v for k, v in pretrained_dict.items()
131 | if not k.startswith('post_mp')}
132 | # Overwrite entries in the existing state dict.
133 | model_dict.update(pretrained_dict)
134 | # Load the new state dict.
135 | model.load_state_dict(model_dict)
136 |
137 | if freeze_main:
138 | for key, param in model.named_parameters():
139 | if not key.startswith('post_mp'):
140 | param.requires_grad = False
141 | return model
142 |
--------------------------------------------------------------------------------
/gssc/head/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/gssc/head/example.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from torch_geometric.graphgym.register import register_head
4 |
5 |
6 | @register_head('head')
7 | class ExampleNodeHead(nn.Module):
8 | '''Head of GNN, node prediction'''
9 | def __init__(self, dim_in, dim_out):
10 | super().__init__()
11 | self.layer_post_mp = nn.Linear(dim_in, dim_out, bias=True)
12 |
13 | def _apply_index(self, batch):
14 | if batch.node_label_index.shape[0] == batch.node_label.shape[0]:
15 | return batch.x[batch.node_label_index], batch.node_label
16 | else:
17 | return batch.x[batch.node_label_index], \
18 | batch.node_label[batch.node_label_index]
19 |
20 | def forward(self, batch):
21 | batch = self.layer_post_mp(batch)
22 | pred, label = self._apply_index(batch)
23 | return pred, label
24 |
--------------------------------------------------------------------------------
/gssc/head/inductive_node.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch_geometric.graphgym.config import cfg
3 | from torch_geometric.graphgym.models.layer import new_layer_config, MLP
4 | from torch_geometric.graphgym.register import register_head
5 | import torch
6 |
7 |
8 | @register_head('inductive_node')
9 | class GNNInductiveNodeHead(nn.Module):
10 | """
11 | GNN prediction head for inductive node prediction tasks.
12 |
13 | Args:
14 | dim_in (int): Input dimension
15 | dim_out (int): Output dimension. For binary prediction, dim_out=1.
16 | """
17 |
18 | def __init__(self, dim_in, dim_out):
19 | super(GNNInductiveNodeHead, self).__init__()
20 | self.layer_post_mp = MLP(
21 | new_layer_config(dim_in, dim_out, cfg.gnn.layers_post_mp,
22 | has_act=False, has_bias=True, cfg=cfg))
23 | if cfg.extra.jk:
24 | self.jk_mlp = nn.Sequential(nn.Linear(cfg.gnn.dim_inner*cfg.gt.layers, cfg.gnn.dim_inner), nn.SiLU(), nn.Linear(cfg.gnn.dim_inner, cfg.gnn.dim_inner))
25 |
26 | def _apply_index(self, batch):
27 | return batch.x, batch.y
28 |
29 | def forward(self, batch):
30 | if cfg.extra.jk:
31 | batch.x = self.jk_mlp(torch.cat(batch.all_x, dim=-1))
32 | batch = self.layer_post_mp(batch)
33 | pred, label = self._apply_index(batch)
34 | return pred, label
35 |
--------------------------------------------------------------------------------
/gssc/head/ogb_code_graph.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | import torch_geometric.graphgym.register as register
4 | from torch_geometric.graphgym import cfg
5 | from torch_geometric.graphgym.register import register_head
6 |
7 |
8 | @register_head('ogb_code_graph')
9 | class OGBCodeGraphHead(nn.Module):
10 | """
11 | Sequence prediction head for ogbg-code2 graph-level prediction tasks.
12 |
13 | Args:
14 | dim_in (int): Input dimension.
15 | dim_out (int): IGNORED, kept for GraphGym framework compatibility
16 | L (int): Number of hidden layers.
17 | """
18 |
19 | def __init__(self, dim_in, dim_out, L=1):
20 | super().__init__()
21 | self.pooling_fun = register.pooling_dict[cfg.model.graph_pooling]
22 | self.L = L
23 | num_vocab = 5002
24 | self.max_seq_len = 5
25 |
26 | if self.L != 1:
27 | raise ValueError(f"Multilayer prediction heads are not supported.")
28 |
29 | self.graph_pred_linear_list = nn.ModuleList()
30 | for i in range(self.max_seq_len):
31 | self.graph_pred_linear_list.append(nn.Linear(dim_in, num_vocab))
32 |
33 | def _apply_index(self, batch):
34 | return batch.pred_list, {'y_arr': batch.y_arr, 'y': batch.y}
35 |
36 | def forward(self, batch):
37 | graph_emb = self.pooling_fun(batch.x, batch.batch)
38 |
39 | pred_list = []
40 | for i in range(self.max_seq_len):
41 | pred_list.append(self.graph_pred_linear_list[i](graph_emb))
42 | batch.pred_list = pred_list
43 |
44 | pred, label = self._apply_index(batch)
45 | return pred, label
46 |
--------------------------------------------------------------------------------
/gssc/head/san_graph.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | import torch_geometric.graphgym.register as register
5 | from torch_geometric.graphgym import cfg
6 | from torch_geometric.graphgym.register import register_head
7 |
8 |
9 | @register_head('san_graph')
10 | class SANGraphHead(nn.Module):
11 | """
12 | SAN prediction head for graph prediction tasks.
13 |
14 | Args:
15 | dim_in (int): Input dimension.
16 | dim_out (int): Output dimension. For binary prediction, dim_out=1.
17 | L (int): Number of hidden layers.
18 | """
19 |
20 | def __init__(self, dim_in, dim_out, L=2):
21 | super().__init__()
22 | self.pooling_fun = register.pooling_dict[cfg.model.graph_pooling]
23 | list_FC_layers = [
24 | nn.Linear(dim_in // 2 ** l, dim_in // 2 ** (l + 1), bias=True)
25 | for l in range(L)]
26 | list_FC_layers.append(
27 | nn.Linear(dim_in // 2 ** L, dim_out, bias=True))
28 | self.FC_layers = nn.ModuleList(list_FC_layers)
29 | self.L = L
30 |
31 | def _apply_index(self, batch):
32 | return batch.graph_feature, batch.y
33 |
34 | def forward(self, batch):
35 | graph_emb = self.pooling_fun(batch.x, batch.batch)
36 | for l in range(self.L):
37 | graph_emb = self.FC_layers[l](graph_emb)
38 | graph_emb = F.relu(graph_emb)
39 | graph_emb = self.FC_layers[self.L](graph_emb)
40 | batch.graph_feature = graph_emb
41 | pred, label = self._apply_index(batch)
42 | return pred, label
43 |
--------------------------------------------------------------------------------
/gssc/layer/ETransformer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch_scatter import scatter
6 |
7 | from torch_geometric.graphgym.models.layer import LayerConfig
8 | from torch_geometric.graphgym.config import cfg
9 | from torch_geometric.graphgym.register import register_layer
10 |
11 |
12 | class ETransformer(nn.Module):
13 | """Mostly Multi-Head Graph Attention Layer.
14 |
15 | Ported to PyG from original repo:
16 | https://github.com/DevinKreuzer/SAN/blob/main/layers/graph_transformer_layer.py
17 | """
18 |
19 | def __init__(self, in_dim, out_dim, num_heads, use_bias, edge_index='edge_index', use_edge_attr=False, edge_attr='edge_attr'):
20 | super().__init__()
21 |
22 | if out_dim % num_heads != 0:
23 | raise ValueError('hidden dimension is not dividable by the number of heads')
24 | self.out_dim = out_dim // num_heads
25 | self.num_heads = num_heads
26 | self.edge_index = edge_index
27 | self.edge_attr = edge_attr
28 | self.use_edge_attr = use_edge_attr
29 |
30 | self.Q = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias)
31 | self.K = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias)
32 | if self.use_edge_attr:
33 | self.E = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias)
34 | self.V = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias)
35 |
36 | def propagate_attention(self, batch):
37 | edge_index = getattr(batch, self.edge_index)
38 |
39 | src = batch.K_h[edge_index[0]] # (num real edges) x num_heads x out_dim
40 | dest = batch.Q_h[edge_index[1]] # (num real edges) x num_heads x out_dim
41 | score = torch.mul(src, dest) # element-wise multiplication
42 |
43 | # Scale scores by sqrt(d)
44 | score = score / np.sqrt(self.out_dim)
45 |
46 | # Use available edge features to modify the scores for edges
47 | if self.use_edge_attr:
48 | score = torch.mul(score, batch.E) # (num real edges) x num_heads x out_dim
49 | score = torch.exp(score.sum(-1, keepdim=True).clamp(-5, 5)) # (num real edges) x num_heads x 1
50 |
51 | # Apply attention score to each source node to create edge messages
52 | msg = batch.V_h[edge_index[0]] * score # (num real edges) x num_heads x out_dim
53 | # Add-up real msgs in destination nodes as given by batch.edge_index[1]
54 | batch.wV = torch.zeros_like(batch.V_h) # (num nodes in batch) x num_heads x out_dim
55 | scatter(msg, edge_index[1], dim=0, out=batch.wV, reduce='add')
56 |
57 | # Compute attention normalization coefficient
58 | batch.Z = score.new_zeros(batch.size(0), self.num_heads, 1) # (num nodes in batch) x num_heads x 1
59 | scatter(score, edge_index[1], dim=0, out=batch.Z, reduce='add')
60 |
61 | def forward(self, batch):
62 | edge_index = getattr(batch, self.edge_index)
63 |
64 | if edge_index is None:
65 | raise ValueError(f'edge index: f{self.edge_index} not found')
66 |
67 | if edge_index.shape[0] != 2 and edge_index.shape[1] == 2:
68 | edge_index = torch.t(edge_index)
69 | setattr(batch, self.edge_index, edge_index)
70 |
71 | if self.use_edge_attr:
72 | edge_attr = getattr(batch, self.edge_attr)
73 | if edge_attr is None or edge_attr.shape[0] != edge_index.shape[1]:
74 | print('edge_attr shape does not match edge_index shape, ignoring edge_attr')
75 | self.use_edge_attr = False
76 |
77 | Q_h = self.Q(batch.x)
78 | K_h = self.K(batch.x)
79 | if self.use_edge_attr:
80 | E = self.E(edge_attr)
81 | V_h = self.V(batch.x)
82 |
83 | # Reshaping into [num_nodes, num_heads, feat_dim] to
84 | # get projections for multi-head attention
85 | batch.Q_h = Q_h.view(-1, self.num_heads, self.out_dim)
86 | batch.K_h = K_h.view(-1, self.num_heads, self.out_dim)
87 | if self.use_edge_attr:
88 | batch.E = E.view(-1, self.num_heads, self.out_dim)
89 | batch.V_h = V_h.view(-1, self.num_heads, self.out_dim)
90 |
91 | self.propagate_attention(batch)
92 |
93 | h_out = batch.wV / (batch.Z + 1e-6)
94 |
95 | h_out = h_out.view(-1, self.out_dim * self.num_heads)
96 |
97 | return h_out
98 |
99 |
100 | register_layer('etransformer', ETransformer)
--------------------------------------------------------------------------------
/gssc/layer/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/gssc/layer/example.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import Parameter
4 |
5 | from torch_geometric.graphgym.config import cfg
6 | from torch_geometric.graphgym.register import register_layer
7 | from torch_geometric.nn.conv import MessagePassing
8 | from torch_geometric.nn.inits import glorot, zeros
9 |
10 | # Note: A registered GNN layer should take 'batch' as input
11 | # and 'batch' as output
12 |
13 |
14 | # Example 1: Directly define a GraphGym format Conv
15 | # take 'batch' as input and 'batch' as output
16 | @register_layer('exampleconv1')
17 | class ExampleConv1(MessagePassing):
18 | r"""Example GNN layer
19 | """
20 | def __init__(self, in_channels, out_channels, bias=True, **kwargs):
21 | super().__init__(aggr=cfg.gnn.agg, **kwargs)
22 |
23 | self.in_channels = in_channels
24 | self.out_channels = out_channels
25 |
26 | self.weight = Parameter(torch.Tensor(in_channels, out_channels))
27 |
28 | if bias:
29 | self.bias = Parameter(torch.Tensor(out_channels))
30 | else:
31 | self.register_parameter('bias', None)
32 |
33 | self.reset_parameters()
34 |
35 | def reset_parameters(self):
36 | glorot(self.weight)
37 | zeros(self.bias)
38 |
39 | def forward(self, batch):
40 | """"""
41 | x, edge_index = batch.x, batch.edge_index
42 | x = torch.matmul(x, self.weight)
43 |
44 | batch.x = self.propagate(edge_index, x=x)
45 |
46 | return batch
47 |
48 | def message(self, x_j):
49 | return x_j
50 |
51 | def update(self, aggr_out):
52 | if self.bias is not None:
53 | aggr_out = aggr_out + self.bias
54 | return aggr_out
55 |
56 |
57 | # Example 2: First define a PyG format Conv layer
58 | # Then wrap it to become GraphGym format
59 | class ExampleConv2Layer(MessagePassing):
60 | r"""Example GNN layer
61 | """
62 | def __init__(self, in_channels, out_channels, bias=True, **kwargs):
63 | super().__init__(aggr=cfg.gnn.agg, **kwargs)
64 |
65 | self.in_channels = in_channels
66 | self.out_channels = out_channels
67 |
68 | self.weight = Parameter(torch.Tensor(in_channels, out_channels))
69 |
70 | if bias:
71 | self.bias = Parameter(torch.Tensor(out_channels))
72 | else:
73 | self.register_parameter('bias', None)
74 |
75 | self.reset_parameters()
76 |
77 | def reset_parameters(self):
78 | glorot(self.weight)
79 | zeros(self.bias)
80 |
81 | def forward(self, x, edge_index):
82 | """"""
83 | x = torch.matmul(x, self.weight)
84 |
85 | return self.propagate(edge_index, x=x)
86 |
87 | def message(self, x_j):
88 | return x_j
89 |
90 | def update(self, aggr_out):
91 | if self.bias is not None:
92 | aggr_out = aggr_out + self.bias
93 | return aggr_out
94 |
95 |
96 | @register_layer('exampleconv2')
97 | class ExampleConv2(nn.Module):
98 | def __init__(self, dim_in, dim_out, bias=False, **kwargs):
99 | super().__init__()
100 | self.model = ExampleConv2Layer(dim_in, dim_out, bias=bias)
101 |
102 | def forward(self, batch):
103 | batch.x = self.model(batch.x, batch.edge_index)
104 | return batch
105 |
--------------------------------------------------------------------------------
/gssc/layer/gatedgcn_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch_geometric.nn as pyg_nn
5 | from torch_geometric.graphgym.models.layer import LayerConfig
6 | from torch_scatter import scatter
7 |
8 | from torch_geometric.graphgym.config import cfg
9 | from torch_geometric.graphgym.register import register_layer
10 |
11 |
12 | class GatedGCNLayer(pyg_nn.conv.MessagePassing):
13 | """
14 | GatedGCN layer
15 | Residual Gated Graph ConvNets
16 | https://arxiv.org/pdf/1711.07553.pdf
17 | """
18 | def __init__(self, in_dim, out_dim, dropout, residual,
19 | equivstable_pe=False, **kwargs):
20 | super().__init__(**kwargs)
21 | self.A = pyg_nn.Linear(in_dim, out_dim, bias=True)
22 | self.B = pyg_nn.Linear(in_dim, out_dim, bias=True)
23 | self.C = pyg_nn.Linear(in_dim, out_dim, bias=True)
24 | self.D = pyg_nn.Linear(in_dim, out_dim, bias=True)
25 | self.E = pyg_nn.Linear(in_dim, out_dim, bias=True)
26 |
27 | # Handling for Equivariant and Stable PE using LapPE
28 | # ICLR 2022 https://openreview.net/pdf?id=e95i1IHcWj
29 | self.EquivStablePE = equivstable_pe
30 | if self.EquivStablePE:
31 | self.mlp_r_ij = nn.Sequential(
32 | nn.Linear(1, out_dim), nn.ReLU(),
33 | nn.Linear(out_dim, 1),
34 | nn.Sigmoid())
35 |
36 | self.bn_node_x = nn.BatchNorm1d(out_dim)
37 | self.bn_edge_e = nn.BatchNorm1d(out_dim)
38 | self.dropout = dropout
39 | self.residual = residual
40 | self.e = None
41 |
42 | def forward(self, batch):
43 | x, e, edge_index = batch.x, batch.edge_attr, batch.edge_index
44 |
45 | """
46 | x : [n_nodes, in_dim]
47 | e : [n_edges, in_dim]
48 | edge_index : [2, n_edges]
49 | """
50 | if self.residual:
51 | x_in = x
52 | e_in = e
53 |
54 | Ax = self.A(x)
55 | Bx = self.B(x)
56 | Ce = self.C(e)
57 | Dx = self.D(x)
58 | Ex = self.E(x)
59 |
60 | # Handling for Equivariant and Stable PE using LapPE
61 | # ICLR 2022 https://openreview.net/pdf?id=e95i1IHcWj
62 | pe_LapPE = batch.pe_EquivStableLapPE if self.EquivStablePE else None
63 |
64 | x, e = self.propagate(edge_index,
65 | Bx=Bx, Dx=Dx, Ex=Ex, Ce=Ce,
66 | e=e, Ax=Ax,
67 | PE=pe_LapPE)
68 |
69 | x = self.bn_node_x(x)
70 | e = self.bn_edge_e(e)
71 |
72 | x = F.relu(x)
73 | e = F.relu(e)
74 |
75 | x = F.dropout(x, self.dropout, training=self.training)
76 | e = F.dropout(e, self.dropout, training=self.training)
77 |
78 | if self.residual:
79 | x = x_in + x
80 | e = e_in + e
81 |
82 | batch.x = x
83 | batch.edge_attr = e
84 |
85 | return batch
86 |
87 | def message(self, Dx_i, Ex_j, PE_i, PE_j, Ce):
88 | """
89 | {}x_i : [n_edges, out_dim]
90 | {}x_j : [n_edges, out_dim]
91 | {}e : [n_edges, out_dim]
92 | """
93 | e_ij = Dx_i + Ex_j + Ce
94 | sigma_ij = torch.sigmoid(e_ij)
95 |
96 | # Handling for Equivariant and Stable PE using LapPE
97 | # ICLR 2022 https://openreview.net/pdf?id=e95i1IHcWj
98 | if self.EquivStablePE:
99 | r_ij = ((PE_i - PE_j) ** 2).sum(dim=-1, keepdim=True)
100 | r_ij = self.mlp_r_ij(r_ij) # the MLP is 1 dim --> hidden_dim --> 1 dim
101 | sigma_ij = sigma_ij * r_ij
102 |
103 | self.e = e_ij
104 | return sigma_ij
105 |
106 | def aggregate(self, sigma_ij, index, Bx_j, Bx):
107 | """
108 | sigma_ij : [n_edges, out_dim] ; is the output from message() function
109 | index : [n_edges]
110 | {}x_j : [n_edges, out_dim]
111 | """
112 | dim_size = Bx.shape[0] # or None ?? <--- Double check this
113 |
114 | sum_sigma_x = sigma_ij * Bx_j
115 | numerator_eta_xj = scatter(sum_sigma_x, index, 0, None, dim_size,
116 | reduce='sum')
117 |
118 | sum_sigma = sigma_ij
119 | denominator_eta_xj = scatter(sum_sigma, index, 0, None, dim_size,
120 | reduce='sum')
121 |
122 | out = numerator_eta_xj / (denominator_eta_xj + 1e-6)
123 | return out
124 |
125 | def update(self, aggr_out, Ax):
126 | """
127 | aggr_out : [n_nodes, out_dim] ; is the output from aggregate() function after the aggregation
128 | {}x : [n_nodes, out_dim]
129 | """
130 | x = Ax + aggr_out
131 | e_out = self.e
132 | del self.e
133 | return x, e_out
134 |
135 |
136 | @register_layer('gatedgcnconv')
137 | class GatedGCNGraphGymLayer(nn.Module):
138 | """GatedGCN layer.
139 | Residual Gated Graph ConvNets
140 | https://arxiv.org/pdf/1711.07553.pdf
141 | """
142 | def __init__(self, layer_config: LayerConfig, **kwargs):
143 | super().__init__()
144 | self.model = GatedGCNLayer(in_dim=layer_config.dim_in,
145 | out_dim=layer_config.dim_out,
146 | dropout=0., # Dropout is handled by GraphGym's `GeneralLayer` wrapper
147 | residual=False, # Residual connections are handled by GraphGym's `GNNStackStage` wrapper
148 | **kwargs)
149 |
150 | def forward(self, batch):
151 | return self.model(batch)
152 |
--------------------------------------------------------------------------------
/gssc/layer/gine_conv_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch_geometric.nn as pyg_nn
5 |
6 | from torch_geometric.graphgym.models.layer import LayerConfig
7 | from torch_geometric.graphgym.register import register_layer
8 | from torch_geometric.nn import Linear as Linear_pyg
9 |
10 |
11 | class GINEConvESLapPE(pyg_nn.conv.MessagePassing):
12 | """GINEConv Layer with EquivStableLapPE implementation.
13 |
14 | Modified torch_geometric.nn.conv.GINEConv layer to perform message scaling
15 | according to equiv. stable PEG-layer with Laplacian Eigenmap (LapPE):
16 | ICLR 2022 https://openreview.net/pdf?id=e95i1IHcWj
17 | """
18 | def __init__(self, nn, eps=0., train_eps=False, edge_dim=None, **kwargs):
19 | kwargs.setdefault('aggr', 'add')
20 | super().__init__(**kwargs)
21 | self.nn = nn
22 | self.initial_eps = eps
23 | if train_eps:
24 | self.eps = torch.nn.Parameter(torch.Tensor([eps]))
25 | else:
26 | self.register_buffer('eps', torch.Tensor([eps]))
27 | if edge_dim is not None:
28 | if hasattr(self.nn[0], 'in_features'):
29 | in_channels = self.nn[0].in_features
30 | else:
31 | in_channels = self.nn[0].in_channels
32 | self.lin = pyg_nn.Linear(edge_dim, in_channels)
33 | else:
34 | self.lin = None
35 | self.reset_parameters()
36 |
37 | if hasattr(self.nn[0], 'in_features'):
38 | out_dim = self.nn[0].out_features
39 | else:
40 | out_dim = self.nn[0].out_channels
41 |
42 | # Handling for Equivariant and Stable PE using LapPE
43 | # ICLR 2022 https://openreview.net/pdf?id=e95i1IHcWj
44 | self.mlp_r_ij = torch.nn.Sequential(
45 | torch.nn.Linear(1, out_dim), torch.nn.ReLU(),
46 | torch.nn.Linear(out_dim, 1),
47 | torch.nn.Sigmoid())
48 |
49 | def reset_parameters(self):
50 | pyg_nn.inits.reset(self.nn)
51 | self.eps.data.fill_(self.initial_eps)
52 | if self.lin is not None:
53 | self.lin.reset_parameters()
54 | pyg_nn.inits.reset(self.mlp_r_ij)
55 |
56 | def forward(self, x, edge_index, edge_attr=None, pe_LapPE=None, size=None):
57 | # if isinstance(x, Tensor):
58 | # x: OptPairTensor = (x, x)
59 |
60 | # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
61 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr,
62 | PE=pe_LapPE, size=size)
63 |
64 | x_r = x[1]
65 | if x_r is not None:
66 | out += (1 + self.eps) * x_r
67 |
68 | return self.nn(out)
69 |
70 | def message(self, x_j, edge_attr, PE_i, PE_j):
71 | if self.lin is None and x_j.size(-1) != edge_attr.size(-1):
72 | raise ValueError("Node and edge feature dimensionalities do not "
73 | "match. Consider setting the 'edge_dim' "
74 | "attribute of 'GINEConv'")
75 |
76 | if self.lin is not None:
77 | edge_attr = self.lin(edge_attr)
78 |
79 | # Handling for Equivariant and Stable PE using LapPE
80 | # ICLR 2022 https://openreview.net/pdf?id=e95i1IHcWj
81 | r_ij = ((PE_i - PE_j) ** 2).sum(dim=-1, keepdim=True)
82 | r_ij = self.mlp_r_ij(r_ij) # the MLP is 1 dim --> hidden_dim --> 1 dim
83 |
84 | return ((x_j + edge_attr).relu()) * r_ij
85 |
86 | def __repr__(self):
87 | return f'{self.__class__.__name__}(nn={self.nn})'
88 |
89 |
90 | class GINEConvLayer(nn.Module):
91 | """Graph Isomorphism Network with Edge features (GINE) layer.
92 | """
93 | def __init__(self, dim_in, dim_out, dropout, residual):
94 | super().__init__()
95 | self.dim_in = dim_in
96 | self.dim_out = dim_out
97 | self.dropout = dropout
98 | self.residual = residual
99 |
100 | gin_nn = nn.Sequential(
101 | pyg_nn.Linear(dim_in, dim_out), nn.ReLU(),
102 | pyg_nn.Linear(dim_out, dim_out))
103 | self.model = pyg_nn.GINEConv(gin_nn)
104 |
105 | def forward(self, batch):
106 | x_in = batch.x
107 |
108 | batch.x = self.model(batch.x, batch.edge_index, batch.edge_attr)
109 |
110 | batch.x = F.relu(batch.x)
111 | batch.x = F.dropout(batch.x, p=self.dropout, training=self.training)
112 |
113 | if self.residual:
114 | batch.x = x_in + batch.x # residual connection
115 |
116 | return batch
117 |
118 |
119 | @register_layer('gineconv')
120 | class GINEConvGraphGymLayer(nn.Module):
121 | """Graph Isomorphism Network with Edge features (GINE) layer.
122 | """
123 | def __init__(self, layer_config: LayerConfig, **kwargs):
124 | super().__init__()
125 | gin_nn = nn.Sequential(
126 | Linear_pyg(layer_config.dim_in, layer_config.dim_out), nn.ReLU(),
127 | Linear_pyg(layer_config.dim_out, layer_config.dim_out))
128 | self.model = pyg_nn.GINEConv(gin_nn)
129 |
130 | def forward(self, batch):
131 | batch.x = self.model(batch.x, batch.edge_index, batch.edge_attr)
132 | return batch
133 |
--------------------------------------------------------------------------------
/gssc/layer/gssc_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops import rearrange
4 | from torch_geometric.utils import to_dense_batch
5 | import torch.nn.functional as F
6 | from torch_geometric.graphgym.config import cfg
7 |
8 |
9 | class GSSC(nn.Module):
10 | def __init__(self, *args, **kwargs) -> None:
11 | super().__init__(*args, **kwargs)
12 | self.deg_coef = nn.Parameter(torch.zeros(1, 1, cfg.gnn.dim_inner, 2))
13 | nn.init.xavier_normal_(self.deg_coef)
14 | if cfg.extra.more_mapping:
15 | self.x_head_mapping = nn.Linear(cfg.gnn.dim_inner, cfg.gnn.dim_inner, bias=False)
16 | self.q_pe_head_mapping = nn.Linear(cfg.extra.init_pe_dim, cfg.extra.init_pe_dim, bias=False)
17 | self.k_pe_head_mapping = nn.Linear(cfg.extra.init_pe_dim, cfg.extra.init_pe_dim, bias=False)
18 | self.out_mapping = nn.Linear(cfg.gnn.dim_inner, cfg.gnn.dim_inner, bias=True)
19 | else:
20 | self.x_head_mapping = nn.Identity()
21 | self.q_pe_head_mapping = nn.Identity()
22 | self.k_pe_head_mapping = nn.Identity()
23 | self.out_mapping = nn.Identity()
24 | if cfg.extra.reweigh_self:
25 | self.reweigh_pe = nn.Linear(cfg.extra.init_pe_dim, cfg.extra.init_pe_dim, bias=False)
26 | self.reweigh_x = nn.Linear(cfg.gnn.dim_inner, cfg.gnn.dim_inner, bias=False)
27 | if cfg.extra.reweigh_self == 2:
28 | self.reweigh_pe_2 = nn.Linear(cfg.extra.init_pe_dim, cfg.extra.init_pe_dim, bias=False)
29 |
30 | def forward(self, x, batch):
31 | init_pe = batch.init_pe
32 | log_deg = to_dense_batch(batch.log_deg, batch.batch)[0][..., None]
33 | x = torch.stack([x, x * log_deg], dim=-1)
34 | x = (x * self.deg_coef).sum(dim=-1)
35 |
36 | if cfg.extra.reweigh_self:
37 | pe_reweigh = self.reweigh_pe(init_pe)
38 | x_reweigh = self.reweigh_x(x)
39 | pe_reweigh_2 = self.reweigh_pe_2(init_pe) if cfg.extra.reweigh_self == 2 else pe_reweigh
40 |
41 | x = self.x_head_mapping(x)
42 | q_pe = self.q_pe_head_mapping(init_pe)
43 | k_pe = self.k_pe_head_mapping(init_pe)
44 | first = torch.einsum("bnrd, bnl -> brdl", k_pe, x)
45 | x = torch.einsum("bnrd, brdl -> bnl", q_pe, first)
46 | x = self.out_mapping(x)
47 |
48 | if cfg.extra.reweigh_self:
49 | x = x + (pe_reweigh * pe_reweigh_2).sum(dim=(-1, -2))[..., None] * x_reweigh
50 | return x
51 |
--------------------------------------------------------------------------------
/gssc/loader/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/gssc/loader/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Graph-COM/GSSC/ebb0ce4c52b6904ff95514b68f4a72aa48dfeb0f/gssc/loader/dataset/__init__.py
--------------------------------------------------------------------------------
/gssc/loader/dataset/aqsol_molecules.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import shutil
4 | import pickle
5 |
6 | import torch
7 | from tqdm import tqdm
8 | from torch_geometric.data import (InMemoryDataset, Data, download_url,
9 | extract_zip)
10 | from torch_geometric.utils import add_self_loops
11 |
12 |
13 | class AQSOL(InMemoryDataset):
14 | r"""The AQSOL dataset from Benchmarking GNNs (Dwivedi et al., 2020) is based on AqSolDB
15 | (Sorkun et al., 2019) which is a standardized database of 9,982 molecular graphs with
16 | their aqueous solubility values, collected from 9 different data sources.
17 |
18 | The aqueous solubility targets are collected from experimental measurements and standardized
19 | to LogS units in AqSolDB. These final values as the property to regress in the AQSOL dataset
20 | which is the resultant collection in 'Benchmarking GNNs' after filtering out few graphs
21 | with no bonds/edges and a small number of graphs with missing node feature values.
22 |
23 | Thus, the total molecular graphs are 9,823. For each molecular graph, the node features are the
24 | types f heavy atoms and the edge features are the types of bonds between them, similar as ZINC.
25 |
26 | Size of Dataset: 9,982 molecules.
27 | Split: Scaffold split (8:1:1) following same code as OGB.
28 | After cleaning: 7,831 train / 996 val / 996 test
29 | Number of (unique) atoms: 65
30 | Number of (unique) bonds: 5
31 | Performance Metric: MAE, same as ZINC
32 |
33 | Atom Dict: {'Br': 0, 'C': 1, 'N': 2, 'O': 3, 'Cl': 4, 'Zn': 5, 'F': 6, 'P': 7, 'S': 8, 'Na': 9, 'Al': 10,
34 | 'Si': 11, 'Mo': 12, 'Ca': 13, 'W': 14, 'Pb': 15, 'B': 16, 'V': 17, 'Co': 18, 'Mg': 19, 'Bi': 20, 'Fe': 21,
35 | 'Ba': 22, 'K': 23, 'Ti': 24, 'Sn': 25, 'Cd': 26, 'I': 27, 'Re': 28, 'Sr': 29, 'H': 30, 'Cu': 31, 'Ni': 32,
36 | 'Lu': 33, 'Pr': 34, 'Te': 35, 'Ce': 36, 'Nd': 37, 'Gd': 38, 'Zr': 39, 'Mn': 40, 'As': 41, 'Hg': 42, 'Sb':
37 | 43, 'Cr': 44, 'Se': 45, 'La': 46, 'Dy': 47, 'Y': 48, 'Pd': 49, 'Ag': 50, 'In': 51, 'Li': 52, 'Rh': 53,
38 | 'Nb': 54, 'Hf': 55, 'Cs': 56, 'Ru': 57, 'Au': 58, 'Sm': 59, 'Ta': 60, 'Pt': 61, 'Ir': 62, 'Be': 63, 'Ge': 64}
39 |
40 | Bond Dict: {'NONE': 0, 'SINGLE': 1, 'DOUBLE': 2, 'AROMATIC': 3, 'TRIPLE': 4}
41 |
42 | Args:
43 | root (string): Root directory where the dataset should be saved.
44 | transform (callable, optional): A function/transform that takes in an
45 | :obj:`torch_geometric.data.Data` object and returns a transformed
46 | version. The data object will be transformed before every access.
47 | (default: :obj:`None`)
48 | pre_transform (callable, optional): A function/transform that takes in
49 | an :obj:`torch_geometric.data.Data` object and returns a
50 | transformed version. The data object will be transformed before
51 | being saved to disk. (default: :obj:`None`)
52 | pre_filter (callable, optional): A function that takes in an
53 | :obj:`torch_geometric.data.Data` object and returns a boolean
54 | value, indicating whether the data object should be included in the
55 | final dataset. (default: :obj:`None`)
56 | """
57 |
58 | url = 'https://www.dropbox.com/s/lzu9lmukwov12kt/aqsol_graph_raw.zip?dl=1'
59 |
60 | def __init__(self, root, split='train', transform=None, pre_transform=None,
61 | pre_filter=None):
62 | self.name = "AQSOL"
63 | assert split in ['train', 'val', 'test']
64 | super().__init__(root, transform, pre_transform, pre_filter)
65 | path = osp.join(self.processed_dir, f'{split}.pt')
66 | self.data, self.slices = torch.load(path)
67 |
68 |
69 | @property
70 | def raw_file_names(self):
71 | return ['train.pickle', 'val.pickle', 'test.pickle']
72 |
73 | @property
74 | def processed_file_names(self):
75 | return ['train.pt', 'val.pt', 'test.pt']
76 |
77 | def download(self):
78 | shutil.rmtree(self.raw_dir)
79 | path = download_url(self.url, self.root)
80 | extract_zip(path, self.root)
81 | os.rename(osp.join(self.root, 'asqol_graph_raw'), self.raw_dir)
82 | os.unlink(path)
83 |
84 | def process(self):
85 | for split in ['train', 'val', 'test']:
86 | with open(osp.join(self.raw_dir, f'{split}.pickle'), 'rb') as f:
87 | graphs = pickle.load(f)
88 |
89 | indices = range(len(graphs))
90 |
91 | pbar = tqdm(total=len(indices))
92 | pbar.set_description(f'Processing {split} dataset')
93 |
94 | data_list = []
95 | for idx in indices:
96 | graph = graphs[idx]
97 |
98 | """
99 | Each `graph` is a tuple (x, edge_attr, edge_index, y)
100 | Shape of x : [num_nodes, 1]
101 | Shape of edge_attr : [num_edges]
102 | Shape of edge_index : [2, num_edges]
103 | Shape of y : [1]
104 | """
105 |
106 | x = torch.LongTensor(graph[0]).unsqueeze(-1)
107 | edge_attr = torch.LongTensor(graph[1])#.unsqueeze(-1)
108 | edge_index = torch.LongTensor(graph[2])
109 | y = torch.tensor(graph[3])
110 |
111 | data = Data(edge_index=edge_index)
112 |
113 | if edge_index.shape[1] == 0:
114 | continue # skipping for graphs with no bonds/edges
115 |
116 | if data.num_nodes != len(x):
117 | continue # cleaning <10 graphs with this discrepancy
118 |
119 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
120 | y=y)
121 |
122 | if self.pre_filter is not None and not self.pre_filter(data):
123 | continue
124 |
125 | if self.pre_transform is not None:
126 | data = self.pre_transform(data)
127 |
128 | data_list.append(data)
129 | pbar.update(1)
130 |
131 | pbar.close()
132 | torch.save(self.collate(data_list),
133 | osp.join(self.processed_dir, f'{split}.pt'))
134 |
--------------------------------------------------------------------------------
/gssc/loader/dataset/count_cycle.py:
--------------------------------------------------------------------------------
1 | """
2 | modified from https://github.com/GraphPKU/I2GNN/blob/master/data_processing.py
3 | """
4 |
5 | import numpy as np
6 | import torch
7 | from torch_geometric.data import InMemoryDataset, Data
8 | import os
9 | from torch_geometric.data import Data, InMemoryDataset
10 | import scipy.io as scio
11 |
12 |
13 | class CountCycle(InMemoryDataset):
14 | def __init__(
15 | self,
16 | dataname="count_cycle",
17 | root="dataset",
18 | processed_name="processed",
19 | split="train",
20 | yidx: int = 0,
21 | ymean: float = 0,
22 | ystd: float = 1,
23 | ymean_log: float = 0,
24 | ystd_log: float = 1,
25 | replace=False,
26 | transform=None,
27 | ):
28 | self.root = root
29 | self.dataname = dataname
30 | self.raw = os.path.join(root, dataname)
31 | self.processed = os.path.join(root, dataname, processed_name)
32 | super(CountCycle, self).__init__(root=root, transform=transform, pre_transform=None, pre_filter=None)
33 | split_id = 0 if split == "train" else 1 if split == "val" else 2
34 | data, slices = torch.load(self.processed_paths[split_id])
35 |
36 | data.log_y = torch.log10(data.y + 1.0)
37 | data.log_y = (data.log_y[:, [yidx]] - ymean_log) / ystd_log
38 | data.log_y = data.log_y.reshape(-1)
39 |
40 | data.y = (data.y[:, [yidx]] - ymean) / ystd
41 | data.y = data.y.reshape(-1)
42 |
43 | if replace:
44 | data.y = data.log_y
45 |
46 | self.data, self.slices = data, slices
47 | self.mean, self.std, self.log_mean, self.log_std = ymean, ystd, ymean_log, ystd_log
48 | # ((10**((data.log_y * ystd_log) + ymean_log) - 1) - ymean) / ystd
49 |
50 | @property
51 | def raw_dir(self):
52 | name = "raw"
53 | return os.path.join(self.root, self.dataname, name)
54 |
55 | @property
56 | def processed_dir(self):
57 | return self.processed
58 |
59 | @property
60 | def raw_file_names(self):
61 | names = ["data"]
62 | return ["{}.mat".format(name) for name in names]
63 |
64 | @property
65 | def processed_file_names(self):
66 | return ["data_tr.pt", "data_val.pt", "data_te.pt"]
67 |
68 | def adj2data(self, A, y):
69 | # x: (n, d), A: (e, n, n)
70 | # begin, end = np.where(np.sum(A, axis=0) == 1.)
71 | begin, end = np.where(A == 1.0)
72 | edge_index = torch.tensor(np.array([begin, end]))
73 | num_nodes = A.shape[0]
74 | if y.ndim == 1:
75 | y = y.reshape([1, -1])
76 | x = torch.ones((num_nodes, 1), dtype=torch.long)
77 | return Data(x=x, edge_index=edge_index, y=torch.tensor(y), num_nodes=torch.tensor([num_nodes]))
78 |
79 | def process(self):
80 | # process npy data into pyg.Data
81 | print("Processing data from " + self.raw_dir + "...")
82 | raw_data = scio.loadmat(self.raw_paths[0])
83 | if raw_data["F"].shape[0] == 1:
84 | data_list_all = [
85 | [self.adj2data(raw_data["A"][0][i], raw_data["F"][0][i]) for i in idx]
86 | for idx in [raw_data["train_idx"][0], raw_data["val_idx"][0], raw_data["test_idx"][0]]
87 | ]
88 | else:
89 | data_list_all = [
90 | [self.adj2data(A, y) for A, y in zip(raw_data["A"][0][idx][0], raw_data["F"][idx][0])]
91 | for idx in [raw_data["train_idx"], raw_data["val_idx"], raw_data["test_idx"]]
92 | ]
93 | for save_path, data_list in zip(self.processed_paths, data_list_all):
94 | print("pre-transforming for data at" + save_path)
95 | if self.pre_filter is not None:
96 | data_list = [data for data in data_list if self.pre_filter(data)]
97 | if self.pre_transform is not None:
98 | temp = []
99 | for i, data in enumerate(data_list):
100 | if i % 100 == 0:
101 | print("Pre-processing %d/%d" % (i, len(data_list)))
102 | temp.append(self.pre_transform(data))
103 | data_list = temp
104 | # data_list = [self.pre_transform(data) for data in data_list]
105 | data, slices = self.collate(data_list)
106 | torch.save((data, slices), save_path)
107 |
--------------------------------------------------------------------------------
/gssc/loader/dataset/malnet_tiny.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Callable, List
2 |
3 | import os
4 | import glob
5 | import os.path as osp
6 |
7 | import torch
8 | from torch_geometric.data import (InMemoryDataset, Data, download_url,
9 | extract_tar, extract_zip)
10 | from torch_geometric.utils import remove_isolated_nodes
11 |
12 | """
13 | This is a local copy of MalNetTiny class from PyG
14 | https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/datasets/malnet_tiny.py
15 |
16 | TODO: Delete and use PyG's version once it is part of a released version.
17 | At the time of writing this class is in the main PyG github branch but is not
18 | included in the current latest released version 2.0.2.
19 | """
20 |
21 | class MalNetTiny(InMemoryDataset):
22 | r"""The MalNet Tiny dataset from the
23 | `"A Large-Scale Database for Graph Representation Learning"
24 | `_ paper.
25 | :class:`MalNetTiny` contains 5,000 malicious and benign software function
26 | call graphs across 5 different types. Each graph contains at most 5k nodes.
27 |
28 | Args:
29 | root (string): Root directory where the dataset should be saved.
30 | transform (callable, optional): A function/transform that takes in an
31 | :obj:`torch_geometric.data.Data` object and returns a transformed
32 | version. The data object will be transformed before every access.
33 | (default: :obj:`None`)
34 | pre_transform (callable, optional): A function/transform that takes in
35 | an :obj:`torch_geometric.data.Data` object and returns a
36 | transformed version. The data object will be transformed before
37 | being saved to disk. (default: :obj:`None`)
38 | pre_filter (callable, optional): A function that takes in an
39 | :obj:`torch_geometric.data.Data` object and returns a boolean
40 | value, indicating whether the data object should be included in the
41 | final dataset. (default: :obj:`None`)
42 | """
43 |
44 | url = 'http://malnet.cc.gatech.edu/graph-data/malnet-graphs-tiny.tar.gz'
45 | # 70/10/20 train, val, test split by type
46 | split_url = 'http://malnet.cc.gatech.edu/split-info/split_info_tiny.zip'
47 |
48 | def __init__(self, root: str, transform: Optional[Callable] = None,
49 | pre_transform: Optional[Callable] = None,
50 | pre_filter: Optional[Callable] = None):
51 | super().__init__(root, transform, pre_transform, pre_filter)
52 | self.data, self.slices = torch.load(self.processed_paths[0])
53 |
54 | @property
55 | def raw_file_names(self) -> List[str]:
56 | folders = ['addisplay', 'adware', 'benign', 'downloader', 'trojan']
57 | return [osp.join('malnet-graphs-tiny', folder) for folder in folders]
58 |
59 | @property
60 | def processed_file_names(self) -> List[str]:
61 | return ['data.pt', 'split_dict.pt']
62 |
63 | def download(self):
64 | path = download_url(self.url, self.raw_dir)
65 | extract_tar(path, self.raw_dir)
66 | os.unlink(path)
67 | path = download_url(self.split_url, self.raw_dir)
68 | extract_zip(path, self.raw_dir)
69 | os.unlink(path)
70 |
71 | def process(self):
72 | data_list = []
73 | split_dict = {'train': [], 'valid': [], 'test': []}
74 |
75 | parse = lambda f: set([x.split('/')[-1]
76 | for x in f.read().split('\n')[:-1]]) # -1 for empty line at EOF
77 | split_dir = osp.join(self.raw_dir, 'split_info_tiny', 'type')
78 | with open(osp.join(split_dir, 'train.txt'), 'r') as f:
79 | train_names = parse(f)
80 | assert len(train_names) == 3500
81 | with open(osp.join(split_dir, 'val.txt'), 'r') as f:
82 | val_names = parse(f)
83 | assert len(val_names) == 500
84 | with open(osp.join(split_dir, 'test.txt'), 'r') as f:
85 | test_names = parse(f)
86 | assert len(test_names) == 1000
87 |
88 | for y, raw_path in enumerate(self.raw_paths):
89 | raw_path = osp.join(raw_path, os.listdir(raw_path)[0])
90 | filenames = glob.glob(osp.join(raw_path, '*.edgelist'))
91 |
92 | for filename in filenames:
93 | with open(filename, 'r') as f:
94 | edges = f.read().split('\n')[5:-1]
95 | edge_index = [[int(s) for s in edge.split()] for edge in edges]
96 | edge_index = torch.tensor(edge_index).t().contiguous()
97 | # Remove isolated nodes, including those with only a self-loop
98 | edge_index = remove_isolated_nodes(edge_index)[0]
99 | num_nodes = int(edge_index.max()) + 1
100 | data = Data(edge_index=edge_index, y=y, num_nodes=num_nodes)
101 | data_list.append(data)
102 |
103 | ind = len(data_list) - 1
104 | graph_id = osp.splitext(osp.basename(filename))[0]
105 | if graph_id in train_names:
106 | split_dict['train'].append(ind)
107 | elif graph_id in val_names:
108 | split_dict['valid'].append(ind)
109 | elif graph_id in test_names:
110 | split_dict['test'].append(ind)
111 | else:
112 | raise ValueError(f'No split assignment for "{graph_id}".')
113 |
114 | if self.pre_filter is not None:
115 | data_list = [data for data in data_list if self.pre_filter(data)]
116 |
117 | if self.pre_transform is not None:
118 | data_list = [self.pre_transform(data) for data in data_list]
119 |
120 | torch.save(self.collate(data_list), self.processed_paths[0])
121 | torch.save(split_dict, self.processed_paths[1])
122 |
123 | def get_idx_split(self):
124 | return torch.load(self.processed_paths[1])
125 |
--------------------------------------------------------------------------------
/gssc/loader/dataset/peptides_functional.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os.path as osp
3 | import pickle
4 | import shutil
5 |
6 | import pandas as pd
7 | import torch
8 | from ogb.utils import smiles2graph
9 | from ogb.utils.torch_util import replace_numpy_with_torchtensor
10 | from ogb.utils.url import decide_download
11 | from torch_geometric.data import Data, InMemoryDataset, download_url
12 | from tqdm import tqdm
13 |
14 |
15 | class PeptidesFunctionalDataset(InMemoryDataset):
16 | def __init__(self, root='datasets', smiles2graph=smiles2graph,
17 | transform=None, pre_transform=None):
18 | """
19 | PyG dataset of 15,535 peptides represented as their molecular graph
20 | (SMILES) with 10-way multi-task binary classification of their
21 | functional classes.
22 |
23 | The goal is use the molecular representation of peptides instead
24 | of amino acid sequence representation ('peptide_seq' field in the file,
25 | provided for possible baseline benchmarking but not used here) to test
26 | GNNs' representation capability.
27 |
28 | The 10 classes represent the following functional classes (in order):
29 | ['antifungal', 'cell_cell_communication', 'anticancer',
30 | 'drug_delivery_vehicle', 'antimicrobial', 'antiviral',
31 | 'antihypertensive', 'antibacterial', 'antiparasitic', 'toxic']
32 |
33 | Args:
34 | root (string): Root directory where the dataset should be saved.
35 | smiles2graph (callable): A callable function that converts a SMILES
36 | string into a graph object. We use the OGB featurization.
37 | * The default smiles2graph requires rdkit to be installed *
38 | """
39 |
40 | self.original_root = root
41 | self.smiles2graph = smiles2graph
42 | self.folder = osp.join(root, 'peptides-functional')
43 |
44 | self.url = 'https://www.dropbox.com/s/ol2v01usvaxbsr8/peptide_multi_class_dataset.csv.gz?dl=1'
45 | self.version = '701eb743e899f4d793f0e13c8fa5a1b4' # MD5 hash of the intended dataset file
46 | self.url_stratified_split = 'https://www.dropbox.com/s/j4zcnx2eipuo0xz/splits_random_stratified_peptide.pickle?dl=1'
47 | self.md5sum_stratified_split = '5a0114bdadc80b94fc7ae974f13ef061'
48 |
49 | # Check version and update if necessary.
50 | release_tag = osp.join(self.folder, self.version)
51 | if osp.isdir(self.folder) and (not osp.exists(release_tag)):
52 | print(f"{self.__class__.__name__} has been updated.")
53 | if input("Will you update the dataset now? (y/N)\n").lower() == 'y':
54 | shutil.rmtree(self.folder)
55 |
56 | super().__init__(self.folder, transform, pre_transform)
57 | self.data, self.slices = torch.load(self.processed_paths[0])
58 |
59 | @property
60 | def raw_file_names(self):
61 | return 'peptide_multi_class_dataset.csv.gz'
62 |
63 | @property
64 | def processed_file_names(self):
65 | return 'geometric_data_processed.pt'
66 |
67 | def _md5sum(self, path):
68 | hash_md5 = hashlib.md5()
69 | with open(path, 'rb') as f:
70 | buffer = f.read()
71 | hash_md5.update(buffer)
72 | return hash_md5.hexdigest()
73 |
74 | def download(self):
75 | if decide_download(self.url):
76 | path = download_url(self.url, self.raw_dir)
77 | # Save to disk the MD5 hash of the downloaded file.
78 | hash = self._md5sum(path)
79 | if hash != self.version:
80 | raise ValueError("Unexpected MD5 hash of the downloaded file")
81 | open(osp.join(self.root, hash), 'w').close()
82 | # Download train/val/test splits.
83 | path_split1 = download_url(self.url_stratified_split, self.root)
84 | assert self._md5sum(path_split1) == self.md5sum_stratified_split
85 | else:
86 | print('Stop download.')
87 | exit(-1)
88 |
89 | def process(self):
90 | data_df = pd.read_csv(osp.join(self.raw_dir,
91 | 'peptide_multi_class_dataset.csv.gz'))
92 | smiles_list = data_df['smiles']
93 |
94 | print('Converting SMILES strings into graphs...')
95 | data_list = []
96 | for i in tqdm(range(len(smiles_list))):
97 | data = Data()
98 |
99 | smiles = smiles_list[i]
100 | graph = self.smiles2graph(smiles)
101 |
102 | assert (len(graph['edge_feat']) == graph['edge_index'].shape[1])
103 | assert (len(graph['node_feat']) == graph['num_nodes'])
104 |
105 | data.__num_nodes__ = int(graph['num_nodes'])
106 | data.edge_index = torch.from_numpy(graph['edge_index']).to(
107 | torch.int64)
108 | data.edge_attr = torch.from_numpy(graph['edge_feat']).to(
109 | torch.int64)
110 | data.x = torch.from_numpy(graph['node_feat']).to(torch.int64)
111 | data.y = torch.Tensor([eval(data_df['labels'].iloc[i])])
112 |
113 | data_list.append(data)
114 |
115 | if self.pre_transform is not None:
116 | data_list = [self.pre_transform(data) for data in data_list]
117 |
118 | data, slices = self.collate(data_list)
119 |
120 | print('Saving...')
121 | torch.save((data, slices), self.processed_paths[0])
122 |
123 | def get_idx_split(self):
124 | """ Get dataset splits.
125 |
126 | Returns:
127 | Dict with 'train', 'val', 'test', splits indices.
128 | """
129 | split_file = osp.join(self.root,
130 | "splits_random_stratified_peptide.pickle")
131 | with open(split_file, 'rb') as f:
132 | splits = pickle.load(f)
133 | split_dict = replace_numpy_with_torchtensor(splits)
134 | return split_dict
135 |
136 |
137 | if __name__ == '__main__':
138 | dataset = PeptidesFunctionalDataset()
139 | print(dataset)
140 | print(dataset.data.edge_index)
141 | print(dataset.data.edge_index.shape)
142 | print(dataset.data.x.shape)
143 | print(dataset[100])
144 | print(dataset[100].y)
145 | print(dataset.get_idx_split())
146 |
--------------------------------------------------------------------------------
/gssc/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/gssc/loss/l1.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch_geometric.graphgym.config import cfg
3 | from torch_geometric.graphgym.register import register_loss
4 |
5 |
6 | @register_loss('l1_losses')
7 | def l1_losses(pred, true):
8 | if cfg.model.loss_fun == 'l1':
9 | l1_loss = nn.L1Loss()
10 | loss = l1_loss(pred, true)
11 | return loss, pred
12 | elif cfg.model.loss_fun == 'smoothl1':
13 | l1_loss = nn.SmoothL1Loss()
14 | loss = l1_loss(pred, true)
15 | return loss, pred
16 | elif cfg.model.loss_fun == 'mse':
17 | l1_loss = nn.MSELoss()
18 | loss = l1_loss(pred, true)
19 | return loss, pred
20 |
--------------------------------------------------------------------------------
/gssc/loss/multilabel_classification_loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch_geometric.graphgym.config import cfg
3 | from torch_geometric.graphgym.register import register_loss
4 |
5 |
6 | @register_loss('multilabel_cross_entropy')
7 | def multilabel_cross_entropy(pred, true):
8 | """Multilabel cross-entropy loss.
9 | """
10 | if cfg.dataset.task_type == 'classification_multilabel':
11 | if cfg.model.loss_fun != 'cross_entropy':
12 | raise ValueError("Only 'cross_entropy' loss_fun supported with "
13 | "'classification_multilabel' task_type.")
14 | bce_loss = nn.BCEWithLogitsLoss()
15 | is_labeled = true == true # Filter our nans.
16 | return bce_loss(pred[is_labeled], true[is_labeled].float()), pred
17 |
--------------------------------------------------------------------------------
/gssc/loss/subtoken_prediction_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.graphgym.config import cfg
3 | from torch_geometric.graphgym.register import register_loss
4 |
5 |
6 | @register_loss('subtoken_cross_entropy')
7 | def subtoken_cross_entropy(pred_list, true):
8 | """Subtoken prediction cross-entropy loss for ogbg-code2.
9 | """
10 | if cfg.dataset.task_type == 'subtoken_prediction':
11 | if cfg.model.loss_fun != 'cross_entropy':
12 | raise ValueError("Only 'cross_entropy' loss_fun supported with "
13 | "'subtoken_prediction' task_type.")
14 | multicls_criterion = torch.nn.CrossEntropyLoss()
15 | loss = 0
16 | for i in range(len(pred_list)):
17 | loss += multicls_criterion(pred_list[i].to(torch.float32), true['y_arr'][:, i])
18 | loss = loss / len(pred_list)
19 |
20 | return loss, pred_list
21 |
--------------------------------------------------------------------------------
/gssc/loss/weighted_cross_entropy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch_geometric.graphgym.config import cfg
4 | from torch_geometric.graphgym.register import register_loss
5 |
6 |
7 | @register_loss('weighted_cross_entropy')
8 | def weighted_cross_entropy(pred, true):
9 | """Weighted cross-entropy for unbalanced classes.
10 | """
11 | if cfg.model.loss_fun == 'weighted_cross_entropy':
12 | # calculating label weights for weighted loss computation
13 | V = true.size(0)
14 | n_classes = pred.shape[1] if pred.ndim > 1 else 2
15 | label_count = torch.bincount(true)
16 | label_count = label_count[label_count.nonzero(as_tuple=True)].squeeze()
17 | cluster_sizes = torch.zeros(n_classes, device=pred.device).long()
18 | cluster_sizes[torch.unique(true)] = label_count
19 | weight = (V - cluster_sizes).float() / V
20 | weight *= (cluster_sizes > 0).float()
21 | # multiclass
22 | if pred.ndim > 1:
23 | pred = F.log_softmax(pred, dim=-1)
24 | return F.nll_loss(pred, true, weight=weight), pred
25 | # binary
26 | else:
27 | loss = F.binary_cross_entropy_with_logits(pred, true.float(),
28 | weight=weight[true])
29 | return loss, torch.sigmoid(pred)
30 |
--------------------------------------------------------------------------------
/gssc/metrics_ogb.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.metrics import roc_auc_score, average_precision_score
3 |
4 | """
5 | Evaluation functions from OGB.
6 | https://github.com/snap-stanford/ogb/blob/master/ogb/graphproppred/evaluate.py
7 | """
8 |
9 | def eval_rocauc(y_true, y_pred):
10 | '''
11 | compute ROC-AUC averaged across tasks
12 | '''
13 |
14 | rocauc_list = []
15 |
16 | for i in range(y_true.shape[1]):
17 | # AUC is only defined when there is at least one positive data.
18 | if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0:
19 | # ignore nan values
20 | is_labeled = y_true[:, i] == y_true[:, i]
21 | rocauc_list.append(
22 | roc_auc_score(y_true[is_labeled, i], y_pred[is_labeled, i]))
23 |
24 | if len(rocauc_list) == 0:
25 | raise RuntimeError(
26 | 'No positively labeled data available. Cannot compute ROC-AUC.')
27 |
28 | return {'rocauc': sum(rocauc_list) / len(rocauc_list)}
29 |
30 |
31 | def eval_ap(y_true, y_pred):
32 | '''
33 | compute Average Precision (AP) averaged across tasks
34 | '''
35 |
36 | ap_list = []
37 |
38 | for i in range(y_true.shape[1]):
39 | # AUC is only defined when there is at least one positive data.
40 | if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0:
41 | # ignore nan values
42 | is_labeled = y_true[:, i] == y_true[:, i]
43 | ap = average_precision_score(y_true[is_labeled, i],
44 | y_pred[is_labeled, i])
45 |
46 | ap_list.append(ap)
47 |
48 | if len(ap_list) == 0:
49 | raise RuntimeError(
50 | 'No positively labeled data available. Cannot compute Average Precision.')
51 |
52 | return {'ap': sum(ap_list) / len(ap_list)}
53 |
54 |
55 | def eval_rmse(y_true, y_pred):
56 | '''
57 | compute RMSE score averaged across tasks
58 | '''
59 | rmse_list = []
60 |
61 | for i in range(y_true.shape[1]):
62 | # ignore nan values
63 | is_labeled = y_true[:, i] == y_true[:, i]
64 | rmse_list.append(np.sqrt(
65 | ((y_true[is_labeled, i] - y_pred[is_labeled, i]) ** 2).mean()))
66 |
67 | return {'rmse': sum(rmse_list) / len(rmse_list)}
68 |
69 |
70 | def eval_acc(y_true, y_pred):
71 | acc_list = []
72 |
73 | for i in range(y_true.shape[1]):
74 | is_labeled = y_true[:, i] == y_true[:, i]
75 | correct = y_true[is_labeled, i] == y_pred[is_labeled, i]
76 | acc_list.append(float(np.sum(correct)) / len(correct))
77 |
78 | return {'acc': sum(acc_list) / len(acc_list)}
79 |
80 |
81 | def eval_F1(seq_ref, seq_pred):
82 | # '''
83 | # compute F1 score averaged over samples
84 | # '''
85 |
86 | precision_list = []
87 | recall_list = []
88 | f1_list = []
89 |
90 | for l, p in zip(seq_ref, seq_pred):
91 | label = set(l)
92 | prediction = set(p)
93 | true_positive = len(label.intersection(prediction))
94 | false_positive = len(prediction - label)
95 | false_negative = len(label - prediction)
96 |
97 | if true_positive + false_positive > 0:
98 | precision = true_positive / (true_positive + false_positive)
99 | else:
100 | precision = 0
101 |
102 | if true_positive + false_negative > 0:
103 | recall = true_positive / (true_positive + false_negative)
104 | else:
105 | recall = 0
106 | if precision + recall > 0:
107 | f1 = 2 * precision * recall / (precision + recall)
108 | else:
109 | f1 = 0
110 |
111 | precision_list.append(precision)
112 | recall_list.append(recall)
113 | f1_list.append(f1)
114 |
115 | return {'precision': np.average(precision_list),
116 | 'recall': np.average(recall_list),
117 | 'F1': np.average(f1_list)}
118 |
--------------------------------------------------------------------------------
/gssc/network/MaskedReduce.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor, BoolTensor
3 | '''
4 | x (B, N, d)
5 | mask (B, N, 1)
6 | '''
7 | def maskedSum(x: Tensor, mask: BoolTensor, dim: int):
8 | '''
9 | mask true elements
10 | '''
11 | return torch.sum(torch.where(mask, 0, x), dim=dim)
12 |
13 | def maskedMean(x: Tensor, mask: BoolTensor, dim: int, gsize: Tensor = None):
14 | '''
15 | mask true elements
16 | '''
17 | if gsize is None:
18 | gsize = x.shape[dim] - torch.sum(mask, dim=dim)
19 | return torch.sum(torch.where(mask, 0, x), dim=dim)/gsize
20 |
21 | def maskedMax(x: Tensor, mask: BoolTensor, dim: int):
22 | return torch.max(torch.where(mask, -torch.inf, x), dim=dim)[0]
23 |
24 | def maskedMin(x: Tensor, mask: BoolTensor, dim: int):
25 | return torch.min(torch.where(mask, torch.inf, x), dim=dim)[0]
26 |
27 | def maskednone(x: Tensor, mask: BoolTensor, dim: int):
28 | return x
29 |
30 | reduce_dict = {
31 | "sum": maskedSum,
32 | "mean": maskedMean,
33 | "max": maskedMax,
34 | "none": maskednone
35 | }
--------------------------------------------------------------------------------
/gssc/network/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/gssc/network/big_bird.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch_geometric.graphgym.register as register
3 | from torch_geometric.graphgym.config import cfg
4 | from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNPreMP
5 | from torch_geometric.graphgym.register import register_network
6 |
7 | from gssc.layer.bigbird_layer import BigBirdModel as BackboneBigBird
8 |
9 |
10 | @register_network('BigBird')
11 | class BigBird(torch.nn.Module):
12 | """BigBird without edge features.
13 | This model disregards edge features and runs a linear transformer over a set of node features only.
14 | BirBird applies random sparse attention to the input sequence - the longer the sequence the closer it is to O(N)
15 | https://arxiv.org/abs/2007.14062
16 | """
17 |
18 | def __init__(self, dim_in, dim_out):
19 | super().__init__()
20 | self.encoder = FeatureEncoder(dim_in)
21 | dim_in = self.encoder.dim_in
22 |
23 | if cfg.gnn.layers_pre_mp > 0:
24 | self.pre_mp = GNNPreMP(
25 | dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp)
26 | dim_in = cfg.gnn.dim_inner
27 |
28 | assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \
29 | "The inner and hidden dims must match."
30 |
31 | # Copy main Transformer hyperparams to the BigBird config.
32 | cfg.gt.bigbird.layers = cfg.gt.layers
33 | cfg.gt.bigbird.n_heads = cfg.gt.n_heads
34 | cfg.gt.bigbird.dim_hidden = cfg.gt.dim_hidden
35 | cfg.gt.bigbird.dropout = cfg.gt.dropout
36 | self.trf = BackboneBigBird(
37 | config=cfg.gt.bigbird,
38 | )
39 |
40 | GNNHead = register.head_dict[cfg.gnn.head]
41 | self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out)
42 |
43 | def forward(self, batch):
44 | for module in self.children():
45 | batch = module(batch)
46 | return batch
47 |
--------------------------------------------------------------------------------
/gssc/network/custom_gnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch_geometric.graphgym.models.head # noqa, register module
3 | import torch_geometric.graphgym.register as register
4 | from torch_geometric.graphgym.config import cfg
5 | from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNPreMP
6 | from torch_geometric.graphgym.register import register_network
7 |
8 | from gssc.layer.gatedgcn_layer import GatedGCNLayer
9 | from gssc.layer.gine_conv_layer import GINEConvLayer
10 |
11 |
12 | @register_network('custom_gnn')
13 | class CustomGNN(torch.nn.Module):
14 | """
15 | GNN model that customizes the torch_geometric.graphgym.models.gnn.GNN
16 | to support specific handling of new conv layers.
17 | """
18 |
19 | def __init__(self, dim_in, dim_out):
20 | super().__init__()
21 | self.encoder = FeatureEncoder(dim_in)
22 | dim_in = self.encoder.dim_in
23 |
24 | if cfg.gnn.layers_pre_mp > 0:
25 | self.pre_mp = GNNPreMP(
26 | dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp)
27 | dim_in = cfg.gnn.dim_inner
28 |
29 | assert cfg.gnn.dim_inner == dim_in, \
30 | "The inner and hidden dims must match."
31 |
32 | conv_model = self.build_conv_model(cfg.gnn.layer_type)
33 | layers = []
34 | for _ in range(cfg.gnn.layers_mp):
35 | layers.append(conv_model(dim_in,
36 | dim_in,
37 | dropout=cfg.gnn.dropout,
38 | residual=cfg.gnn.residual))
39 | self.gnn_layers = torch.nn.Sequential(*layers)
40 |
41 | GNNHead = register.head_dict[cfg.gnn.head]
42 | self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out)
43 |
44 | def build_conv_model(self, model_type):
45 | if model_type == 'gatedgcnconv':
46 | return GatedGCNLayer
47 | elif model_type == 'gineconv':
48 | return GINEConvLayer
49 | else:
50 | raise ValueError("Model {} unavailable".format(model_type))
51 |
52 | def forward(self, batch):
53 | for module in self.children():
54 | batch = module(batch)
55 | return batch
56 |
--------------------------------------------------------------------------------
/gssc/network/example.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import torch_geometric.graphgym.models.head # noqa, register module
6 | import torch_geometric.graphgym.register as register
7 | import torch_geometric.nn as pyg_nn
8 | from torch_geometric.graphgym.config import cfg
9 | from torch_geometric.graphgym.register import register_network
10 |
11 |
12 | @register_network('example')
13 | class ExampleGNN(torch.nn.Module):
14 | def __init__(self, dim_in, dim_out, num_layers=2, model_type='GCN'):
15 | super().__init__()
16 | conv_model = self.build_conv_model(model_type)
17 | self.convs = nn.ModuleList()
18 | self.convs.append(conv_model(dim_in, dim_in))
19 |
20 | for _ in range(num_layers - 1):
21 | self.convs.append(conv_model(dim_in, dim_in))
22 |
23 | GNNHead = register.head_dict[cfg.dataset.task]
24 | self.post_mp = GNNHead(dim_in=dim_in, dim_out=dim_out)
25 |
26 | def build_conv_model(self, model_type):
27 | if model_type == 'GCN':
28 | return pyg_nn.GCNConv
29 | elif model_type == 'GAT':
30 | return pyg_nn.GATConv
31 | elif model_type == "GraphSage":
32 | return pyg_nn.SAGEConv
33 | else:
34 | raise ValueError(f'Model {model_type} unavailable')
35 |
36 | def forward(self, batch):
37 | x, edge_index = batch.x, batch.edge_index
38 |
39 | for i in range(len(self.convs)):
40 | x = self.convs[i](x, edge_index)
41 | x = F.relu(x)
42 | x = F.dropout(x, p=0.1, training=self.training)
43 |
44 | batch.x = x
45 | batch = self.post_mp(batch)
46 |
47 | return batch
48 |
--------------------------------------------------------------------------------
/gssc/network/gps_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch_geometric.graphgym.register as register
3 | from torch_geometric.graphgym.config import cfg
4 | from torch_geometric.graphgym.models.gnn import GNNPreMP
5 | from torch_geometric.graphgym.models.layer import new_layer_config, BatchNorm1dNode
6 | from torch_geometric.graphgym.register import register_network
7 | from gssc.encoder.ER_edge_encoder import EREdgeEncoder
8 | from gssc.layer.gps_layer import GPSLayer
9 | from gssc.network.utils import InitPEs
10 |
11 |
12 | class FeatureEncoder(torch.nn.Module):
13 | """
14 | Encoding node and edge features
15 |
16 | Args:
17 | dim_in (int): Input feature dimension
18 | """
19 |
20 | def __init__(self, dim_in):
21 | super(FeatureEncoder, self).__init__()
22 | self.dim_in = dim_in
23 | if cfg.dataset.node_encoder:
24 | # Encode integer node features via nn.Embeddings
25 | NodeEncoder = register.node_encoder_dict[cfg.dataset.node_encoder_name]
26 | self.node_encoder = NodeEncoder(cfg.gnn.dim_inner)
27 | if cfg.dataset.node_encoder_bn:
28 | self.node_encoder_bn = BatchNorm1dNode(
29 | new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False, has_bias=False, cfg=cfg)
30 | )
31 | # Update dim_in to reflect the new dimension fo the node features
32 | self.dim_in = cfg.gnn.dim_inner
33 | if cfg.dataset.edge_encoder:
34 | # Hard-set edge dim for PNA.
35 | cfg.gnn.dim_edge = 16 if "PNA" in cfg.gt.layer_type else cfg.gnn.dim_inner
36 | if cfg.dataset.edge_encoder_name == "ER":
37 | self.edge_encoder = EREdgeEncoder(cfg.gnn.dim_edge)
38 | elif cfg.dataset.edge_encoder_name.endswith("+ER"):
39 | EdgeEncoder = register.edge_encoder_dict[cfg.dataset.edge_encoder_name[:-3]]
40 | self.edge_encoder = EdgeEncoder(cfg.gnn.dim_edge - cfg.posenc_ERE.dim_pe)
41 | self.edge_encoder_er = EREdgeEncoder(cfg.posenc_ERE.dim_pe, use_edge_attr=True)
42 | else:
43 | EdgeEncoder = register.edge_encoder_dict[cfg.dataset.edge_encoder_name]
44 | self.edge_encoder = EdgeEncoder(cfg.gnn.dim_edge)
45 |
46 | if cfg.dataset.edge_encoder_bn:
47 | self.edge_encoder_bn = BatchNorm1dNode(
48 | new_layer_config(cfg.gnn.dim_edge, -1, -1, has_act=False, has_bias=False, cfg=cfg)
49 | )
50 |
51 | def forward(self, batch):
52 | for module in self.children():
53 | batch = module(batch)
54 | return batch
55 |
56 |
57 | @register_network("GPSModel")
58 | class GPSModel(torch.nn.Module):
59 | """Multi-scale graph x-former."""
60 |
61 | def __init__(self, dim_in, dim_out):
62 | super().__init__()
63 | self.encoder = FeatureEncoder(dim_in)
64 | dim_in = self.encoder.dim_in
65 |
66 | if cfg.gnn.layers_pre_mp > 0:
67 | self.pre_mp = GNNPreMP(dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp)
68 | dim_in = cfg.gnn.dim_inner
69 |
70 | assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, "The inner and hidden dims must match."
71 |
72 | try:
73 | local_gnn_type, global_model_type = cfg.gt.layer_type.split("+")
74 | except:
75 | raise ValueError(f"Unexpected layer type: {cfg.gt.layer_type}")
76 |
77 | if global_model_type == "GSSC":
78 | self.get_init_pe = InitPEs(cfg.extra.init_pe_dim)
79 |
80 | layers = []
81 | for _ in range(cfg.gt.layers):
82 | layers.append(
83 | GPSLayer(
84 | dim_h=cfg.gt.dim_hidden,
85 | local_gnn_type=local_gnn_type,
86 | global_model_type=global_model_type,
87 | num_heads=cfg.gt.n_heads,
88 | pna_degrees=cfg.gt.pna_degrees,
89 | equivstable_pe=cfg.posenc_EquivStableLapPE.enable,
90 | dropout=cfg.gt.dropout,
91 | attn_dropout=cfg.gt.attn_dropout,
92 | layer_norm=cfg.gt.layer_norm,
93 | batch_norm=cfg.gt.batch_norm,
94 | bigbird_cfg=cfg.gt.bigbird,
95 | )
96 | )
97 | self.layers = torch.nn.Sequential(*layers)
98 |
99 | GNNHead = register.head_dict[cfg.gnn.head]
100 | self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out)
101 |
102 | def forward(self, batch):
103 | batch.all_x = []
104 | for module in self.children():
105 | batch = module(batch)
106 | return batch
107 |
--------------------------------------------------------------------------------
/gssc/network/multi_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch_geometric.graphgym.register as register
3 | from torch_geometric.graphgym.config import cfg
4 | from torch_geometric.graphgym.models.gnn import GNNPreMP
5 | from torch_geometric.graphgym.models.layer import (new_layer_config,
6 | BatchNorm1dNode)
7 | from torch_geometric.graphgym.register import register_network
8 |
9 | from gssc.layer.multi_model_layer import MultiLayer, SingleLayer
10 | from gssc.encoder.ER_edge_encoder import EREdgeEncoder
11 | from gssc.encoder.exp_edge_fixer import ExpanderEdgeFixer
12 |
13 |
14 | class FeatureEncoder(torch.nn.Module):
15 | """
16 | Encoding node and edge features
17 |
18 | Args:
19 | dim_in (int): Input feature dimension
20 | """
21 | def __init__(self, dim_in):
22 | super(FeatureEncoder, self).__init__()
23 | self.dim_in = dim_in
24 | if cfg.dataset.node_encoder:
25 | # Encode integer node features via nn.Embeddings
26 | NodeEncoder = register.node_encoder_dict[
27 | cfg.dataset.node_encoder_name]
28 | self.node_encoder = NodeEncoder(cfg.gnn.dim_inner)
29 | if cfg.dataset.node_encoder_bn:
30 | self.node_encoder_bn = BatchNorm1dNode(
31 | new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False,
32 | has_bias=False, cfg=cfg))
33 | # Update dim_in to reflect the new dimension fo the node features
34 | self.dim_in = cfg.gnn.dim_inner
35 | if cfg.dataset.edge_encoder:
36 | if not hasattr(cfg.gt, 'dim_edge') or cfg.gt.dim_edge is None:
37 | cfg.gt.dim_edge = cfg.gt.dim_hidden
38 |
39 | if cfg.dataset.edge_encoder_name == 'ER':
40 | self.edge_encoder = EREdgeEncoder(cfg.gt.dim_edge)
41 | elif cfg.dataset.edge_encoder_name.endswith('+ER'):
42 | EdgeEncoder = register.edge_encoder_dict[
43 | cfg.dataset.edge_encoder_name[:-3]]
44 | self.edge_encoder = EdgeEncoder(cfg.gt.dim_edge - cfg.posenc_ERE.dim_pe)
45 | self.edge_encoder_er = EREdgeEncoder(cfg.posenc_ERE.dim_pe, use_edge_attr=True)
46 | else:
47 | EdgeEncoder = register.edge_encoder_dict[
48 | cfg.dataset.edge_encoder_name]
49 | self.edge_encoder = EdgeEncoder(cfg.gt.dim_edge)
50 |
51 | if cfg.dataset.edge_encoder_bn:
52 | self.edge_encoder_bn = BatchNorm1dNode(
53 | new_layer_config(cfg.gt.dim_edge, -1, -1, has_act=False,
54 | has_bias=False, cfg=cfg))
55 |
56 | if 'Exphormer' in cfg.gt.layer_type:
57 | self.exp_edge_fixer = ExpanderEdgeFixer(add_edge_index=cfg.prep.add_edge_index,
58 | num_virt_node=cfg.prep.num_virt_node)
59 |
60 | def forward(self, batch):
61 | for module in self.children():
62 | batch = module(batch)
63 | return batch
64 |
65 |
66 | class MultiModel(torch.nn.Module):
67 | """Multiple layer types can be combined here.
68 | """
69 |
70 | def __init__(self, dim_in, dim_out):
71 | super().__init__()
72 | self.encoder = FeatureEncoder(dim_in)
73 | dim_in = self.encoder.dim_in
74 |
75 | if cfg.gnn.layers_pre_mp > 0:
76 | self.pre_mp = GNNPreMP(
77 | dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp)
78 | dim_in = cfg.gnn.dim_inner
79 |
80 | assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \
81 | "The inner and hidden dims must match."
82 |
83 | try:
84 | model_types = cfg.gt.layer_type.split('+')
85 | except:
86 | raise ValueError(f"Unexpected layer type: {cfg.gt.layer_type}")
87 | layers = []
88 | for _ in range(cfg.gt.layers):
89 | layers.append(MultiLayer(
90 | dim_h=cfg.gt.dim_hidden,
91 | model_types=model_types,
92 | num_heads=cfg.gt.n_heads,
93 | pna_degrees=cfg.gt.pna_degrees,
94 | equivstable_pe=cfg.posenc_EquivStableLapPE.enable,
95 | dropout=cfg.gt.dropout,
96 | attn_dropout=cfg.gt.attn_dropout,
97 | layer_norm=cfg.gt.layer_norm,
98 | batch_norm=cfg.gt.batch_norm,
99 | bigbird_cfg=cfg.gt.bigbird,
100 | exp_edges_cfg=cfg.prep
101 | ))
102 | self.layers = torch.nn.Sequential(*layers)
103 |
104 | GNNHead = register.head_dict[cfg.gnn.head]
105 | self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out)
106 |
107 | def forward(self, batch):
108 | for module in self.children():
109 | batch = module(batch)
110 | return batch
111 |
112 |
113 | class SingleModel(torch.nn.Module):
114 | """A single layer type can be used without FFN between the layers.
115 | """
116 |
117 | def __init__(self, dim_in, dim_out):
118 | super().__init__()
119 | self.encoder = FeatureEncoder(dim_in)
120 | dim_in = self.encoder.dim_in
121 |
122 | if cfg.gnn.layers_pre_mp > 0:
123 | self.pre_mp = GNNPreMP(
124 | dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp)
125 | dim_in = cfg.gnn.dim_inner
126 |
127 | assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \
128 | "The inner and hidden dims must match."
129 |
130 | layers = []
131 | for _ in range(cfg.gt.layers):
132 | layers.append(SingleLayer(
133 | dim_h=cfg.gt.dim_hidden,
134 | model_type=cfg.gt.layer_type,
135 | num_heads=cfg.gt.n_heads,
136 | pna_degrees=cfg.gt.pna_degrees,
137 | equivstable_pe=cfg.posenc_EquivStableLapPE.enable,
138 | dropout=cfg.gt.dropout,
139 | attn_dropout=cfg.gt.attn_dropout,
140 | layer_norm=cfg.gt.layer_norm,
141 | batch_norm=cfg.gt.batch_norm,
142 | bigbird_cfg=cfg.gt.bigbird,
143 | exp_edges_cfg=cfg.prep
144 | ))
145 | self.layers = torch.nn.Sequential(*layers)
146 |
147 | GNNHead = register.head_dict[cfg.gnn.head]
148 | self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out)
149 |
150 | def forward(self, batch):
151 | for module in self.children():
152 | batch = module(batch)
153 | return batch
154 |
155 |
156 | register_network('MultiModel', MultiModel)
157 | register_network('SingleModel', SingleModel)
158 |
--------------------------------------------------------------------------------
/gssc/network/norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import List, Callable
3 | from torch_geometric.nn.norm import GraphNorm as PygGN, InstanceNorm as PygIN
4 | from torch import Tensor
5 | import torch.nn as nn
6 |
7 | def expandbatch(x: Tensor, batch: Tensor):
8 | if batch is None:
9 | return x.flatten(0, 1), None
10 | else:
11 | R = x.shape[0]
12 | N = batch[-1] + 1
13 | offset = N*torch.arange(R, device=x.device).reshape(-1, 1)
14 | batch = batch.unsqueeze(0) + offset
15 | return x.flatten(0, 1), batch.flatten()
16 |
17 |
18 | class NormMomentumScheduler:
19 | def __init__(self, mfunc: Callable, initmomentum: float, normtype=nn.BatchNorm1d) -> None:
20 | super().__init__()
21 | self.normtype = normtype
22 | self.mfunc = mfunc
23 | self.epoch = 0
24 | self.initmomentum = initmomentum
25 |
26 | def step(self, model: nn.Module):
27 | ratio = self.mfunc(self.epoch)
28 | if 1-1e-6 None:
39 | super().__init__()
40 | self.num_features = dim
41 |
42 | def forward(self, x):
43 | return x
44 |
45 | class BatchNorm(nn.Module):
46 | def __init__(self, dim, normparam=0.1) -> None:
47 | super().__init__()
48 | self.num_features = dim
49 | self.norm = nn.BatchNorm1d(dim, momentum=normparam)
50 |
51 | def forward(self, x: Tensor):
52 | if x.dim() == 2:
53 | return self.norm(x)
54 | elif x.dim() > 2:
55 | shape = x.shape
56 | x = self.norm(x.flatten(0, -2)).reshape(shape)
57 | return x
58 | else:
59 | raise NotImplementedError
60 |
61 | class LayerNorm(nn.Module):
62 | def __init__(self, dim, normparam=0.1) -> None:
63 | super().__init__()
64 | self.num_features = dim
65 | self.norm = nn.LayerNorm(dim)
66 |
67 | def forward(self, x: Tensor):
68 | return self.norm(x)
69 |
70 | class InstanceNorm(nn.Module):
71 | def __init__(self, dim, normparam=0.1) -> None:
72 | super().__init__()
73 | self.norm = PygIN(dim, momentum=normparam)
74 | self.num_features = dim
75 |
76 | def forward(self, x: Tensor):
77 | if x.dim() == 2:
78 | return self.norm(x)
79 | elif x.dim() > 2:
80 | shape = x.shape
81 | x = self.norm(x.flatten(0, -2)).reshape(shape)
82 | return x
83 | else:
84 | raise NotImplementedError
85 |
86 | normdict = {"bn": BatchNorm, "ln": LayerNorm, "in": InstanceNorm, "none": NoneNorm}
87 | basenormdict = {"bn": nn.BatchNorm1d, "ln": None, "in": PygIN, "gn": None, "none": None}
88 |
89 | if __name__ == "__main__":
90 | x = torch.randn((3,4,5))
91 | batch = torch.tensor((0,0,1,2))
92 | x, batch = expandbatch(x, batch)
93 | print(x.shape, batch)
94 | x = torch.randn((3,4,5))
95 | batch = None
96 | x, batch = expandbatch(x, batch)
97 | print(x.shape, batch)
98 |
99 | print(list(InstanceNorm(1000).modules()))
--------------------------------------------------------------------------------
/gssc/network/performer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch_geometric.graphgym.register as register
3 | from torch_geometric.graphgym.config import cfg
4 | from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNPreMP
5 | from torch_geometric.graphgym.register import register_network
6 |
7 | from gssc.layer.performer_layer import Performer as BackbonePerformer
8 |
9 |
10 | @register_network('Performer')
11 | class Performer(torch.nn.Module):
12 | """Performer without edge features.
13 | This model disregards edge features and runs a linear transformer over a set of node features only.
14 | https://arxiv.org/abs/2009.14794
15 | """
16 |
17 | def __init__(self, dim_in, dim_out):
18 | super().__init__()
19 | self.encoder = FeatureEncoder(dim_in)
20 | dim_in = self.encoder.dim_in
21 |
22 | if cfg.gnn.layers_pre_mp > 0:
23 | self.pre_mp = GNNPreMP(
24 | dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp)
25 | dim_in = cfg.gnn.dim_inner
26 |
27 | assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \
28 | "The inner and hidden dims must match."
29 |
30 | self.trf = BackbonePerformer(
31 | dim=cfg.gt.dim_hidden,
32 | depth=cfg.gt.layers,
33 | heads=cfg.gt.n_heads,
34 | dim_head=cfg.gt.dim_hidden // cfg.gt.n_heads
35 | )
36 |
37 | GNNHead = register.head_dict[cfg.gnn.head]
38 | self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out)
39 |
40 | def forward(self, batch):
41 | for module in self.children():
42 | batch = module(batch)
43 | return batch
44 |
--------------------------------------------------------------------------------
/gssc/network/san_transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch_geometric.graphgym.register as register
3 | from torch_geometric.graphgym.config import cfg
4 | from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNPreMP
5 | from torch_geometric.graphgym.register import register_network
6 |
7 | from gssc.layer.san_layer import SANLayer
8 | from gssc.layer.san2_layer import SAN2Layer
9 |
10 |
11 | @register_network('SANTransformer')
12 | class SANTransformer(torch.nn.Module):
13 | """Spectral Attention Network (SAN) Graph Transformer.
14 | https://arxiv.org/abs/2106.03893
15 | """
16 |
17 | def __init__(self, dim_in, dim_out):
18 | super().__init__()
19 | self.encoder = FeatureEncoder(dim_in)
20 | dim_in = self.encoder.dim_in
21 |
22 | if cfg.gnn.layers_pre_mp > 0:
23 | self.pre_mp = GNNPreMP(
24 | dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp)
25 | dim_in = cfg.gnn.dim_inner
26 |
27 | assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \
28 | "The inner and hidden dims must match."
29 |
30 | fake_edge_emb = torch.nn.Embedding(1, cfg.gt.dim_hidden)
31 | # torch.nn.init.xavier_uniform_(fake_edge_emb.weight.data)
32 | Layer = {
33 | 'SANLayer': SANLayer,
34 | 'SAN2Layer': SAN2Layer
35 | }.get(cfg.gt.layer_type)
36 | layers = []
37 | for _ in range(cfg.gt.layers):
38 | layers.append(Layer(gamma=cfg.gt.gamma,
39 | in_dim=cfg.gt.dim_hidden,
40 | out_dim=cfg.gt.dim_hidden,
41 | num_heads=cfg.gt.n_heads,
42 | secondary_edges=cfg.gt.secondary_edges,
43 | fake_edge_emb=fake_edge_emb,
44 | dropout=cfg.gt.dropout,
45 | layer_norm=cfg.gt.layer_norm,
46 | batch_norm=cfg.gt.batch_norm,
47 | residual=cfg.gt.residual))
48 | self.trf_layers = torch.nn.Sequential(*layers)
49 |
50 | GNNHead = register.head_dict[cfg.gnn.head]
51 | self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out)
52 |
53 | def forward(self, batch):
54 | for module in self.children():
55 | batch = module(batch)
56 | return batch
57 |
--------------------------------------------------------------------------------
/gssc/optimizer/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/gssc/pooling/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/gssc/pooling/example.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.graphgym.register import register_pooling
2 | from torch_scatter import scatter
3 |
4 |
5 | @register_pooling('example')
6 | def global_example_pool(x, batch, size=None):
7 | size = batch.max().item() + 1 if size is None else size
8 | return scatter(x, batch, dim=0, dim_size=size, reduce='add')
9 |
--------------------------------------------------------------------------------
/gssc/stage/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/gssc/stage/example.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | from torch_geometric.graphgym.config import cfg
5 | from torch_geometric.graphgym.models.layer import GeneralLayer
6 | from torch_geometric.graphgym.register import register_stage
7 |
8 |
9 | def GNNLayer(dim_in, dim_out, has_act=True):
10 | return GeneralLayer(cfg.gnn.layer_type, dim_in, dim_out, has_act)
11 |
12 |
13 | @register_stage('example')
14 | class GNNStackStage(nn.Module):
15 | '''Simple Stage that stack GNN layers'''
16 | def __init__(self, dim_in, dim_out, num_layers):
17 | super().__init__()
18 | for i in range(num_layers):
19 | d_in = dim_in if i == 0 else dim_out
20 | layer = GNNLayer(d_in, dim_out)
21 | self.add_module(f'layer{i}', layer)
22 | self.dim_out = dim_out
23 |
24 | def forward(self, batch):
25 | for layer in self.children():
26 | batch = layer(batch)
27 | if cfg.gnn.l2norm:
28 | batch.x = F.normalize(batch.x, p=2, dim=-1)
29 | return batch
30 |
--------------------------------------------------------------------------------
/gssc/train/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/gssc/train/example.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 |
4 | import torch
5 |
6 | from torch_geometric.graphgym.checkpoint import (
7 | clean_ckpt,
8 | load_ckpt,
9 | save_ckpt,
10 | )
11 | from torch_geometric.graphgym.config import cfg
12 | from torch_geometric.graphgym.loss import compute_loss
13 | from torch_geometric.graphgym.register import register_train
14 | from torch_geometric.graphgym.utils.epoch import is_ckpt_epoch, is_eval_epoch
15 |
16 |
17 | def train_epoch(logger, loader, model, optimizer, scheduler):
18 | model.train()
19 | time_start = time.time()
20 | for batch in loader:
21 | optimizer.zero_grad()
22 | batch.to(torch.device(cfg.device))
23 | pred, true = model(batch)
24 | loss, pred_score = compute_loss(pred, true)
25 | loss.backward()
26 | optimizer.step()
27 | logger.update_stats(true=true.detach().cpu(),
28 | pred=pred_score.detach().cpu(), loss=loss.item(),
29 | lr=scheduler.get_last_lr()[0],
30 | time_used=time.time() - time_start,
31 | params=cfg.params)
32 | time_start = time.time()
33 | scheduler.step()
34 |
35 |
36 | def eval_epoch(logger, loader, model):
37 | model.eval()
38 | time_start = time.time()
39 | for batch in loader:
40 | batch.to(torch.device(cfg.device))
41 | pred, true = model(batch)
42 | loss, pred_score = compute_loss(pred, true)
43 | logger.update_stats(true=true.detach().cpu(),
44 | pred=pred_score.detach().cpu(), loss=loss.item(),
45 | lr=0, time_used=time.time() - time_start,
46 | params=cfg.params)
47 | time_start = time.time()
48 |
49 |
50 | @register_train('example')
51 | def train_example(loggers, loaders, model, optimizer, scheduler):
52 | start_epoch = 0
53 | if cfg.train.auto_resume:
54 | start_epoch = load_ckpt(model, optimizer, scheduler,
55 | cfg.train.epoch_resume)
56 | if start_epoch == cfg.optim.max_epoch:
57 | logging.info('Checkpoint found, Task already done')
58 | else:
59 | logging.info('Start from epoch %s', start_epoch)
60 |
61 | num_splits = len(loggers)
62 | for cur_epoch in range(start_epoch, cfg.optim.max_epoch):
63 | train_epoch(loggers[0], loaders[0], model, optimizer, scheduler)
64 | loggers[0].write_epoch(cur_epoch)
65 | if is_eval_epoch(cur_epoch):
66 | for i in range(1, num_splits):
67 | eval_epoch(loggers[i], loaders[i], model)
68 | loggers[i].write_epoch(cur_epoch)
69 | if is_ckpt_epoch(cur_epoch):
70 | save_ckpt(model, optimizer, scheduler, cur_epoch)
71 | for logger in loggers:
72 | logger.close()
73 | if cfg.train.ckpt_clean:
74 | clean_ckpt()
75 |
76 | logging.info('Task done, results saved in %s', cfg.run_dir)
77 |
--------------------------------------------------------------------------------
/gssc/transform/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/gssc/transform/expander_edges.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import scipy as sp
4 | from typing import Any, Optional
5 | import torch
6 | from gssc.transform.dist_transforms import laplacian_eigenv
7 |
8 |
9 | def generate_random_regular_graph1(num_nodes, degree, rng=None):
10 | """Generates a random 2d-regular graph with n nodes using permutations algorithm.
11 | Returns the list of edges. This list is symmetric; i.e., if
12 | (x, y) is an edge so is (y,x).
13 | Args:
14 | num_nodes: Number of nodes in the desired graph.
15 | degree: Desired degree.
16 | rng: random number generator
17 | Returns:
18 | senders: tail of each edge.
19 | receivers: head of each edge.
20 | """
21 |
22 | if rng is None:
23 | rng = np.random.default_rng()
24 |
25 | senders = [*range(0, num_nodes)] * degree
26 | receivers = []
27 | for _ in range(degree):
28 | receivers.extend(rng.permutation(list(range(num_nodes))).tolist())
29 |
30 | senders, receivers = [*senders, *receivers], [*receivers, *senders]
31 |
32 | senders = np.array(senders)
33 | receivers = np.array(receivers)
34 |
35 | return senders, receivers
36 |
37 |
38 |
39 | def generate_random_regular_graph2(num_nodes, degree, rng=None):
40 | """Generates a random 2d-regular graph with n nodes using simple variant of permutations algorithm.
41 | Returns the list of edges. This list is symmetric; i.e., if
42 | (x, y) is an edge so is (y,x).
43 | Args:
44 | num_nodes: Number of nodes in the desired graph.
45 | degree: Desired degree.
46 | rng: random number generator
47 | Returns:
48 | senders: tail of each edge.
49 | receivers: head of each edge.
50 | """
51 |
52 | if rng is None:
53 | rng = np.random.default_rng()
54 |
55 | senders = [*range(0, num_nodes)] * degree
56 | receivers = rng.permutation(senders).tolist()
57 |
58 | senders, receivers = [*senders, *receivers], [*receivers, *senders]
59 |
60 | return senders, receivers
61 |
62 |
63 | def generate_random_graph_with_hamiltonian_cycles(num_nodes, degree, rng=None):
64 | """Generates a 2d-regular graph with n nodes using d random hamiltonian cycles.
65 | Returns the list of edges. This list is symmetric; i.e., if
66 | (x, y) is an edge so is (y,x).
67 | Args:
68 | num_nodes: Number of nodes in the desired graph.
69 | degree: Desired degree.
70 | rng: random number generator
71 | Returns:
72 | senders: tail of each edge.
73 | receivers: head of each edge.
74 | """
75 |
76 | if rng is None:
77 | rng = np.random.default_rng()
78 |
79 | senders = []
80 | receivers = []
81 | for _ in range(degree):
82 | permutation = rng.permutation(list(range(num_nodes))).tolist()
83 | for idx, v in enumerate(permutation):
84 | u = permutation[idx - 1]
85 | senders.extend([v, u])
86 | receivers.extend([u, v])
87 |
88 | senders = np.array(senders)
89 | receivers = np.array(receivers)
90 |
91 | return senders, receivers
92 |
93 |
94 | def generate_random_expander(data, degree, algorithm, rng=None, max_num_iters=100, exp_index=0):
95 | """Generates a random d-regular expander graph with n nodes.
96 | Returns the list of edges. This list is symmetric; i.e., if
97 | (x, y) is an edge so is (y,x).
98 | Args:
99 | num_nodes: Number of nodes in the desired graph.
100 | degree: Desired degree.
101 | rng: random number generator
102 | max_num_iters: maximum number of iterations
103 | Returns:
104 | senders: tail of each edge.
105 | receivers: head of each edge.
106 | """
107 |
108 | num_nodes = data.num_nodes
109 |
110 | if rng is None:
111 | rng = np.random.default_rng()
112 |
113 | eig_val = -1
114 | eig_val_lower_bound = max(0, 2 * degree - 2 * math.sqrt(2 * degree - 1) - 0.1)
115 |
116 | max_eig_val_so_far = -1
117 | max_senders = []
118 | max_receivers = []
119 | cur_iter = 1
120 |
121 | if num_nodes <= degree:
122 | degree = num_nodes - 1
123 |
124 | # if there are too few nodes, random graph generation will fail. in this case, we will
125 | # add the whole graph.
126 | if num_nodes <= 10:
127 | for i in range(num_nodes):
128 | for j in range(num_nodes):
129 | if i != j:
130 | max_senders.append(i)
131 | max_receivers.append(j)
132 | else:
133 | while eig_val < eig_val_lower_bound and cur_iter <= max_num_iters:
134 | if algorithm == 'Random-d':
135 | senders, receivers = generate_random_regular_graph1(num_nodes, degree, rng)
136 | elif algorithm == 'Random-d-2':
137 | senders, receivers = generate_random_regular_graph2(num_nodes, degree, rng)
138 | elif algorithm == 'Hamiltonian':
139 | senders, receivers = generate_random_graph_with_hamiltonian_cycles(num_nodes, degree, rng)
140 | else:
141 | raise ValueError('prep.exp_algorithm should be one of the Random-d or Hamiltonian')
142 | [eig_val, _] = laplacian_eigenv(senders, receivers, k=1, n=num_nodes)
143 | if len(eig_val) == 0:
144 | print("num_nodes = %d, degree = %d, cur_iter = %d, mmax_iters = %d, senders = %d, receivers = %d" %(num_nodes, degree, cur_iter, max_num_iters, len(senders), len(receivers)))
145 | eig_val = 0
146 | else:
147 | eig_val = eig_val[0]
148 |
149 | if eig_val > max_eig_val_so_far:
150 | max_eig_val_so_far = eig_val
151 | max_senders = senders
152 | max_receivers = receivers
153 |
154 | cur_iter += 1
155 |
156 | # eliminate self loops.
157 | non_loops = [
158 | *filter(lambda i: max_senders[i] != max_receivers[i], range(0, len(max_senders)))
159 | ]
160 |
161 | senders = np.array(max_senders)[non_loops]
162 | receivers = np.array(max_receivers)[non_loops]
163 |
164 | max_senders = torch.tensor(max_senders, dtype=torch.long).view(-1, 1)
165 | max_receivers = torch.tensor(max_receivers, dtype=torch.long).view(-1, 1)
166 |
167 | if exp_index == 0:
168 | data.expander_edges = torch.cat([max_senders, max_receivers], dim=1)
169 | else:
170 | attrname = f"expander_edges{exp_index}"
171 | setattr(data, attrname, torch.cat([max_senders, max_receivers], dim=1))
172 |
173 | return data
174 |
--------------------------------------------------------------------------------
/gssc/transform/transforms.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import os
4 | import torch
5 | from torch_geometric.utils import subgraph
6 | from tqdm import tqdm
7 |
8 |
9 | def pre_transform_in_memory(dataset, transform_func, show_progress=False, check_if_pe_done=False):
10 | """Pre-transform already loaded PyG dataset object.
11 |
12 | Apply transform function to a loaded PyG dataset object so that
13 | the transformed result is persistent for the lifespan of the object.
14 | This means the result is not saved to disk, as what PyG's `pre_transform`
15 | would do, but also the transform is applied only once and not at each
16 | data access as what PyG's `transform` hook does.
17 |
18 | Implementation is based on torch_geometric.data.in_memory_dataset.copy
19 |
20 | Args:
21 | dataset: PyG dataset object to modify
22 | transform_func: transformation function to apply to each data example
23 | show_progress: show tqdm progress bar
24 | """
25 | if transform_func is None:
26 | return dataset
27 |
28 | if check_if_pe_done:
29 | file_name = dataset.processed_dir + "/pe_transformed.pt"
30 | if os.path.exists(file_name):
31 | print(f"Loading pre-transformed PE from {file_name}")
32 | data_list, collate_data, collate_slices = torch.load(file_name)
33 | dataset._indices, dataset._data_list, dataset.data, dataset.slices = None, data_list, collate_data, collate_slices
34 | print(f"Loaded pre-transformed PE from {file_name}")
35 | else:
36 | print(f"Pre-transforming PE and saving to {file_name}")
37 | data_list = [transform_func(dataset.get(i))
38 | for i in tqdm(range(len(dataset)),
39 | disable=not show_progress,
40 | mininterval=10,
41 | miniters=len(dataset)//20)]
42 | data_list = list(filter(None, data_list))
43 |
44 | dataset._indices = None
45 | dataset._data_list = data_list
46 | dataset.data, dataset.slices = dataset.collate(data_list)
47 | torch.save((data_list, dataset.data, dataset.slices), file_name)
48 | print(f"Saved pre-transformed PE to {file_name}")
49 | else:
50 | data_list = [transform_func(dataset.get(i))
51 | for i in tqdm(range(len(dataset)),
52 | disable=not show_progress,
53 | mininterval=10,
54 | miniters=len(dataset)//20)]
55 | data_list = list(filter(None, data_list))
56 |
57 | dataset._indices = None
58 | dataset._data_list = data_list
59 | dataset.data, dataset.slices = dataset.collate(data_list)
60 |
61 |
62 | def generate_splits(data, g_split):
63 | n_nodes = len(data.x)
64 | train_mask = torch.zeros(n_nodes, dtype=bool)
65 | valid_mask = torch.zeros(n_nodes, dtype=bool)
66 | test_mask = torch.zeros(n_nodes, dtype=bool)
67 | idx = torch.randperm(n_nodes)
68 | val_num = test_num = int(n_nodes * (1 - g_split) / 2)
69 | train_mask[idx[val_num + test_num:]] = True
70 | valid_mask[idx[:val_num]] = True
71 | test_mask[idx[val_num:val_num + test_num]] = True
72 | data.train_mask = train_mask
73 | data.val_mask = valid_mask
74 | data.test_mask = test_mask
75 | return data
76 |
77 |
78 | def typecast_x(data, type_str):
79 | if type_str == 'float':
80 | data.x = data.x.float()
81 | elif type_str == 'long':
82 | data.x = data.x.long()
83 | else:
84 | raise ValueError(f"Unexpected type '{type_str}'.")
85 | return data
86 |
87 |
88 | def concat_x_and_pos(data):
89 | data.x = torch.cat((data.x, data.pos), 1)
90 | return data
91 |
92 | def move_node_feat_to_x(data):
93 | """For ogbn-proteins, move the attribute node_species to attribute x."""
94 | data.x = data.node_species
95 | return data
96 |
97 | def clip_graphs_to_size(data, size_limit=5000):
98 | if hasattr(data, 'num_nodes'):
99 | N = data.num_nodes # Explicitly given number of nodes, e.g. ogbg-ppa
100 | else:
101 | N = data.x.shape[0] # Number of nodes, including disconnected nodes.
102 | if N <= size_limit:
103 | return data
104 | else:
105 | logging.info(f' ...clip to {size_limit} a graph of size: {N}')
106 | if hasattr(data, 'edge_attr'):
107 | edge_attr = data.edge_attr
108 | else:
109 | edge_attr = None
110 | edge_index, edge_attr = subgraph(list(range(size_limit)),
111 | data.edge_index, edge_attr)
112 | if hasattr(data, 'x'):
113 | data.x = data.x[:size_limit]
114 | data.num_nodes = size_limit
115 | else:
116 | data.num_nodes = size_limit
117 | if hasattr(data, 'node_is_attributed'): # for ogbg-code2 dataset
118 | data.node_is_attributed = data.node_is_attributed[:size_limit]
119 | data.node_dfs_order = data.node_dfs_order[:size_limit]
120 | data.node_depth = data.node_depth[:size_limit]
121 | data.edge_index = edge_index
122 | if hasattr(data, 'edge_attr'):
123 | data.edge_attr = edge_attr
124 | return data
125 |
--------------------------------------------------------------------------------
/gssc/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | from torch_geometric.utils import remove_self_loops
5 | from torch_scatter import scatter
6 |
7 | from yacs.config import CfgNode
8 |
9 |
10 | def negate_edge_index(edge_index, batch=None):
11 | """Negate batched sparse adjacency matrices given by edge indices.
12 |
13 | Returns batched sparse adjacency matrices with exactly those edges that
14 | are not in the input `edge_index` while ignoring self-loops.
15 |
16 | Implementation inspired by `torch_geometric.utils.to_dense_adj`
17 |
18 | Args:
19 | edge_index: The edge indices.
20 | batch: Batch vector, which assigns each node to a specific example.
21 |
22 | Returns:
23 | Complementary edge index.
24 | """
25 |
26 | if batch is None:
27 | batch = edge_index.new_zeros(edge_index.max().item() + 1)
28 |
29 | batch_size = batch.max().item() + 1
30 | one = batch.new_ones(batch.size(0))
31 | num_nodes = scatter(one, batch,
32 | dim=0, dim_size=batch_size, reduce='add')
33 | cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])
34 |
35 | idx0 = batch[edge_index[0]]
36 | idx1 = edge_index[0] - cum_nodes[batch][edge_index[0]]
37 | idx2 = edge_index[1] - cum_nodes[batch][edge_index[1]]
38 |
39 | negative_index_list = []
40 | for i in range(batch_size):
41 | n = num_nodes[i].item()
42 | size = [n, n]
43 | adj = torch.ones(size, dtype=torch.short,
44 | device=edge_index.device)
45 |
46 | # Remove existing edges from the full N x N adjacency matrix
47 | flattened_size = n * n
48 | adj = adj.view([flattened_size])
49 | _idx1 = idx1[idx0 == i]
50 | _idx2 = idx2[idx0 == i]
51 | idx = _idx1 * n + _idx2
52 | zero = torch.zeros(_idx1.numel(), dtype=torch.short,
53 | device=edge_index.device)
54 | scatter(zero, idx, dim=0, out=adj, reduce='mul')
55 |
56 | # Convert to edge index format
57 | adj = adj.view(size)
58 | _edge_index = adj.nonzero(as_tuple=False).t().contiguous()
59 | _edge_index, _ = remove_self_loops(_edge_index)
60 | negative_index_list.append(_edge_index + cum_nodes[i])
61 |
62 | edge_index_negative = torch.cat(negative_index_list, dim=1).contiguous()
63 | return edge_index_negative
64 |
65 |
66 | def flatten_dict(metrics):
67 | """Flatten a list of train/val/test metrics into one dict to send to wandb.
68 |
69 | Args:
70 | metrics: List of Dicts with metrics
71 |
72 | Returns:
73 | A flat dictionary with names prefixed with "train/" , "val/" , "test/"
74 | """
75 | prefixes = ['train', 'val', 'test']
76 | result = {}
77 | for i in range(len(metrics)):
78 | # Take the latest metrics.
79 | stats = metrics[i][-1]
80 | result.update({f"{prefixes[i]}/{k}": v for k, v in stats.items()})
81 | return result
82 |
83 |
84 | def cfg_to_dict(cfg_node, key_list=[]):
85 | """Convert a config node to dictionary.
86 |
87 | Yacs doesn't have a default function to convert the cfg object to plain
88 | python dict. The following function was taken from
89 | https://github.com/rbgirshick/yacs/issues/19
90 | """
91 | _VALID_TYPES = {tuple, list, str, int, float, bool}
92 |
93 | if not isinstance(cfg_node, CfgNode):
94 | if type(cfg_node) not in _VALID_TYPES:
95 | logging.warning(f"Key {'.'.join(key_list)} with "
96 | f"value {type(cfg_node)} is not "
97 | f"a valid type; valid types: {_VALID_TYPES}")
98 | return cfg_node
99 | else:
100 | cfg_dict = dict(cfg_node)
101 | for k, v in cfg_dict.items():
102 | cfg_dict[k] = cfg_to_dict(v, key_list + [k])
103 | return cfg_dict
104 |
105 |
106 | def make_wandb_name(cfg):
107 | # Format dataset name.
108 | dataset_name = cfg.dataset.format
109 | if dataset_name.startswith('OGB'):
110 | dataset_name = dataset_name[3:]
111 | if dataset_name.startswith('PyG-'):
112 | dataset_name = dataset_name[4:]
113 | if dataset_name in ['GNNBenchmarkDataset', 'TUDataset']:
114 | # Shorten some verbose dataset naming schemes.
115 | dataset_name = ""
116 | if cfg.dataset.name != 'none':
117 | dataset_name += "-" if dataset_name != "" else ""
118 | if cfg.dataset.name == 'LocalDegreeProfile':
119 | dataset_name += 'LDP'
120 | else:
121 | dataset_name += cfg.dataset.name
122 | # Format model name.
123 | model_name = cfg.model.type
124 | if cfg.model.type in ['gnn', 'custom_gnn']:
125 | model_name += f".{cfg.gnn.layer_type}"
126 | elif cfg.model.type == 'GPSModel':
127 | model_name = f"GPS.{cfg.gt.layer_type}"
128 | model_name += f".{cfg.name_tag}" if cfg.name_tag else ""
129 | # Compose wandb run name.
130 | name = f"{dataset_name}.{model_name}.r{cfg.run_id}"
131 | return name
132 |
--------------------------------------------------------------------------------
/wandb/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 |
--------------------------------------------------------------------------------