├── .github └── workflows │ └── python-package.yml ├── .gitignore ├── LICENSE ├── README.md ├── environment.yml ├── notebooks ├── ray_analysis.ipynb └── visualise_attention.ipynb ├── pyG_install.sh ├── requirements.txt ├── src ├── CGNN.py ├── DIGL_data.py ├── DIGL_seeds.py ├── GNN.py ├── GNN_KNN.py ├── GNN_KNN_early.py ├── GNN_early.py ├── GNN_image.py ├── base_classes.py ├── best_params.py ├── block_constant.py ├── block_constant_rewiring.py ├── block_mixed.py ├── block_transformer_attention.py ├── block_transformer_hard_attention.py ├── block_transformer_rewiring.py ├── data.py ├── data_image.py ├── deepwalk_embeddings.py ├── deepwalk_gen.sh ├── deepwalk_gen_symlinks.py ├── distances_kNN.py ├── early_stop_solver.py ├── function_GAT_attention.py ├── function_laplacian_diffusion.py ├── function_transformer_attention.py ├── graph_rewiring.py ├── heterophilic.py ├── hyperbolic_distances.py ├── model_configurations.py ├── pos_enc_factorisation.py ├── post_analysis_image.py ├── ray_tune.py ├── regularized_ODE_function.py ├── run_GNN.py ├── run_best_ray.py ├── run_explicit_implicit_exp.py ├── run_image.py ├── utils.py └── visualise_attention.py └── test ├── test_ICML_gnn.py ├── test_attention.py ├── test_attention_ode_block.py ├── test_block_mixed.py ├── test_early_stop.py ├── test_function_laplacian_diffusion.py ├── test_gnn.py ├── test_params.py ├── test_transformer_attention.py └── test_utils.py /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | name: Build and Test 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | push: 8 | branches: 9 | - main 10 | schedule: 11 | # Run the tests at 00:00 each day 12 | - cron: "0 0 * * *" 13 | 14 | jobs: 15 | build: 16 | 17 | runs-on: ubuntu-latest 18 | strategy: 19 | matrix: 20 | python-version: [3.8] 21 | defaults: 22 | run: 23 | shell: bash -l {0} 24 | 25 | steps: 26 | - uses: actions/checkout@v2 27 | - name: cache conda 28 | uses: actions/cache@v2 29 | env: 30 | # Increase this value to reset cache if etc/example-environment.yml has not changed 31 | CACHE_NUMBER: 0 32 | with: 33 | path: ~/conda_pkgs_dir 34 | key: 35 | ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ 36 | hashFiles('requirements.txt') }} 37 | - uses: conda-incubator/setup-miniconda@v2 38 | with: 39 | activate-environment: test 40 | python-version: 3.8 41 | use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly! 42 | - name: Set up env 43 | run: | 44 | conda activate test 45 | conda install pip 46 | - name: Cache pip 47 | uses: actions/cache@v2 48 | with: 49 | # This path is specific to Ubuntu 50 | path: ~/.cache/pip 51 | # Look to see if there is a cache hit for the corresponding requirements file 52 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} 53 | restore-keys: | 54 | ${{ runner.os }}-pip- 55 | ${{ runner.os }}- 56 | - name: pytorch 57 | run: | 58 | conda install -y pytorch=1.7.1 torchvision cudatoolkit=10.2 -c pytorch --update-deps 59 | - name: Install pyG 60 | run: | 61 | ./pyG_install.sh cu102 62 | - name: Install dependencies 63 | run: | 64 | pip install -r requirements.txt 65 | - name: Test with pytest 66 | run: | 67 | pytest -v -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | data/ 3 | ray_tune/ 4 | __pycache__/ 5 | src/checkpoint 6 | images/ 7 | ray_results/ 8 | models/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![example workflow](https://github.com/twitter-research/graph-neural-pde/actions/workflows/python-package.yml/badge.svg) 2 | 3 | ![Cora_animation_16](https://user-images.githubusercontent.com/5874124/143270624-265c2d01-39ca-488c-b118-b68f876dfbfa.gif) 4 | 5 | ## Introduction 6 | 7 | This repository contains the source code for the publications [GRAND: Graph Neural Diffusion](https://icml.cc/virtual/2021/poster/8889) and [Beltrami Flow and Neural Diffusion on Graphs (BLEND)](https://arxiv.org/abs/2110.09443). 8 | These approaches treat deep learning on graphs as a continuous diffusion process and Graph Neural 9 | Networks (GNNs) as discretisations of an underlying PDE. In both models, the layer structure and 10 | topology correspond to the discretisation choices 11 | of temporal and spatial operators. Our approach allows a principled development of a broad new 12 | class of GNNs that are able to address the common plights of graph learning models such as 13 | depth, oversmoothing, and bottlenecks. Key to 14 | the success of our models are stability with respect to perturbations in the data and this is addressed for both 15 | implicit and explicit discretisation schemes. We develop linear and nonlinear 16 | versions of GRAND, which achieve competitive results on many standard graph benchmarks. BLEND is a non-Euclidean extension of GRAND that jointly evolves the feature and positional encodings of each node providing a principled means to perform graph rewiring. 17 | 18 | ## Running the experiments 19 | 20 | ### Requirements 21 | Dependencies (with python >= 3.7): 22 | Main dependencies are 23 | torch==1.8.1 24 | torch-cluster==1.5.9 25 | torch-geometric==1.7.0 26 | torch-scatter==2.0.6 27 | torch-sparse==0.6.9 28 | torch-spline-conv==1.2.1 29 | torchdiffeq==0.2.1 30 | Commands to install all the dependencies in a new conda environment 31 | ``` 32 | conda create --name grand python=3.7 33 | conda activate grand 34 | 35 | pip install ogb pykeops 36 | pip install torch==1.8.1 37 | pip install torchdiffeq -f https://pytorch-geometric.com/whl/torch-1.8.1+cu102.html 38 | 39 | pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.1+cu102.html 40 | pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.1+cu102.html 41 | pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.8.1+cu102.html 42 | pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.8.1+cu102.html 43 | pip install torch-geometric 44 | ``` 45 | 46 | ### Troubleshooting 47 | 48 | There is a bug in pandas==1.3.1 that could produce the error ImportError: cannot import name 'DtypeObj' from 'pandas._typing' 49 | If encountered, then the fix is 50 | pip install pandas==1.3.0 -U 51 | 52 | ## GRAND (Graph Neural Diffusion) 53 | 54 | ### Dataset and Preprocessing 55 | create a root level folder 56 | ``` 57 | ./data 58 | ``` 59 | This will be automatically populated the first time each experiment is run. 60 | 61 | ### Experiments 62 | For example to run for Cora with random splits: 63 | ``` 64 | cd src 65 | python run_GNN.py --dataset Cora 66 | ``` 67 | 68 | ## BLEND (Beltrami Flow and Neural Diffusion on Graphs) 69 | 70 | ### Dataset and Preprocessing 71 | 72 | Create a root level 73 | ``` 74 | ./data folder 75 | ``` 76 | This will be automatically populated the first time each experiment is run. 77 | create a root level folder 78 | ``` 79 | ./data/pos_encodings 80 | ``` 81 | DIGL positional encodings will build automatically and populate this folder, but DeepWalk or Hyperbollic positional encodings will need generating using the relevant generator scripts or downloading. We include a shell script (warning: it's slow) to generate them: 82 | ``` 83 | sh deepwalk_gen.sh 84 | ``` 85 | then create symlinks to them with 86 | ``` 87 | python deepwalk_gen_symlinks.py 88 | ``` 89 | Alternatively, we also provide precomputed positional encodings [here](https://www.dropbox.com/sh/wfktgbfiueikcp0/AABrIjyhR6Yi4EcirnryRXjja?dl=0) 90 | Specifically, the positional encodings required to run the default settings for Citeseer, Computers, Phota and ogbn-arxiv are 91 | - [Citeseer](https://www.dropbox.com/sh/wfktgbfiueikcp0/AAB9HypMFO3QCeDFojRYuQoDa/Citeseer_DW64.pkl?dl=0) 92 | - [Computers](https://www.dropbox.com/sh/wfktgbfiueikcp0/AAD_evlqcwQFLL6MVyGeiKiha/Computers_DW128.pkl?dl=0) 93 | - [Photo](https://www.dropbox.com/sh/wfktgbfiueikcp0/AAAAhsxAcHWB5OGTHLNMXR5-a/Photo_DW128.pkl?dl=0) 94 | - [ogbn-arxiv](https://www.dropbox.com/sh/wfktgbfiueikcp0/AADcRPI5pLrx3iUvUjGBcqD0a/ogbn-arxiv_DW64.pkl?dl=0) 95 | 96 | Download them and place into 97 | ``` 98 | ./data/pos_encodings 99 | ``` 100 | 101 | ### Experiments 102 | 103 | For example to run for Cora with random splits: 104 | ``` 105 | cd src 106 | python run_GNN.py --dataset Cora --beltrami 107 | ``` 108 | 109 | ## Troubleshooting 110 | 111 | Most problems installing the dependencies are caused by Cuda version mismatches with pytorch geometric. We recommend checking your cuda and pytorch versions 112 | ``` 113 | nvcc --version 114 | python -c "import torch; print(torch.__version__)" 115 | ``` 116 | and then following instructions here to install pytorch geometric 117 | https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html 118 | 119 | ## Cite us 120 | If you found this work useful, please consider citing our papers 121 | ``` 122 | @article 123 | {chamberlain2021grand, 124 | title={GRAND: Graph Neural Diffusion}, 125 | author={Chamberlain, Benjamin Paul and Rowbottom, James and Goronova, Maria and Webb, Stefan and Rossi, 126 | Emanuele and Bronstein, Michael M}, 127 | journal={Proceedings of the 38th International Conference on Machine Learning, 128 | (ICML) 2021, 18-24 July 2021, Virtual Event}, 129 | year={2021} 130 | } 131 | ``` 132 | and 133 | ``` 134 | @article 135 | {chamberlain2021blend, 136 | title={Beltrami Flow and Neural Diffusion on Graphs}, 137 | author={Chamberlain, Benjamin Paul and Rowbottom, James and Eynard, Davide and Di Giovanni, Francesco and Dong Xiaowen and Bronstein, Michael M}, 138 | journal={Proceedings of the Thirty-fifth Conference on Neural Information Processing Systems (NeurIPS) 2021, Virtual Event}, 139 | year={2021} 140 | } 141 | ``` 142 | 143 | ## Security Issues? 144 | Please report sensitive security issues via Twitter's bug-bounty program (https://hackerone.com/twitter) rather than GitHub. 145 | 146 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: graph-neural-pde 2 | channels: 3 | - soumith 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - blas=1.0 8 | - ca-certificates=2020.12.8 9 | - certifi=2020.12.5 10 | - cycler=0.10.0 11 | - freetype=2.10.4 12 | - intel-openmp=2020.2 13 | - joblib=1.0.0 14 | - jpeg=9b 15 | - kiwisolver=1.3.0 16 | - lcms2=2.11 17 | - libcxx=10.0.0 18 | - libedit=3.1.20191231 19 | - libffi=3.3 20 | - libgfortran 21 | - libllvm9=9.0.1 22 | - libpng=1.6.37 23 | - libtiff=4.1.0 24 | - libuv=1.40.0 25 | - llvm-openmp=10.0.0 26 | - lz4-c=1.9.2 27 | - matplotlib=3.3.2 28 | - matplotlib-base=3.3.2 29 | - mkl=2019.4 30 | - mkl-service=2.3.0 31 | - mkl_fft=1.2.0 32 | - mkl_random=1.1.1 33 | - ncurses=6.2 34 | - ninja=1.10.2 35 | - numba=0.50.1 36 | - numpy=1.19.2 37 | - numpy-base=1.19.2 38 | - olefile=0.46 39 | - openssl=1.1.1i 40 | - pillow=8.1.0 41 | - pip=20.3.3 42 | - pyparsing=2.4.7 43 | - python=3.8.5 44 | - python-dateutil=2.8.1 45 | - pytorch=1.7.1 46 | - readline=8.0 47 | - setuptools=51.1.2 48 | - six=1.15.0 49 | - sqlite=3.33.0 50 | - tbb=2020.3 51 | - threadpoolctl=2.1.0 52 | - tk=8.6.10 53 | - torchvision=0.2.1 54 | - tornado=6.1 55 | - typing_extensions=3.7.4.3 56 | - wheel=0.36.2 57 | - xz=5.2.5 58 | - zlib=1.2.11 59 | - zstd=1.4.5 60 | - pip: 61 | - ase==3.20.1 62 | - boltons==20.2.1 63 | - chardet==4.0.0 64 | - decorator==4.4.2 65 | - et-xmlfile==1.0.1 66 | - googledrivedownloader==0.4 67 | - h5py==3.1.0 68 | - idna==2.10 69 | - isodate==0.6.0 70 | - jdcal==1.4.1 71 | - jinja2==2.11.2 72 | - littleutils==0.2.2 73 | - llvmlite==0.33.0 74 | - markupsafe==1.1.1 75 | - networkx==2.5 76 | - ogb==1.2.4 77 | - openpyxl==3.0.6 78 | - outdated==0.2.0 79 | - pandas==1.2.0 80 | - pykeops==1.4.2 81 | - python-louvain==0.15 82 | - pytz==2020.5 83 | - rdflib==5.0.0 84 | - requests==2.25.1 85 | - scikit-learn==0.24.0 86 | - scipy==1.5.4 87 | - torch-cluster==1.5.8 88 | - torch-geometric==1.6.3 89 | - torch-scatter==2.0.5 90 | - torch-sparse==0.6.8 91 | - torch-spline-conv==1.2.0 92 | - torchdiffeq==0.1.1 93 | - torchsde==0.2.4 94 | - tqdm==4.56.0 95 | - trampoline==0.1.2 96 | - urllib3==1.26.2 97 | prefix: /Users/benchamberlain/anaconda/envs/graph-neural-pde 98 | -------------------------------------------------------------------------------- /pyG_install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | TORCH=1.7.1 3 | CUDA=$1 # Supply as command line cpu or cu102 4 | pip install torch-scatter==2.0.5 -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html 5 | pip install torch-sparse==0.6.8 -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html 6 | pip install torch-cluster==1.5.8 -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html 7 | pip install torch-spline-conv==1.2.0 -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html 8 | pip install torch-geometric==1.6.3 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: osx-64 4 | h5py==2.10.0 5 | jupyter==1.0.0 6 | matplotlib==3.2.2 7 | networkx==2.4 8 | numba==0.51.2 9 | numpy==1.19.3 10 | ogb==1.2.4 11 | pandas==1.0.5 12 | pip==20.1.1 13 | pykeops==1.4.1 14 | pytest==6.1.1 15 | ray==1.0.0 16 | scikit-learn==0.23.1 17 | scipy==1.5.3 18 | six==1.15.0 19 | tqdm==4.46.1 20 | torchdiffeq==0.1.1 21 | tabulate==0.8.7 22 | 23 | # reliant on cuda version 24 | # torch==1.6.0=pypi_0 25 | # torch-geometric=1.6.1=pypi_0 26 | # torch-scatter=2.0.5=pypi_0 27 | # torch-sparse=0.6.7=pypi_0 28 | # torchdiffeq=0.1.1=pypi_0 29 | # torchsde=0.2.1=pypi_0 30 | # torchsummary=1.5.1=pypi_0 31 | # torchvision=0.7.0=pypi_0 -------------------------------------------------------------------------------- /src/DIGL_data.py: -------------------------------------------------------------------------------- 1 | __author__ = "Stefan Weißenberger and Johannes Klicpera" 2 | __license__ = "MIT" 3 | 4 | import os 5 | 6 | import numpy as np 7 | from scipy.linalg import expm 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch_geometric.data import Data, InMemoryDataset 12 | from torch_geometric.datasets import Planetoid, Amazon, Coauthor 13 | from torch.optim import Adam, Optimizer 14 | from DIGL_seeds import development_seed 15 | 16 | DATA_PATH = 'data' 17 | 18 | 19 | def train(model: torch.nn.Module, optimizer: Optimizer, data: Data): 20 | model.train() 21 | optimizer.zero_grad() 22 | logits = model(data.x) #data 23 | loss = F.nll_loss(logits[data.train_mask], data.y[data.train_mask]) 24 | loss.backward() 25 | optimizer.step() 26 | 27 | 28 | def evaluate(model: torch.nn.Module, data: Data, test: bool): 29 | model.eval() 30 | with torch.no_grad(): 31 | logits = model(data.x) #data 32 | eval_dict = {} 33 | keys = ['val', 'test'] if test else ['val'] 34 | for key in keys: 35 | mask = data[f'{key}_mask'] 36 | # loss = F.nll_loss(logits[mask], data.y[mask]).item() 37 | # eval_dict[f'{key}_loss'] = loss 38 | pred = logits[mask].max(1)[1] 39 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() 40 | eval_dict[f'{key}_acc'] = acc 41 | return eval_dict 42 | 43 | def get_dataset(name: str, use_lcc: bool = True) -> InMemoryDataset: 44 | path = os.path.join(DATA_PATH, name) 45 | if name in ['Cora', 'Citeseer', 'Pubmed']: 46 | dataset = Planetoid(path, name) 47 | elif name in ['Computers', 'Photo']: 48 | dataset = Amazon(path, name) 49 | elif name == 'CoauthorCS': 50 | dataset = Coauthor(path, 'CS') 51 | else: 52 | raise Exception('Unknown dataset.') 53 | 54 | if use_lcc: 55 | lcc = get_largest_connected_component(dataset) 56 | 57 | x_new = dataset.data.x[lcc] 58 | y_new = dataset.data.y[lcc] 59 | 60 | row, col = dataset.data.edge_index.numpy() 61 | edges = [[i, j] for i, j in zip(row, col) if i in lcc and j in lcc] 62 | edges = remap_edges(edges, get_node_mapper(lcc)) 63 | 64 | data = Data( 65 | x=x_new, 66 | edge_index=torch.LongTensor(edges), 67 | y=y_new, 68 | train_mask=torch.zeros(y_new.size()[0], dtype=torch.bool), 69 | test_mask=torch.zeros(y_new.size()[0], dtype=torch.bool), 70 | val_mask=torch.zeros(y_new.size()[0], dtype=torch.bool) 71 | ) 72 | dataset.data = data 73 | 74 | return dataset 75 | 76 | 77 | def get_component(dataset: InMemoryDataset, start: int = 0) -> set: 78 | visited_nodes = set() 79 | queued_nodes = set([start]) 80 | row, col = dataset.data.edge_index.numpy() 81 | while queued_nodes: 82 | current_node = queued_nodes.pop() 83 | visited_nodes.update([current_node]) 84 | neighbors = col[np.where(row == current_node)[0]] 85 | neighbors = [n for n in neighbors if n not in visited_nodes and n not in queued_nodes] 86 | queued_nodes.update(neighbors) 87 | return visited_nodes 88 | 89 | 90 | def get_largest_connected_component(dataset: InMemoryDataset) -> np.ndarray: 91 | remaining_nodes = set(range(dataset.data.x.shape[0])) 92 | comps = [] 93 | while remaining_nodes: 94 | start = min(remaining_nodes) 95 | comp = get_component(dataset, start) 96 | comps.append(comp) 97 | remaining_nodes = remaining_nodes.difference(comp) 98 | return np.array(list(comps[np.argmax(list(map(len, comps)))])) 99 | 100 | 101 | def get_node_mapper(lcc: np.ndarray) -> dict: 102 | mapper = {} 103 | counter = 0 104 | for node in lcc: 105 | mapper[node] = counter 106 | counter += 1 107 | return mapper 108 | 109 | 110 | def remap_edges(edges: list, mapper: dict) -> list: 111 | row = [e[0] for e in edges] 112 | col = [e[1] for e in edges] 113 | row = list(map(lambda x: mapper[x], row)) 114 | col = list(map(lambda x: mapper[x], col)) 115 | return [row, col] 116 | 117 | 118 | def get_adj_matrix(dataset: InMemoryDataset) -> np.ndarray: 119 | num_nodes = dataset.data.x.shape[0] 120 | adj_matrix = np.zeros(shape=(num_nodes, num_nodes)) 121 | for i, j in zip(dataset.data.edge_index[0], dataset.data.edge_index[1]): 122 | adj_matrix[i, j] = 1. 123 | return adj_matrix 124 | 125 | 126 | def get_ppr_matrix( 127 | adj_matrix: np.ndarray, 128 | alpha: float = 0.1) -> np.ndarray: 129 | num_nodes = adj_matrix.shape[0] 130 | A_tilde = adj_matrix + np.eye(num_nodes) 131 | D_tilde = np.diag(1 / np.sqrt(A_tilde.sum(axis=1))) 132 | H = D_tilde @ A_tilde @ D_tilde 133 | return alpha * np.linalg.inv(np.eye(num_nodes) - (1 - alpha) * H) 134 | 135 | 136 | def get_heat_matrix( 137 | adj_matrix: np.ndarray, 138 | t: float = 5.0) -> np.ndarray: 139 | num_nodes = adj_matrix.shape[0] 140 | A_tilde = adj_matrix + np.eye(num_nodes) 141 | D_tilde = np.diag(1 / np.sqrt(A_tilde.sum(axis=1))) 142 | H = D_tilde @ A_tilde @ D_tilde 143 | return expm(-t * (np.eye(num_nodes) - H)) 144 | 145 | 146 | def get_top_k_matrix(A: np.ndarray, k: int = 128) -> np.ndarray: 147 | num_nodes = A.shape[0] 148 | row_idx = np.arange(num_nodes) 149 | A[A.argsort(axis=0)[:num_nodes - k], row_idx] = 0. 150 | norm = A.sum(axis=0) 151 | norm[norm <= 0] = 1 # avoid dividing by zero 152 | return A / norm 153 | 154 | 155 | def get_clipped_matrix(A: np.ndarray, eps: float = 0.01) -> np.ndarray: 156 | num_nodes = A.shape[0] 157 | A[A < eps] = 0. 158 | norm = A.sum(axis=0) 159 | norm[norm <= 0] = 1 # avoid dividing by zero 160 | return A / norm 161 | 162 | 163 | def set_train_val_test_split( 164 | seed: int, 165 | data: Data, 166 | num_development: int = 1500, 167 | num_per_class: int = 20) -> Data: 168 | rnd_state = np.random.RandomState(seed) #seed development_seed) 169 | num_nodes = data.y.shape[0] 170 | development_idx = rnd_state.choice(num_nodes, num_development, replace=False) 171 | test_idx = [i for i in np.arange(num_nodes) if i not in development_idx] 172 | 173 | train_idx = [] 174 | rnd_state = np.random.RandomState(seed) 175 | for c in range(data.y.max() + 1): 176 | class_idx = development_idx[np.where(data.y[development_idx].cpu() == c)[0]] 177 | train_idx.extend(rnd_state.choice(class_idx, num_per_class, replace=False)) 178 | 179 | val_idx = [i for i in development_idx if i not in train_idx] 180 | 181 | def get_mask(idx): 182 | mask = torch.zeros(num_nodes, dtype=torch.bool) 183 | mask[idx] = 1 184 | return mask 185 | 186 | data.train_mask = get_mask(train_idx) 187 | data.val_mask = get_mask(val_idx) 188 | data.test_mask = get_mask(test_idx) 189 | 190 | return data 191 | 192 | 193 | class PPRDataset(InMemoryDataset): 194 | """ 195 | Dataset preprocessed with GDC using PPR diffusion. 196 | Note that this implementations is not scalable 197 | since we directly invert the adjacency matrix. 198 | """ 199 | 200 | def __init__(self, 201 | name: str = 'Cora', 202 | use_lcc: bool = True, 203 | alpha: float = 0.1, 204 | k: int = 16, 205 | eps: float = None): 206 | self.name = name 207 | self.use_lcc = use_lcc 208 | self.alpha = alpha 209 | self.k = k 210 | self.eps = eps 211 | 212 | super(PPRDataset, self).__init__(DATA_PATH) 213 | self.data, self.slices = torch.load(self.processed_paths[0]) 214 | 215 | @property 216 | def raw_file_names(self) -> list: 217 | return [] 218 | 219 | @property 220 | def processed_file_names(self) -> list: 221 | return [str(self) + '.pt'] 222 | 223 | def download(self): 224 | pass 225 | 226 | def process(self): 227 | base = get_dataset(name=self.name, use_lcc=self.use_lcc) 228 | # generate adjacency matrix from sparse representation 229 | adj_matrix = get_adj_matrix(base) 230 | # obtain exact PPR matrix 231 | ppr_matrix = get_ppr_matrix(adj_matrix, 232 | alpha=self.alpha) 233 | 234 | if self.k: 235 | print(f'Selecting top {self.k} edges per node.') 236 | ppr_matrix = get_top_k_matrix(ppr_matrix, k=self.k) 237 | elif self.eps: 238 | print(f'Selecting edges with weight greater than {self.eps}.') 239 | ppr_matrix = get_clipped_matrix(ppr_matrix, eps=self.eps) 240 | else: 241 | raise ValueError 242 | 243 | # create PyG Data object 244 | edges_i = [] 245 | edges_j = [] 246 | edge_attr = [] 247 | for i, row in enumerate(ppr_matrix): 248 | for j in np.where(row > 0)[0]: 249 | edges_i.append(i) 250 | edges_j.append(j) 251 | edge_attr.append(ppr_matrix[i, j]) 252 | edge_index = [edges_i, edges_j] 253 | 254 | data = Data( 255 | x=base.data.x, 256 | edge_index=torch.LongTensor(edge_index), 257 | edge_attr=torch.FloatTensor(edge_attr), 258 | y=base.data.y, 259 | train_mask=torch.zeros(base.data.train_mask.size()[0], dtype=torch.bool), 260 | test_mask=torch.zeros(base.data.test_mask.size()[0], dtype=torch.bool), 261 | val_mask=torch.zeros(base.data.val_mask.size()[0], dtype=torch.bool) 262 | ) 263 | 264 | data, slices = self.collate([data]) 265 | torch.save((data, slices), self.processed_paths[0]) 266 | 267 | def __str__(self) -> str: 268 | return f'{self.name}_ppr_alpha={self.alpha}_k={self.k}_eps={self.eps}_lcc={self.use_lcc}' 269 | 270 | 271 | class HeatDataset(InMemoryDataset): 272 | """ 273 | Dataset preprocessed with GDC using heat kernel diffusion. 274 | Note that this implementations is not scalable 275 | since we directly calculate the matrix exponential 276 | of the adjacency matrix. 277 | """ 278 | 279 | def __init__(self, 280 | name: str = 'Cora', 281 | use_lcc: bool = True, 282 | t: float = 5.0, 283 | k: int = 16, 284 | eps: float = None): 285 | self.name = name 286 | self.use_lcc = use_lcc 287 | self.t = t 288 | self.k = k 289 | self.eps = eps 290 | 291 | super(HeatDataset, self).__init__(DATA_PATH) 292 | self.data, self.slices = torch.load(self.processed_paths[0]) 293 | 294 | @property 295 | def raw_file_names(self) -> list: 296 | return [] 297 | 298 | @property 299 | def processed_file_names(self) -> list: 300 | return [str(self) + '.pt'] 301 | 302 | def download(self): 303 | pass 304 | 305 | def process(self): 306 | base = get_dataset(name=self.name, use_lcc=self.use_lcc) 307 | # generate adjacency matrix from sparse representation 308 | adj_matrix = get_adj_matrix(base) 309 | # get heat matrix as described in Berberidis et al., 2019 310 | heat_matrix = get_heat_matrix(adj_matrix, 311 | t=self.t) 312 | if self.k: 313 | print(f'Selecting top {self.k} edges per node.') 314 | heat_matrix = get_top_k_matrix(heat_matrix, k=self.k) 315 | elif self.eps: 316 | print(f'Selecting edges with weight greater than {self.eps}.') 317 | heat_matrix = get_clipped_matrix(heat_matrix, eps=self.eps) 318 | else: 319 | raise ValueError 320 | 321 | # create PyG Data object 322 | edges_i = [] 323 | edges_j = [] 324 | edge_attr = [] 325 | for i, row in enumerate(heat_matrix): 326 | for j in np.where(row > 0)[0]: 327 | edges_i.append(i) 328 | edges_j.append(j) 329 | edge_attr.append(heat_matrix[i, j]) 330 | edge_index = [edges_i, edges_j] 331 | 332 | data = Data( 333 | x=base.data.x, 334 | edge_index=torch.LongTensor(edge_index), 335 | edge_attr=torch.FloatTensor(edge_attr), 336 | y=base.data.y, 337 | train_mask=torch.zeros(base.data.train_mask.size()[0], dtype=torch.bool), 338 | test_mask=torch.zeros(base.data.test_mask.size()[0], dtype=torch.bool), 339 | val_mask=torch.zeros(base.data.val_mask.size()[0], dtype=torch.bool) 340 | ) 341 | 342 | data, slices = self.collate([data]) 343 | torch.save((data, slices), self.processed_paths[0]) 344 | 345 | def __str__(self) -> str: 346 | return f'{self.name}_heat_t={self.t}_k={self.k}_eps={self.eps}_lcc={self.use_lcc}' 347 | -------------------------------------------------------------------------------- /src/DIGL_seeds.py: -------------------------------------------------------------------------------- 1 | __author__ = "Stefan Weißenberger and Johannes Klicpera" 2 | __license__ = "MIT" 3 | 4 | test_seeds = [ 5 | 2406525885, 3164031153, 1454191016, 1583215992, 765984986, 6 | 258270452, 3808600642, 292690791, 2492579272, 1660347731, 7 | 902096533, 1295255868, 3887601419, 2250799892, 4099160157, 8 | 658822373, 1105377040, 1822472846, 2360402805, 2355749367, 9 | 2291281609, 1241963358, 3431144533, 623424053, 78533721, 10 | 1819244826, 1368272433, 555336705, 1979924085, 1064200250, 11 | 256355991, 125892661, 4214462414, 2173868563, 629150633, 12 | 525931699, 3859280724, 1633334170, 1881852583, 2776477614, 13 | 1576005390, 2488832372, 2518362830, 2535216825, 333285849, 14 | 109709634, 2287562222, 3519650116, 3997158861, 3939456016, 15 | 4049817465, 2056937834, 4198936517, 1928038128, 897197605, 16 | 3241375559, 3379824712, 3094687001, 80894711, 1598990667, 17 | 2733558549, 2514977904, 3551930474, 2501047343, 2838870928, 18 | 2323804206, 2609476842, 1941488137, 1647800118, 1544748364, 19 | 983997847, 1907884813, 1261931583, 4094088262, 536998751, 20 | 3788863109, 4023022221, 3116173213, 4019585660, 3278901850, 21 | 3321752075, 2108550661, 2354669019, 3317723962, 1915553117, 22 | 1464389813, 1648766618, 3423813613, 1338906396, 629014539, 23 | 3330934799, 3295065306, 3212139042, 3653474276, 1078114430, 24 | 2424918363, 3316305951, 2059234307, 1805510917, 1327514671 25 | ] 26 | 27 | val_seeds = [ 28 | 4258031807, 3829679737, 3706579387, 789594926, 3628091752, 29 | 54121625, 825346923, 646393804, 1579300575, 246132812, 30 | 2882726575, 970387138, 413984459, 288449314, 1594895720, 31 | 1950255998, 4015021126, 3798842978, 2668546961, 1254814623, 32 | 1804908540, 674684671, 1988664841, 3361110162, 3784152546, 33 | 3431665473, 1487802115, 1080377472, 1033325667, 2068347440, 34 | 50862517, 1266130159, 3705237643, 2523113545, 1385697073, 35 | 1227694832, 198559329, 1464601500, 490478722, 3144635527, 36 | 4085231799, 2935399337, 3291449301, 2933074791, 1604475278, 37 | 2748278770, 1041151773, 2302537583, 1592364233, 1347718791, 38 | 2260302349, 2870906085, 3324642025, 3383731094, 3268345887, 39 | 3861549985, 1839485103, 2440976226, 1348632978, 1730263803, 40 | 3273174762, 2443236195, 2018253000, 3131053563, 2750855724, 41 | 2142840570, 133334446, 2906772286, 1676623629, 2799515439, 42 | 1950780225, 245027879, 974231345, 1019551316, 418468904, 43 | 3645979760, 2676444879, 2600212003, 243207504, 4050914577, 44 | 395869280, 3037389484, 319467089, 2091061953, 1121224029, 45 | 1506683900, 4265586951, 910928236, 1175970114, 2105285287, 46 | 3164711608, 3255599240, 894959334, 493067366, 3349051410, 47 | 511641138, 2487307261, 951126382, 530590201, 17966177 48 | ] 49 | 50 | development_seed = 1684992425 51 | -------------------------------------------------------------------------------- /src/GNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from base_classes import BaseGNN 5 | from model_configurations import set_block, set_function 6 | 7 | 8 | # Define the GNN model. 9 | class GNN(BaseGNN): 10 | def __init__(self, opt, dataset, device=torch.device('cpu')): 11 | super(GNN, self).__init__(opt, dataset, device) 12 | self.f = set_function(opt) 13 | block = set_block(opt) 14 | time_tensor = torch.tensor([0, self.T]).to(device) 15 | self.odeblock = block(self.f, self.regularization_fns, opt, dataset.data, device, t=time_tensor).to(device) 16 | 17 | def forward(self, x, pos_encoding=None): 18 | # Encode each node based on its feature. 19 | if self.opt['use_labels']: 20 | y = x[:, -self.num_classes:] 21 | x = x[:, :-self.num_classes] 22 | 23 | if self.opt['beltrami']: 24 | x = F.dropout(x, self.opt['input_dropout'], training=self.training) 25 | x = self.mx(x) 26 | p = F.dropout(pos_encoding, self.opt['input_dropout'], training=self.training) 27 | p = self.mp(p) 28 | x = torch.cat([x, p], dim=1) 29 | else: 30 | x = F.dropout(x, self.opt['input_dropout'], training=self.training) 31 | x = self.m1(x) 32 | 33 | if self.opt['use_mlp']: 34 | x = F.dropout(x, self.opt['dropout'], training=self.training) 35 | x = F.dropout(x + self.m11(F.relu(x)), self.opt['dropout'], training=self.training) 36 | x = F.dropout(x + self.m12(F.relu(x)), self.opt['dropout'], training=self.training) 37 | # todo investigate if some input non-linearity solves the problem with smooth deformations identified in the ANODE paper 38 | 39 | if self.opt['use_labels']: 40 | x = torch.cat([x, y], dim=-1) 41 | 42 | if self.opt['batch_norm']: 43 | x = self.bn_in(x) 44 | 45 | # Solve the initial value problem of the ODE. 46 | if self.opt['augment']: 47 | c_aux = torch.zeros(x.shape).to(self.device) 48 | x = torch.cat([x, c_aux], dim=1) 49 | 50 | self.odeblock.set_x0(x) 51 | 52 | if self.training and self.odeblock.nreg > 0: 53 | z, self.reg_states = self.odeblock(x) 54 | else: 55 | z = self.odeblock(x) 56 | 57 | if self.opt['augment']: 58 | z = torch.split(z, x.shape[1] // 2, dim=1)[0] 59 | 60 | # Activation. 61 | z = F.relu(z) 62 | 63 | if self.opt['fc_out']: 64 | z = self.fc(z) 65 | z = F.relu(z) 66 | 67 | # Dropout. 68 | z = F.dropout(z, self.opt['dropout'], training=self.training) 69 | 70 | # Decode each node embedding to get node label. 71 | z = self.m2(z) 72 | return z 73 | -------------------------------------------------------------------------------- /src/GNN_KNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from base_classes import BaseGNN 5 | from model_configurations import set_block, set_function 6 | from graph_rewiring import KNN, add_edges, edge_sampling, GDCWrapper 7 | from utils import DummyData, get_full_adjacency 8 | 9 | # Define the GNN model. 10 | class GNN_KNN(BaseGNN): 11 | def __init__(self, opt, dataset, device=torch.device('cpu')): 12 | super(GNN_KNN, self).__init__(opt, dataset, device) 13 | self.f = set_function(opt) 14 | block = set_block(opt) 15 | time_tensor = torch.tensor([0, self.T]).to(device) 16 | self.odeblock = block(self.f, self.regularization_fns, opt, dataset.data, device, t=time_tensor).to(device) 17 | self.data_edge_index = dataset.data.edge_index.to(device) 18 | self.fa = get_full_adjacency(self.num_nodes).to(device) 19 | 20 | def forward(self, x, pos_encoding): 21 | # Encode each node based on its feature. 22 | if self.opt['use_labels']: 23 | y = x[:, -self.num_classes:] 24 | x = x[:, :-self.num_classes] 25 | 26 | if self.opt['beltrami']: 27 | x = F.dropout(x, self.opt['input_dropout'], training=self.training) 28 | x = self.mx(x) 29 | if self.opt['dataset'] == 'ogbn-arxiv': 30 | p = pos_encoding 31 | else: 32 | p = F.dropout(pos_encoding, self.opt['input_dropout'], training=self.training) 33 | p = self.mp(p) 34 | x = torch.cat([x, p], dim=1) 35 | else: 36 | x = F.dropout(x, self.opt['input_dropout'], training=self.training) 37 | x = self.m1(x) 38 | 39 | if self.opt['use_mlp']: 40 | x = F.dropout(x, self.opt['dropout'], training=self.training) 41 | x = F.dropout(x + self.m11(F.relu(x)), self.opt['dropout'], training=self.training) 42 | x = F.dropout(x + self.m12(F.relu(x)), self.opt['dropout'], training=self.training) 43 | 44 | # todo investigate if some input non-linearity solves the problem with smooth deformations identified in the ANODE paper 45 | # if True: 46 | # x = F.relu(x) 47 | if self.opt['use_labels']: 48 | x = torch.cat([x, y], dim=-1) 49 | 50 | if self.opt['batch_norm']: 51 | x = self.bn_in(x) 52 | 53 | # Solve the initial value problem of the ODE. 54 | if self.opt['augment']: 55 | c_aux = torch.zeros(x.shape).to(self.device) 56 | x = torch.cat([x, c_aux], dim=1) 57 | 58 | self.odeblock.set_x0(x) 59 | 60 | if self.training and self.odeblock.nreg > 0: 61 | z, self.reg_states = self.odeblock(x) 62 | else: 63 | z = self.odeblock(x) 64 | 65 | if self.opt['fa_layer']: 66 | temp_time = self.opt['time'] 67 | temp_method = self.opt['method'] 68 | temp_step_size = self.opt['step_size'] 69 | 70 | self.opt['time'] = 1 # self.opt['fa_layer_time'] #1.0 71 | self.opt['method'] = 'rk4' # self.opt['fa_layer_method']#'rk4' 72 | self.opt['step_size'] = 1 # self.opt['fa_layer_step_size']#1.0 73 | self.odeblock.set_x0(z) 74 | self.odeblock.odefunc.edge_index = add_edges(self, self.opt) 75 | if self.opt['edge_sampling_rmv'] != 0: 76 | edge_sampling(self, z, self.opt) 77 | 78 | z = self.odeblock(z) 79 | self.odeblock.odefunc.edge_index = self.data_edge_index 80 | 81 | self.opt['time'] = temp_time 82 | self.opt['method'] = temp_method 83 | self.opt['step_size'] = temp_step_size 84 | 85 | 86 | if self.opt['augment']: 87 | z = torch.split(z, x.shape[1] // 2, dim=1)[0] 88 | 89 | # if self.opt['batch_norm']: 90 | # z = self.bn_in(z) 91 | 92 | # Activation. 93 | z = F.relu(z) 94 | 95 | if self.opt['fc_out']: 96 | z = self.fc(z) 97 | z = F.relu(z) 98 | 99 | # Dropout. 100 | z = F.dropout(z, self.opt['dropout'], training=self.training) 101 | 102 | # Decode each node embedding to get node label. 103 | z = self.m2(z) 104 | return z 105 | 106 | def forward_encoder(self, x, pos_encoding): 107 | # Encode each node based on its feature. 108 | if self.opt['use_labels']: 109 | y = x[:, -self.num_classes:] 110 | x = x[:, :-self.num_classes] 111 | 112 | if self.opt['beltrami']: 113 | # x = F.dropout(x, self.opt['input_dropout'], training=self.training) 114 | x = self.mx(x) 115 | if self.opt['dataset'] == 'ogbn-arxiv': 116 | p = pos_encoding 117 | else: 118 | # p = F.dropout(pos_encoding, self.opt['input_dropout'], training=self.training) 119 | p = self.mp(pos_encoding) 120 | x = torch.cat([x, p], dim=1) 121 | else: 122 | # x = F.dropout(x, self.opt['input_dropout'], training=self.training) 123 | x = self.m1(x) 124 | 125 | if self.opt['use_mlp']: 126 | # x = F.dropout(x, self.opt['dropout'], training=self.training) 127 | # x = F.dropout(x + self.m11(F.relu(x)), self.opt['dropout'], training=self.training) 128 | # x = F.dropout(x + self.m12(F.relu(x)), self.opt['dropout'], training=self.training) 129 | x = x + self.m11(F.relu(x)) 130 | x = x + self.m12(F.relu(x)) 131 | 132 | # todo investigate if some input non-linearity solves the problem with smooth deformations identified in the ANODE paper 133 | # if True: 134 | # x = F.relu(x) 135 | if self.opt['use_labels']: 136 | x = torch.cat([x, y], dim=-1) 137 | 138 | if self.opt['batch_norm']: 139 | x = self.bn_in(x) 140 | 141 | # Solve the initial value problem of the ODE. 142 | if self.opt['augment']: 143 | c_aux = torch.zeros(x.shape).to(self.device) 144 | x = torch.cat([x, c_aux], dim=1) 145 | 146 | return x 147 | 148 | def forward_ODE(self, x, pos_encoding): 149 | x = self.forward_encoder(x, pos_encoding) 150 | 151 | self.odeblock.set_x0(x) 152 | 153 | if self.training and self.odeblock.nreg > 0: 154 | z, self.reg_states = self.odeblock(x) 155 | else: 156 | z = self.odeblock(x) 157 | 158 | if self.opt['fa_layer']: 159 | temp_time = self.opt['time'] 160 | temp_method = self.opt['method'] 161 | temp_step_size = self.opt['step_size'] 162 | 163 | self.opt['time'] = 1 # self.opt['fa_layer_time'] #1.0 164 | self.opt['method'] = 'rk4' # self.opt['fa_layer_method']#'rk4' 165 | self.opt['step_size'] = 1 # self.opt['fa_layer_step_size']#1.0 166 | self.odeblock.set_x0(z) 167 | self.odeblock.odefunc.edge_index = add_edges(self, self.opt) 168 | if self.opt['edge_sampling_rmv'] != 0: 169 | edge_sampling(self, z, self.opt) 170 | 171 | z = self.odeblock(z) 172 | self.odeblock.odefunc.edge_index = self.data_edge_index 173 | 174 | self.opt['time'] = temp_time 175 | self.opt['method'] = temp_method 176 | self.opt['step_size'] = temp_step_size 177 | 178 | 179 | if self.opt['augment']: 180 | z = torch.split(z, x.shape[1] // 2, dim=1)[0] 181 | 182 | return z 183 | -------------------------------------------------------------------------------- /src/GNN_image.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from base_classes import BaseGNN 5 | from model_configurations import set_block, set_function 6 | from collections import namedtuple 7 | 8 | 9 | # Define the GNN model. 10 | class GNN_image(BaseGNN): 11 | def __init__(self, opt, data, num_classes, device=torch.device('cpu')): 12 | DataWrapper = namedtuple('DataWrapper', ['num_features']) 13 | dw = DataWrapper(1) 14 | DatasetWrapper = namedtuple('DatasetWrapper', ['data', 'num_classes']) 15 | dsw = DatasetWrapper(dw, num_classes) 16 | super(GNN_image, self).__init__(opt, dsw, device) 17 | self.f = set_function(opt) 18 | self.block = set_block(opt) 19 | time_tensor = torch.tensor([0, self.T]).to(device) 20 | self.odeblocks = nn.ModuleList( 21 | [self.block(self.f, self.regularization_fns, opt, self.data, device, t=time_tensor) for dummy_i in 22 | range(self.n_ode_blocks)]).to(self.device) 23 | self.odeblock = self.block(self.f, self.regularization_fns, opt, self.data, device, t=time_tensor).to(self.device) 24 | 25 | self.m2 = nn.Linear(opt['im_width'] * opt['im_height'] * opt['im_chan'], num_classes) 26 | 27 | def forward(self, x): 28 | # Encode each node based on its feature. 29 | x = F.dropout(x, self.opt['input_dropout'], training=self.training) 30 | 31 | self.odeblock.set_x0(x) 32 | 33 | if self.training: 34 | z, self.reg_states = self.odeblock(x) 35 | else: 36 | z = self.odeblock(x) 37 | 38 | # Activation. 39 | z = F.relu(z) 40 | 41 | # Dropout. 42 | z = F.dropout(z, self.opt['dropout'], training=self.training) 43 | 44 | z = z.view(-1, self.opt['im_width'] * self.opt['im_height'] * self.opt['im_chan']) 45 | # Decode each node embedding to get node label. 46 | z = self.m2(z) 47 | return z 48 | 49 | def forward_plot_T(self, x): # the same as forward but without the decoder 50 | # Encode each node based on its feature. 51 | x = F.dropout(x, self.opt['input_dropout'], training=self.training) 52 | 53 | self.odeblock.set_x0(x) 54 | 55 | if self.training: 56 | z, self.reg_states = self.odeblock(x) 57 | else: 58 | z = self.odeblock(x) 59 | 60 | # Activation. 61 | z = F.relu(z) 62 | 63 | # Dropout. 64 | z = F.dropout(z, self.opt['dropout'], training=self.training) 65 | 66 | z = z.view(-1, self.opt['im_width'] * self.opt['im_height'] * self.opt['im_chan']) 67 | 68 | return z 69 | 70 | def forward_plot_path(self, x, frames): # stitch together ODE integrations 71 | # Encode each node based on its feature. 72 | x = F.dropout(x, self.opt['input_dropout'], training=self.training) 73 | z = x 74 | paths = [z.view(-1, self.opt['im_width'] * self.opt['im_height'] * self.opt['im_chan'])] 75 | for f in range(frames): 76 | self.odeblock.set_x0(z) # (x) 77 | if self.training: 78 | z, self.reg_states = self.odeblock(z) 79 | else: 80 | z = self.odeblock(z) 81 | # Activation. 82 | z = F.relu(z) 83 | # Dropout. 84 | z = F.dropout(z, self.opt['dropout'], training=self.training) 85 | path = z.view(-1, self.opt['im_width'] * self.opt['im_height'] * self.opt['im_chan']) 86 | print( 87 | f"Total Pixel intensity of the first image: {torch.sum(z[0:self.opt['im_width'] * self.opt['im_height'] * self.opt['im_chan'], :])}") 88 | print(f"{torch.sum(z[1:self.opt['im_width'] * self.opt['im_height'] * self.opt['im_chan'], :])}") 89 | print(f"{torch.sum(z[2:self.opt['im_width'] * self.opt['im_height'] * self.opt['im_chan'], :])}") 90 | 91 | paths.append(path) 92 | 93 | paths = torch.stack(paths, dim=1) 94 | return paths 95 | -------------------------------------------------------------------------------- /src/base_classes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_geometric.nn.conv import MessagePassing 4 | from utils import Meter 5 | from regularized_ODE_function import RegularizedODEfunc 6 | import regularized_ODE_function as reg_lib 7 | import six 8 | 9 | 10 | REGULARIZATION_FNS = { 11 | "kinetic_energy": reg_lib.quadratic_cost, 12 | "jacobian_norm2": reg_lib.jacobian_frobenius_regularization_fn, 13 | "total_deriv": reg_lib.total_derivative, 14 | "directional_penalty": reg_lib.directional_derivative 15 | } 16 | 17 | 18 | def create_regularization_fns(args): 19 | regularization_fns = [] 20 | regularization_coeffs = [] 21 | 22 | for arg_key, reg_fn in six.iteritems(REGULARIZATION_FNS): 23 | if args[arg_key] is not None: 24 | regularization_fns.append(reg_fn) 25 | regularization_coeffs.append(args[arg_key]) 26 | 27 | regularization_fns = regularization_fns 28 | regularization_coeffs = regularization_coeffs 29 | return regularization_fns, regularization_coeffs 30 | 31 | 32 | class ODEblock(nn.Module): 33 | def __init__(self, odefunc, regularization_fns, opt, data, device, t): 34 | super(ODEblock, self).__init__() 35 | self.opt = opt 36 | self.t = t 37 | 38 | self.aug_dim = 2 if opt['augment'] else 1 39 | self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device) 40 | 41 | self.nreg = len(regularization_fns) 42 | self.reg_odefunc = RegularizedODEfunc(self.odefunc, regularization_fns) 43 | 44 | if opt['adjoint']: 45 | from torchdiffeq import odeint_adjoint as odeint 46 | else: 47 | from torchdiffeq import odeint 48 | self.train_integrator = odeint 49 | self.test_integrator = None 50 | self.set_tol() 51 | 52 | def set_x0(self, x0): 53 | self.odefunc.x0 = x0.clone().detach() 54 | self.reg_odefunc.odefunc.x0 = x0.clone().detach() 55 | 56 | def set_tol(self): 57 | self.atol = self.opt['tol_scale'] * 1e-7 58 | self.rtol = self.opt['tol_scale'] * 1e-9 59 | if self.opt['adjoint']: 60 | self.atol_adjoint = self.opt['tol_scale_adjoint'] * 1e-7 61 | self.rtol_adjoint = self.opt['tol_scale_adjoint'] * 1e-9 62 | 63 | def reset_tol(self): 64 | self.atol = 1e-7 65 | self.rtol = 1e-9 66 | self.atol_adjoint = 1e-7 67 | self.rtol_adjoint = 1e-9 68 | 69 | def set_time(self, time): 70 | self.t = torch.tensor([0, time]).to(self.device) 71 | 72 | def __repr__(self): 73 | return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \ 74 | + ")" 75 | 76 | 77 | class ODEFunc(MessagePassing): 78 | 79 | # currently requires in_features = out_features 80 | def __init__(self, opt, data, device): 81 | super(ODEFunc, self).__init__() 82 | self.opt = opt 83 | self.device = device 84 | self.edge_index = None 85 | self.edge_weight = None 86 | self.attention_weights = None 87 | self.alpha_train = nn.Parameter(torch.tensor(0.0)) 88 | self.beta_train = nn.Parameter(torch.tensor(0.0)) 89 | self.x0 = None 90 | self.nfe = 0 91 | self.alpha_sc = nn.Parameter(torch.ones(1)) 92 | self.beta_sc = nn.Parameter(torch.ones(1)) 93 | 94 | def __repr__(self): 95 | return self.__class__.__name__ 96 | 97 | 98 | class BaseGNN(MessagePassing): 99 | def __init__(self, opt, dataset, device=torch.device('cpu')): 100 | super(BaseGNN, self).__init__() 101 | self.opt = opt 102 | self.T = opt['time'] 103 | self.num_classes = dataset.num_classes 104 | self.num_features = dataset.data.num_features 105 | self.num_nodes = dataset.data.num_nodes 106 | self.device = device 107 | self.fm = Meter() 108 | self.bm = Meter() 109 | 110 | if opt['beltrami']: 111 | self.mx = nn.Linear(self.num_features, opt['feat_hidden_dim']) 112 | self.mp = nn.Linear(opt['pos_enc_dim'], opt['pos_enc_hidden_dim']) 113 | opt['hidden_dim'] = opt['feat_hidden_dim'] + opt['pos_enc_hidden_dim'] 114 | else: 115 | self.m1 = nn.Linear(self.num_features, opt['hidden_dim']) 116 | 117 | if self.opt['use_mlp']: 118 | self.m11 = nn.Linear(opt['hidden_dim'], opt['hidden_dim']) 119 | self.m12 = nn.Linear(opt['hidden_dim'], opt['hidden_dim']) 120 | if opt['use_labels']: 121 | # todo - fastest way to propagate this everywhere, but error prone - refactor later 122 | opt['hidden_dim'] = opt['hidden_dim'] + dataset.num_classes 123 | else: 124 | self.hidden_dim = opt['hidden_dim'] 125 | if opt['fc_out']: 126 | self.fc = nn.Linear(opt['hidden_dim'], opt['hidden_dim']) 127 | self.m2 = nn.Linear(opt['hidden_dim'], dataset.num_classes) 128 | if self.opt['batch_norm']: 129 | self.bn_in = torch.nn.BatchNorm1d(opt['hidden_dim']) 130 | self.bn_out = torch.nn.BatchNorm1d(opt['hidden_dim']) 131 | 132 | self.regularization_fns, self.regularization_coeffs = create_regularization_fns(self.opt) 133 | 134 | def getNFE(self): 135 | return self.odeblock.odefunc.nfe + self.odeblock.reg_odefunc.odefunc.nfe 136 | 137 | def resetNFE(self): 138 | self.odeblock.odefunc.nfe = 0 139 | self.odeblock.reg_odefunc.odefunc.nfe = 0 140 | 141 | def reset(self): 142 | self.m1.reset_parameters() 143 | self.m2.reset_parameters() 144 | 145 | def __repr__(self): 146 | return self.__class__.__name__ 147 | -------------------------------------------------------------------------------- /src/block_constant.py: -------------------------------------------------------------------------------- 1 | from base_classes import ODEblock 2 | import torch 3 | from utils import get_rw_adj, gcn_norm_fill_val 4 | 5 | 6 | class ConstantODEblock(ODEblock): 7 | def __init__(self, odefunc, regularization_fns, opt, data, device, t=torch.tensor([0, 1])): 8 | super(ConstantODEblock, self).__init__(odefunc, regularization_fns, opt, data, device, t) 9 | 10 | self.aug_dim = 2 if opt['augment'] else 1 11 | self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device) 12 | if opt['data_norm'] == 'rw': 13 | edge_index, edge_weight = get_rw_adj(data.edge_index, edge_weight=data.edge_attr, norm_dim=1, 14 | fill_value=opt['self_loop_weight'], 15 | num_nodes=data.num_nodes, 16 | dtype=data.x.dtype) 17 | else: 18 | edge_index, edge_weight = gcn_norm_fill_val(data.edge_index, edge_weight=data.edge_attr, 19 | fill_value=opt['self_loop_weight'], 20 | num_nodes=data.num_nodes, 21 | dtype=data.x.dtype) 22 | self.odefunc.edge_index = edge_index.to(device) 23 | self.odefunc.edge_weight = edge_weight.to(device) 24 | self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight 25 | 26 | if opt['adjoint']: 27 | from torchdiffeq import odeint_adjoint as odeint 28 | else: 29 | from torchdiffeq import odeint 30 | 31 | self.train_integrator = odeint 32 | self.test_integrator = odeint 33 | self.set_tol() 34 | 35 | def forward(self, x): 36 | t = self.t.type_as(x) 37 | 38 | integrator = self.train_integrator if self.training else self.test_integrator 39 | 40 | reg_states = tuple( torch.zeros(x.size(0)).to(x) for i in range(self.nreg) ) 41 | 42 | func = self.reg_odefunc if self.training and self.nreg > 0 else self.odefunc 43 | state = (x,) + reg_states if self.training and self.nreg > 0 else x 44 | 45 | if self.opt["adjoint"] and self.training: 46 | state_dt = integrator( 47 | func, state, t, 48 | method=self.opt['method'], 49 | options=dict(step_size=self.opt['step_size'], max_iters=self.opt['max_iters']), 50 | adjoint_method=self.opt['adjoint_method'], 51 | adjoint_options=dict(step_size = self.opt['adjoint_step_size'], max_iters=self.opt['max_iters']), 52 | atol=self.atol, 53 | rtol=self.rtol, 54 | adjoint_atol=self.atol_adjoint, 55 | adjoint_rtol=self.rtol_adjoint) 56 | else: 57 | state_dt = integrator( 58 | func, state, t, 59 | method=self.opt['method'], 60 | options=dict(step_size=self.opt['step_size'], max_iters=self.opt['max_iters']), 61 | atol=self.atol, 62 | rtol=self.rtol) 63 | 64 | if self.training and self.nreg > 0: 65 | z = state_dt[0][1] 66 | reg_states = tuple( st[1] for st in state_dt[1:] ) 67 | return z, reg_states 68 | else: 69 | z = state_dt[1] 70 | return z 71 | 72 | def __repr__(self): 73 | return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \ 74 | + ")" 75 | -------------------------------------------------------------------------------- /src/block_constant_rewiring.py: -------------------------------------------------------------------------------- 1 | from base_classes import ODEblock 2 | import torch 3 | from utils import get_rw_adj, gcn_norm_fill_val 4 | import torch_sparse 5 | from torch_geometric.utils import get_laplacian 6 | import numpy as np 7 | 8 | class ConstantODEblock(ODEblock): 9 | def __init__(self, odefunc, regularization_fns, opt, data, device, t=torch.tensor([0, 1])): 10 | super(ConstantODEblock, self).__init__(odefunc, regularization_fns, opt, data, device, t) 11 | 12 | self.aug_dim = 2 if opt['augment'] else 1 13 | self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device) 14 | if opt['data_norm'] == 'rw': 15 | edge_index, edge_weight = get_rw_adj(data.edge_index, edge_weight=data.edge_attr, norm_dim=1, 16 | fill_value=opt['self_loop_weight'], 17 | num_nodes=data.num_nodes, 18 | dtype=data.x.dtype) 19 | else: 20 | edge_index, edge_weight = gcn_norm_fill_val(data.edge_index, edge_weight=data.edge_attr, 21 | fill_value=opt['self_loop_weight'], 22 | num_nodes=data.num_nodes, 23 | dtype=data.x.dtype) 24 | self.odefunc.edge_index = edge_index.to(device) 25 | self.odefunc.edge_weight = edge_weight.to(device) 26 | self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight 27 | 28 | if opt['adjoint']: 29 | from torchdiffeq import odeint_adjoint as odeint 30 | else: 31 | from torchdiffeq import odeint 32 | 33 | self.train_integrator = odeint 34 | self.test_integrator = odeint 35 | self.set_tol() 36 | 37 | 38 | def add_random_edges(self): 39 | #todo check if theres a pygeometric function for this 40 | 41 | # M = self.opt["M_nodes"] 42 | M = int(self.num_nodes * (1/(1 - (1 - self.opt['att_samp_pct'])) - 1)) 43 | 44 | with torch.no_grad(): 45 | new_edges = np.random.choice(self.num_nodes, size=(2,M), replace=True, p=None) 46 | new_edges = torch.tensor(new_edges) 47 | cat = torch.cat([self.data_edge_index, new_edges],dim=1) 48 | no_repeats = torch.unique(cat, sorted=False, return_inverse=False, 49 | return_counts=False, dim=0) 50 | self.data_edge_index = no_repeats 51 | 52 | def add_khop_edges(self, k): 53 | n = self.num_nodes 54 | # do k_hop 55 | for i in range(k): 56 | new_edges, new_weights = \ 57 | torch_sparse.spspmm(self.odefunc.edge_index, self.odefunc.edge_weight, 58 | self.odefunc.edge_index, self.odefunc.edge_weight, n, n, n, coalesced=False) 59 | self.edge_weight = 0.5 * self.edge_weight + 0.5 * new_weights 60 | cat = torch.cat([self.data_edge_index, new_edges], dim=1) 61 | self.edge_index = torch.unique(cat, sorted=False, return_inverse=False, 62 | return_counts=False, dim=0) 63 | # threshold 64 | # normalise 65 | 66 | # self.odefunc.edge_index, self.odefunc.edge_weight = 67 | # get_rw_adj(edge_index, edge_weight=None, norm_dim=1, fill_value=0., num_nodes=None, dtype=None): 68 | # num_nodes = maybe_num_nodes(edge_index, num_nodes) 69 | 70 | 71 | def forward(self, x): 72 | t = self.t.type_as(x) 73 | 74 | if self.training: 75 | if self.opt['new_edges'] == 'random': 76 | self.add_random_edges() 77 | elif self.opt['new_edges'] == 'k_hop': 78 | self.add_khop_edges(k=2) 79 | elif self.opt['new_edges'] == 'random_walk' and self.odefunc.attention_weights is not None: 80 | self.add_rw_edges() 81 | 82 | 83 | 84 | attention_weights = self.get_attention_weights(x) 85 | # create attention mask 86 | if self.training: 87 | with torch.no_grad(): 88 | mean_att = attention_weights.mean(dim=1, keepdim=False) 89 | if self.opt['use_flux']: 90 | src_features = x[self.data_edge_index[0, :], :] 91 | dst_features = x[self.data_edge_index[1, :], :] 92 | delta = torch.linalg.norm(src_features-dst_features, dim=1) 93 | mean_att = mean_att * delta 94 | threshold = torch.quantile(mean_att, 1-self.opt['att_samp_pct']) 95 | mask = mean_att > threshold 96 | self.odefunc.edge_index = self.data_edge_index[:, mask.T] 97 | sampled_attention_weights = self.renormalise_attention(mean_att[mask]) 98 | print('retaining {} of {} edges'.format(self.odefunc.edge_index.shape[1], self.data_edge_index.shape[1])) 99 | self.odefunc.attention_weights = sampled_attention_weights 100 | else: 101 | self.odefunc.edge_index = self.data_edge_index 102 | self.odefunc.attention_weights = attention_weights.mean(dim=1, keepdim=False) 103 | self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight 104 | self.reg_odefunc.odefunc.attention_weights = self.odefunc.attention_weights 105 | 106 | 107 | 108 | 109 | integrator = self.train_integrator if self.training else self.test_integrator 110 | 111 | reg_states = tuple( torch.zeros(x.size(0)).to(x) for i in range(self.nreg) ) 112 | 113 | func = self.reg_odefunc if self.training and self.nreg > 0 else self.odefunc 114 | state = (x,) + reg_states if self.training and self.nreg > 0 else x 115 | 116 | if self.opt["adjoint"] and self.training: 117 | state_dt = integrator( 118 | func, state, t, 119 | method=self.opt['method'], 120 | options=dict(step_size=self.opt['step_size'], max_iters=self.opt['max_iters']), 121 | adjoint_method=self.opt['adjoint_method'], 122 | adjoint_options=dict(step_size = self.opt['adjoint_step_size'], max_iters=self.opt['max_iters']), 123 | atol=self.atol, 124 | rtol=self.rtol, 125 | adjoint_atol=self.atol_adjoint, 126 | adjoint_rtol=self.rtol_adjoint) 127 | else: 128 | state_dt = integrator( 129 | func, state, t, 130 | method=self.opt['method'], 131 | options=dict(step_size=self.opt['step_size'], max_iters=self.opt['max_iters']), 132 | atol=self.atol, 133 | rtol=self.rtol) 134 | 135 | if self.training and self.nreg > 0: 136 | z = state_dt[0][1] 137 | reg_states = tuple( st[1] for st in state_dt[1:] ) 138 | return z, reg_states 139 | else: 140 | z = state_dt[1] 141 | return z 142 | 143 | def __repr__(self): 144 | return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \ 145 | + ")" 146 | -------------------------------------------------------------------------------- /src/block_mixed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from function_transformer_attention import SpGraphTransAttentionLayer 4 | from base_classes import ODEblock 5 | from utils import get_rw_adj 6 | 7 | 8 | class MixedODEblock(ODEblock): 9 | def __init__(self, odefunc, regularization_fns, opt, data, device, t=torch.tensor([0, 1]), gamma=0.): 10 | super(MixedODEblock, self).__init__(odefunc, regularization_fns, opt, data, device, t) 11 | 12 | self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device) 13 | # self.odefunc.edge_index, self.odefunc.edge_weight = data.edge_index, edge_weight=data.edge_attr 14 | edge_index, edge_weight = get_rw_adj(data.edge_index, edge_weight=data.edge_attr, norm_dim=1, 15 | fill_value=opt['self_loop_weight'], 16 | num_nodes=data.num_nodes, 17 | dtype=data.x.dtype) 18 | self.odefunc.edge_index = edge_index.to(device) 19 | self.odefunc.edge_weight = edge_weight.to(device) 20 | self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight 21 | 22 | if opt['adjoint']: 23 | from torchdiffeq import odeint_adjoint as odeint 24 | else: 25 | from torchdiffeq import odeint 26 | self.train_integrator = odeint 27 | self.test_integrator = odeint 28 | self.set_tol() 29 | # parameter trading off between attention and the Laplacian 30 | self.gamma = nn.Parameter(gamma * torch.ones(1)) 31 | self.multihead_att_layer = SpGraphTransAttentionLayer(opt['hidden_dim'], opt['hidden_dim'], opt, 32 | device).to(device) 33 | 34 | def get_attention_weights(self, x): 35 | attention, values = self.multihead_att_layer(x, self.odefunc.edge_index) 36 | return attention 37 | 38 | def get_mixed_attention(self, x): 39 | gamma = torch.sigmoid(self.gamma) 40 | attention = self.get_attention_weights(x) 41 | mixed_attention = attention.mean(dim=1) * (1 - gamma) + self.odefunc.edge_weight * gamma 42 | return mixed_attention 43 | 44 | def forward(self, x): 45 | t = self.t.type_as(x) 46 | self.odefunc.attention_weights = self.get_mixed_attention(x) 47 | integrator = self.train_integrator if self.training else self.test_integrator 48 | if self.opt["adjoint"] and self.training: 49 | z = integrator( 50 | self.odefunc, x, t, 51 | method=self.opt['method'], 52 | options={'step_size': self.opt['step_size']}, 53 | adjoint_method=self.opt['adjoint_method'], 54 | adjoint_options={'step_size': self.opt['adjoint_step_size']}, 55 | atol=self.atol, 56 | rtol=self.rtol, 57 | adjoint_atol=self.atol_adjoint, 58 | adjoint_rtol=self.rtol_adjoint)[1] 59 | else: 60 | z = integrator( 61 | self.odefunc, x, t, 62 | method=self.opt['method'], 63 | options={'step_size': self.opt['step_size']}, 64 | atol=self.atol, 65 | rtol=self.rtol)[1] 66 | 67 | return z 68 | 69 | def __repr__(self): 70 | return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \ 71 | + ")" 72 | -------------------------------------------------------------------------------- /src/block_transformer_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from function_transformer_attention import SpGraphTransAttentionLayer 3 | from base_classes import ODEblock 4 | from utils import get_rw_adj 5 | 6 | 7 | class AttODEblock(ODEblock): 8 | def __init__(self, odefunc, regularization_fns, opt, data, device, t=torch.tensor([0, 1]), gamma=0.5): 9 | super(AttODEblock, self).__init__(odefunc, regularization_fns, opt, data, device, t) 10 | 11 | self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device) 12 | # self.odefunc.edge_index, self.odefunc.edge_weight = data.edge_index, edge_weight=data.edge_attr 13 | edge_index, edge_weight = get_rw_adj(data.edge_index, edge_weight=data.edge_attr, norm_dim=1, 14 | fill_value=opt['self_loop_weight'], 15 | num_nodes=data.num_nodes, 16 | dtype=data.x.dtype) 17 | self.odefunc.edge_index = edge_index.to(device) 18 | self.odefunc.edge_weight = edge_weight.to(device) 19 | self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight 20 | 21 | if opt['adjoint']: 22 | from torchdiffeq import odeint_adjoint as odeint 23 | else: 24 | from torchdiffeq import odeint 25 | self.train_integrator = odeint 26 | self.test_integrator = odeint 27 | self.set_tol() 28 | # parameter trading off between attention and the Laplacian 29 | self.multihead_att_layer = SpGraphTransAttentionLayer(opt['hidden_dim'], opt['hidden_dim'], opt, 30 | device, edge_weights=self.odefunc.edge_weight).to(device) 31 | 32 | def get_attention_weights(self, x): 33 | attention, values = self.multihead_att_layer(x, self.odefunc.edge_index) 34 | return attention 35 | 36 | def forward(self, x): 37 | t = self.t.type_as(x) 38 | self.odefunc.attention_weights = self.get_attention_weights(x) 39 | self.reg_odefunc.odefunc.attention_weights = self.odefunc.attention_weights 40 | integrator = self.train_integrator if self.training else self.test_integrator 41 | 42 | reg_states = tuple(torch.zeros(x.size(0)).to(x) for i in range(self.nreg)) 43 | 44 | func = self.reg_odefunc if self.training and self.nreg > 0 else self.odefunc 45 | state = (x,) + reg_states if self.training and self.nreg > 0 else x 46 | 47 | if self.opt["adjoint"] and self.training: 48 | state_dt = integrator( 49 | func, state, t, 50 | method=self.opt['method'], 51 | options={'step_size': self.opt['step_size']}, 52 | adjoint_method=self.opt['adjoint_method'], 53 | adjoint_options={'step_size': self.opt['adjoint_step_size']}, 54 | atol=self.atol, 55 | rtol=self.rtol, 56 | adjoint_atol=self.atol_adjoint, 57 | adjoint_rtol=self.rtol_adjoint) 58 | else: 59 | state_dt = integrator( 60 | func, state, t, 61 | method=self.opt['method'], 62 | options={'step_size': self.opt['step_size']}, 63 | atol=self.atol, 64 | rtol=self.rtol) 65 | 66 | if self.training and self.nreg > 0: 67 | z = state_dt[0][1] 68 | reg_states = tuple(st[1] for st in state_dt[1:]) 69 | return z, reg_states 70 | else: 71 | z = state_dt[1] 72 | return z 73 | 74 | def __repr__(self): 75 | return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \ 76 | + ")" 77 | -------------------------------------------------------------------------------- /src/block_transformer_hard_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from function_transformer_attention import SpGraphTransAttentionLayer 3 | from base_classes import ODEblock 4 | from utils import get_rw_adj 5 | from torch_scatter import scatter 6 | 7 | class HardAttODEblock(ODEblock): 8 | def __init__(self, odefunc, regularization_fns, opt, data, device, t=torch.tensor([0, 1]), gamma=0.5): 9 | super(HardAttODEblock, self).__init__(odefunc, regularization_fns, opt, data, device, t) 10 | assert opt['att_samp_pct'] > 0 and opt['att_samp_pct'] <= 1, "attention sampling threshold must be in (0,1]" 11 | self.opt = opt 12 | self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device) 13 | # self.odefunc.edge_index, self.odefunc.edge_weight = data.edge_index, edge_weight=data.edge_attr 14 | self.num_nodes = data.num_nodes 15 | edge_index, edge_weight = get_rw_adj(data.edge_index, edge_weight=data.edge_attr, norm_dim=1, 16 | fill_value=opt['self_loop_weight'], 17 | num_nodes=data.num_nodes, 18 | dtype=data.x.dtype) 19 | self.data_edge_index = edge_index.to(device) 20 | self.odefunc.edge_index = edge_index.to(device) # this will be changed by attention scores 21 | self.odefunc.edge_weight = edge_weight.to(device) 22 | self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight 23 | 24 | if opt['adjoint']: 25 | from torchdiffeq import odeint_adjoint as odeint 26 | else: 27 | from torchdiffeq import odeint 28 | self.train_integrator = odeint 29 | self.test_integrator = odeint 30 | self.set_tol() 31 | # parameter trading off between attention and the Laplacian 32 | if opt['function'] not in {'GAT', 'transformer'}: 33 | self.multihead_att_layer = SpGraphTransAttentionLayer(opt['hidden_dim'], opt['hidden_dim'], opt, 34 | device, edge_weights=self.odefunc.edge_weight).to(device) 35 | 36 | def get_attention_weights(self, x): 37 | if self.opt['function'] not in {'GAT', 'transformer'}: 38 | attention, values = self.multihead_att_layer(x, self.data_edge_index) 39 | else: 40 | attention, values = self.odefunc.multihead_att_layer(x, self.data_edge_index) 41 | return attention 42 | 43 | def renormalise_attention(self, attention): 44 | index = self.odefunc.edge_index[self.opt['attention_norm_idx']] 45 | att_sums = scatter(attention, index, dim=0, dim_size=self.num_nodes, reduce='sum')[index] 46 | return attention / (att_sums + 1e-16) 47 | 48 | def forward(self, x): 49 | t = self.t.type_as(x) 50 | attention_weights = self.get_attention_weights(x) 51 | # create attention mask 52 | if self.training: 53 | with torch.no_grad(): 54 | mean_att = attention_weights.mean(dim=1, keepdim=False) 55 | if self.opt['use_flux']: 56 | src_features = x[self.data_edge_index[0, :], :] 57 | dst_features = x[self.data_edge_index[1, :], :] 58 | delta = torch.linalg.norm(src_features-dst_features, dim=1) 59 | mean_att = mean_att * delta 60 | threshold = torch.quantile(mean_att, 1-self.opt['att_samp_pct']) 61 | mask = mean_att > threshold 62 | self.odefunc.edge_index = self.data_edge_index[:, mask.T] 63 | sampled_attention_weights = self.renormalise_attention(mean_att[mask]) 64 | print('retaining {} of {} edges'.format(self.odefunc.edge_index.shape[1], self.data_edge_index.shape[1])) 65 | self.odefunc.attention_weights = sampled_attention_weights 66 | else: 67 | self.odefunc.edge_index = self.data_edge_index 68 | self.odefunc.attention_weights = attention_weights.mean(dim=1, keepdim=False) 69 | self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight 70 | self.reg_odefunc.odefunc.attention_weights = self.odefunc.attention_weights 71 | integrator = self.train_integrator if self.training else self.test_integrator 72 | 73 | reg_states = tuple(torch.zeros(x.size(0)).to(x) for i in range(self.nreg)) 74 | 75 | func = self.reg_odefunc if self.training and self.nreg > 0 else self.odefunc 76 | state = (x,) + reg_states if self.training and self.nreg > 0 else x 77 | 78 | if self.opt["adjoint"] and self.training: 79 | state_dt = integrator( 80 | func, state, t, 81 | method=self.opt['method'], 82 | options={'step_size': self.opt['step_size']}, 83 | adjoint_method=self.opt['adjoint_method'], 84 | adjoint_options={'step_size': self.opt['adjoint_step_size']}, 85 | atol=self.atol, 86 | rtol=self.rtol, 87 | adjoint_atol=self.atol_adjoint, 88 | adjoint_rtol=self.rtol_adjoint) 89 | else: 90 | state_dt = integrator( 91 | func, state, t, 92 | method=self.opt['method'], 93 | options={'step_size': self.opt['step_size']}, 94 | atol=self.atol, 95 | rtol=self.rtol) 96 | 97 | if self.training and self.nreg > 0: 98 | z = state_dt[0][1] 99 | reg_states = tuple(st[1] for st in state_dt[1:]) 100 | return z, reg_states 101 | else: 102 | z = state_dt[1] 103 | return z 104 | 105 | def __repr__(self): 106 | return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \ 107 | + ")" 108 | -------------------------------------------------------------------------------- /src/block_transformer_rewiring.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from function_transformer_attention import SpGraphTransAttentionLayer 3 | from base_classes import ODEblock 4 | from utils import get_rw_adj 5 | from torch_scatter import scatter 6 | import numpy as np 7 | import torch_sparse 8 | from torch_geometric.utils import remove_self_loops 9 | 10 | class RewireAttODEblock(ODEblock): 11 | def __init__(self, odefunc, regularization_fns, opt, data, device, t=torch.tensor([0, 1]), gamma=0.5): 12 | super(RewireAttODEblock, self).__init__(odefunc, regularization_fns, opt, data, device, t) 13 | assert opt['att_samp_pct'] > 0 and opt['att_samp_pct'] <= 1, "attention sampling threshold must be in (0,1]" 14 | self.opt = opt 15 | self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device) 16 | # self.odefunc.edge_index, self.odefunc.edge_weight = data.edge_index, edge_weight=data.edge_attr 17 | self.num_nodes = data.num_nodes 18 | edge_index, edge_weight = get_rw_adj(data.edge_index, edge_weight=data.edge_attr, norm_dim=1, 19 | fill_value=opt['self_loop_weight'], 20 | num_nodes=data.num_nodes, 21 | dtype=data.x.dtype) 22 | self.data_edge_index = edge_index.to(device) 23 | self.odefunc.edge_index = edge_index.to(device) # this will be changed by attention scores 24 | self.odefunc.edge_weight = edge_weight.to(device) 25 | self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight 26 | 27 | if opt['adjoint']: 28 | from torchdiffeq import odeint_adjoint as odeint 29 | else: 30 | from torchdiffeq import odeint 31 | self.train_integrator = odeint 32 | self.test_integrator = odeint 33 | self.set_tol() 34 | # parameter trading off between attention and the Laplacian 35 | if opt['function'] not in {'GAT', 'transformer'}: 36 | self.multihead_att_layer = SpGraphTransAttentionLayer(opt['hidden_dim'], opt['hidden_dim'], opt, 37 | device, edge_weights=self.odefunc.edge_weight).to(device) 38 | 39 | def get_attention_weights(self, x): 40 | if self.opt['function'] not in {'GAT', 'transformer'}: 41 | attention, values = self.multihead_att_layer(x, self.data_edge_index) 42 | else: 43 | attention, values = self.odefunc.multihead_att_layer(x, self.data_edge_index) 44 | return attention 45 | 46 | def renormalise_attention(self, attention): 47 | index = self.odefunc.edge_index[self.opt['attention_norm_idx']] 48 | att_sums = scatter(attention, index, dim=0, dim_size=self.num_nodes, reduce='sum')[index] 49 | return attention / (att_sums + 1e-16) 50 | 51 | 52 | def add_random_edges(self): 53 | # M = self.opt["M_nodes"] 54 | # M = int(self.num_nodes * (1/(1 - (1 - self.opt['att_samp_pct'])) - 1)) 55 | M = int(self.num_nodes * (1/(1 - (self.opt['rw_addD'])) - 1)) 56 | 57 | with torch.no_grad(): 58 | new_edges = np.random.choice(self.num_nodes, size=(2,M), replace=True, p=None) 59 | new_edges = torch.tensor(new_edges) 60 | #todo check if should be using coalesce insted of unique 61 | #eg https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/transforms/two_hop.html#TwoHop 62 | cat = torch.cat([self.data_edge_index, new_edges],dim=1) 63 | no_repeats = torch.unique(cat, sorted=False, return_inverse=False, 64 | return_counts=False, dim=1) 65 | self.data_edge_index = no_repeats 66 | self.odefunc.edge_index = self.data_edge_index 67 | 68 | def add_khop_edges(self, k=2, rm_self_loops=True): 69 | n = self.num_nodes 70 | for i in range(k-1): 71 | new_edges, new_weights = torch_sparse.spspmm(self.odefunc.edge_index, self.odefunc.edge_weight, 72 | self.odefunc.edge_index, self.odefunc.edge_weight, n, n, n, coalesced=True) 73 | 74 | new_edges, new_weights = remove_self_loops(new_edges, new_weights) 75 | 76 | # A1 = torch.sparse_coo_tensor(self.odefunc.edge_index, self.odefunc.edge_weight, (n, n)).coalesce() 77 | # A2 = torch.sparse_coo_tensor(new_edges, new_weights, (n, n)).coalesce() 78 | 79 | A1pA2_index = torch.cat([self.odefunc.edge_index, new_edges], dim=1) 80 | A1pA2_value = torch.cat([self.odefunc.edge_weight, new_weights], dim=0) / 2 81 | ei, ew = torch_sparse.coalesce(A1pA2_index, A1pA2_value, n, n, op="add") 82 | 83 | self.data_edge_index = ei 84 | self.odefunc.edge_index = self.data_edge_index 85 | self.odefunc.attention_weights = ew 86 | 87 | # if self.opt['threshold_type'] == 'topk_adj': 88 | # #todo not efficient see graph rewiting A+A.T 89 | # S_hat = (0.5 * A1 + 0.5 * A2).coalesce() 90 | # self.data_edge_index = S_hat.indices() 91 | # self.odefunc.attention_weights = S_hat.values() 92 | # 93 | # elif self.opt['threshold_type'] == 'addD_rvR': 94 | # pass 95 | # # AN = A2 96 | # # npA1idx = A1.indices().numpy().T 97 | # # npANidx = AN.indices().numpy().T 98 | # # 99 | # # A1_rows = np.ascontiguousarray(npA1idx).view(np.dtype((np.void, npA1idx.dtype.itemsize * npA1idx.shape[1]))) 100 | # # AN_rows = np.ascontiguousarray(npANidx).view(np.dtype((np.void, npANidx.dtype.itemsize * npANidx.shape[1]))) 101 | # # #todo use jax.numpy.in1d to do on GPU 102 | # # removed_mask = np.in1d(A1_rows, AN_rows, assume_unique=True, invert=True) 103 | # # added_mask = np.in1d(AN_rows, A1_rows, assume_unique=True, invert=True) 104 | # # 105 | # # assert len(A1_rows)+added_mask.sum()-removed_mask.sum()-len(AN_rows)==0 106 | # # 107 | # # threshold = torch.quantile(AN.values()[added_mask], 1 - self.opt['rw_addD']) 108 | # # threshold_mask = AN.values()[added_mask] > threshold 109 | # # 110 | # # add_edges = npANidx[added_mask,:][threshold_mask,:] 111 | # # add_values = AN.values()[added_mask][threshold_mask] 112 | # # print(f"Add {add_edges.shape[0]} edges") 113 | # # 114 | # # combined_edges = torch.cat((self.odefunc.edge_index, torch.from_numpy(add_edges).T), dim=1) 115 | # # combined_values = torch.cat((self.odefunc.edge_weight, add_values)) 116 | # # 117 | # # self.data_edge_index = combined_edges 118 | # # self.odefunc.edge_index = self.data_edge_index 119 | # # self.odefunc.attention_weights = combined_values 120 | 121 | # def add_rw_edges(self): #NOT COMPLETE 122 | # # function to sample M random walks rather than densifying Adjacency 123 | # # https: // github.com / rusty1s / pytorch_sparse / blob / master / torch_sparse / sample.py 124 | # # def sample(src: SparseTensor, num_neighbors: int, 125 | # # subset: Optional[torch.Tensor] = None) -> torch.Tensor: 126 | # M = int(self.num_nodes * (1/(1 - (1 - self.opt['att_samp_pct'])) - 1)) 127 | # with torch.no_grad(): 128 | # M_start = np.random.choice(self.num_nodes, size=(M), replace=True, p=None) 129 | # scale = 3.0 130 | # L = np.abs(np.random.normal(loc=0, scale=scale, size=(M))) 131 | # attention_weights = self.odefunc.attention_weights 132 | # M_end = torch.zeros(M) 133 | # for m, m_start in enumerate(M_start): 134 | # fuel = L[m] 135 | # while fuel > 0: 136 | # current_node_mask = self.data_edge_index[0,:] == m_start 137 | # p = attention_weights * current_node_mask 138 | # m_start = np.random.choice(len(p), size=(M), replace=True, p=p) 139 | # fuel -= 1 #written this way in case change cost of path length from 1 140 | # M_end[m] = m_start 141 | # # keep going until all steps taken 142 | # # L[m] -= 1 143 | 144 | def densify_edges(self): 145 | if self.opt['new_edges'] == 'random': 146 | self.add_random_edges() 147 | elif self.opt['new_edges'] == 'random_walk': 148 | self.add_rw_edges() 149 | elif self.opt['new_edges'] == 'k_hop_lap': 150 | pass 151 | elif self.opt['new_edges'] == 'k_hop_att': 152 | self.add_khop_edges(k=2) 153 | 154 | def threshold_edges(self, x, threshold): 155 | # get mean attention 156 | # i) sparsify on S_hat 157 | if self.opt['new_edges'] == 'k_hop_att' and self.opt['sparsify'] == 'S_hat': 158 | attention_weights = self.odefunc.attention_weights 159 | mean_att = attention_weights 160 | # ii) sparsify on recalculated attentions 161 | else:#elif self.opt['sparsify'] == 'recalc_att': 162 | attention_weights = self.get_attention_weights(x) 163 | mean_att = attention_weights.mean(dim=1, keepdim=False) 164 | 165 | if self.opt['use_flux']: 166 | src_features = x[self.data_edge_index[0, :], :] 167 | dst_features = x[self.data_edge_index[1, :], :] 168 | delta = torch.linalg.norm(src_features - dst_features, dim=1) 169 | mean_att = mean_att * delta 170 | 171 | # just for the test where threshold catches all edges 172 | # unique_att = torch.unique(mean_att, sorted=False, return_inverse=False, return_counts=False, dim=0) 173 | # print(f"mean_att {mean_att.shape}, unqiue atts: {unique_att.shape}") 174 | # threshold 175 | 176 | # threshold = torch.quantile(mean_att, 1 - self.opt['att_samp_pct']) 177 | mask = mean_att > threshold 178 | self.odefunc.edge_index = self.data_edge_index[:, mask.T] 179 | sampled_attention_weights = self.renormalise_attention(mean_att[mask]) 180 | print('retaining {} of {} edges'.format(self.odefunc.edge_index.shape[1], self.data_edge_index.shape[1])) 181 | self.data_edge_index = self.data_edge_index[:, mask.T] 182 | self.odefunc.edge_weight = sampled_attention_weights #rewiring structure so need to replace any preproc ew's with new ew's 183 | self.odefunc.attention_weights = sampled_attention_weights 184 | 185 | def forward(self, x): 186 | t = self.t.type_as(x) 187 | 188 | if self.training: 189 | with torch.no_grad(): 190 | #calc attentions for transition matrix 191 | attention_weights = self.get_attention_weights(x) 192 | self.odefunc.attention_weights = attention_weights.mean(dim=1, keepdim=False) 193 | 194 | # Densify and threshold attention weights 195 | pre_count = self.odefunc.edge_index.shape[1] 196 | self.densify_edges() 197 | post_count = self.odefunc.edge_index.shape[1] 198 | pc_change = post_count /pre_count - 1 199 | threshold = torch.quantile(self.odefunc.edge_weight, 1/(pc_change - self.opt['rw_addD'])) 200 | 201 | self.threshold_edges(x, threshold) 202 | 203 | self.odefunc.edge_index = self.data_edge_index 204 | attention_weights = self.get_attention_weights(x) 205 | mean_att = attention_weights.mean(dim=1, keepdim=False) 206 | self.odefunc.edge_weight = mean_att 207 | self.odefunc.attention_weights = mean_att 208 | 209 | self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight 210 | self.reg_odefunc.odefunc.attention_weights = self.odefunc.attention_weights 211 | integrator = self.train_integrator if self.training else self.test_integrator 212 | reg_states = tuple(torch.zeros(x.size(0)).to(x) for i in range(self.nreg)) 213 | func = self.reg_odefunc if self.training and self.nreg > 0 else self.odefunc 214 | state = (x,) + reg_states if self.training and self.nreg > 0 else x 215 | 216 | if self.opt["adjoint"] and self.training: 217 | state_dt = integrator( 218 | func, state, t, 219 | method=self.opt['method'], 220 | options={'step_size': self.opt['step_size']}, 221 | adjoint_method=self.opt['adjoint_method'], 222 | adjoint_options={'step_size': self.opt['adjoint_step_size']}, 223 | atol=self.atol, 224 | rtol=self.rtol, 225 | adjoint_atol=self.atol_adjoint, 226 | adjoint_rtol=self.rtol_adjoint) 227 | else: 228 | state_dt = integrator( 229 | func, state, t, 230 | method=self.opt['method'], 231 | options={'step_size': self.opt['step_size']}, 232 | atol=self.atol, 233 | rtol=self.rtol) 234 | 235 | if self.training and self.nreg > 0: 236 | z = state_dt[0][1] 237 | reg_states = tuple(st[1] for st in state_dt[1:]) 238 | return z, reg_states 239 | else: 240 | z = state_dt[1] 241 | return z 242 | 243 | def __repr__(self): 244 | return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \ 245 | + ")" 246 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code partially copied from 'Diffusion Improves Graph Learning' repo https://github.com/klicperajo/gdc/blob/master/data.py 3 | """ 4 | 5 | import os 6 | 7 | import numpy as np 8 | 9 | import torch 10 | from torch_geometric.data import Data, InMemoryDataset 11 | from torch_geometric.datasets import Planetoid, Amazon, Coauthor 12 | from graph_rewiring import get_two_hop, apply_gdc 13 | from ogb.nodeproppred import PygNodePropPredDataset 14 | import torch_geometric.transforms as T 15 | from torch_geometric.utils import to_undirected 16 | from graph_rewiring import make_symmetric, apply_pos_dist_rewire 17 | from heterophilic import WebKB, WikipediaNetwork, Actor 18 | from utils import ROOT_DIR 19 | 20 | DATA_PATH = f'{ROOT_DIR}/data' 21 | 22 | 23 | def rewire(data, opt, data_dir): 24 | rw = opt['rewiring'] 25 | if rw == 'two_hop': 26 | data = get_two_hop(data) 27 | elif rw == 'gdc': 28 | data = apply_gdc(data, opt) 29 | elif rw == 'pos_enc_knn': 30 | data = apply_pos_dist_rewire(data, opt, data_dir) 31 | return data 32 | 33 | 34 | def get_dataset(opt: dict, data_dir, use_lcc: bool = False) -> InMemoryDataset: 35 | ds = opt['dataset'] 36 | path = os.path.join(data_dir, ds) 37 | if ds in ['Cora', 'Citeseer', 'Pubmed']: 38 | dataset = Planetoid(path, ds) 39 | elif ds in ['Computers', 'Photo']: 40 | dataset = Amazon(path, ds) 41 | elif ds == 'CoauthorCS': 42 | dataset = Coauthor(path, 'CS') 43 | elif ds in ['cornell', 'texas', 'wisconsin']: 44 | dataset = WebKB(root=path, name=ds, transform=T.NormalizeFeatures()) 45 | elif ds in ['chameleon', 'squirrel']: 46 | dataset = WikipediaNetwork(root=path, name=ds, transform=T.NormalizeFeatures()) 47 | elif ds == 'film': 48 | dataset = Actor(root=path, transform=T.NormalizeFeatures()) 49 | elif ds == 'ogbn-arxiv': 50 | dataset = PygNodePropPredDataset(name=ds, root=path, 51 | transform=T.ToSparseTensor()) 52 | use_lcc = False # never need to calculate the lcc with ogb datasets 53 | else: 54 | raise Exception('Unknown dataset.') 55 | 56 | if use_lcc: 57 | lcc = get_largest_connected_component(dataset) 58 | 59 | x_new = dataset.data.x[lcc] 60 | y_new = dataset.data.y[lcc] 61 | 62 | row, col = dataset.data.edge_index.numpy() 63 | edges = [[i, j] for i, j in zip(row, col) if i in lcc and j in lcc] 64 | edges = remap_edges(edges, get_node_mapper(lcc)) 65 | 66 | data = Data( 67 | x=x_new, 68 | edge_index=torch.LongTensor(edges), 69 | y=y_new, 70 | train_mask=torch.zeros(y_new.size()[0], dtype=torch.bool), 71 | test_mask=torch.zeros(y_new.size()[0], dtype=torch.bool), 72 | val_mask=torch.zeros(y_new.size()[0], dtype=torch.bool) 73 | ) 74 | dataset.data = data 75 | if opt['rewiring'] is not None: 76 | dataset.data = rewire(dataset.data, opt, data_dir) 77 | train_mask_exists = True 78 | try: 79 | dataset.data.train_mask 80 | except AttributeError: 81 | train_mask_exists = False 82 | 83 | if ds == 'ogbn-arxiv': 84 | split_idx = dataset.get_idx_split() 85 | ei = to_undirected(dataset.data.edge_index) 86 | data = Data( 87 | x=dataset.data.x, 88 | edge_index=ei, 89 | y=dataset.data.y, 90 | train_mask=split_idx['train'], 91 | test_mask=split_idx['test'], 92 | val_mask=split_idx['valid']) 93 | dataset.data = data 94 | train_mask_exists = True 95 | 96 | #todo this currently breaks with heterophilic datasets if you don't pass --geom_gcn_splits 97 | if (use_lcc or not train_mask_exists) and not opt['geom_gcn_splits']: 98 | dataset.data = set_train_val_test_split( 99 | 12345, 100 | dataset.data, 101 | num_development=5000 if ds == "CoauthorCS" else 1500) 102 | 103 | return dataset 104 | 105 | 106 | def get_component(dataset: InMemoryDataset, start: int = 0) -> set: 107 | visited_nodes = set() 108 | queued_nodes = set([start]) 109 | row, col = dataset.data.edge_index.numpy() 110 | while queued_nodes: 111 | current_node = queued_nodes.pop() 112 | visited_nodes.update([current_node]) 113 | neighbors = col[np.where(row == current_node)[0]] 114 | neighbors = [n for n in neighbors if n not in visited_nodes and n not in queued_nodes] 115 | queued_nodes.update(neighbors) 116 | return visited_nodes 117 | 118 | 119 | def get_largest_connected_component(dataset: InMemoryDataset) -> np.ndarray: 120 | remaining_nodes = set(range(dataset.data.x.shape[0])) 121 | comps = [] 122 | while remaining_nodes: 123 | start = min(remaining_nodes) 124 | comp = get_component(dataset, start) 125 | comps.append(comp) 126 | remaining_nodes = remaining_nodes.difference(comp) 127 | return np.array(list(comps[np.argmax(list(map(len, comps)))])) 128 | 129 | 130 | def get_node_mapper(lcc: np.ndarray) -> dict: 131 | mapper = {} 132 | counter = 0 133 | for node in lcc: 134 | mapper[node] = counter 135 | counter += 1 136 | return mapper 137 | 138 | 139 | def remap_edges(edges: list, mapper: dict) -> list: 140 | row = [e[0] for e in edges] 141 | col = [e[1] for e in edges] 142 | row = list(map(lambda x: mapper[x], row)) 143 | col = list(map(lambda x: mapper[x], col)) 144 | return [row, col] 145 | 146 | 147 | def set_train_val_test_split( 148 | seed: int, 149 | data: Data, 150 | num_development: int = 1500, 151 | num_per_class: int = 20) -> Data: 152 | rnd_state = np.random.RandomState(seed) 153 | num_nodes = data.y.shape[0] 154 | development_idx = rnd_state.choice(num_nodes, num_development, replace=False) 155 | test_idx = [i for i in np.arange(num_nodes) if i not in development_idx] 156 | 157 | train_idx = [] 158 | rnd_state = np.random.RandomState(seed) 159 | for c in range(data.y.max() + 1): 160 | class_idx = development_idx[np.where(data.y[development_idx].cpu() == c)[0]] 161 | train_idx.extend(rnd_state.choice(class_idx, num_per_class, replace=False)) 162 | 163 | val_idx = [i for i in development_idx if i not in train_idx] 164 | 165 | def get_mask(idx): 166 | mask = torch.zeros(num_nodes, dtype=torch.bool) 167 | mask[idx] = 1 168 | return mask 169 | 170 | data.train_mask = get_mask(train_idx) 171 | data.val_mask = get_mask(val_idx) 172 | data.test_mask = get_mask(test_idx) 173 | 174 | return data 175 | -------------------------------------------------------------------------------- /src/deepwalk_embeddings.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import argparse 3 | 4 | import torch 5 | import matplotlib.pyplot as plt 6 | from sklearn.manifold import TSNE 7 | from torch_geometric.datasets import Planetoid 8 | from torch_geometric.nn import Node2Vec 9 | import time 10 | import pickle 11 | from data import get_dataset 12 | 13 | 14 | def main(opt): 15 | dataset_name = opt['dataset'] 16 | 17 | print(f"[i] Generating embeddings for dataset: {dataset_name}") 18 | dataset = get_dataset(opt, '../data', opt['not_lcc']) 19 | data = dataset.data 20 | 21 | device = torch.device(f"cuda:{opt['gpu']}" if torch.cuda.is_available() else 'cpu') 22 | 23 | model = Node2Vec(data.edge_index, embedding_dim=opt['embedding_dim'], walk_length=opt['walk_length'], 24 | context_size=opt['context_size'], walks_per_node=opt['walks_per_node'], 25 | num_negative_samples=opt['neg_pos_ratio'], p=1, q=1, sparse=True).to(device) 26 | 27 | loader = model.loader(batch_size=128, shuffle=True, num_workers=4) 28 | optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01) 29 | 30 | def train(): 31 | model.train() 32 | total_loss = 0 33 | for pos_rw, neg_rw in loader: 34 | optimizer.zero_grad() 35 | loss = model.loss(pos_rw.to(device), neg_rw.to(device)) 36 | loss.backward() 37 | optimizer.step() 38 | total_loss += loss.item() 39 | return total_loss / len(loader) 40 | 41 | @torch.no_grad() 42 | def test(): 43 | model.eval() 44 | z = model() 45 | acc = model.test(z[data.train_mask], data.y[data.train_mask], 46 | z[data.test_mask], data.y[data.test_mask], 47 | max_iter=150) 48 | return acc, z 49 | 50 | 51 | ### here be main code 52 | t = time.time() 53 | for epoch in range(1, opt['epochs']+1): 54 | loss = train() 55 | train_t = time.time() - t 56 | t = time.time() 57 | acc, _ = test() 58 | test_t = time.time() - t 59 | print(f'Epoch: {epoch:02d}, Train: {train_t:.2f}, Test: {test_t:.2f}, Loss: {loss:.4f}, Acc: {acc:.4f}') 60 | 61 | 62 | acc, z = test() 63 | print(f"[i] Final accuracy is {acc}") 64 | print(f"[i] Embedding shape is {z.data.shape}") 65 | 66 | fname = "DW_%s_emb_%03d_wl_%03d_cs_%02d_wn_%02d_epochs_%03d.pickle" % ( 67 | opt['dataset'], opt['embedding_dim'], opt['walk_length'], opt['context_size'], opt['walks_per_node'], opt['epochs'] 68 | ) 69 | 70 | print(f"[i] Storing embeddings in {fname}") 71 | 72 | with open(osp.join("../data/pos_encodings", fname), 'wb') as f: 73 | # make sure the pickle is not bound to any gpu, and store test acc with data 74 | pickle.dump({"data": z.data.to(torch.device("cpu")), "acc": acc}, f) 75 | 76 | 77 | 78 | if __name__ == "__main__": 79 | 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument('--dataset', type=str, default='Cora', 82 | help='Cora, Citeseer, Pubmed, Computers, Photo, CoauthorCS, ogbn-arxiv') 83 | parser.add_argument('--embedding_dim', type=int, default=128, 84 | help='Embedding dimension') 85 | parser.add_argument('--walk_length', type=int, default=20, # note this can grow much bigger (paper: 40~100) 86 | help='Walk length') 87 | parser.add_argument('--context_size', type=int, default=16,# paper shows increased perf until 16 88 | help='Context size') 89 | parser.add_argument('--walks_per_node', type=int, default=16, # best paper results with 18 90 | help='Walks per node') 91 | parser.add_argument('--neg_pos_ratio', type=int, default=1, 92 | help='Number of negatives for each positive') 93 | parser.add_argument('--epochs', type=int, default=100, 94 | help='Number of epochs') 95 | parser.add_argument('--gpu', type=int, default=0, 96 | help='GPU id (default 0)') 97 | parser.add_argument("--not_lcc", action="store_false", help="don't use the largest connected component") 98 | 99 | 100 | args = parser.parse_args() 101 | opt = vars(args) 102 | opt['rewiring'] = None 103 | main(opt) 104 | -------------------------------------------------------------------------------- /src/deepwalk_gen.sh: -------------------------------------------------------------------------------- 1 | python deepwalk_embeddings.py --dataset Citeseer --embedding_dim 64 --walk_length 100 2 | python deepwalk_embeddings.py --dataset Citeseer --embedding_dim 128 --walk_length 20 3 | python deepwalk_embeddings.py --dataset Citeseer --embedding_dim 256 --walk_length 20 4 | 5 | python deepwalk_embeddings.py --dataset CoauthorCS --embedding_dim 64 --walk_length 100 6 | python deepwalk_embeddings.py --dataset CoauthorCS --embedding_dim 128 --walk_length 100 7 | python deepwalk_embeddings.py --dataset CoauthorCS --embedding_dim 256 --walk_length 60 8 | 9 | python deepwalk_embeddings.py --dataset Computers --embedding_dim 64 --walk_length 100 10 | python deepwalk_embeddings.py --dataset Computers --embedding_dim 128 --walk_length 100 11 | python deepwalk_embeddings.py --dataset Computers --embedding_dim 256 --walk_length 40 12 | 13 | python deepwalk_embeddings.py --dataset Cora --embedding_dim 64 --walk_length 80 14 | python deepwalk_embeddings.py --dataset Cora --embedding_dim 128 --walk_length 40 15 | python deepwalk_embeddings.py --dataset Cora --embedding_dim 256 --walk_length 40 16 | 17 | python deepwalk_embeddings.py --dataset Photo --embedding_dim 64 --walk_length 40 18 | python deepwalk_embeddings.py --dataset Photo --embedding_dim 128 --walk_length 100 19 | python deepwalk_embeddings.py --dataset Photo --embedding_dim 256 --walk_length 40 20 | 21 | python deepwalk_embeddings.py --dataset Pubmed --embedding_dim 64 --walk_length 100 22 | python deepwalk_embeddings.py --dataset Pubmed --embedding_dim 128 --walk_length 40 23 | python deepwalk_embeddings.py --dataset Pubmed --embedding_dim 256 --walk_length 60 24 | 25 | python deepwalk_embeddings.py --dataset ogbn-arxiv --embedding_dim 64 --walk_length 80 26 | python deepwalk_embeddings.py --dataset ogbn-arxiv --embedding_dim 128 --walk_length 40 27 | python deepwalk_embeddings.py --dataset ogbn-arxiv --embedding_dim 256 --walk_length 40 28 | -------------------------------------------------------------------------------- /src/deepwalk_gen_symlinks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import pickle 4 | import argparse 5 | 6 | def main(opt): 7 | 8 | all_datasets = ["Cora", "Citeseer", "Pubmed", "Computers", "Photo", "CoauthorCS", "ogbn-arxiv"] 9 | all_embedding_dims = [64, 128, 256] 10 | data_path = "../data/pos_encodings/" 11 | 12 | if opt['dataset'] == "ALL": 13 | datasets = all_datasets 14 | else: 15 | datasets = [opt['dataset']] 16 | 17 | if opt['embedding_dim'] == 0: 18 | embedding_dims = all_embedding_dims 19 | else: 20 | embedding_dims = [opt['embedding_dim']] 21 | 22 | for dataset in datasets: 23 | for embedding_dim in embedding_dims: 24 | fname = f"DW_{dataset}_emb_{embedding_dim:03d}*" 25 | 26 | pickles = glob.glob(os.path.join(data_path, fname)) 27 | 28 | max_acc = 0 29 | best_emb = None 30 | 31 | for p in pickles: 32 | with open(p, "rb") as f: 33 | data = pickle.load(f) 34 | acc = data['acc'] 35 | print(f"Model {p} has accuracy {acc}") 36 | if acc > max_acc: 37 | max_acc = acc 38 | best_emb = p 39 | 40 | print(f"=> The best model is {best_emb} with accuracy {max_acc}") 41 | 42 | print("Removing previous symlink...") 43 | os.system(f"rm {data_path}{dataset}_DW{embedding_dim}.pkl") 44 | 45 | command = f"ln -s {p[len(data_path):]} {data_path}{dataset}_DW{embedding_dim}.pkl" 46 | print(f"Running: {command}") 47 | os.system(command) 48 | 49 | if __name__ == "__main__": 50 | 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument('--dataset', type=str, default='ALL', 53 | help='Cora, Citeseer, Pubmed, Computers, Photo, CoauthorCS, ogbn-arxiv') 54 | parser.add_argument('--embedding_dim', type=int, default=0, 55 | help='Embedding dimension') 56 | 57 | 58 | args = parser.parse_args() 59 | opt = vars(args) 60 | main(opt) 61 | 62 | -------------------------------------------------------------------------------- /src/distances_kNN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.neighbors import NearestNeighbors, KDTree, BallTree, DistanceMetric 3 | 4 | 5 | def apply_feat_KNN(x, k): 6 | nbrs = NearestNeighbors(n_neighbors=k).fit(x) 7 | distances, indices = nbrs.kneighbors(x) 8 | src = np.linspace(0, len(x) * k, len(x) * k + 1)[:-1] // k 9 | dst = indices.reshape(-1) 10 | ei = np.vstack((src, dst)) 11 | return ei 12 | 13 | def apply_dist_KNN(x, k): 14 | nbrs = NearestNeighbors(n_neighbors=k, metric='precomputed').fit(x) 15 | distances, indices = nbrs.kneighbors(x) 16 | src = np.linspace(0, len(x) * k, len(x) * k + 1)[:-1] // k 17 | dst = indices.reshape(-1) 18 | ei = np.vstack((src, dst)) 19 | return ei 20 | 21 | def threshold_mat(dist, quant=1/1000): 22 | thresh = np.quantile(dist, quant, axis=None) 23 | A = dist <= thresh 24 | return A 25 | 26 | def make_ei(A): 27 | src, dst = np.where(A) 28 | ei = np.vstack((src, dst)) 29 | return ei 30 | 31 | def apply_dist_threshold(dist, quant=1/1000): 32 | return make_ei(threshold_mat(dist, quant)) 33 | 34 | 35 | def get_distances(x): 36 | dist = DistanceMetric.get_metric('euclidean') 37 | return dist.pairwise(x) 38 | 39 | if __name__ == "__main__": 40 | # triangele 41 | # dist = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]]) 42 | # square 43 | dist = np.array([[0, 1, 1, np.sqrt(2)], [1, 0, np.sqrt(2), 1], [1, np.sqrt(2), 0, 1], [np.sqrt(2), 1, 1, 0]]) 44 | print(f"distances \n {dist}") 45 | 46 | for k in range(4): # 3 47 | print(f"{k + 1} edges \n {apply_dist_KNN(dist, k + 1)}") 48 | 49 | quant= 0.75 50 | thresh = np.quantile(dist, quant, axis=None) 51 | 52 | A = threshold_mat(dist, quant) 53 | print(f"Threshold mat \n {A}") 54 | print(f"Edge index1 \n {make_ei(A)}") 55 | print(f"Edge index2 \n {apply_dist_threshold(dist, quant)}") 56 | 57 | square = np.array([[0,1],[1,1],[0,0],[1,0]]) 58 | sq_dist = get_distances(square) 59 | print(f"sq_dist \n {sq_dist}") -------------------------------------------------------------------------------- /src/early_stop_solver.py: -------------------------------------------------------------------------------- 1 | import torchdiffeq 2 | from torchdiffeq._impl.dopri5 import _DORMAND_PRINCE_SHAMPINE_TABLEAU, DPS_C_MID 3 | from torchdiffeq._impl.solvers import FixedGridODESolver 4 | import torch 5 | from torchdiffeq._impl.misc import _check_inputs, _flat_to_shape 6 | import torch.nn.functional as F 7 | import copy 8 | 9 | from torchdiffeq._impl.interp import _interp_evaluate 10 | from torchdiffeq._impl.rk_common import RKAdaptiveStepsizeODESolver, rk4_alt_step_func 11 | from ogb.nodeproppred import Evaluator 12 | 13 | 14 | def run_evaluator(evaluator, data, y_pred): 15 | train_acc = evaluator.eval({ 16 | 'y_true': data.y[data.train_mask], 17 | 'y_pred': y_pred[data.train_mask], 18 | })['acc'] 19 | valid_acc = evaluator.eval({ 20 | 'y_true': data.y[data.val_mask], 21 | 'y_pred': y_pred[data.val_mask], 22 | })['acc'] 23 | test_acc = evaluator.eval({ 24 | 'y_true': data.y[data.test_mask], 25 | 'y_pred': y_pred[data.test_mask], 26 | })['acc'] 27 | return train_acc, valid_acc, test_acc 28 | 29 | 30 | class EarlyStopDopri5(RKAdaptiveStepsizeODESolver): 31 | order = 5 32 | tableau = _DORMAND_PRINCE_SHAMPINE_TABLEAU 33 | mid = DPS_C_MID 34 | 35 | def __init__(self, func, y0, rtol, atol, opt, **kwargs): 36 | super(EarlyStopDopri5, self).__init__(func, y0, rtol, atol, **kwargs) 37 | 38 | self.lf = torch.nn.CrossEntropyLoss() 39 | self.m2_weight = None 40 | self.m2_bias = None 41 | self.data = None 42 | self.best_val = 0 43 | self.best_test = 0 44 | self.max_test_steps = opt['max_test_steps'] 45 | self.best_time = 0 46 | self.ode_test = self.test_OGB if opt['dataset'] == 'ogbn-arxiv' else self.test 47 | self.dataset = opt['dataset'] 48 | if opt['dataset'] == 'ogbn-arxiv': 49 | self.lf = torch.nn.functional.nll_loss 50 | self.evaluator = Evaluator(name=opt['dataset']) 51 | 52 | def set_accs(self, train, val, test, time): 53 | self.best_train = train 54 | self.best_val = val 55 | self.best_test = test 56 | self.best_time = time.item() 57 | 58 | def integrate(self, t): 59 | solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device) 60 | solution[0] = self.y0 61 | t = t.to(self.dtype) 62 | self._before_integrate(t) 63 | new_t = t 64 | for i in range(1, len(t)): 65 | new_t, y = self.advance(t[i]) 66 | solution[i] = y 67 | return new_t, solution 68 | 69 | def advance(self, next_t): 70 | """ 71 | Takes steps dt to get to the next user specified time point next_t. In practice this goes past next_t and then interpolates 72 | :param next_t: 73 | :return: The state, x(next_t) 74 | """ 75 | n_steps = 0 76 | while next_t > self.rk_state.t1 and n_steps < self.max_test_steps: 77 | self.rk_state = self._adaptive_step(self.rk_state) 78 | n_steps += 1 79 | train_acc, val_acc, test_acc = self.evaluate(self.rk_state) 80 | if val_acc > self.best_val: 81 | self.set_accs(train_acc, val_acc, test_acc, self.rk_state.t1) 82 | new_t = next_t 83 | if n_steps < self.max_test_steps: 84 | return (new_t, _interp_evaluate(self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, next_t)) 85 | else: 86 | return (new_t, _interp_evaluate(self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, self.rk_state.t1)) 87 | 88 | @torch.no_grad() 89 | def test(self, logits): 90 | accs = [] 91 | for _, mask in self.data('train_mask', 'val_mask', 'test_mask'): 92 | pred = logits[mask].max(1)[1] 93 | acc = pred.eq(self.data.y[mask]).sum().item() / mask.sum().item() 94 | accs.append(acc) 95 | return accs 96 | 97 | @torch.no_grad() 98 | def test_OGB(self, logits): 99 | evaluator = self.evaluator 100 | data = self.data 101 | y_pred = logits.argmax(dim=-1, keepdim=True) 102 | train_acc, valid_acc, test_acc = run_evaluator(evaluator, data, y_pred) 103 | return [train_acc, valid_acc, test_acc] 104 | 105 | @torch.no_grad() 106 | def evaluate(self, rkstate): 107 | # Activation. 108 | z = rkstate.y1 109 | if not self.m2_weight.shape[1] == z.shape[1]: # system has been augmented 110 | z = torch.split(z, self.m2_weight.shape[1], dim=1)[0] 111 | z = F.relu(z) 112 | z = F.linear(z, self.m2_weight, self.m2_bias) 113 | t0, t1 = float(self.rk_state.t0), float(self.rk_state.t1) 114 | if self.dataset == 'ogbn-arxiv': 115 | z = z.log_softmax(dim=-1) 116 | loss = self.lf(z[self.data.train_mask], self.data.y.squeeze()[self.data.train_mask]) 117 | else: 118 | loss = self.lf(z[self.data.train_mask], self.data.y[self.data.train_mask]) 119 | train_acc, val_acc, test_acc = self.ode_test(z) 120 | log = 'ODE eval t0 {:.3f}, t1 {:.3f} Loss: {:.4f}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' 121 | # print(log.format(t0, t1, loss, train_acc, val_acc, tmp_test_acc)) 122 | return train_acc, val_acc, test_acc 123 | 124 | def set_m2(self, m2): 125 | self.m2 = copy.deepcopy(m2) 126 | 127 | def set_data(self, data): 128 | if self.data is None: 129 | self.data = data 130 | 131 | class EarlyStopRK4(FixedGridODESolver): 132 | order = 4 133 | 134 | def __init__(self, func, y0, opt, eps=0, **kwargs): 135 | super(EarlyStopRK4, self).__init__(func, y0, **kwargs) 136 | self.eps = torch.as_tensor(eps, dtype=self.dtype, device=self.device) 137 | self.lf = torch.nn.CrossEntropyLoss() 138 | self.m2_weight = None 139 | self.m2_bias = None 140 | self.data = None 141 | self.best_val = 0 142 | self.best_test = 0 143 | self.best_time = 0 144 | self.ode_test = self.test_OGB if opt['dataset'] == 'ogbn-arxiv' else self.test 145 | self.dataset = opt['dataset'] 146 | if opt['dataset'] == 'ogbn-arxiv': 147 | self.lf = torch.nn.functional.nll_loss 148 | self.evaluator = Evaluator(name=opt['dataset']) 149 | 150 | def _step_func(self, func, t, dt, t1, y): 151 | ver = torchdiffeq.__version__[0] + torchdiffeq.__version__[2] + torchdiffeq.__version__[4] 152 | if int(ver) >= 22: # '0.2.2' 153 | return rk4_alt_step_func(func, t + self.eps, dt - 2 * self.eps, t1, y) 154 | else: 155 | return rk4_alt_step_func(func, t + self.eps, dt - 2 * self.eps, y) 156 | 157 | def set_accs(self, train, val, test, time): 158 | self.best_train = train 159 | self.best_val = val 160 | self.best_test = test 161 | self.best_time = time.item() 162 | 163 | def integrate(self, t): 164 | time_grid = self.grid_constructor(self.func, self.y0, t) 165 | assert time_grid[0] == t[0] and time_grid[-1] == t[-1] 166 | 167 | solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device) 168 | solution[0] = self.y0 169 | 170 | j = 1 171 | y0 = self.y0 172 | for t0, t1 in zip(time_grid[:-1], time_grid[1:]): 173 | dy = self._step_func(self.func, t0, t1 - t0, t1, y0) 174 | y1 = y0 + dy 175 | train_acc, val_acc, test_acc = self.evaluate(y1, t0, t1) 176 | if val_acc > self.best_val: 177 | self.set_accs(train_acc, val_acc, test_acc, t1) 178 | 179 | while j < len(t) and t1 >= t[j]: 180 | solution[j] = self._linear_interp(t0, t1, y0, y1, t[j]) 181 | j += 1 182 | y0 = y1 183 | 184 | return t1, solution 185 | 186 | @torch.no_grad() 187 | def test(self, logits): 188 | accs = [] 189 | for _, mask in self.data('train_mask', 'val_mask', 'test_mask'): 190 | pred = logits[mask].max(1)[1] 191 | acc = pred.eq(self.data.y[mask]).sum().item() / mask.sum().item() 192 | accs.append(acc) 193 | return accs 194 | 195 | @torch.no_grad() 196 | def test_OGB(self, logits): 197 | evaluator = self.evaluator 198 | data = self.data 199 | y_pred = logits.argmax(dim=-1, keepdim=True) 200 | train_acc, valid_acc, test_acc = run_evaluator(evaluator, data, y_pred) 201 | return [train_acc, valid_acc, test_acc] 202 | 203 | @torch.no_grad() 204 | def evaluate(self, z, t0, t1): 205 | # Activation. 206 | if not self.m2_weight.shape[1] == z.shape[1]: # system has been augmented 207 | z = torch.split(z, self.m2_weight.shape[1], dim=1)[0] 208 | z = F.relu(z) 209 | z = F.linear(z, self.m2_weight, self.m2_bias) 210 | if self.dataset == 'ogbn-arxiv': 211 | z = z.log_softmax(dim=-1) 212 | loss = self.lf(z[self.data.train_mask], self.data.y.squeeze()[self.data.train_mask]) 213 | else: 214 | loss = self.lf(z[self.data.train_mask], self.data.y[self.data.train_mask]) 215 | train_acc, val_acc, test_acc = self.ode_test(z) 216 | log = 'ODE eval t0 {:.3f}, t1 {:.3f} Loss: {:.4f}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' 217 | # print(log.format(t0, t1, loss, train_acc, val_acc, tmp_test_acc)) 218 | return train_acc, val_acc, test_acc 219 | 220 | def set_m2(self, m2): 221 | self.m2 = copy.deepcopy(m2) 222 | 223 | def set_data(self, data): 224 | if self.data is None: 225 | self.data = data 226 | 227 | 228 | SOLVERS = { 229 | 'dopri5': EarlyStopDopri5, 230 | 'rk4': EarlyStopRK4 231 | } 232 | 233 | 234 | class EarlyStopInt(torch.nn.Module): 235 | def __init__(self, t, opt, device=None): 236 | super(EarlyStopInt, self).__init__() 237 | self.device = device 238 | self.solver = None 239 | self.data = None 240 | self.max_test_steps = opt['max_test_steps'] 241 | self.m2_weight = None 242 | self.m2_bias = None 243 | self.opt = opt 244 | self.t = torch.tensor([0, opt['earlystopxT'] * t], dtype=torch.float).to(self.device) 245 | 246 | def __call__(self, func, y0, t, method=None, rtol=1e-7, atol=1e-9, 247 | adjoint_method="dopri5", adjoint_atol=1e-9, adjoint_rtol=1e-7, options=None): 248 | """Integrate a system of ordinary differential equations. 249 | 250 | Solves the initial value problem for a non-stiff system of first order ODEs: 251 | ``` 252 | dy/dt = func(t, y), y(t[0]) = y0 253 | ``` 254 | where y is a Tensor of any shape. 255 | 256 | Output dtypes and numerical precision are based on the dtypes of the inputs `y0`. 257 | 258 | Args: 259 | func: Function that maps a Tensor holding the state `y` and a scalar Tensor 260 | `t` into a Tensor of state derivatives with respect to time. 261 | y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May 262 | have any floating point or complex dtype. 263 | t: 1-D Tensor holding a sequence of time points for which to solve for 264 | `y`. The initial time point should be the first element of this sequence, 265 | and each time must be larger than the previous time. May have any floating 266 | point dtype. Converted to a Tensor with float64 dtype. 267 | rtol: optional float64 Tensor specifying an upper bound on relative error, 268 | per element of `y`. 269 | atol: optional float64 Tensor specifying an upper bound on absolute error, 270 | per element of `y`. 271 | method: optional string indicating the integration method to use. 272 | options: optional dict of configuring options for the indicated integration 273 | method. Can only be provided if a `method` is explicitly set. 274 | name: Optional name for this operation. 275 | 276 | Returns: 277 | y: Tensor, where the first dimension corresponds to different 278 | time points. Contains the solved value of y for each desired time point in 279 | `t`, with the initial value `y0` being the first element along the first 280 | dimension. 281 | 282 | Raises: 283 | ValueError: if an invalid `method` is provided. 284 | TypeError: if `options` is supplied without `method`, or if `t` or `y0` has 285 | an invalid dtype. 286 | """ 287 | method = self.opt['method'] 288 | assert method in ['rk4', 'dopri5'], "Only dopri5 and rk4 implemented with early stopping" 289 | 290 | ver = torchdiffeq.__version__ 291 | if int(ver[0] + ver[2] + ver[4]) >= 20: # 0.2.0 change of signature on this release for event_fn 292 | event_fn = None 293 | shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed = _check_inputs(func, y0, self.t, rtol, 294 | atol, method, options, 295 | event_fn, SOLVERS) 296 | else: 297 | shapes, func, y0, t, rtol, atol, method, options = _check_inputs(func, y0, self.t, rtol, atol, method, options, 298 | SOLVERS) 299 | 300 | self.solver = SOLVERS[method](func, y0, rtol=rtol, atol=atol, opt=self.opt, **options) 301 | if self.solver.data is None: 302 | self.solver.data = self.data 303 | self.solver.m2_weight = self.m2_weight 304 | self.solver.m2_bias = self.m2_bias 305 | t, solution = self.solver.integrate(t) 306 | if shapes is not None: 307 | solution = _flat_to_shape(solution, (len(t),), shapes) 308 | return solution 309 | -------------------------------------------------------------------------------- /src/function_GAT_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_geometric.utils import softmax 4 | import torch_sparse 5 | from torch_geometric.utils.loop import add_remaining_self_loops 6 | from data import get_dataset 7 | from utils import MaxNFEException 8 | from base_classes import ODEFunc 9 | 10 | 11 | class ODEFuncAtt(ODEFunc): 12 | 13 | def __init__(self, in_features, out_features, opt, data, device): 14 | super(ODEFuncAtt, self).__init__(opt, data, device) 15 | 16 | if opt['self_loop_weight'] > 0: 17 | self.edge_index, self.edge_weight = add_remaining_self_loops(data.edge_index, data.edge_attr, 18 | fill_value=opt['self_loop_weight']) 19 | else: 20 | self.edge_index, self.edge_weight = data.edge_index, data.edge_attr 21 | 22 | self.multihead_att_layer = SpGraphAttentionLayer(in_features, out_features, opt, 23 | device).to(device) 24 | try: 25 | self.attention_dim = opt['attention_dim'] 26 | except KeyError: 27 | self.attention_dim = out_features 28 | 29 | assert self.attention_dim % opt['heads'] == 0, "Number of heads must be a factor of the dimension size" 30 | self.d_k = self.attention_dim // opt['heads'] 31 | 32 | def multiply_attention(self, x, attention, wx): 33 | if self.opt['mix_features']: 34 | wx = torch.mean(torch.stack( 35 | [torch_sparse.spmm(self.edge_index, attention[:, idx], wx.shape[0], wx.shape[0], wx) for idx in 36 | range(self.opt['heads'])], dim=0), 37 | dim=0) 38 | ax = torch.mm(wx, self.multihead_att_layer.Wout) 39 | else: 40 | ax = torch.mean(torch.stack( 41 | [torch_sparse.spmm(self.edge_index, attention[:, idx], x.shape[0], x.shape[0], x) for idx in 42 | range(self.opt['heads'])], dim=0), 43 | dim=0) 44 | return ax 45 | 46 | def forward(self, t, x): # t is needed when called by the integrator 47 | 48 | if self.nfe > self.opt["max_nfe"]: 49 | raise MaxNFEException 50 | 51 | self.nfe += 1 52 | 53 | attention, wx = self.multihead_att_layer(x, self.edge_index) 54 | ax = self.multiply_attention(x, attention, wx) 55 | # todo would be nice if this was more efficient 56 | 57 | if not self.opt['no_alpha_sigmoid']: 58 | alpha = torch.sigmoid(self.alpha_train) 59 | else: 60 | alpha = self.alpha_train 61 | 62 | f = alpha * (ax - x) 63 | if self.opt['add_source']: 64 | f = f + self.beta_train * self.x0 65 | return f 66 | 67 | def __repr__(self): 68 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 69 | 70 | 71 | class SpGraphAttentionLayer(nn.Module): 72 | """ 73 | Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903 74 | """ 75 | 76 | def __init__(self, in_features, out_features, opt, device, concat=True): 77 | super(SpGraphAttentionLayer, self).__init__() 78 | self.in_features = in_features 79 | self.out_features = out_features 80 | self.alpha = opt['leaky_relu_slope'] 81 | self.concat = concat 82 | self.device = device 83 | self.opt = opt 84 | self.h = opt['heads'] 85 | 86 | try: 87 | self.attention_dim = opt['attention_dim'] 88 | except KeyError: 89 | self.attention_dim = out_features 90 | 91 | assert self.attention_dim % opt['heads'] == 0, "Number of heads must be a factor of the dimension size" 92 | self.d_k = self.attention_dim // opt['heads'] 93 | 94 | self.W = nn.Parameter(torch.zeros(size=(in_features, self.attention_dim))).to(device) 95 | nn.init.xavier_normal_(self.W.data, gain=1.414) 96 | 97 | self.Wout = nn.Parameter(torch.zeros(size=(self.attention_dim, self.in_features))).to(device) 98 | nn.init.xavier_normal_(self.Wout.data, gain=1.414) 99 | 100 | self.a = nn.Parameter(torch.zeros(size=(2 * self.d_k, 1, 1))).to(device) 101 | nn.init.xavier_normal_(self.a.data, gain=1.414) 102 | 103 | self.leakyrelu = nn.LeakyReLU(self.alpha) 104 | 105 | def forward(self, x, edge): 106 | wx = torch.mm(x, self.W) # h: N x out 107 | h = wx.view(-1, self.h, self.d_k) 108 | h = h.transpose(1, 2) 109 | 110 | # Self-attention on the nodes - Shared attention mechanism 111 | edge_h = torch.cat((h[edge[0, :], :, :], h[edge[1, :], :, :]), dim=1).transpose(0, 1).to( 112 | self.device) # edge: 2*D x E 113 | edge_e = self.leakyrelu(torch.sum(self.a * edge_h, dim=0)).to(self.device) 114 | attention = softmax(edge_e, edge[self.opt['attention_norm_idx']]) 115 | return attention, wx 116 | 117 | def __repr__(self): 118 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 119 | 120 | 121 | if __name__ == '__main__': 122 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 123 | opt = {'dataset': 'Cora', 'self_loop_weight': 1, 'leaky_relu_slope': 0.2, 'beta_dim': 'vc', 'heads': 2, 'K': 10, 'attention_norm_idx': 0, 124 | 'add_source':False, 'alpha_dim': 'sc', 'beta_dim': 'vc', 'max_nfe':1000, 'mix_features': False} 125 | dataset = get_dataset(opt, '../data', False) 126 | t = 1 127 | func = ODEFuncAtt(dataset.data.num_features, 6, opt, dataset.data, device) 128 | out = func(t, dataset.data.x) 129 | -------------------------------------------------------------------------------- /src/function_laplacian_diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch_sparse 4 | 5 | from base_classes import ODEFunc 6 | from utils import MaxNFEException 7 | 8 | 9 | # Define the ODE function. 10 | # Input: 11 | # --- t: A tensor with shape [], meaning the current time. 12 | # --- x: A tensor with shape [#batches, dims], meaning the value of x at t. 13 | # Output: 14 | # --- dx/dt: A tensor with shape [#batches, dims], meaning the derivative of x at t. 15 | class LaplacianODEFunc(ODEFunc): 16 | 17 | # currently requires in_features = out_features 18 | def __init__(self, in_features, out_features, opt, data, device): 19 | super(LaplacianODEFunc, self).__init__(opt, data, device) 20 | 21 | self.in_features = in_features 22 | self.out_features = out_features 23 | self.w = nn.Parameter(torch.eye(opt['hidden_dim'])) 24 | self.d = nn.Parameter(torch.zeros(opt['hidden_dim']) + 1) 25 | self.alpha_sc = nn.Parameter(torch.ones(1)) 26 | self.beta_sc = nn.Parameter(torch.ones(1)) 27 | 28 | def sparse_multiply(self, x): 29 | if self.opt['block'] in ['attention']: # adj is a multihead attention 30 | mean_attention = self.attention_weights.mean(dim=1) 31 | ax = torch_sparse.spmm(self.edge_index, mean_attention, x.shape[0], x.shape[0], x) 32 | elif self.opt['block'] in ['mixed', 'hard_attention']: # adj is a torch sparse matrix 33 | ax = torch_sparse.spmm(self.edge_index, self.attention_weights, x.shape[0], x.shape[0], x) 34 | else: # adj is a torch sparse matrix 35 | ax = torch_sparse.spmm(self.edge_index, self.edge_weight, x.shape[0], x.shape[0], x) 36 | return ax 37 | 38 | def forward(self, t, x): # the t param is needed by the ODE solver. 39 | if self.nfe > self.opt["max_nfe"]: 40 | raise MaxNFEException 41 | self.nfe += 1 42 | ax = self.sparse_multiply(x) 43 | if not self.opt['no_alpha_sigmoid']: 44 | alpha = torch.sigmoid(self.alpha_train) 45 | else: 46 | alpha = self.alpha_train 47 | 48 | f = alpha * (ax - x) 49 | if self.opt['add_source']: 50 | f = f + self.beta_train * self.x0 51 | return f 52 | -------------------------------------------------------------------------------- /src/function_transformer_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_geometric.utils import softmax 4 | import torch_sparse 5 | from torch_geometric.utils.loop import add_remaining_self_loops 6 | import numpy as np 7 | from data import get_dataset 8 | from utils import MaxNFEException, squareplus 9 | from base_classes import ODEFunc 10 | 11 | 12 | class ODEFuncTransformerAtt(ODEFunc): 13 | 14 | def __init__(self, in_features, out_features, opt, data, device): 15 | super(ODEFuncTransformerAtt, self).__init__(opt, data, device) 16 | 17 | if opt['self_loop_weight'] > 0: 18 | self.edge_index, self.edge_weight = add_remaining_self_loops(data.edge_index, data.edge_attr, 19 | fill_value=opt['self_loop_weight']) 20 | else: 21 | self.edge_index, self.edge_weight = data.edge_index, data.edge_attr 22 | self.multihead_att_layer = SpGraphTransAttentionLayer(in_features, out_features, opt, 23 | device, edge_weights=self.edge_weight).to(device) 24 | 25 | def multiply_attention(self, x, attention, v=None): 26 | # todo would be nice if this was more efficient 27 | if self.opt['mix_features']: 28 | vx = torch.mean(torch.stack( 29 | [torch_sparse.spmm(self.edge_index, attention[:, idx], v.shape[0], v.shape[0], v[:, :, idx]) for idx in 30 | range(self.opt['heads'])], dim=0), 31 | dim=0) 32 | ax = self.multihead_att_layer.Wout(vx) 33 | else: 34 | mean_attention = attention.mean(dim=1) 35 | ax = torch_sparse.spmm(self.edge_index, mean_attention, x.shape[0], x.shape[0], x) 36 | return ax 37 | 38 | def forward(self, t, x): # t is needed when called by the integrator 39 | if self.nfe > self.opt["max_nfe"]: 40 | raise MaxNFEException 41 | 42 | self.nfe += 1 43 | attention, values = self.multihead_att_layer(x, self.edge_index) 44 | ax = self.multiply_attention(x, attention, values) 45 | 46 | if not self.opt['no_alpha_sigmoid']: 47 | alpha = torch.sigmoid(self.alpha_train) 48 | else: 49 | alpha = self.alpha_train 50 | f = alpha * (ax - x) 51 | if self.opt['add_source']: 52 | f = f + self.beta_train * self.x0 53 | return f 54 | 55 | def __repr__(self): 56 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 57 | 58 | 59 | class SpGraphTransAttentionLayer(nn.Module): 60 | """ 61 | Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903 62 | """ 63 | 64 | def __init__(self, in_features, out_features, opt, device, concat=True, edge_weights=None): 65 | super(SpGraphTransAttentionLayer, self).__init__() 66 | self.in_features = in_features 67 | self.out_features = out_features 68 | self.alpha = opt['leaky_relu_slope'] 69 | self.concat = concat 70 | self.device = device 71 | self.opt = opt 72 | self.h = int(opt['heads']) 73 | self.edge_weights = edge_weights 74 | 75 | try: 76 | self.attention_dim = opt['attention_dim'] 77 | except KeyError: 78 | self.attention_dim = out_features 79 | 80 | assert self.attention_dim % self.h == 0, "Number of heads ({}) must be a factor of the dimension size ({})".format( 81 | self.h, self.attention_dim) 82 | self.d_k = self.attention_dim // self.h 83 | 84 | if self.opt['beltrami'] and self.opt['attention_type'] == "exp_kernel": 85 | self.output_var_x = nn.Parameter(torch.ones(1)) 86 | self.lengthscale_x = nn.Parameter(torch.ones(1)) 87 | self.output_var_p = nn.Parameter(torch.ones(1)) 88 | self.lengthscale_p = nn.Parameter(torch.ones(1)) 89 | self.Qx = nn.Linear(opt['hidden_dim']-opt['pos_enc_hidden_dim'], self.attention_dim) 90 | self.init_weights(self.Qx) 91 | self.Vx = nn.Linear(opt['hidden_dim']-opt['pos_enc_hidden_dim'], self.attention_dim) 92 | self.init_weights(self.Vx) 93 | self.Kx = nn.Linear(opt['hidden_dim']-opt['pos_enc_hidden_dim'], self.attention_dim) 94 | self.init_weights(self.Kx) 95 | 96 | self.Qp = nn.Linear(opt['pos_enc_hidden_dim'], self.attention_dim) 97 | self.init_weights(self.Qp) 98 | self.Vp = nn.Linear(opt['pos_enc_hidden_dim'], self.attention_dim) 99 | self.init_weights(self.Vp) 100 | self.Kp = nn.Linear(opt['pos_enc_hidden_dim'], self.attention_dim) 101 | self.init_weights(self.Kp) 102 | 103 | else: 104 | if self.opt['attention_type'] == "exp_kernel": 105 | self.output_var = nn.Parameter(torch.ones(1)) 106 | self.lengthscale = nn.Parameter(torch.ones(1)) 107 | 108 | self.Q = nn.Linear(in_features, self.attention_dim) 109 | self.init_weights(self.Q) 110 | 111 | self.V = nn.Linear(in_features, self.attention_dim) 112 | self.init_weights(self.V) 113 | 114 | self.K = nn.Linear(in_features, self.attention_dim) 115 | self.init_weights(self.K) 116 | 117 | self.activation = nn.Sigmoid() # nn.LeakyReLU(self.alpha) 118 | 119 | self.Wout = nn.Linear(self.d_k, in_features) 120 | self.init_weights(self.Wout) 121 | 122 | def init_weights(self, m): 123 | if type(m) == nn.Linear: 124 | # nn.init.xavier_uniform_(m.weight, gain=1.414) 125 | # m.bias.data.fill_(0.01) 126 | nn.init.constant_(m.weight, 1e-5) 127 | 128 | def forward(self, x, edge): 129 | """ 130 | x might be [features, augmentation, positional encoding, labels] 131 | """ 132 | # if self.opt['beltrami'] and self.opt['attention_type'] == "exp_kernel": 133 | if self.opt['beltrami'] and self.opt['attention_type'] == "exp_kernel": 134 | label_index = self.opt['feat_hidden_dim'] + self.opt['pos_enc_hidden_dim'] 135 | p = x[:, self.opt['feat_hidden_dim']: label_index] 136 | x = torch.cat((x[:, :self.opt['feat_hidden_dim']], x[:, label_index:]), dim=1) 137 | 138 | qx = self.Qx(x) 139 | kx = self.Kx(x) 140 | vx = self.Vx(x) 141 | # perform linear operation and split into h heads 142 | kx = kx.view(-1, self.h, self.d_k) 143 | qx = qx.view(-1, self.h, self.d_k) 144 | vx = vx.view(-1, self.h, self.d_k) 145 | # transpose to get dimensions [n_nodes, attention_dim, n_heads] 146 | kx = kx.transpose(1, 2) 147 | qx = qx.transpose(1, 2) 148 | vx = vx.transpose(1, 2) 149 | src_x = qx[edge[0, :], :, :] 150 | dst_x = kx[edge[1, :], :, :] 151 | 152 | qp = self.Qp(p) 153 | kp = self.Kp(p) 154 | vp = self.Vp(p) 155 | # perform linear operation and split into h heads 156 | kp = kp.view(-1, self.h, self.d_k) 157 | qp = qp.view(-1, self.h, self.d_k) 158 | vp = vp.view(-1, self.h, self.d_k) 159 | # transpose to get dimensions [n_nodes, attention_dim, n_heads] 160 | kp = kp.transpose(1, 2) 161 | qp = qp.transpose(1, 2) 162 | vp = vp.transpose(1, 2) 163 | src_p = qp[edge[0, :], :, :] 164 | dst_p = kp[edge[1, :], :, :] 165 | 166 | prods = self.output_var_x ** 2 * torch.exp( 167 | -torch.sum((src_x - dst_x) ** 2, dim=1) / (2 * self.lengthscale_x ** 2)) \ 168 | * self.output_var_p ** 2 * torch.exp( 169 | -torch.sum((src_p - dst_p) ** 2, dim=1) / (2 * self.lengthscale_p ** 2)) 170 | 171 | v = None 172 | 173 | else: 174 | q = self.Q(x) 175 | k = self.K(x) 176 | v = self.V(x) 177 | 178 | # perform linear operation and split into h heads 179 | 180 | k = k.view(-1, self.h, self.d_k) 181 | q = q.view(-1, self.h, self.d_k) 182 | v = v.view(-1, self.h, self.d_k) 183 | 184 | # transpose to get dimensions [n_nodes, attention_dim, n_heads] 185 | 186 | k = k.transpose(1, 2) 187 | q = q.transpose(1, 2) 188 | v = v.transpose(1, 2) 189 | 190 | src = q[edge[0, :], :, :] 191 | dst_k = k[edge[1, :], :, :] 192 | 193 | if not self.opt['beltrami'] and self.opt['attention_type'] == "exp_kernel": 194 | prods = self.output_var ** 2 * torch.exp(-(torch.sum((src - dst_k) ** 2, dim=1) / (2 * self.lengthscale ** 2))) 195 | elif self.opt['attention_type'] == "scaled_dot": 196 | prods = torch.sum(src * dst_k, dim=1) / np.sqrt(self.d_k) 197 | elif self.opt['attention_type'] == "cosine_sim": 198 | cos = torch.nn.CosineSimilarity(dim=1, eps=1e-5) 199 | prods = cos(src, dst_k) 200 | elif self.opt['attention_type'] == "pearson": 201 | src_mu = torch.mean(src, dim=1, keepdim=True) 202 | dst_mu = torch.mean(dst_k, dim=1, keepdim=True) 203 | src = src - src_mu 204 | dst_k = dst_k - dst_mu 205 | cos = torch.nn.CosineSimilarity(dim=1, eps=1e-5) 206 | prods = cos(src, dst_k) 207 | 208 | if self.opt['reweight_attention'] and self.edge_weights is not None: 209 | prods = prods * self.edge_weights.unsqueeze(dim=1) 210 | if self.opt['square_plus']: 211 | attention = squareplus(prods, edge[self.opt['attention_norm_idx']]) 212 | else: 213 | attention = softmax(prods, edge[self.opt['attention_norm_idx']]) 214 | return attention, (v, prods) 215 | 216 | def __repr__(self): 217 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 218 | 219 | 220 | if __name__ == '__main__': 221 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 222 | opt = {'dataset': 'Cora', 'self_loop_weight': 1, 'leaky_relu_slope': 0.2, 'heads': 2, 'K': 10, 223 | 'attention_norm_idx': 0, 'add_source': False, 224 | 'alpha_dim': 'sc', 'beta_dim': 'sc', 'max_nfe': 1000, 'mix_features': False 225 | } 226 | dataset = get_dataset(opt, '../data', False) 227 | t = 1 228 | func = ODEFuncTransformerAtt(dataset.data.num_features, 6, opt, dataset.data, device) 229 | out = func(t, dataset.data.x) 230 | -------------------------------------------------------------------------------- /src/hyperbolic_distances.py: -------------------------------------------------------------------------------- 1 | import time 2 | from scipy.spatial.distance import squareform, pdist 3 | import numpy as np 4 | import argparse 5 | import pickle 6 | 7 | def hyperbolize(x): 8 | n = pdist(x.detach().numpy(), "sqeuclidean") 9 | MACHINE_EPSILON = np.finfo(np.double).eps 10 | m = squareform(n) 11 | qsqr = np.sum(x ** 2, axis=1) 12 | divisor = np.maximum(1 - qsqr[:, np.newaxis], MACHINE_EPSILON) * np.maximum(1 - qsqr[np.newaxis, :], MACHINE_EPSILON) 13 | m = np.arccosh(1 + 2 * m / divisor ) #** 2 14 | return m 15 | 16 | def main(opt): 17 | dataset = opt['dataset'] 18 | for emb_dim in [16, 8, 4, 2]: 19 | with open(f"../data/pos_encodings/{dataset}_HYPS{emb_dim:02d}.pkl", "rb") as f: 20 | emb = pickle.load(f) 21 | t = time.time() 22 | sqdist = pdist(emb.detach().numpy(), "sqeuclidean") 23 | distances_ = hyperbolize(emb.detach().numpy(), sqdist) 24 | print("Distances calculated in %.2f sec" % (time.time()-t)) 25 | #with open(f"../data/pos_encodings/{dataset}_HYPS{emb_dim:02d}_dists.pkl", "wb") as f: 26 | # pickle.dump(distances, f) 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--dataset', type=str, default='ALL', 30 | help='Cora, Citeseer, Pubmed, Computers, Photo, CoauthorCS') 31 | args = parser.parse_args() 32 | opt = vars(args) 33 | main(opt) -------------------------------------------------------------------------------- /src/model_configurations.py: -------------------------------------------------------------------------------- 1 | from function_transformer_attention import ODEFuncTransformerAtt 2 | from function_GAT_attention import ODEFuncAtt 3 | from function_laplacian_diffusion import LaplacianODEFunc 4 | from block_transformer_attention import AttODEblock 5 | from block_constant import ConstantODEblock 6 | from block_mixed import MixedODEblock 7 | from block_transformer_hard_attention import HardAttODEblock 8 | from block_transformer_rewiring import RewireAttODEblock 9 | 10 | class BlockNotDefined(Exception): 11 | pass 12 | 13 | class FunctionNotDefined(Exception): 14 | pass 15 | 16 | 17 | def set_block(opt): 18 | ode_str = opt['block'] 19 | if ode_str == 'mixed': 20 | block = MixedODEblock 21 | elif ode_str == 'attention': 22 | block = AttODEblock 23 | elif ode_str == 'hard_attention': 24 | block = HardAttODEblock 25 | elif ode_str == 'rewire_attention': 26 | block = RewireAttODEblock 27 | elif ode_str == 'constant': 28 | block = ConstantODEblock 29 | else: 30 | raise BlockNotDefined 31 | return block 32 | 33 | 34 | def set_function(opt): 35 | ode_str = opt['function'] 36 | if ode_str == 'laplacian': 37 | f = LaplacianODEFunc 38 | elif ode_str == 'GAT': 39 | f = ODEFuncAtt 40 | elif ode_str == 'transformer': 41 | f = ODEFuncTransformerAtt 42 | else: 43 | raise FunctionNotDefined 44 | return f 45 | -------------------------------------------------------------------------------- /src/pos_enc_factorisation.py: -------------------------------------------------------------------------------- 1 | """ 2 | matrix factorisation of the positional encoding required for arxiv-ogbn 3 | """ 4 | 5 | import numpy as np 6 | import argparse 7 | import os 8 | import pickle 9 | from sklearn.decomposition import NMF 10 | from graph_rewiring import apply_gdc 11 | from data import get_dataset 12 | import time 13 | from libmf import mf 14 | from scipy import sparse 15 | 16 | 17 | POS_ENC_PATH = os.path.join("../data", "pos_encodings") 18 | 19 | 20 | def find_or_make_encodings(opt): 21 | # generate new positional encodings 22 | # do encodings already exist on disk? 23 | fname = os.path.join(POS_ENC_PATH, f"{opt['dataset']}_{opt['pos_enc_type']}.pkl") 24 | print(f"[i] Looking for positional encodings in {fname}...") 25 | 26 | # - if so, just load them 27 | if os.path.exists(fname): 28 | print(" Found them! Loading cached version") 29 | with open(fname, "rb") as f: 30 | pos_encoding = pickle.load(f) 31 | 32 | # - otherwise, calculate... 33 | else: 34 | print(" Encodings not found! Calculating and caching them") 35 | # choose different functions for different positional encodings 36 | dataset = get_dataset(opt, '../data', False) 37 | data = dataset.data 38 | if opt['pos_enc_type'] == "GDC": 39 | pos_encoding = apply_gdc(data, opt, type="pos_encoding") 40 | else: 41 | print(f"[x] The positional encoding type you specified ({opt['pos_enc_type']}) does not exist") 42 | quit() 43 | # - ... and store them on disk 44 | if not os.path.exists(POS_ENC_PATH): 45 | os.makedirs(POS_ENC_PATH) 46 | with open(fname, "wb") as f: 47 | pickle.dump(pos_encoding, f) 48 | 49 | return pos_encoding 50 | 51 | def run_libmf(): 52 | pos_encodings = sparse.random(10,10) 53 | engine = mf.MF(k=2, nr_threads=8) 54 | engine.fit(pos_encodings) 55 | embedding = engine.q_factors() 56 | return embedding 57 | 58 | def main(opt): 59 | start_time = time.time() 60 | dim = opt['embedding_dim'] 61 | type = opt['pos_enc_type'] 62 | model = NMF(n_components=dim, init='random', random_state=0, max_iter=opt['max_iter'], verbose=1, tol=opt['tol']) 63 | fname = os.path.join(POS_ENC_PATH, f"{opt['dataset']}_{opt['pos_enc_type']}.pkl") 64 | print(f"[i] Looking for positional encodings in {fname}...") 65 | 66 | pos_encodings = find_or_make_encodings(opt) 67 | 68 | 69 | 70 | # - if so, just load them 71 | print(f"positional encodings retrieved after {time.time()-start_time} seconds. Starting matrix factorisation") 72 | 73 | W = model.fit_transform(pos_encodings) 74 | # H = model.components_ 75 | end_time = time.time() 76 | print(f"compression to {dim} dim complete in {(end_time - start_time)} seconds") 77 | 78 | out_path = f"{opt['out_dir']}/compressed_pos_encodings_{dim}_{type}.pkl" 79 | with opt(out_path, 'wb') as f: 80 | pickle.dump(W, f) 81 | 82 | if not os.path.exists(opt['out_path']): 83 | os.makedirs(opt['out_path']) 84 | with open(fname, "wb") as f: 85 | pickle.dump(W, f) 86 | 87 | 88 | if __name__ == '__main__': 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument( 91 | "--use_cora_defaults", 92 | action="store_true", 93 | help="Whether to run with best params for cora. Overrides the choice of dataset",) 94 | parser.add_argument("--rewiring", action="store_true", help="Whether to rewire the dataset with GDC") 95 | parser.add_argument("--exact", action="store_true", help="Whether to approximate ppr") 96 | parser.add_argument( 97 | "--data_path", type=str, default=".", help="path to the positional encoding" 98 | ) 99 | parser.add_argument( 100 | "--pos_enc_orientation", type=str, default="row", help="" 101 | ) 102 | parser.add_argument( 103 | "--out_dir", type=str, default="../data", help="path to save compressed encoding" 104 | ) 105 | parser.add_argument("--self_loop_weight", type=int, default=1) 106 | parser.add_argument("--pos_enc_type", type=str, default="GDC", 107 | help="type of encoding to make only GDC currently implemented") 108 | parser.add_argument('--gdc_method', type=str, default='ppr', help="ppr, heat, coeff") 109 | parser.add_argument('--gdc_sparsification', type=str, default='topk', help="threshold, topk") 110 | parser.add_argument('--gdc_k', type=int, default=64, help="number of neighbours to sparsify to when using topk") 111 | parser.add_argument('--gdc_threshold', type=float, default=0.0001, 112 | help="above this edge weight, keep edges when using threshold") 113 | parser.add_argument('--gdc_avg_degree', type=int, default=64, 114 | help="if gdc_threshold is not given can be calculated by specifying avg degree") 115 | parser.add_argument('--ppr_alpha', type=float, default=0.05, help="teleport probability") 116 | parser.add_argument('--heat_time', type=float, default=3., help="time to run gdc heat kernal diffusion for") 117 | parser.add_argument( 118 | "--dataset", type=str, default="ogbn-arxiv", help="type of encoding to make only GDC currently implemented" 119 | ) 120 | parser.add_argument( 121 | "--embedding_dim", type=int, default=1000, help="dimension of compressed encoding" 122 | ) 123 | parser.add_argument( 124 | "--max_iter", type=int, default=100, help="number of training iterations" 125 | ) 126 | parser.add_argument( 127 | "--tol", type=float, default=0.002, help="number of training iterations" 128 | ) 129 | args = parser.parse_args() 130 | opt = vars(args) 131 | # run_libmf() 132 | main(opt) 133 | -------------------------------------------------------------------------------- /src/post_analysis_image.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torchvision 4 | import numpy as np 5 | from torch_geometric.nn import GCNConv, ChebConv 6 | from GNN_image import GNN_image 7 | import time 8 | from torch_geometric.data import DataLoader 9 | from data_image import edge_index_calc, create_in_memory_dataset 10 | import matplotlib.pyplot as plt 11 | from matplotlib.animation import FuncAnimation 12 | from torch_geometric.utils import to_dense_adj 13 | import pandas as pd 14 | import torchvision.transforms as transforms 15 | from data_image import load_data 16 | import openpyxl 17 | from utils import get_rw_adj 18 | 19 | @torch.no_grad() 20 | def print_image_T(model, dataset, opt, modelpath, height=2, width=3): 21 | 22 | loader = DataLoader(dataset, batch_size=opt['batch_size'], shuffle=True) 23 | fig = plt.figure()#figsize=(width*10, height*10)) 24 | for batch_idx, batch in enumerate(loader): 25 | out = model.forward_plot_T(batch.x) 26 | break 27 | 28 | for i in range(height*width): 29 | # t == 0 30 | plt.subplot(2*height, width, i + 1) 31 | plt.tight_layout() 32 | plt.axis('off') 33 | mask = batch.batch == i 34 | if opt['im_dataset'] == 'MNIST': 35 | plt.imshow(batch.x[torch.nonzero(mask)].view(model.opt['im_height'],model.opt['im_width']), cmap='gray', interpolation='none') 36 | elif opt['im_dataset'] == 'CIFAR': 37 | A = batch.x[torch.nonzero(mask)].view(model.opt['im_height'], model.opt['im_width'], model.opt['im_chan']) 38 | A = A / 2 + 0.5 39 | plt.imshow(A) 40 | plt.title("t=0 Ground Truth: {}".format(batch.y[i].item())) 41 | 42 | #t == T 43 | plt.subplot(2*height, width, height*width + i + 1) 44 | plt.tight_layout() 45 | plt.axis('off') 46 | if opt['im_dataset'] == 'MNIST': 47 | plt.imshow(out[i, :].view(model.opt['im_height'], model.opt['im_width']), cmap='gray', interpolation='none') 48 | elif opt['im_dataset'] == 'CIFAR': 49 | A = out[i, :].view(model.opt['im_height'], model.opt['im_width'], model.opt['im_chan']) 50 | A = A / 2 + 0.5 51 | plt.imshow(A) 52 | plt.title("t=T Ground Truth: {}".format(batch.y[i].item())) 53 | 54 | plt.savefig(f"{modelpath}_imageT.png", format="PNG") 55 | 56 | 57 | @torch.no_grad() 58 | def print_image_path(model, dataset, opt, height, width, frames): 59 | loader = DataLoader(dataset, batch_size=opt['batch_size'], shuffle=True) 60 | # build the data 61 | for batch_idx, batch in enumerate(loader): 62 | paths = model.forward_plot_path(batch.x, frames) 63 | break 64 | # draw graph initial graph 65 | fig = plt.figure() #figsize=(width*10, height*10)) 66 | for i in range(height * width): 67 | plt.subplot(height, width, i + 1) 68 | plt.tight_layout() 69 | mask = batch.batch == i 70 | if opt['im_dataset'] == 'MNIST': 71 | plt.imshow(paths[i,0,:].view(model.opt['im_height'],model.opt['im_width']), cmap='gray', interpolation='none') 72 | elif opt['im_dataset'] == 'CIFAR': 73 | A = paths[i,0,:].view(model.opt['im_height'], model.opt['im_width'], model.opt['im_chan']) 74 | A = A / 2 + 0.5 75 | plt.imshow(A) 76 | plt.title("t=0 Ground Truth: {}".format(batch.y[i].item())) 77 | plt.axis('off') 78 | 79 | # loop through data and update plot 80 | def update(ii): 81 | for i in range(height * width): 82 | plt.subplot(height, width, i + 1) 83 | plt.tight_layout() 84 | if opt['im_dataset'] == 'MNIST': 85 | plt.imshow(paths[i,ii,:].view(model.opt['im_height'], model.opt['im_width']), cmap='gray', interpolation='none') 86 | elif opt['im_dataset'] == 'CIFAR': 87 | A = paths[i, ii, :].view(model.opt['im_height'], model.opt['im_width'], model.opt['im_chan']) 88 | A = A / 2 + 0.5 # unnormalize 89 | plt.imshow(A) 90 | plt.title("t={} Ground Truth: {}".format(ii, batch.y[i].item())) 91 | plt.axis('off') 92 | 93 | fig = plt.gcf() 94 | animation = FuncAnimation(fig, func=update, frames=frames, interval=10)#, blit=True) 95 | return animation 96 | 97 | @torch.no_grad() 98 | def plot_att_heat(model, model_key, modelpath): 99 | im_height = model.opt['im_height'] 100 | im_width = model.opt['im_width'] 101 | im_chan = model.opt['im_chan'] 102 | hwc = im_height * im_width * im_chan 103 | slice = torch.tensor(range(hwc+1)) 104 | edge_index = model.odeblock.edge_index 105 | num_nodes = model.opt['num_nodes'] 106 | edge_weight = model.odeblock.odefunc.adj[0,:,:] 107 | 108 | dense_att = to_dense_adj(edge_index=edge_index, edge_attr=edge_weight, max_num_nodes=num_nodes) 109 | square_att = dense_att.view(model.opt['num_nodes'], model.opt['num_nodes']) 110 | 111 | x_np = square_att.numpy() 112 | x_df = pd.DataFrame(x_np) 113 | x_df.to_csv(f"{modelpath}_att.csv") 114 | 115 | fig = plt.figure() 116 | plt.tight_layout() 117 | plt.imshow(square_att, cmap='hot', interpolation='nearest') 118 | plt.title("Attention Heat Map {}".format(model_key)) 119 | plt.savefig(f"{modelpath}_AttHeat.png", format="PNG") 120 | 121 | # useful code to overcome normalisation colour bar 122 | # https: // matplotlib.org / 3.3.3 / gallery / images_contours_and_fields / multi_image.html # sphx-glr-gallery-images-contours-and-fields-multi-image-py 123 | 124 | @torch.no_grad() 125 | def plot_att_edges(model): 126 | pass 127 | 128 | def main(opt): 129 | model_key = 'model_20210113-093420' 130 | modelfolder = f"../models/{model_key}" 131 | modelpath = f"../models/{model_key}/{model_key}" 132 | 133 | df = pd.read_excel('../models/models.xlsx', engine='openpyxl', ) 134 | optdf = df.loc[df['model_key'] == model_key] 135 | numeric = ['batch_size', 'train_size', 'test_size', 'Test Acc', 'alpha', 136 | 'hidden_dim', 'input_dropout', 'dropout', 137 | 'lr', 'decay', 'self_loop_weight', 'epoch', 'time', 138 | 'tol_scale', 'ode_blocks', 'dt_min', 'dt', 139 | 'leaky_relu_slope', 'attention_dropout', 'heads', 'attention_norm_idx', 'attention_dim', 140 | 'im_width', 'im_height', 'num_feature', 'num_class', 'im_chan', 'num_nodes'] 141 | df[numeric] = df[numeric].apply(pd.to_numeric) 142 | opt = optdf.to_dict('records')[0] 143 | 144 | print("Loading Data") 145 | use_cuda = False 146 | torch.manual_seed(1) 147 | device = torch.device("cuda" if use_cuda else "cpu") 148 | 149 | Graph_GNN, Graph_train, Graph_test = load_data(opt) 150 | 151 | print("creating GNN model") 152 | # model = GNN_image(opt, Graph_GNN, device).to(device) 153 | #todo this is so fucked, load model with GNN to get num_classes==10 and then augment adj with below 154 | # loader = DataLoader(Graph_train, batch_size=model.opt['batch_size'], shuffle=True) 155 | loader = DataLoader(Graph_train, batch_size=opt['batch_size'], shuffle=True) 156 | for batch_idx, batch in enumerate(loader): 157 | if batch_idx == 0:# only do this for 1st batch/epoch 158 | # model.data = batch #loader.dataset #adding this to reset the data 159 | # model.odeblock.data = batch #loader.dataset.data #why do I need to do this? duplicating data from model to ODE block? 160 | # model.odeblock.odefunc.adj = get_rw_adj(model.data.edge_index) #to reset adj matrix 161 | break 162 | 163 | model = GNN_image(opt, batch, opt['num_class'], device).to(device) 164 | # model.load_state_dict(torch.load(modelpath)) 165 | # out = model(batch.x) 166 | model.eval() 167 | 168 | # do these as functions that take model key to generate displays on demand 169 | # 1) 170 | print_image_T(model, Graph_test, opt, modelpath, height=2, width=2) #width=3) 171 | # 2) 172 | #TODO Total Pixel intensity seems to increase loads for linear ATT 173 | # animation = print_image_path(model, Graph_test, opt, height=2, width=3, frames=10) 174 | animation = print_image_path(model, Graph_test, opt, height=2, width=2, frames=10) 175 | animation.save(f'{modelpath}_animation.gif', writer='imagemagick', savefig_kwargs={'facecolor': 'white'}, fps=0.5) 176 | # 3) 177 | # plot_att_heat(model, model_key, modelpath) 178 | 179 | 180 | if __name__ == '__main__': 181 | parser = argparse.ArgumentParser() 182 | parser.add_argument('--use_image_defaults', default='MNIST', 183 | help='#Image version# Whether to run with best params for cora. Overrides the choice of dataset') 184 | # parser.add_argument('--use_image_defaults', action='store_true', 185 | # help='Whether to run with best params for cora. Overrides the choice of dataset') 186 | # parser.add_argument('--dataset', type=str, default='Cora', 187 | # help='Cora, Citeseer, Pubmed, Computers, Photo, CoauthorCS') 188 | parser.add_argument('--hidden_dim', type=int, default=16, help='Hidden dimension.') ######## NEED 189 | parser.add_argument('--input_dropout', type=float, default=0.5, help='Input dropout rate.') 190 | parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.') 191 | parser.add_argument('--optimizer', type=str, default='adam', help='One from sgd, rmsprop, adam, adagrad, adamax.') 192 | parser.add_argument('--lr', type=float, default=0.01, help='Learning rate.') 193 | parser.add_argument('--decay', type=float, default=5e-4, help='Weight decay for optimization') 194 | parser.add_argument('--self_loop_weight', type=float, default=1.0, help='Weight of self-loops.') 195 | parser.add_argument('--epoch', type=int, default=10, help='Number of training epochs per iteration.') 196 | parser.add_argument('--alpha', type=float, default=1.0, help='Factor in front matrix A.') 197 | parser.add_argument('--time', type=float, default=1.0, help='End time of ODE integrator.') 198 | parser.add_argument('--augment', action='store_true', 199 | help='double the length of the feature vector by appending zeros to stabilist ODE learning') 200 | parser.add_argument('--alpha_dim', type=str, default='sc', help='choose either scalar (sc) or vector (vc) alpha') 201 | parser.add_argument('--alpha_sigmoid', type=bool, default=True, help='apply sigmoid before multiplying by alpha') 202 | parser.add_argument('--beta_dim', type=str, default='sc', help='choose either scalar (sc) or vector (vc) beta') 203 | # ODE args 204 | parser.add_argument('--method', type=str, default='dopri5', 205 | help="set the numerical solver: dopri5, euler, rk4, midpoint") ######## NEED 206 | parser.add_argument('--ode', type=str, default='ode', help="set ode block. Either 'ode', 'att', 'sde'") ######## NEED 207 | parser.add_argument('--adjoint', default=False, help='use the adjoint ODE method to reduce memory footprint') 208 | parser.add_argument('--tol_scale', type=float, default=1., help='multiplier for atol and rtol') 209 | parser.add_argument('--ode_blocks', type=int, default=1, help='number of ode blocks to run') 210 | parser.add_argument('--simple', type=bool, default=False, 211 | help='If try get rid of alpha param and the beta*x0 source term') 212 | # SDE args 213 | parser.add_argument('--dt_min', type=float, default=1e-5, help='minimum timestep for the SDE solver') 214 | parser.add_argument('--dt', type=float, default=1e-3, help='fixed step size') 215 | parser.add_argument('--adaptive', type=bool, default=False, help='use adaptive step sizes') 216 | # Attention args 217 | parser.add_argument('--leaky_relu_slope', type=float, default=0.2, 218 | help='slope of the negative part of the leaky relu used in attention') 219 | parser.add_argument('--attention_dropout', type=float, default=0., help='dropout of attention weights') 220 | parser.add_argument('--heads', type=int, default=1, help='number of attention heads') 221 | parser.add_argument('--attention_norm_idx', type=int, default=0, help='0 = normalise rows, 1 = normalise cols') 222 | parser.add_argument('--attention_dim', type=int, default=64, 223 | help='the size to project x to before calculating att scores') 224 | parser.add_argument('--mix_features', type=bool, default=False, 225 | help='apply a feature transformation xW to the ODE') 226 | parser.add_argument('--linear_attention', type=bool, default=False, 227 | help='learn the adjacency using attention at the start of each epoch, but do not update inside the ode') 228 | parser.add_argument('--mixed_block', type=bool, default=False, 229 | help='learn the adjacency using a mix of attention and the Laplacian at the start of each epoch, but do not update inside the ode') 230 | 231 | # visualisation args 232 | parser.add_argument('--batch_size', type=int, default=64, help='Batch size') 233 | parser.add_argument('--batched', type=bool, default=True, 234 | help='Batching') 235 | parser.add_argument('--im_width', type=int, default=28, help='im_width') 236 | parser.add_argument('--im_height', type=int, default=28, help='im_height') 237 | parser.add_argument('--diags', type=bool, default=False, 238 | help='Edge index include diagonal diffusion') 239 | parser.add_argument('--im_dataset', type=str, default='MNIST', 240 | help='MNIST, CIFAR') 241 | args = parser.parse_args() 242 | opt = vars(args) 243 | main(opt) -------------------------------------------------------------------------------- /src/regularized_ODE_function.py: -------------------------------------------------------------------------------- 1 | ## This code has been adapted from https://github.com/cfinlay/ffjord-rnode/ 2 | ## MIT License 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class RegularizedODEfunc(nn.Module): 9 | def __init__(self, odefunc, regularization_fns): 10 | super(RegularizedODEfunc, self).__init__() 11 | self.odefunc = odefunc 12 | self.regularization_fns = regularization_fns 13 | 14 | def before_odeint(self, *args, **kwargs): 15 | self.odefunc.before_odeint(*args, **kwargs) 16 | 17 | def forward(self, t, state): 18 | 19 | with torch.enable_grad(): 20 | x = state[0] 21 | x.requires_grad_(True) 22 | t.requires_grad_(True) 23 | dstate = self.odefunc(t, x) 24 | if len(state) > 1: 25 | dx = dstate 26 | reg_states = tuple(reg_fn(x, t, dx, self.odefunc) for reg_fn in self.regularization_fns) 27 | return (dstate,) + reg_states 28 | else: 29 | return dstate 30 | 31 | @property 32 | def _num_evals(self): 33 | return self.odefunc._num_evals 34 | 35 | 36 | def total_derivative(x, t, dx, unused_context): 37 | del unused_context 38 | 39 | directional_dx = torch.autograd.grad(dx, x, dx, create_graph=True)[0] 40 | 41 | try: 42 | u = torch.full_like(dx, 1 / x.numel(), requires_grad=True) 43 | tmp = torch.autograd.grad((u * dx).sum(), t, create_graph=True)[0] 44 | partial_dt = torch.autograd.grad(tmp.sum(), u, create_graph=True)[0] 45 | 46 | total_deriv = directional_dx + partial_dt 47 | except RuntimeError as e: 48 | if 'One of the differentiated Tensors' in e.__str__(): 49 | raise RuntimeError( 50 | 'No partial derivative with respect to time. Use mathematically equivalent "directional_derivative" regularizer instead') 51 | 52 | tdv2 = total_deriv.pow(2).view(x.size(0), -1) 53 | 54 | return 0.5 * tdv2.mean(dim=-1) 55 | 56 | 57 | def directional_derivative(x, t, dx, unused_context): 58 | del t, unused_context 59 | 60 | directional_dx = torch.autograd.grad(dx, x, dx, create_graph=True)[0] 61 | ddx2 = directional_dx.pow(2).view(x.size(0), -1) 62 | 63 | return 0.5 * ddx2.mean(dim=-1) 64 | 65 | 66 | def quadratic_cost(x, t, dx, unused_context): 67 | del x, t, unused_context 68 | dx = dx.view(dx.shape[0], -1) 69 | return 0.5 * dx.pow(2).mean(dim=-1) 70 | 71 | 72 | def divergence_bf(dx, x): 73 | sum_diag = 0. 74 | for i in range(x.shape[1]): 75 | sum_diag += torch.autograd.grad(dx[:, i].sum(), x, create_graph=True)[0].contiguous()[:, i].contiguous() 76 | return sum_diag.contiguous() 77 | 78 | 79 | def jacobian_frobenius_regularization_fn(x, t, dx, context): 80 | del t 81 | return divergence_bf(dx, x) 82 | -------------------------------------------------------------------------------- /src/run_best_ray.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from ray.tune import Analysis 3 | import json 4 | import numpy as np 5 | from utils import get_sem, mean_confidence_interval 6 | from ray_tune import train_ray_int 7 | from ray import tune 8 | from functools import partial 9 | import os, time 10 | from ray.tune import CLIReporter 11 | 12 | 13 | def get_best_params_dir(opt): 14 | analysis = Analysis("../ray_tune/{}".format(opt['folder'])) 15 | df = analysis.dataframe(metric=opt['metric'], mode='max') 16 | best_params_dir = df.sort_values('accuracy', ascending=False)['logdir'].iloc[opt['index']] 17 | return best_params_dir 18 | 19 | 20 | def run_best_params(opt): 21 | best_params_dir = get_best_params_dir(opt) 22 | with open(best_params_dir + '/params.json') as f: 23 | best_params = json.loads(f.read()) 24 | # allow params specified at the cmd line to override 25 | best_params_ret = {**best_params, **opt} 26 | try: 27 | best_params_ret['mix_features'] 28 | except KeyError: 29 | best_params_ret['mix_features'] = False 30 | # the exception is number of epochs as we want to use more here than we would for hyperparameter tuning. 31 | best_params_ret['epoch'] = opt['epoch'] 32 | best_params_ret['max_nfe'] = opt['max_nfe'] 33 | # handle adjoint 34 | if best_params['adjoint'] or opt['adjoint']: 35 | best_params_ret['adjoint'] = True 36 | 37 | print("Running with parameters {}".format(best_params_ret)) 38 | 39 | data_dir = os.path.abspath("../data") 40 | reporter = CLIReporter( 41 | metric_columns=["accuracy", "loss", "test_acc", "train_acc", "best_time", "best_epoch", "training_iteration", "forward_nfe", "backward_nfe"]) 42 | 43 | if opt['name'] is None: 44 | name = opt['folder'] + '_test' 45 | else: 46 | name = opt['name'] 47 | 48 | result = tune.run( 49 | partial(train_ray_int, data_dir=data_dir), 50 | name=name, 51 | resources_per_trial={"cpu": opt['cpus'], "gpu": opt['gpus']}, 52 | search_alg=None, 53 | keep_checkpoints_num=3, 54 | checkpoint_score_attr='accuracy', 55 | config=best_params_ret, 56 | num_samples=opt['reps'] if opt["num_splits"] == 0 else opt["num_splits"] * opt["reps"], 57 | scheduler=None, 58 | max_failures=1, # early stop solver can't recover from failure as it doesn't own m2. 59 | local_dir='../ray_tune', 60 | progress_reporter=reporter, 61 | raise_on_failed_trial=False) 62 | 63 | df = result.dataframe(metric=opt['metric'], mode="max").sort_values(opt['metric'], ascending=False) 64 | try: 65 | df.to_csv('../ray_results/{}_{}.csv'.format(name, time.strftime("%Y%m%d-%H%M%S"))) 66 | except: 67 | pass 68 | 69 | print(df[['accuracy', 'test_acc', 'train_acc', 'best_time', 'best_epoch']]) 70 | 71 | test_accs = df['test_acc'].values 72 | print("test accuracy {}".format(test_accs)) 73 | log = "mean test {:04f}, test std {:04f}, test sem {:04f}, test 95% conf {:04f}" 74 | print(log.format(test_accs.mean(), np.std(test_accs), get_sem(test_accs), mean_confidence_interval(test_accs))) 75 | 76 | 77 | if __name__ == '__main__': 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument('--epoch', type=int, default=10, help='Number of training epochs per iteration.') 80 | parser.add_argument('--folder', type=str, default=None, help='experiment folder to read') 81 | parser.add_argument('--index', type=int, default=0, help='index to take from experiment folder') 82 | parser.add_argument('--metric', type=str, default='accuracy', help='metric to sort the hyperparameter tuning runs on') 83 | parser.add_argument('--augment', action='store_true', 84 | help='double the length of the feature vector by appending zeros to stabilise ODE learning') 85 | parser.add_argument('--reps', type=int, default=1, help='the number of random weight initialisations to use') 86 | parser.add_argument('--name', type=str, default=None) 87 | parser.add_argument('--gpus', type=float, default=0, help='number of gpus per trial. Can be fractional') 88 | parser.add_argument('--cpus', type=float, default=1, help='number of cpus per trial. Can be fractional') 89 | parser.add_argument("--num_splits", type=int, default=0, help="Number of random slpits >= 0. 0 for planetoid split") 90 | parser.add_argument("--adjoint", dest='adjoint', action='store_true', 91 | help="use the adjoint ODE method to reduce memory footprint") 92 | parser.add_argument("--max_nfe", type=int, default=5000, help="Maximum number of function evaluations allowed in an epcoh.") 93 | parser.add_argument("--no_early", action="store_true", 94 | help="Whether or not to use early stopping of the ODE integrator when testing.") 95 | 96 | parser.add_argument('--earlystopxT', type=float, default=3, help='multiplier for T used to evaluate best model') 97 | 98 | args = parser.parse_args() 99 | 100 | opt = vars(args) 101 | run_best_params(opt) 102 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | utility functions 3 | """ 4 | import os 5 | 6 | import scipy 7 | from scipy.stats import sem 8 | import numpy as np 9 | from torch_scatter import scatter_add 10 | from torch_geometric.utils import add_remaining_self_loops 11 | from torch_geometric.utils.num_nodes import maybe_num_nodes 12 | from torch_geometric.utils.convert import to_scipy_sparse_matrix 13 | from sklearn.preprocessing import normalize 14 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 15 | 16 | ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 17 | 18 | class MaxNFEException(Exception): pass 19 | 20 | 21 | def rms_norm(tensor): 22 | return tensor.pow(2).mean().sqrt() 23 | 24 | 25 | def make_norm(state): 26 | if isinstance(state, tuple): 27 | state = state[0] 28 | state_size = state.numel() 29 | 30 | def norm(aug_state): 31 | y = aug_state[1:1 + state_size] 32 | adj_y = aug_state[1 + state_size:1 + 2 * state_size] 33 | return max(rms_norm(y), rms_norm(adj_y)) 34 | 35 | return norm 36 | 37 | 38 | def print_model_params(model): 39 | total_num_params = 0 40 | print(model) 41 | for name, param in model.named_parameters(): 42 | if param.requires_grad: 43 | print(name) 44 | print(param.data.shape) 45 | total_num_params += param.numel() 46 | print("Model has a total of {} params".format(total_num_params)) 47 | 48 | 49 | def adjust_learning_rate(optimizer, lr, epoch, burnin=50): 50 | if epoch <= burnin: 51 | for param_group in optimizer.param_groups: 52 | param_group["lr"] = lr * epoch / burnin 53 | 54 | 55 | def gcn_norm_fill_val(edge_index, edge_weight=None, fill_value=0., num_nodes=None, dtype=None): 56 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 57 | 58 | if edge_weight is None: 59 | edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, 60 | device=edge_index.device) 61 | 62 | if not int(fill_value) == 0: 63 | edge_index, tmp_edge_weight = add_remaining_self_loops( 64 | edge_index, edge_weight, fill_value, num_nodes) 65 | assert tmp_edge_weight is not None 66 | edge_weight = tmp_edge_weight 67 | 68 | row, col = edge_index[0], edge_index[1] 69 | deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) 70 | deg_inv_sqrt = deg.pow_(-0.5) 71 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) 72 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 73 | 74 | 75 | def coo2tensor(coo, device=None): 76 | indices = np.vstack((coo.row, coo.col)) 77 | i = torch.LongTensor(indices) 78 | values = coo.data 79 | v = torch.FloatTensor(values) 80 | shape = coo.shape 81 | print('adjacency matrix generated with shape {}'.format(shape)) 82 | # test 83 | return torch.sparse.FloatTensor(i, v, torch.Size(shape)).to(device) 84 | 85 | 86 | def get_sym_adj(data, opt, improved=False): 87 | edge_index, edge_weight = gcn_norm( # yapf: disable 88 | data.edge_index, data.edge_attr, data.num_nodes, 89 | improved, opt['self_loop_weight'] > 0, dtype=data.x.dtype) 90 | coo = to_scipy_sparse_matrix(edge_index, edge_weight) 91 | return coo2tensor(coo) 92 | 93 | 94 | def get_rw_adj_old(data, opt): 95 | if opt['self_loop_weight'] > 0: 96 | edge_index, edge_weight = add_remaining_self_loops(data.edge_index, data.edge_attr, 97 | fill_value=opt['self_loop_weight']) 98 | else: 99 | edge_index, edge_weight = data.edge_index, data.edge_attr 100 | coo = to_scipy_sparse_matrix(edge_index, edge_weight) 101 | normed_csc = normalize(coo, norm='l1', axis=0) 102 | return coo2tensor(normed_csc.tocoo()) 103 | 104 | 105 | def get_rw_adj(edge_index, edge_weight=None, norm_dim=1, fill_value=0., num_nodes=None, dtype=None): 106 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 107 | 108 | if edge_weight is None: 109 | edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, 110 | device=edge_index.device) 111 | 112 | if not fill_value == 0: 113 | edge_index, tmp_edge_weight = add_remaining_self_loops( 114 | edge_index, edge_weight, fill_value, num_nodes) 115 | assert tmp_edge_weight is not None 116 | edge_weight = tmp_edge_weight 117 | 118 | row, col = edge_index[0], edge_index[1] 119 | indices = row if norm_dim == 0 else col 120 | deg = scatter_add(edge_weight, indices, dim=0, dim_size=num_nodes) 121 | deg_inv_sqrt = deg.pow_(-1) 122 | edge_weight = deg_inv_sqrt[indices] * edge_weight if norm_dim == 0 else edge_weight * deg_inv_sqrt[indices] 123 | return edge_index, edge_weight 124 | 125 | 126 | def mean_confidence_interval(data, confidence=0.95): 127 | """ 128 | As number of samples will be < 10 use t-test for the mean confidence intervals 129 | :param data: NDarray of metric means 130 | :param confidence: The desired confidence interval 131 | :return: Float confidence interval 132 | """ 133 | if len(data) < 2: 134 | return 0 135 | a = 1.0 * np.array(data) 136 | n = len(a) 137 | _, se = np.mean(a), scipy.stats.sem(a) 138 | h = se * scipy.stats.t.ppf((1 + confidence) / 2., n - 1) 139 | return h 140 | 141 | 142 | def sparse_dense_mul(s, d): 143 | i = s._indices() 144 | v = s._values() 145 | return torch.sparse.FloatTensor(i, v * d, s.size()) 146 | 147 | 148 | def get_sem(vec): 149 | """ 150 | wrapper around the scipy standard error metric 151 | :param vec: List of metric means 152 | :return: 153 | """ 154 | if len(vec) > 1: 155 | retval = sem(vec) 156 | else: 157 | retval = 0. 158 | return retval 159 | 160 | 161 | def get_full_adjacency(num_nodes): 162 | # what is the format of the edge index? 163 | edge_index = torch.zeros((2, num_nodes ** 2),dtype=torch.long) 164 | for idx in range(num_nodes): 165 | edge_index[0][idx * num_nodes: (idx + 1) * num_nodes] = idx 166 | edge_index[1][idx * num_nodes: (idx + 1) * num_nodes] = torch.arange(0, num_nodes,dtype=torch.long) 167 | return edge_index 168 | 169 | 170 | 171 | from typing import Optional 172 | import torch 173 | from torch import Tensor 174 | from torch_scatter import scatter, segment_csr, gather_csr 175 | 176 | 177 | # https://twitter.com/jon_barron/status/1387167648669048833?s=12 178 | # @torch.jit.script 179 | def squareplus(src: Tensor, index: Optional[Tensor], ptr: Optional[Tensor] = None, 180 | num_nodes: Optional[int] = None) -> Tensor: 181 | r"""Computes a sparsely evaluated softmax. 182 | Given a value tensor :attr:`src`, this function first groups the values 183 | along the first dimension based on the indices specified in :attr:`index`, 184 | and then proceeds to compute the softmax individually for each group. 185 | 186 | Args: 187 | src (Tensor): The source tensor. 188 | index (LongTensor): The indices of elements for applying the softmax. 189 | ptr (LongTensor, optional): If given, computes the softmax based on 190 | sorted inputs in CSR representation. (default: :obj:`None`) 191 | num_nodes (int, optional): The number of nodes, *i.e.* 192 | :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) 193 | 194 | :rtype: :class:`Tensor` 195 | """ 196 | out = src - src.max() 197 | # out = out.exp() 198 | out = (out + torch.sqrt(out ** 2 + 4)) / 2 199 | 200 | if ptr is not None: 201 | out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr) 202 | elif index is not None: 203 | N = maybe_num_nodes(index, num_nodes) 204 | out_sum = scatter(out, index, dim=0, dim_size=N, reduce='sum')[index] 205 | else: 206 | raise NotImplementedError 207 | 208 | return out / (out_sum + 1e-16) 209 | 210 | 211 | # Counter of forward and backward passes. 212 | class Meter(object): 213 | 214 | def __init__(self): 215 | self.reset() 216 | 217 | def reset(self): 218 | self.val = None 219 | self.sum = 0 220 | self.cnt = 0 221 | 222 | def update(self, val): 223 | self.val = val 224 | self.sum += val 225 | self.cnt += 1 226 | 227 | def get_average(self): 228 | if self.cnt == 0: 229 | return 0 230 | return self.sum / self.cnt 231 | 232 | def get_value(self): 233 | return self.val 234 | 235 | 236 | class DummyDataset(object): 237 | def __init__(self, data, num_classes): 238 | self.data = data 239 | self.num_classes = num_classes 240 | 241 | 242 | class DummyData(object): 243 | def __init__(self, edge_index=None, edge_Attr=None, num_nodes=None): 244 | self.edge_index = edge_index 245 | self.edge_attr = edge_Attr 246 | self.num_nodes = num_nodes 247 | -------------------------------------------------------------------------------- /src/visualise_attention.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torch_geometric.nn import GCNConv, ChebConv # noqa 5 | from GNN import GNN 6 | import time 7 | from data import get_dataset 8 | from run_GNN import get_optimizer, print_model_params, train, test 9 | import networkx as nx 10 | import matplotlib.pyplot as plt 11 | 12 | def construct_graph(model): 13 | edges = model.odeblock.odefunc.edge_index 14 | edge_list = zip(edges[0], edges[1]) 15 | g = nx.Graph(edge_list) 16 | nx.draw(g) 17 | 18 | def main(opt): 19 | 20 | dataset = get_dataset(opt, '../data', False) 21 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 22 | model, data = GNN(opt, dataset, device).to(device), dataset.data.to(device) 23 | print(opt) 24 | # todo for some reason the submodule parameters inside the attention module don't show up when running on GPU. 25 | parameters = [p for p in model.parameters() if p.requires_grad] 26 | print_model_params(model) 27 | optimizer = get_optimizer(opt['optimizer'], parameters, lr=opt['lr'], weight_decay=opt['decay']) 28 | best_val_acc = test_acc = best_epoch = 0 29 | for epoch in range(1, opt['epoch']): 30 | start_time = time.time() 31 | 32 | loss = train(model, optimizer, data) 33 | train_acc, val_acc, tmp_test_acc = test(model, data) 34 | 35 | if val_acc > best_val_acc: 36 | best_val_acc = val_acc 37 | test_acc = tmp_test_acc 38 | best_epoch = epoch 39 | log = 'Epoch: {:03d}, Runtime {:03f}, Loss {:03f}, forward nfe {:d}, backward nfe {:d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' 40 | print( 41 | log.format(epoch, time.time() - start_time, loss, model.fm.sum, model.bm.sum, train_acc, best_val_acc, test_acc)) 42 | print('best val accuracy {:03f} with test accuracy {:03f} at epoch {:d}'.format(best_val_acc, test_acc, best_epoch)) 43 | 44 | construct_graph(model) 45 | 46 | return train_acc, best_val_acc, test_acc 47 | 48 | 49 | if __name__ == '__main__': 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument('--use_cora_defaults', action='store_true', 52 | help='Whether to run with best params for cora. Overrides the choice of dataset') 53 | parser.add_argument('--dataset', type=str, default='Cora', 54 | help='Cora, Citeseer, Pubmed, Computers, Photo, CoauthorCS') 55 | parser.add_argument('--hidden_dim', type=int, default=16, help='Hidden dimension.') 56 | parser.add_argument('--input_dropout', type=float, default=0.5, help='Input dropout rate.') 57 | parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.') 58 | parser.add_argument('--optimizer', type=str, default='adam', help='One from sgd, rmsprop, adam, adagrad, adamax.') 59 | parser.add_argument('--lr', type=float, default=0.01, help='Learning rate.') 60 | parser.add_argument('--decay', type=float, default=5e-4, help='Weight decay for optimization') 61 | parser.add_argument('--self_loop_weight', type=float, default=1.0, help='Weight of self-loops.') 62 | parser.add_argument('--epoch', type=int, default=10, help='Number of training epochs per iteration.') 63 | parser.add_argument('--alpha', type=float, default=1.0, help='Factor in front matrix A.') 64 | parser.add_argument('--time', type=float, default=1.0, help='End time of ODE integrator.') 65 | parser.add_argument('--augment', action='store_true', 66 | help='double the length of the feature vector by appending zeros to stabilist ODE learning') 67 | parser.add_argument('--alpha_dim', type=str, default='sc', help='choose either scalar (sc) or vector (vc) alpha') 68 | parser.add_argument('--no_alpha_sigmoid', dest='no_alpha_sigmoid', action='store_true', help='apply sigmoid before multiplying by alpha') 69 | parser.add_argument('--beta_dim', type=str, default='sc', help='choose either scalar (sc) or vector (vc) beta') 70 | parser.add_argument('--block', type=str, default='attention', help='constant, mixed, attention, SDE') 71 | parser.add_argument('--function', type=str, default='laplacian', help='laplacian, transformer, dorsey, GAT, SDE') 72 | # ODE args 73 | parser.add_argument('--method', type=str, default='dopri5', 74 | help="set the numerical solver: dopri5, euler, rk4, midpoint") 75 | parser.add_argument( 76 | "--adjoint_method", type=str, default="adaptive_heun", help="set the numerical solver for the backward pass: dopri5, euler, rk4, midpoint" 77 | ) 78 | parser.add_argument('--adjoint', default=False, help='use the adjoint ODE method to reduce memory footprint') 79 | parser.add_argument('--tol_scale', type=float, default=1., help='multiplier for atol and rtol') 80 | parser.add_argument("--tol_scale_adjoint", type=float, default=1.0, 81 | help="multiplier for adjoint_atol and adjoint_rtol") 82 | parser.add_argument('--ode_blocks', type=int, default=1, help='number of ode blocks to run') 83 | parser.add_argument('--add_source', dest='add_source', action='store_true', 84 | help='If try get rid of alpha param and the beta*x0 source term') 85 | # SDE args 86 | parser.add_argument('--dt_min', type=float, default=1e-5, help='minimum timestep for the SDE solver') 87 | parser.add_argument('--dt', type=float, default=1e-3, help='fixed step size') 88 | parser.add_argument('--adaptive', type=bool, default=False, help='use adaptive step sizes') 89 | # Attention args 90 | parser.add_argument('--leaky_relu_slope', type=float, default=0.2, 91 | help='slope of the negative part of the leaky relu used in attention') 92 | parser.add_argument('--attention_dropout', type=float, default=0., help='dropout of attention weights') 93 | parser.add_argument('--heads', type=int, default=4, help='number of attention heads') 94 | parser.add_argument('--attention_norm_idx', type=int, default=0, help='0 = normalise rows, 1 = normalise cols') 95 | parser.add_argument('--attention_dim', type=int, default=64, 96 | help='the size to project x to before calculating att scores') 97 | parser.add_argument('--mix_features', type=bool, default=False, 98 | help='apply a feature transformation xW to the ODE') 99 | parser.add_argument("--max_nfe", type=int, default=1000, help="Maximum number of function evaluations allowed.") 100 | 101 | parser.add_argument('--jacobian_norm2', type=float, default=None, help="int_t ||df/dx||_F^2") 102 | parser.add_argument('--total_deriv', type=float, default=None, help="int_t ||df/dt||^2") 103 | 104 | parser.add_argument('--kinetic_energy', type=float, default=None, help="int_t ||f||_2^2") 105 | parser.add_argument('--directional_penalty', type=float, default=None, help="int_t ||(df/dx)^T f||^2") 106 | 107 | args = parser.parse_args() 108 | 109 | opt = vars(args) 110 | 111 | main(opt) 112 | -------------------------------------------------------------------------------- /test/test_ICML_gnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Test attention 5 | """ 6 | import os 7 | import sys 8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) 9 | 10 | import unittest 11 | import torch 12 | from torch import tensor 13 | from torch import nn 14 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 15 | from torch_geometric.utils.convert import to_scipy_sparse_matrix 16 | from ray.tune.utils import diagnose_serialization 17 | from functools import partial 18 | 19 | from CGNN import gcn_norm_fill_val, coo2tensor, train_ray 20 | from data import get_dataset 21 | from test_params import OPT 22 | 23 | 24 | class ICMLGNNTests(unittest.TestCase): 25 | def setUp(self): 26 | self.edge = tensor([[0, 2, 2], [1, 0, 1]]) 27 | self.x = tensor([[1., 2.], [3., 2.], [4., 5.]], dtype=float) 28 | self.W = tensor([[2, 1], [3, 2]], dtype=float) 29 | self.alpha = tensor([[1, 2, 3, 4]], dtype=float) 30 | self.leakyrelu = nn.LeakyReLU(0.2) 31 | 32 | def tearDown(self) -> None: 33 | pass 34 | 35 | def test_fill_norm(self): 36 | opt = {'dataset': 'Cora', 'improved': False, 'self_loop_weight': 1., 'rewiring': None, 'no_alpha_sigmoid': False, 37 | 'reweight_attention': False, 'kinetic_energy': None, 'jacobian_norm2': None, 'total_deriv': None, 'directional_penalty': None, 'beltrami': False} 38 | opt = {**OPT, **opt} 39 | dataset = get_dataset(opt, '../data', False) 40 | data = dataset.data 41 | edge_index1, edge_weight1 = gcn_norm(data.edge_index, data.edge_attr, data.num_nodes, 42 | opt['improved'], opt['self_loop_weight'] > 0, dtype=data.x.dtype) 43 | edge_index, edge_weight = gcn_norm_fill_val(data.edge_index, data.edge_attr, opt['self_loop_weight'], 44 | data.num_nodes, dtype=data.x.dtype) 45 | assert torch.all(edge_index.eq(edge_index1)) 46 | assert torch.all(edge_weight.eq(edge_weight1)) 47 | 48 | 49 | def main(): 50 | data_dir = os.path.abspath("../data") 51 | trainable = partial(train_ray, data_dir=data_dir) 52 | diagnose_serialization(trainable) 53 | opt = {'dataset': 'Cora', 'improved': False, 'self_loop_weight': 1.} 54 | dataset = get_dataset(opt, '../data', False) 55 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 56 | data = dataset.data 57 | edge_index1, edge_weight1 = gcn_norm(data.edge_index, data.edge_attr, data.num_nodes, 58 | opt['improved'], opt['self_loop_weight'] > 0, dtype=data.x.dtype) 59 | edge_index, edge_weight = gcn_norm_fill_val(data.edge_index, data.edge_attr, opt['self_loop_weight'], data.num_nodes, 60 | opt['self_loop_weight'] > 0) 61 | assert torch.all(edge_index.eq(edge_index1)) 62 | assert torch.all(edge_weight.eq(edge_weight1)) 63 | coo = to_scipy_sparse_matrix(edge_index, edge_weight) 64 | coo = coo2tensor(coo, device) 65 | 66 | 67 | if __name__ == '__main__': 68 | main() 69 | -------------------------------------------------------------------------------- /test/test_attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Test attention 5 | """ 6 | # needed for CI/CD 7 | import os 8 | import sys 9 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) 10 | 11 | import unittest 12 | import torch 13 | from torch import tensor 14 | from torch import nn 15 | 16 | from function_GAT_attention import SpGraphAttentionLayer, ODEFuncAtt 17 | from torch_geometric.utils import softmax, to_dense_adj 18 | from data import get_dataset 19 | from test_params import OPT 20 | 21 | 22 | class AttentionTests(unittest.TestCase): 23 | def setUp(self): 24 | self.edge = tensor([[0, 2, 2, 1], [1, 0, 1, 2]]) 25 | self.x = tensor([[1., 2.], [3., 2.], [4., 5.]], dtype=torch.float) 26 | self.W = tensor([[2, 1], [3, 2]], dtype=torch.float) 27 | self.alpha = tensor([[1, 2, 3, 4]], dtype=torch.float) 28 | self.edge1 = tensor([[0, 0, 1, 1, 2, 2], [1, 2, 0, 2, 0, 1]]) 29 | self.x1 = torch.ones((3, 2), dtype=torch.float) 30 | 31 | self.leakyrelu = nn.LeakyReLU(0.2) 32 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 33 | opt = {'dataset': 'Cora', 'self_loop_weight': 1, 'leaky_relu_slope': 0.2, 'beta_dim': 'vc', 'heads': 2, 34 | 'K': 10, 35 | 'attention_norm_idx': 0, 'add_source': False, 'max_nfe': 1000, 'mix_features': False, 36 | 'attention_dim': 32, 37 | 'mixed_block': False, 'rewiring': None, 'no_alpha_sigmoid': False, 'reweight_attention': False, 38 | 'kinetic_energy': None, 'jacobian_norm2': None, 'total_deriv': None, 'directional_penalty': None} 39 | self.opt = {**OPT, **opt} 40 | 41 | def tearDown(self) -> None: 42 | pass 43 | 44 | def test(self): 45 | h = torch.mm(self.x, self.W) 46 | edge_h = torch.cat((h[self.edge[0, :], :], h[self.edge[1, :], :]), dim=1) 47 | self.assertTrue(edge_h.shape == torch.Size([self.edge.shape[1], 2 * 2])) 48 | ah = self.alpha.mm(edge_h.t()).t() 49 | self.assertTrue(ah.shape == torch.Size([self.edge.shape[1], 1])) 50 | edge_e = self.leakyrelu(ah) 51 | attention = softmax(edge_e, self.edge[1]) 52 | print(attention) 53 | 54 | def test_function(self): 55 | in_features = self.x.shape[1] 56 | out_features = self.x.shape[1] 57 | 58 | def get_round_sum(tens, n_digits=3): 59 | val = torch.sum(tens, dim=int(not self.opt['attention_norm_idx'])) 60 | return (val * 10 ** n_digits).round() / (10 ** n_digits) 61 | 62 | att_layer = SpGraphAttentionLayer(in_features, out_features, self.opt, self.device, concat=True) 63 | attention, _ = att_layer(self.x, self.edge) # should be n_edges x n_heads 64 | self.assertTrue(attention.shape == (self.edge.shape[1], self.opt['heads'])) 65 | dense_attention1 = to_dense_adj(self.edge, edge_attr=attention[:, 0]).squeeze() 66 | dense_attention2 = to_dense_adj(self.edge, edge_attr=attention[:, 1]).squeeze() 67 | 68 | self.assertTrue(torch.all(torch.eq(get_round_sum(dense_attention1), 1.))) 69 | self.assertTrue(torch.all(torch.eq(get_round_sum(dense_attention2), 1.))) 70 | 71 | self.assertTrue(torch.all(attention > 0.)) 72 | self.assertTrue(torch.all(attention <= 1.)) 73 | 74 | dataset = get_dataset(self.opt, '../data', False) 75 | data = dataset.data 76 | in_features = data.x.shape[1] 77 | out_features = data.x.shape[1] 78 | 79 | att_layer = SpGraphAttentionLayer(in_features, out_features, self.opt, self.device, concat=True) 80 | attention, _ = att_layer(data.x, data.edge_index) # should be n_edges x n_heads 81 | 82 | self.assertTrue(attention.shape == (data.edge_index.shape[1], self.opt['heads'])) 83 | dense_attention1 = to_dense_adj(data.edge_index, edge_attr=attention[:, 0]).squeeze() 84 | dense_attention2 = to_dense_adj(data.edge_index, edge_attr=attention[:, 1]).squeeze() 85 | self.assertTrue(torch.all(torch.eq(get_round_sum(dense_attention1), 1.))) 86 | self.assertTrue(torch.all(torch.eq(get_round_sum(dense_attention2), 1.))) 87 | self.assertTrue(torch.all(attention > 0.)) 88 | self.assertTrue(torch.all(attention <= 1.)) 89 | 90 | def test_symetric_attention(self): 91 | in_features = self.x1.shape[1] 92 | out_features = self.x1.shape[1] 93 | att_layer = SpGraphAttentionLayer(in_features, out_features, self.opt, self.device, concat=True) 94 | attention, _ = att_layer(self.x1, self.edge1) # should be n_edges x n_heads 95 | 96 | self.assertTrue(torch.all(torch.eq(attention, 0.5 * torch.ones((self.edge1.shape[1], self.x1.shape[1]))))) 97 | 98 | def test_module(self): 99 | dataset = get_dataset(self.opt, '../data', False) 100 | t = 1 101 | out_dim = 6 102 | func = ODEFuncAtt(dataset.data.num_features, out_dim, self.opt, dataset.data, self.device) 103 | out = func(t, dataset.data.x) 104 | print(out.shape) 105 | self.assertTrue(out.shape == (dataset.data.num_nodes, dataset.num_features)) 106 | -------------------------------------------------------------------------------- /test/test_attention_ode_block.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # needed for CI/CD 4 | import os 5 | import sys 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) 7 | 8 | import unittest 9 | import torch 10 | from torch import tensor 11 | from torch import nn 12 | 13 | from data import get_dataset 14 | from function_laplacian_diffusion import LaplacianODEFunc 15 | from GNN import GNN 16 | from block_transformer_attention import AttODEblock 17 | from test_params import OPT 18 | 19 | class AttentionODEBlockTests(unittest.TestCase): 20 | def setUp(self): 21 | self.edge = tensor([[0, 2, 2, 1], [1, 0, 1, 2]]) 22 | self.x = tensor([[1., 2.], [3., 2.], [4., 5.]], dtype=torch.float) 23 | self.W = tensor([[2, 1], [3, 2]], dtype=torch.float) 24 | self.alpha = tensor([[1, 2, 3, 4]], dtype=torch.float) 25 | self.edge1 = tensor([[0, 0, 1, 1, 2, 2], [1, 2, 0, 2, 0, 1]]) 26 | self.x1 = torch.ones((3, 2), dtype=torch.float) 27 | 28 | self.leakyrelu = nn.LeakyReLU(0.2) 29 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 30 | self.opt = OPT 31 | self.dataset = get_dataset(self.opt, '../data', False) 32 | 33 | def tearDown(self) -> None: 34 | pass 35 | 36 | def test_block(self): 37 | data = self.dataset.data 38 | self.opt['hidden_dim'] = self.dataset.num_features 39 | self.opt['heads'] = 1 40 | gnn = GNN(self.opt, self.dataset, device=self.device) 41 | odeblock = gnn.odeblock 42 | self.assertTrue(isinstance(odeblock, AttODEblock)) 43 | self.assertTrue(isinstance(odeblock.odefunc, LaplacianODEFunc)) 44 | gnn.train() 45 | out = odeblock(data.x) 46 | self.assertTrue(data.x.shape == out.shape) 47 | gnn.eval() 48 | out = odeblock(data.x) 49 | print('ode block out', out) 50 | self.assertTrue(data.x.shape == out.shape) 51 | self.opt['heads'] = 2 52 | try: 53 | gnn = GNN(self.opt, self.dataset, device=self.device) 54 | self.assertTrue(False) 55 | except AssertionError: 56 | pass 57 | 58 | def test_gnn(self): 59 | self.opt['attention_dim']=32 60 | gnn = GNN(self.opt, self.dataset, device=self.device) 61 | gnn.train() 62 | out = gnn(self.dataset.data.x) 63 | print(out.shape) 64 | print(torch.Size([self.dataset.data.num_nodes, self.dataset.num_classes])) 65 | self.assertTrue(out.shape == torch.Size([self.dataset.data.num_nodes, self.dataset.num_classes])) 66 | gnn.eval() 67 | out = gnn(self.dataset.data.x) 68 | self.assertTrue(out.shape == torch.Size([self.dataset.data.num_nodes, self.dataset.num_classes])) 69 | 70 | 71 | if __name__ == '__main__': 72 | est = AttentionODEBlockTests() 73 | est.setUp() 74 | est.test_block() 75 | -------------------------------------------------------------------------------- /test/test_block_mixed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # needed for CI/CD 4 | import os 5 | import sys 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) 7 | 8 | import unittest 9 | import torch 10 | from torch import tensor 11 | from torch import nn 12 | import numpy as np 13 | 14 | from data import get_dataset 15 | from function_laplacian_diffusion import LaplacianODEFunc 16 | from GNN import GNN 17 | from block_mixed import MixedODEblock 18 | from torch_geometric.data import Data 19 | from torch_geometric.utils import to_dense_adj 20 | from test_params import OPT 21 | 22 | 23 | class DummyDataset(): 24 | def __init__(self, data, num_classes): 25 | self.data = data 26 | self.num_classes = num_classes 27 | 28 | 29 | class MixedODEBlockTests(unittest.TestCase): 30 | def setUp(self): 31 | self.edge = tensor([[0, 2, 2, 1], [1, 0, 1, 2]]) 32 | self.x = tensor([[1., 2.], [3., 2.], [4., 5.]], dtype=torch.float) 33 | self.W = tensor([[2, 1], [3, 2]], dtype=torch.float) 34 | self.alpha = tensor([[1, 2, 3, 4]], dtype=torch.float) 35 | self.edge1 = tensor([[0, 0, 1, 1, 2, 2], [1, 2, 0, 2, 0, 1]]) 36 | self.x1 = torch.ones((3, 2), dtype=torch.float) 37 | self.data = Data(x=self.x, edge_index=self.edge) 38 | 39 | self.leakyrelu = nn.LeakyReLU(0.2) 40 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 41 | opt = {'dataset': 'Cora', 'self_loop_weight': 1, 'leaky_relu_slope': 0.2, 'beta_dim': 'sc', 'heads': 2, 42 | 'K': 10, 'attention_norm_idx': 0, 'add_source': False, 'alpha': 1, 'alpha_dim': 'vc', 43 | 'hidden_dim': 6, 'block': 'mixed', 'function': 'laplacian', 'augment': False, 'adjoint': False, 44 | 'tol_scale': 1, 'time': 1, 'ode': 'ode', 'input_dropout': 0.5, 'dropout': 0.5, 'method': 'euler', 45 | 'rewiring': None, 'no_alpha_sigmoid': False, 'reweight_attention': False, 'kinetic_energy': None, 46 | 'total_deriv': None, 'directional_penalty': None, 'jacobian_norm2': None, 'step_size':1, 'max_iter': 10, 'beltrami': False} 47 | self.opt = {**OPT, **opt} 48 | 49 | self.dataset = get_dataset(self.opt, '../data', False) 50 | 51 | def tearDown(self) -> None: 52 | pass 53 | 54 | def test_block_toy(self): 55 | # construct a pyg dataset 56 | dataset = DummyDataset(self.data, 3) 57 | self.opt['heads'] = 1 58 | self.opt['hidden_dim'] = 2 # same as the raw data so we don't have to first encode it 59 | gnn = GNN(self.opt, dataset, device=self.device) 60 | odeblock = gnn.odeblock 61 | self.assertTrue(isinstance(odeblock, MixedODEblock)) 62 | self.assertTrue(isinstance(odeblock.odefunc, LaplacianODEFunc)) 63 | self.assertTrue(odeblock.gamma.item() == 0.) 64 | 65 | def test_get_mixed_attention(self): 66 | dataset = DummyDataset(self.data, 3) 67 | self.opt['heads'] = 1 68 | self.opt['hidden_dim'] = 2 # same as the raw data so we don't have to first encode it 69 | gnn = GNN(self.opt, dataset, device=self.device) 70 | odeblock = gnn.odeblock 71 | attention = odeblock.get_attention_weights(self.x) 72 | mixed_att_weights = odeblock.get_mixed_attention(self.x) 73 | mixed_att = to_dense_adj(odeblock.odefunc.edge_index, 74 | edge_attr=mixed_att_weights).detach().numpy().squeeze() 75 | rw_arr = to_dense_adj(odeblock.odefunc.edge_index, 76 | edge_attr=odeblock.odefunc.edge_weight).detach().numpy().squeeze() 77 | att_arr = to_dense_adj(odeblock.odefunc.edge_index, edge_attr=attention).detach().numpy().squeeze() 78 | gamma = torch.sigmoid(odeblock.gamma).detach().numpy() 79 | mixed_att_test = (1 - gamma) * att_arr + gamma * rw_arr 80 | self.assertTrue(np.allclose(mixed_att, mixed_att_test)) 81 | 82 | def test_block_cora(self): 83 | data = self.dataset.data 84 | self.opt['hidden_dim'] = self.dataset.num_features 85 | self.opt['heads'] = 1 86 | gnn = GNN(self.opt, self.dataset, device=self.device) 87 | odeblock = gnn.odeblock 88 | self.assertTrue(isinstance(odeblock, MixedODEblock)) 89 | self.assertTrue(isinstance(odeblock.odefunc, LaplacianODEFunc)) 90 | self.assertTrue(odeblock.gamma.item() == 0.) 91 | self.assertTrue(odeblock.odefunc.edge_weight.shape is not None) 92 | gnn.train() 93 | out = odeblock(data.x) 94 | self.assertTrue(data.x.shape == out.shape) 95 | gnn.eval() 96 | out = odeblock(data.x) 97 | self.assertTrue(data.x.shape == out.shape) 98 | self.opt['heads'] = 2 99 | try: 100 | gnn = GNN(self.opt, self.dataset, device=self.device) 101 | self.assertTrue(False) 102 | except AssertionError: 103 | pass 104 | 105 | 106 | if __name__ == '__main__': 107 | tests = MixedODEBlockTests() 108 | tests.setUp() 109 | tests.test_block_toy() 110 | -------------------------------------------------------------------------------- /test/test_early_stop.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Test early stop 5 | """ 6 | import os 7 | import sys 8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) 9 | 10 | import unittest 11 | import torch 12 | from torch import tensor 13 | from torch import nn 14 | 15 | from data import get_dataset 16 | from function_laplacian_diffusion import LaplacianODEFunc 17 | from GNN_early import GNNEarly 18 | from block_constant import ConstantODEblock 19 | from utils import get_rw_adj 20 | from test_params import OPT 21 | 22 | 23 | class EarlyStopTests(unittest.TestCase): 24 | def setUp(self): 25 | self.edge = tensor([[0, 2, 2, 1], [1, 0, 1, 2]]) 26 | self.x = tensor([[1., 2.], [3., 2.], [4., 5.]], dtype=torch.float) 27 | self.W = tensor([[2, 1], [3, 2]], dtype=torch.float) 28 | self.alpha = tensor([[1, 2, 3, 4]], dtype=torch.float) 29 | self.edge1 = tensor([[0, 0, 1, 1, 2, 2], [1, 2, 0, 2, 0, 1]]) 30 | self.x1 = torch.ones((3, 2), dtype=torch.float) 31 | 32 | self.leakyrelu = nn.LeakyReLU(0.2) 33 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 34 | opt = {'dataset': 'Cora', 'self_loop_weight': 1, 'leaky_relu_slope': 0.2, 'beta_dim': 'vc', 'heads': 2, 'K': 10, 35 | 'attention_norm_idx': 0, 'add_source': False, 'alpha': 1, 'alpha_dim': 'vc', 'beta_dim': 'vc', 36 | 'hidden_dim': 6, 'block': 'constant', 'function': 'laplacian', 'augment': False, 'adjoint': False, 37 | 'tol_scale': 1, 'time': 1, 'ode': 'ode', 'input_dropout': 0.5, 'dropout': 0.5, 'method': 'dopri5', 38 | 'rewiring': None, 'no_alpha_sigmoid': False, 'reweight_attention': False, 'kinetic_energy': None, 39 | 'jacobian_norm2': None, 'total_deriv': None, 'directional_penalty': None, 'step_size': 1, 'data_norm': 'rw', 40 | 'earlystopxT': 3, 'max_iters': 10, 'beltrami': False} 41 | self.opt = {**OPT, **opt} 42 | self.dataset = get_dataset(self.opt, '../data', False) 43 | 44 | def tearDown(self) -> None: 45 | pass 46 | 47 | def test_block(self): 48 | data = self.dataset.data 49 | t = 1 50 | out_dim = 6 51 | func = LaplacianODEFunc(self.dataset.data.num_features, out_dim, self.opt, data, self.device) 52 | func.edge_index, func.edge_weight = get_rw_adj(data.edge_index, edge_weight=None, norm_dim=1, 53 | fill_value=self.opt['self_loop_weight'], 54 | num_nodes=data.num_nodes, 55 | dtype=data.x.dtype) 56 | out = func(t, data.x) 57 | print(out.shape) 58 | self.assertTrue(out.shape == (self.dataset.data.num_nodes, self.dataset.num_features)) 59 | gnn = GNNEarly(self.opt, self.dataset, device=self.device) 60 | odeblock = gnn.odeblock 61 | self.assertTrue(isinstance(odeblock, ConstantODEblock)) 62 | self.assertTrue(isinstance(odeblock.odefunc, LaplacianODEFunc)) 63 | self.assertTrue(odeblock.test_integrator.data.x.shape == data.x.shape) 64 | gnn.train() 65 | out = odeblock(data.x) 66 | self.assertTrue(data.x.shape == out.shape) 67 | gnn.eval() 68 | gnn.set_solver_m2() 69 | gnn.set_solver_data(data) 70 | out = odeblock(data.x) 71 | print('ode block out', out) 72 | self.assertTrue(data.x.shape == out.shape) 73 | 74 | def test_rk4(self): 75 | data = self.dataset.data 76 | t = 1 77 | out_dim = 6 78 | self.opt['method'] = 'rk4' 79 | func = LaplacianODEFunc(self.dataset.data.num_features, out_dim, self.opt, data, self.device) 80 | func.edge_index, func.edge_weight = get_rw_adj(data.edge_index, edge_weight=None, norm_dim=1, 81 | fill_value=self.opt['self_loop_weight'], 82 | num_nodes=data.num_nodes, 83 | dtype=data.x.dtype) 84 | out = func(t, data.x) 85 | print(out.shape) 86 | self.assertTrue(out.shape == (self.dataset.data.num_nodes, self.dataset.num_features)) 87 | gnn = GNNEarly(self.opt, self.dataset, device=self.device) 88 | odeblock = gnn.odeblock 89 | self.assertTrue(isinstance(odeblock, ConstantODEblock)) 90 | self.assertTrue(isinstance(odeblock.odefunc, LaplacianODEFunc)) 91 | self.assertTrue(odeblock.test_integrator.data.x.shape == data.x.shape) 92 | gnn.train() 93 | out = odeblock(data.x) 94 | self.assertTrue(data.x.shape == out.shape) 95 | gnn.eval() 96 | gnn.set_solver_m2() 97 | gnn.set_solver_data(data) 98 | out = odeblock(data.x) 99 | print('ode block out', out) 100 | self.assertTrue(data.x.shape == out.shape) 101 | 102 | def test_gnn(self): 103 | gnn = GNNEarly(self.opt, self.dataset, device=self.device) 104 | gnn.train() 105 | out = gnn(self.dataset.data.x) 106 | print(out.shape) 107 | print(torch.Size([self.dataset.data.num_nodes, self.dataset.num_classes])) 108 | self.assertTrue(out.shape == torch.Size([self.dataset.data.num_nodes, self.dataset.num_classes])) 109 | gnn.eval() 110 | out = gnn(self.dataset.data.x) 111 | self.assertTrue(out.shape == torch.Size([self.dataset.data.num_nodes, self.dataset.num_classes])) 112 | solver = gnn.odeblock.test_integrator.solver 113 | self.assertTrue(solver.best_val >= 0) 114 | self.assertTrue(solver.best_test >= 0) 115 | 116 | 117 | if __name__ == '__main__': 118 | est = EarlyStopTests() 119 | est.setUp() 120 | est.test_gnn() 121 | -------------------------------------------------------------------------------- /test/test_function_laplacian_diffusion.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import sys 5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) 6 | 7 | import unittest 8 | import torch 9 | from torch import tensor 10 | from torch import nn 11 | import numpy as np 12 | from sklearn.preprocessing import normalize 13 | 14 | from data import get_dataset 15 | from function_laplacian_diffusion import LaplacianODEFunc 16 | from GNN import GNN 17 | from block_constant import ConstantODEblock 18 | from torch_geometric.data import Data 19 | from torch_geometric.utils import to_dense_adj 20 | from utils import get_rw_adj, get_sym_adj 21 | from test_params import OPT 22 | 23 | 24 | class DummyDataset(): 25 | def __init__(self, data, num_classes): 26 | self.data = data 27 | self.num_classes = num_classes 28 | 29 | 30 | class FunctionLaplacianDiffusionTests(unittest.TestCase): 31 | def setUp(self): 32 | self.edge = tensor([[0, 1, 2, 1], [1, 0, 1, 2]]) 33 | self.x = tensor([[1., 2.], [3., 2.], [4., 5.]], dtype=torch.float) 34 | self.W = tensor([[2, 1], [3, 2]], dtype=torch.float) 35 | self.alpha = tensor([[1, 2, 3, 4]], dtype=torch.float) 36 | self.edge1 = tensor([[0, 0, 1, 1, 2, 2], [1, 2, 0, 2, 0, 1]]) 37 | self.x1 = torch.ones((3, 2), dtype=torch.float) 38 | self.data = Data(x=self.x, edge_index=self.edge) 39 | 40 | self.leakyrelu = nn.LeakyReLU(0.2) 41 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 42 | opt = {'dataset': 'Cora', 'self_loop_weight': 1, 'leaky_relu_slope': 0.2, 'beta_dim': 'vc', 'heads': 2, 43 | 'K': 10, 44 | 'attention_norm_idx': 0, 'add_source': False, 'alpha': 1, 'alpha_dim': 'vc', 'beta_dim': 'vc', 45 | 'hidden_dim': 6, 'augment': False, 'adjoint': False, 46 | 'block': 'constant', 'function': 'laplacian', 47 | 'tol_scale': 1, 'time': 1, 'ode': 'ode', 'input_dropout': 0.5, 'dropout': 0.5, 'method': 'euler', 48 | 'rewiring': None, 'no_alpha_sigmoid': False, 'reweight_attention': False, 'kinetic_energy': None, 49 | 'jacobian_norm2': None, 'total_deriv': None, 'directional_penalty': None, 'step_size': 1, 'data_norm': 'rw', 50 | 'max_iters': 10, 'beltrami': False} 51 | self.opt = {**OPT, **opt} 52 | 53 | self.dataset = get_dataset(self.opt, '../data', False) 54 | 55 | def tearDown(self) -> None: 56 | pass 57 | 58 | def test_block_toy(self): 59 | # construct a pyg dataset 60 | num_nodes = 3 61 | dataset = DummyDataset(self.data, num_nodes) 62 | gnn = GNN(self.opt, dataset, device=self.device) 63 | odeblock = gnn.odeblock 64 | self.assertTrue(isinstance(odeblock, ConstantODEblock)) 65 | func = odeblock.odefunc 66 | self.assertTrue(isinstance(func, LaplacianODEFunc)) 67 | func.edge_index, func.edge_weight = get_rw_adj(self.edge, edge_weight=None, norm_dim=1, 68 | fill_value=self.opt['self_loop_weight'], num_nodes=3, 69 | dtype=None) 70 | sym_adj = get_sym_adj(self.data, self.opt) 71 | 72 | rw_adj = to_dense_adj(func.edge_index, edge_attr=func.edge_weight).numpy().squeeze() 73 | input_adj = to_dense_adj(dataset.data.edge_index).numpy().squeeze() 74 | augmented_input = input_adj + self.opt['self_loop_weight'] * np.identity(num_nodes) 75 | augmented_degree = augmented_input.sum(axis=1) # symmetric so axis doesn't matter 76 | sqr_degree = np.sqrt(augmented_degree) 77 | test_sym_adj = np.divide(np.divide(augmented_input, sqr_degree[:, None]), sqr_degree[None, :]) 78 | 79 | test_rw_adj = normalize(augmented_input, norm='l1', axis=0) 80 | sym_adj = sym_adj.to_dense() 81 | print('rw adjacency', rw_adj) 82 | print('test rw adjacency', test_rw_adj) 83 | self.assertTrue(np.allclose(test_rw_adj, rw_adj)) 84 | self.assertTrue(np.allclose(test_sym_adj, sym_adj.numpy().squeeze())) 85 | print('sym adjacency', sym_adj) 86 | 87 | def test_block_cora(self): 88 | data = self.dataset.data 89 | self.opt['hidden_dim'] = self.dataset.num_features 90 | self.opt['heads'] = 1 91 | gnn = GNN(self.opt, self.dataset, device=self.device) 92 | odeblock = gnn.odeblock 93 | self.assertTrue(isinstance(odeblock, ConstantODEblock)) 94 | self.assertTrue(isinstance(odeblock.odefunc, LaplacianODEFunc)) 95 | gnn.train() 96 | out = odeblock(data.x) 97 | self.assertTrue(data.x.shape == out.shape) 98 | gnn.eval() 99 | out = odeblock(data.x) 100 | self.assertTrue(data.x.shape == out.shape) 101 | 102 | 103 | if __name__ == '__main__': 104 | tests = FunctionLaplacianDiffusionTests() 105 | tests.setUp() 106 | tests.test_block_cora() 107 | -------------------------------------------------------------------------------- /test/test_gnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Test the GNN class 5 | """ 6 | import os 7 | import sys 8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) 9 | 10 | import unittest 11 | import torch 12 | from torch import tensor 13 | from torch import nn 14 | 15 | from data import get_dataset 16 | from function_laplacian_diffusion import LaplacianODEFunc 17 | from block_constant import ConstantODEblock 18 | from GNN import GNN 19 | from test_params import OPT 20 | 21 | 22 | class GNNTests(unittest.TestCase): 23 | def setUp(self): 24 | self.edge = tensor([[0, 2, 2, 1], [1, 0, 1, 2]]) 25 | self.x = tensor([[1., 2.], [3., 2.], [4., 5.]], dtype=torch.float) 26 | self.W = tensor([[2, 1], [3, 2]], dtype=torch.float) 27 | self.alpha = tensor([[1, 2, 3, 4]], dtype=torch.float) 28 | self.edge1 = tensor([[0, 0, 1, 1, 2, 2], [1, 2, 0, 2, 0, 1]]) 29 | self.x1 = torch.ones((3, 2), dtype=torch.float) 30 | 31 | self.leakyrelu = nn.LeakyReLU(0.2) 32 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 33 | opt = {'dataset': 'Cora', 'self_loop_weight': 1, 'leaky_relu_slope': 0.2, 'beta_dim': 'vc', 'heads': 2, 34 | 'K': 10, 35 | 'attention_norm_idx': 0, 'add_source': False, 'alpha': 1, 'alpha_dim': 'vc', 'beta_dim': 'vc', 36 | 'hidden_dim': 6, 'augment': False, 'adjoint': False, 37 | 'tol_scale': 1, 'time': 1, 'ode': 'ode', 'input_dropout': 0.5, 'dropout': 0.5, 'method': 'euler', 38 | 'block': 'constant', 'function': 'laplacian', 'rewiring': None, 'no_alpha_sigmoid': False, 39 | 'reweight_attention': False, 'kinetic_energy': None, 'jacobian_norm2': None, 'total_deriv': None, 'directional_penalty': None 40 | , 'step_size': 1, 'data_norm': 'rw', 'max_iters': 10, 'beltrami': False} 41 | self.opt = {**OPT, **opt} 42 | self.dataset = get_dataset(self.opt, '../data', False) 43 | 44 | def tearDown(self) -> None: 45 | pass 46 | 47 | def test_constant_block(self): 48 | data = self.dataset.data 49 | self.opt['hidden_dim'] = self.dataset.num_features 50 | self.opt['heads'] = 1 51 | gnn = GNN(self.opt, self.dataset, device=self.device) 52 | odeblock = gnn.odeblock 53 | self.assertTrue(isinstance(odeblock, ConstantODEblock)) 54 | self.assertTrue(isinstance(odeblock.odefunc, LaplacianODEFunc)) 55 | gnn.train() 56 | out = odeblock(data.x) 57 | self.assertTrue(data.x.shape == out.shape) 58 | gnn.eval() 59 | out = odeblock(data.x) 60 | print('ode block out', out) 61 | self.assertTrue(data.x.shape == out.shape) 62 | self.opt['heads'] = 2 63 | try: 64 | gnn = GNN(self.opt, self.dataset, device=self.device) 65 | self.assertTrue(False) 66 | except AssertionError: 67 | pass 68 | 69 | def test_gnn(self): 70 | gnn = GNN(self.opt, self.dataset, device=self.device) 71 | gnn.train() 72 | out = gnn(self.dataset.data.x) 73 | print(out.shape) 74 | print(torch.Size([self.dataset.data.num_nodes, self.dataset.num_classes])) 75 | self.assertTrue(out.shape == torch.Size([self.dataset.data.num_nodes, self.dataset.num_classes])) 76 | gnn.eval() 77 | out = gnn(self.dataset.data.x) 78 | self.assertTrue(out.shape == torch.Size([self.dataset.data.num_nodes, self.dataset.num_classes])) 79 | 80 | 81 | if __name__ == '__main__': 82 | est = GNNTests() 83 | est.setUp() 84 | est.test_block() 85 | -------------------------------------------------------------------------------- /test/test_params.py: -------------------------------------------------------------------------------- 1 | """ 2 | Store the global parameter dictionary to be imported and modified by each test 3 | """ 4 | 5 | OPT = {'dataset': 'Cora', 'self_loop_weight': 1, 'leaky_relu_slope': 0.2, 'heads': 2, 'K': 10, 6 | 'attention_norm_idx': 0, 'add_source': False, 'alpha': 1, 'alpha_dim': 'vc', 'beta_dim': 'vc', 7 | 'hidden_dim': 6, 'block': 'attention', 'function': 'laplacian', 'augment': False, 'adjoint': False, 8 | 'tol_scale': 1, 'time': 1, 'input_dropout': 0.5, 'dropout': 0.5, 'method': 'euler', 'rewiring': None, 9 | 'no_alpha_sigmoid': False, 'reweight_attention': False, 'kinetic_energy': None, 'jacobian_norm2': None, 10 | 'total_deriv': None, 'directional_penalty': None, 'step_size': 1, 'beltrami': False, 'use_mlp': False, 11 | 'use_labels': False, 'fc_out': False, 'attention_type': "scaled_dot", 'batch_norm': False, 'square_plus': False, 12 | 'feat_hidden_dim': 16, 'pos_enc_hidden_dim': 8, 'gdc_method': 'ppr', 'gdc_sparsification': 'topk', 'gdc_k': 4, 13 | 'gdc_threshold': 1e-5, 'ppr_alpha': 0.05, 'exact': True, 'pos_enc_orientation': 'row', 'pos_enc_type': 'GDC', 14 | 'max_nfe': 1000, 'pos_enc_csv': False, 'max_test_steps': 1000, 'edge_sampling_add_type': 'importance', 15 | 'fa_layer': False, 'att_samp_pct': 1, 'edge_sampling_sym': False, 'data_norm': 'rw', 'lr': 0.01, 'decay': 0, 16 | 'max_iters': 1000, 'geom_gcn_splits': False} 17 | -------------------------------------------------------------------------------- /test/test_transformer_attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Test attention 5 | """ 6 | import os 7 | import sys 8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) 9 | 10 | import unittest 11 | import torch 12 | from torch import tensor 13 | from torch import nn 14 | import torch_sparse 15 | from torch_geometric.utils import softmax, to_dense_adj 16 | 17 | from function_transformer_attention import SpGraphTransAttentionLayer, ODEFuncTransformerAtt 18 | from data import get_dataset 19 | from test_params import OPT 20 | from utils import ROOT_DIR 21 | 22 | class AttentionTests(unittest.TestCase): 23 | def setUp(self): 24 | self.edge = tensor([[0, 2, 2, 1], [1, 0, 1, 2]]) 25 | self.x = tensor([[1., 2.], [3., 2.], [4., 5.]], dtype=torch.float) 26 | self.edge1 = tensor([[0, 0, 1, 1, 2, 2], [1, 2, 0, 2, 0, 1]]) 27 | self.x1 = torch.ones((3, 2), dtype=torch.float) 28 | 29 | self.W = tensor([[2, 1], [3, 2]], dtype=torch.float) 30 | self.alpha = tensor([[1, 2, 3, 4]], dtype=torch.float) 31 | self.leakyrelu = nn.LeakyReLU(0.2) 32 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 33 | opt = {'dataset': 'Citeseer', 'self_loop_weight': 1, 'leaky_relu_slope': 0.2, 'beta_dim': 'vc', 'heads': 2, 34 | 'K': 10, 35 | 'attention_norm_idx': 0, 'add_source': False, 'alpha': 1, 'alpha_dim': 'vc', 'beta_dim': 'vc', 36 | 'hidden_dim': 6, 'linear_attention': True, 'augment': False, 'adjoint': False, 37 | 'tol_scale': 1, 'time': 1, 'ode': 'ode', 'input_dropout': 0.5, 'dropout': 0.5, 'method': 'euler', 38 | 'mixed_block': True, 'max_nfe': 1000, 'mix_features': False, 'attention_dim': 32, 'rewiring': None, 39 | 'no_alpha_sigmoid': False, 'reweight_attention': False, 'kinetic_energy': None, 'jacobian_norm2': None, 'total_deriv': None, 'directional_penalty': None, 'beltrami': False} 40 | self.opt = {**OPT, **opt} 41 | def tearDown(self) -> None: 42 | pass 43 | 44 | def test(self): 45 | h = torch.mm(self.x, self.W) 46 | edge_h = torch.cat((h[self.edge[0, :], :], h[self.edge[1, :], :]), dim=1) 47 | self.assertTrue(edge_h.shape == torch.Size([self.edge.shape[1], 2 * 2])) 48 | ah = self.alpha.mm(edge_h.t()).t() 49 | self.assertTrue(ah.shape == torch.Size([self.edge.shape[1], 1])) 50 | edge_e = self.leakyrelu(ah) 51 | attention = softmax(edge_e, self.edge[1]) 52 | print(attention) 53 | 54 | def test_function(self): 55 | in_features = self.x.shape[1] 56 | out_features = self.x.shape[1] 57 | att_layer = SpGraphTransAttentionLayer(in_features, out_features, self.opt, self.device, concat=True) 58 | attention, _ = att_layer(self.x, self.edge) # should be n_edges x n_heads 59 | self.assertTrue(attention.shape == (self.edge.shape[1], self.opt['heads'])) 60 | dense_attention1 = to_dense_adj(self.edge, edge_attr=attention[:, 0]).squeeze() 61 | dense_attention2 = to_dense_adj(self.edge, edge_attr=attention[:, 1]).squeeze() 62 | 63 | def get_round_sum(tens, n_digits=3): 64 | val = torch.sum(tens, dim=int(not self.opt['attention_norm_idx'])) 65 | round_sum = (val * 10 ** n_digits).round() / (10 ** n_digits) 66 | print('round sum', round_sum) 67 | return round_sum 68 | 69 | self.assertTrue(torch.all(torch.isclose(get_round_sum(dense_attention1), torch.ones(size=dense_attention1.shape)))) 70 | self.assertTrue(torch.all(torch.isclose(get_round_sum(dense_attention2), torch.ones(size=dense_attention1.shape)))) 71 | self.assertTrue(torch.all(attention > 0.)) 72 | self.assertTrue(torch.all(attention <= 1.)) 73 | 74 | dataset = get_dataset(self.opt, f'{ROOT_DIR}/data', True) 75 | data = dataset.data 76 | in_features = data.x.shape[1] 77 | out_features = data.x.shape[1] 78 | 79 | att_layer = SpGraphTransAttentionLayer(in_features, out_features, self.opt, self.device, concat=True) 80 | attention, _ = att_layer(data.x, data.edge_index) # should be n_edges x n_heads 81 | self.assertTrue(attention.shape == (data.edge_index.shape[1], self.opt['heads'])) 82 | dense_attention1 = to_dense_adj(data.edge_index, edge_attr=attention[:, 0]).squeeze() 83 | dense_attention2 = to_dense_adj(data.edge_index, edge_attr=attention[:, 1]).squeeze() 84 | print('sums:', torch.sum(torch.isclose(dense_attention1, torch.ones(size=dense_attention1.shape))), dense_attention1.shape) 85 | print('da1', dense_attention1) 86 | print('da2', dense_attention2) 87 | self.assertTrue(torch.all(torch.isclose(get_round_sum(dense_attention1), torch.ones(size=dense_attention1.shape)))) 88 | self.assertTrue(torch.all(torch.isclose(get_round_sum(dense_attention2), torch.ones(size=dense_attention2.shape)))) 89 | self.assertTrue(torch.all(attention > 0.)) 90 | self.assertTrue(torch.all(attention <= 1.)) 91 | 92 | def test_symmetric_attention(self): 93 | in_features = self.x1.shape[1] 94 | out_features = self.x1.shape[1] 95 | att_layer = SpGraphTransAttentionLayer(in_features, out_features, self.opt, self.device, concat=True) 96 | attention, _ = att_layer(self.x1, self.edge1) # should be n_edges x n_heads 97 | 98 | self.assertTrue(torch.all(torch.isclose(att_layer.Q.weight, att_layer.K.weight))) 99 | self.assertTrue(torch.all(torch.eq(attention, 0.5 * torch.ones((self.edge1.shape[1], self.x1.shape[1]))))) 100 | 101 | def test_module(self): 102 | dataset = get_dataset(self.opt, f'{ROOT_DIR}/data', False) 103 | t = 1 104 | out_dim = 6 105 | func = ODEFuncTransformerAtt(dataset.data.num_features, out_dim, self.opt, dataset.data, self.device) 106 | out = func(t, dataset.data.x) 107 | print(out.shape) 108 | self.assertTrue(out.shape == (dataset.data.num_nodes, dataset.num_features)) 109 | 110 | def test_head_aggregation(self): 111 | in_features = self.x.shape[1] 112 | out_features = self.x.shape[1] 113 | self.opt['head'] = 4 114 | att_layer = SpGraphTransAttentionLayer(in_features, out_features, self.opt, self.device, concat=True) 115 | attention, _ = att_layer(self.x, self.edge) 116 | ax1 = torch.mean(torch.stack( 117 | [torch_sparse.spmm(self.edge, attention[:, idx], self.x.shape[0], self.x.shape[0], self.x) for idx in 118 | range(self.opt['heads'])], dim=0), dim=0) 119 | mean_attention = attention.mean(dim=1) 120 | ax2 = torch_sparse.spmm(self.edge, mean_attention, self.x.shape[0], self.x.shape[0], self.x) 121 | self.assertTrue(torch.all(torch.isclose(ax1,ax2))) 122 | 123 | def test_two_way_edge(self): 124 | dataset = get_dataset(self.opt, f'{ROOT_DIR}/data', False) 125 | edge = dataset.data.edge_index 126 | print(f"is_undirected {dataset.data.is_undirected()}") 127 | 128 | edge_dict = {} 129 | 130 | for idx, src in enumerate(edge[0, :]): 131 | src = int(src) 132 | if src in edge_dict: 133 | edge_dict[src].add(int(edge[1, idx])) 134 | else: 135 | edge_dict[src] = set([int(edge[1, idx])]) 136 | 137 | print(f"edge shape {edge.shape}") 138 | src_test = edge[:, edge[0, :] == 1][1, :] 139 | dst_test = edge[:, edge[1, :] == 1][0, :] 140 | print('dst where src = 1', src_test) 141 | print('src where dst = 1', dst_test) 142 | 143 | for idx, dst in enumerate(edge[1, :]): 144 | dst = int(dst) 145 | self.assertTrue(int(edge[0, idx]) in edge_dict[dst]) 146 | 147 | 148 | if __name__ == '__main__': 149 | AT = AttentionTests() 150 | AT.test_symmetric_attention() -------------------------------------------------------------------------------- /test/test_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # needed for CI/CD 4 | import os 5 | import sys 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) 7 | 8 | import unittest 9 | import torch 10 | from torch import tensor 11 | from torch import nn 12 | from data import get_dataset 13 | from torch_geometric.data import Data 14 | from torch_geometric.utils import to_dense_adj 15 | import numpy as np 16 | from utils import get_rw_adj, gcn_norm_fill_val 17 | from sklearn.preprocessing import normalize 18 | from test_params import OPT 19 | 20 | 21 | class DummyDataset(): 22 | def __init__(self, data, num_classes): 23 | self.data = data 24 | self.num_classes = num_classes 25 | 26 | 27 | def get_rw_numpy(arr, self_loops, norm_dim): 28 | new_arr = arr + np.identity(arr.shape[0]) * self_loops 29 | dim = 0 if norm_dim == 1 else 1 30 | return normalize(new_arr, norm='l1', axis=dim) 31 | 32 | 33 | class UtilsTests(unittest.TestCase): 34 | def setUp(self): 35 | self.edge = tensor([[0, 2, 2, 1], [1, 0, 1, 2]]) 36 | self.x = tensor([[1., 2.], [3., 2.], [4., 5.]], dtype=torch.float) 37 | self.W = tensor([[2, 1], [3, 2]], dtype=torch.float) 38 | self.alpha = tensor([[1, 2, 3, 4]], dtype=torch.float) 39 | self.edge1 = tensor([[0, 0, 1, 1, 2, 2], [1, 2, 0, 2, 0, 1]]) 40 | self.x1 = torch.ones((3, 2), dtype=torch.float) 41 | self.data = Data(x=self.x, edge_index=self.edge) 42 | 43 | self.leakyrelu = nn.LeakyReLU(0.2) 44 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 45 | opt = {'dataset': 'Cora', 'self_loop_weight': 0, 'leaky_relu_slope': 0.2, 'beta_dim': 'vc', 'heads': 2, 46 | 'K': 10, 47 | 'attention_norm_idx': 0, 'add_source': False, 'alpha': 1, 'alpha_dim': 'vc', 'beta_dim': 'vc', 48 | 'hidden_dim': 6, 'linear_attention': True, 'augment': False, 'adjoint': False, 49 | 'tol_scale': 1, 'time': 1, 'ode': 'ode', 'input_dropout': 0.5, 'dropout': 0.5, 'method': 'euler', 50 | 'rewiring': None, 'no_alpha_sigmoid': False, 'reweight_attention': False, 'kinetic_energy': None, 51 | 'jacobian_norm2': None, 'total_deriv': None, 'directional_penalty': None, 'beltrami': False} 52 | self.opt = {**OPT, **opt} 53 | self.dataset = get_dataset(self.opt, '../data', False) 54 | 55 | def tearDown(self) -> None: 56 | pass 57 | 58 | def test_gcn_norm_fill_val(self): 59 | edge_index, edge_weight = gcn_norm_fill_val(self.edge, None, self.opt['self_loop_weight'], self.x.shape[0], 60 | dtype=self.x.dtype) 61 | 62 | def self_loop_test(self, base_adj, self_loop=0, norm_dim=0): 63 | edge_index, edge_weight = get_rw_adj(self.edge, norm_dim=norm_dim, fill_value=self_loop, 64 | num_nodes=self.x.shape[0]) 65 | dense_rw_adj = to_dense_adj(edge_index, edge_attr=edge_weight).numpy().squeeze() 66 | numpy_arr = get_rw_numpy(base_adj, self_loop, norm_dim) 67 | print('self loop', self_loop) 68 | print('numpy arr', numpy_arr) 69 | print('torch arr', dense_rw_adj) 70 | self.assertTrue(np.allclose(numpy_arr, dense_rw_adj)) 71 | 72 | def test_get_rw_adj(self): 73 | base_adj = to_dense_adj(self.edge).numpy().squeeze() 74 | self_loops = [0, 0.3, 1, 3.2] 75 | for self_loop in self_loops: 76 | self.self_loop_test(base_adj, self_loop, norm_dim=0) 77 | self.self_loop_test(base_adj, self_loop, norm_dim=1) 78 | 79 | if __name__ == '__main__': 80 | tests = UtilsTests() 81 | tests.setUp() 82 | tests.test_get_rw_adj() 83 | --------------------------------------------------------------------------------