├── .gitignore ├── .gitmodules ├── AUTHORS ├── LICENSE ├── Makefile ├── README.md ├── bin ├── generate │ ├── balanced_tree.yaml │ ├── kronecker_graph.yaml │ ├── nCRP.yaml │ ├── price.yaml │ ├── random_tree.yaml │ └── scale_free_network.yaml ├── launch_generation_sweep.sh ├── launch_sweep.sh ├── launch_train_sweep.sh ├── learning │ ├── scale_free_network_box.yaml │ └── scale_free_network_vector.yaml ├── plot │ ├── plot_bar.py │ ├── plot_graphs.py │ ├── plot_parallel_coord.py │ ├── upload_results.py │ └── upload_results2.py └── run_sweep.sh ├── fastentrypoints.py ├── requirements.txt ├── scripts ├── README.md ├── auto_sweep.py ├── copy_results.py ├── graph-modeling └── wandb_sweep_example.yaml ├── setup.py ├── src └── graph_modeling │ ├── __init__.py │ ├── __main__.py │ ├── enums.py │ ├── generate │ ├── __init__.py │ ├── __main__.py │ ├── balanced_tree.py │ ├── generic.py │ ├── hac.py │ ├── knn_graph.py │ ├── kronecker_graph.py │ ├── nested_chinese_restaurant_process.py │ ├── price.py │ ├── random_tree.py │ ├── scale_free_network.py │ └── wordnet.py │ ├── metric_logger.py │ ├── metrics │ ├── __init__.py │ ├── __main__.py │ ├── calculate.py │ └── collect.py │ ├── models │ ├── __init__.py │ ├── box.py │ ├── hyperbolic.py │ ├── poe.py │ ├── temps.py │ └── vector.py │ └── training │ ├── __init__.py │ ├── __main__.py │ ├── dataset.py │ ├── loopers.py │ ├── loss.py │ ├── metrics.py │ └── train.py └── tests ├── conftest.py ├── data ├── kronecker_graph.npz ├── kronecker_graph.toml ├── random_tree.npz └── random_tree.param.toml ├── models ├── test_box.py └── test_temps.py └── training └── test_dataset.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Manual Additions 2 | wandb 3 | /data 4 | /models 5 | scratch 6 | .ipynb_checkpoints 7 | *.ipynb 8 | .idea 9 | /tests/**/*.metric 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | pip-wheel-metadata/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # celery beat schedule file 102 | celerybeat-schedule 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | 135 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 136 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 137 | 138 | # User-specific stuff 139 | .idea/**/workspace.xml 140 | .idea/**/tasks.xml 141 | .idea/**/usage.statistics.xml 142 | .idea/**/dictionaries 143 | .idea/**/shelf 144 | 145 | # Generated files 146 | .idea/**/contentModel.xml 147 | 148 | # Sensitive or high-churn files 149 | .idea/**/dataSources/ 150 | .idea/**/dataSources.ids 151 | .idea/**/dataSources.local.xml 152 | .idea/**/sqlDataSources.xml 153 | .idea/**/dynamic.xml 154 | .idea/**/uiDesigner.xml 155 | .idea/**/dbnavigator.xml 156 | 157 | # Gradle 158 | .idea/**/gradle.xml 159 | .idea/**/libraries 160 | 161 | # Gradle and Maven with auto-import 162 | # When using Gradle or Maven with auto-import, you should exclude module files, 163 | # since they will be recreated, and may cause churn. Uncomment if using 164 | # auto-import. 165 | # .idea/modules.xml 166 | # .idea/*.iml 167 | # .idea/modules 168 | 169 | # CMake 170 | cmake-build-*/ 171 | 172 | # Mongo Explorer plugin 173 | .idea/**/mongoSettings.xml 174 | 175 | # File-based project format 176 | *.iws 177 | 178 | # IntelliJ 179 | out/ 180 | 181 | # mpeltonen/sbt-idea plugin 182 | .idea_modules/ 183 | 184 | # JIRA plugin 185 | atlassian-ide-plugin.xml 186 | 187 | # Cursive Clojure plugin 188 | .idea/replstate.xml 189 | 190 | # Crashlytics plugin (for Android Studio and IntelliJ) 191 | com_crashlytics_export_strings.xml 192 | crashlytics.properties 193 | crashlytics-build.properties 194 | fabric.properties 195 | 196 | # Editor-based Rest Client 197 | .idea/httpRequests 198 | 199 | # Android studio 3.1+ serialized cache file 200 | .idea/caches/build_file_checksums.ser 201 | 202 | # Ignore these as they only make sense when not on the server 203 | .idea/webServers.xml 204 | .idea/deployment.xml 205 | 206 | # Mac OS X files 207 | .DS_Store 208 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "lib/box-embeddings"] 2 | path = lib/box-embeddings 3 | url = https://github.com/iesl/box-embeddings.git 4 | [submodule "lib/pytorch-utils"] 5 | path = lib/pytorch-utils 6 | url = https://gitlab.com/boratko/pytorch-utils.git 7 | [submodule "lib/wandb-utils"] 8 | path = lib/wandb-utils 9 | url = https://gitlab.com/boratko/wandb-utils.git 10 | [submodule "libc/snap"] 11 | path = libc/snap 12 | url = https://github.com/snap-stanford/snap.git 13 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the list of Geometric Graph Embedding's significant contributors. 2 | # 3 | # This does not necessarily list everyone who has contributed code, 4 | # especially since many employees of one corporation may be contributing. 5 | # To see the full list of contributors, see the revision history in 6 | # source control. 7 | 8 | Michael Boratko 9 | Dongxu Zhang 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: base generation 2 | 3 | base: 4 | @echo "Checking to make sure submodules are cloned..." 5 | git submodule update --init --recursive 6 | @echo "Installing: local python libraries..." 7 | pip install -e lib/* 8 | pip install -e . 9 | @echo "Completed: local python library installation" 10 | 11 | generation: 12 | @echo "Compiling: Kronecker generation from snap..." 13 | cd libc/snap/examples/krongen/; make all 14 | @echo "Completed: Kronecker generation compilation" 15 | @echo "Installing: graph-tools..." 16 | conda install -c conda-forge graph-tool 17 | @echo "Completed: graph-tools installation" 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Modeling 2 | This repository contains code which accompanies the paper [Capacity and Bias of Learned Geometric Embeddings for Directed Graphs (Boratko et al. 2021)](https://proceedings.neurips.cc/paper/2021/hash/88d25099b103efd638163ecb40a55589-Abstract.html). 3 | 4 | Code for the following papers will also be added shortly: 5 | * [Modeling Transitivity and Cyclicity in Directed Graphs via Binary Code Box Embeddings (Zhang et al. 2022)](https://proceedings.neurips.cc/paper_files/paper/2022/hash/44a1f18afd6d5cc34d7e5c3d8a80f63b-Abstract-Conference.html) 6 | * [Learning Representations for Hierarchies with Minimal Support (Rozonoyer et al. 2024)](https://openreview.net/forum?id=HFS800reZK) 7 | 8 | This code includes implementations of many geometric embedding methods: 9 | - Vector Similarity and Distance 10 | - Bilinear Vector Model [(Nickel et al. 2011)](https://openreview.net/forum?id=H14QEiZ_WS) 11 | - ComplEx Embeddings [(Trouillon et al. 2016)](https://arxiv.org/abs/1606.06357) 12 | - Order Embeddings [(Vendrov et al. 2015)](https://arxiv.org/abs/1511.06361) and Probabilistic Order Embeddings [(Lai and Hockenmaier 2017)](https://aclanthology.org/E17-1068.pdf) 13 | - Hyperbolic Embeddings, including: 14 | - "Lorentzian" - uses the squared Lorentzian distance on the Hyperboloid as in [(Law et al. 2019)](http://proceedings.mlr.press/v97/law19a.html), trains undirected but uses the asymmetric score function from [(Nickel and Kiela 2017)](https://proceedings.neurips.cc/paper/2017/file/59dfa2df42d9e3d41f5b02bfc32229dd-Paper.pdf) to determine edge direction at inference 15 | - "Lorentzian Score" - uses the asymmetric score above directly in training loss 16 | - "Lorentzian Distance" - Hyperbolic model for directed graphs as described in section 2.3 of [(Boratko et al. 2021)](https://proceedings.neurips.cc/paper/2021/hash/88d25099b103efd638163ecb40a55589-Abstract.html) 17 | - Hyperbolic Entailment Cones [(Ganea et al. 2018)](https://arxiv.org/abs/1804.01882) 18 | - Gumbel Box Embeddings [(Dasgupta et al. 2020)](https://arxiv.org/abs/2010.04831) 19 | - t-Box model as described in section 3 of [(Boratko et al. 2021)](https://proceedings.neurips.cc/paper/2021/hash/88d25099b103efd638163ecb40a55589-Abstract.html) 20 | 21 | It also provides a general-purpose pipeline to explore correlation between graph characteristics and models' learning capabilities. 22 | 23 | ## Installation 24 | 25 | This repository makes use of submodules, to clone them you should use the `--recurse-submodules` flag, eg. 26 | ```bash 27 | git clone --recurse-submodules 28 | ``` 29 | After cloning the repo, you should create an environment and install pytorch. For example, 30 | 31 | ```bash 32 | conda create -n graph-modeling python=3.8 33 | conda activate graph-modeling 34 | conda install -c pytorch cudatoolkit=11.3 pytorch 35 | ``` 36 | 37 | You can then run `make all` to install the remaining modules and their dependencies. **Note:** 38 | 1. This will install Python modules, so you should run this command with the virtual environment created previously activated. 39 | 2. Certain graph generation methods (Kronecker and Price Network) will require additional dependencies to be compiled. In particular, Price requires that you use `conda`. If you are not interested in generating Kronecker or Price graphs you can skip this by using `make base` instead of `make all`. 40 | 41 | ## Usage 42 | 43 | This module provides a command line interface available with `graph_modeling`. 44 | 45 | Run `graph_modeling --help` to see available options. 46 | 47 | ### Generate Graphs 48 | To generate a graph, run `graph_modeling generate `, eg. `graph_modeling generate scale-free-network`. 49 | 50 | - `graph_modeling generate --help` provides a list of available graphs that can be generated 51 | - `graph_modeling generate --help` provides a list of parameters for generation 52 | 53 | By default, graphs will be output in `data/graphs`, using a subfolder for their graph type and parameter settings. You can override this with the `--outdir` parameter. 54 | 55 | ### Train Graph Representations 56 | You can train graph representations using the `graph_modeling train` command, run `graph_modeling train --help` to see available options. The only required parameter is `--data_path`, which specifies either a specific graph file or a folder, in which case it will pick a graph in the folder uniformly randomly. The `--model` option allows for a selection of different embedding models. Most other options apply to every model (eg. `--dim`) or training in general (eg. `--log_batch_size`). Model-specific options are prefaced with the model name (eg. `--box_intersection_temp`). Please see the help text for the options for more details, and submit an issue if anything is unclear. 57 | 58 | ## Citation 59 | If you found the code contained in this repository helpful in your research, please cite the following paper: 60 | 61 | ``` 62 | @inproceedings{boratko2021capacity, 63 | title={Capacity and Bias of Learned Geometric Embeddings for Directed Graphs}, 64 | author={Boratko, Michael and Zhang, Dongxu and Monath, Nicholas and Vilnis, Luke and Clarkson, Kenneth L and McCallum, Andrew}, 65 | booktitle={Thirty-Fifth Conference on Neural Information Processing Systems}, 66 | year={2021} 67 | } 68 | ``` 69 | 70 | 71 | -------------------------------------------------------------------------------- /bin/generate/balanced_tree.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | program: scripts/graph-modeling 18 | command: 19 | - ${env} 20 | - ${interpreter} 21 | - ${program} 22 | - generate 23 | - balanced-tree 24 | - --transitive_closure 25 | - --outdir=data/graphs13/ 26 | - ${args} 27 | method: grid 28 | project: generate_graphs 29 | parameters: 30 | branching: 31 | values: [2, 3, 5, 10] 32 | log_num_nodes: 33 | values: [13] 34 | -------------------------------------------------------------------------------- /bin/generate/kronecker_graph.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | program: scripts/graph-modeling 18 | command: 19 | - ${env} 20 | - ${interpreter} 21 | - ${program} 22 | - generate 23 | - kronecker 24 | - ${args} 25 | method: grid 26 | project: generate_graphs 27 | parameters: 28 | log_num_nodes: 29 | values: [13] 30 | a: 31 | values: [1.0, 0.8] 32 | b: 33 | values: [0.6, 0.4] 34 | c: 35 | values: [0.5, 0.3] 36 | d: 37 | values: [0.3, 0.1] 38 | seed: 39 | values: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 40 | -------------------------------------------------------------------------------- /bin/generate/nCRP.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | program: scripts/graph-modeling 18 | command: 19 | - ${env} 20 | - ${interpreter} 21 | - ${program} 22 | - generate 23 | - ncrp 24 | - --no_transitive_closure 25 | - ${args} 26 | method: grid 27 | project: generate_graphs 28 | parameters: 29 | log_num_nodes: 30 | values: [13] 31 | alpha: 32 | values: [10, 100, 500] 33 | seed: 34 | values: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 35 | 36 | -------------------------------------------------------------------------------- /bin/generate/price.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | program: scripts/graph-modeling 18 | command: 19 | - ${env} 20 | - ${interpreter} 21 | - ${program} 22 | - generate 23 | - price 24 | - --outdir=data/graphs13/ 25 | - --transitive_closure 26 | - ${args} 27 | method: grid 28 | project: generate_graphs 29 | parameters: 30 | log_num_nodes: 31 | values: [13] 32 | m: 33 | values: [1, 5, 10] 34 | c: 35 | values: [0.01, 0.1] 36 | seed: 37 | values: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 38 | -------------------------------------------------------------------------------- /bin/generate/random_tree.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | program: scripts/graph-modeling 18 | command: 19 | - ${env} 20 | - ${interpreter} 21 | - ${program} 22 | - generate 23 | - random-tree 24 | - ${args} 25 | method: grid 26 | project: generate_graphs 27 | parameters: 28 | log_num_nodes: 29 | values: [12, 14, 16] 30 | seed: 31 | values: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 32 | -------------------------------------------------------------------------------- /bin/generate/scale_free_network.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | program: scripts/graph-modeling 18 | command: 19 | - ${env} 20 | - ${interpreter} 21 | - ${program} 22 | - generate 23 | - scale-free-network 24 | - --outdir=data/graphs13/ 25 | - ${args} 26 | method: grid 27 | project: generate_graphs 28 | parameters: 29 | log_num_nodes: 30 | values: [13] 31 | alpha: 32 | values: [0.1, 0.3] 33 | gamma: 34 | values: [0.4, 0.6] 35 | delta_in: 36 | values: [0.0, 1.0] 37 | delta_out: 38 | values: [0.0, 1.0] 39 | seed: 40 | values: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 41 | -------------------------------------------------------------------------------- /bin/launch_generation_sweep.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2021 The Geometric Graph Embedding Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # SPDX-License-Identifier: Apache-2.0 17 | 18 | 19 | set -exu 20 | 21 | sweep_id=$1 22 | num_machines=${2:-20} 23 | threads=${3:-2} 24 | mem=${4:-10000} 25 | 26 | TIME=`(date +%Y-%m-%d-%H-%M-%S-%N)` 27 | 28 | export MKL_NUM_THREADS=$threads 29 | export OPENBLAS_NUM_THREADS=$threads 30 | export OMP_NUM_THREADS=$threads 31 | 32 | model_name="wandb" 33 | dataset=$sweep_id 34 | job_name="$model_name-$dataset-$TIME" 35 | log_dir=logs/$model_name/$dataset/$TIME 36 | log_base=$log_dir/log 37 | 38 | partition='1080ti-short' 39 | 40 | mkdir -p $log_dir 41 | 42 | sbatch -J $job_name \ 43 | -e $log_base.err \ 44 | -o $log_base.log \ 45 | --cpus-per-task $threads \ 46 | --partition=$partition \ 47 | --gres=gpu:1 \ 48 | --ntasks=1 \ 49 | --nodes=1 \ 50 | --mem=$mem \ 51 | --array=0-$num_machines \ 52 | bin/run_sweep.sh $sweep_id $threads 100000 53 | -------------------------------------------------------------------------------- /bin/launch_sweep.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2021 The Geometric Graph Embedding Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # SPDX-License-Identifier: Apache-2.0 17 | 18 | 19 | set -exu 20 | 21 | sweep_id=$1 22 | num_machines=${2:-0} 23 | threads=${3:-1} 24 | mem=${4:-20000} 25 | 26 | TIME=`(date +%Y-%m-%d-%H-%M-%S-%N)` 27 | 28 | export MKL_NUM_THREADS=$threads 29 | export OPENBLAS_NUM_THREADS=$threads 30 | export OMP_NUM_THREADS=$threads 31 | 32 | model_name="wandb" 33 | dataset=$sweep_id 34 | job_name="$model_name-$dataset-$TIME" 35 | log_dir=logs/$model_name/$dataset/$TIME 36 | log_base=$log_dir/log 37 | 38 | partition='titanx-long' 39 | 40 | mkdir -p $log_dir 41 | 42 | sbatch -J $job_name \ 43 | -e $log_base.err \ 44 | -o $log_base.log \ 45 | --cpus-per-task $threads \ 46 | --partition=$partition \ 47 | --gres=gpu:1 \ 48 | --ntasks=1 \ 49 | --nodes=1 \ 50 | --mem=$mem \ 51 | --array=0-$num_machines \ 52 | bin/run_sweep.sh $sweep_id $threads 53 | -------------------------------------------------------------------------------- /bin/launch_train_sweep.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2021 The Geometric Graph Embedding Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # SPDX-License-Identifier: Apache-2.0 17 | 18 | 19 | set -exu 20 | 21 | 22 | sweep_id=$1 23 | partition=$2 24 | max_run=${3:-100} 25 | num_machines=${4:-0} 26 | threads=${5:-1} 27 | mem=${6:-25000} 28 | 29 | 30 | TIME=`(date +%Y-%m-%d-%H-%M-%S-%N)` 31 | 32 | export MKL_NUM_THREADS=$threads 33 | export OPENBLAS_NUM_THREADS=$threads 34 | export OMP_NUM_THREADS=$threads 35 | 36 | model_name="wandb" 37 | dataset=$sweep_id 38 | job_name="$model_name-$dataset-$TIME" 39 | log_dir=logs/$model_name/$dataset/$TIME 40 | log_base=$log_dir/log 41 | 42 | partition=$partition 43 | 44 | mkdir -p $log_dir 45 | 46 | sbatch -J $job_name \ 47 | -e $log_base.err \ 48 | -o $log_base.log \ 49 | --cpus-per-task $threads \ 50 | --partition=$partition \ 51 | --gres=gpu:1 \ 52 | --ntasks=1 \ 53 | --nodes=1 \ 54 | --mem=$mem \ 55 | --array=0-$num_machines \ 56 | --exclude=node026,node030,node040,node057,node059,node072,node095,node099,node123,node125,node167,node169,node176 \ 57 | bin/run_sweep.sh $sweep_id $threads $max_run 58 | -------------------------------------------------------------------------------- /bin/learning/scale_free_network_box.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | program: scripts/graph-modeling 18 | command: 19 | - ${env} 20 | - ${interpreter} 21 | - ${program} 22 | - train 23 | - ${args} 24 | - --model_type=box 25 | - --log_interval=10000 26 | - --log_eval_batch_size=20 27 | - --graph_file_stub=data/graphs/scale_free_network-alpha=0.4-gamma=0.5-delta_in=0.2-num_nodes=10000-seed=1-transitive_closure=False-delta_out=0.0 28 | method: bayes 29 | metric: 30 | goal: maximize 31 | name: '[Train] F1' 32 | parameters: 33 | learning_rate: 34 | values: [0.001, 0.01] 35 | dim: 36 | values: [64] 37 | log_batch_size: 38 | values: [5, 7] 39 | negative_weight: 40 | values: [0.5, 0.9] 41 | negative_ratio: 42 | values: [1, 10, 100] 43 | box_intersection_temp: 44 | values: [0.001, 0.01, 0.1] 45 | box_volume_temp: 46 | values: [1.0, 0.1] 47 | -------------------------------------------------------------------------------- /bin/learning/scale_free_network_vector.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | program: scripts/graph-modeling 18 | command: 19 | - ${env} 20 | - ${interpreter} 21 | - ${program} 22 | - train 23 | - ${args} 24 | - --model_type=vector 25 | - --log_interval=10000 26 | - --log_eval_batch_size=20 27 | - --graph_file_stub=data/graphs/scale_free_network-alpha=0.4-gamma=0.5-delta_in=0.2-num_nodes=10000-seed=1-transitive_closure=False-delta_out=0.0 28 | method: bayes 29 | metric: 30 | goal: maximize 31 | name: '[Train] F1' 32 | parameters: 33 | learning_rate: 34 | values: [0.001, 0.01] 35 | dim: 36 | values: [2, 16, 64, 256] 37 | log_batch_size: 38 | values: [5, 7] 39 | negative_weight: 40 | values: [0.5, 0.9] 41 | negative_ratio: 42 | values: [10] -------------------------------------------------------------------------------- /bin/plot/plot_graphs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import graph_tool as gt 18 | from graph_tool.all import * 19 | from scipy.sparse import load_npz 20 | 21 | 22 | def load(path): 23 | g = Graph() 24 | digraph_coo = load_npz(path) 25 | out_node_list = digraph_coo.row 26 | in_node_list = digraph_coo.col 27 | for i in range(out_node_list.shape[0]): 28 | # print(out_node_list[i], in_node_list[i]) 29 | g.add_edge(out_node_list[i], in_node_list[i], True) 30 | return g 31 | 32 | 33 | def plot_tree(g, output): 34 | return gt.draw.graphviz_draw(g, layout="dot", output=output) 35 | 36 | 37 | def plot_gv(inpath, outpath): 38 | # node[shape = point] 39 | with open(outpath, "w") as fout: 40 | fout.write("digraph g {\n") 41 | nodes = set() 42 | digraph_coo = load_npz(inpath) 43 | out_node_list = digraph_coo.row 44 | in_node_list = digraph_coo.col 45 | for i in range(out_node_list.shape[0]): 46 | nodes.add(out_node_list[i]) 47 | nodes.add(in_node_list[i]) 48 | fout.write("%s -> %s;\n" % (out_node_list[i], in_node_list[i])) 49 | for p in nodes: 50 | fout.write("%s[shape=point];\n" % p) 51 | fout.write("}") 52 | 53 | 54 | fname = "graphs13/balanced_tree/branching=10-log_num_nodes=13-transitive_closure=False/2150935259.npz" 55 | 56 | plot_gv(fname, fname + ".gv") 57 | -------------------------------------------------------------------------------- /bin/plot/plot_parallel_coord.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | def plotme_single_f1(np_mat, ynames, title): 18 | """ 19 | np_mat - N rows by C columns, all values to plot (real valued) 20 | ynames - C items, strings naming the C columns of np_mats 21 | title - 1 title for the plot 22 | """ 23 | # https://stackoverflow.com/questions/8230638/parallel-coordinates-plot-in-matplotlib 24 | import matplotlib.pyplot as plt 25 | from matplotlib.path import Path 26 | import matplotlib.patches as patches 27 | import numpy as np 28 | 29 | ys = np_mat 30 | 31 | ymins = ys.min(axis=0) 32 | ymaxs = ys.max(axis=0) 33 | dys = ymaxs - ymins 34 | ymins -= dys * 0.05 # add 5% padding below and above 35 | ymaxs += dys * 0.05 36 | dys = ymaxs - ymins 37 | 38 | # transform all data to be compatible with the main axis 39 | zs = np.zeros_like(ys) 40 | zs[:, 0] = ys[:, 0] 41 | zs[:, 1:] = (ys[:, 1:] - ymins[1:]) / dys[1:] * dys[0] + ymins[0] 42 | 43 | fig, host = plt.subplots(figsize=(10, 4)) 44 | from matplotlib import cm 45 | 46 | viridis = cm.get_cmap("viridis", 256) 47 | 48 | axes = [host] + [host.twinx() for i in range(ys.shape[1] - 1)] 49 | for i, ax in enumerate(axes): 50 | ax.set_ylim(ymins[i], ymaxs[i]) 51 | ax.spines["top"].set_visible(False) 52 | ax.spines["bottom"].set_visible(False) 53 | if ax != host: 54 | ax.spines["left"].set_visible(False) 55 | ax.yaxis.set_ticks_position("right") 56 | ax.spines["right"].set_position(("axes", i / (ys.shape[1] - 1))) 57 | 58 | host.set_xlim(0, ys.shape[1] - 1) 59 | host.set_xticks(range(ys.shape[1])) 60 | host.set_xticklabels(ynames, fontsize=14) 61 | host.tick_params(axis="x", which="major", pad=7) 62 | host.spines["right"].set_visible(False) 63 | host.xaxis.tick_top() 64 | host.set_title(title, fontsize=18) 65 | 66 | colors = plt.cm.Set2.colors 67 | # each row the matrix 68 | for j in range(ys.shape[0]): 69 | # to just draw straight lines between the axes: 70 | # host.plot(range(ys.shape[1]), zs[j,:], c=colors[(category[j] - 1) % len(colors) ]) 71 | 72 | # create bezier curves 73 | # for each axis, there will a control vertex at the point itself, one at 1/3rd towards the previous and one 74 | # at one third towards the next axis; the first and last axis have one less control vertex 75 | # x-coordinate of the control vertices: at each integer (for the axes) and two inbetween 76 | # y-coordinate: repeat every point three times, except the first and last only twice 77 | verts = list( 78 | zip( 79 | [ 80 | x 81 | for x in np.linspace( 82 | 0, len(ys) - 1, len(ys) * 3 - 2, endpoint=True 83 | ) 84 | ], 85 | np.repeat(zs[j, :], 3)[1:-1], 86 | ) 87 | ) 88 | # for x,y in verts: host.plot(x, y, 'go') # to show the control points of the beziers 89 | codes = [Path.MOVETO] + [Path.CURVE4 for _ in range(len(verts) - 1)] 90 | path = Path(verts, codes) 91 | patch = patches.PathPatch( 92 | path, facecolor="none", lw=1, edgecolor=viridis(ys[j, ys.shape[1] - 1]) 93 | ) 94 | host.add_patch(patch) 95 | # host.legend(legend_handles, legend_names, 96 | # loc='lower center', bbox_to_anchor=(0.5, -0.18), 97 | # ncol=len(model_names), fancybox=True, shadow=True) 98 | plt.tight_layout() 99 | plt.savefig("/tmp/parallel_coordinates.pdf") 100 | 101 | 102 | import wandb 103 | import sys 104 | 105 | 106 | import csv 107 | import json 108 | import numpy as np 109 | 110 | 111 | def load_csv(filename, metric="F1"): 112 | with open(filename) as csv_file: 113 | # csv_reader = csv.reader(csv_file, delimiter=',') 114 | csv_reader = csv.reader(csv_file, delimiter="\t", quotechar='"') 115 | line_count = 0 116 | key2id = {} 117 | graph2method2dim = {} 118 | for row in csv_reader: 119 | if line_count == 0: 120 | key2id = dict([(k, i) for i, k in list(enumerate(row))]) 121 | # print(key2id) 122 | line_count += 1 123 | print(key2id.items()) 124 | else: 125 | method_name = row[key2id["model_type"]] 126 | dim = int(row[key2id["dim"]]) 127 | graph_type = row[key2id["type"]] 128 | transitive_closure = row[key2id["transitive_closure"]] 129 | # transitive_closure = row[key2id["path"]].split("transitive_closure")[1].split("/")[0][1:] 130 | graph_type = graph_type + "_" + str(transitive_closure) 131 | graph_path = row[key2id["path"]] 132 | # print(method_name, graph_type, dim) 133 | if metric == "AUC": 134 | metric = float(row[key2id["AUC"]]) 135 | else: 136 | metric = float(row[key2id["F1"]]) 137 | if graph_type not in graph2method2dim: 138 | graph2method2dim[graph_type] = {} 139 | if method_name not in graph2method2dim[graph_type]: 140 | if method_name in ["box", "vector", "complex_vector"]: 141 | graph2method2dim[graph_type][method_name] = { 142 | 4: {}, 143 | 16: {}, 144 | 64: {}, 145 | } 146 | else: 147 | graph2method2dim[graph_type][method_name] = { 148 | 8: {}, 149 | 32: {}, 150 | 128: {}, 151 | } 152 | if dim not in graph2method2dim[graph_type][method_name]: 153 | continue 154 | if graph_path not in graph2method2dim[graph_type][method_name][dim]: 155 | graph2method2dim[graph_type][method_name][dim][graph_path] = -np.inf 156 | if metric > graph2method2dim[graph_type][method_name][dim][graph_path]: 157 | graph2method2dim[graph_type][method_name][dim][graph_path] = metric 158 | return graph2method2dim 159 | 160 | 161 | def load_metrics(fname): 162 | columns = [] 163 | path2metrics = {} 164 | 165 | def maybe_float(s): 166 | try: 167 | s = float(s) 168 | except: 169 | pass 170 | return s 171 | 172 | with open(fname) as fin: 173 | for idx, line in enumerate(fin): 174 | if idx == 0: 175 | columns = [""] + line.strip().split("\t") 176 | print(columns) 177 | else: 178 | splt = [maybe_float(x) for x in line.strip().split("\t")] 179 | path2metrics[splt[1]] = splt 180 | return path2metrics, columns 181 | 182 | 183 | import time 184 | 185 | 186 | def upload_results(graph2method2dim, graph_stats, graph_col): 187 | # graph2method2dim[graph_type][method_name][dim][graph_path] 188 | 189 | numpy_mat_cols = [ 190 | "sparsity", 191 | "avg_degree", 192 | "transitivity", 193 | "assortativity", 194 | "reciprocity", 195 | "flow_hierarchy", 196 | "branching", 197 | ] 198 | numpy_mat_cols = ["sparsity", "avg_degree", "transitivity", "reciprocity"] 199 | import collections 200 | 201 | method2dim2rows = collections.defaultdict(dict) 202 | numpy_mat = dict() # method2dim2res 203 | for graph_type in graph2method2dim.keys(): 204 | for method_name in graph2method2dim[graph_type].keys(): 205 | for dim in graph2method2dim[graph_type][method_name]: 206 | for graph_path in graph2method2dim[graph_type][method_name][dim]: 207 | if dim not in method2dim2rows[method_name]: 208 | method2dim2rows[method_name][dim] = [] 209 | try: 210 | res = [] 211 | stats = dict( 212 | [ 213 | (cname, cval) 214 | for cname, cval in zip( 215 | graph_col, graph_stats[graph_path] 216 | ) 217 | ] 218 | ) 219 | for c in numpy_mat_cols: 220 | print(stats[c], c) 221 | res.append(float(stats[c])) 222 | res.append( 223 | graph2method2dim[graph_type][method_name][dim][graph_path] 224 | ) # f1 225 | method2dim2rows[method_name][dim].append(res) 226 | except: 227 | print("error processing one line %s" % str(graph_path)) 228 | return method2dim2rows, numpy_mat_cols 229 | 230 | 231 | result_files = [ 232 | "results/balanced_tree.tsv", 233 | "results/kronecker_graph.tsv", 234 | "results/nCRP.tsv", 235 | "results/price.tsv", 236 | "results/scale_free_network.tsv", 237 | ] 238 | 239 | 240 | def f(rf): 241 | print(rf) 242 | graph_stats, graph_col = load_metrics("results/graph_metrics.tsv") 243 | from tqdm import tqdm 244 | 245 | bests = load_csv(rf) 246 | method2dim2rows, columns = upload_results(bests, graph_stats, graph_col) 247 | for method_name in method2dim2rows: 248 | print(method_name) 249 | for dim in method2dim2rows[method_name]: 250 | print(dim) 251 | results = np.array(method2dim2rows[method_name][dim]) 252 | plotme_single_f1(results, columns + ["f1"], rf) 253 | return "done!" 254 | 255 | 256 | if __name__ == "__main__": 257 | # f(result_files[int(sys.argv[1])]) 258 | f("results/balanced_tree.tsv") 259 | -------------------------------------------------------------------------------- /bin/plot/upload_results.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import wandb 18 | 19 | 20 | # load 21 | def load_metrics(fname): 22 | columns = [] 23 | path2metrics = {} 24 | 25 | def maybe_float(s): 26 | try: 27 | s = float(s) 28 | except: 29 | pass 30 | return s 31 | 32 | with open(fname) as fin: 33 | for idx, line in enumerate(fin): 34 | if idx == 0: 35 | columns = [""] + line.strip().split("\t") 36 | print(columns) 37 | else: 38 | splt = [maybe_float(x) for x in line.strip().split("\t")] 39 | path2metrics[splt[1]] = splt 40 | return path2metrics, columns 41 | 42 | 43 | def load_results(fname): 44 | columns = [] 45 | path2metrics = {} 46 | 47 | def maybe_float(s): 48 | try: 49 | s = float(s) 50 | except: 51 | pass 52 | return s 53 | 54 | with open(fname) as fin: 55 | for idx, line in enumerate(fin): 56 | if idx == 0: 57 | columns = [""] + line.strip().split("\t") 58 | print(columns) 59 | else: 60 | splt = [maybe_float(x) for x in line.strip().split("\t")] 61 | path2metrics[splt[2]] = splt 62 | return path2metrics, columns 63 | 64 | 65 | def log_rol(k, results_dict, results_col, metrics_dict, metrics_col): 66 | run = wandb.init(project="icml_box_paper_v1", reinit=True) 67 | res = dict() 68 | for cname, cval in zip(results_col, results_dict[k]): 69 | if cname != "": 70 | res[cname] = cval 71 | for cname, cval in zip(metrics_col, metrics_dict[k]): 72 | if cname != "": 73 | res[cname] = cval 74 | wandb.log(res) 75 | run.finish() 76 | 77 | 78 | graph_stats, graph_col = load_metrics("results/graph_metrics.tsv") 79 | 80 | result_files = [ 81 | "results/balanced_tree.tsv", 82 | "results/kronecker_graph.tsv", 83 | "results/nCRP.tsv", 84 | "results/price.tsv", 85 | "results/scale_free_network.tsv", 86 | ] 87 | 88 | from tqdm import tqdm 89 | 90 | for rf in tqdm(result_files): 91 | results, results_col = load_results(rf) 92 | for k in tqdm(results.keys()): 93 | log_rol(k, results, results_col, graph_stats, graph_col) 94 | -------------------------------------------------------------------------------- /bin/plot/upload_results2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import wandb 18 | import sys 19 | 20 | 21 | import csv 22 | import json 23 | import numpy as np 24 | 25 | 26 | def load_csv(filename, metric="F1"): 27 | with open(filename) as csv_file: 28 | # csv_reader = csv.reader(csv_file, delimiter=',') 29 | csv_reader = csv.reader(csv_file, delimiter="\t", quotechar='"') 30 | line_count = 0 31 | key2id = {} 32 | graph2method2dim = {} 33 | for row in csv_reader: 34 | if line_count == 0: 35 | key2id = dict([(k, i) for i, k in list(enumerate(row))]) 36 | # print(key2id) 37 | line_count += 1 38 | print(key2id.items()) 39 | else: 40 | method_name = row[key2id["model_type"]] 41 | dim = int(row[key2id["dim"]]) 42 | graph_type = row[key2id["type"]] 43 | transitive_closure = row[key2id["transitive_closure"]] 44 | # transitive_closure = row[key2id["path"]].split("transitive_closure")[1].split("/")[0][1:] 45 | graph_type = graph_type + "_" + str(transitive_closure) 46 | graph_path = row[key2id["path"]] 47 | # print(method_name, graph_type, dim) 48 | if metric == "AUC": 49 | metric = float(row[key2id["AUC"]]) 50 | else: 51 | metric = float(row[key2id["F1"]]) 52 | if graph_type not in graph2method2dim: 53 | graph2method2dim[graph_type] = {} 54 | if method_name not in graph2method2dim[graph_type]: 55 | if method_name in ["box", "vector", "complex_vector"]: 56 | graph2method2dim[graph_type][method_name] = { 57 | 4: {}, 58 | 16: {}, 59 | 64: {}, 60 | } 61 | else: 62 | graph2method2dim[graph_type][method_name] = { 63 | 8: {}, 64 | 32: {}, 65 | 128: {}, 66 | } 67 | if dim not in graph2method2dim[graph_type][method_name]: 68 | continue 69 | if graph_path not in graph2method2dim[graph_type][method_name][dim]: 70 | graph2method2dim[graph_type][method_name][dim][graph_path] = -np.inf 71 | if metric > graph2method2dim[graph_type][method_name][dim][graph_path]: 72 | graph2method2dim[graph_type][method_name][dim][graph_path] = metric 73 | return graph2method2dim 74 | 75 | 76 | def load_metrics(fname): 77 | columns = [] 78 | path2metrics = {} 79 | 80 | def maybe_float(s): 81 | try: 82 | s = float(s) 83 | except: 84 | pass 85 | return s 86 | 87 | with open(fname) as fin: 88 | for idx, line in enumerate(fin): 89 | if idx == 0: 90 | columns = [""] + line.strip().split("\t") 91 | print(columns) 92 | else: 93 | splt = [maybe_float(x) for x in line.strip().split("\t")] 94 | path2metrics[splt[1]] = splt 95 | return path2metrics, columns 96 | 97 | 98 | import time 99 | 100 | 101 | def upload_results(graph2method2dim, graph_stats, graph_col): 102 | # graph2method2dim[graph_type][method_name][dim][graph_path] 103 | for graph_type in graph2method2dim.keys(): 104 | for method_name in graph2method2dim[graph_type].keys(): 105 | for dim in graph2method2dim[graph_type][method_name]: 106 | for graph_path in graph2method2dim[graph_type][method_name][dim]: 107 | res = {} 108 | res["graph_type"] = graph_type 109 | res["method_name"] = method_name 110 | if method_name == "lorentzian_distance" and ( 111 | dim == 8 or dim == "8" 112 | ): 113 | res["dim"] = dim 114 | res["graph_path"] = graph_path 115 | for cname, cval in zip(graph_col, graph_stats[graph_path]): 116 | if cname != "": 117 | res[cname] = cval 118 | res["f1"] = graph2method2dim[graph_type][method_name][dim][ 119 | graph_path 120 | ] 121 | run = wandb.init(project="icml_box_paper_v11", reinit=True) 122 | wandb.log(res) 123 | run.finish() 124 | time.sleep(2) 125 | 126 | 127 | def load_results(fname): 128 | columns = [] 129 | path2metrics = {} 130 | 131 | def maybe_float(s): 132 | try: 133 | s = float(s) 134 | except: 135 | pass 136 | return s 137 | 138 | with open(fname) as fin: 139 | for idx, line in enumerate(fin): 140 | if idx == 0: 141 | columns = [""] + line.strip().split("\t") 142 | print(columns) 143 | else: 144 | splt = [maybe_float(x) for x in line.strip().split("\t")] 145 | path2metrics[splt[2]] = splt 146 | return path2metrics, columns 147 | 148 | 149 | result_files = [ 150 | "results/balanced_tree.tsv", 151 | "results/kronecker_graph.tsv", 152 | "results/nCRP.tsv", 153 | "results/price.tsv", 154 | "results/scale_free_network.tsv", 155 | ] 156 | 157 | 158 | def f(rf): 159 | print(rf) 160 | graph_stats, graph_col = load_metrics("results/graph_metrics.tsv") 161 | from tqdm import tqdm 162 | 163 | bests = load_csv(rf) 164 | upload_results(bests, graph_stats, graph_col) 165 | return "done!" 166 | 167 | 168 | if __name__ == "__main__": 169 | f(result_files[int(sys.argv[1])]) 170 | -------------------------------------------------------------------------------- /bin/run_sweep.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2021 The Geometric Graph Embedding Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # SPDX-License-Identifier: Apache-2.0 17 | 18 | 19 | set -exu 20 | 21 | sweep_id=$1 22 | threads=$2 23 | count=$3 24 | 25 | export MKL_NUM_THREADS=$threads 26 | export OPENBLAS_NUM_THREADS=$threads 27 | export OMP_NUM_THREADS=$threads 28 | 29 | echo $OMP_NUM_THREADS 30 | echo $OMP_NUM_THREADS 31 | echo $OMP_NUM_THREADS 32 | 33 | wandb agent --count $count $sweep_id 34 | -------------------------------------------------------------------------------- /fastentrypoints.py: -------------------------------------------------------------------------------- 1 | # noqa: D300,D400 2 | # Copyright (c) 2016, Aaron Christianson 3 | # All rights reserved. 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are 7 | # met: 8 | # 9 | # 1. Redistributions of source code must retain the above copyright 10 | # notice, this list of conditions and the following disclaimer. 11 | # 12 | # 2. Redistributions in binary form must reproduce the above copyright 13 | # notice, this list of conditions and the following disclaimer in the 14 | # documentation and/or other materials provided with the distribution. 15 | # 16 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS 17 | # IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED 18 | # TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 19 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 20 | # HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 21 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 22 | # TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 | # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 | # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | """ 28 | Monkey patch setuptools to write faster console_scripts with this format: 29 | 30 | import sys 31 | from mymodule import entry_function 32 | sys.exit(entry_function()) 33 | 34 | This is better. 35 | 36 | (c) 2016, Aaron Christianson 37 | http://github.com/ninjaaron/fast-entry_points 38 | """ 39 | from setuptools.command import easy_install 40 | import re 41 | 42 | TEMPLATE = r""" 43 | # -*- coding: utf-8 -*- 44 | # EASY-INSTALL-ENTRY-SCRIPT: '{3}','{4}','{5}' 45 | __requires__ = '{3}' 46 | import re 47 | import sys 48 | 49 | from {0} import {1} 50 | 51 | if __name__ == '__main__': 52 | sys.argv[0] = re.sub(r'(-script\.pyw?|\.exe)?$', '', sys.argv[0]) 53 | sys.exit({2}()) 54 | """.lstrip() 55 | 56 | 57 | @classmethod 58 | def get_args(cls, dist, header=None): # noqa: D205,D400 59 | """ 60 | Yield write_script() argument tuples for a distribution's 61 | console_scripts and gui_scripts entry points. 62 | """ 63 | if header is None: 64 | # pylint: disable=E1101 65 | header = cls.get_header() 66 | spec = str(dist.as_requirement()) 67 | for type_ in "console", "gui": 68 | group = type_ + "_scripts" 69 | for name, ep in dist.get_entry_map(group).items(): 70 | # ensure_safe_name 71 | if re.search(r"[\\/]", name): 72 | raise ValueError("Path separators not allowed in script names") 73 | script_text = TEMPLATE.format( 74 | ep.module_name, ep.attrs[0], ".".join(ep.attrs), spec, group, name 75 | ) 76 | # pylint: disable=E1101 77 | args = cls._get_script_args(type_, name, header, script_text) 78 | for res in args: 79 | yield res 80 | 81 | 82 | # pylint: disable=E1101 83 | easy_install.ScriptWriter.get_args = get_args 84 | 85 | 86 | def main(): 87 | import os 88 | import re 89 | import shutil 90 | import sys 91 | 92 | dests = sys.argv[1:] or ["."] 93 | filename = re.sub(r"\.pyc$", ".py", __file__) 94 | 95 | for dst in dests: 96 | shutil.copy(filename, dst) 97 | manifest_path = os.path.join(dst, "MANIFEST.in") 98 | setup_path = os.path.join(dst, "setup.py") 99 | 100 | # Insert the include statement to MANIFEST.in if not present 101 | with open(manifest_path, "a+") as manifest: 102 | manifest.seek(0) 103 | manifest_content = manifest.read() 104 | if "include fastentrypoints.py" not in manifest_content: 105 | manifest.write( 106 | ("\n" if manifest_content else "") + "include fastentrypoints.py" 107 | ) 108 | 109 | # Insert the import statement to setup.py if not present 110 | with open(setup_path, "a+") as setup: 111 | setup.seek(0) 112 | setup_content = setup.read() 113 | if "import fastentrypoints" not in setup_content: 114 | setup.seek(0) 115 | setup.truncate() 116 | setup.write("import fastentrypoints\n" + setup_content) 117 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy~=1.7.1 2 | tqdm~=4.62.3 3 | wandb~=0.12.7 4 | networkx~=2.6.3 5 | scikit-learn~=1.0.1 6 | attrs~=21.2.0 7 | pandas~=1.3.4 8 | numpy~=1.21.2 9 | toml~=0.10.2 10 | setuptools~=58.0.4 11 | loguru~=0.5.3 12 | hypothesis~=6.27.1 13 | pytest~=6.2.5 14 | click~=8.0.3 -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # Graph Modeling Scripts 2 | In this directory are some small self-contained scripts used for orchestrating the sweeps for Bayesian hyperparameter tuning with Weights and Biases, submitting jobs on SLURM, and collating the data. 3 | 4 | To submit multiple sweeps given a two-level directory of graph data: 5 | `python scripts/auto_sweep.py --data_path data/graphs/ --model_type=box --dim=16 --partition=1080ti-long --max_run 100` 6 | This will also save sweep configs to ./sweeps_config/ 7 | 8 | (There is also a file [graph-modeling](graph-modeling) which provides a script-like interface to the `graph_modeling` command resulting from installing this module. This is the entry point we use with wandb, and can also be useful as a target for your debugger.) 9 | -------------------------------------------------------------------------------- /scripts/auto_sweep.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # input: model_type, dim 18 | # output: a list of "wandb sweep" command 19 | 20 | 21 | import argparse 22 | import json 23 | import os 24 | from pathlib import Path 25 | 26 | import wandb 27 | 28 | 29 | def config_generation(model_type, dim, path): 30 | 31 | sweep_config = { 32 | # "controller": {"type": "local"}, 33 | "program": "scripts/graph-modeling", 34 | "command": [ 35 | "${env}", 36 | "${interpreter}", 37 | "${program}", 38 | "train", 39 | "${args}", 40 | "--model_type=" + model_type, 41 | "--log_interval=0.1", 42 | "--patience=21", 43 | "--log_eval_batch_size=17", 44 | "--epochs=10000", 45 | "--negative_ratio=128", 46 | "--dim=" + str(dim), 47 | "--data_path=" + path, 48 | ], 49 | "method": "bayes", 50 | "metric": {"goal": "maximize", "name": "[Train] F1"}, 51 | "parameters": { 52 | "learning_rate": {"distribution": "log_uniform", "min": -9.2, "max": 0}, 53 | "log_batch_size": {"distribution": "int_uniform", "min": 8, "max": 11}, 54 | "negative_weight": {"distribution": "uniform", "min": 0.0, "max": 1.0}, 55 | }, 56 | } 57 | 58 | if model_type in [ 59 | "box", 60 | "global_learned_temp_box", 61 | "per_entity_learned_temp_box", 62 | "per_dim_learned_temp_box", 63 | "pure_gumbel", 64 | ]: 65 | sweep_config["parameters"]["box_intersection_temp"] = { 66 | "distribution": "log_uniform", 67 | "min": -9.2, 68 | "max": -0.69, 69 | } 70 | sweep_config["parameters"]["box_volume_temp"] = { 71 | "distribution": "log_uniform", 72 | "min": -2.3, 73 | "max": 2.3, 74 | } 75 | if model_type == "oe": 76 | sweep_config["parameters"]["margin"] = { 77 | "distribution": "uniform", 78 | "min": 0, 79 | "max": 10, 80 | } 81 | if ( 82 | model_type == "lorentzian_distance" 83 | or model_type == "lorentzian" 84 | or model_type == "lorentzian_score" 85 | ): 86 | sweep_config["parameters"]["lorentzian_alpha"] = { 87 | "distribution": "uniform", 88 | "min": 0.0, 89 | "max": 10.0, 90 | } 91 | sweep_config["parameters"]["lorentzian_beta"] = { 92 | "distribution": "uniform", 93 | "min": 0.0, 94 | "max": 10, 95 | } 96 | if model_type == "vector_dist": 97 | sweep_config["parameters"]["margin"] = { 98 | "distribution": "uniform", 99 | "min": 1, 100 | "max": 30, 101 | } 102 | sweep_config["command"].append("--separate_io") 103 | if model_type == "vector": 104 | sweep_config["command"].append("--separate_io") 105 | if model_type == "bilinear_vector": 106 | sweep_config["command"].append("--no_separate_io") 107 | 108 | return sweep_config 109 | 110 | 111 | def main(config): 112 | 113 | if config.model_type not in [ 114 | "box", 115 | "per_entity_learned_temp_box", 116 | "per_dim_learned_temp_box", 117 | "global_learned_temp_box", 118 | "pure_gumbel", 119 | "vector", 120 | "complex_vector", 121 | "vector_dist", 122 | ]: 123 | config.dim = config.dim * 2 124 | 125 | if config.bayes_run == True: 126 | if not os.path.exists("sweeps_configs"): 127 | os.makedirs("sweeps_configs") 128 | fout = open( 129 | f"sweeps_configs/{config.model_type}_{config.dim}_{config.data_path.replace('/','_')}.jsonl", 130 | "w+", 131 | ) 132 | count_sweep = 0 133 | count_start_sweep = 0 134 | count_halfway_sweep = 0 135 | count_ninty_sweep = 0 136 | count_finished_sweep = 0 137 | count_finished_runs = 0 138 | for path in Path(config.data_path).glob("**/*log_num_nodes=*"): 139 | count_sweep += 1 140 | count_this_sweep = 0 141 | best_hyperparams = None 142 | best_metric = 0.0 143 | target_filename = f"results/{config.model_type}_{config.dim}/*metric" 144 | for f in path.glob(target_filename): 145 | # print(f) 146 | count_finished_runs += 1 147 | count_this_sweep += 1 148 | params, metrics = open(str(f)).read().split("\n") 149 | params = json.loads(params) 150 | metrics = json.loads(metrics) 151 | metric = metrics[0]["F1"] 152 | if metric > best_metric: 153 | best_hyperparams = params 154 | 155 | if count_this_sweep > 0: 156 | count_start_sweep += 1 157 | if count_this_sweep > config.max_run / 2: 158 | count_halfway_sweep += 1 159 | if count_this_sweep > config.max_run * 0.95: 160 | count_ninty_sweep += 1 161 | if count_this_sweep >= config.max_run: 162 | count_finished_sweep += 1 163 | 164 | print(f"{count_this_sweep} / {config.max_run} finished under {str(path)}") 165 | 166 | # If mode is bayes run, check each sweep directory. 167 | # If there is less then 95% results, clean the result and do a new sweep. 168 | if config.bayes_run == True and count_this_sweep <= 0.95 * config.max_run: 169 | print("deleting saved results in this sweep") 170 | target_filename = f"results/{config.model_type}_{config.dim}/*metric" 171 | for f in path.glob(target_filename): 172 | f.unlink() 173 | 174 | sweep_config = config_generation( 175 | model_type=config.model_type, dim=config.dim, path=str(path) 176 | ) 177 | sweep_id = wandb.sweep(sweep_config, project="learning_generated_graph") 178 | os.system( 179 | f"sh bin/launch_train_sweep.sh dongxu/learning_generated_graph/{sweep_id} {config.partition} {config.max_run} " 180 | ) 181 | fout.write(f"{sweep_id} {json.dumps(sweep_config)}\n") 182 | 183 | if config.bayes_run: 184 | fout.close() 185 | 186 | print(f"# sweep started: {count_start_sweep}/{count_sweep}") 187 | print(f"# sweep halfway: {count_halfway_sweep}/{count_sweep}") 188 | print(f"# sweep 95% finished: {count_ninty_sweep}/{count_sweep}") 189 | print(f"# sweep finished: {count_finished_sweep}/{count_sweep}") 190 | print(f"# run finished: {count_finished_runs}/{count_sweep * config.max_run}") 191 | 192 | 193 | if __name__ == "__main__": 194 | 195 | parser = argparse.ArgumentParser( 196 | description="Submit multiple sweeps over slurm servers." 197 | + "script will only check current status if running without --best_run or --bayes_run" 198 | ) 199 | parser.add_argument( 200 | "--bayes_run", 201 | action="store_true", 202 | help="do bayes hyper-parameter search for each sweep" 203 | + "(CAUTION: this will clear all existing results under each sweep data path (if not completed))", 204 | ) 205 | parser.add_argument("--model_type", type=str) 206 | parser.add_argument( 207 | "--dim", 208 | type=int, 209 | help="dimension will double when the model_type is not box or vector", 210 | ) 211 | parser.add_argument("--partition", type=str, default="titanx-short") 212 | parser.add_argument("--max_run", type=int, default=100) 213 | parser.add_argument("--data_path", type=str, default="data/graphs/") 214 | config = parser.parse_args() 215 | if config.model_type not in [ 216 | "box", 217 | "global_learned_temp_box", 218 | "per_entity_learned_temp_box", 219 | "per_dim_learned_temp_box", 220 | "pure_gumbel", 221 | "oe", 222 | "poe", 223 | "vector", 224 | "vector_dist", 225 | "bilinear_vector", 226 | "complex_vector", 227 | "lorentzian_distance", 228 | "lorentzian_score", 229 | "lorentzian", 230 | ]: 231 | raise Exception(f"model type {config.model_type} does not exist") 232 | main(config) 233 | -------------------------------------------------------------------------------- /scripts/copy_results.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import os 18 | import argparse 19 | from pathlib import Path 20 | from tqdm import tqdm 21 | 22 | 23 | parser = argparse.ArgumentParser(description="Submit multiple sweeps") 24 | parser.add_argument("--src", type=str, default="") 25 | parser.add_argument("--tgt", type=str, default="") 26 | parser.add_argument("--model_type", type=str) 27 | parser.add_argument("--dim", type=int, help="dimension ") 28 | config = parser.parse_args() 29 | if config.model_type not in [ 30 | "pure_gumbel", 31 | "per_entity_learned_temp_box", 32 | "per_dim_learned_temp_box", 33 | "global_learned_temp_box", 34 | "box", 35 | "oe", 36 | "poe", 37 | "vector", 38 | "vector_dist", 39 | "bilinear_vector", 40 | "complex_vector", 41 | "lorentzian_distance", 42 | ]: 43 | raise Exception(f"model type {config.model_type} does not exist") 44 | 45 | if config.model_type not in [ 46 | "box", 47 | "pure_gumbel", 48 | "per_entity_learned_temp_box", 49 | "per_dim_learned_temp_box", 50 | "global_learned_temp_box", 51 | "vector", 52 | "complex_vector", 53 | "vector_dist", 54 | ]: 55 | config.dim = config.dim * 2 56 | 57 | for path in tqdm( 58 | Path(config.src).glob(f"**/results/{config.model_type}_{config.dim}/*") 59 | ): 60 | tgt_path = config.tgt + str(path).split(config.src)[1] 61 | tgt_dir = os.path.dirname(tgt_path) 62 | if not os.path.exists(tgt_dir): 63 | os.makedirs(tgt_dir) 64 | os.system(f"cp {str(path)} {tgt_path}") 65 | -------------------------------------------------------------------------------- /scripts/graph-modeling: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # EASY-INSTALL-ENTRY-SCRIPT: 'graph-modeling','console_scripts','graph_modeling' 4 | __requires__ = "graph-modeling" 5 | import re 6 | import sys 7 | 8 | from graph_modeling.__main__ import main 9 | 10 | if __name__ == "__main__": 11 | sys.argv[0] = re.sub(r"(-script\.pyw?|\.exe)?$", "", sys.argv[0]) 12 | sys.exit(main()) 13 | -------------------------------------------------------------------------------- /scripts/wandb_sweep_example.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | program: scripts/graph-modeling 18 | command: 19 | - ${env} 20 | - ${interpreter} 21 | - ${program} 22 | - train 23 | - ${args} 24 | - --log_interval=10000 25 | - --graph_file_stub=data/graph/some_name 26 | method: bayes 27 | metric: 28 | goal: maximize 29 | name: '[Valid] F1' 30 | parameters: 31 | epochs: 32 | value: 1000 33 | learning_rate: 34 | distribution: log_uniform 35 | max: 0 36 | min: -10 37 | log_batch_size: 38 | values: [6,7,8] 39 | log_eval_batch_size: 40 | value: 12 41 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | """ 18 | Scripts to generate graphs, train and evaluate graph representations 19 | """ 20 | import fastentrypoints 21 | from setuptools import find_packages, setup 22 | 23 | setup( 24 | name="graph_modeling", 25 | version="0.1", 26 | packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), 27 | package_dir={"": "src"}, 28 | description="Scripts to generate graphs, train and evaluate graph representations", 29 | install_requires=[ 30 | "Click>=7.1.2", 31 | "networkx", 32 | "scipy", 33 | "scikit-learn", 34 | "numpy", 35 | "xopen", 36 | "toml", 37 | "torch", 38 | "pandas", 39 | "loguru", 40 | "tqdm", 41 | "wandb", # TODO: break out this dependency 42 | ], 43 | extras_require={ 44 | "price_generation": ["graph_tool"], 45 | "wordnet_generation": ["nltk"], 46 | "test": ["pytest", "hypothesis"], 47 | }, 48 | entry_points={"console_scripts": ["graph_modeling = graph_modeling.__main__:main"]}, 49 | ) 50 | -------------------------------------------------------------------------------- /src/graph_modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iesl/geometric-graph-embedding/46a4ed4406bff18c9570273fce99178d0e5820c8/src/graph_modeling/__init__.py -------------------------------------------------------------------------------- /src/graph_modeling/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import click 18 | 19 | from .generate.__main__ import main as generate 20 | from .metrics.__main__ import main as metrics 21 | from .training.__main__ import train 22 | 23 | 24 | @click.group() 25 | def main(): 26 | """Scripts to generate graphs, train and evaluate graph representations""" 27 | pass 28 | 29 | 30 | main.add_command(generate, "generate") 31 | main.add_command(train, "train") 32 | main.add_command(metrics, "metrics") 33 | -------------------------------------------------------------------------------- /src/graph_modeling/enums.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | from enum import Enum 18 | 19 | __all__ = ["PermutationOption"] 20 | 21 | 22 | class PermutationOption(Enum): 23 | none = "none" 24 | head = "head" 25 | tail = "tail" 26 | -------------------------------------------------------------------------------- /src/graph_modeling/generate/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | from importlib import import_module 18 | from pathlib import Path 19 | from pprint import pformat 20 | from typing import * 21 | 22 | import networkx as nx 23 | import toml 24 | from loguru import logger 25 | from scipy.sparse import save_npz 26 | 27 | __all__ = [ 28 | "write_graph", 29 | ] 30 | 31 | 32 | def write_graph(out_dir: Union[str, Path], **graph_config): 33 | out_dir = Path(out_dir).expanduser() 34 | out_dir.mkdir(parents=True, exist_ok=True) 35 | 36 | graph_module = import_module("graph_modeling.generate." + graph_config["type"]) 37 | 38 | logger.info("Generating graph with the following config:\n" + pformat(graph_config)) 39 | tree = graph_module.generate(**graph_config) 40 | 41 | graph_sub_configs = [] 42 | for name in sorted(graph_config.keys()): 43 | if name not in ["seed", "type"]: 44 | graph_sub_configs.append(f"{name}={graph_config[name]}") 45 | graph_folder_name = "-".join(graph_sub_configs) 46 | 47 | if graph_config["transitive_closure"]: 48 | tree = nx.transitive_closure(tree) 49 | 50 | logger.info("Converting to sparse matrix") 51 | t_scipy = nx.to_scipy_sparse_matrix(tree, format="coo") 52 | graph_folder = out_dir / f"{graph_config['type']}/{graph_folder_name}/" 53 | graph_folder.mkdir(parents=True, exist_ok=True) 54 | graph_file_stub = graph_folder / str(graph_config["seed"]) 55 | logger.info(f"Saving to {graph_file_stub}") 56 | save_npz(graph_file_stub.with_suffix(".npz"), t_scipy) 57 | with graph_file_stub.with_suffix(".toml").open("w") as f: 58 | toml.dump(graph_config, f) 59 | -------------------------------------------------------------------------------- /src/graph_modeling/generate/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import functools 18 | import random 19 | 20 | import click 21 | 22 | from . import write_graph 23 | 24 | 25 | @click.group() 26 | def main(): 27 | """Graph generation commands""" 28 | pass 29 | 30 | 31 | def _common_options(func): 32 | """Common options used in all subcommands""" 33 | 34 | @main.command(context_settings=dict(show_default=True)) 35 | @click.option( 36 | "--outdir", 37 | default="data/graphs/", 38 | type=click.Path(writable=True), 39 | help="location to save output", 40 | ) 41 | @click.option( 42 | "--log_num_nodes", type=int, default=12, help="2**log_num_nodes number of nodes" 43 | ) 44 | @click.option("--seed", type=int, default=None, help="manually set random seed") 45 | @click.option( 46 | "--transitive_closure / --no_transitive_closure", 47 | default=False, 48 | help="create the transitive closure of the generated graph", 49 | ) 50 | @functools.wraps(func) 51 | def wrapper(*args, seed, **kwargs): 52 | if seed is None: 53 | seed = random.randint(0, 2 ** 32) 54 | return func(*args, seed=seed, **kwargs) 55 | 56 | return wrapper 57 | 58 | 59 | @_common_options 60 | @click.option("--branching", default=2, help="branching factor") 61 | def balanced_tree(outdir, **graph_config): 62 | """Writes out a balanced directed tree""" 63 | write_graph(outdir, type="balanced_tree", **graph_config) 64 | 65 | 66 | @_common_options 67 | def random_tree(outdir, **graph_config): 68 | """Writes out a random directed tree""" 69 | write_graph(outdir, type="random_tree", **graph_config) 70 | 71 | 72 | @_common_options 73 | @click.option( 74 | "--alpha", 75 | default=0.41, 76 | help="probability for adding a new node connected to an existing node chosen randomly according " 77 | "to the in-degree distribution (0 <= alpha + gamma <= 1)", 78 | ) 79 | @click.option( 80 | "--gamma", 81 | default=0.05, 82 | help="probability for adding a new node connected to an existing node chosen randomly according " 83 | "to the out-degree distribution (0 <= alpha + gamma <= 1)", 84 | ) 85 | @click.option( 86 | "--delta_in", 87 | default=0.2, 88 | help="bias for choosing nodes from in-degree distribution", 89 | ) 90 | @click.option( 91 | "--delta_out", 92 | default=0.0, 93 | help="bias for choosing nodes from out-degree distribution", 94 | ) 95 | def scale_free_network(outdir, **graph_config): 96 | """Writes out a scale-free directed graph""" 97 | write_graph(outdir, type="scale_free_network", **graph_config) 98 | 99 | 100 | @_common_options 101 | @click.option( 102 | "--alpha", 103 | default=10, 104 | help="probability of adding a new table is proportional to alpha (>0)", 105 | ) 106 | def ncrp(outdir, **graph_config): 107 | """Writes out a nested Chinese restaurant process graph""" 108 | write_graph(outdir, type="nested_chinese_restaurant_process", **graph_config) 109 | 110 | 111 | @_common_options 112 | @click.option( 113 | "--a", default=1.0, help="first entry of seed graph", 114 | ) 115 | @click.option( 116 | "--b", default=0.6, help="second entry of seed graph", 117 | ) 118 | @click.option( 119 | "--c", default=0.5, help="third entry of seed graph", 120 | ) 121 | @click.option( 122 | "--d", default=0.2, help="fourth entry of seed graph", 123 | ) 124 | def kronecker(outdir, **graph_config): 125 | """Writes out a Kronecker graph""" 126 | write_graph(outdir, type="kronecker_graph", **graph_config) 127 | 128 | 129 | @_common_options 130 | @click.option( 131 | "--m", default=1, help="Out-degree of newly added vertices.", 132 | ) 133 | @click.option( 134 | "--c", 135 | default=1.0, 136 | help="Constant factor added to the probability of a vertex receiving an edge", 137 | ) 138 | @click.option( 139 | "--gamma", default=1.0, help="Preferential attachment exponent", 140 | ) 141 | def price(outdir, **graph_config): 142 | """Writes out a graph produced using the Price model""" 143 | write_graph(outdir, type="price", **graph_config) 144 | 145 | 146 | @_common_options 147 | @click.option( 148 | "--vector_file", default="", help="fourth entry of seed graph", 149 | ) 150 | def hac(outdir, **graph_config): 151 | """Writes out a HAC graph""" 152 | write_graph(outdir, type="hac", **graph_config) 153 | 154 | 155 | @_common_options 156 | @click.option( 157 | "--vector_file", default="", help="xcluster format", 158 | ) 159 | @click.option( 160 | "--k", default=5, help="number of neighbors", 161 | ) 162 | def knn_graph(outdir, **graph_config): 163 | """Writes out a KNN graph""" 164 | write_graph(outdir, type="knn_graph", **graph_config) 165 | 166 | 167 | @_common_options 168 | @click.option( 169 | "--root_name", default="entity", help="Name of root node to start traversing from", 170 | ) 171 | @click.option( 172 | "--traversal_method", default="dfs", help="How to expand from the root [dfs or bfs]" 173 | ) 174 | def wordnet(outdir, **graph_config): 175 | write_graph(outdir, type="wordnet", **graph_config) 176 | -------------------------------------------------------------------------------- /src/graph_modeling/generate/balanced_tree.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import networkx as nx 18 | 19 | from graph_modeling.generate.generic import convert_to_outtree 20 | 21 | __all__ = [ 22 | "generate", 23 | ] 24 | 25 | 26 | def generate(log_num_nodes: int, branching: int, **kwargs) -> nx.DiGraph: 27 | num_nodes = 2 ** log_num_nodes 28 | height = 0 29 | count_nodes = 0 30 | while count_nodes < num_nodes: 31 | count_nodes += branching ** height 32 | height += 1 33 | 34 | height -= 1 35 | 36 | tree = nx.balanced_tree(branching, height) 37 | tree.remove_nodes_from(list(range(num_nodes, count_nodes))) 38 | tree = convert_to_outtree(tree) 39 | return tree 40 | -------------------------------------------------------------------------------- /src/graph_modeling/generate/generic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import networkx as nx 18 | 19 | __all__ = [ 20 | "convert_to_outtree", 21 | "remove_self_loops", 22 | ] 23 | 24 | 25 | def convert_to_outtree(tree: nx.Graph) -> nx.DiGraph: 26 | """ 27 | the graph generated by networkx.random_tree() is undirected. 28 | means that parent -> children and children -> parent both exist 29 | This function change the graph to only parent -> children 30 | 31 | :param tree: 32 | :return: 33 | """ 34 | digraph = nx.DiGraph(tree) 35 | for u, v in nx.bfs_tree(tree, 0).edges(): 36 | digraph.remove_edge(v, u) 37 | return digraph 38 | 39 | 40 | def remove_self_loops(G): 41 | # TODO: G.remove_edges_from(nx.selfloop_edges(G)) ? 42 | for e in list(G.edges()): 43 | if e[0] == e[1]: 44 | G.remove_edge(e[0], e[1]) 45 | return G 46 | -------------------------------------------------------------------------------- /src/graph_modeling/generate/hac.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import networkx as nx 18 | import numpy as np 19 | from loguru import logger 20 | from numpy.random import default_rng 21 | from scipy.cluster.hierarchy import linkage 22 | from tqdm import tqdm 23 | 24 | __all__ = [ 25 | "generate", 26 | ] 27 | 28 | 29 | def from_linkage_matrix(Z): 30 | edges = [] # parents 31 | for i in range(Z.shape[0]): 32 | edges.append((Z.shape[0] + 1 + i, int(Z[i, 0]))) 33 | edges.append((Z.shape[0] + 1 + i, int(Z[i, 1]))) 34 | G = nx.DiGraph() 35 | G.add_edges_from(edges) 36 | return G 37 | 38 | 39 | def load_text(infile): 40 | vecs = [] 41 | lbls = [] 42 | with open(infile, "r") as f: 43 | for i, line in enumerate(f): 44 | splits = line.strip().split("\t") 45 | lbls.append(splits[1]) 46 | vecs.append([float(x) for x in splits[2:]]) 47 | vecs = np.array(vecs, dtype=np.float32) 48 | norms = np.linalg.norm(vecs, axis=1, keepdims=True) 49 | num_zero = np.sum(norms == 0) 50 | logger.info("Loaded vectors, and unit norming. %s vectors had 0 norm.", num_zero) 51 | norms[norms == 0] = 1.0 52 | vecs /= norms 53 | return np.arange(vecs.shape[0]), lbls, vecs 54 | 55 | 56 | def generate(log_num_nodes: int, seed: int, vector_file: str, **kwargs) -> nx.DiGraph: 57 | 58 | pids, labels, X = load_text(vector_file) 59 | 60 | # select a subset of log_num_nodes 61 | r = default_rng(seed) 62 | idx = np.arange(X.shape[0]) 63 | r.shuffle(idx) 64 | idx = idx[: 2 ** log_num_nodes] 65 | X = X[idx, :] 66 | 67 | def dot(XA, XB): 68 | return np.matmul(XA, XB.T) 69 | 70 | def batched_cdist(XA, XB, batch_size=1000, use_tqdm=True): 71 | res = np.zeros((XA.shape[0], XB.shape[0]), dtype=np.float32) 72 | if use_tqdm: 73 | for i in tqdm(range(0, XA.shape[0], batch_size), "cdist"): 74 | for j in range(0, XB.shape[0], batch_size): 75 | istart = i 76 | jstart = j 77 | iend = min(XA.shape[0], i + batch_size) 78 | jend = min(XB.shape[0], j + batch_size) 79 | res[istart:iend, jstart:jend] = dot( 80 | XA[istart:iend], XB[jstart:jend] 81 | ) 82 | else: 83 | for i in range(0, XA.shape[0], batch_size): 84 | for j in range(0, XB.shape[0], batch_size): 85 | istart = i 86 | jstart = j 87 | iend = min(XA.shape[0], i + batch_size) 88 | jend = min(XB.shape[0], j + batch_size) 89 | res[istart:iend, jstart:jend] = dot( 90 | XA[istart:iend], XB[jstart:jend] 91 | ) 92 | res = res + 1 93 | res = np.maximum(res, 0.0) 94 | return res 95 | 96 | Xdist = batched_cdist(X, X) 97 | Xdist = 0.5 * (Xdist + Xdist.T) 98 | Xdist = 2 - Xdist 99 | np.fill_diagonal(Xdist, 0.0) 100 | Xdist = np.maximum(Xdist, 0.0) 101 | from scipy.spatial.distance import squareform 102 | 103 | Xdist = squareform(Xdist) 104 | 105 | Z = linkage(Xdist, method="average") 106 | 107 | # build a coo matrix 108 | 109 | G = from_linkage_matrix(Z) 110 | return G 111 | -------------------------------------------------------------------------------- /src/graph_modeling/generate/knn_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import networkx as nx 18 | import numpy as np 19 | from loguru import logger 20 | from numpy.random import default_rng 21 | from tqdm import tqdm 22 | 23 | __all__ = [ 24 | "generate", 25 | ] 26 | 27 | 28 | def load_text(infile): 29 | vecs = [] 30 | lbls = [] 31 | with open(infile, "r") as f: 32 | for i, line in enumerate(f): 33 | splits = line.strip().split("\t") 34 | lbls.append(splits[1]) 35 | vecs.append([float(x) for x in splits[2:]]) 36 | vecs = np.array(vecs, dtype=np.float32) 37 | norms = np.linalg.norm(vecs, axis=1, keepdims=True) 38 | num_zero = np.sum(norms == 0) 39 | logger.info("Loaded vectors, and unit norming. %s vectors had 0 norm.", num_zero) 40 | norms[norms == 0] = 1.0 41 | vecs /= norms 42 | return np.arange(vecs.shape[0]), lbls, vecs 43 | 44 | 45 | def generate( 46 | log_num_nodes: int, seed: int, vector_file: str, k: int, **kwargs 47 | ) -> nx.DiGraph: 48 | 49 | pids, labels, X = load_text(vector_file) 50 | 51 | # select a subset of log_num_nodes 52 | r = default_rng(seed) 53 | idx = np.arange(X.shape[0]) 54 | r.shuffle(idx) 55 | idx = idx[: 2 ** log_num_nodes] 56 | X = X[idx, :] 57 | 58 | def dot(XA, XB): 59 | return np.matmul(XA, XB.T) 60 | 61 | def batched_knn(XA, XB, K, batch_size=1000, offset=0): 62 | K = np.minimum(K, XB.shape[0]) 63 | res_i = np.zeros((XA.shape[0], K), dtype=np.int32) 64 | res = np.zeros((XA.shape[0], K), dtype=np.int32) 65 | resd = np.zeros((XA.shape[0], K), dtype=np.float32) 66 | for i in tqdm([x for x in range(0, XA.shape[0], batch_size)]): 67 | istart = i 68 | iend = min(XA.shape[0], i + batch_size) 69 | r = np.zeros((iend - istart, XB.shape[0]), dtype=np.float32) 70 | for j in range(0, XB.shape[0], batch_size): 71 | jstart = j 72 | jend = min(XB.shape[0], j + batch_size) 73 | r[:, jstart:jend] = dot(XA[istart:iend], XB[jstart:jend]) 74 | np.put( 75 | r, 76 | np.arange(iend - istart) * r.shape[1] + np.arange(istart, iend), 77 | np.inf, 78 | ) 79 | res[istart:iend, :] = np.argpartition(r, -K, axis=1)[:, -K:] 80 | resd[istart:iend, :] = r[ 81 | np.arange(iend - istart)[:, None], res[istart:iend, :] 82 | ] 83 | res_i[istart:iend, :] = ( 84 | np.repeat(np.expand_dims(np.arange(istart, iend), 1), K, axis=1) 85 | + offset 86 | ) 87 | 88 | from scipy.sparse import coo_matrix 89 | 90 | row = res_i.flatten() 91 | col = res.flatten() 92 | d = resd.flatten() 93 | c = coo_matrix( 94 | (d, (row, col)), dtype=np.float32, shape=(XB.shape[0], XB.shape[0]) 95 | ) 96 | return c 97 | 98 | G = batched_knn(X, X, k) 99 | edges = [] 100 | for i in range(G.row.shape[0]): 101 | edges.append((G.col[i], G.row[i])) 102 | 103 | G = nx.DiGraph() 104 | G.add_edges_from(edges) 105 | return G 106 | -------------------------------------------------------------------------------- /src/graph_modeling/generate/kronecker_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import os 18 | import subprocess 19 | import tempfile 20 | from pathlib import Path 21 | 22 | import networkx as nx 23 | from loguru import logger 24 | 25 | from .generic import remove_self_loops 26 | 27 | __all__ = [ 28 | "generate", 29 | ] 30 | 31 | 32 | FILE_DIR = Path(os.path.realpath(__file__)).parent 33 | KRONECKER_BINARY_LOCATION = FILE_DIR / "../../../libc/snap/examples/krongen/krongen" 34 | 35 | if not KRONECKER_BINARY_LOCATION.exists(): 36 | logger.warning( 37 | f"Cannot locate Kronecker generation binary, did you compile it? (cd libc/snap/examples/krongen/; make all)" 38 | ) 39 | raise RuntimeError( 40 | f"Kronecker generation binary not found at {KRONECKER_BINARY_LOCATION.resolve()}" 41 | ) 42 | 43 | 44 | def generate( 45 | log_num_nodes: int, a: float, b: float, c: float, d: float, seed: int, **kwargs 46 | ) -> nx.DiGraph: 47 | 48 | tmp_graph_file = tempfile.NamedTemporaryFile(delete=False) 49 | tmp_graph_file.close() 50 | logger.info(f"Generating graph in temporary file {tmp_graph_file.name}") 51 | args = [ 52 | str(KRONECKER_BINARY_LOCATION), 53 | f"-o:{tmp_graph_file.name}", 54 | f'-m:"{a} {b}; {c} {d}"', 55 | f"-i:{log_num_nodes}", 56 | "-s:{seed}", 57 | ] 58 | logger.info(f"Running subprocess {' '.join(args)}") 59 | subprocess.call(args) 60 | logger.info(f"Subprocess krongen completed, reading edge list") 61 | edge_list = [] 62 | with open(tmp_graph_file.name, "r") as f: 63 | for line in f.read().split("\n")[4:-1]: 64 | e1, e2 = line.split("\t") 65 | edge_list.append((int(e1), int(e2))) 66 | g = nx.DiGraph(edge_list) 67 | logger.info( 68 | f"Generated graph has {g.number_of_nodes()} nodes, {g.number_of_edges()} edges" 69 | ) 70 | logger.info(f"Removing self-loops") 71 | g = remove_self_loops(g) 72 | logger.info(f"After removing self-loops, graph has {g.number_of_edges()} edges") 73 | os.unlink(tmp_graph_file.name) 74 | return g 75 | -------------------------------------------------------------------------------- /src/graph_modeling/generate/nested_chinese_restaurant_process.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import random 18 | 19 | import networkx as nx 20 | from loguru import logger 21 | from typing import * 22 | 23 | __all__ = [ 24 | "generate", 25 | ] 26 | 27 | 28 | def generate(log_num_nodes: int, seed: int, alpha: int, **kwargs) -> nx.DiGraph: 29 | """ 30 | Generate a nCRP graph with `num_nodes` nodes and parameter `alpha` which represents the 31 | "number of people you expect to be sitting at a new table" 32 | """ 33 | num_nodes = 2 ** log_num_nodes 34 | nodes_to_process = [(0, num_nodes, None)] # id, num nodes, parent 35 | edges = [] 36 | rng = random.Random(seed) 37 | next_id = 1 38 | while nodes_to_process: 39 | logger.debug(f"num nodes to process: {len(nodes_to_process)}") 40 | node_id, count, parent = nodes_to_process.pop(0) 41 | if parent is not None: 42 | edges.append((parent, node_id)) 43 | if count > 1: 44 | counts_of_kids = sample_crp(count, alpha, rng) 45 | for kid_count in counts_of_kids: 46 | nodes_to_process.append((next_id, kid_count, node_id)) 47 | next_id += 1 48 | 49 | G = nx.DiGraph() 50 | G.add_edges_from(edges) 51 | return G 52 | 53 | 54 | def sample_crp(N: int, alpha: float, rng: random.Random) -> List[int]: 55 | """ 56 | Return a list of integers representing the number of people sitting at each table in a Chinese restaurant process. 57 | :param N: number of people to seat 58 | :param alpha: number of people you expect to be sitting at a new table 59 | :param rng: random number generator 60 | (provided as an argument, as we will call this repeatedly and want to seed it prior) 61 | 62 | returns: List of integers, where the ith integer represents the number of people seated at the ith table 63 | """ 64 | # begin with one "imagined" table 65 | people_per_table = [alpha] 66 | for i in range(N): 67 | # sample a table to sit at 68 | x = rng.random() * (i + alpha) 69 | sampled_table = 0 70 | for people_at_this_table in people_per_table: 71 | x -= people_at_this_table 72 | if x < 0: 73 | break 74 | sampled_table += 1 75 | 76 | if sampled_table == len(people_per_table) - 1: 77 | # keep the new table, which actually only has 1 person 78 | people_per_table[-1] = 1 79 | # set up a new one 80 | people_per_table.append(alpha) 81 | else: 82 | # otherwise, simply increment the count on existing table 83 | people_per_table[sampled_table] += 1 84 | # and leave that new table set up 85 | # return all but the last "imagined" table 86 | return people_per_table[:-1] 87 | -------------------------------------------------------------------------------- /src/graph_modeling/generate/price.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import networkx as nx 18 | from loguru import logger 19 | 20 | try: 21 | import graph_tool as gt 22 | from graph_tool.generation import price_network 23 | except ImportError as e: 24 | logger.warning( 25 | "Could not import graph_tool, did you install it? (conda install -c conda-forge graph-tool)" 26 | ) 27 | raise e 28 | 29 | 30 | def generate( 31 | log_num_nodes: int, seed: int, m: int, c: float, gamma: float, **kwargs 32 | ) -> nx.DiGraph: 33 | gt.seed_rng(seed) 34 | num_nodes = 2 ** log_num_nodes 35 | g = price_network(num_nodes, m, c, gamma) 36 | ngx = nx.DiGraph() 37 | for s, t in g.iter_edges(): 38 | ngx.add_edge(t, s) 39 | return ngx 40 | -------------------------------------------------------------------------------- /src/graph_modeling/generate/random_tree.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import networkx as nx 18 | 19 | from .generic import convert_to_outtree 20 | 21 | __all__ = [ 22 | "generate", 23 | ] 24 | 25 | 26 | def generate(log_num_nodes: int, seed: int, **kwargs) -> nx.DiGraph: 27 | num_nodes = 2 ** log_num_nodes 28 | t = nx.random_tree(n=num_nodes, seed=seed) 29 | t = convert_to_outtree(t) 30 | return t 31 | -------------------------------------------------------------------------------- /src/graph_modeling/generate/scale_free_network.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import networkx as nx 18 | 19 | from .generic import remove_self_loops 20 | 21 | __all__ = [ 22 | "generate", 23 | ] 24 | 25 | 26 | def generate( 27 | log_num_nodes: int, 28 | alpha: float, 29 | gamma: float, 30 | delta_in: float, 31 | delta_out: float, 32 | seed: int, 33 | **kwargs 34 | ) -> nx.DiGraph: 35 | num_nodes = 2 ** log_num_nodes 36 | g = nx.scale_free_graph( 37 | num_nodes, 38 | alpha=alpha, 39 | beta=1 - alpha - gamma, 40 | gamma=gamma, 41 | delta_in=delta_in, 42 | delta_out=delta_out, 43 | seed=seed, 44 | ) 45 | g = nx.DiGraph(g) 46 | g = remove_self_loops(g) 47 | return g 48 | -------------------------------------------------------------------------------- /src/graph_modeling/generate/wordnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | from itertools import chain 18 | import networkx as nx 19 | from nltk.corpus import wordnet 20 | 21 | __all__ = [ 22 | "generate", 23 | ] 24 | 25 | 26 | def generate(log_num_nodes: int, seed: int, root_name: str, **kwargs) -> nx.DiGraph: 27 | G = nx.DiGraph() 28 | queue = wordnet.synsets(root_name) 29 | if not queue: 30 | raise ValueError(f"Synset with name '{root_name}' does not exist") 31 | if len(queue) > 1: 32 | raise ValueError(f"More than one synset matches name '{root_name}'.") 33 | 34 | node_limit = 2 ** log_num_nodes 35 | G.add_node(queue[0]) 36 | while queue: 37 | node = queue.pop(0) 38 | for hyponym in chain(node.hyponyms(), node.instance_hyponyms()): 39 | if G.number_of_nodes() >= node_limit: 40 | queue = [] 41 | break 42 | G.add_edge(node, hyponym) 43 | queue.append(hyponym) 44 | 45 | return G 46 | -------------------------------------------------------------------------------- /src/graph_modeling/metric_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | from pytorch_utils.loggers import Logger 18 | 19 | __all__ = [ 20 | "metric_logger", 21 | ] 22 | 23 | metric_logger = Logger() 24 | -------------------------------------------------------------------------------- /src/graph_modeling/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iesl/geometric-graph-embedding/46a4ed4406bff18c9570273fce99178d0e5820c8/src/graph_modeling/metrics/__init__.py -------------------------------------------------------------------------------- /src/graph_modeling/metrics/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import click 18 | 19 | 20 | @click.group() 21 | def main(): 22 | """Calculate and collect graph metrics / characteristics""" 23 | pass 24 | 25 | 26 | @main.command(context_settings=dict(show_default=True),) 27 | @click.argument( 28 | "data_path", type=click.Path(), 29 | ) 30 | @click.option( 31 | "--metrics", 32 | "-m", 33 | type=click.Choice( 34 | [ 35 | "all", 36 | "num_edges", 37 | "num_nodes", 38 | "avg_degree", 39 | "sparsity", 40 | "transitivity", 41 | "reciprocity", 42 | "flow_hierarchy", 43 | "clustering_coefficient", 44 | "assortativity", 45 | ], 46 | case_sensitive=False, 47 | ), 48 | default=("all",), 49 | help="name(s) of graph metric to calculate", 50 | multiple=True, 51 | ) 52 | @click.option( 53 | "--predictions / --no_predictions", 54 | "-p/ ", 55 | default=False, 56 | help="calculate metrics on predictions (otherwise only calculate on original graphs)", 57 | multiple=False, 58 | ) 59 | def calc(data_path, metrics, predictions=False): 60 | """Calculate graph metrics / characteristics""" 61 | from .calculate import write_metrics 62 | 63 | write_metrics(data_path, metrics) 64 | 65 | 66 | @main.command(context_settings=dict(show_default=True),) 67 | @click.argument( 68 | "data_path", type=click.Path(), 69 | ) 70 | def collect_graph_info(data_path): 71 | """Collect graph characteristics to a single tsv""" 72 | from .collect import collect_graph_info 73 | 74 | collect_graph_info(data_path) 75 | 76 | 77 | @main.command(context_settings=dict(show_default=True),) 78 | @click.argument( 79 | "data_path", type=click.Path(), 80 | ) 81 | def collect_result_info(data_path): 82 | """Collect results into a single tsv""" 83 | from .collect import collect_result_info 84 | 85 | collect_result_info(data_path) 86 | -------------------------------------------------------------------------------- /src/graph_modeling/metrics/calculate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import json 18 | import time 19 | from pathlib import Path 20 | from typing import * 21 | 22 | import networkx as nx 23 | from loguru import logger 24 | from scipy.sparse import load_npz 25 | from tqdm import tqdm 26 | 27 | __all__ = [ 28 | "calc_metrics", 29 | "write_metrics", 30 | ] 31 | 32 | num_nodes = lambda G: G.number_of_nodes() 33 | num_edges = lambda G: G.number_of_edges() 34 | sparsity = lambda G: G.number_of_edges() / G.number_of_nodes() 35 | avg_degree = lambda G: G.number_of_edges() / G.number_of_nodes() 36 | 37 | metric_functions = { 38 | "num_nodes": num_nodes, 39 | "num_edges": num_edges, 40 | "sparsity": sparsity, 41 | "avg_degree": avg_degree, 42 | "transitivity": nx.transitivity, 43 | "reciprocity": nx.reciprocity, 44 | "flow_hierarchy": nx.flow_hierarchy, 45 | "clustering_coefficient": nx.average_clustering, 46 | "assortativity": nx.degree_pearson_correlation_coefficient, 47 | } 48 | 49 | 50 | def calc_metrics( 51 | path: Union[str, Path], metrics_to_calc: Dict[str, Callable] 52 | ) -> Dict[str, Any]: 53 | digraph_coo = load_npz(path) 54 | G = nx.from_scipy_sparse_matrix(digraph_coo, create_using=nx.DiGraph) 55 | calculated_character = dict() 56 | for character, func in metrics_to_calc.items(): 57 | time1 = time.time() 58 | calculated_character[character] = func(G) 59 | time2 = time.time() 60 | calculated_character[f"character_time"] = time2 - time1 61 | return calculated_character 62 | 63 | 64 | def write_metrics( 65 | data_path: Union[str, Path], 66 | metrics_to_calc: Iterable[str], 67 | predictions: bool = False, 68 | ) -> None: 69 | """ 70 | Calculate graph characteristics and write out a file 71 | :param data_path: path to search recursively for graphs in 72 | :param metrics_to_calc: names of metrics to calculate 73 | """ 74 | data_path = Path(data_path).expanduser() 75 | unavailable_metrics = set(metrics_to_calc).difference( 76 | {"all", *metric_functions.keys()} 77 | ) 78 | if unavailable_metrics: 79 | logger.warning( 80 | f"Requested calculation of {unavailable_metrics}, but these are not implemented." 81 | ) 82 | if "all" in metrics_to_calc: 83 | metrics_to_calc = metric_functions 84 | else: 85 | # sort metrics_to_calc according to metric_functions 86 | metrics_to_calc = { 87 | k: v for k, v in metric_functions.items() if k in metrics_to_calc 88 | } 89 | if predictions: 90 | graph_files = data_path.glob("**/*prediction.npz") 91 | else: 92 | graph_files = data_path.glob("**/*[!prediction].npz") 93 | progress_bar = tqdm(graph_files, desc="Evaluating Metrics...") 94 | for graph_file in progress_bar: 95 | progress_bar.set_description(f"Evaluating Metrics [{graph_file}]...") 96 | metrics = calc_metrics(str(graph_file), metrics_to_calc) 97 | output_path = graph_file.with_suffix(".json") 98 | logger.debug(f"Writing metrics to {output_path}") 99 | json.dump(metrics, output_path.open("w")) 100 | logger.info("Complete!") 101 | -------------------------------------------------------------------------------- /src/graph_modeling/metrics/collect.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import json 18 | from pathlib import Path 19 | from typing import * 20 | 21 | import pandas as pd 22 | import toml 23 | from loguru import logger 24 | from tqdm import tqdm 25 | 26 | __all__ = [ 27 | "collect_graph_info", 28 | "collect_result_info", 29 | ] 30 | 31 | 32 | def collect_graph_info(data_path: Union[Path, str]) -> pd.DataFrame: 33 | """ 34 | This function collects graph metrics from a given data path and outputs tsv files. 35 | It will perform a depth-first search on data_path. In each folder it will deposit a tsv file called 36 | "graph_metrics.tsv" where each row represents a single graph, and the columns are graph characteristics. 37 | For convenience, it will also save a pandas data frame as "graph_metrics.pkl". 38 | 39 | :param data_path: path to explore recursively 40 | :returns: DataFrame with aggregated statistics in data_path 41 | """ 42 | data_path = Path(data_path).expanduser() 43 | 44 | subdir_dfs = [ 45 | collect_graph_info(subdir) for subdir in data_path.iterdir() if subdir.is_dir() 46 | ] 47 | 48 | metrics = [] 49 | for graph_file in data_path.glob("*[!prediction].npz"): 50 | graph_param = toml.load(graph_file.with_suffix(".toml")) 51 | graph_metrics = json.load(graph_file.with_suffix(".json").open()) 52 | metrics.append({"path": graph_file, **graph_param, **graph_metrics}) 53 | 54 | metrics_df = pd.DataFrame.from_records(metrics) 55 | aggregated_df = pd.concat((metrics_df, *subdir_dfs), ignore_index=True) 56 | if not aggregated_df.empty: 57 | aggregated_df.to_csv(data_path / "graph_metrics.tsv", sep="\t") 58 | aggregated_df.to_pickle(data_path / "graph_metrics.pkl") 59 | 60 | return aggregated_df 61 | 62 | 63 | def collect_result_info(data_path: Union[Path, str]) -> pd.DataFrame: 64 | """ 65 | This function collects results from a given data path and outputs tsv files. 66 | It will perform a depth-first search on data_path. In each folder it will deposit a tsv file called 67 | "results.tsv" where each row represents a single result, and the columns are performance and graph characteristics. 68 | For convenience, it will also save a pandas data frame as "results.pkl". 69 | 70 | :param data_path: path to explore recursively 71 | :returns: DataFrame with aggregated statistics in data_path 72 | """ 73 | data_path = Path(data_path).expanduser() 74 | 75 | metrics = [] 76 | for result_file in tqdm(data_path.glob("**/*results/**/*.metric")): 77 | logger.debug(result_file) 78 | model_config, results = open(result_file).read().split("\n") 79 | model_config = json.loads(model_config) 80 | results = json.loads(results)[0] 81 | seed = model_config["data_path"].split("/")[-1] 82 | graph_file = result_file.parent.parent.parent / f"{seed}.npz" 83 | try: 84 | graph_param = toml.load(graph_file.with_suffix(".toml")) 85 | except: 86 | graph_param = {} 87 | try: 88 | graph_metrics = json.load(graph_file.with_suffix(".json").open()) 89 | except: 90 | graph_metrics = {} 91 | metrics.append( 92 | { 93 | "result_path": result_file, 94 | "path": graph_file, 95 | **model_config, 96 | **graph_param, 97 | **graph_metrics, 98 | **results, 99 | } 100 | ) 101 | 102 | metrics_df = pd.DataFrame.from_records(metrics) 103 | if not metrics_df.empty: 104 | metrics_df.to_csv(data_path / "results.tsv", sep="\t") 105 | metrics_df.to_pickle(data_path / "results.pkl") 106 | 107 | return metrics_df 108 | -------------------------------------------------------------------------------- /src/graph_modeling/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iesl/geometric-graph-embedding/46a4ed4406bff18c9570273fce99178d0e5820c8/src/graph_modeling/models/__init__.py -------------------------------------------------------------------------------- /src/graph_modeling/models/box.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | from typing import * 18 | 19 | import torch 20 | import wandb 21 | from torch import Tensor, LongTensor 22 | from torch.nn import Module, Parameter 23 | from torch.nn import functional as F 24 | from wandb_utils.loggers import WandBLogger 25 | 26 | from .temps import convert_float_to_const_temp 27 | from .. import metric_logger 28 | 29 | __all__ = [ 30 | "BoxMinDeltaSoftplus", 31 | "TBox", 32 | ] 33 | 34 | 35 | class BoxMinDeltaSoftplus(Module): 36 | def __init__(self, num_entity, dim, volume_temp=1.0, intersection_temp=1.0): 37 | super().__init__() 38 | self.centers = torch.nn.Embedding(num_entity, dim) 39 | self.sidelengths = torch.nn.Embedding(num_entity, dim) 40 | self.centers.weight.data.uniform_(-0.1, 0.1) 41 | self.sidelengths.weight.data.zero_() 42 | 43 | self.volume_temp = volume_temp 44 | self.intersection_temp = intersection_temp 45 | self.softplus = torch.nn.Softplus(beta=1 / self.volume_temp) 46 | self.softplus_const = 2 * self.intersection_temp * 0.57721566490153286060 47 | 48 | def log_volume(self, z, Z): 49 | log_vol = torch.sum( 50 | torch.log(self.softplus(Z - z - self.softplus_const)), dim=-1, 51 | ) 52 | return log_vol 53 | 54 | def embedding_lookup(self, idx): 55 | center = self.centers(idx) 56 | length = self.softplus(self.sidelengths(idx)) 57 | z = center - length 58 | Z = center + length 59 | return z, Z 60 | 61 | def gumbel_intersection(self, e1_min, e1_max, e2_min, e2_max): 62 | meet_min = self.intersection_temp * torch.logsumexp( 63 | torch.stack( 64 | [e1_min / self.intersection_temp, e2_min / self.intersection_temp] 65 | ), 66 | 0, 67 | ) 68 | meet_max = -self.intersection_temp * torch.logsumexp( 69 | torch.stack( 70 | [-e1_max / self.intersection_temp, -e2_max / self.intersection_temp] 71 | ), 72 | 0, 73 | ) 74 | meet_min = torch.max(meet_min, torch.max(e1_min, e2_min)) 75 | meet_max = torch.min(meet_max, torch.min(e1_max, e2_max)) 76 | return meet_min, meet_max 77 | 78 | def forward(self, idxs): 79 | """ 80 | :param idxs: Tensor of shape (..., 2) (N, K+1, 2) during training or (N, 2) during testing 81 | :return: log prob of shape (..., ) 82 | """ 83 | e1_min, e1_max = self.embedding_lookup(idxs[..., 0]) 84 | e2_min, e2_max = self.embedding_lookup(idxs[..., 1]) 85 | 86 | meet_min, meet_max = self.gumbel_intersection(e1_min, e1_max, e2_min, e2_max) 87 | 88 | log_overlap_volume = self.log_volume(meet_min, meet_max) 89 | log_rhs_volume = self.log_volume(e2_min, e2_max) 90 | 91 | return log_overlap_volume - log_rhs_volume 92 | 93 | def forward_log_overlap_volume(self, idxs): 94 | """ 95 | :param idxs: Tensor of shape (N, 2) 96 | :return: log of overlap volume, shape (N, ) 97 | """ 98 | e1_min, e1_max = self.embedding_lookup(idxs[..., 0]) 99 | e2_min, e2_max = self.embedding_lookup(idxs[..., 1]) 100 | 101 | meet_min, meet_max = self.gumbel_intersection(e1_min, e1_max, e2_min, e2_max) 102 | 103 | log_overlap_volume = self.log_volume(meet_min, meet_max) 104 | 105 | return log_overlap_volume 106 | 107 | def forward_log_marginal_volume(self, idxs): 108 | """ 109 | :param idxs: Tensor of shape (N, ) 110 | :return: log of marginal volume, shape (N, ) 111 | """ 112 | e_min, e_max = self.embedding_lookup(idxs) 113 | log_volume = self.log_volume(e_min, e_max) 114 | 115 | return log_volume 116 | 117 | 118 | class TBox(Module): 119 | """ 120 | Box embedding model where the temperatures can (optionally) be trained. 121 | 122 | In this model, the self.boxes parameter is of shape (num_entity, 2, dim), where self.boxes[i,:,k] are location 123 | parameters for Gumbel distributions representing the corners of the ith box in the kth dimension. 124 | self.boxes[i,0,k] is the location parameter mu_z for a MaxGumbel distribution 125 | self.boxes[i,1,k] represents -mu_Z, i.e. negation of location parameter, for a MinGumbel distribution 126 | This rather odd convention is chosen to maximize speed / ease of computation. 127 | 128 | Note that with this parameterization, we allow the location parameter to "flip around", i.e. mu_z > mu_Z. 129 | This is completely reasonable, from the GumbelBox perspective (in fact, a bit more reasonable than requiring 130 | mu_Z > mu_z, as this means the distributions are no longer independent). 131 | 132 | :param num_entities: Number of entities to create box embeddings for (eg. number of nodes). 133 | :param dim: Embedding dimension (i.e. boxes will be in RR^dim). 134 | :param intersection_temp: Temperature for intersection LogSumExp calculations 135 | :param volume_temp: Temperature for volume LogSumExp calculations 136 | Note: Temperatures can either be either a float representing a constant (global) temperature, 137 | or a Module which, when called, takes a LongTensor of indices and returns their temps. 138 | """ 139 | 140 | def __init__( 141 | self, 142 | num_entities: int, 143 | dim: int, 144 | intersection_temp: Union[Module, float] = 0.01, 145 | volume_temp: Union[Module, float] = 1.0, 146 | ): 147 | super().__init__() 148 | self.boxes = Parameter( 149 | torch.sort(torch.randn((num_entities, 2, dim)), dim=-2).values 150 | * torch.tensor([1, -1])[None, :, None] 151 | ) 152 | self.intersection_temp = convert_float_to_const_temp(intersection_temp) 153 | self.volume_temp = convert_float_to_const_temp(volume_temp) 154 | 155 | def forward( 156 | self, idxs: LongTensor 157 | ) -> Union[Tuple[Tensor, Dict[str, Tensor]], Tensor]: 158 | """ 159 | A version of the forward pass that is slightly more performant. 160 | :param idxs: Tensor of shape (..., 2) indicating edges, i.e. [...,0] -> [..., 1] is an edge 161 | :returns: FloatTensor representing the energy of the edges in `idxs` 162 | """ 163 | boxes = self.boxes[idxs] # shape (..., 2, 2 (min/-max), dim) 164 | intersection_temp = self.intersection_temp(idxs).mean(dim=-3, keepdim=True) 165 | volume_temp = self.volume_temp(idxs).mean(dim=-3, keepdim=False) 166 | 167 | # calculate Gumbel intersection 168 | intersection = intersection_temp * torch.logsumexp( 169 | boxes / intersection_temp, dim=-3, keepdim=True 170 | ) 171 | intersection = torch.max( 172 | torch.cat((intersection, boxes), dim=-3), dim=-3 173 | ).values 174 | # combine intersections and marginals, since we are going to perform the same operations on both 175 | intersection_and_marginal = torch.stack( 176 | (intersection, boxes[..., 1, :, :]), dim=-3 177 | ) 178 | # calculating log volumes 179 | # keep in mind that the [...,1,:] represents negative max, thus we negate it 180 | log_volumes = torch.sum( 181 | torch.log( 182 | volume_temp 183 | * F.softplus((-intersection_and_marginal.sum(dim=-2)) / volume_temp) 184 | + 1e-23 185 | ), 186 | dim=-1, 187 | ) 188 | out = log_volumes[..., 0] - log_volumes[..., 1] 189 | 190 | if self.training and isinstance(metric_logger.metric_logger, WandBLogger): 191 | regularizer_terms = { 192 | "intersection_temp": self.intersection_temp(idxs).squeeze(-2), 193 | "volume_temp": self.volume_temp(idxs).squeeze(-2), 194 | "log_marginal_vol": log_volumes[..., 1], 195 | "marginal_vol": log_volumes[..., 1].exp(), 196 | "side_length": -boxes.sum(dim=-2), 197 | } 198 | metrics_to_collect = { 199 | "pos": wandb.Histogram(out[..., 0].detach().exp().cpu()), 200 | "neg": wandb.Histogram(out[..., 1:].detach().exp().cpu()), 201 | } 202 | for k, v in regularizer_terms.items(): 203 | metrics_to_collect[k] = wandb.Histogram(v.detach().cpu()) 204 | 205 | metric_logger.metric_logger.collect( 206 | {f"[Train] {k}": v for k, v in metrics_to_collect.items()}, 207 | overwrite=True, 208 | ) 209 | return out 210 | -------------------------------------------------------------------------------- /src/graph_modeling/models/hyperbolic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | from torch import Tensor, LongTensor 20 | from torch.nn import Module 21 | 22 | from .temps import BoundedTemp 23 | 24 | __all__ = [ 25 | "Lorentzian", 26 | "LorentzianDistance", 27 | "LorentzianScore", 28 | "squared_lorentzian_distance", 29 | "lorentzian_inner_product", 30 | "hyperboloid_vector", 31 | "HyperbolicEntailmentCones", 32 | ] 33 | 34 | 35 | class Lorentzian(Module): 36 | """ 37 | This embedding model uses the (symmetric) squared lorentzian distance function from Law et al. 2019 38 | (http://proceedings.mlr.press/v97/law19a/law19a.pdf) for training, and an adaptation of the score function from 39 | equation (8) of Nickel & Kiela 2017 (https://arxiv.org/pdf/1705.08039.pdf) for evaluation. Namely, our 40 | evaluation function will be 41 | 42 | -(1 + alpha (||u||^2 - ||v||^2)) ||u - v||_L^2 43 | 44 | where ||.||_L is the Lorentzian distance. 45 | 46 | :param alpha: penalty for distance, where higher alpha emphasises distance as a determining factor 47 | for edge direction more 48 | :param beta: -1/curvature 49 | """ 50 | 51 | def __init__( 52 | self, num_entity: int, dim: int, alpha: float = 5.0, beta: float = 1.0 53 | ): 54 | super().__init__() 55 | self.embeddings_in = torch.nn.Embedding(num_entity, dim) 56 | self.alpha = alpha 57 | self.beta = beta 58 | 59 | def forward(self, idxs: LongTensor) -> Tensor: 60 | """ 61 | Returns the score of edges between the nodes in `idxs`. 62 | :param idxs: Tensor of shape (..., 2) indicating edges, i.e. [...,0] -> [..., 1] is an edge 63 | :return: score 64 | """ 65 | euclidean_embeddings = self.embeddings_in(idxs) 66 | hyperboloid_embeddings = hyperboloid_vector(euclidean_embeddings, self.beta) 67 | dist = squared_lorentzian_distance( 68 | hyperboloid_embeddings[..., 0, :], 69 | hyperboloid_embeddings[..., 1, :], 70 | self.beta, 71 | ) 72 | if self.training: 73 | return dist + 1e-5 # flip sign for distance loss 74 | else: 75 | euclidean_norm = euclidean_embeddings.pow(2).sum(dim=-1) 76 | return ( 77 | -(1 + self.alpha * (euclidean_norm[..., 0] - euclidean_norm[..., 1])) 78 | * dist 79 | ) 80 | 81 | 82 | class LorentzianDistance(Module): 83 | """ 84 | This embedding returns a score <=0 for a given edge, where higher is better. 85 | The score is given by 86 | 87 | -||u - v||_L^2 - alpha softplus(||u|| - ||v||) 88 | 89 | where ||.||_L is the Lorentzian distance. 90 | 91 | :param alpha: penalty for distance, where higher alpha emphasises distance as a determining factor 92 | for edge direction more 93 | :param beta: -1/curvature 94 | """ 95 | 96 | def __init__( 97 | self, num_entity: int, dim: int, alpha: float = 5.0, beta: float = 1.0 98 | ): 99 | super().__init__() 100 | self.embeddings_in = torch.nn.Embedding(num_entity, dim) 101 | self.alpha = alpha 102 | self.beta = beta 103 | 104 | def forward(self, idxs: LongTensor) -> Tensor: 105 | """ 106 | Returns the score the edges between the nodes in `idxs`. 107 | :param idxs: Tensor of shape (..., 2) indicating edges, i.e. [...,0] -> [..., 1] is an edge 108 | :return: score 109 | """ 110 | euclidean_embeddings = self.embeddings_in(idxs) 111 | hyperboloid_embeddings = hyperboloid_vector(euclidean_embeddings, self.beta) 112 | dist = squared_lorentzian_distance( 113 | hyperboloid_embeddings[..., 0, :], 114 | hyperboloid_embeddings[..., 1, :], 115 | self.beta, 116 | ) 117 | euclidean_norm = torch.norm(euclidean_embeddings, dim=-1) 118 | return -dist - self.alpha * F.softplus( 119 | euclidean_norm[..., 0] - euclidean_norm[..., 1] 120 | ) 121 | 122 | 123 | class LorentzianScore(Module): 124 | """ 125 | This embedding model combines the score function from equation (8) of Nickel & Kiela 2017 126 | (https://arxiv.org/pdf/1705.08039.pdf) with the Lorentzian distance from Law et al. 2019 127 | (http://proceedings.mlr.press/v97/law19a/law19a.pdf). The score function is: 128 | 129 | -(1 + alpha (||u||^2 - ||v||^2)) ||u - v||_L^2 130 | 131 | where ||.||_L is the Lorentzian distance. 132 | 133 | :param alpha: penalty for distance, where higher alpha emphasises distance as a determining factor 134 | for edge direction more 135 | :param beta: -1/curvature 136 | """ 137 | 138 | def __init__( 139 | self, num_entity: int, dim: int, alpha: float = 1e-3, beta: float = 1.0 140 | ): 141 | super().__init__() 142 | self.embeddings_in = torch.nn.Embedding(num_entity, dim) 143 | self.alpha = alpha 144 | self.beta = beta 145 | 146 | def forward(self, idxs: LongTensor) -> Tensor: 147 | """ 148 | Returns the score the edges between the nodes in `idxs`. 149 | :param idxs: Tensor of shape (..., 2) indicating edges, i.e. [...,0] -> [..., 1] is an edge 150 | :return: score 151 | """ 152 | euclidean_embeddings = self.embeddings_in(idxs) 153 | hyperboloid_embeddings = hyperboloid_vector(euclidean_embeddings, self.beta) 154 | dist = squared_lorentzian_distance( 155 | hyperboloid_embeddings[..., 0, :], 156 | hyperboloid_embeddings[..., 1, :], 157 | self.beta, 158 | ) 159 | euclidean_norm = euclidean_embeddings.pow(2).sum(dim=-1) 160 | return ( 161 | -(1 + self.alpha * (euclidean_norm[..., 0] - euclidean_norm[..., 1])) * dist 162 | ) 163 | 164 | 165 | def squared_lorentzian_distance(a: Tensor, b: Tensor, beta: float = 1.0) -> Tensor: 166 | """ 167 | Given vectors a, b in H^{d, beta} we calculate the squared Lorentzian distance: 168 | ||a - b||_L = -2 beta - 2 _L 169 | where _L is the Lorentzian inner-product. 170 | 171 | :param a: tensor of shape (..., d+1) representing vectors in H^{d, beta} 172 | :param b: tensor of shape (..., d+1) representing vectors in H^{d, beta} 173 | :param beta: -1/curvature 174 | :return: tensor of shape (...,) representing the distance ||a - b||_L 175 | """ 176 | # First we map from RR^d to H^(d, beta) by calculating the first coordinate 177 | # Note: if we're using this for gradient descent, we can probably remove the radius from this 178 | # subsequent calculation. We could also simply perform the inner product, and use this as a direct 179 | # replacement of the inner product on RR^d... 180 | return -2 * beta - 2 * lorentzian_inner_product(a, b) 181 | 182 | 183 | def lorentzian_inner_product(a: Tensor, b: Tensor) -> Tensor: 184 | """ 185 | Given vectors a, b in H^{d, beta} we calculate the Lorentzian inner product: 186 | -a_0 b_0 + sum_i a_i b_i 187 | 188 | :param a: tensor of shape (..., d+1) representing vectors in H^{d, beta} 189 | :param b: tensor of shape (..., d+1) representing vectors in H^{d, beta} 190 | :return: tensor of shape (...,) representing the inner product _L 191 | """ 192 | prod = a * b 193 | prod[..., 0] *= -1 194 | return torch.sum(prod, dim=-1) 195 | 196 | 197 | def hyperboloid_vector(u: Tensor, beta: float = 1.0) -> Tensor: 198 | """ 199 | Given a vector u in RR^d, we map it to a vector a in the hyperboloid H^{d, beta} (where -1/beta is the 200 | curvature) by setting 201 | a_0 = sqrt(||a||^2 + radius), a_i = u_i 202 | 203 | :param u: tensor of shape (..., d) representing vectors in RR^d 204 | :param beta: -1/curvature 205 | :return: tensor of shape (...,d+1) representing a vector in H^{d, beta} 206 | """ 207 | a_0 = (u.pow(2).sum(dim=-1, keepdim=True) + beta).sqrt() 208 | return torch.cat((a_0, u), dim=-1) 209 | 210 | 211 | class HyperbolicEntailmentCones(Module): 212 | """ 213 | This embedding model represents entities as cones in the Poincare disk, where the aperture of the cone is dependent 214 | on the origin in such a way as to preserve transitivity with respect to containment. 215 | (See https://arxiv.org/pdf/1804.01882.pdf) 216 | 217 | :param relative_cone_aperture_scale: Number in (0,1] which is the relative size of the aperture with respect to 218 | distance from origin. Our implementation is such that 219 | K = relative_cone_aperture_scale * eps_bound / (1 - eps_bound^2) 220 | (see eq. (25) in the above paper as to why this is required) 221 | :param eps_bound: bounds vectors in the annulus between eps and 1-eps. 222 | """ 223 | 224 | def __init__( 225 | self, 226 | num_entity: int, 227 | dim: int, 228 | relative_cone_aperture_scale: float = 1.0, 229 | eps_bound: float = 0.1, 230 | ): 231 | super().__init__() 232 | self.eps_bound = eps_bound 233 | assert 0 < self.eps_bound < 0.5 234 | self.cone_aperature_scale = ( 235 | relative_cone_aperture_scale * self.eps_bound / (1 - self.eps_bound ** 2) 236 | ) 237 | 238 | self.angles = torch.nn.Embedding(num_entity, dim) 239 | initial_radius_range = 0.9 * (1 - 2 * self.eps_bound) 240 | initial_radius = 0.5 + initial_radius_range * (torch.rand(num_entity) - 0.5) 241 | self.radii = BoundedTemp( 242 | num_entity, initial_radius, self.eps_bound, 1 - self.eps_bound, 243 | ) 244 | 245 | def forward(self, idxs: LongTensor) -> Tensor: 246 | """ 247 | Returns the score of edges between the nodes in `idxs`. 248 | :param idxs: Tensor of shape (..., 2) indicating edges, i.e. [...,0] -> [..., 1] is an edge 249 | :return: score 250 | """ 251 | angles = F.normalize(self.angles(idxs), p=2, dim=-1) 252 | radii = self.radii(idxs) 253 | vectors = radii[..., None] * angles 254 | 255 | # test_vectors_radii = torch.linalg.norm(vectors, dim=-1) 256 | # assert (test_vectors_radii > self.eps_bound).all() 257 | # assert (test_vectors_radii < 1 - self.eps_bound).all() 258 | # assert torch.isclose(test_vectors_radii, radii).all() 259 | 260 | radii_squared = radii ** 2 261 | euclidean_dot_products = (vectors[..., 0, :] * vectors[..., 1, :]).sum(dim=-1) 262 | euclidean_distances = torch.linalg.norm( 263 | vectors[..., 0, :] - vectors[..., 1, :], dim=-1 264 | ) 265 | 266 | parent_aperature_angle_sin = ( 267 | self.cone_aperature_scale * (1 - radii_squared[..., 0]) / radii[..., 0] 268 | ) 269 | # assert (parent_aperature_angle_sin >= -1).all() 270 | # assert (parent_aperature_angle_sin <= 1).all() 271 | parent_aperature_angle = torch.arcsin(parent_aperature_angle_sin) 272 | 273 | min_angle_parent_rotation_cos = ( 274 | euclidean_dot_products * (1 + radii_squared[..., 0]) 275 | - radii_squared[..., 0] * (1 + radii_squared[..., 1]) 276 | ) / ( 277 | radii[..., 0] 278 | * euclidean_distances 279 | * torch.sqrt( 280 | 1 281 | + radii_squared[..., 0] * radii_squared[..., 1] 282 | - 2 * euclidean_dot_products 283 | ) 284 | + 1e-22 285 | ) 286 | # assert (min_angle_parent_rotation_cos >= -1).all() 287 | # assert (min_angle_parent_rotation_cos <= 1).all() 288 | # original implementation clamps this value from -1+eps to 1-eps, however it seems as though [-1, 1] is all that 289 | # is required. 290 | min_angle_parent_rotation = torch.arccos( 291 | min_angle_parent_rotation_cos.clamp(-1, 1) 292 | ) 293 | # The energy in the original formulation is clamped, which means gradients for negative examples may be squashed. 294 | return (parent_aperature_angle - min_angle_parent_rotation).clamp_max(0) 295 | # Two potential alternatives: 296 | # return -F.softplus(-parent_aperature_angle + min_angle_parent_rotation, beta=40) 297 | # (beta would have to be tuned) 298 | # The following is recommended in https://arxiv.org/pdf/1902.04335.pdf, however this requires a corresponding 299 | # adjustment to the loss function. 300 | # return parent_aperature_angle - min_angle_parent_rotation 301 | -------------------------------------------------------------------------------- /src/graph_modeling/models/poe.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import torch 18 | from torch.nn import Module 19 | 20 | __all__ = [ 21 | "OE", 22 | "POE", 23 | ] 24 | 25 | 26 | class OE(Module): 27 | def __init__(self, num_entity, dim): 28 | super().__init__() 29 | self.embeddings = torch.nn.Embedding(num_entity, dim) 30 | 31 | def forward(self, idxs): 32 | """ 33 | :param idxs: Tensor of shape (..., 2) (N, K+1, 2) during training or (N, 2) during testing 34 | :return: log prob of shape (..., ) 35 | """ 36 | e1 = self.embeddings(idxs[..., 0]) 37 | e2 = self.embeddings(idxs[..., 1]) 38 | 39 | dist = torch.max(e1, e2) - e2 40 | dist = dist.square().sum(-1) 41 | 42 | return -dist 43 | 44 | 45 | class POE(Module): 46 | def __init__(self, num_entity, dim): 47 | super().__init__() 48 | self.embeddings = torch.nn.Embedding(num_entity, dim) 49 | 50 | def log_volume(self, e): 51 | return -e.sum(-1) 52 | 53 | def intersection(self, e1, e2): 54 | return torch.max(e1, e2) 55 | 56 | def forward(self, idxs): 57 | """ 58 | :param idxs: Tensor of shape (..., 2) (N, K+1, 2) during training or (N, 2) during testing 59 | :return: log prob of shape (..., ) 60 | """ 61 | e1 = self.embeddings(idxs[..., 0]) 62 | e2 = self.embeddings(idxs[..., 1]) 63 | 64 | e_intersect = self.intersection(e1, e2) 65 | log_overlap_volume = self.log_volume(e_intersect) 66 | log_rhs_volume = self.log_volume(e2) 67 | 68 | return log_overlap_volume - log_rhs_volume 69 | -------------------------------------------------------------------------------- /src/graph_modeling/models/temps.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | from typing import * 18 | 19 | import torch 20 | from loguru import logger 21 | from torch import Tensor, LongTensor, FloatTensor 22 | from torch.nn import Module, Parameter 23 | 24 | from pytorch_utils import TorchShape 25 | 26 | __all__ = [ 27 | "BoundedTemp", 28 | "ConstTemp", 29 | "GlobalTemp", 30 | "PerDimTemp", 31 | "PerEntityTemp", 32 | "PerEntityPerDimTemp", 33 | "convert_float_to_const_temp", 34 | ] 35 | 36 | 37 | class BoundedTemp(Module): 38 | def __init__( 39 | self, 40 | shape: TorchShape, 41 | init: Union[float, Tensor] = 1.0, 42 | min: float = 0.0, 43 | max: float = 10.0, 44 | ): 45 | super().__init__() 46 | self.max = max 47 | self.min = min 48 | self.shape = shape 49 | if isinstance(init, float): 50 | self.init = torch.ones(shape) * init 51 | else: 52 | self.init = init 53 | self.temp = Parameter( 54 | torch.logit((self.init - self.min) / (self.max - self.min)) 55 | ) 56 | forward_check = self.forward() 57 | if not torch.allclose( 58 | forward_check, self.init, atol=(self.max - self.min) * 1e-8, rtol=0 59 | ): 60 | logger.warning( 61 | f"BoundedTemp with min={self.min}, max={self.max}, and init={self.init} has numerical issue which " 62 | f"results in a slightly different initialization (max error is {torch.max(self.init - forward_check)})" 63 | ) 64 | 65 | def forward( 66 | self, idx: Union[LongTensor, slice] = slice(None, None, None) 67 | ) -> FloatTensor: 68 | return (self.max - self.min) * torch.sigmoid(self.temp[idx]) + self.min 69 | 70 | 71 | class ConstTemp(Module): 72 | def __init__(self, init: float = 1.0, **kwargs): 73 | super().__init__() 74 | self.temp = Parameter(torch.tensor([init]), requires_grad=False) 75 | 76 | def forward(self, idxs: LongTensor) -> FloatTensor: 77 | """ 78 | Return a global temp with the appropriate shape to broadcast against box tensors. 79 | :param idxs: Tensor of shape (..., 2) indicating edges, i.e. [...,0] -> [..., 1] is an edge 80 | """ 81 | output_shape = [1] * (len(idxs.shape) + 2) 82 | return self.temp.view(output_shape) 83 | 84 | 85 | class GlobalTemp(Module): 86 | def __init__(self, init: float, min: float, max: float, **kwargs): 87 | super().__init__() 88 | self.temp = BoundedTemp(shape=1, init=init, min=min, max=max) 89 | 90 | def forward(self, idxs: LongTensor) -> FloatTensor: 91 | """ 92 | Return a global temp with the appropriate shape to broadcast against box tensors. 93 | :param idxs: Tensor of shape (..., 2) indicating edges, i.e. [...,0] -> [..., 1] is an edge 94 | """ 95 | output_shape = [1] * (len(idxs.shape) + 2) 96 | return self.temp(0).view(output_shape) 97 | 98 | 99 | class PerDimTemp(Module): 100 | def __init__(self, init: float, min: float, max: float, *, dim: int, **kwargs): 101 | super().__init__() 102 | self.dim = dim 103 | self.temp = BoundedTemp(shape=dim, init=init, min=min, max=max) 104 | 105 | def forward(self, idxs: LongTensor) -> FloatTensor: 106 | """ 107 | Return a per-dim temp with the appropriate shape to broadcast against box tensors. 108 | :param idxs: Tensor of shape (..., 2) indicating edges, i.e. [...,0] -> [..., 1] is an edge 109 | :param shape: Target shape to be capable of broadcasting with. 110 | """ 111 | output_shape = [1] * (len(idxs.shape) + 2) 112 | output_shape[-1] = -1 113 | return self.temp(slice(None, None, None)).view(output_shape) 114 | 115 | 116 | class PerEntityTemp(Module): 117 | def __init__( 118 | self, init: float, min: float, max: float, *, num_entities: int, **kwargs 119 | ): 120 | super().__init__() 121 | self.num_entities = num_entities 122 | self.temp = BoundedTemp(shape=num_entities, init=init, min=min, max=max) 123 | 124 | def forward(self, idxs: LongTensor) -> FloatTensor: 125 | """ 126 | Return a per-dim temp with the appropriate shape to broadcast against box tensors. 127 | :param idxs: Tensor of shape (..., 2) indicating edges, i.e. [...,0] -> [..., 1] is an edge 128 | """ 129 | output_shape = [*idxs.shape, 1, 1] 130 | return self.temp(idxs).view(output_shape) 131 | 132 | 133 | class PerEntityPerDimTemp(Module): 134 | def __init__( 135 | self, 136 | init: float, 137 | min: float, 138 | max: float, 139 | *, 140 | num_entities: int, 141 | dim: int, 142 | **kwargs, 143 | ): 144 | super().__init__() 145 | self.num_entities = num_entities 146 | self.temp = BoundedTemp(shape=(num_entities, dim), init=init, min=min, max=max) 147 | 148 | def forward(self, idxs: LongTensor) -> FloatTensor: 149 | """ 150 | Return a per-entity, per-dim temp with the appropriate shape to broadcast against box tensors. 151 | :param idxs: Tensor of shape (..., l). 152 | Often, idxs may indicate edges, in which case we should have l=2, and idxs[...,0] -> idxs[..., 1] is an edge. 153 | """ 154 | output_shape = [*idxs.shape, 1, -1] 155 | return self.temp(idxs).view(output_shape) 156 | 157 | 158 | def convert_float_to_const_temp(temp: Union[Module, float]) -> Module: 159 | """Helper function to convert floats to ConstTemp modules""" 160 | if isinstance(temp, Module): 161 | return temp 162 | else: 163 | return ConstTemp(temp) 164 | -------------------------------------------------------------------------------- /src/graph_modeling/models/vector.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import torch 18 | from torch.nn import Module, Parameter 19 | 20 | __all__ = [ 21 | "VectorSim", 22 | "VectorDist", 23 | "BilinearVector", 24 | "ComplexVector", 25 | ] 26 | 27 | 28 | class VectorSim(Module): 29 | def __init__(self, num_entity, dim, separate_io=True, use_bias=False): 30 | super().__init__() 31 | self.embeddings_in = torch.nn.Embedding(num_entity, dim) 32 | if separate_io: 33 | self.embeddings_out = torch.nn.Embedding(num_entity, dim) 34 | else: 35 | self.embeddings_out = self.embeddings_in 36 | if use_bias == True: 37 | self.bias = torch.nn.Parameter(torch.zeros(1,)) 38 | else: 39 | self.bias = 0.0 40 | 41 | def forward(self, idxs): 42 | e1 = self.embeddings_in(idxs[..., 0]) 43 | e2 = self.embeddings_out(idxs[..., 1]) 44 | logits = torch.sum(e1 * e2, dim=-1) + self.bias 45 | return logits 46 | 47 | 48 | class VectorDist(Module): 49 | def __init__(self, num_entity, dim, separate_io=True): 50 | super().__init__() 51 | self.embeddings_in = torch.nn.Embedding(num_entity, dim) 52 | if separate_io: 53 | self.embeddings_out = torch.nn.Embedding(num_entity, dim) 54 | else: 55 | self.embeddings_out = self.embeddings_in 56 | 57 | def forward(self, idxs): 58 | e1 = self.embeddings_in(idxs[..., 0]) 59 | e2 = self.embeddings_out(idxs[..., 1]) 60 | log_probs = -torch.sum(torch.square(e1 - e2), dim=-1) 61 | return log_probs 62 | 63 | 64 | class BilinearVector(Module): 65 | def __init__(self, num_entity, dim, separate_io=True, use_bias=False): 66 | super().__init__() 67 | self.embeddings_in = torch.nn.Embedding(num_entity, dim) 68 | if separate_io: 69 | self.embeddings_out = torch.nn.Embedding(num_entity, dim) 70 | else: 71 | self.embeddings_out = self.embeddings_in 72 | self.bilinear_layer = torch.nn.Bilinear(dim, dim, 1, use_bias) 73 | self.use_bias = use_bias 74 | 75 | def forward(self, idxs): 76 | e1 = self.embeddings_in(idxs[..., 0]) 77 | e2 = self.embeddings_out(idxs[..., 1]) 78 | logits = self.bilinear_layer(e1, e2).squeeze(-1) 79 | return logits 80 | 81 | 82 | class ComplexVector(Module): 83 | def __init__(self, num_entity, dim): 84 | super().__init__() 85 | self.embeddings_re = torch.nn.Embedding(num_entity, dim) 86 | self.embeddings_im = torch.nn.Embedding(num_entity, dim) 87 | self.w = Parameter(torch.randn((2, dim))) 88 | 89 | def forward(self, idxs): 90 | entities_re = self.embeddings_re(idxs) # (..., 2, dim) 91 | entities_im = self.embeddings_im(idxs) # (..., 2, dim) 92 | s_r, s_i = entities_re[..., 0, :], entities_im[..., 0, :] 93 | o_r, o_i = entities_re[..., 1, :], entities_im[..., 1, :] 94 | 95 | logits = ( 96 | (s_r * self.w[0] * o_r).sum(-1) 97 | + (s_i * self.w[0] * o_i).sum(-1) 98 | + (s_r * self.w[1] * o_i).sum(-1) 99 | - (s_i * self.w[1] * o_r).sum(-1) 100 | ) 101 | return logits 102 | -------------------------------------------------------------------------------- /src/graph_modeling/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iesl/geometric-graph-embedding/46a4ed4406bff18c9570273fce99178d0e5820c8/src/graph_modeling/training/__init__.py -------------------------------------------------------------------------------- /src/graph_modeling/training/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import click 18 | 19 | 20 | class IntOrPercent(click.ParamType): 21 | name = "click_union" 22 | 23 | def convert(self, value, param, ctx): 24 | try: 25 | float_value = float(value) 26 | if 0 <= float_value <= 1: 27 | return float_value 28 | elif float_value == int(float_value): 29 | return int(float_value) 30 | else: 31 | self.fail( 32 | f"expected float between [0,1] or int, got {float_value}", 33 | param, 34 | ctx, 35 | ) 36 | except TypeError: 37 | self.fail( 38 | "expected string for int() or float() conversion, got " 39 | f"{value!r} of type {type(value).__name__}", 40 | param, 41 | ctx, 42 | ) 43 | except ValueError: 44 | self.fail(f"{value!r} is not a valid integer or float", param, ctx) 45 | 46 | 47 | @click.command(context_settings=dict(show_default=True),) 48 | @click.option( 49 | "--data_path", 50 | type=click.Path(), 51 | help="directory or file with graph data (eg. data/graph/some_tree)", 52 | required=True, 53 | ) 54 | @click.option( 55 | "--model_type", 56 | type=click.Choice( 57 | [ 58 | "tbox", 59 | "gumbel_box", 60 | "order_embeddings", 61 | "partial_order_embeddings", 62 | "vector_sim", 63 | "vector_dist", 64 | "bilinear_vector", 65 | "complex_vector", 66 | "lorentzian_distance", 67 | "lorentzian_score", 68 | "lorentzian", 69 | "hyperbolic_entailment_cones", 70 | ], 71 | case_sensitive=False, 72 | ), 73 | default="tbox", 74 | help="model architecture to use", 75 | ) 76 | @click.option( 77 | "--negatives_permutation_option", 78 | type=click.Choice(["none", "head", "tail"], case_sensitive=False), 79 | default="none", 80 | help="whether to use permuted negatives during training, and if so whether to permute head or tail", 81 | ) 82 | @click.option( 83 | "--undirected / --directed", 84 | default=None, 85 | help="whether to train using an undirected or directed graph (default is model dependent)", 86 | show_default=False, 87 | ) 88 | @click.option( 89 | "--dim", type=int, default=4, help="dimension for embedding space", 90 | ) 91 | @click.option( 92 | "--log_batch_size", 93 | type=int, 94 | default=10, 95 | help="batch size for training will be 2**LOG_BATCH_SIZE", 96 | ) # Using batch sizes which are 2**n for some integer n may help optimize GPU efficiency 97 | @click.option( 98 | "--log_eval_batch_size", 99 | type=int, 100 | default=15, 101 | help="batch size for eval will be 2**LOG_EVAL_BATCH_SIZE", 102 | ) # Using batch sizes which are 2**n for some integer n may help optimize GPU efficiency 103 | @click.option( 104 | "--learning_rate", type=float, default=0.01, help="learning rate", 105 | ) 106 | @click.option( 107 | "--negative_weight", type=float, default=0.9, help="weight of negative loss", 108 | ) 109 | @click.option( 110 | "--margin", 111 | type=float, 112 | default=1.0, 113 | help="margin for MaxMarginWithLogitsNegativeSamplingLoss or BCEWithDistancesNegativeSamplingLoss (unused otherwise)", 114 | ) 115 | @click.option( 116 | "--negative_ratio", 117 | type=int, 118 | default=128, 119 | help="number of negative samples for each positive", 120 | ) 121 | @click.option( 122 | "--epochs", type=int, default=1_000, help="maximum number of epochs to train" 123 | ) 124 | @click.option( 125 | "--patience", 126 | type=int, 127 | default=11, 128 | help="number of log_intervals without decreased loss before stopping training", 129 | ) 130 | @click.option( 131 | "--log_interval", 132 | type=IntOrPercent(), 133 | default=0.1, 134 | help="interval or percentage (as float in [0,1]) of examples to train between logging training metrics", 135 | ) 136 | @click.option( 137 | "--eval / --no_eval", 138 | default=True, 139 | help="whether or not to evaluate the model at the end of training", 140 | ) 141 | @click.option( 142 | "--cuda / --no_cuda", default=True, help="enable/disable CUDA (eg. no nVidia GPU)", 143 | ) 144 | @click.option( 145 | "--save_prediction / --no_save_prediction", 146 | default=False, 147 | help="enable/disable saving predicted adjacency matrix", 148 | ) 149 | @click.option( 150 | "--seed", type=int, help="seed for random number generator", 151 | ) 152 | @click.option( 153 | "--wandb / --no_wandb", 154 | default=False, 155 | help="enable/disable logging to Weights and Biases", 156 | ) 157 | @click.option( 158 | "--vector_separate_io / --vector_no_separate_io", 159 | default=True, 160 | help="enable/disable using separate input/output representations for vector / bilinear vector model", 161 | ) 162 | @click.option( 163 | "--vector_use_bias / --vector_no_use_bias", 164 | default=False, 165 | help="enable/disable using bias term in vector / bilinear", 166 | ) 167 | @click.option( 168 | "--lorentzian_alpha", 169 | type=float, 170 | default=5.0, 171 | help="penalty for distance, where higher alpha emphasises distance as a determining factor in edge direction more", 172 | ) 173 | @click.option( 174 | "--lorentzian_beta", 175 | type=float, 176 | default=1.0, 177 | help="-1/curvature of the space, if beta is higher the space is less curved / more euclidean", 178 | ) 179 | @click.option( 180 | "--hyperbolic_entailment_cones_relative_cone_aperture_scale", 181 | type=float, 182 | default=1.0, 183 | help="float in (0,1) representing relative scale of cone apertures with respect to radius (K = relative_cone_aperature_scale * eps_bound / (1 - eps_bound^2))", 184 | ) 185 | @click.option( 186 | "--hyperbolic_entailment_cones_eps_bound", 187 | type=float, 188 | default=0.1, 189 | help="restrict vectors to be parameterized in an annulus from eps to 1-eps", 190 | ) 191 | @click.option( 192 | "--box_intersection_temp", 193 | type=float, 194 | default=0.01, 195 | help="temperature of intersection calculation (hyperparameter for gumbel_box, initialized value for tbox)", 196 | ) 197 | @click.option( 198 | "--box_volume_temp", 199 | type=float, 200 | default=1.0, 201 | help="temperature of volume calculation (hyperparameter for gumbel_box, initialized value for tbox)", 202 | ) 203 | @click.option( 204 | "--tbox_temperature_type", 205 | type=click.Choice(["global", "per_dim", "per_entity", "per_entity_per_dim"]), 206 | default="per_entity_per_dim", 207 | help="type of learned temperatures (for tbox model)", 208 | ) 209 | @click.option( 210 | "--output_dir", 211 | type=str, 212 | default=None, 213 | help="output directory for recording current hyper-parameters and results", 214 | ) 215 | @click.option( 216 | "--save_model / --no_save_model", 217 | type=bool, 218 | default=False, 219 | help="whether or not to save the model to disk", 220 | ) 221 | def train(**config): 222 | """Train a graph embedding representation""" 223 | from .train import training 224 | 225 | training(config) 226 | -------------------------------------------------------------------------------- /src/graph_modeling/training/loopers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | from __future__ import annotations 18 | 19 | import time 20 | from typing import * 21 | 22 | import attr 23 | import numpy as np 24 | import torch 25 | from loguru import logger 26 | from scipy.sparse import coo_matrix 27 | from torch.nn import Module 28 | from torch.utils.data import DataLoader 29 | from tqdm.autonotebook import trange, tqdm 30 | 31 | from pytorch_utils.exceptions import StopLoopingException 32 | from pytorch_utils.loggers import Logger 33 | from pytorch_utils.training import IntervalConditional 34 | from .metrics import * 35 | 36 | __all__ = [ 37 | "TrainLooper", 38 | "EvalLooper", 39 | ] 40 | 41 | 42 | @attr.s(auto_attribs=True) 43 | class TrainLooper: 44 | name: str 45 | model: Module 46 | dl: DataLoader 47 | opt: torch.optim.Optimizer 48 | loss_func: Callable 49 | scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None 50 | eval_loopers: Iterable[EvalLooper] = attr.ib(factory=tuple) 51 | early_stopping: Callable = lambda z: None 52 | logger: Logger = attr.ib(factory=Logger) 53 | summary_func: Callable[Dict] = lambda z: None 54 | save_model: Callable[Module] = lambda z: None 55 | log_interval: Optional[Union[IntervalConditional, int]] = attr.ib( 56 | default=None, converter=IntervalConditional.interval_conditional_converter 57 | ) 58 | 59 | def __attrs_post_init__(self): 60 | if isinstance(self.eval_loopers, EvalLooper): 61 | self._eval_loopers = (self.eval_loopers,) 62 | self.looper_metrics = {"Total Examples": 0} 63 | if self.log_interval is None: 64 | # by default, log every batch 65 | self.log_interval = IntervalConditional(0) 66 | self.running_losses = [] 67 | 68 | self.best_metrics_comparison_functions = {"Mean Loss": min} 69 | self.best_metrics = {} 70 | self.previous_best = None 71 | 72 | def loop(self, epochs: int): 73 | try: 74 | self.running_losses = [] 75 | for epoch in trange(epochs, desc=f"[{self.name}] Epochs"): 76 | self.model.train() 77 | with torch.enable_grad(): 78 | self.train_loop(epoch) 79 | except StopLoopingException as e: 80 | logger.warning(str(e)) 81 | finally: 82 | self.logger.commit() 83 | 84 | # load in the best model 85 | previous_device = next(iter(self.model.parameters())).device 86 | self.model.load_state_dict(self.save_model.best_model_state_dict()) 87 | self.model.to(previous_device) 88 | 89 | # evaluate 90 | metrics = [] 91 | predictions_coo = [] 92 | for eval_looper in self.eval_loopers: 93 | metric, prediction_coo = eval_looper.loop() 94 | metrics.append(metric) 95 | predictions_coo.append(prediction_coo) 96 | return metrics, predictions_coo 97 | 98 | def train_loop(self, epoch: Optional[int] = None): 99 | """ 100 | Internal loop for a single epoch of training 101 | :return: list of losses per batch 102 | """ 103 | examples_this_epoch = 0 104 | examples_in_single_epoch = len(self.dl.dataset) 105 | last_time_stamp = time.time() 106 | num_batch_passed = 0 107 | for iteration, batch_in in enumerate( 108 | tqdm(self.dl, desc=f"[{self.name}] Batch", leave=False) 109 | ): 110 | self.opt.zero_grad() 111 | 112 | batch_out = self.model(batch_in) 113 | loss = self.loss_func(batch_out) 114 | 115 | # This is not always going to be the right thing to check. 116 | # In a more general setting, we might want to consider wrapping the DataLoader in some way 117 | # with something which stores this information. 118 | num_in_batch = len(loss) 119 | 120 | loss = loss.sum(dim=0) 121 | 122 | self.looper_metrics["Total Examples"] += num_in_batch 123 | examples_this_epoch += num_in_batch 124 | 125 | if torch.isnan(loss).any(): 126 | raise StopLoopingException("NaNs in loss") 127 | self.running_losses.append(loss.detach().item()) 128 | loss.backward() 129 | 130 | for param in self.model.parameters(): 131 | if param.grad is not None: 132 | if torch.isnan(param.grad).any(): 133 | raise StopLoopingException("NaNs in grad") 134 | 135 | num_batch_passed += 1 136 | # TODO: Refactor the following 137 | self.opt.step() 138 | # If you have a scheduler, keep track of the learning rate 139 | if self.scheduler is not None: 140 | self.scheduler.step() 141 | if len(self.opt.param_groups) == 1: 142 | self.looper_metrics[f"Learning Rate"] = self.opt.param_groups[0][ 143 | "lr" 144 | ] 145 | else: 146 | for i, param_group in enumerate(self.opt.param_groups): 147 | self.looper_metrics[f"Learning Rate (Group {i})"] = param_group[ 148 | "lr" 149 | ] 150 | 151 | # Check performance every self.log_interval number of examples 152 | last_log = self.log_interval.last 153 | 154 | if self.log_interval(self.looper_metrics["Total Examples"]): 155 | current_time_stamp = time.time() 156 | time_spend = (current_time_stamp - last_time_stamp) / num_batch_passed 157 | last_time_stamp = current_time_stamp 158 | num_batch_passed = 0 159 | self.logger.collect({"avg_time_per_batch": time_spend}) 160 | 161 | self.logger.collect(self.looper_metrics) 162 | mean_loss = sum(self.running_losses) / ( 163 | self.looper_metrics["Total Examples"] - last_log 164 | ) 165 | metrics = {"Mean Loss": mean_loss} 166 | self.logger.collect( 167 | { 168 | **{ 169 | f"[{self.name}] {metric_name}": value 170 | for metric_name, value in metrics.items() 171 | }, 172 | "Epoch": epoch + examples_this_epoch / examples_in_single_epoch, 173 | } 174 | ) 175 | self.logger.commit() 176 | self.running_losses = [] 177 | self.update_best_metrics_(metrics) 178 | self.save_if_best_(self.best_metrics["Mean Loss"]) 179 | self.early_stopping(self.best_metrics["Mean Loss"]) 180 | 181 | def update_best_metrics_(self, metrics: Dict[str, float]) -> None: 182 | for name, comparison in self.best_metrics_comparison_functions.items(): 183 | if name not in self.best_metrics: 184 | self.best_metrics[name] = metrics[name] 185 | else: 186 | self.best_metrics[name] = comparison( 187 | metrics[name], self.best_metrics[name] 188 | ) 189 | self.summary_func( 190 | { 191 | f"[{self.name}] Best {name}": val 192 | for name, val in self.best_metrics.items() 193 | } 194 | ) 195 | 196 | def save_if_best_(self, best_metric) -> None: 197 | if best_metric != self.previous_best: 198 | self.save_model(self.model) 199 | self.previous_best = best_metric 200 | 201 | 202 | @attr.s(auto_attribs=True) 203 | class EvalLooper: 204 | name: str 205 | model: Module 206 | dl: DataLoader 207 | batchsize: int 208 | logger: Logger = attr.ib(factory=Logger) 209 | summary_func: Callable[Dict] = lambda z: None 210 | 211 | @torch.no_grad() 212 | def loop(self) -> Dict[str, Any]: 213 | self.model.eval() 214 | 215 | logger.debug("Evaluating model predictions on full adjacency matrix") 216 | time1 = time.time() 217 | previous_device = next(iter(self.model.parameters())).device 218 | num_nodes = self.dl.dataset.num_nodes 219 | ground_truth = np.zeros((num_nodes, num_nodes)) 220 | pos_index = self.dl.dataset.edges.cpu().numpy() 221 | # release RAM 222 | del self.dl.dataset 223 | 224 | ground_truth[pos_index[:, 0], pos_index[:, 1]] = 1 225 | 226 | prediction_scores = np.zeros((num_nodes, num_nodes)) # .to(previous_device) 227 | 228 | input_x, input_y = np.indices((num_nodes, num_nodes)) 229 | input_x, input_y = input_x.flatten(), input_y.flatten() 230 | input_list = np.stack([input_x, input_y], axis=-1) 231 | number_of_entries = len(input_x) 232 | 233 | with torch.no_grad(): 234 | pbar = tqdm( 235 | desc=f"[{self.name}] Evaluating", leave=False, total=number_of_entries 236 | ) 237 | cur_pos = 0 238 | while cur_pos < number_of_entries: 239 | last_pos = cur_pos 240 | cur_pos += self.batchsize 241 | if cur_pos > number_of_entries: 242 | cur_pos = number_of_entries 243 | 244 | ids = torch.tensor(input_list[last_pos:cur_pos], dtype=torch.long) 245 | cur_preds = self.model(ids.to(previous_device)).cpu().numpy() 246 | prediction_scores[ 247 | input_x[last_pos:cur_pos], input_y[last_pos:cur_pos] 248 | ] = cur_preds 249 | pbar.update(self.batchsize) 250 | 251 | prediction_scores_no_diag = prediction_scores[~np.eye(num_nodes, dtype=bool)] 252 | ground_truth_no_diag = ground_truth[~np.eye(num_nodes, dtype=bool)] 253 | 254 | time2 = time.time() 255 | logger.debug(f"Evaluation time: {time2 - time1}") 256 | 257 | # TODO: release self.dl from gpu 258 | del input_x, input_y 259 | 260 | logger.debug("Calculating optimal F1 score") 261 | metrics = calculate_optimal_F1(ground_truth_no_diag, prediction_scores_no_diag) 262 | time3 = time.time() 263 | logger.debug(f"F1 calculation time: {time3 - time2}") 264 | logger.info(f"Metrics: {metrics}") 265 | 266 | self.logger.collect({f"[{self.name}] {k}": v for k, v in metrics.items()}) 267 | self.logger.commit() 268 | 269 | predictions = (prediction_scores > metrics["threshold"]) * ( 270 | ~np.eye(num_nodes, dtype=bool) 271 | ) 272 | 273 | return metrics, coo_matrix(predictions) 274 | -------------------------------------------------------------------------------- /src/graph_modeling/training/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import torch 18 | from torch import Tensor 19 | from torch.nn import Module 20 | from torch.nn import functional as F 21 | 22 | from box_embeddings.common.utils import log1mexp 23 | 24 | __all__ = [ 25 | "BCEWithLogsLoss", 26 | "BCEWithLogsNegativeSamplingLoss", 27 | "BCEWithLogitsNegativeSamplingLoss", 28 | "BCEWithDistancesNegativeSamplingLoss", 29 | "MaxMarginWithLogitsNegativeSamplingLoss", 30 | "MaxMarginOENegativeSamplingLoss", 31 | "MaxMarginDiskEmbeddingNegativeSamplingLoss", 32 | ] 33 | 34 | 35 | class BCEWithLogsLoss(Module): 36 | def forward(self, input: Tensor, target: Tensor, *args, **kwargs) -> Tensor: 37 | """ 38 | :param input: log probabilities 39 | :param target: target probabilities 40 | """ 41 | return -(target * input + (1 - target) * log1mexp(input)) 42 | 43 | 44 | class BCEWithLogsNegativeSamplingLoss(Module): 45 | def __init__(self, negative_weight: float = 0.5): 46 | super().__init__() 47 | self.negative_weight = negative_weight 48 | 49 | def forward(self, log_prob_scores: Tensor) -> Tensor: 50 | """ 51 | Returns a weighted BCE loss where: 52 | (1 - negative_weight) * pos_loss + negative_weight * weighted_average(neg_loss) 53 | 54 | :param log_prob_scores: Tensor of shape (..., 1+K) where [...,0] is the score for positive examples and [..., 1:] are negative 55 | :return: weighted BCE loss 56 | """ 57 | log_prob_pos = log_prob_scores[..., 0] 58 | log_prob_neg = log_prob_scores[..., 1:] 59 | pos_loss = -log_prob_pos 60 | neg_loss = -log1mexp(log_prob_neg) 61 | logit_prob_neg = log_prob_neg + neg_loss 62 | weights = F.softmax(logit_prob_neg, dim=-1) 63 | weighted_average_neg_loss = (weights * neg_loss).sum(dim=-1) 64 | return ( 65 | 1 - self.negative_weight 66 | ) * pos_loss + self.negative_weight * weighted_average_neg_loss 67 | 68 | 69 | class BCEWithLogitsNegativeSamplingLoss(Module): 70 | """ 71 | Refer to NCE from Word2vec [1] + self-adversarial negative sampling [2]. 72 | this loss optimize similarity scores. 73 | [1] Mikolov, Tomas, et al. "Distributed representations of words and phrases and their compositionality." 74 | Advances in neural information processing systems 26 (2013): 3111-3119. 75 | [2] Sun, Zhiqing, et al. "Rotate: Knowledge graph embedding by relational rotation in complex space." 76 | arXiv preprint arXiv:1902.10197 (2019). 77 | """ 78 | 79 | def __init__(self, negative_weight: float = 0.5): 80 | super().__init__() 81 | self.negative_weight = negative_weight 82 | self.logsigmoid = torch.nn.LogSigmoid() 83 | 84 | def forward(self, logits: Tensor) -> Tensor: 85 | """ 86 | Returns a weighted BCE loss where: 87 | (1 - negative_weight) * pos_loss + negative_weight * weighted_average(neg_loss) 88 | 89 | :param logit: Tensor of shape (..., 1+K) where [...,0] is the logit for positive examples 90 | and [..., 1:] are logits for negatives 91 | :return: weighted BCE loss 92 | """ 93 | pos_scores = logits[..., 0] 94 | neg_scores = logits[..., 1:] 95 | pos_loss = -self.logsigmoid(pos_scores) 96 | neg_loss = -self.logsigmoid(-neg_scores) # sigmoid(-x) = 1 - sigmoid(x) 97 | 98 | weights = F.softmax(neg_scores, dim=-1) 99 | weighted_average_neg_loss = (weights * neg_loss).sum(dim=-1) 100 | return ( 101 | 1 - self.negative_weight 102 | ) * pos_loss + self.negative_weight * weighted_average_neg_loss 103 | 104 | 105 | class BCEWithDistancesNegativeSamplingLoss(Module): 106 | """ 107 | Refer to RotatE [1], this loss can effectively optimize distance-based models 108 | [1] Sun, Zhiqing, et al. "Rotate: Knowledge graph embedding by relational rotation in complex space." 109 | arXiv preprint arXiv:1902.10197 (2019). 110 | """ 111 | 112 | def __init__(self, negative_weight: float = 0.5, margin=1.0): 113 | super().__init__() 114 | self.negative_weight = negative_weight 115 | self.margin = margin 116 | self.logsigmoid = torch.nn.LogSigmoid() 117 | 118 | def forward(self, distance: Tensor) -> Tensor: 119 | """ 120 | :param distance: Tensor of shape (..., 1+K) where [...,0] is a distance for positive examples and [..., 1:] is 121 | a distance-based score for a negative example. 122 | :return: weighted BCE loss 123 | 124 | 125 | """ 126 | pos_dists = distance[..., 0] 127 | neg_dists = distance[..., 1:] 128 | pos_loss = -self.logsigmoid(self.margin + pos_dists) 129 | neg_loss = -self.logsigmoid(-neg_dists - self.margin) 130 | weights = F.softmax(-neg_dists, dim=-1) 131 | weighted_average_neg_loss = (weights * neg_loss).sum(dim=-1) 132 | return ( 133 | 1 - self.negative_weight 134 | ) * pos_loss + self.negative_weight * weighted_average_neg_loss 135 | 136 | 137 | class MaxMarginWithLogitsNegativeSamplingLoss(Module): 138 | def __init__(self, margin: float = 1.0): 139 | super().__init__() 140 | self.max_margin = torch.nn.MarginRankingLoss(margin, reduction="none") 141 | 142 | def forward(self, logits: Tensor) -> Tensor: 143 | """ 144 | Returns a max margin loss: max(0, margin - pos + neg) 145 | 146 | :param logits: Tensor of shape (..., 1+K) where [...,0] is the score for positive examples 147 | and [..., 1:] are scores for negatives 148 | :return: max margin loss 149 | 150 | """ 151 | pos_scores = logits[..., [0]] 152 | neg_scores = logits[..., 1:] 153 | loss = self.max_margin( 154 | pos_scores, neg_scores, torch.ones_like(neg_scores) 155 | ).mean(dim=-1) 156 | return loss 157 | 158 | 159 | class MaxMarginOENegativeSamplingLoss(Module): 160 | def __init__(self, negative_weight: float = 0.5, margin: float = 1.0): 161 | super().__init__() 162 | self.margin = margin 163 | self.negative_weight = negative_weight 164 | 165 | def forward(self, logits: Tensor) -> Tensor: 166 | """ 167 | Returns a margin loss for order embedding: loss = - pos + max(0, margin + neg) 168 | 169 | :param logits: Tensor of shape (..., 1+K) where [...,0] is the score for positive examples 170 | and [..., 1:] are scores for negatives 171 | :return: max margin loss 172 | 173 | """ 174 | pos_scores = logits[..., [0]] 175 | neg_scores = logits[..., 1:] 176 | loss = -(1 - self.negative_weight) * pos_scores.mean( 177 | dim=-1 178 | ) + self.negative_weight * torch.maximum( 179 | torch.zeros_like(neg_scores), self.margin + neg_scores 180 | ).mean( 181 | -1 182 | ) 183 | return loss 184 | 185 | 186 | class MaxMarginDiskEmbeddingNegativeSamplingLoss(Module): 187 | def __init__(self, negative_weight: float = 0.5, margin: float = 1.0): 188 | super().__init__() 189 | self.margin = margin 190 | self.negative_weight = negative_weight 191 | 192 | def forward(self, logits: Tensor) -> Tensor: 193 | """ 194 | Returns a margin loss: loss = max(0, -pos) + max(0, margin + neg) 195 | This was recommended for HEC in https://arxiv.org/pdf/1902.04335.pdf. 196 | 197 | :param scores: Tensor of shape (..., 1+K) where [...,0] is the score for positive examples 198 | and [..., 1:] are scores for negatives. Note higher score means more positive. 199 | :return: max margin loss 200 | 201 | """ 202 | pos_scores = logits[..., [0]] 203 | neg_scores = logits[..., 1:] 204 | loss = (1 - self.negative_weight) * (-pos_scores).clamp_min(0).mean( 205 | dim=-1 206 | ) + self.negative_weight * (self.margin + neg_scores).clamp_min(0).mean(dim=-1) 207 | return loss 208 | -------------------------------------------------------------------------------- /src/graph_modeling/training/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | from typing import * 18 | 19 | import numpy as np 20 | import sklearn.metrics 21 | from loguru import logger 22 | 23 | __all__ = [ 24 | "calculate_optimal_threshold", 25 | "calculate_optimal_F1_threshold", 26 | "calculate_metrics", 27 | "numpy_metrics", 28 | "calculate_optimal_F1", 29 | ] 30 | 31 | 32 | def calculate_optimal_threshold(targets: np.ndarray, scores: np.ndarray) -> float: 33 | fpr, tpr, thresholds = sklearn.metrics.roc_curve( 34 | targets, scores, drop_intermediate=False 35 | ) 36 | return thresholds[np.argmax(tpr - fpr)] 37 | 38 | 39 | def calculate_optimal_F1_threshold(targets: np.ndarray, scores: np.ndarray) -> float: 40 | fpr, tpr, thresholds = sklearn.metrics.roc_curve( 41 | targets, scores, drop_intermediate=False 42 | ) 43 | num_pos = targets.sum() 44 | num_neg = (1 - targets).sum() 45 | logger.debug(f"Calculating F1 with {num_pos} positive, {num_neg} negative") 46 | f1 = 2 * tpr / (1 + tpr + fpr * num_neg / num_pos) 47 | return thresholds[np.argmax(f1)] 48 | 49 | 50 | def calculate_metrics( 51 | targets: np.ndarray, scores: np.ndarray, threshold: float 52 | ) -> Dict[str, float]: 53 | scores_hard = scores > threshold 54 | return { 55 | "Accuracy": (scores_hard == targets).mean(), 56 | "F1": sklearn.metrics.f1_score(targets, scores_hard), 57 | } 58 | 59 | 60 | def numpy_metrics(targets: np.ndarray, predictions: np.ndarray) -> Dict[str, float]: 61 | assert targets.dtype == np.bool 62 | assert predictions.dtype == np.bool 63 | 64 | true_positives = (predictions & targets).sum() 65 | false_positives = (predictions & ~targets).sum() 66 | false_negatives = (~predictions & targets).sum() 67 | return { 68 | "Accuracy": (predictions == targets).mean(), 69 | "F1": true_positives 70 | / (true_positives + (false_positives + false_negatives) / 2), 71 | } 72 | 73 | 74 | def calculate_optimal_F1(targets, scores) -> Dict[str, float]: 75 | fpr, tpr, thresholds = sklearn.metrics.roc_curve( 76 | targets, scores, drop_intermediate=False 77 | ) 78 | auc = sklearn.metrics.auc(fpr, tpr) 79 | num_pos = targets.sum() 80 | num_neg = (1 - targets).sum() 81 | f1 = 2 * tpr / (1 + tpr + fpr * num_neg / num_pos) 82 | threshold = thresholds[np.argmax(f1)] 83 | return {"F1": float(np.max(f1)), "AUC": float(auc), "threshold": threshold} 84 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import logging 18 | import pytest 19 | from _pytest.logging import caplog as _caplog 20 | from loguru import logger 21 | 22 | 23 | @pytest.fixture 24 | def caplog(_caplog): 25 | class PropogateHandler(logging.Handler): 26 | def emit(self, record): 27 | logging.getLogger(record.name).handle(record) 28 | 29 | handler_id = logger.add(PropogateHandler(), format="{message} {extra}") 30 | yield _caplog 31 | logger.remove(handler_id) 32 | -------------------------------------------------------------------------------- /tests/data/kronecker_graph.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iesl/geometric-graph-embedding/46a4ed4406bff18c9570273fce99178d0e5820c8/tests/data/kronecker_graph.npz -------------------------------------------------------------------------------- /tests/data/kronecker_graph.toml: -------------------------------------------------------------------------------- 1 | type = "kronecker_graph" 2 | seed = 3179112883 3 | a = 1.0 4 | b = 0.8 5 | c = 0.6 6 | d = 0.5 7 | log_num_nodes = 8 8 | transitive_closure = false 9 | -------------------------------------------------------------------------------- /tests/data/random_tree.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iesl/geometric-graph-embedding/46a4ed4406bff18c9570273fce99178d0e5820c8/tests/data/random_tree.npz -------------------------------------------------------------------------------- /tests/data/random_tree.param.toml: -------------------------------------------------------------------------------- 1 | type = "random_tree" 2 | num_nodes = 1000 3 | transitive_closure = false 4 | -------------------------------------------------------------------------------- /tests/models/test_box.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import torch 18 | import pytest 19 | from hypothesis import given, strategies as st 20 | from graph_modeling.models.box import * 21 | from graph_modeling.models.temps import * 22 | 23 | 24 | @given( 25 | num_entities=st.integers(100, 200), 26 | dim=st.integers(20, 100), 27 | batch_size=st.integers(1, 100), 28 | ) 29 | @pytest.mark.parametrize( 30 | "IntersectionTempClass", 31 | [ConstTemp, GlobalTemp, PerDimTemp, PerEntityTemp, PerEntityPerDimTemp], 32 | ) 33 | @pytest.mark.parametrize( 34 | "VolumeTempClass", 35 | [ConstTemp, GlobalTemp, PerDimTemp, PerEntityTemp, PerEntityPerDimTemp], 36 | ) 37 | def test_tbox(IntersectionTempClass, VolumeTempClass, num_entities, dim, batch_size): 38 | """Verify the performant forward pass is accurate using the naive implementation""" 39 | box_model = TBox( 40 | num_entities, 41 | dim, 42 | intersection_temp=IntersectionTempClass( 43 | 0.01, min=0.0, max=10.0, num_entities=num_entities, dim=dim 44 | ), 45 | volume_temp=VolumeTempClass( 46 | 1.0, min=0.0, max=100.0, num_entities=num_entities, dim=dim 47 | ), 48 | ) 49 | box_model.train() 50 | optim = torch.optim.SGD(box_model.parameters(), lr=1e-2) 51 | idxs = torch.randint(num_entities, size=[batch_size, 2]) 52 | output = box_model(idxs) 53 | loss = output.sum() 54 | optim.step() 55 | -------------------------------------------------------------------------------- /tests/models/test_temps.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | from pytorch_utils import is_broadcastable 18 | from graph_modeling.models.temps import * 19 | import torch 20 | import pytest 21 | from hypothesis import given, strategies as st 22 | 23 | 24 | @given( 25 | num_entities=st.integers(100, 200), 26 | dim=st.integers(20, 100), 27 | batch_size=st.integers(1, 100), 28 | axes=st.lists(st.integers(1, 10), max_size=4), 29 | ) 30 | @pytest.mark.parametrize( 31 | "TempClass", [ConstTemp, GlobalTemp, PerDimTemp, PerEntityTemp, PerEntityPerDimTemp] 32 | ) 33 | def test_broadcastability(TempClass, num_entities, dim, batch_size, axes): 34 | temp_module = TempClass( 35 | init=1.0, min=0.0, max=10.0, num_entities=num_entities, dim=dim 36 | ) 37 | random_size = [batch_size, *axes] 38 | idxs = torch.randint(num_entities, size=tuple(random_size)) 39 | output = temp_module(idxs) 40 | assert is_broadcastable(output.shape, random_size + [2, dim]) 41 | 42 | 43 | @pytest.mark.parametrize("init", [10.0 ** i for i in range(-8, 3)]) 44 | @pytest.mark.parametrize("max", [0.1, 1.0, 10.0, 100.0, 1000.0]) 45 | def test_bounded_temp(init, max): 46 | """ 47 | Make sure BoundedTemp initializes the values correctly. 48 | Note: in general, there will always be settings of (init, min, max) for which the BoundedTemp implementation will 49 | not be initialized exactly correctly due to numerical issues. 50 | """ 51 | if init < max: 52 | bounded_temp = BoundedTemp(1, init, 0.0, max) 53 | assert torch.allclose( 54 | torch.tensor(init), bounded_temp(), atol=max * 1e-8, rtol=0 55 | ) 56 | bounded_temp.temp.data += 1e10 57 | assert torch.allclose( 58 | torch.tensor(max), bounded_temp(), atol=max * 1e-8, rtol=0 59 | ) 60 | bounded_temp.temp.data -= 1e20 61 | assert torch.allclose( 62 | torch.tensor(0.0), bounded_temp(), atol=max * 1e-8, rtol=0 63 | ) 64 | 65 | 66 | def test_bounded_warning(caplog): 67 | """Make sure BoundedTemp logs a warning if we cannot initialize as requested due to numerical issues""" 68 | bounded_temp = BoundedTemp(1, 0.0, -0.16589745133188674, 2.2204460492503134e-08) 69 | assert "BoundedTemp" in caplog.text 70 | -------------------------------------------------------------------------------- /tests/training/test_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Geometric Graph Embedding Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | from graph_modeling.training.dataset import * 18 | from graph_modeling.enums import PermutationOption 19 | from pytorch_utils.tensordataloader import TensorDataLoader 20 | from hypothesis import given, strategies as st 21 | from hypothesis.extra import numpy as hnp 22 | import numpy as np 23 | import torch 24 | from torch import LongTensor 25 | import os 26 | import pytest 27 | from pathlib import Path 28 | from typing import * 29 | 30 | TEST_DIR = Path(os.path.dirname(__file__)).parent # top-level test directory 31 | 32 | 33 | class GraphData(NamedTuple): 34 | path: Optional[Path] = None 35 | edges: Union[None, LongTensor, np.ndarray] = None 36 | num_nodes: Optional[int] = None 37 | num_edges: Optional[int] = None 38 | 39 | 40 | @pytest.fixture() 41 | def random_tree_npz() -> GraphData: 42 | return GraphData( 43 | path=TEST_DIR / "data/random_tree.npz", num_nodes=1000, num_edges=999 44 | ) 45 | 46 | 47 | @pytest.fixture() 48 | def kronecker_graph_npz() -> GraphData: 49 | return GraphData( 50 | path=TEST_DIR / "data/kronecker_graph.npz", num_nodes=256, num_edges=4981, 51 | ) 52 | 53 | 54 | @pytest.fixture( 55 | params=["random_tree_npz", "kronecker_graph_npz",] 56 | ) 57 | def arbitrary_graph(request): 58 | return request.getfixturevalue(request.param) 59 | 60 | 61 | def test_edges_and_num_nodes_from_npz(arbitrary_graph: GraphData): 62 | """Simply test if GraphDatasetUniformNegatives can load from npz without error""" 63 | edges, num_nodes = edges_and_num_nodes_from_npz(arbitrary_graph.path) 64 | assert len(edges) == arbitrary_graph.num_edges 65 | assert num_nodes == arbitrary_graph.num_nodes 66 | 67 | 68 | def test_get_data_from_graph_dataset(arbitrary_graph): 69 | g = GraphDataset(*edges_and_num_nodes_from_npz(arbitrary_graph.path)) 70 | t = TensorDataLoader(g, batch_size=100, shuffle=True) 71 | batch = next(iter(t)) 72 | assert batch.shape == (100, 2) 73 | 74 | 75 | @pytest.mark.parametrize( 76 | "permutation_option", 77 | [PermutationOption.none, PermutationOption.head, PermutationOption.tail], 78 | ) 79 | def test_uniform_negatives_dataset( 80 | arbitrary_graph: GraphData, permutation_option: PermutationOption 81 | ): 82 | positives, num_nodes = edges_and_num_nodes_from_npz(arbitrary_graph.path) 83 | g = GraphDataset( 84 | positives, 85 | num_nodes, 86 | RandomNegativeEdges( 87 | num_nodes=num_nodes, 88 | negative_ratio=100, 89 | avoid_edges=positives, 90 | permutation_option=permutation_option, 91 | ), 92 | ) 93 | full_graph = torch.zeros((g.num_nodes, g.num_nodes)) 94 | full_graph[positives[:, 0], positives[:, 1]] = 1 95 | batch = g[torch.arange(len(positives))] 96 | # Check if [:,0] is where positives are located 97 | assert (batch[:, 0] == positives).all() 98 | # Check that all the negatives are zero in the ground-truth graph 99 | assert (full_graph[batch[:, 1:, 0], batch[:, 1:, 1]] == 0).all() 100 | # For permuted variants, check that the head or tail is correct 101 | if permutation_option == PermutationOption.head: 102 | assert (batch[:, :, 0] == positives[:, None, 0]).all() 103 | if permutation_option == PermutationOption.tail: 104 | assert (batch[:, :, 1] == positives[:, None, 1]).all() 105 | 106 | 107 | @given( 108 | num_entities=st.integers(100, 2000), 109 | negative_ratio=st.integers(1, 100), 110 | batch_size=st.integers(1, 100), 111 | axes=st.lists(st.integers(1, 10), max_size=2), 112 | ) 113 | def test_uniform_random_edges(num_entities, negative_ratio, batch_size, axes): 114 | random_negatives = RandomEdges(num_entities, negative_ratio) 115 | random_size = [batch_size, *axes, 2] 116 | pos_idxs = torch.randint(num_entities, size=tuple(random_size)) 117 | output = random_negatives(pos_idxs) 118 | assert output.shape == tuple((*pos_idxs.shape[:-1], negative_ratio, 2)) 119 | 120 | 121 | @given( 122 | num_entities=st.integers(100, 2000), 123 | negative_ratio=st.integers(1, 100), 124 | batch_size=st.integers(1, 100), 125 | axes=st.lists(st.integers(1, 10), max_size=2), 126 | ) 127 | def test_permuted_random_edges(num_entities, negative_ratio, batch_size, axes): 128 | random_negatives = RandomEdges(num_entities, negative_ratio, permuted=True) 129 | random_size = [batch_size, *axes, 2] 130 | pos_idxs = torch.randint(num_entities, size=tuple(random_size)) 131 | output = random_negatives(pos_idxs) 132 | assert output.shape == tuple((*pos_idxs.shape[:-1], negative_ratio, 2)) 133 | pos_idxs = pos_idxs[..., None, :] 134 | assert (output == pos_idxs).any(dim=-1).all() 135 | 136 | 137 | @st.composite 138 | def generate_edges( 139 | draw, 140 | num_nodes=st.integers(min_value=2, max_value=1_000_000_000), 141 | num_edges=st.integers(min_value=1, max_value=100_000), 142 | ): 143 | num_nodes = draw(num_nodes) 144 | edges = draw( 145 | hnp.arrays( 146 | dtype=np.int64, 147 | shape=(draw(num_edges), 2), 148 | elements=st.integers(min_value=0, max_value=num_nodes - 1), 149 | ) 150 | ) 151 | return torch.from_numpy(edges), num_nodes 152 | 153 | 154 | @given(edges_and_num_nodes=generate_edges()) 155 | def test_edges_to_ints_and_back(edges_and_num_nodes): 156 | edges, num_nodes = edges_and_num_nodes 157 | ints = convert_edges_to_ints(edges, num_nodes) 158 | assert ints.shape == edges.shape[:-1] 159 | edges_back = convert_ints_to_edges(ints, num_nodes) 160 | assert edges_back.shape == edges.shape 161 | assert (edges == edges_back).all() 162 | 163 | 164 | @st.composite 165 | def generate_positive_edges( 166 | draw, num_nodes=st.integers(min_value=2, max_value=1_000_000_000) 167 | ): 168 | """Generates edges with at least one negative still possible""" 169 | num_nodes = draw(num_nodes) 170 | max_edges = num_nodes ** 2 171 | num_edges = draw(st.integers(min_value=1, max_value=min(1_000, max_edges - 1))) 172 | ints = draw( 173 | hnp.arrays( 174 | dtype=np.int64, 175 | shape=(num_edges,), 176 | elements=st.integers(min_value=0, max_value=max_edges - 1), 177 | unique=True, 178 | ) 179 | ) 180 | return convert_ints_to_edges(torch.from_numpy(ints), num_nodes), num_nodes 181 | 182 | 183 | @given( 184 | edges_and_num_nodes=generate_positive_edges(), 185 | negative_ratio=st.integers(min_value=1, max_value=128), 186 | batch_size=st.integers(1, 32), 187 | axes=st.lists(st.integers(1, 16), max_size=2), 188 | ) 189 | def test_uniform_random_negative_edges( 190 | edges_and_num_nodes, negative_ratio, batch_size, axes 191 | ): 192 | avoid_edges, num_nodes = edges_and_num_nodes 193 | random_edges = RandomNegativeEdges(num_nodes, negative_ratio, avoid_edges) 194 | # It doesn't matter what the positive edges are, since we are doing uniform negative sampling. 195 | # We just need any tensor with the given shape. 196 | positive_edges = torch.empty(size=(batch_size, *axes, 2), dtype=torch.long) 197 | sample_negative_edges = random_edges(positive_edges) 198 | negative_shape = (*positive_edges.shape[:-1], negative_ratio, 2) 199 | assert sample_negative_edges.dtype == torch.long 200 | assert sample_negative_edges.shape == negative_shape 201 | sample_edges_ints = convert_edges_to_ints(sample_negative_edges, num_nodes).numpy() 202 | avoid_edges_ints = convert_edges_to_ints(avoid_edges, num_nodes).numpy() 203 | assert not np.isin(sample_edges_ints, avoid_edges_ints).any() 204 | 205 | 206 | @st.composite 207 | def generate_positive_edges_for_permuted_negatives( 208 | draw, 209 | num_nodes=st.integers(min_value=2, max_value=1_000), 210 | permutation_option: PermutationOption = PermutationOption.head, 211 | ): 212 | """ 213 | Generates edges with at least one negative per tail/head still possible (depending on permutation_option). 214 | Unfortunately, this is a bit circular, as the logic used to generate such edges is similar to that which will 215 | generate the negatives. One could do this in small cases using rejection sampling, but it is very slow. 216 | On the other hand, the logic used to generate the valid permutation of edges is basically the same as 217 | RandomIntsAvoid, which can (and has) already been rigorously tested with more naive sampling approaches. 218 | """ 219 | permutation_option = PermutationOption(permutation_option) 220 | num_nodes = draw(num_nodes) 221 | # first, generate some true negative nodes, one per tail 222 | true_negatives = ( 223 | draw( 224 | hnp.arrays( 225 | dtype=np.int64, 226 | shape=(num_nodes,), 227 | elements=st.integers(min_value=0, max_value=num_nodes - 1), 228 | unique=False, 229 | ) 230 | ) 231 | + np.arange(num_nodes) * num_nodes 232 | ) 233 | 234 | # After this, there are actually at most num_nodes * (num_nodes - 1) edges remaining 235 | max_edges = num_nodes * (num_nodes - 1) 236 | num_edges = draw(st.integers(min_value=1, max_value=min(1_000, max_edges))) 237 | # we will draw this many edges, and then use the method of RandomIntsAvoid from pytorch_utils.random 238 | # to generate the actual integers 239 | ints = draw( 240 | hnp.arrays( 241 | dtype=np.int64, 242 | shape=(num_edges,), 243 | elements=st.integers(min_value=0, max_value=max_edges - 1), 244 | unique=True, 245 | ) 246 | ) 247 | buckets = torch.from_numpy(true_negatives - np.arange(num_nodes)) 248 | ints = torch.from_numpy(ints) 249 | ints += torch.bucketize(ints, buckets, right=True) 250 | 251 | edges = convert_ints_to_edges(ints, num_nodes) 252 | if permutation_option == PermutationOption.tail: 253 | edges = edges[..., [1, 0]] 254 | return edges, num_nodes 255 | 256 | 257 | @st.composite 258 | def generate_positive_edges_and_sample( 259 | draw, 260 | edges_and_num_nodes=generate_positive_edges(), 261 | batch_size=st.integers(1, 32), 262 | axes=st.lists(st.integers(1, 16), max_size=2), 263 | ): 264 | avoid_edges, num_nodes = draw(edges_and_num_nodes) 265 | # For the purposes of these tests, we will treat all avoid_edges as positives. In practice, avoid_edges and 266 | # positives may be different sets of edges (eg. if avoid_edges contains diagonal, but positives does not). 267 | _batch_size = draw(batch_size) 268 | _axes = draw(axes) 269 | positive_batch = avoid_edges[ 270 | torch.randint(len(avoid_edges), size=(_batch_size, *_axes,)) 271 | ] 272 | # Positives would also generally be chosen without replacement, but that also shouldn't matter for these tests. 273 | return avoid_edges, num_nodes, _batch_size, _axes, positive_batch 274 | 275 | 276 | @given( 277 | edges_and_sample=generate_positive_edges_and_sample( 278 | edges_and_num_nodes=generate_positive_edges_for_permuted_negatives( 279 | permutation_option=PermutationOption.head 280 | ) 281 | ), 282 | negative_ratio=st.integers(min_value=1, max_value=128), 283 | ) 284 | def test_permuted_head_random_negative_edges(edges_and_sample, negative_ratio): 285 | """Test that we can generate random negatives for a given tail node""" 286 | avoid_edges, num_nodes, batch_size, axes, positive_batch = edges_and_sample 287 | random_edges = RandomNegativeEdges( 288 | num_nodes, negative_ratio, avoid_edges, permutation_option="head" 289 | ) 290 | # We just need any tensor with the given shape. 291 | sample_negative_edges = random_edges(positive_batch) 292 | negative_shape = (*positive_batch.shape[:-1], negative_ratio, 2) 293 | assert sample_negative_edges.dtype == torch.long 294 | assert sample_negative_edges.shape == negative_shape 295 | sample_edges_ints = convert_edges_to_ints(sample_negative_edges, num_nodes).numpy() 296 | avoid_edges_ints = convert_edges_to_ints(avoid_edges, num_nodes).numpy() 297 | assert not np.isin(sample_edges_ints, avoid_edges_ints).any() 298 | assert (positive_batch[..., None, 0] == sample_negative_edges[..., 0]).all() 299 | 300 | 301 | @given( 302 | edges_and_sample=generate_positive_edges_and_sample( 303 | edges_and_num_nodes=generate_positive_edges_for_permuted_negatives( 304 | permutation_option=PermutationOption.tail 305 | ) 306 | ), 307 | negative_ratio=st.integers(min_value=1, max_value=128), 308 | ) 309 | def test_permuted_tail_random_negative_edges(edges_and_sample, negative_ratio): 310 | """Test that we can generate random negatives for a given head node""" 311 | avoid_edges, num_nodes, batch_size, axes, positive_batch = edges_and_sample 312 | random_edges = RandomNegativeEdges( 313 | num_nodes, negative_ratio, avoid_edges, permutation_option="tail" 314 | ) 315 | # We just need any tensor with the given shape. 316 | sample_negative_edges = random_edges(positive_batch) 317 | negative_shape = (*positive_batch.shape[:-1], negative_ratio, 2) 318 | assert sample_negative_edges.dtype == torch.long 319 | assert sample_negative_edges.shape == negative_shape 320 | sample_edges_ints = convert_edges_to_ints(sample_negative_edges, num_nodes).numpy() 321 | avoid_edges_ints = convert_edges_to_ints(avoid_edges, num_nodes).numpy() 322 | assert not np.isin(sample_edges_ints, avoid_edges_ints).any() 323 | assert (positive_batch[..., None, 1] == sample_negative_edges[..., 1]).all() 324 | --------------------------------------------------------------------------------