├── .github └── workflows │ └── pypi-publish.yml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── Makefile ├── api.rst ├── conf.py ├── index.rst └── requirements.txt ├── images └── logo.png ├── jraph ├── __init__.py ├── _src │ ├── graph.py │ ├── models.py │ ├── models_test.py │ ├── utils.py │ └── utils_test.py ├── examples │ ├── basic.py │ ├── e_voting.py │ ├── game_of_life.py │ ├── hamiltonian_graph_network.py │ ├── higgs_detection.py │ ├── lstm.py │ ├── sat.py │ └── zacharys_karate_club.py ├── experimental │ ├── sharded_graphnet.py │ └── sharded_graphnet_test.py └── ogb_examples │ ├── data_utils.py │ ├── data_utils_test.py │ ├── test_data │ ├── edge-feat.csv.gz │ ├── edge.csv.gz │ ├── graph-label.csv.gz │ ├── master.csv │ ├── node-feat.csv.gz │ ├── num-edge-list.csv.gz │ ├── num-node-list.csv.gz │ └── train.csv.gz │ ├── train.py │ ├── train_flax.py │ ├── train_flax_test.py │ ├── train_pmap.py │ ├── train_pmap_test.py │ └── train_test.py ├── readthedocs.yml ├── requirements.txt └── setup.py /.github/workflows/pypi-publish.yml: -------------------------------------------------------------------------------- 1 | name: pypi 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: '3.x' 16 | - name: Check consistency between the package version and release tag 17 | run: | 18 | RELEASE_VER=${GITHUB_REF#refs/*/} 19 | PACKAGE_VER="v`python setup.py --version`" 20 | if [ $RELEASE_VER != $PACKAGE_VER ] 21 | then 22 | echo "package ver. ($PACKAGE_VER) != release ver. ($RELEASE_VER)"; exit 1 23 | fi 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install setuptools wheel twine 28 | - name: Build and publish 29 | env: 30 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 31 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 32 | run: | 33 | python setup.py sdist bdist_wheel 34 | twine upload dist/* 35 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing guidelines 2 | 3 | Thank you for your interest in contributing to Jraph! If you have improvements 4 | to the library, we would love to accept your pull requests. In particular, 5 | we are always very happy to have new model implementations in our model zoo. 6 | 7 | ## How to become a contributor and submit your own code 8 | 9 | ### Contributor License Agreements 10 | 11 | Please fill out either the individual or corporate Contributor License Agreement (CLA). 12 | 13 | * If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA (http://code.google.com/legal/individual-cla-v1.0.html). 14 | 15 | * If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA (http://code.google.com/legal/corporate-cla-v1.0.html). 16 | 17 | Follow either of the two links above to access the appropriate CLA and instructions for how to sign and return it. Once we receive it, we'll be able to accept your pull requests. 18 | 19 | ***NOTE***: Only original source code from you and other people that have signed the CLA can be accepted into the main repository. 20 | 21 | ### Contributing code 22 | If you have improvements to this library, send us your pull requests! For those just getting started, GitHub has a [howto](https://help.github.com/articles/using-pull-requests/). 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ![logo](images/logo.png) 3 | # Jraph - A library for graph neural networks in jax. 4 | 5 | 6 | ## New! PMAP Examples and Data Loading. 7 | 8 | We have added a pmap [example](https://github.com/deepmind/jraph/tree/master/jraph/ogb_examples/train_pmap.py). 9 | 10 | Our friends at instadeep, Jama Hussein Mohamud and Tom Makkink 11 | have put together a nice guide to using pytorch data loading. Find it [here](https://colab.research.google.com/drive/1_X2su92_nS52RNl4m-WYvmkvUSrFE4xQ). 12 | 13 | 14 | ## New! Support For Large Distributed MPNNs 15 | 16 | We have released a distributed graph network implementation that allows you to 17 | distribute a very large (millions of edges) graph network with explicit edge 18 | messages across multiple devices. [**Check it out!**](https://github.com/deepmind/jraph/tree/master/jraph/experimental) 19 | 20 | ## New! Interactive Jraph Colabs 21 | 22 | We have two new colabs to help you get to grips with Jraph. 23 | 24 | The first is an educational colab with an amazing introduction to graph neural networks, graph theory,shows you how to use Jraph to solve a number of problems. Check it out [**here**.](https://github.com/deepmind/educational/blob/master/colabs/summer_schools/intro_to_graph_nets_tutorial_with_jraph.ipynb) 25 | 26 | The second is a fully working example with best practices of using Jraph with OGBG-MOLPCBA with some great visualizations. Check it out [**here**.](https://github.com/google/flax/tree/main/examples/ogbg_molpcba) 27 | 28 | Thank you to Lisa Wang, Nikola Jovanović & Ameya Daigavane. 29 | 30 | ## Quick Start 31 | 32 | [**Quick Start** ](#quick-start)|[ **Documentation** ](https://jraph.readthedocs.io/en/latest/) 33 | 34 | Jraph (pronounced "giraffe") is a lightweight library for working with graph 35 | neural networks in jax. It provides a data structure for graphs, a set of 36 | utilities for working with graphs, and a 'zoo' of forkable graph neural network 37 | models. 38 | 39 | ## Installation 40 | 41 | ```pip install jraph``` 42 | 43 | Or Jraph can be installed directly from github using the following command: 44 | 45 | ```pip install git+git://github.com/deepmind/jraph.git``` 46 | 47 | The examples require additional dependencies. To install them please run: 48 | 49 | ```pip install "jraph[examples, ogb_examples] @ git+git://github.com/deepmind/jraph.git"``` 50 | 51 | ## Overview 52 | 53 | Jraph is designed to provide utilities for working with graphs in jax, but 54 | doesn't prescribe a way to write or develop graph neural networks. 55 | 56 | * `graph.py` provides a lightweight data structure, `GraphsTuple`, for working 57 | with graphs. 58 | * `utils.py` provides utilities for working with `GraphsTuples` in jax. 59 | * Utilities for batching datasets of `GraphsTuples`. 60 | * Utilities to support jit compilation of variable shaped graphs via 61 | padding and masking. 62 | * Utilities for defining losses on partitions of inputs. 63 | * `models.py` provides examples of different types of graph neural network 64 | message passing. These are designed to be lightweight, easy to fork and 65 | adapt. They do not manage parameters for you - for that, consider using 66 | `haiku` or `flax`. See the examples for more details. 67 | 68 | 69 | ## Quick Start 70 | 71 | Jraph takes inspiration from the Tensorflow [graph_nets library](https://github.com/deepmind/graph_nets) in defining a `GraphsTuple` 72 | data structure, which is a namedtuple that contains one or more directed graphs. 73 | 74 | ### Representing Graphs - The `GraphsTuple` 75 | 76 | ```python 77 | import jraph 78 | import jax.numpy as jnp 79 | 80 | # Define a three node graph, each node has an integer as its feature. 81 | node_features = jnp.array([[0.], [1.], [2.]]) 82 | 83 | # We will construct a graph for which there is a directed edge between each node 84 | # and its successor. We define this with `senders` (source nodes) and `receivers` 85 | # (destination nodes). 86 | senders = jnp.array([0, 1, 2]) 87 | receivers = jnp.array([1, 2, 0]) 88 | 89 | # You can optionally add edge attributes. 90 | edges = jnp.array([[5.], [6.], [7.]]) 91 | 92 | # We then save the number of nodes and the number of edges. 93 | # This information is used to make running GNNs over multiple graphs 94 | # in a GraphsTuple possible. 95 | n_node = jnp.array([3]) 96 | n_edge = jnp.array([3]) 97 | 98 | # Optionally you can add `global` information, such as a graph label. 99 | 100 | global_context = jnp.array([[1]]) 101 | graph = jraph.GraphsTuple(nodes=node_features, senders=senders, receivers=receivers, 102 | edges=edges, n_node=n_node, n_edge=n_edge, globals=global_context) 103 | ``` 104 | 105 | A `GraphsTuple` can have more than one graph. 106 | 107 | ```python 108 | two_graph_graphstuple = jraph.batch([graph, graph]) 109 | ``` 110 | 111 | The node and edge features are stacked on the leading axis. 112 | 113 | ```python 114 | jraph.batch([graph, graph]).nodes 115 | >>> DeviceArray([[0.], 116 | [1.], 117 | [2.], 118 | [0.], 119 | [1.], 120 | [2.]], dtype=float32) 121 | ``` 122 | 123 | You can tell which nodes are from which graph by looking at `n_node`. 124 | 125 | ```python 126 | jraph.batch([graph, graph]).n_node 127 | >>> DeviceArray([3, 3], dtype=int32) 128 | ``` 129 | 130 | You can store nests of features in `nodes`, `edges` and `globals`. This makes 131 | it possible to store multiple sets of features for each node, edge or graph, with 132 | potentially different types and semantically different meanings (for example 133 | 'training' and 'testing' nodes). The only requirement if that all arrays within 134 | each nest must have a common leading dimensions size, matching the total number 135 | of nodes, edges or graphs within the `Graphstuple` respectively. 136 | 137 | ```python 138 | node_targets = jnp.array([[True], [False], [True]]) 139 | graph = graph._replace(nodes={'inputs': graph.nodes, 'targets': node_targets}) 140 | ``` 141 | 142 | ### Using the Model Zoo 143 | 144 | Jraph provides a set of implemented reference models for you to use. 145 | 146 | A Jraph model defines a message passing algorithm between the nodes, edges and 147 | global attributes of a graph. The user defines `update` functions that update graph features, which are typically neural networks but can be arbitrary jax functions. 148 | 149 | Let's go through a `GraphNetwork` [(paper)](https://arxiv.org/abs/1806.01261) example. 150 | A GraphNet's first update function updates the edges using `edge` features, 151 | the node features of the `sender` and `receiver` and the `global` features. 152 | 153 | 154 | ```python 155 | # As one example, we just pass the edge features straight through. 156 | def update_edge_fn(edge, sender, receiver, globals_): 157 | return edge 158 | ``` 159 | 160 | Often we use the concatenation of these features, and `jraph` provides an easy 161 | way of doing this with the `concatenated_args` decorator. 162 | 163 | ```python 164 | @jraph.concatenated_args 165 | def update_edge_fn(concatenated_features): 166 | return concatenated_features 167 | ``` 168 | Typically, a learned model such as a Multi-Layer Perceptron is used within an 169 | update function. 170 | 171 | The user similarly defines functions that update the nodes and globals. These 172 | are then used to configure a `GraphNetwork`. To see the arguments to the node 173 | and global `update_fns` please take a look at the model zoo. 174 | 175 | ```python 176 | net = jraph.GraphNetwork(update_edge_fn=update_edge_fn, 177 | update_node_fn=update_node_fn, 178 | update_global_fn=update_global_fn) 179 | ``` 180 | 181 | `net` is a function that sends messages according to the `GraphNetwork` algorithm 182 | and applies the `update_fn`. It takes a graph, and returns a graph. 183 | 184 | ```python 185 | updated_graph = net(graph) 186 | ``` 187 | 188 | 189 | ## Examples 190 | 191 | For a deeper dive best place to start are the examples. In particular: 192 | 193 | * `examples/basic.py` provides an introduction to the features of the library. 194 | * `ogb_examples/train.py` provides an end to 195 | end example of training a `GraphNet` on `molhiv` Open Graph Benchmark dataset. 196 | Please note, you need to have downloaded the dataset to run this example. 197 | 198 | The rest of the examples are short scripts demonstrating how to use various 199 | models from our model zoo, as well as making models go fast with `jax.jit`, and 200 | how to deal with Jax's static shape requirement. 201 | 202 | 203 | ## Citing Jraph 204 | 205 | To cite this repository: 206 | 207 | ``` 208 | @software{jraph2020github, 209 | author = {Jonathan Godwin* and Thomas Keck* and Peter Battaglia and Victor Bapst and Thomas Kipf and Yujia Li and Kimberly Stachenfeld and Petar Veli\v{c}kovi\'{c} and Alvaro Sanchez-Gonzalez}, 210 | title = {{J}raph: {A} library for graph neural networks in jax.}, 211 | url = {http://github.com/deepmind/jraph}, 212 | version = {0.0.1.dev}, 213 | year = {2020}, 214 | } 215 | ``` 216 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | Jraph API 2 | ========= 3 | 4 | .. currentmodule:: jraph 5 | 6 | GraphsTuple 7 | ----------- 8 | 9 | .. autoclass:: GraphsTuple 10 | 11 | Batching & Padding Utilities 12 | ---------------------------- 13 | 14 | .. autofunction:: batch 15 | 16 | .. autofunction:: unbatch 17 | 18 | .. autofunction:: pad_with_graphs 19 | 20 | .. autofunction:: get_number_of_padding_with_graphs_graphs 21 | 22 | .. autofunction:: get_number_of_padding_with_graphs_nodes 23 | 24 | .. autofunction:: get_number_of_padding_with_graphs_edges 25 | 26 | .. autofunction:: unpad_with_graphs 27 | 28 | .. autofunction:: get_node_padding_mask 29 | 30 | .. autofunction:: get_edge_padding_mask 31 | 32 | .. autofunction:: get_graph_padding_mask 33 | 34 | Segment Utilities 35 | ----------------- 36 | 37 | .. autofunction:: segment_mean 38 | 39 | .. autofunction:: segment_max 40 | 41 | .. autofunction:: segment_softmax 42 | 43 | .. autofunction:: partition_softmax 44 | 45 | Misc Utilities 46 | ----------------- 47 | 48 | .. autofunction:: concatenated_args 49 | 50 | Models 51 | ====== 52 | 53 | .. autofunction:: GraphNetwork 54 | 55 | .. autofunction:: InteractionNetwork 56 | 57 | .. autofunction:: GraphMapFeatures 58 | 59 | .. autofunction:: RelationNetwork 60 | 61 | .. autofunction:: DeepSets 62 | 63 | .. autofunction:: GraphNetGAT 64 | 65 | .. autofunction:: GAT 66 | 67 | .. autofunction:: GraphConvolution 68 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Configuration file for the Sphinx documentation builder.""" 16 | 17 | # This file only contains a selection of the most common options. For a full 18 | # list see the documentation: 19 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 20 | 21 | # -- Path setup -------------------------------------------------------------- 22 | 23 | # If extensions (or modules to document with autodoc) are in another directory, 24 | # add these directories to sys.path here. If the directory is relative to the 25 | # documentation root, use os.path.abspath to make it absolute, like shown here. 26 | # 27 | # pylint: disable=g-bad-import-order 28 | # pylint: disable=g-import-not-at-top 29 | import inspect 30 | import os 31 | import sys 32 | import typing 33 | import jraph 34 | 35 | 36 | def _add_annotations_import(path): 37 | """Appends a future annotations import to the file at the given path.""" 38 | with open(path) as f: 39 | contents = f.read() 40 | # If we run sphinx multiple times then we will append the future import 41 | # multiple times too, so this check is here to prevent that. 42 | if contents.startswith('from __future__ import annotations'): 43 | return 44 | 45 | assert contents.startswith('#'), (path, contents.split('\n')[0]) 46 | with open(path, 'w') as f: 47 | # NOTE: This is subtle and not unit tested, we're prefixing the first line 48 | # in each Python file with this future import. It is important to prefix 49 | # not insert a newline such that source code locations are accurate (we link 50 | # to GitHub). The assertion above ensures that the first line in the file is 51 | # a comment so it is safe to prefix it. 52 | f.write('from __future__ import annotations ') 53 | f.write(contents) 54 | 55 | 56 | def _recursive_add_annotations_import(): 57 | for path, _, files in os.walk('../jraph/'): 58 | for file in files: 59 | if file.endswith('.py'): 60 | _add_annotations_import(os.path.abspath(os.path.join(path, file))) 61 | 62 | if 'READTHEDOCS' in os.environ: 63 | _recursive_add_annotations_import() 64 | 65 | typing.get_type_hints = lambda obj, *unused: obj.__annotations__ 66 | sys.path.insert(0, os.path.abspath('../')) 67 | sys.path.append(os.path.abspath('ext')) 68 | 69 | # -- Project information ----------------------------------------------------- 70 | 71 | project = 'Jraph' 72 | copyright = '2021, Jraph Authors' # pylint: disable=redefined-builtin 73 | author = 'Jraph Authors' 74 | 75 | # The full version, including alpha/beta/rc tags 76 | release = '0.0.1.dev' 77 | 78 | 79 | # -- General configuration --------------------------------------------------- 80 | 81 | # Add any Sphinx extension module names here, as strings. They can be 82 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 83 | # ones. 84 | extensions = [ 85 | 'sphinx.ext.autodoc', 86 | 'sphinx.ext.autosummary', 87 | 'sphinx.ext.doctest', 88 | 'sphinx.ext.inheritance_diagram', 89 | 'sphinx.ext.linkcode', 90 | 'sphinx.ext.napoleon', 91 | 'sphinxcontrib.bibtex', 92 | 'sphinx_autodoc_typehints', 93 | ] 94 | 95 | pygments_style = 'sphinx' 96 | 97 | # Add any paths that contain templates here, relative to this directory. 98 | templates_path = ['_templates'] 99 | 100 | # List of patterns, relative to source directory, that match files and 101 | # directories to ignore when looking for source files. 102 | # This pattern also affects html_static_path and html_extra_path. 103 | exclude_patterns = ['_build'] 104 | 105 | # -- Options for autodoc ----------------------------------------------------- 106 | 107 | autodoc_default_options = { 108 | 'member-order': 'bysource', 109 | 'special-members': True, 110 | } 111 | 112 | # -- Options for HTML output ------------------------------------------------- 113 | 114 | # The theme to use for HTML and HTML Help pages. See the documentation for 115 | # a list of builtin themes. 116 | # 117 | html_theme = 'sphinx_rtd_theme' 118 | 119 | # Add any paths that contain custom static files (such as style sheets) here, 120 | # relative to this directory. They are copied after the builtin static files, 121 | # so a file named "default.css" will overwrite the builtin "default.css". 122 | html_static_path = ['_static'] 123 | 124 | # -- Source code links ------------------------------------------------------- 125 | 126 | 127 | def linkcode_resolve(domain, info): 128 | """Resolve a GitHub URL corresponding to Python object.""" 129 | if domain != 'py': 130 | return None 131 | 132 | try: 133 | mod = sys.modules[info['module']] 134 | except ImportError: 135 | return None 136 | 137 | obj = mod 138 | try: 139 | for attr in info['fullname'].split('.'): 140 | obj = getattr(obj, attr) 141 | except AttributeError: 142 | return None 143 | else: 144 | obj = inspect.unwrap(obj) 145 | 146 | try: 147 | filename = inspect.getsourcefile(obj) 148 | except TypeError: 149 | return None 150 | 151 | try: 152 | source, lineno = inspect.getsourcelines(obj) 153 | except OSError: 154 | return None 155 | 156 | return 'https://github.com/deepmind/jraph/blob/master/jraph/%s#L%d#L%d' % ( 157 | os.path.relpath(filename, start=os.path.dirname( 158 | jraph.__file__)), lineno, lineno + len(source) - 1) 159 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/deepmind/jraph/tree/master/docs 2 | 3 | Welcome to Jraph's documentation! 4 | ================================= 5 | 6 | Jraph (pronounced "giraffe") is a lightweight library for working with graph 7 | neural networks in jax. It provides a data structure for graphs, a set of 8 | utilites for working with graphs, and a 'zoo' of forkable graph neural network 9 | models. 10 | 11 | .. toctree:: 12 | :caption: API Documentation: 13 | :maxdepth: 1 14 | 15 | api 16 | 17 | Overview 18 | -------- 19 | 20 | Jraph is designed to provide utilities for working with graphs in jax, but 21 | doesn't prescribe a way to write or develop graph neural networks. 22 | 23 | * ``graph.py`` provides a lightweight data structure, ``GraphsTuple``, for working with graphs. 24 | * ``utils.py`` provides utilies for working with ``GraphsTuples`` in jax. 25 | 26 | * Utilities for batching datasets of ``GraphsTuples``. 27 | * Utilities to support jit compilation of variable shaped graphs via 28 | padding and masking. 29 | * Utilities for defining losses on partitions of inputs. 30 | * ``models.py`` provides examples of different types of graph neural network 31 | message passing. These are designed to be lightweight, easy to fork and 32 | adapt. They do not manage parameters for you - for that, consider using 33 | ``haiku`` or ``flax``. See the examples for more details. 34 | 35 | 36 | Installation 37 | ------------ 38 | 39 | See https://github.com/google/jax#pip-installation for instructions on 40 | installing JAX. 41 | 42 | Jraph can be installed directly from github using the following command: 43 | 44 | ``pip install git+git://github.com/deepmind/jraph.git`` 45 | 46 | Quick Start 47 | =========== 48 | 49 | Representing Graphs - The ``GraphsTuple`` 50 | ------------------------------------------ 51 | 52 | Jraph takes inspiration from the Tensorflow `graph_nets library `_ 53 | in defining a ``GraphsTuple`` data structure, which is a ``namedtuple`` that contains 54 | one or more directed graphs. 55 | 56 | .. code-block:: python 57 | 58 | import jraph 59 | import jax.numpy as jnp 60 | 61 | # Define a three node graph, each node has an integer as its feature. 62 | node_features = jnp.array([[0.], [1.], [2.]]) 63 | 64 | # We will construct a graph fro which there is a directed edge between each node 65 | # and its successor. We define this with `senders` (source nodes) and `receivers` 66 | # (destination nodes). 67 | senders = jnp.array([0, 1, 2]) 68 | receivers = jnp.array([1, 2, 0]) 69 | 70 | # You can optionally add edge attributes. 71 | edges = jnp.array([[5.], [6.], [7.]]) 72 | 73 | # We then save the number of nodes and the number of edges. 74 | # This information is used to make running GNNs over multiple graphs 75 | # in a GraphsTuple possible. 76 | n_node = jnp.array([3]) 77 | n_edge = jnp.array([3]) 78 | 79 | # Optionally you can add `global` information, such as a graph label. 80 | 81 | global_context = jnp.array([[1]]) # Same feature dimensions as nodes and edges. 82 | graph = jraph.GraphsTuple(nodes=node_features, senders=senders, receivers=receivers, 83 | edges=edges, n_node=n_node, n_edge=n_edge, globals=global_context) 84 | 85 | A ``GraphsTuple`` can have more than one graph. 86 | 87 | .. code-block:: python 88 | 89 | two_graph_graphstuple = jraph.batch([graph, graph]) 90 | 91 | 92 | The ``node`` and ``edge`` features are stacked on the leading axis. 93 | 94 | .. code-block:: python 95 | 96 | jraph.batch([graph, graph]).nodes 97 | >> DeviceArray([[0.], 98 | [1.], 99 | [2.], 100 | [0.], 101 | [1.], 102 | [2.]], dtype=float32) 103 | 104 | 105 | You can tell which nodes are from which graph by looking at ``n_node``. 106 | 107 | .. code-block:: python 108 | 109 | jraph.batch([graph, graph]).n_node 110 | >> DeviceArray([3, 3], dtype=int32) 111 | 112 | 113 | You can store nests of features in ``nodes``, ``edges`` and ``globals``. This makes 114 | it possible to store multiple sets of features for each node, edge or graph, with 115 | potentially different types and semantically different meanings (for example 116 | 'training' and 'testing' nodes). The only requirement if that all arrays within 117 | each nest must have a common leading dimensions size, matching the total number 118 | of nodes, edges or graphs within the ``Graphstuple`` respectively. 119 | 120 | .. code-block:: python 121 | 122 | node_targets = jnp.array([[True], [False], [True]]) 123 | graph = graph._replace(nodes={'inputs': graph.nodes, 'targets': node_targets}) 124 | 125 | 126 | Using the Model Zoo 127 | ------------------- 128 | 129 | Jraph provides a set of implemented reference models for you to use. 130 | 131 | A Jraph model defines a message passing algorithm between the nodes, edges and 132 | global attributes of a graph. The user defines ``update`` functions that update graph features, which are typically neural networks but can be arbitrary jax functions. 133 | 134 | Let's go through a ``GraphNetwork`` [(paper)](https://arxiv.org/abs/1806.01261) example. 135 | A GraphNetwork's first update function updates the edges using ``edge`` features, 136 | the node features of the ``sender`` and ``receiver`` and the ``global`` features. 137 | 138 | 139 | .. code-block:: python 140 | 141 | # As one example, we just pass the edge features straight through. 142 | def update_edge_fn(edge, sender, receiver, globals_): 143 | return edge 144 | 145 | 146 | Often we use the concatenation of these features, and ``jraph`` provides an easy 147 | way of doing this with the ``concatenated_args`` decorator. 148 | 149 | .. code-block:: python 150 | 151 | @jraph.concatenated_args 152 | def update_edge_fn(concatenated_features): 153 | return concatenated_features 154 | 155 | 156 | Typically, a learned model such as a Multi-Layer Perceptron is used within an 157 | update function. 158 | 159 | The user similarly defines functions that update the nodes and globals. These 160 | are then used to configure a `GraphNetwork`. To see the arguments to the node 161 | and global `update_fns` please take a look at the model zoo. 162 | 163 | .. code-block:: python 164 | 165 | net = jraph.GraphNetwork(update_edge_fn=update_edge_fn, 166 | update_node_fn=update_node_fn, 167 | update_global_fn=update_global_fn) 168 | 169 | 170 | ``net`` is a function that sends messages according to the ``GraphNetwork`` algorithm 171 | and applies the ``update_fn``. It takes a graph, and returns a graph. 172 | 173 | .. code-block:: python 174 | 175 | updated_graph = net(graph) 176 | 177 | Contribute 178 | ---------- 179 | 180 | Please read ``CONTRIBUTING.md``. 181 | 182 | - Issue tracker: https://github.com/deepmind/jraph/issues 183 | - Source code: https://github.com/deepmind/jraph/tree/master 184 | 185 | Support 186 | ------- 187 | 188 | If you are having issues, please let us know by filing an issue on our 189 | `issue tracker `_. 190 | 191 | License 192 | ------- 193 | 194 | Jraph is licensed under the Apache 2.0 License. 195 | 196 | 197 | Indices and tables 198 | ================== 199 | 200 | * :ref:`genindex` 201 | * :ref:`modindex` 202 | * :ref:`search` 203 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==3.3.0 2 | sphinx_rtd_theme==0.5.0 3 | sphinxcontrib-katex==0.7.1 4 | sphinxcontrib-bibtex==1.0.0 5 | sphinx-autodoc-typehints==1.11.1 6 | nbsphinx==0.8.0 7 | -------------------------------------------------------------------------------- /images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/jraph/51f5990104f7374492f8f3ea1cbc47feb411c69c/images/logo.png -------------------------------------------------------------------------------- /jraph/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Jraph.""" 16 | 17 | 18 | from jraph._src.graph import GraphsTuple 19 | from jraph._src.models import AggregateEdgesToGlobalsFn 20 | from jraph._src.models import AggregateEdgesToNodesFn 21 | from jraph._src.models import AggregateNodesToGlobalsFn 22 | from jraph._src.models import AttentionLogitFn 23 | from jraph._src.models import AttentionReduceFn 24 | from jraph._src.models import DeepSets 25 | from jraph._src.models import EmbedEdgeFn 26 | from jraph._src.models import EmbedGlobalFn 27 | from jraph._src.models import EmbedNodeFn 28 | from jraph._src.models import GAT 29 | from jraph._src.models import GATAttentionLogitFn 30 | from jraph._src.models import GATAttentionQueryFn 31 | from jraph._src.models import GATNodeUpdateFn 32 | from jraph._src.models import GNUpdateEdgeFn 33 | from jraph._src.models import GNUpdateGlobalFn 34 | from jraph._src.models import GNUpdateNodeFn 35 | from jraph._src.models import GraphConvolution 36 | from jraph._src.models import GraphMapFeatures 37 | from jraph._src.models import GraphNetGAT 38 | from jraph._src.models import GraphNetwork 39 | from jraph._src.models import InteractionNetwork 40 | from jraph._src.models import InteractionUpdateEdgeFn 41 | from jraph._src.models import InteractionUpdateNodeFn 42 | from jraph._src.models import NodeFeatures 43 | from jraph._src.models import RelationNetwork 44 | from jraph._src.utils import ArrayTree 45 | from jraph._src.utils import batch 46 | from jraph._src.utils import batch_np 47 | from jraph._src.utils import concatenated_args 48 | from jraph._src.utils import dynamically_batch 49 | from jraph._src.utils import get_edge_padding_mask 50 | from jraph._src.utils import get_fully_connected_graph 51 | from jraph._src.utils import get_graph_padding_mask 52 | from jraph._src.utils import get_node_padding_mask 53 | from jraph._src.utils import get_number_of_padding_with_graphs_edges 54 | from jraph._src.utils import get_number_of_padding_with_graphs_graphs 55 | from jraph._src.utils import get_number_of_padding_with_graphs_nodes 56 | from jraph._src.utils import pad_with_graphs 57 | from jraph._src.utils import partition_softmax 58 | from jraph._src.utils import segment_max 59 | from jraph._src.utils import segment_max_or_constant 60 | from jraph._src.utils import segment_mean 61 | from jraph._src.utils import segment_min 62 | from jraph._src.utils import segment_min_or_constant 63 | from jraph._src.utils import segment_normalize 64 | from jraph._src.utils import segment_softmax 65 | from jraph._src.utils import segment_sum 66 | from jraph._src.utils import segment_variance 67 | from jraph._src.utils import sparse_matrix_to_graphs_tuple 68 | from jraph._src.utils import unbatch 69 | from jraph._src.utils import unbatch_np 70 | from jraph._src.utils import unpad_with_graphs 71 | from jraph._src.utils import with_zero_out_padding_outputs 72 | from jraph._src.utils import zero_out_padding 73 | 74 | 75 | __version__ = "0.0.6.dev0" 76 | 77 | __all__ = ("ArrayTree", "DeepSets", "GraphConvolution", "GraphMapFeatures", 78 | "InteractionNetwork", "RelationNetwork", "GraphNetGAT", "GAT", 79 | "GraphsTuple", "GraphNetwork", "NodeFeatures", 80 | "AggregateEdgesToNodesFn", "AggregateNodesToGlobalsFn", 81 | "AggregateEdgesToGlobalsFn", "AttentionLogitFn", "AttentionReduceFn", 82 | "GNUpdateEdgeFn", "GNUpdateNodeFn", "GNUpdateGlobalFn", 83 | "InteractionUpdateNodeFn", "InteractionUpdateEdgeFn", "EmbedEdgeFn", 84 | "EmbedNodeFn", "EmbedGlobalFn", "GATAttentionQueryFn", 85 | "GATAttentionLogitFn", "GATNodeUpdateFn", "batch", "batch_np", 86 | "unbatch", "unbatch_np", "pad_with_graphs", 87 | "get_number_of_padding_with_graphs_graphs", 88 | "get_number_of_padding_with_graphs_nodes", 89 | "get_number_of_padding_with_graphs_edges", "unpad_with_graphs", 90 | "get_node_padding_mask", "get_edge_padding_mask", 91 | "get_graph_padding_mask", "segment_max", "segment_max_or_constant", 92 | "segment_min_or_constant", "segment_softmax", "segment_sum", 93 | "partition_softmax", "concatenated_args", 94 | "get_fully_connected_graph", "dynamically_batch", 95 | "with_zero_out_padding_outputs", "zero_out_padding", 96 | "sparse_matrix_to_graphs_tuple") 97 | 98 | # _________________________________________ 99 | # / Please don't use symbols in `_src` they \ 100 | # \ are not part of the Jraph public API. / 101 | # ----------------------------------------- 102 | # \ ^__^ 103 | # \ (oo)\_______ 104 | # (__)\ )\/\ 105 | # ||----w | 106 | # || || 107 | # 108 | try: 109 | del _src # pylint: disable=undefined-variable 110 | except NameError: 111 | pass 112 | -------------------------------------------------------------------------------- /jraph/_src/graph.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Graph Data Structures.""" 16 | 17 | from typing import Any, NamedTuple, Iterable, Mapping, Union, Optional 18 | import jax.numpy as jnp 19 | 20 | 21 | # As of 04/2020 pytype doesn't support recursive types. 22 | # pytype: disable=not-supported-yet 23 | ArrayTree = Union[jnp.ndarray, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']] 24 | 25 | 26 | class GraphsTuple(NamedTuple): 27 | """An ordered collection of graphs in a sparse format. 28 | 29 | The values of ``nodes``, ``edges`` and ``globals`` can be ``ArrayTrees`` - 30 | nests of features with ``jax`` compatible values. For example, ``nodes`` in a 31 | graph may have more than one type of attribute. 32 | 33 | However, the GraphsTuple typically takes the following form for a batch of 34 | `n` graphs: 35 | 36 | - n_node: The number of nodes per graph. It is a vector of integers with shape 37 | `[n_graphs]`, such that ``graph.n_node[i]`` is the number of nodes in the 38 | i-th graph. 39 | 40 | - n_edge: The number of edges per graph. It is a vector of integers with shape 41 | `[n_graphs]`, such that ``graph.n_edge[i]`` is the number of edges in the 42 | i-th graph. 43 | 44 | - nodes: The nodes features. It is either ``None`` (the graph has no node 45 | features), or a vector of shape `[n_nodes] + node_shape`, where 46 | ``n_nodes = sum(graph.n_node)`` is the total number of nodes in the batch of 47 | graphs, and `node_shape` represents the shape of the features of each node. 48 | The relative index of a node from the batched version can be recovered from 49 | the ``graph.n_node`` property. For instance, the second node of the third 50 | graph will have its features in the 51 | `1 + graph.n_node[0] + graph.n_node[1]`-th slot of graph.nodes. 52 | Observe that having a ``None`` value for this field does not mean that the 53 | graphs have no nodes, only that they do not have node features. 54 | 55 | - edges: The edges features. It is either ``None`` (the graph has no edge 56 | features), or a vector of shape `[n_edges] + edge_shape`, where 57 | ``n_edges = sum(graph.n_edge)`` is the total number of edges in the batch of 58 | graphs, and ``edge_shape`` represents the shape of the features of each 59 | edge. 60 | 61 | The relative index of an edge from the batched version can be recovered from 62 | the ``graph.n_edge`` property. For instance, the third edge of the third 63 | graph will have its features in the `2 + graph.n_edge[0] + graph.n_edge[1]`- 64 | th slot of graph.edges. 65 | 66 | Having a ``None`` value for this field does not necessarily mean that the 67 | graph has no edges, only that they do not have edge features. 68 | 69 | - receivers: The indices of the receiver nodes, for each edge. It is either 70 | ``None`` (if the graph has no edges), or a vector of integers of shape 71 | `[n_edges]`, such that ``graph.receivers[i]`` is the index of the node 72 | receiving from the i-th edge. 73 | 74 | Observe that the index is absolute (in other words, cumulative), i.e. 75 | ``graphs.receivers`` take value in `[0, n_nodes]`. For instance, an edge 76 | connecting the vertices with relative indices 2 and 3 in the second graph of 77 | the batch would have a ``receivers`` value of `3 + graph.n_node[0]`. 78 | If `graphs.receivers` is ``None``, then ``graphs.edges`` and 79 | ``graphs.senders`` should also be ``None``. 80 | 81 | - senders: The indices of the sender nodes, for each edge. It is either 82 | ``None`` (if the graph has no edges), or a vector of integers of shape 83 | `[n_edges]`, such that ``graph.senders[i]`` is the index of the node 84 | sending from the i-th edge. 85 | 86 | Observe that the index is absolute, i.e. ``graphs.senders`` take value in 87 | `[0, n_nodes]`. For instance, an edge connecting the vertices with relative 88 | indices 1 and 3 in the third graph of the batch would have a ``senders`` 89 | value of `1 + graph.n_node[0] + graph.n_node[1]`. 90 | 91 | If ``graphs.senders`` is ``None``, then ``graphs.edges`` and 92 | ``graphs.receivers`` should also be ``None``. 93 | 94 | - globals: The global features of the graph. It is either ``None`` (the graph 95 | has no global features), or a vector of shape `[n_graphs] + global_shape` 96 | representing graph level features. 97 | 98 | 99 | """ 100 | nodes: Optional[ArrayTree] 101 | edges: Optional[ArrayTree] 102 | receivers: Optional[jnp.ndarray] # with integer dtype 103 | senders: Optional[jnp.ndarray] # with integer dtype 104 | globals: Optional[ArrayTree] 105 | n_node: jnp.ndarray # with integer dtype 106 | n_edge: jnp.ndarray # with integer dtype 107 | -------------------------------------------------------------------------------- /jraph/_src/models_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tests for jraph.models.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import jax 20 | import jax.numpy as jnp 21 | import jax.tree_util as tree 22 | 23 | from jraph._src import graph 24 | from jraph._src import models 25 | from jraph._src import utils 26 | import numpy as np 27 | 28 | 29 | def _get_random_graph(max_n_graph=10): 30 | n_graph = np.random.randint(1, max_n_graph + 1) 31 | n_node = np.random.randint(0, 10, n_graph) 32 | n_edge = np.random.randint(0, 20, n_graph) 33 | # We cannot have any edges if there are no nodes. 34 | n_edge[n_node == 0] = 0 35 | 36 | senders = [] 37 | receivers = [] 38 | offset = 0 39 | for n_node_in_graph, n_edge_in_graph in zip(n_node, n_edge): 40 | if n_edge_in_graph != 0: 41 | senders += list( 42 | np.random.randint(0, n_node_in_graph, n_edge_in_graph) + offset) 43 | receivers += list( 44 | np.random.randint(0, n_node_in_graph, n_edge_in_graph) + offset) 45 | offset += n_node_in_graph 46 | 47 | return graph.GraphsTuple( 48 | n_node=jnp.asarray(n_node), 49 | n_edge=jnp.asarray(n_edge), 50 | nodes=jnp.asarray(np.random.random(size=(np.sum(n_node), 4))), 51 | edges=jnp.asarray(np.random.random(size=(np.sum(n_edge), 3))), 52 | globals=jnp.asarray(np.random.random(size=(n_graph, 5))), 53 | senders=jnp.asarray(senders), 54 | receivers=jnp.asarray(receivers)) 55 | 56 | 57 | def _get_graph_network(graphs_tuple): 58 | # Our test update functions are just identity functions. 59 | update_node_fn = lambda n, se, re, g: n 60 | update_edge_fn = lambda e, sn, rn, g: e 61 | update_global_fn = lambda gn, ge, g: g 62 | net = models.GraphNetwork(update_edge_fn, 63 | update_node_fn, 64 | update_global_fn) 65 | return net(graphs_tuple) 66 | 67 | 68 | def _get_graph_network_no_global_update(graphs_tuple): 69 | # Our test update functions are just identity functions. 70 | update_node_fn = lambda n, se, re, g: n 71 | update_edge_fn = lambda e, sn, rn, g: e 72 | update_global_fn = None 73 | net = models.GraphNetwork(update_edge_fn, 74 | update_node_fn, 75 | update_global_fn) 76 | return net(graphs_tuple) 77 | 78 | 79 | def _get_graph_network_no_node_update(graphs_tuple): 80 | # Our test update functions are just identity functions. 81 | update_node_fn = None 82 | update_edge_fn = lambda e, sn, rn, g: e 83 | update_global_fn = lambda gn, ge, g: g 84 | net = models.GraphNetwork(update_edge_fn, 85 | update_node_fn, 86 | update_global_fn) 87 | return net(graphs_tuple) 88 | 89 | 90 | def _get_graph_network_no_edge_update(graphs_tuple): 91 | # Our test update functions are just identity functions. 92 | update_node_fn = lambda n, se, re, g: n 93 | update_edge_fn = None 94 | update_global_fn = lambda gn, ge, g: g 95 | net = models.GraphNetwork(update_edge_fn, 96 | update_node_fn, 97 | update_global_fn) 98 | return net(graphs_tuple) 99 | 100 | 101 | def _get_attention_graph_network(graphs_tuple): 102 | # Our test update functions are just identity functions. 103 | update_node_fn = lambda n, se, re, g: n 104 | update_edge_fn = lambda e, sn, rn, g: e 105 | update_global_fn = lambda gn, ge, g: g 106 | # Our attention logits are just one in this case. 107 | attention_logit_fn = lambda e, sn, rn, g: jnp.array(1.0) 108 | # We use a custom apply function here, which just returns the edge unchanged. 109 | attention_reduce_fn = lambda e, w: e 110 | net = models.GraphNetwork(update_edge_fn, 111 | update_node_fn, 112 | update_global_fn, 113 | attention_logit_fn=attention_logit_fn, 114 | attention_reduce_fn=attention_reduce_fn) 115 | return net(graphs_tuple) 116 | 117 | 118 | def _get_graph_gat(graphs_tuple): 119 | # Our test update functions are just identity functions. 120 | update_node_fn = lambda n, se, re, g: n 121 | update_edge_fn = lambda e, sn, rn, g: e 122 | update_global_fn = lambda gn, ge, g: g 123 | # Our attention logits are just one in this case. 124 | attention_logit_fn = lambda e, sn, rn, g: jnp.array(1.0) 125 | # We use a custom apply function here, which just returns the edge unchanged. 126 | attention_reduce_fn = lambda e, w: e 127 | net = models.GraphNetGAT(update_edge_fn, 128 | update_node_fn, 129 | attention_logit_fn, 130 | attention_reduce_fn, 131 | update_global_fn) 132 | return net(graphs_tuple) 133 | 134 | 135 | def _get_multi_head_attention_graph_network(graphs_tuple): 136 | # Our test update functions are just identity functions. 137 | update_node_fn = lambda n, se, re, g: n 138 | update_global_fn = lambda gn, ge, g: g 139 | # With multi-head attention we have to return multiple edge features. 140 | # Here we define 3 heads, all with the same message. 141 | def update_edge_fn(e, unused_sn, unused_rn, unused_g): 142 | return tree.tree_map(lambda e_: jnp.stack([e_, e_, e_]), e) 143 | # Our attention logits are just the sum of the edge features of each head. 144 | def attention_logit_fn(e, unused_sn, unused_rn, unused_g): 145 | return tree.tree_map(lambda e_: jnp.sum(e_, axis=-1), e) 146 | # For multi-head attention we need a custom apply attention function. 147 | # In this we return the first edge. 148 | def attention_reduce_fn(e, unused_w): 149 | return tree.tree_map(lambda e_: e_[0], e) 150 | net = models.GraphNetwork(jax.vmap(update_edge_fn), 151 | jax.vmap(update_node_fn), 152 | update_global_fn, 153 | attention_logit_fn=jax.vmap(attention_logit_fn), 154 | attention_reduce_fn=jax.vmap(attention_reduce_fn)) 155 | return net(graphs_tuple) 156 | 157 | 158 | def _get_interaction_network(graphs_tuple): 159 | update_node_fn = lambda n, r: jnp.concatenate((n, r), axis=-1) 160 | update_edge_fn = lambda e, s, r: jnp.concatenate((e, s, r), axis=-1) 161 | out = models.InteractionNetwork(update_edge_fn, update_node_fn)(graphs_tuple) 162 | nodes, edges, receivers, senders, _, _, _ = graphs_tuple 163 | expected_edges = jnp.concatenate( 164 | (edges, nodes[senders], nodes[receivers]), axis=-1) 165 | aggregated_nodes = utils.segment_sum( 166 | expected_edges, receivers, num_segments=len(graphs_tuple.nodes)) 167 | expected_nodes = jnp.concatenate( 168 | (nodes, aggregated_nodes), axis=-1) 169 | expected_out = graphs_tuple._replace( 170 | edges=expected_edges, nodes=expected_nodes) 171 | return out, expected_out 172 | 173 | 174 | def _get_graph_independent(graphs_tuple): 175 | embed_fn = lambda x: x * 2 176 | out = models.GraphMapFeatures(embed_fn, embed_fn, embed_fn)(graphs_tuple) 177 | expected_out = graphs_tuple._replace(nodes=graphs_tuple.nodes*2, 178 | edges=graphs_tuple.edges*2, 179 | globals=graphs_tuple.globals*2) 180 | return out, expected_out 181 | 182 | 183 | def _get_relation_network(graphs_tuple): 184 | edge_fn = lambda s, r: jnp.concatenate((s, r), axis=-1) 185 | global_fn = lambda e: e*2 186 | out = models.RelationNetwork(edge_fn, global_fn)(graphs_tuple) 187 | expected_edges = jnp.concatenate( 188 | (graphs_tuple.nodes[graphs_tuple.senders], 189 | graphs_tuple.nodes[graphs_tuple.receivers]), axis=-1) 190 | num_graphs = len(graphs_tuple.n_edge) 191 | edge_gr_idx = jnp.repeat(jnp.arange(num_graphs), 192 | graphs_tuple.n_edge, 193 | total_repeat_length=graphs_tuple.edges.shape[0]) 194 | aggregated_edges = utils.segment_sum( 195 | expected_edges, edge_gr_idx, num_segments=num_graphs) 196 | expected_out = graphs_tuple._replace( 197 | edges=expected_edges, globals=aggregated_edges*2) 198 | return out, expected_out 199 | 200 | 201 | def _get_deep_sets(graphs_tuple): 202 | node_fn = lambda n, g: jnp.concatenate((n, g), axis=-1) 203 | global_fn = lambda n: n*2 204 | out = models.DeepSets(node_fn, global_fn)(graphs_tuple) 205 | num_graphs = len(graphs_tuple.n_node) 206 | num_nodes = len(graphs_tuple.nodes) 207 | broadcasted_globals = jnp.repeat(graphs_tuple.globals, graphs_tuple.n_node, 208 | total_repeat_length=num_nodes, axis=0) 209 | expected_nodes = jnp.concatenate( 210 | (graphs_tuple.nodes, broadcasted_globals), axis=-1) 211 | node_gr_idx = jnp.repeat(jnp.arange(num_graphs), 212 | graphs_tuple.n_node, 213 | total_repeat_length=num_nodes) 214 | expected_out = graphs_tuple._replace( 215 | nodes=expected_nodes, 216 | globals=utils.segment_sum( 217 | expected_nodes, node_gr_idx, num_segments=num_graphs)*2) 218 | return out, expected_out 219 | 220 | 221 | def _get_gat(graphs_tuple): 222 | # With multi-head attention we have to return multiple edge features. 223 | # Here we define 3 heads, all with the same message. 224 | def attention_query_fn(n): 225 | return tree.tree_map(lambda n_: jnp.stack([n_, n_, n_], axis=2), n) 226 | # Our attention logits 1 if a self edge 227 | def attention_logit_fn(s, r, e_): 228 | del e_ 229 | return (s == r)*1 + (s != r)*-1e10 230 | 231 | def node_update_fn(nodes): 232 | return jnp.mean(nodes, axis=2) 233 | 234 | net = models.GAT(attention_query_fn, attention_logit_fn, node_update_fn) 235 | 236 | # Cast nodes to floats since GAT will output floats from the softmax 237 | # attention. 238 | graphs_tuple = graphs_tuple._replace( 239 | nodes=jnp.array(graphs_tuple.nodes, jnp.float32)) 240 | return net(graphs_tuple), graphs_tuple 241 | 242 | 243 | class ModelsTest(parameterized.TestCase): 244 | 245 | def _make_nest(self, array): 246 | """Returns a nest given an array.""" 247 | return {'a': array, 248 | 'b': [jnp.ones_like(array), {'c': jnp.zeros_like(array)}]} 249 | 250 | def _get_list_and_batched_graph(self): 251 | """Returns a list of individual graphs and a batched version. 252 | 253 | This test-case includes the following corner-cases: 254 | - single node, 255 | - multiple nodes, 256 | - no edges, 257 | - single edge, 258 | - and multiple edges. 259 | """ 260 | batched_graph = graph.GraphsTuple( 261 | n_node=jnp.array([1, 3, 1, 0, 2, 0, 0]), 262 | n_edge=jnp.array([2, 5, 0, 0, 1, 0, 0]), 263 | nodes=self._make_nest(jnp.arange(14).reshape(7, 2)), 264 | edges=self._make_nest(jnp.arange(24).reshape(8, 3)), 265 | globals=self._make_nest(jnp.arange(14).reshape(7, 2)), 266 | senders=jnp.array([0, 0, 1, 1, 2, 3, 3, 6]), 267 | receivers=jnp.array([0, 0, 2, 1, 3, 2, 1, 5])) 268 | 269 | list_graphs = [ 270 | graph.GraphsTuple( 271 | n_node=jnp.array([1]), 272 | n_edge=jnp.array([2]), 273 | nodes=self._make_nest(jnp.array([[0, 1]])), 274 | edges=self._make_nest(jnp.array([[0, 1, 2], [3, 4, 5]])), 275 | globals=self._make_nest(jnp.array([[0, 1]])), 276 | senders=jnp.array([0, 0]), 277 | receivers=jnp.array([0, 0])), 278 | graph.GraphsTuple( 279 | n_node=jnp.array([3]), 280 | n_edge=jnp.array([5]), 281 | nodes=self._make_nest(jnp.array([[2, 3], [4, 5], [6, 7]])), 282 | edges=self._make_nest( 283 | jnp.array([[6, 7, 8], [9, 10, 11], [12, 13, 14], [15, 16, 17], 284 | [18, 19, 20]])), 285 | globals=self._make_nest(jnp.array([[2, 3]])), 286 | senders=jnp.array([0, 0, 1, 2, 2]), 287 | receivers=jnp.array([1, 0, 2, 1, 0])), 288 | graph.GraphsTuple( 289 | n_node=jnp.array([1]), 290 | n_edge=jnp.array([0]), 291 | nodes=self._make_nest(jnp.array([[8, 9]])), 292 | edges=self._make_nest(jnp.zeros((0, 3))), 293 | globals=self._make_nest(jnp.array([[4, 5]])), 294 | senders=jnp.array([]), 295 | receivers=jnp.array([])), 296 | graph.GraphsTuple( 297 | n_node=jnp.array([0]), 298 | n_edge=jnp.array([0]), 299 | nodes=self._make_nest(jnp.zeros((0, 2))), 300 | edges=self._make_nest(jnp.zeros((0, 3))), 301 | globals=self._make_nest(jnp.array([[6, 7]])), 302 | senders=jnp.array([]), 303 | receivers=jnp.array([])), 304 | graph.GraphsTuple( 305 | n_node=jnp.array([2]), 306 | n_edge=jnp.array([1]), 307 | nodes=self._make_nest(jnp.array([[10, 11], [12, 13]])), 308 | edges=self._make_nest(jnp.array([[21, 22, 23]])), 309 | globals=self._make_nest(jnp.array([[8, 9]])), 310 | senders=jnp.array([1]), 311 | receivers=jnp.array([0])), 312 | graph.GraphsTuple( 313 | n_node=jnp.array([0]), 314 | n_edge=jnp.array([0]), 315 | nodes=self._make_nest(jnp.zeros((0, 2))), 316 | edges=self._make_nest(jnp.zeros((0, 3))), 317 | globals=self._make_nest(jnp.array([[10, 11]])), 318 | senders=jnp.array([]), 319 | receivers=jnp.array([])), 320 | graph.GraphsTuple( 321 | n_node=jnp.array([0]), 322 | n_edge=jnp.array([0]), 323 | nodes=self._make_nest(jnp.zeros((0, 2))), 324 | edges=self._make_nest(jnp.zeros((0, 3))), 325 | globals=self._make_nest(jnp.array([[12, 13]])), 326 | senders=jnp.array([]), 327 | receivers=jnp.array([])) 328 | ] 329 | 330 | return list_graphs, batched_graph 331 | 332 | @parameterized.parameters(_get_graph_network, 333 | _get_graph_network_no_node_update, 334 | _get_graph_network_no_edge_update, 335 | _get_graph_network_no_global_update, 336 | _get_attention_graph_network, 337 | _get_multi_head_attention_graph_network, 338 | _get_graph_gat) 339 | def test_connect_graphnetwork(self, network_fn): 340 | _, batched_graphs_tuple = self._get_list_and_batched_graph() 341 | with self.subTest('nojit'): 342 | out = network_fn(batched_graphs_tuple) 343 | jax.tree_util.tree_map(np.testing.assert_allclose, out, 344 | batched_graphs_tuple) 345 | with self.subTest('jit'): 346 | out = jax.jit(network_fn)(batched_graphs_tuple) 347 | jax.tree_util.tree_map(np.testing.assert_allclose, out, 348 | batched_graphs_tuple) 349 | 350 | @parameterized.parameters(_get_graph_network, 351 | _get_graph_network_no_node_update, 352 | _get_graph_network_no_edge_update, 353 | _get_graph_network_no_global_update) 354 | def test_connect_graphnetwork_nones(self, network_fn): 355 | batched_graphs_tuple = graph.GraphsTuple( 356 | n_node=jnp.array([1, 3, 1, 0, 2, 0, 0]), 357 | n_edge=jnp.array([2, 5, 0, 0, 1, 0, 0]), 358 | nodes=self._make_nest(jnp.arange(14).reshape(7, 2)), 359 | edges=self._make_nest(jnp.arange(24).reshape(8, 3)), 360 | globals=self._make_nest(jnp.arange(14).reshape(7, 2)), 361 | senders=jnp.array([0, 0, 1, 1, 2, 3, 3, 6]), 362 | receivers=jnp.array([0, 0, 2, 1, 3, 2, 1, 5])) 363 | 364 | for name, graphs_tuple in [ 365 | ('no_globals', batched_graphs_tuple._replace(globals=None)), 366 | ('empty_globals', batched_graphs_tuple._replace(globals=[])), 367 | ('no_edges', batched_graphs_tuple._replace(edges=None)), 368 | ('empty_edges', batched_graphs_tuple._replace(edges=[])), 369 | ]: 370 | with self.subTest(name + '_nojit'): 371 | out = network_fn(graphs_tuple) 372 | jax.tree_util.tree_map(np.testing.assert_allclose, out, graphs_tuple) 373 | with self.subTest(name + '_jit'): 374 | out = jax.jit(network_fn)(graphs_tuple) 375 | jax.tree_util.tree_map(np.testing.assert_allclose, out, graphs_tuple) 376 | 377 | @parameterized.parameters(_get_interaction_network, 378 | _get_graph_independent, 379 | _get_gat, 380 | _get_relation_network, 381 | _get_deep_sets) 382 | def test_connect_gnns(self, network_fn): 383 | batched_graphs_tuple = graph.GraphsTuple( 384 | n_node=jnp.array([1, 3, 1, 0, 2, 0, 0]), 385 | n_edge=jnp.array([1, 7, 1, 0, 3, 0, 0]), 386 | nodes=jnp.arange(14).reshape(7, 2), 387 | edges=jnp.arange(36).reshape(12, 3), 388 | globals=jnp.arange(14).reshape(7, 2), 389 | senders=jnp.array([0, 1, 2, 3, 4, 5, 6, 1, 2, 3, 3, 6]), 390 | receivers=jnp.array([0, 1, 2, 3, 4, 5, 6, 2, 3, 2, 1, 5])) 391 | with self.subTest('nojit'): 392 | out, expected_out = network_fn(batched_graphs_tuple) 393 | jax.tree_util.tree_map(np.testing.assert_allclose, out, expected_out) 394 | with self.subTest('jit'): 395 | out, expected_out = jax.jit(network_fn)(batched_graphs_tuple) 396 | jax.tree_util.tree_map(np.testing.assert_allclose, out, expected_out) 397 | 398 | def test_graphnetwork_attention_error(self): 399 | with self.assertRaisesRegex( 400 | ValueError, ('attention_logit_fn and attention_reduce_fn ' 401 | 'must both be supplied.')): 402 | models.GraphNetwork(update_edge_fn=None, update_node_fn=None, 403 | attention_logit_fn=lambda x: x, 404 | attention_reduce_fn=None) 405 | with self.assertRaisesRegex( 406 | ValueError, ('attention_logit_fn and attention_reduce_fn ' 407 | 'must both be supplied.')): 408 | models.GraphNetwork(update_edge_fn=None, update_node_fn=None, 409 | attention_logit_fn=None, 410 | attention_reduce_fn=lambda x: x) 411 | 412 | 413 | if __name__ == '__main__': 414 | absltest.main() 415 | -------------------------------------------------------------------------------- /jraph/examples/basic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | r"""A basic graphnet example. 16 | 17 | This example just explains the bare mechanics of the library. 18 | """ 19 | 20 | import logging 21 | 22 | from absl import app 23 | import jax 24 | import jraph 25 | import numpy as np 26 | 27 | MASK_BROKEN_MSG = ("Support for jax.mask is currently broken. This is not a " 28 | "jraph error.") 29 | 30 | 31 | def run(): 32 | """Runs basic example.""" 33 | 34 | # Creating graph tuples. 35 | 36 | # Creates a GraphsTuple from scratch containing a single graph. 37 | # The graph has 3 nodes and 2 edges. 38 | # Each node has a 4-dimensional feature vector. 39 | # Each edge has a 5-dimensional feature vector. 40 | # The graph itself has a 6-dimensional feature vector. 41 | single_graph = jraph.GraphsTuple( 42 | n_node=np.asarray([3]), n_edge=np.asarray([2]), 43 | nodes=np.ones((3, 4)), edges=np.ones((2, 5)), 44 | globals=np.ones((1, 6)), 45 | senders=np.array([0, 1]), receivers=np.array([2, 2])) 46 | logging.info("Single graph %r", single_graph) 47 | 48 | # Creates a GraphsTuple from scratch containing a single graph with nested 49 | # feature vectors. 50 | # The graph has 3 nodes and 2 edges. 51 | # The feature vector can be arbitrary nested types of dict, list and tuple, 52 | # or any other type you registered with jax.tree_util.register_pytree_node. 53 | nested_graph = jraph.GraphsTuple( 54 | n_node=np.asarray([3]), n_edge=np.asarray([2]), 55 | nodes={"a": np.ones((3, 4))}, edges={"b": np.ones((2, 5))}, 56 | globals={"c": np.ones((1, 6))}, 57 | senders=np.array([0, 1]), receivers=np.array([2, 2])) 58 | logging.info("Nested graph %r", nested_graph) 59 | 60 | # Creates a GraphsTuple from scratch containing 2 graphs using an implicit 61 | # batch dimension. 62 | # The first graph has 3 nodes and 2 edges. 63 | # The second graph has 1 node and 1 edge. 64 | # Each node has a 4-dimensional feature vector. 65 | # Each edge has a 5-dimensional feature vector. 66 | # The graph itself has a 6-dimensional feature vector. 67 | implicitly_batched_graph = jraph.GraphsTuple( 68 | n_node=np.asarray([3, 1]), n_edge=np.asarray([2, 1]), 69 | nodes=np.ones((4, 4)), edges=np.ones((3, 5)), 70 | globals=np.ones((2, 6)), 71 | senders=np.array([0, 1, 3]), receivers=np.array([2, 2, 3])) 72 | logging.info("Implicitly batched graph %r", implicitly_batched_graph) 73 | 74 | # Batching graphs can be challenging. There are in general two approaches: 75 | # 1. Implicit batching: Independent graphs are combined into the same 76 | # GraphsTuple first, and the padding is added to the combined graph. 77 | # 2. Explicit batching: Pad all graphs to a maximum size, stack them together 78 | # using an explicit batch dimension followed by jax.vmap. 79 | # Both approaches are shown below. 80 | 81 | # Creates a GraphsTuple from two existing GraphsTuple using an implicit 82 | # batch dimension. 83 | # The GraphsTuple will contain three graphs. 84 | implicitly_batched_graph = jraph.batch( 85 | [single_graph, implicitly_batched_graph]) 86 | logging.info("Implicitly batched graph %r", implicitly_batched_graph) 87 | 88 | # Creates multiple GraphsTuples from an existing GraphsTuple with an implicit 89 | # batch dimension. 90 | graph_1, graph_2, graph_3 = jraph.unbatch(implicitly_batched_graph) 91 | logging.info("Unbatched graphs %r %r %r", graph_1, graph_2, graph_3) 92 | 93 | # Creates a padded GraphsTuple from an existing GraphsTuple. 94 | # The padded GraphsTuple will contain 10 nodes, 5 edges, and 4 graphs. 95 | # Three graphs are added for the padding. 96 | # First a dummy graph which contains the padding nodes and edges and secondly 97 | # two empty graphs without nodes or edges to pad out the graphs. 98 | padded_graph = jraph.pad_with_graphs( 99 | single_graph, n_node=10, n_edge=5, n_graph=4) 100 | logging.info("Padded graph %r", padded_graph) 101 | 102 | # Creates a GraphsTuple from an existing padded GraphsTuple. 103 | # The previously added padding is removed. 104 | single_graph = jraph.unpad_with_graphs(padded_graph) 105 | logging.info("Unpadded graph %r", single_graph) 106 | 107 | # Creates a GraphsTuple containing 2 graphs using an explicit batch 108 | # dimension. 109 | # An explicit batch dimension requires more memory, but can simplify 110 | # the definition of functions operating on the graph. 111 | # Explicitly batched graphs require the GraphNetwork to be transformed 112 | # by jax.vmap. 113 | # Using an explicit batch requires padding all feature vectors to 114 | # the maximum size of nodes and edges. 115 | # The first graph has 3 nodes and 2 edges. 116 | # The second graph has 1 node and 1 edge. 117 | # Each node has a 4-dimensional feature vector. 118 | # Each edge has a 5-dimensional feature vector. 119 | # The graph itself has a 6-dimensional feature vector. 120 | explicitly_batched_graph = jraph.GraphsTuple( 121 | n_node=np.asarray([[3], [1]]), n_edge=np.asarray([[2], [1]]), 122 | nodes=np.ones((2, 3, 4)), edges=np.ones((2, 2, 5)), 123 | globals=np.ones((2, 1, 6)), 124 | senders=np.array([[0, 1], [0, -1]]), 125 | receivers=np.array([[2, 2], [0, -1]])) 126 | logging.info("Explicitly batched graph %r", explicitly_batched_graph) 127 | 128 | # Running a graph propagation step. 129 | # First define the update functions for the edges, nodes and globals. 130 | # In this example we use the identity everywhere. 131 | # For Graph neural networks, each update function is typically a neural 132 | # network. 133 | def update_edge_fn( 134 | edge_features, 135 | sender_node_features, 136 | receiver_node_features, 137 | globals_): 138 | """Returns the update edge features.""" 139 | del sender_node_features 140 | del receiver_node_features 141 | del globals_ 142 | return edge_features 143 | 144 | def update_node_fn( 145 | node_features, 146 | aggregated_sender_edge_features, 147 | aggregated_receiver_edge_features, 148 | globals_): 149 | """Returns the update node features.""" 150 | del aggregated_sender_edge_features 151 | del aggregated_receiver_edge_features 152 | del globals_ 153 | return node_features 154 | 155 | def update_globals_fn( 156 | aggregated_node_features, 157 | aggregated_edge_features, 158 | globals_): 159 | """Returns the global features.""" 160 | del aggregated_node_features 161 | del aggregated_edge_features 162 | return globals_ 163 | 164 | # Optionally define custom aggregation functions. 165 | # In this example we use the defaults (so no need to define them explicitly). 166 | aggregate_edges_for_nodes_fn = jraph.segment_sum 167 | aggregate_nodes_for_globals_fn = jraph.segment_sum 168 | aggregate_edges_for_globals_fn = jraph.segment_sum 169 | 170 | # Optionally define an attention logit function and an attention reduce 171 | # function. This can be used for graph attention. 172 | # The attention function calculates attention weights, and the apply 173 | # attention function calculates the new edge feature given the weights. 174 | # We don't use graph attention here, and just pass the defaults. 175 | attention_logit_fn = None 176 | attention_reduce_fn = None 177 | 178 | # Creates a new GraphNetwork in its most general form. 179 | # Most of the arguments have defaults and can be omitted if a feature 180 | # is not used. 181 | # There are also predefined GraphNetworks available (see models.py) 182 | network = jraph.GraphNetwork( 183 | update_edge_fn=update_edge_fn, 184 | update_node_fn=update_node_fn, 185 | update_global_fn=update_globals_fn, 186 | attention_logit_fn=attention_logit_fn, 187 | aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn, 188 | aggregate_nodes_for_globals_fn=aggregate_nodes_for_globals_fn, 189 | aggregate_edges_for_globals_fn=aggregate_edges_for_globals_fn, 190 | attention_reduce_fn=attention_reduce_fn) 191 | 192 | # Runs graph propagation on (implicitly batched) graphs. 193 | updated_graph = network(single_graph) 194 | logging.info("Updated graph from single graph %r", updated_graph) 195 | 196 | updated_graph = network(nested_graph) 197 | logging.info("Updated graph from nested graph %r", nested_graph) 198 | 199 | updated_graph = network(implicitly_batched_graph) 200 | logging.info("Updated graph from implicitly batched graph %r", updated_graph) 201 | 202 | updated_graph = network(padded_graph) 203 | logging.info("Updated graph from padded graph %r", updated_graph) 204 | 205 | # JIT-compile graph propagation. 206 | # Use padded graphs to avoid re-compilation at every step! 207 | jitted_network = jax.jit(network) 208 | updated_graph = jitted_network(padded_graph) 209 | logging.info("(JIT) updated graph from padded graph %r", updated_graph) 210 | logging.info("basic.py complete!") 211 | 212 | 213 | def main(argv): 214 | if len(argv) > 1: 215 | raise app.UsageError("Too many command-line arguments.") 216 | run() 217 | 218 | 219 | if __name__ == "__main__": 220 | app.run(main) 221 | -------------------------------------------------------------------------------- /jraph/examples/e_voting.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | r"""Electronic Voting Example. 16 | 17 | In this example we use DeepSets to estimate the winner of an election. 18 | Each vote is represented by a one-hot encoded vector. 19 | 20 | It goes without saying, but don't use this in a real election! 21 | Seriously, don't! 22 | """ 23 | 24 | import collections 25 | import logging 26 | import random 27 | 28 | from absl import app 29 | import haiku as hk 30 | import jax 31 | import jax.numpy as jnp 32 | import jraph 33 | import numpy as np 34 | import optax 35 | 36 | 37 | Problem = collections.namedtuple("Problem", ("graph", "labels")) 38 | 39 | 40 | def get_voting_problem(min_n_voters: int, max_n_voters: int) -> Problem: 41 | """Creates set of one-hot vectors representing a randomly generated election. 42 | 43 | Args: 44 | min_n_voters: minimum number of voters in the election. 45 | max_n_voters: maximum number of voters in the election. 46 | 47 | Returns: 48 | set, one-hot vector encoding the winner. 49 | """ 50 | n_candidates = 20 51 | n_voters = random.randint(min_n_voters, max_n_voters) 52 | votes = np.random.randint(0, n_candidates, size=(n_voters,)) 53 | one_hot_votes = np.eye(n_candidates)[votes] 54 | winner = np.argmax(np.sum(one_hot_votes, axis=0)) 55 | one_hot_winner = np.eye(n_candidates)[winner] 56 | 57 | graph = jraph.GraphsTuple( 58 | n_node=np.asarray([n_voters]), 59 | n_edge=np.asarray([0]), 60 | nodes=one_hot_votes, 61 | edges=None, 62 | globals=np.zeros((1, n_candidates)), 63 | # There are no edges in our graph. 64 | senders=np.array([], dtype=np.int32), 65 | receivers=np.array([], dtype=np.int32)) 66 | 67 | # In order to jit compile our code, we have to pad the nodes and edges of 68 | # the GraphsTuple to a static shape. 69 | graph = jraph.pad_with_graphs(graph, max_n_voters+1, 0) 70 | 71 | return Problem(graph=graph, labels=one_hot_winner) 72 | 73 | 74 | def network_definition( 75 | graph: jraph.GraphsTuple, 76 | num_message_passing_steps: int = 1) -> jraph.ArrayTree: 77 | """Defines a graph neural network. 78 | 79 | Args: 80 | graph: Graphstuple the network processes. 81 | num_message_passing_steps: number of message passing steps. 82 | 83 | Returns: 84 | globals. 85 | """ 86 | 87 | @jax.vmap 88 | def update_fn(*args): 89 | size = args[0].shape[-1] 90 | return hk.nets.MLP([size, size])(jnp.concatenate(args, axis=-1)) 91 | 92 | for _ in range(num_message_passing_steps): 93 | gn = jraph.DeepSets( 94 | update_node_fn=update_fn, 95 | update_global_fn=update_fn, 96 | aggregate_nodes_for_globals_fn=jraph.segment_mean, 97 | ) 98 | graph = gn(graph) 99 | 100 | return hk.Linear(graph.globals.shape[-1])(graph.globals) 101 | 102 | 103 | def train(num_steps: int): 104 | """Trains a graph neural network on an electronic voting problem.""" 105 | train_dataset = (2, 15) 106 | test_dataset = (16, 20) 107 | random.seed(42) 108 | 109 | network = hk.without_apply_rng(hk.transform(network_definition)) 110 | problem = get_voting_problem(*train_dataset) 111 | params = network.init(jax.random.PRNGKey(42), problem.graph) 112 | 113 | @jax.jit 114 | def prediction_loss(params, problem): 115 | globals_ = network.apply(params, problem.graph) 116 | # We interpret the globals as logits for the winner. 117 | # Only the first graph is real, the second graph is for padding. 118 | log_prob = jax.nn.log_softmax(globals_[0]) * problem.labels 119 | return -jnp.sum(log_prob) 120 | 121 | @jax.jit 122 | def accuracy_loss(params, problem): 123 | globals_ = network.apply(params, problem.graph) 124 | # We interpret the globals as logits for the winner. 125 | # Only the first graph is real, the second graph is for padding. 126 | equal = jnp.argmax(globals_[0]) == jnp.argmax(problem.labels) 127 | return equal.astype(np.int32) 128 | 129 | opt_init, opt_update = optax.adam(2e-4) 130 | opt_state = opt_init(params) 131 | 132 | @jax.jit 133 | def update(params, opt_state, problem): 134 | g = jax.grad(prediction_loss)(params, problem) 135 | updates, opt_state = opt_update(g, opt_state) 136 | return optax.apply_updates(params, updates), opt_state 137 | 138 | for step in range(num_steps): 139 | problem = get_voting_problem(*train_dataset) 140 | params, opt_state = update(params, opt_state, problem) 141 | if step % 1000 == 0: 142 | train_loss = jnp.mean( 143 | jnp.asarray([ 144 | accuracy_loss(params, get_voting_problem(*train_dataset)) 145 | for _ in range(100) 146 | ])).item() 147 | test_loss = jnp.mean( 148 | jnp.asarray([ 149 | accuracy_loss(params, get_voting_problem(*test_dataset)) 150 | for _ in range(100) 151 | ])).item() 152 | logging.info("step %r loss train %r test %r", step, train_loss, test_loss) 153 | 154 | 155 | def main(argv): 156 | if len(argv) > 1: 157 | raise app.UsageError("Too many command-line arguments.") 158 | 159 | train(num_steps=100000) 160 | 161 | 162 | if __name__ == "__main__": 163 | app.run(main) 164 | -------------------------------------------------------------------------------- /jraph/examples/game_of_life.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Implementation of Conway's game of life using jraph.""" 16 | 17 | import time 18 | 19 | from absl import app 20 | import haiku as hk 21 | import jax 22 | import jax.numpy as jnp 23 | import jraph 24 | import numpy as np 25 | 26 | 27 | def conway_mlp(x): 28 | """Implements a MLP representing Conway's game of life rules.""" 29 | w = jnp.array([[0.0, -1.0], [0.0, 1.0], [0.0, 1.0], 30 | [0, -1.0], [1.0, 1.0], [1.0, 1.0]]) 31 | b = jnp.array([3.5, -3.5, -1.5, 1.5, -2.5, -3.5]) 32 | h = jnp.maximum(jnp.dot(w, x) + b, 0.) 33 | w = jnp.array([[2.0, -4.0, 2.0, -4.0, 2.0, -4.0]]) 34 | b = jnp.array([-4.0]) 35 | y = jnp.maximum(jnp.dot(w, h) + b, 0.0) 36 | return y 37 | 38 | 39 | def conway_graph(size) -> jraph.GraphsTuple: 40 | """Returns a graph representing the game field of conway's game of life.""" 41 | # Creates nodes: each node represents a cell in the game. 42 | n_node = size**2 43 | nodes = np.zeros((n_node, 1)) 44 | node_indices = jnp.arange(n_node) 45 | # Creates edges, senders and receivers: 46 | # the senders represent the connections to the 8 neighboring fields. 47 | n_edge = 8 * n_node 48 | edges = jnp.zeros((n_edge, 1)) 49 | senders = jnp.vstack( 50 | [node_indices - size - 1, node_indices - size, node_indices - size + 1, 51 | node_indices - 1, node_indices + 1, 52 | node_indices + size - 1, node_indices + size, node_indices + size + 1]) 53 | senders = senders.T.reshape(-1) 54 | senders = (senders + size**2) % size**2 55 | receivers = jnp.repeat(node_indices, 8) 56 | # Adds a glider to the game 57 | nodes[0, 0] = 1.0 58 | nodes[1, 0] = 1.0 59 | nodes[2, 0] = 1.0 60 | nodes[2 + size, 0] = 1.0 61 | nodes[1 + 2 * size, 0] = 1.0 62 | return jraph.GraphsTuple(n_node=jnp.array([n_node]), 63 | n_edge=jnp.array([n_edge]), 64 | nodes=jnp.asarray(nodes), 65 | edges=edges, 66 | globals=None, 67 | senders=senders, 68 | receivers=receivers) 69 | 70 | 71 | def display_graph(graph: jraph.GraphsTuple): 72 | """Prints the nodes of the graph representing Conway's game of life.""" 73 | size = int(np.sqrt(np.sum(graph.n_node))) 74 | 75 | def _display_node(node): 76 | if node == 1.0: 77 | return 'x' 78 | else: 79 | return ' ' 80 | 81 | nodes = graph.nodes.copy() 82 | output = '\n'.join( 83 | ''.join(_display_node(nodes[i * size + j][0]) 84 | for j in range(size)) 85 | for i in range(size)) 86 | print('-' * size + '\n' + output) 87 | 88 | 89 | def main(_): 90 | 91 | def net_fn(graph: jraph.GraphsTuple): 92 | unf = jraph.concatenated_args(conway_mlp) 93 | net = jraph.InteractionNetwork( 94 | update_edge_fn=lambda e, n_s, n_r: n_s, 95 | update_node_fn=jax.vmap(unf)) 96 | return net(graph) 97 | 98 | net = hk.without_apply_rng(hk.transform(net_fn)) 99 | 100 | cg = conway_graph(size=20) 101 | params = net.init(jax.random.PRNGKey(42), cg) 102 | for _ in range(100): 103 | time.sleep(0.05) 104 | cg = jax.jit(net.apply)(params, cg) 105 | display_graph(cg) 106 | 107 | if __name__ == '__main__': 108 | app.run(main) 109 | -------------------------------------------------------------------------------- /jraph/examples/hamiltonian_graph_network.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Example of a Hamiltonian Graph Network (https://arxiv.org/abs/1909.12790). 16 | 17 | In this example: 18 | * The `GraphNetwork`s implements the hardcoded formulas of the Hooke's 19 | Hamiltonian to return the scalar hamiltonian of a system 20 | (see `hookes_hamiltonian_from_graph_fn`). 21 | * Then JAX autodiff is used to obtain the function of derivatives of the state 22 | via Hamilton equations(see `get_state_derivatives_from_hamiltonian_fn`). 23 | * The function of the derivatives of the state is used by a generic integrator 24 | to simulate steps forward in time (see `single_integration_step`). 25 | * `build_hookes_particle_state_graph` is used to sample the initial state 26 | of the system. 27 | 28 | Note this example does not implement a learned Hamiltonian Graph Network, but 29 | a hardcoded Hamiltonian function corresponding to the Hooke's potential, 30 | implemented as a `GraphNetwork`. This is to show how natural is to express the 31 | true Hamiltonian of a particle system as a Graph Network. However, the only 32 | system-specific code here is the hard coded Hamiltonian in 33 | `hookes_hamiltonian_from_graph_fn` and the data generation provided in 34 | `build_hookes_particle_state_graph`. Everything is reusable for any other 35 | system. 36 | 37 | To implement a learned Hamiltonian Graph Network, one could closely follow 38 | `hookes_hamiltonian_from_graph_fn` using function approximators (e.g. MLP) in 39 | `edge_update_fn`, `node_update_fn` and `global_update_fn` that take as inputs 40 | the concatenated features (e.g. using `jraph.concatenated_args`), with the only 41 | condition that the `global_update_fn` has an output size of 1, to match 42 | the expected output size for a Hamiltonian. Then the learned Hamiltonian Graph 43 | Network would be trained by simply adding a loss term in the position / momentum 44 | of the next step after applying the integrator. 45 | 46 | 47 | Note it is recommended to use immutable container types to store nested edge, 48 | node and global features to avoid unwanted side effects. In this example we 49 | use `frozendict`s, which we register with `jax.tree_util`. 50 | 51 | """ 52 | 53 | import functools 54 | from typing import Tuple, Callable 55 | 56 | from absl import app 57 | from frozendict import frozendict 58 | import jax 59 | import jax.numpy as jnp 60 | import jraph 61 | import matplotlib.pyplot as plt 62 | import numpy as np 63 | 64 | 65 | # Tell tree_util how to navigate frozendicts. 66 | jax.tree_util.register_pytree_node( 67 | frozendict, 68 | flatten_func=lambda s: (tuple(s.values()), tuple(s.keys())), 69 | unflatten_func=lambda k, xs: frozendict(zip(k, xs))) 70 | 71 | 72 | def hookes_hamiltonian_from_graph_fn( 73 | graph: jraph.GraphsTuple) -> jraph.GraphsTuple: 74 | """Computes Hamiltonian of a Hooke's potential system represented in a graph. 75 | 76 | While this function hardcodes the Hamiltonian for a Hooke's potential, a 77 | learned Hamiltonian Graph Network (https://arxiv.org/abs/1909.12790) could 78 | be implemented by replacing the hardcoded formulas by learnable MLPs that 79 | take as inputs all of the concatenated features to the edge_fn, node_fn, 80 | and global_fn, and outputs a single scalar value in the global_fn. 81 | 82 | Args: 83 | graph: `GraphsTuple` where the nodes contain: 84 | - "mass": [num_particles] 85 | - "position": [num_particles, num_dims] 86 | - "momentum": [num_particles, num_dims] 87 | and the edges contain: 88 | - "spring_constant": [num_interations] 89 | 90 | Returns: 91 | `GraphsTuple` with features: 92 | - edge features: "hookes_potential" [num_interactions] 93 | - node features: "kinetic_energy" [num_particles] 94 | - global features: "hamiltonian" [batch_size] 95 | 96 | """ 97 | 98 | def update_edge_fn(edges, senders, receivers, globals_): 99 | del globals_ 100 | distance = jnp.linalg.norm(senders["position"] - receivers["position"]) 101 | hookes_potential_per_edge = 0.5 * edges["spring_constant"] * distance ** 2 102 | return frozendict({"hookes_potential": hookes_potential_per_edge}) 103 | 104 | def update_node_fn(nodes, sent_edges, received_edges, globals_): 105 | del sent_edges, received_edges, globals_ 106 | momentum_norm = jnp.linalg.norm(nodes["momentum"]) 107 | kinetic_energy_per_node = momentum_norm ** 2 / (2 * nodes["mass"]) 108 | return frozendict({"kinetic_energy": kinetic_energy_per_node}) 109 | 110 | def update_global_fn(nodes, edges, globals_): 111 | del globals_ 112 | # At this point we will receive node and edge features aggregated (summed) 113 | # for all nodes and edges in each graph. 114 | hamiltonian_per_graph = nodes["kinetic_energy"] + edges["hookes_potential"] 115 | return frozendict({"hamiltonian": hamiltonian_per_graph}) 116 | 117 | gn = jraph.GraphNetwork( 118 | update_edge_fn=update_edge_fn, 119 | update_node_fn=update_node_fn, 120 | update_global_fn=update_global_fn) 121 | 122 | return gn(graph) 123 | 124 | 125 | # Methods for generating the data. 126 | def build_hookes_particle_state_graph(num_particles: int) -> jraph.GraphsTuple: 127 | """Generates a graph representing a Hooke's system in a random state.""" 128 | 129 | mass = np.random.uniform(0, 5, [num_particles]) 130 | velocity = get_random_uniform_norm2d_vectors(0, 0.1, num_particles) 131 | position = get_random_uniform_norm2d_vectors(0, 1, num_particles) 132 | momentum = velocity * np.expand_dims(mass, axis=-1) 133 | # Remove average momentum, so center of mass does not move. 134 | momentum = momentum - momentum.mean(0, keepdims=True) 135 | 136 | # Connect all particles to all particles. 137 | particle_indices = np.arange(num_particles) 138 | senders, receivers = np.meshgrid(particle_indices, particle_indices) 139 | senders, receivers = senders.flatten(), receivers.flatten() 140 | 141 | # Generate a symmetric random matrix of spring constants. 142 | # Generate random elements stringly in the lower triangular part. 143 | spring_constants = np.random.uniform( 144 | 1e-2, 1e-1, [num_particles, num_particles]) 145 | spring_constants = np.tril( 146 | spring_constants) + np.tril(spring_constants, -1).T 147 | spring_constants = spring_constants.flatten() 148 | 149 | # Remove interactions of particles to themselves. 150 | mask = senders != receivers 151 | senders, receivers = senders[mask], receivers[mask] 152 | spring_constants = spring_constants[mask] 153 | num_interactions = receivers.shape[0] 154 | 155 | return jraph.GraphsTuple( 156 | n_node=np.asarray([num_particles]), 157 | n_edge=np.asarray([num_interactions]), 158 | nodes={ 159 | "mass": mass, # Scalar mass for each particle. 160 | "position": position, # 2d position for each particle. 161 | "momentum": momentum, # 2d momentum for each particle. 162 | }, 163 | edges={ 164 | # Scalar spring constant for each interaction 165 | "spring_constant": spring_constants, 166 | }, 167 | globals={}, 168 | senders=senders, 169 | receivers=receivers) 170 | 171 | 172 | def get_random_uniform_norm2d_vectors( 173 | min_norm: float, max_norm: float, num_particles: int) -> np.ndarray: 174 | """Returns 2-d vectors with random norms.""" 175 | norm = np.random.uniform(min_norm, max_norm, [num_particles, 1]) 176 | angle = np.random.uniform(0, 2*np.pi, [num_particles]) 177 | return norm * np.stack([np.cos(angle), np.sin(angle)], axis=-1) 178 | 179 | 180 | def get_fully_connected_senders_and_receivers( 181 | num_particles: int, self_edges: bool = False, 182 | ) -> Tuple[np.ndarray, np.ndarray]: 183 | """Returns senders and receivers for fully connected particles.""" 184 | particle_indices = np.arange(num_particles) 185 | senders, receivers = np.meshgrid(particle_indices, particle_indices) 186 | senders, receivers = senders.flatten(), receivers.flatten() 187 | if not self_edges: 188 | mask = senders != receivers 189 | senders, receivers = senders[mask], receivers[mask] 190 | return senders, receivers 191 | 192 | 193 | # All code below here is general purpose for any system or integrator. 194 | # Utility methods for getting/setting the state of the particles in the graph 195 | # (position and momentum), and for obtaining the static part of the graph 196 | # (connectivity and particle parameters: masses, spring constants). 197 | def set_system_state( 198 | static_graph: jraph.GraphsTuple, 199 | position: np.ndarray, 200 | momentum: np.ndarray) -> jraph.GraphsTuple: 201 | """Sets the non-static parameters of the graph (momentum, position).""" 202 | nodes = static_graph.nodes.copy(position=position, momentum=momentum) 203 | return static_graph._replace(nodes=nodes) 204 | 205 | 206 | def get_system_state(graph: jraph.GraphsTuple) -> Tuple[np.ndarray, np.ndarray]: 207 | return graph.nodes["position"], graph.nodes["momentum"] 208 | 209 | 210 | def get_static_graph(graph: jraph.GraphsTuple) -> jraph.GraphsTuple: 211 | """Returns the graph with the static parts of a system only.""" 212 | nodes = dict(graph.nodes) 213 | del nodes["position"], nodes["momentum"] 214 | return graph._replace(nodes=frozendict(nodes)) 215 | 216 | 217 | # Utility methods to operate with Hamiltonian functions. 218 | def get_hamiltonian_from_state_fn( 219 | static_graph: jraph.GraphsTuple, 220 | hamiltonian_from_graph_fn: Callable[[jraph.GraphsTuple], jraph.GraphsTuple], 221 | ) -> Callable[[np.ndarray, np.ndarray], float]: 222 | """Returns fn such that fn(position, momentum) -> scalar Hamiltonian. 223 | 224 | Args: 225 | static_graph: `GraphsTuple` containing per-particle static parameters and 226 | connectivity, such as a full graph of the state can be build by calling 227 | `set_system_state(static_graph, position, momentum)`. 228 | hamiltonian_from_graph_fn: callable that given an input `GraphsTuple` 229 | returns a `GraphsTuple` with a "hamiltonian" field in the globals. 230 | 231 | Returns: 232 | Function that given a state (position, momentum) returns the scalar 233 | Hamiltonian. 234 | """ 235 | 236 | def hamiltonian_from_state_fn(position, momentum): 237 | # Note we sum along the batch dimension to get the total energy in the batch 238 | # so get can easily get the gradient. 239 | graph = set_system_state(static_graph, position, momentum) 240 | output_graph = hamiltonian_from_graph_fn(graph) 241 | return output_graph.globals["hamiltonian"].sum() 242 | 243 | return hamiltonian_from_state_fn 244 | 245 | 246 | def get_state_derivatives_from_hamiltonian_fn( 247 | hamiltonian_from_state_fn: Callable[[np.ndarray, np.ndarray], float], 248 | ) -> Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]: 249 | """Returns fn(position, momentum, ...) -> (dposition_dt, dmomentum_dt). 250 | 251 | Args: 252 | hamiltonian_from_state_fn: Function that given a state 253 | (position, momentum) returns the scalar Hamiltonian. 254 | 255 | Returns: 256 | Function that given a state (position, momentum) returns the time 257 | derivatives of the state (dposition_dt, dmomentum_dt) by applying 258 | Hamilton equations. 259 | 260 | """ 261 | 262 | hamiltonian_gradients_fn = jax.grad(hamiltonian_from_state_fn, argnums=[0, 1]) 263 | 264 | def state_derivatives_from_hamiltonian_fn( 265 | position: np.ndarray, momentum: np.ndarray 266 | ) -> Tuple[np.ndarray, np.ndarray]: 267 | # Take the derivatives against position and momentum. 268 | dh_dposition, dh_dmomentum = hamiltonian_gradients_fn(position, momentum) 269 | 270 | # Hamilton equations. 271 | dposition_dt = dh_dmomentum 272 | dmomentum_dt = - dh_dposition 273 | return dposition_dt, dmomentum_dt 274 | return state_derivatives_from_hamiltonian_fn 275 | 276 | 277 | # Implementations of some general purpose integrators for Hamiltonian states. 278 | StateDerivativesFnType = Callable[ 279 | [np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]] 280 | 281 | 282 | def abstract_integrator( 283 | position: np.ndarray, momentum: np.ndarray, time_step: float, 284 | state_derivatives_fn: StateDerivativesFnType, 285 | ) -> Tuple[np.ndarray, np.ndarray]: 286 | """Signature of an abstract integrator. 287 | 288 | An integrator is a function, that given the the current state, a time step, 289 | and a `state_derivatives_fn` returns the next state. 290 | 291 | Args: 292 | position: array with the position at time t. 293 | momentum: array with the momentum at time t. 294 | time_step: integration step size. 295 | state_derivatives_fn: a function fn, that returns time derivatives of a 296 | state such fn(position, momentum) -> (dposition_dt, dmomentum_dt) 297 | where dposition_dt, dmomentum_dt, have the same shapes as 298 | position, momentum. 299 | 300 | Returns: 301 | Tuple with position and momentum at time `t + time_step`. 302 | 303 | """ 304 | raise NotImplementedError("Abstract integrator") 305 | 306 | 307 | def euler_integrator( 308 | position: np.ndarray, momentum: np.ndarray, time_step: float, 309 | state_derivatives_fn: StateDerivativesFnType, 310 | ) -> Tuple[np.ndarray, np.ndarray]: 311 | """Implementation of an Euler integrator (see `abstract_integrator`).""" 312 | dposition_dt, dmomentum_dt = state_derivatives_fn(position, momentum) 313 | next_position = position + dposition_dt * time_step 314 | next_momentum = momentum + dmomentum_dt * time_step 315 | return next_position, next_momentum 316 | 317 | 318 | def verlet_integrator( 319 | position: np.ndarray, momentum: np.ndarray, time_step: float, 320 | state_derivatives_fn: StateDerivativesFnType, 321 | ) -> Tuple[np.ndarray, np.ndarray]: 322 | """Implementation of Verlet integrator (see `abstract_integrator`).""" 323 | 324 | _, dmomentum_dt = state_derivatives_fn(position, momentum) 325 | aux_momentum = momentum + dmomentum_dt * time_step / 2 326 | 327 | dposition_dt, _ = state_derivatives_fn(position, aux_momentum) 328 | next_position = position + dposition_dt * time_step 329 | 330 | _, dmomentum_dt = state_derivatives_fn(next_position, aux_momentum) 331 | next_momentum = aux_momentum + dmomentum_dt * time_step / 2 332 | 333 | return next_position, next_momentum 334 | 335 | 336 | # Single graph -> graph integration step. 337 | IntegratorType = Callable[ 338 | [np.ndarray, np.ndarray, float, StateDerivativesFnType], 339 | Tuple[np.ndarray, np.ndarray] 340 | ] 341 | 342 | 343 | def single_integration_step( 344 | graph: jraph.GraphsTuple, time_step: float, 345 | integrator_fn: IntegratorType, 346 | hamiltonian_from_graph_fn: Callable[[jraph.GraphsTuple], jraph.GraphsTuple], 347 | ) -> Tuple[float, jraph.GraphsTuple]: 348 | """Updates a graph state integrating by a single step. 349 | 350 | Args: 351 | graph: `GraphsTuple` representing a system state at time t. 352 | time_step: size of the timestep to integrate for. 353 | integrator_fn: Integrator to use. A function fn such that 354 | fn(position_t, momentum_t, time_step, state_derivatives_fn) -> 355 | (position_tp1, momentum_tp1) 356 | hamiltonian_from_graph_fn: Function that given a `GraphsTuple`, returns 357 | another one with a "hamiltonian" global field. 358 | 359 | Returns: 360 | `GraphsTuple` representing a system state at time `t + time_step`. 361 | 362 | """ 363 | 364 | # Template graph with particle/interactions parameters and connectiviity 365 | # but without the state (position/momentum). 366 | static_graph = get_static_graph(graph) 367 | 368 | # Get the Hamiltonian function, and the function that returns the state 369 | # derivatives. 370 | hamiltonian_fn = get_hamiltonian_from_state_fn( 371 | static_graph=static_graph, 372 | hamiltonian_from_graph_fn=hamiltonian_from_graph_fn) 373 | state_derivatives_fn = get_state_derivatives_from_hamiltonian_fn( 374 | hamiltonian_fn) 375 | 376 | # Get the current state. 377 | position, momentum = get_system_state(graph) 378 | 379 | # Calling the integrator to get the next state. 380 | next_position, next_momentum = integrator_fn( 381 | position, momentum, time_step, state_derivatives_fn) 382 | next_graph = set_system_state(static_graph, next_position, next_momentum) 383 | 384 | # Return the energy of the next state too for plotting. 385 | energy = hamiltonian_fn(next_position, next_momentum) 386 | 387 | return energy, next_graph 388 | 389 | 390 | def main(_): 391 | 392 | # Get a state function and jit it. 393 | # We could switch to any other Hamiltonian and any other integrator here. 394 | # e.g. the non-symplectic `euler_integrator`. 395 | step_fn = functools.partial( 396 | single_integration_step, 397 | hamiltonian_from_graph_fn=hookes_hamiltonian_from_graph_fn, 398 | integrator_fn=verlet_integrator) 399 | step_fn = jax.jit(step_fn) 400 | 401 | # Get a graph with the initial state. 402 | num_particles = 10 403 | graph = build_hookes_particle_state_graph(num_particles) 404 | 405 | # Iterate for multiple timesteps. 406 | num_steps = 200 407 | time_step = 0.002 408 | positions_sequence = [] 409 | total_energies = [] 410 | steps = [] 411 | for step_i in range(num_steps): 412 | energy, graph = step_fn(graph, time_step) 413 | total_energies.append(energy) 414 | positions_sequence.append(graph.nodes["position"]) 415 | steps.append(step_i + 1) 416 | 417 | # Plot results (positions and energy as a function of time). 418 | unused_fig, axes = plt.subplots(1, 2, figsize=(15, 5)) 419 | positions_sequence_array = np.stack(positions_sequence, axis=0) 420 | axes[0].plot(positions_sequence_array[..., 0], 421 | positions_sequence_array[..., -1]) 422 | axes[0].set_xlabel("Particle position x") 423 | axes[0].set_ylabel("Particle position y") 424 | 425 | axes[1].plot(steps, total_energies) 426 | axes[1].set_ylim(0, max(total_energies)*1.2) 427 | axes[1].set_xlabel("Simulation step") 428 | axes[1].set_ylabel("Total energy") 429 | plt.show() 430 | 431 | if __name__ == "__main__": 432 | app.run(main) 433 | -------------------------------------------------------------------------------- /jraph/examples/higgs_detection.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | r"""Higgs Boson Detection Example. 16 | 17 | One of the decay-channels of the Higgs Boson is Higgs to two photons. 18 | The two photons must have a combined invariant mass 125 GeV. 19 | In this example we use a relational vector to detect if a Higgs Boson 20 | is present in a set of photons. 21 | 22 | There are two situations: 23 | a) Higgs: Two photons with an invariant mass of 125 GeV + an arbitrary number of 24 | uncorrelated photons. 25 | b) No Higgs: Just an arbitrary number of uncorrelation photons. 26 | """ 27 | 28 | import collections 29 | import logging 30 | import random 31 | 32 | from absl import app 33 | import haiku as hk 34 | import jax 35 | import jax.numpy as jnp 36 | import jraph 37 | import numpy as np 38 | import optax 39 | import scipy.stats 40 | 41 | 42 | Problem = collections.namedtuple("Problem", ("graph", "labels")) 43 | 44 | 45 | def get_random_rotation_matrix(): 46 | rotation = np.eye(4) 47 | rotation[1:, 1:] = scipy.stats.ortho_group.rvs(3) 48 | return rotation 49 | 50 | 51 | def get_random_boost_matrix(): 52 | eta = np.random.uniform(-1, 1) 53 | boost = np.eye(4) 54 | boost[:2, :2] = np.array([[np.cosh(eta), -np.sinh(eta)], 55 | [-np.sinh(eta), np.cosh(eta)]]) 56 | rotation = get_random_rotation_matrix() 57 | return rotation.T @ boost @ rotation 58 | 59 | 60 | def get_random_higgs_photons(): 61 | higgs = 125.18 62 | boost = get_random_boost_matrix() 63 | rotation = get_random_rotation_matrix() 64 | photon1 = boost @ rotation @ np.array([higgs / 2, higgs / 2, 0, 0]) 65 | photon2 = boost @ rotation @ np.array([higgs / 2, -higgs / 2, 0, 0]) 66 | return photon1, photon2 67 | 68 | 69 | def get_random_background_photon(): 70 | boost = get_random_boost_matrix() 71 | rotation = get_random_rotation_matrix() 72 | energy = np.random.uniform(20, 120) 73 | return boost @ rotation @ np.array([energy, energy, 0, 0]) 74 | 75 | 76 | def get_higgs_problem(min_n_photons: int, max_n_photons: int) -> Problem: 77 | """Creates fully connected graph containing the detected photons. 78 | 79 | Args: 80 | min_n_photons: minimum number of photons in the detector. 81 | max_n_photons: maximum number of photons in the detector. 82 | 83 | Returns: 84 | graph, one-hot label whether a higgs was present or not. 85 | """ 86 | assert min_n_photons >= 2, "Number of photons must be at least 2." 87 | n_photons = random.randint(min_n_photons, max_n_photons) 88 | photons = np.stack([get_random_background_photon() for _ in range(n_photons)]) 89 | 90 | # Add a higgs 91 | if random.random() > 0.5: 92 | label = np.eye(2)[0] 93 | photons[:2] = np.stack(get_random_higgs_photons()) 94 | else: 95 | label = np.eye(2)[1] 96 | 97 | # The graph is fully connected. 98 | senders = np.repeat(np.arange(n_photons), n_photons) 99 | receivers = np.tile(np.arange(n_photons), n_photons) 100 | graph = jraph.GraphsTuple( 101 | n_node=np.asarray([n_photons]), 102 | n_edge=np.asarray([len(senders)]), 103 | nodes=photons, 104 | edges=None, 105 | globals=None, 106 | senders=senders, 107 | receivers=receivers) 108 | 109 | # In order to jit compile our code, we have to pad the nodes and edges of 110 | # the GraphsTuple to a static shape. 111 | graph = jraph.pad_with_graphs(graph, max_n_photons + 1, 112 | max_n_photons * max_n_photons) 113 | 114 | return Problem(graph=graph, labels=label) 115 | 116 | 117 | def network_definition( 118 | graph: jraph.GraphsTuple) -> jraph.ArrayTree: 119 | """Defines a graph neural network. 120 | 121 | Args: 122 | graph: Graphstuple the network processes. 123 | 124 | Returns: 125 | globals. 126 | """ 127 | 128 | @jax.vmap 129 | @jraph.concatenated_args 130 | def update_edge_fn(features): 131 | return hk.nets.MLP([30, 30, 30])(features) 132 | 133 | # The correct solution for the edge update function is the invariant mass 134 | # of the photon pair. 135 | # The simple MLP we use here seems to fail to find the correct solution. 136 | # You can ensure that the example works in principle by replacing the 137 | # update_edge_fn below with the following analytical solution. 138 | @jax.vmap 139 | def unused_update_edge_fn_solution(s, r): 140 | """Calculates invariant mass of photon pair and compares to Higgs mass.""" 141 | t = (s + r)**2 142 | return jnp.array(jnp.abs(t[0] - t[1] - t[2] - t[3] - 125.18**2) < 1, 143 | dtype=jnp.float32)[None] 144 | 145 | gn = jraph.RelationNetwork( 146 | update_edge_fn=update_edge_fn, 147 | update_global_fn=hk.nets.MLP([2]), 148 | aggregate_edges_for_globals_fn=jraph.segment_sum, 149 | ) 150 | graph = gn(graph) 151 | 152 | return graph.globals 153 | 154 | 155 | def train(num_steps: int): 156 | """Trains a graph neural network on an electronic voting problem.""" 157 | train_dataset = (2, 15) 158 | test_dataset = (16, 20) 159 | random.seed(42) 160 | 161 | network = hk.without_apply_rng(hk.transform(network_definition)) 162 | problem = get_higgs_problem(*train_dataset) 163 | params = network.init(jax.random.PRNGKey(42), problem.graph) 164 | 165 | @jax.jit 166 | def prediction_loss(params, problem): 167 | globals_ = network.apply(params, problem.graph) 168 | # We interpret the globals as logits for the detection. 169 | # Only the first graph is real, the second graph is for padding. 170 | log_prob = jax.nn.log_softmax(globals_[0]) * problem.labels 171 | return -jnp.sum(log_prob) 172 | 173 | @jax.jit 174 | def accuracy_loss(params, problem): 175 | globals_ = network.apply(params, problem.graph) 176 | # We interpret the globals as logits for the detection. 177 | # Only the first graph is real, the second graph is for padding. 178 | equal = jnp.argmax(globals_[0]) == jnp.argmax(problem.labels) 179 | return equal.astype(np.int32) 180 | 181 | opt_init, opt_update = optax.adam(2e-4) 182 | opt_state = opt_init(params) 183 | 184 | @jax.jit 185 | def update(params, opt_state, problem): 186 | g = jax.grad(prediction_loss)(params, problem) 187 | updates, opt_state = opt_update(g, opt_state) 188 | return optax.apply_updates(params, updates), opt_state 189 | 190 | for step in range(num_steps): 191 | problem = get_higgs_problem(*train_dataset) 192 | params, opt_state = update(params, opt_state, problem) 193 | if step % 1000 == 0: 194 | train_loss = jnp.mean( 195 | jnp.asarray([ 196 | accuracy_loss(params, get_higgs_problem(*train_dataset)) 197 | for _ in range(100) 198 | ])).item() 199 | test_loss = jnp.mean( 200 | jnp.asarray([ 201 | accuracy_loss(params, get_higgs_problem(*test_dataset)) 202 | for _ in range(100) 203 | ])).item() 204 | logging.info("step %r loss train %r test %r", step, train_loss, test_loss) 205 | 206 | 207 | def main(argv): 208 | if len(argv) > 1: 209 | raise app.UsageError("Too many command-line arguments.") 210 | 211 | train(num_steps=10000) 212 | 213 | 214 | if __name__ == "__main__": 215 | app.run(main) 216 | -------------------------------------------------------------------------------- /jraph/examples/lstm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Example of how to use recurrent networks (e.g.`LSTM`s) with `GraphNetwork`s. 16 | 17 | Models can use the mechanism for specifying nested node, edge, or global 18 | features to simultaneously keep inputs/embeddings together with a per-node, 19 | per-edge or per-graph recurrent state. 20 | 21 | In this example we show an `InteractionNetwork` that uses an LSTM to keep a 22 | memory of the inputs to the edge model at each step of message passing, by using 23 | separate "embedding" and "state" fields in the edge features. 24 | Following a similar procedure, an LSTM could be added to the `node_update_fn`, 25 | or even the `global_update_fn`, if using a full `GraphNetwork`. 26 | 27 | Note it is recommended to use immutable container types to store nested edge, 28 | node and global features to avoid unwanted side effects. In this example we 29 | use `namedtuple`s. 30 | 31 | """ 32 | 33 | import collections 34 | 35 | from absl import app 36 | import haiku as hk 37 | import jax 38 | import jax.numpy as jnp 39 | import jax.tree_util as tree 40 | import jraph 41 | import numpy as np 42 | 43 | 44 | NUM_NODES = 5 45 | NUM_EDGES = 7 46 | NUM_MESSAGE_PASSING_STEPS = 10 47 | EMBEDDING_SIZE = 32 48 | HIDDEN_SIZE = 128 49 | 50 | # Immutable class for storing nested node/edge features containing an embedding 51 | # and a recurrent state. 52 | StatefulField = collections.namedtuple("StatefulField", ["embedding", "state"]) 53 | 54 | 55 | def get_random_graph() -> jraph.GraphsTuple: 56 | return jraph.GraphsTuple( 57 | n_node=np.asarray([NUM_NODES]), 58 | n_edge=np.asarray([NUM_EDGES]), 59 | nodes=np.random.normal(size=[NUM_NODES, EMBEDDING_SIZE]), 60 | edges=np.random.normal(size=[NUM_EDGES, EMBEDDING_SIZE]), 61 | globals=None, 62 | senders=np.random.randint(0, NUM_NODES, [NUM_EDGES]), 63 | receivers=np.random.randint(0, NUM_NODES, [NUM_EDGES])) 64 | 65 | 66 | def network_definition(graph: jraph.GraphsTuple) -> jraph.ArrayTree: 67 | """`InteractionNetwork` with an LSTM in the edge update.""" 68 | 69 | # LSTM that will keep a memory of the inputs to the edge model. 70 | edge_fn_lstm = hk.LSTM(hidden_size=HIDDEN_SIZE) 71 | 72 | # MLPs used in the edge and the node model. Note that in this instance 73 | # the output size matches the input size so the same model can be run 74 | # iteratively multiple times. In a real model, this would usually be achieved 75 | # by first using an encoder in the input data into a common `EMBEDDING_SIZE`. 76 | edge_fn_mlp = hk.nets.MLP([HIDDEN_SIZE, EMBEDDING_SIZE]) 77 | node_fn_mlp = hk.nets.MLP([HIDDEN_SIZE, EMBEDDING_SIZE]) 78 | 79 | # Initialize the edge features to contain both the input edge embedding 80 | # and initial LSTM state. Note for the nodes we only have an embedding since 81 | # in this example nodes do not use a `node_fn_lstm`, but for analogy, we 82 | # still put it in a `StatefulField`. 83 | graph = graph._replace( 84 | edges=StatefulField( 85 | embedding=graph.edges, 86 | state=edge_fn_lstm.initial_state(graph.edges.shape[0])), 87 | nodes=StatefulField(embedding=graph.nodes, state=None), 88 | ) 89 | 90 | def update_edge_fn(edges, sender_nodes, receiver_nodes): 91 | # We will run an LSTM memory on the inputs first, and then 92 | # process the output of the LSTM with an MLP. 93 | edge_inputs = jnp.concatenate([edges.embedding, 94 | sender_nodes.embedding, 95 | receiver_nodes.embedding], axis=-1) 96 | lstm_output, updated_state = edge_fn_lstm(edge_inputs, edges.state) 97 | updated_edges = StatefulField( 98 | embedding=edge_fn_mlp(lstm_output), state=updated_state, 99 | ) 100 | return updated_edges 101 | 102 | def update_node_fn(nodes, received_edges): 103 | # Note `received_edges.state` will also contain the aggregated state for 104 | # all received edges, which we may choose to use in the node update. 105 | node_inputs = jnp.concatenate( 106 | [nodes.embedding, received_edges.embedding], axis=-1) 107 | updated_nodes = StatefulField( 108 | embedding=node_fn_mlp(node_inputs), 109 | state=None) 110 | return updated_nodes 111 | 112 | recurrent_graph_network = jraph.InteractionNetwork( 113 | update_edge_fn=update_edge_fn, 114 | update_node_fn=update_node_fn) 115 | 116 | # Apply the model recurrently for 10 message passing steps. 117 | # If instead we intended to use the LSTM to process a sequence of features 118 | # for each node/edge, here we would select the corresponding inputs from the 119 | # sequence along the sequence axis of the nodes/edges features to build the 120 | # correct input graph for each step of the iteration. 121 | num_message_passing_steps = 10 122 | for _ in range(num_message_passing_steps): 123 | graph = recurrent_graph_network(graph) 124 | 125 | return graph 126 | 127 | 128 | def main(_): 129 | 130 | network = hk.without_apply_rng(hk.transform(network_definition)) 131 | input_graph = get_random_graph() 132 | params = network.init(jax.random.PRNGKey(42), input_graph) 133 | output_graph = network.apply(params, input_graph) 134 | print(tree.tree_map(lambda x: x.shape, output_graph)) 135 | 136 | 137 | if __name__ == "__main__": 138 | app.run(main) 139 | -------------------------------------------------------------------------------- /jraph/examples/sat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | r"""2-SAT solver example. 16 | 17 | Here we train a graph neural network to solve 2-sat problems. 18 | https://en.wikipedia.org/wiki/2-satisfiability 19 | 20 | For instance a 2 sat problem with 3 literals would look like this: 21 | (a or b) and (not a or c) and (not b or not c) 22 | 23 | We represent this problem in form of a bipartite-graph, with edges 24 | connecting the literal-nodes (a, b, c) with the constraint-nodes (O). 25 | The corresponding graph looks like this: 26 | O O O 27 | |\ /\ /| 28 | | \/ \/ | 29 | | /\ /\ | 30 | |/ \/ \| 31 | a b c 32 | 33 | The nodes are one-hot encoded with literal nodes as (1, 0) and constraint nodes 34 | as (0, 1). The edges are one-hot encoded with (1, 0) if the literal should be 35 | true and (0, 1) if the literal should be false. 36 | 37 | The graph neural network encodes the nodes and the edges and runs multiple 38 | message passing steps by calculating message for each edge and aggregating 39 | all the messages of the nodes. 40 | 41 | The training dataset consists of randomly generated 2-sat problems with 2 to 15 42 | literals. 43 | The test dataset consists of randomly generated 2-sat problems with 16 to 20 44 | literals. 45 | """ 46 | 47 | import collections 48 | import logging 49 | import random 50 | 51 | from absl import app 52 | import haiku as hk 53 | import jax 54 | import jax.numpy as jnp 55 | import jraph 56 | import numpy as np 57 | import optax 58 | 59 | 60 | Problem = collections.namedtuple("Problem", ("graph", "labels", "mask")) 61 | 62 | 63 | def get_2sat_problem(min_n_literals: int, max_n_literals: int) -> Problem: 64 | """Creates bipartite-graph representing a randomly generated 2-sat problem. 65 | 66 | Args: 67 | min_n_literals: minimum number of literals in the 2-sat problem. 68 | max_n_literals: maximum number of literals in the 2-sat problem. 69 | 70 | Returns: 71 | bipartite-graph, node labels and node mask. 72 | """ 73 | n_literals = random.randint(min_n_literals, max_n_literals) 74 | n_literals_true = random.randint(1, n_literals - 1) 75 | n_constraints = n_literals * (n_literals - 1) // 2 76 | 77 | n_node = n_literals + n_constraints 78 | # 0 indicates a literal node 79 | # 1 indicates a constraint node. 80 | nodes = [0 if i < n_literals else 1 for i in range(n_node)] 81 | edges = [] 82 | senders = [] 83 | for literal_node1 in range(n_literals): 84 | for literal_node2 in range(literal_node1 + 1, n_literals): 85 | senders.append(literal_node1) 86 | senders.append(literal_node2) 87 | # 1 indicates that the literal must be true for this constraint. 88 | # 0 indicates that the literal must be false for this constraint. 89 | # I.e. with literals a and b, we have the following possible constraints: 90 | # 0, 0 -> a or b 91 | # 1, 0 -> not a or b 92 | # 0, 1 -> a or not b 93 | # 1, 1 -> not a or not b 94 | edges.append(1 if literal_node1 < n_literals_true else 0) 95 | edges.append(1 if literal_node2 < n_literals_true else 0) 96 | 97 | graph = jraph.GraphsTuple( 98 | n_node=np.asarray([n_node]), 99 | n_edge=np.asarray([2 * n_constraints]), 100 | # One-hot encoding for nodes and edges. 101 | nodes=np.eye(2)[nodes], 102 | edges=np.eye(2)[edges], 103 | globals=None, 104 | senders=np.asarray(senders), 105 | receivers=np.repeat(np.arange(n_constraints) + n_literals, 2)) 106 | 107 | # In order to jit compile our code, we have to pad the nodes and edges of 108 | # the GraphsTuple to a static shape. 109 | max_n_constraints = max_n_literals * (max_n_literals - 1) // 2 110 | max_nodes = max_n_literals + max_n_constraints + 1 111 | max_edges = 2 * max_n_constraints 112 | graph = jraph.pad_with_graphs(graph, max_nodes, max_edges) 113 | 114 | # The ground truth solution for the 2-sat problem. 115 | labels = (np.arange(max_nodes) < n_literals_true).astype(np.int32) 116 | labels = np.eye(2)[labels] 117 | 118 | # For the loss calculation we create a mask for the nodes, which masks the 119 | # the constraint nodes and the padding nodes. 120 | mask = (np.arange(max_nodes) < n_literals).astype(np.int32) 121 | return Problem(graph=graph, labels=labels, mask=mask) 122 | 123 | 124 | def network_definition( 125 | graph: jraph.GraphsTuple, 126 | num_message_passing_steps: int = 5) -> jraph.ArrayTree: 127 | """Defines a graph neural network. 128 | 129 | Args: 130 | graph: Graphstuple the network processes. 131 | num_message_passing_steps: number of message passing steps. 132 | 133 | Returns: 134 | Decoded nodes. 135 | """ 136 | embedding = jraph.GraphMapFeatures( 137 | embed_edge_fn=jax.vmap(hk.Linear(output_size=16)), 138 | embed_node_fn=jax.vmap(hk.Linear(output_size=16))) 139 | graph = embedding(graph) 140 | 141 | @jax.vmap 142 | @jraph.concatenated_args 143 | def update_fn(features): 144 | net = hk.Sequential([ 145 | hk.Linear(10), jax.nn.relu, 146 | hk.Linear(10), jax.nn.relu, 147 | hk.Linear(10), jax.nn.relu]) 148 | return net(features) 149 | 150 | for _ in range(num_message_passing_steps): 151 | gn = jraph.InteractionNetwork( 152 | update_edge_fn=update_fn, 153 | update_node_fn=update_fn, 154 | include_sent_messages_in_node_update=True) 155 | graph = gn(graph) 156 | 157 | return hk.Linear(2)(graph.nodes) 158 | 159 | 160 | def train(num_steps: int): 161 | """Trains a graph neural network on a 2-sat problem.""" 162 | train_dataset = (2, 15) 163 | test_dataset = (16, 20) 164 | random.seed(42) 165 | 166 | network = hk.without_apply_rng(hk.transform(network_definition)) 167 | problem = get_2sat_problem(*train_dataset) 168 | params = network.init(jax.random.PRNGKey(42), problem.graph) 169 | 170 | @jax.jit 171 | def prediction_loss(params, problem): 172 | decoded_nodes = network.apply(params, problem.graph) 173 | # We interpret the decoded nodes as a pair of logits for each node. 174 | log_prob = jax.nn.log_softmax(decoded_nodes) * problem.labels 175 | return -jnp.sum(log_prob * problem.mask[:, None]) / jnp.sum(problem.mask) 176 | 177 | opt_init, opt_update = optax.adam(2e-4) 178 | opt_state = opt_init(params) 179 | 180 | @jax.jit 181 | def update(params, opt_state, problem): 182 | g = jax.grad(prediction_loss)(params, problem) 183 | updates, opt_state = opt_update(g, opt_state) 184 | return optax.apply_updates(params, updates), opt_state 185 | 186 | for step in range(num_steps): 187 | problem = get_2sat_problem(*train_dataset) 188 | params, opt_state = update(params, opt_state, problem) 189 | if step % 1000 == 0: 190 | train_loss = jnp.mean( 191 | jnp.asarray([ 192 | prediction_loss(params, get_2sat_problem(*train_dataset)) 193 | for _ in range(100) 194 | ])).item() 195 | test_loss = jnp.mean( 196 | jnp.asarray([ 197 | prediction_loss(params, get_2sat_problem(*test_dataset)) 198 | for _ in range(100) 199 | ])).item() 200 | logging.info("step %r loss train %r test %r", step, train_loss, test_loss) 201 | 202 | 203 | def main(argv): 204 | if len(argv) > 1: 205 | raise app.UsageError("Too many command-line arguments.") 206 | 207 | train(num_steps=10000) 208 | 209 | 210 | if __name__ == "__main__": 211 | app.run(main) 212 | -------------------------------------------------------------------------------- /jraph/examples/zacharys_karate_club.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | r"""Zachary's karate club example. 16 | 17 | Here we train a graph neural network to process Zachary's karate club. 18 | https://en.wikipedia.org/wiki/Zachary%27s_karate_club 19 | 20 | Zachary's karate club is used in the literature as an example of a social graph. 21 | Here we use a graphnet to optimize the assignments of the students in the 22 | karate club to two distinct karate instructors (Mr. Hi and John A). 23 | """ 24 | 25 | import logging 26 | 27 | from absl import app 28 | import haiku as hk 29 | import jax 30 | import jax.numpy as jnp 31 | import jraph 32 | import optax 33 | 34 | 35 | def get_zacharys_karate_club() -> jraph.GraphsTuple: 36 | """Returns GraphsTuple representing Zachary's karate club.""" 37 | social_graph = [ 38 | (1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2), 39 | (4, 0), (5, 0), (6, 0), (6, 4), (6, 5), (7, 0), (7, 1), 40 | (7, 2), (7, 3), (8, 0), (8, 2), (9, 2), (10, 0), (10, 4), 41 | (10, 5), (11, 0), (12, 0), (12, 3), (13, 0), (13, 1), (13, 2), 42 | (13, 3), (16, 5), (16, 6), (17, 0), (17, 1), (19, 0), (19, 1), 43 | (21, 0), (21, 1), (25, 23), (25, 24), (27, 2), (27, 23), 44 | (27, 24), (28, 2), (29, 23), (29, 26), (30, 1), (30, 8), 45 | (31, 0), (31, 24), (31, 25), (31, 28), (32, 2), (32, 8), 46 | (32, 14), (32, 15), (32, 18), (32, 20), (32, 22), (32, 23), 47 | (32, 29), (32, 30), (32, 31), (33, 8), (33, 9), (33, 13), 48 | (33, 14), (33, 15), (33, 18), (33, 19), (33, 20), (33, 22), 49 | (33, 23), (33, 26), (33, 27), (33, 28), (33, 29), (33, 30), 50 | (33, 31), (33, 32)] 51 | # Add reverse edges. 52 | social_graph += [(edge[1], edge[0]) for edge in social_graph] 53 | n_club_members = 34 54 | 55 | return jraph.GraphsTuple( 56 | n_node=jnp.asarray([n_club_members]), 57 | n_edge=jnp.asarray([len(social_graph)]), 58 | # One-hot encoding for nodes. 59 | nodes=jnp.eye(n_club_members), 60 | # No edge features. 61 | edges=None, 62 | globals=None, 63 | senders=jnp.asarray([edge[0] for edge in social_graph]), 64 | receivers=jnp.asarray([edge[1] for edge in social_graph])) 65 | 66 | 67 | def get_ground_truth_assignments_for_zacharys_karate_club() -> jnp.ndarray: 68 | """Returns ground truth assignments for Zachary's karate club.""" 69 | return jnp.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 70 | 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) 71 | 72 | 73 | def network_definition(graph: jraph.GraphsTuple) -> jraph.ArrayTree: 74 | """Implements the GCN from Kipf et al https://arxiv.org/pdf/1609.02907.pdf. 75 | 76 | A' = D^{-0.5} A D^{-0.5} 77 | Z = f(X, A') = A' relu(A' X W_0) W_1 78 | 79 | Args: 80 | graph: GraphsTuple the network processes. 81 | 82 | Returns: 83 | processed nodes. 84 | """ 85 | gn = jraph.GraphConvolution( 86 | update_node_fn=hk.Linear(5, with_bias=False), 87 | add_self_edges=True) 88 | graph = gn(graph) 89 | graph = graph._replace(nodes=jax.nn.relu(graph.nodes)) 90 | gn = jraph.GraphConvolution( 91 | update_node_fn=hk.Linear(2, with_bias=False)) 92 | graph = gn(graph) 93 | return graph.nodes 94 | 95 | 96 | def optimize_club(num_steps: int): 97 | """Solves the karte club problem by optimizing the assignments of students.""" 98 | network = hk.without_apply_rng(hk.transform(network_definition)) 99 | zacharys_karate_club = get_zacharys_karate_club() 100 | labels = get_ground_truth_assignments_for_zacharys_karate_club() 101 | params = network.init(jax.random.PRNGKey(42), zacharys_karate_club) 102 | 103 | @jax.jit 104 | def prediction_loss(params): 105 | decoded_nodes = network.apply(params, zacharys_karate_club) 106 | # We interpret the decoded nodes as a pair of logits for each node. 107 | log_prob = jax.nn.log_softmax(decoded_nodes) 108 | # The only two assignments we know a-priori are those of Mr. Hi (Node 0) 109 | # and John A (Node 33). 110 | return -(log_prob[0, 0] + log_prob[33, 1]) 111 | 112 | opt_init, opt_update = optax.adam(1e-2) 113 | opt_state = opt_init(params) 114 | 115 | @jax.jit 116 | def update(params, opt_state): 117 | g = jax.grad(prediction_loss)(params) 118 | updates, opt_state = opt_update(g, opt_state) 119 | return optax.apply_updates(params, updates), opt_state 120 | 121 | @jax.jit 122 | def accuracy(params): 123 | decoded_nodes = network.apply(params, zacharys_karate_club) 124 | return jnp.mean(jnp.argmax(decoded_nodes, axis=1) == labels) 125 | 126 | for step in range(num_steps): 127 | logging.info("step %r accuracy %r", step, accuracy(params).item()) 128 | params, opt_state = update(params, opt_state) 129 | 130 | 131 | def main(_): 132 | optimize_club(num_steps=30) 133 | 134 | 135 | if __name__ == "__main__": 136 | app.run(main) 137 | -------------------------------------------------------------------------------- /jraph/experimental/sharded_graphnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Sharded (Data Parallel) Graph Nets.""" 15 | 16 | import functools 17 | from typing import Callable, List, NamedTuple, Optional 18 | import jax 19 | import jax.numpy as jnp 20 | import jax.tree_util as tree 21 | import jraph 22 | from jraph._src import graph as gn_graph 23 | from jraph._src import utils 24 | import numpy as np 25 | 26 | 27 | class ShardedEdgesGraphsTuple(NamedTuple): 28 | """A `GraphsTuple` for use with `ShardedEdgesGraphNetwork`. 29 | 30 | NOTES: 31 | - A ShardedEdgesGraphNetwork is for use with `jax.pmap`. As such, it will have 32 | a leading axis of size `num_devices` on the host, but no such axis on 33 | device. Non-sharded data is replicated on each device. To achieve this with 34 | `jax.pmap` you can broadcast non-sharded data to have leading axis 35 | 'num_devices' or use the 'in_axes' parameter, which will indicate which 36 | attributes should be replicated and which should not. Current helper methods 37 | use the first approach. 38 | - It is recommended that you constructed ShardedEdgesGraphsTuples with 39 | `graphs_tuple_to_broadcasted_sharded_grahs_tuple`. 40 | 41 | 42 | The values of `nodes`, `device_edges` and `globals` can be gn_graph.ArrayTree 43 | - nests of features with `jax` compatible values. For example, `nodes` in a 44 | graph may have more than one type of attribute. 45 | 46 | However, the ShardedEdgesGraphsTuple typically takes the following form for a 47 | batch of `n` graphs: 48 | 49 | - n_node: The number of nodes per graph. It is a vector of integers with shape 50 | `[n_graphs]`, such that `graph.n_node[i]` is the number of nodes in the i-th 51 | graph. 52 | 53 | - n_edge: The number of edges per graph. It is a vector of integers with shape 54 | `[n_graphs]`, such that `graph.n_edge[i]` is the number of edges in the i-th 55 | graph. 56 | 57 | - nodes: The nodes features. It is either `None` (the graph has no node 58 | features), or a vector of shape `[n_nodes] + node_shape`, where 59 | `n_nodes = sum(graph.n_node)` is the total number of nodes in the batch of 60 | graphs, and `node_shape` represents the shape of the features of each node. 61 | The relative index of a node from the batched version can be recovered from 62 | the `graph.n_node` property. For instance, the second node of the third 63 | graph will have its features in the 64 | `1 + graph.n_node[0] + graph.n_node[1]`-th slot of graph.nodes. 65 | Observe that having a `None` value for this field does not mean that the 66 | graphs have no nodes, only that they do not have node features. 67 | 68 | - receivers: The indices of the receiver nodes, for each edge. It is either 69 | `None` (if the graph has no edges), or a vector of integers of shape 70 | `[n_edges]`, such that `graph.receivers[i]` is the index of the node 71 | receiving from the i-th edge. 72 | 73 | Observe that the index is absolute (in other words, cumulative), i.e. 74 | `graphs.receivers` take value in `[0, n_nodes]`. For instance, an edge 75 | connecting the vertices with relative indices 2 and 3 in the second graph of 76 | the batch would have a `receivers` value of `3 + graph.n_node[0]`. 77 | If `graphs.receivers` is `None`, then `graphs.edges` and `graphs.senders` 78 | should also be `None`. 79 | 80 | - senders: The indices of the sender nodes, for each edge. It is either 81 | `None` (if the graph has no edges), or a vector of integers of shape 82 | `[n_edges]`, such that `graph.senders[i]` is the index of the node 83 | sending from the i-th edge. 84 | 85 | Observe that the index is absolute, i.e. `graphs.senders` take value in 86 | `[0, n_nodes]`. For instance, an edge connecting the vertices with relative 87 | indices 1 and 3 in the third graph of the batch would have a `senders` value 88 | of `1 + graph.n_node[0] + graph.n_node[1]`. 89 | 90 | If `graphs.senders` is `None`, then `graphs.edges` and `graphs.receivers` 91 | should also be `None`. 92 | 93 | - globals: The global features of the graph. It is either `None` (the graph 94 | has no global features), or a vector of shape `[n_graphs] + global_shape` 95 | representing graph level features. 96 | 97 | The ShardedEdgesGraphsTuple also contains device-local attributes that are 98 | used for data parallel computation. On the host, each of these attributes will 99 | have an additional leading axis of shape `num_devices` for use with 100 | `jax.pmap`, but this is ommited in the following documentation. 101 | 102 | - device_edges: The subset of the edge features that are on the device. 103 | It is either `None` (the graph has no edge features), or a vector of 104 | shape `[num_edges / num_devices] + edge_shape` 105 | 106 | Observe that having a `None` value for this field does not mean that the 107 | graph has no edges, only that they do not have edge features. 108 | 109 | - device_senders: The sender indices of edges on device. This is of length 110 | num_edges / num_devices. 111 | 112 | - device_receivers: The receiver indices of edge on device. This is of length 113 | num_edges / num_devices. 114 | 115 | - device_n_edge: The graph partitions of the edges on device. For example, 116 | say that there are 2 graphs in the original graphs tuple, with n_edge 117 | [1, 11], which has been split over 3 devices. The `device_n_edge`s would 118 | be [[1, 3], [4, 0], [4, 0]]. `0` valued entries that are padding values or 119 | graphs with zero edges are not distinguished. Since these attributes are 120 | used only for `repeat` purposes, the difference makes no difference to 121 | the implementation. 122 | 123 | - device_graph_idx: The indices of the graphs on device. For example, say 124 | that there are 5 graphs in the original graphs tuple, and these has been 125 | split over 3 devices, the device_graphs_idxs could be 126 | [[0, 1, 2], [2, 3, 0], [3, 4, 0]]. In this splitting, the third graph 127 | is split over 2 devices. If a `0` is the first in `device_graph_idx` then 128 | that indicates the first graph, otherwise it indicates a padding value. 129 | """ 130 | nodes: gn_graph.ArrayTree 131 | device_edges: gn_graph.ArrayTree 132 | device_receivers: jnp.ndarray # with integer dtype 133 | device_senders: jnp.ndarray # with integer dtype 134 | receivers: jnp.ndarray # with integer dtype 135 | senders: jnp.ndarray # with integer dtype 136 | globals: gn_graph.ArrayTree 137 | device_n_edge: jnp.ndarray # with integer dtype 138 | n_node: jnp.ndarray # with integer dtype 139 | n_edge: jnp.ndarray # with integer dtype 140 | device_graph_idx: jnp.ndarray # with integer dtype 141 | 142 | 143 | def graphs_tuple_to_broadcasted_sharded_graphs_tuple( 144 | graphs_tuple: jraph.GraphsTuple, 145 | num_shards: int) -> ShardedEdgesGraphsTuple: 146 | """Converts a `GraphsTuple` to a `ShardedEdgesGraphsTuple` to use with `pmap`. 147 | 148 | For a given number of shards this will compute device-local edge and graph 149 | attributes, and add a batch axis of size num_shards. You can then use 150 | `ShardedEdgesGraphNetwork` with `jax.pmap`. 151 | 152 | Args: 153 | graphs_tuple: The `GraphsTuple` to be converted to a sharded `GraphsTuple`. 154 | num_shards: The number of devices to shard over. 155 | 156 | Returns: 157 | A ShardedEdgesGraphsTuple over the number of shards. 158 | """ 159 | # Note: this is not jittable, so to prevent using a device by accident, 160 | # this is all happening in numpy. 161 | nodes, edges, receivers, senders, globals_, n_node, n_edge = graphs_tuple 162 | if np.sum(n_edge) % num_shards != 0: 163 | raise ValueError(('The number of edges in a `graph.GraphsTuple` must be ' 164 | 'divisible by the number of devices per replica.')) 165 | if np.sum(np.array(n_edge)) == 0: 166 | raise ValueError('The input `Graphstuple` must have edges.') 167 | # Broadcast replicated features to have a `num_shards` leading axis. 168 | # pylint: disable=g-long-lambda 169 | broadcast = lambda x: np.broadcast_to(x[None, :], (num_shards,) + x.shape) 170 | # pylint: enable=g-long-lambda 171 | 172 | # `edges` will be straightforwardly sharded, with 1/num_shards of 173 | # the edges on each device. 174 | def shard_edges(edge_features): 175 | return np.reshape(edge_features, (num_shards, -1) + edge_features.shape[1:]) 176 | 177 | edges = jax.tree_map(shard_edges, edges) 178 | # Our sharded strategy is by edges - which means we need a device local 179 | # n_edge, senders and receivers to do global aggregations. 180 | 181 | # Senders and receivers are easy - 1/num_shards per device. 182 | device_senders = shard_edges(senders) 183 | device_receivers = shard_edges(receivers) 184 | 185 | # n_edge is a bit more difficult. Let's say we have a graphs tuple with 186 | # n_edge [2, 8], and we want to distribute this on two devices. Then 187 | # we will have sharded the edges to [5, 5], so the n_edge per device will be 188 | # [2,3], and [5]. Since we need to have each of the n_edge the same shape, 189 | # we will need to pad this to [5,0]. This is a bit dangerous, as the zero 190 | # here has a different meaning to a graph with zero edges, but we need the 191 | # zero for the global broadcasting to be correct for aggregation. Since 192 | # this will only be used in the first instance for global broadcasting on 193 | # device I think this is ok, but ideally we'd have a more elegant solution. 194 | # TODO(jonathangodwin): think of a more elegant solution. 195 | edges_per_device = np.sum(n_edge) // num_shards 196 | edges_in_current_split = 0 197 | completed_splits = [] 198 | current_split = {'n_edge': [], 'device_graph_idx': []} 199 | for device_graph_idx, x in enumerate(n_edge): 200 | new_edges_in_current_split = edges_in_current_split + x 201 | if new_edges_in_current_split > edges_per_device: 202 | # A single graph may be spread across multiple replicas, so here we 203 | # iteratively create new splits until the graph is exhausted. 204 | 205 | # How many edges we are trying to allocate. 206 | carry = x 207 | # How much room there is in the current split for new edges. 208 | space_in_current_split = edges_per_device - edges_in_current_split 209 | while carry > 0: 210 | if carry >= space_in_current_split: 211 | # We've encountered a situation where we need to split a graph across 212 | # >= 2 devices. We compute the number we will carry to the next split, 213 | # and add a full split. 214 | carry = carry - space_in_current_split 215 | # Add the left edges to the current split, and complete the split 216 | # by adding it to completed_splits. 217 | current_split['n_edge'].append(space_in_current_split) 218 | current_split['device_graph_idx'].append(device_graph_idx) 219 | completed_splits.append(current_split) 220 | # reset the split 221 | current_split = {'n_edge': [], 'device_graph_idx': []} 222 | 223 | space_in_current_split = edges_per_device 224 | edges_in_current_split = 0 225 | else: 226 | current_split = { 227 | 'n_edge': [carry], 228 | 'device_graph_idx': [device_graph_idx] 229 | } 230 | edges_in_current_split = carry 231 | carry = 0 232 | # Since the total number of edges must be divisible by the number 233 | # of devices, this code path can only be executed for an intermediate 234 | # graph, thus it is not a complete split and we never need to add it 235 | # to `completed splits`. 236 | else: 237 | # Add the edges and globals to the current split. 238 | current_split['n_edge'].append(x) 239 | current_split['device_graph_idx'].append(device_graph_idx) 240 | # If we've reached the end of a split, complete it and start a new one. 241 | if new_edges_in_current_split == edges_per_device: 242 | completed_splits.append(current_split) 243 | current_split = {'n_edge': [], 'device_graph_idx': []} 244 | edges_in_current_split = 0 245 | else: 246 | edges_in_current_split = new_edges_in_current_split 247 | 248 | # Flatten list of dicts to dict of lists. 249 | completed_splits = { 250 | k: [d[k] for d in completed_splits] for k in completed_splits[0] 251 | } 252 | pad_split_to = max([len(x) for x in completed_splits['n_edge']]) 253 | pad = lambda x: np.pad(x, (0, pad_split_to - len(x)), mode='constant') 254 | device_n_edge = np.array([pad(x) for x in completed_splits['n_edge']]) 255 | device_graph_idx = np.array( 256 | [pad(x) for x in completed_splits['device_graph_idx']]) 257 | return ShardedEdgesGraphsTuple( 258 | nodes=jax.tree_map(broadcast, nodes), 259 | device_edges=edges, 260 | device_receivers=device_receivers, 261 | device_senders=device_senders, 262 | receivers=broadcast(receivers), 263 | senders=broadcast(senders), 264 | device_graph_idx=device_graph_idx, 265 | globals=jax.tree_map(broadcast, globals_), 266 | n_node=broadcast(n_node), 267 | n_edge=broadcast(n_edge), 268 | device_n_edge=device_n_edge) 269 | 270 | 271 | def broadcasted_sharded_graphs_tuple_to_graphs_tuple(sharded_graphs_tuple): 272 | """Converts a broadcasted ShardedGraphsTuple to a GraphsTuple.""" 273 | # We index the first element of replicated arrays, since they have been 274 | # repeated. For edges, we reshape to recover all of the edge features. 275 | unbroadcast = lambda y: tree.tree_map(lambda x: x[0], y) 276 | unshard = lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],) + x.shape[2:]) 277 | # TODO(jonathangodwin): check senders and receivers are consistent. 278 | return jraph.GraphsTuple( 279 | nodes=unbroadcast(sharded_graphs_tuple.nodes), 280 | edges=tree.tree_map(unshard, sharded_graphs_tuple.device_edges), 281 | n_node=sharded_graphs_tuple.n_node[0], 282 | n_edge=sharded_graphs_tuple.n_edge[0], 283 | globals=unbroadcast(sharded_graphs_tuple.globals), 284 | senders=sharded_graphs_tuple.senders[0], 285 | receivers=sharded_graphs_tuple.receivers[0]) 286 | 287 | 288 | def sharded_segment_sum(data, indices, num_segments, axis_index_groups): 289 | """Segment sum over data on multiple devices.""" 290 | device_segment_sum = utils.segment_sum(data, indices, num_segments) 291 | return jax.lax.psum( 292 | device_segment_sum, axis_name='i', axis_index_groups=axis_index_groups) 293 | 294 | 295 | ShardedEdgeFeatures = gn_graph.ArrayTree 296 | AggregateShardedEdgesToGlobalsFn = Callable[ 297 | [ShardedEdgeFeatures, jnp.ndarray, int, jnp.ndarray], gn_graph.ArrayTree] 298 | AggregateShardedEdgesToNodesFn = Callable[ 299 | [gn_graph.ArrayTree, jnp.ndarray, int, List[List[int]]], jraph.NodeFeatures] 300 | 301 | 302 | # pylint: disable=invalid-name 303 | def ShardedEdgesGraphNetwork( 304 | update_edge_fn: Optional[jraph.GNUpdateEdgeFn], 305 | update_node_fn: Optional[jraph.GNUpdateNodeFn], 306 | update_global_fn: Optional[jraph.GNUpdateGlobalFn] = None, 307 | aggregate_edges_for_nodes_fn: 308 | AggregateShardedEdgesToNodesFn = sharded_segment_sum, 309 | aggregate_nodes_for_globals_fn: jraph.AggregateNodesToGlobalsFn = jax.ops 310 | .segment_sum, 311 | aggregate_edges_for_globals_fn: 312 | AggregateShardedEdgesToGlobalsFn = sharded_segment_sum, 313 | attention_logit_fn: Optional[jraph.AttentionLogitFn] = None, 314 | attention_reduce_fn: Optional[jraph.AttentionReduceFn] = None, 315 | num_shards: int = 1): 316 | """Returns a method that applies a GraphNetwork on a sharded GraphsTuple. 317 | 318 | This GraphNetwork is sharded over `edges`, all other features are assumed 319 | to be replicated on device. 320 | There are two clear use cases for a ShardedEdgesGraphNetwork. The first is 321 | where a single graph can't fit on device. The second is when you are compute 322 | bound on the edge feature calculation, and you'd like to speed up 323 | training/inference by distributing the compute across devices. 324 | 325 | Example usage: 326 | 327 | ``` 328 | gn = jax.pmap(ShardedEdgesGraphNetwork(update_edge_function, 329 | update_node_function, **kwargs), axis_name='i') 330 | # Conduct multiple rounds of message passing with the same parameters: 331 | for _ in range(num_message_passing_steps): 332 | sharded_graph = gn(sharded_graph) 333 | ``` 334 | 335 | Args: 336 | update_edge_fn: function used to update the edges or None to deactivate edge 337 | updates. 338 | update_node_fn: function used to update the nodes or None to deactivate node 339 | updates. 340 | update_global_fn: function used to update the globals or None to deactivate 341 | globals updates. 342 | aggregate_edges_for_nodes_fn: function used to aggregate messages to each 343 | nodes. This must support cross-device aggregations. 344 | aggregate_nodes_for_globals_fn: function used to aggregate the nodes for the 345 | globals. 346 | aggregate_edges_for_globals_fn: function used to aggregate the edges for the 347 | globals. This must support cross-device aggregations. 348 | attention_logit_fn: function used to calculate the attention weights or None 349 | to deactivate attention mechanism. 350 | attention_reduce_fn: function used to apply weights to the edge features or 351 | None if attention mechanism is not active. 352 | num_shards: how many devices per replica for sharding. 353 | 354 | Returns: 355 | A method that applies the configured GraphNetwork. 356 | """ 357 | not_both_supplied = lambda x, y: (x != y) and ((x is None) or (y is None)) 358 | if not_both_supplied(attention_reduce_fn, attention_logit_fn): 359 | raise ValueError(('attention_logit_fn and attention_reduce_fn must both be' 360 | ' supplied.')) 361 | 362 | devices = jax.devices() 363 | num_devices = len(devices) 364 | assert num_devices % num_shards == 0 365 | num_replicas = num_devices // num_shards 366 | # The IDs within a replica. 367 | replica_ids = list(range(num_devices)) 368 | # How the devices are grouped per replica. 369 | axis_groups = [ 370 | replica_ids[i * num_shards:(i + 1) * num_shards] 371 | for i in range(num_replicas) 372 | ] 373 | 374 | def _ApplyGraphNet(graph: ShardedEdgesGraphsTuple) -> ShardedEdgesGraphsTuple: 375 | """Applies a configured GraphNetwork to a sharded graph. 376 | 377 | This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261 378 | 379 | There is one difference. For the nodes update the class aggregates over the 380 | sender edges and receiver edges separately. This is a bit more general 381 | the algorithm described in the paper. The original behaviour can be 382 | recovered by using only the receiver edge aggregations for the update. 383 | 384 | In addition this implementation supports softmax attention over incoming 385 | edge features. 386 | 387 | 388 | Many popular Graph Neural Networks can be implemented as special cases of 389 | GraphNets, for more information please see the paper. 390 | 391 | Args: 392 | graph: a `GraphsTuple` containing the graph. 393 | 394 | Returns: 395 | Updated `GraphsTuple`. 396 | """ 397 | # pylint: disable=g-long-lambda 398 | nodes, device_edges, device_receivers, device_senders, receivers, senders, globals_, device_n_edge, n_node, n_edge, device_graph_idx = graph 399 | # Equivalent to jnp.sum(n_node), but jittable. 400 | sum_n_node = tree.tree_leaves(nodes)[0].shape[0] 401 | sum_device_n_edge = device_senders.shape[0] 402 | if not tree.tree_all( 403 | tree.tree_map(lambda n: n.shape[0] == sum_n_node, nodes)): 404 | raise ValueError( 405 | 'All node arrays in nest must contain the same number of nodes.') 406 | 407 | sent_attributes = tree.tree_map(lambda n: n[device_senders], nodes) 408 | received_attributes = tree.tree_map(lambda n: n[device_receivers], nodes) 409 | # Here we scatter the global features to the corresponding edges, 410 | # giving us tensors of shape [num_edges, global_feat]. 411 | global_edge_attributes = tree.tree_map( 412 | lambda g: jnp.repeat( 413 | g[device_graph_idx], device_n_edge, axis=0, 414 | total_repeat_length=sum_device_n_edge), 415 | globals_) 416 | 417 | if update_edge_fn: 418 | device_edges = update_edge_fn(device_edges, sent_attributes, 419 | received_attributes, global_edge_attributes) 420 | 421 | if attention_logit_fn: 422 | logits = attention_logit_fn(device_edges, sent_attributes, 423 | received_attributes, global_edge_attributes) 424 | tree_calculate_weights = functools.partial( 425 | utils.segment_softmax, segment_ids=receivers, num_segments=sum_n_node) 426 | weights = tree.tree_map(tree_calculate_weights, logits) 427 | device_edges = attention_reduce_fn(device_edges, weights) 428 | 429 | if update_node_fn: 430 | # Aggregations over nodes are assumed to take place over devices 431 | # specified by the axis_groups (e.g. with sharded_segment_sum). 432 | sent_attributes = tree.tree_map( 433 | lambda e: aggregate_edges_for_nodes_fn(e, device_senders, sum_n_node, 434 | axis_groups), device_edges) 435 | received_attributes = tree.tree_map( 436 | lambda e: aggregate_edges_for_nodes_fn( 437 | e, device_receivers, sum_n_node, axis_groups), device_edges) 438 | # Here we scatter the global features to the corresponding nodes, 439 | # giving us tensors of shape [num_nodes, global_feat]. 440 | global_attributes = tree.tree_map( 441 | lambda g: jnp.repeat( 442 | g, n_node, axis=0, total_repeat_length=sum_n_node), globals_) 443 | nodes = update_node_fn(nodes, sent_attributes, received_attributes, 444 | global_attributes) 445 | 446 | if update_global_fn: 447 | n_graph = n_node.shape[0] 448 | graph_idx = jnp.arange(n_graph) 449 | # To aggregate nodes and edges from each graph to global features, 450 | # we first construct tensors that map the node to the corresponding graph. 451 | # For example, if you have `n_node=[1,2]`, we construct the tensor 452 | # [0, 1, 1]. We then do the same for edges. 453 | node_gr_idx = jnp.repeat( 454 | graph_idx, n_node, axis=0, total_repeat_length=sum_n_node) 455 | edge_gr_idx = jnp.repeat( 456 | device_graph_idx, 457 | device_n_edge, 458 | axis=0, 459 | total_repeat_length=sum_device_n_edge) 460 | # We use the aggregation function to pool the nodes/edges per graph. 461 | node_attributes = tree.tree_map( 462 | lambda n: aggregate_nodes_for_globals_fn(n, node_gr_idx, n_graph), 463 | nodes) 464 | edge_attribtutes = tree.tree_map( 465 | lambda e: aggregate_edges_for_globals_fn(e, edge_gr_idx, n_graph, 466 | axis_groups), device_edges) 467 | # These pooled nodes are the inputs to the global update fn. 468 | globals_ = update_global_fn(node_attributes, edge_attribtutes, globals_) 469 | # pylint: enable=g-long-lambda 470 | return ShardedEdgesGraphsTuple( 471 | nodes=nodes, 472 | device_edges=device_edges, 473 | device_senders=device_senders, 474 | device_receivers=device_receivers, 475 | receivers=receivers, 476 | senders=senders, 477 | device_graph_idx=device_graph_idx, 478 | globals=globals_, 479 | n_node=n_node, 480 | n_edge=n_edge, 481 | device_n_edge=device_n_edge) 482 | 483 | return _ApplyGraphNet 484 | # pylint: enable=invalid-name 485 | -------------------------------------------------------------------------------- /jraph/experimental/sharded_graphnet_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for sharded graphnet.""" 15 | 16 | import functools 17 | import os 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | import jax 22 | from jax.lib import xla_bridge 23 | import jax.tree_util as tree 24 | import jraph 25 | from jraph._src import utils 26 | from jraph.experimental import sharded_graphnet 27 | import numpy as np 28 | 29 | 30 | def _get_graphs_from_n_edge(n_edge): 31 | """Get a graphs tuple from n_edge.""" 32 | graphs = [] 33 | for el in n_edge: 34 | graphs.append( 35 | jraph.GraphsTuple( 36 | nodes=np.random.uniform(size=(128, 2)), 37 | edges=np.random.uniform(size=(el, 2)), 38 | senders=np.random.choice(128, el), 39 | receivers=np.random.choice(128, el), 40 | n_edge=np.array([el]), 41 | n_node=np.array([128]), 42 | globals=np.array([[el]]), 43 | )) 44 | graphs = utils.batch_np(graphs) 45 | return graphs 46 | 47 | 48 | def get_graphs_tuples(n_edge, sharded_n_edge, device_graph_idx): 49 | sharded_n_edge = np.array(sharded_n_edge) 50 | device_graph_idx = np.array(device_graph_idx) 51 | devices = len(sharded_n_edge) 52 | graphs = _get_graphs_from_n_edge(n_edge) 53 | sharded_senders = np.reshape(graphs.senders, [devices, -1]) 54 | sharded_receivers = np.reshape(graphs.receivers, [devices, -1]) 55 | sharded_edges = np.reshape(graphs.edges, 56 | [devices, -1, graphs.edges.shape[-1]]) 57 | # Broadcast replicated features to have a devices leading axis. 58 | broadcast = lambda x: np.broadcast_to(x[None, :], [devices] + list(x.shape)) 59 | 60 | sharded_graphs = sharded_graphnet.ShardedEdgesGraphsTuple( 61 | device_senders=sharded_senders, 62 | device_receivers=sharded_receivers, 63 | device_edges=sharded_edges, 64 | device_n_edge=sharded_n_edge, 65 | nodes=broadcast(graphs.nodes), 66 | senders=broadcast(graphs.senders), 67 | receivers=broadcast(graphs.receivers), 68 | device_graph_idx=device_graph_idx, 69 | globals=broadcast(graphs.globals), 70 | n_node=broadcast(graphs.n_node), 71 | n_edge=broadcast(graphs.n_edge)) 72 | return graphs, sharded_graphs 73 | 74 | 75 | class ShardedGraphnetTest(parameterized.TestCase): 76 | 77 | def setUp(self): 78 | super().setUp() 79 | os.environ[ 80 | 'XLA_FLAGS'] = '--xla_force_host_platform_device_count=3' 81 | xla_bridge.get_backend.cache_clear() 82 | 83 | @parameterized.named_parameters( 84 | ('split_3_to_4', [3, 5, 4], [[3, 3], [2, 4]], [[0, 1], [1, 2]]), 85 | ('split_zero_last_edge', [1, 2, 5, 4], [[1, 2, 3], [2, 4, 0] 86 | ], [[0, 1, 2], [2, 3, 0]]), 87 | ('split_one_over_multiple', [1, 11], [[1, 3], [4, 0], [4, 0] 88 | ], [[0, 1], [1, 0], [1, 0]])) 89 | def test_get_sharded_graphs_tuple(self, n_edge, sharded_n_edge, 90 | device_graph_idx): 91 | in_tuple, expect_tuple = get_graphs_tuples(n_edge, sharded_n_edge, 92 | device_graph_idx) 93 | out_tuple = sharded_graphnet.graphs_tuple_to_broadcasted_sharded_graphs_tuple( 94 | in_tuple, num_shards=len(expect_tuple.nodes)) 95 | tree.tree_map(np.testing.assert_almost_equal, out_tuple, expect_tuple) 96 | 97 | @parameterized.named_parameters( 98 | ('split_intermediate', [3, 5, 4, 3, 3]), 99 | ('split_zero_last_edge', [1, 2, 5, 4, 6]), 100 | ('split_one_over_multiple', [1, 11])) 101 | def test_sharded_same_as_non_sharded(self, n_edge): 102 | in_tuple = _get_graphs_from_n_edge(n_edge) 103 | devices = 3 104 | sharded_tuple = sharded_graphnet.graphs_tuple_to_broadcasted_sharded_graphs_tuple( 105 | in_tuple, devices) 106 | update_fn = jraph.concatenated_args(lambda x: x) 107 | sharded_gn = sharded_graphnet.ShardedEdgesGraphNetwork( 108 | update_fn, update_fn, update_fn, num_shards=devices) 109 | gn = jraph.GraphNetwork(update_fn, update_fn, update_fn) 110 | sharded_out = jax.pmap(sharded_gn, axis_name='i')(sharded_tuple) 111 | expected_out = gn(in_tuple) 112 | reduced_out = sharded_graphnet.broadcasted_sharded_graphs_tuple_to_graphs_tuple( 113 | sharded_out) 114 | jax.tree_util.tree_map( 115 | functools.partial(np.testing.assert_allclose, atol=1E-5, rtol=1E-5), 116 | expected_out, reduced_out) 117 | 118 | 119 | if __name__ == '__main__': 120 | absltest.main() 121 | -------------------------------------------------------------------------------- /jraph/ogb_examples/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Data loading utils for the Open Graph Benchmark (OGB) Mol-Hiv.""" 16 | 17 | import pathlib 18 | import jraph 19 | import numpy as np 20 | import pandas as pd 21 | 22 | 23 | class DataReader: 24 | """Data Reader for Open Graph Benchmark datasets.""" 25 | 26 | def __init__(self, 27 | data_path, 28 | master_csv_path, 29 | split_path, 30 | batch_size=1, 31 | dynamically_batch=False): 32 | """Initializes the data reader by loading in data.""" 33 | with pathlib.Path(master_csv_path).open("rt") as fp: 34 | self._dataset_info = pd.read_csv(fp, index_col=0)["ogbg-molhiv"] 35 | self._data_path = pathlib.Path(data_path) 36 | # Load edge information, and transpose into (senders, receivers). 37 | with pathlib.Path(data_path, "edge.csv.gz").open("rb") as fp: 38 | sender_receivers = pd.read_csv( 39 | fp, compression="gzip", header=None).values.T.astype(np.int64) 40 | self._senders = sender_receivers[0] 41 | self._receivers = sender_receivers[1] 42 | # Load n_node and n_edge 43 | with pathlib.Path(data_path, "num-node-list.csv.gz").open("rb") as fp: 44 | self._n_node = pd.read_csv(fp, compression="gzip", header=None) 45 | self._n_node = self._n_node.astype(np.int64)[0].tolist() 46 | with pathlib.Path(data_path, "num-edge-list.csv.gz").open("rb") as fp: 47 | self._n_edge = pd.read_csv(fp, compression="gzip", header=None) 48 | self._n_edge = self._n_edge.astype(np.int64)[0].tolist() 49 | # Load node features 50 | with pathlib.Path(data_path, "node-feat.csv.gz").open("rb") as fp: 51 | self._nodes = pd.read_csv( 52 | fp, compression="gzip", header=None).astype(np.float32).values 53 | with pathlib.Path(data_path, "edge-feat.csv.gz").open("rb") as fp: 54 | self._edges = pd.read_csv( 55 | fp, compression="gzip", header=None).astype(np.float32).values 56 | with pathlib.Path(data_path, "graph-label.csv.gz").open("rb") as fp: 57 | self._labels = pd.read_csv( 58 | fp, compression="gzip", header=None).values 59 | 60 | with pathlib.Path(split_path).open("rb") as fp: 61 | self._split_idx = pd.read_csv( 62 | fp, compression="gzip", header=None).values.T[0] 63 | 64 | self._repeat = False 65 | self._batch_size = batch_size 66 | self._generator = self._make_generator() 67 | self._max_nodes = int(np.max(self._n_node)) 68 | self._max_edges = int(np.max(self._n_edge)) 69 | 70 | if dynamically_batch: 71 | self._generator = jraph.dynamically_batch( 72 | self._generator, 73 | # Plus one for the extra padding node. 74 | n_node=self._batch_size * (self._max_nodes) + 1, 75 | # Times two because we want backwards edges. 76 | n_edge=self._batch_size * (self._max_edges) * 2, 77 | n_graph=self._batch_size + 1) 78 | 79 | # If n_node = [1,2,3], we create accumulated n_node [0,1,3,6] for indexing. 80 | self._accumulated_n_nodes = np.concatenate((np.array([0]), 81 | np.cumsum(self._n_node))) 82 | # Same for n_edge 83 | self._accumulated_n_edges = np.concatenate((np.array([0]), 84 | np.cumsum(self._n_edge))) 85 | self._dynamically_batch = dynamically_batch 86 | 87 | @property 88 | def total_num_graphs(self): 89 | return len(self._n_node) 90 | 91 | def repeat(self): 92 | self._repeat = True 93 | 94 | def __iter__(self): 95 | return self 96 | 97 | def __next__(self): 98 | graphs = [] 99 | if self._dynamically_batch: 100 | # If we are using pmap we need each batch to have the same size in both 101 | # number of nodes and number of edges. So we use dynamically batch which 102 | # guarantees this. 103 | return next(self._generator) 104 | else: 105 | for _ in range(self._batch_size): 106 | graph = next(self._generator) 107 | graphs.append(graph) 108 | return jraph.batch(graphs) 109 | 110 | def get_graph_by_idx(self, idx): 111 | """Gets a graph by an integer index.""" 112 | # Gather the graph information 113 | label = self._labels[idx] 114 | n_node = self._n_node[idx] 115 | n_edge = self._n_edge[idx] 116 | node_slice = slice( 117 | self._accumulated_n_nodes[idx], self._accumulated_n_nodes[idx+1]) 118 | edge_slice = slice( 119 | self._accumulated_n_edges[idx], self._accumulated_n_edges[idx+1]) 120 | nodes = self._nodes[node_slice] 121 | edges = self._edges[edge_slice] 122 | senders = self._senders[edge_slice] 123 | receivers = self._receivers[edge_slice] 124 | # Molecular graphs are bi directional, but the serialization only 125 | # stores one way so we add the missing edges. 126 | return jraph.GraphsTuple( 127 | nodes=nodes, 128 | edges=np.concatenate([edges, edges]), 129 | n_node=np.array([n_node]), 130 | n_edge=np.array([n_edge*2]), 131 | senders=np.concatenate([senders, receivers]), 132 | receivers=np.concatenate([receivers, senders]), 133 | globals={"label": label}) 134 | 135 | def _make_generator(self): 136 | """Makes a single example generator of the loaded OGB data.""" 137 | idx = 0 138 | while True: 139 | # If not repeating, exit when we've cycled through all the graphs. 140 | # Only return graphs within the split. 141 | if not self._repeat: 142 | if idx == self.total_num_graphs: 143 | return 144 | else: 145 | # This will reset the index to 0 if we are at the end of the dataset. 146 | idx = idx % self.total_num_graphs 147 | if idx not in self._split_idx: 148 | idx += 1 149 | continue 150 | graph = self.get_graph_by_idx(idx) 151 | idx += 1 152 | yield graph 153 | -------------------------------------------------------------------------------- /jraph/ogb_examples/data_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tests for graph.ogb_examples.data_utils.""" 16 | 17 | import pathlib 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import jraph 21 | from jraph.ogb_examples import data_utils 22 | import numpy as np 23 | import tree 24 | 25 | 26 | class DataUtilsTest(parameterized.TestCase): 27 | 28 | def setUp(self): 29 | super(DataUtilsTest, self).setUp() 30 | self._test_graph = jraph.GraphsTuple( 31 | nodes=np.broadcast_to( 32 | np.arange(10, dtype=np.float32)[:, None], (10, 10)), 33 | edges=np.concatenate(( 34 | np.broadcast_to(np.arange(20, dtype=np.float32)[:, None], (20, 4)), 35 | np.broadcast_to(np.arange(20, dtype=np.float32)[:, None], (20, 4)) 36 | )), 37 | receivers=np.concatenate((np.arange(20), np.arange(20))), 38 | senders=np.concatenate((np.arange(20), np.arange(20))), 39 | globals={'label': np.array([1], dtype=np.int32)}, 40 | n_node=np.array([10], dtype=np.int32), 41 | n_edge=np.array([40], dtype=np.int32)) 42 | ogb_path = pathlib.Path(data_utils.__file__).parents[0] 43 | master_csv_path = pathlib.Path(ogb_path, 'test_data', 'master.csv') 44 | split_path = pathlib.Path(ogb_path, 'test_data', 'train.csv.gz') 45 | data_path = master_csv_path.parents[0] 46 | self._reader = data_utils.DataReader( 47 | data_path=data_path, 48 | master_csv_path=master_csv_path, 49 | split_path=split_path) 50 | 51 | def test_total_num_graph(self): 52 | self.assertEqual(self._reader.total_num_graphs, 1) 53 | 54 | def test_expected_graph(self): 55 | graph = next(self._reader) 56 | with self.subTest('test_graph_equality'): 57 | tree.map_structure( 58 | np.testing.assert_almost_equal, graph, self._test_graph) 59 | with self.subTest('stop_iteration'): 60 | # One element in the dataset, so should have stop iteration. 61 | with self.assertRaises(StopIteration): 62 | next(self._reader) 63 | 64 | def test_reader_repeat(self): 65 | self._reader.repeat() 66 | next(self._reader) 67 | graph = next(self._reader) 68 | # One graph in the test dataset so should be the same. 69 | tree.map_structure(np.testing.assert_almost_equal, graph, self._test_graph) 70 | 71 | 72 | if __name__ == '__main__': 73 | absltest.main() 74 | -------------------------------------------------------------------------------- /jraph/ogb_examples/test_data/edge-feat.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/jraph/51f5990104f7374492f8f3ea1cbc47feb411c69c/jraph/ogb_examples/test_data/edge-feat.csv.gz -------------------------------------------------------------------------------- /jraph/ogb_examples/test_data/edge.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/jraph/51f5990104f7374492f8f3ea1cbc47feb411c69c/jraph/ogb_examples/test_data/edge.csv.gz -------------------------------------------------------------------------------- /jraph/ogb_examples/test_data/graph-label.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/jraph/51f5990104f7374492f8f3ea1cbc47feb411c69c/jraph/ogb_examples/test_data/graph-label.csv.gz -------------------------------------------------------------------------------- /jraph/ogb_examples/test_data/master.csv: -------------------------------------------------------------------------------- 1 | ,ogbg-molhiv 2 | name, test_data 3 | -------------------------------------------------------------------------------- /jraph/ogb_examples/test_data/node-feat.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/jraph/51f5990104f7374492f8f3ea1cbc47feb411c69c/jraph/ogb_examples/test_data/node-feat.csv.gz -------------------------------------------------------------------------------- /jraph/ogb_examples/test_data/num-edge-list.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/jraph/51f5990104f7374492f8f3ea1cbc47feb411c69c/jraph/ogb_examples/test_data/num-edge-list.csv.gz -------------------------------------------------------------------------------- /jraph/ogb_examples/test_data/num-node-list.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/jraph/51f5990104f7374492f8f3ea1cbc47feb411c69c/jraph/ogb_examples/test_data/num-node-list.csv.gz -------------------------------------------------------------------------------- /jraph/ogb_examples/test_data/train.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/jraph/51f5990104f7374492f8f3ea1cbc47feb411c69c/jraph/ogb_examples/test_data/train.csv.gz -------------------------------------------------------------------------------- /jraph/ogb_examples/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | r"""Example training script for training OGB molhiv with jax graph-nets. 16 | 17 | The ogbg-molhiv dataset is a molecular property prediction dataset. 18 | It is adopted from the MoleculeNet [1]. All the molecules are pre-processed 19 | using RDKit [2]. 20 | 21 | Each graph represents a molecule, where nodes are atoms, and edges are chemical 22 | bonds. Input node features are 9-dimensional, containing atomic number and 23 | chirality, as well as other additional atom features such as formal charge and 24 | whether the atom is in the ring or not. 25 | 26 | The goal is to predict whether a molecule inhibits HIV virus replication or not. 27 | Performance is measured in ROC-AUC. 28 | 29 | This script uses a GraphNet to learn the prediction task. 30 | 31 | [1] Zhenqin Wu, Bharath Ramsundar, Evan N Feinberg, Joseph Gomes, 32 | Caleb Geniesse, Aneesh SPappu, Karl Leswing, and Vijay Pande. 33 | Moleculenet: a benchmark for molecular machine learning. 34 | Chemical Science, 9(2):513–530, 2018. 35 | 36 | [2] Greg Landrum et al. RDKit: Open-source cheminformatics, 2006. 37 | 38 | Example usage: 39 | 40 | python3 train.py --data_path={DATA_PATH} --master_csv_path={MASTER_CSV_PATH} \ 41 | --save_dir={SAVE_DIR} --split_path={SPLIT_PATH} 42 | """ 43 | 44 | import functools 45 | import logging 46 | import pathlib 47 | import pickle 48 | from absl import app 49 | from absl import flags 50 | import haiku as hk 51 | import jax 52 | import jax.numpy as jnp 53 | import jraph 54 | from jraph.ogb_examples import data_utils 55 | import optax 56 | 57 | 58 | flags.DEFINE_string('data_path', None, 'Directory of the data.') 59 | flags.DEFINE_string('split_path', None, 'Path to the data split indices.') 60 | flags.DEFINE_string('master_csv_path', None, 'Path to OGB master.csv.') 61 | flags.DEFINE_string('save_dir', None, 'Directory to save parameters to.') 62 | flags.DEFINE_integer('batch_size', 1, 'Number of graphs in batch.') 63 | flags.DEFINE_integer('num_training_steps', 1000, 'Number of training steps.') 64 | flags.DEFINE_enum('mode', 'train', ['train', 'evaluate'], 'Train or evaluate.') 65 | FLAGS = flags.FLAGS 66 | 67 | 68 | @jraph.concatenated_args 69 | def edge_update_fn(feats: jnp.ndarray) -> jnp.ndarray: 70 | """Edge update function for graph net.""" 71 | net = hk.Sequential( 72 | [hk.Linear(128), jax.nn.relu, 73 | hk.Linear(128)]) 74 | return net(feats) 75 | 76 | 77 | @jraph.concatenated_args 78 | def node_update_fn(feats: jnp.ndarray) -> jnp.ndarray: 79 | """Node update function for graph net.""" 80 | net = hk.Sequential( 81 | [hk.Linear(128), jax.nn.relu, 82 | hk.Linear(128)]) 83 | return net(feats) 84 | 85 | 86 | @jraph.concatenated_args 87 | def update_global_fn(feats: jnp.ndarray) -> jnp.ndarray: 88 | """Global update function for graph net.""" 89 | # Molhiv is a binary classification task, so output pos neg logits. 90 | net = hk.Sequential( 91 | [hk.Linear(128), jax.nn.relu, 92 | hk.Linear(2)]) 93 | return net(feats) 94 | 95 | 96 | def net_fn(graph: jraph.GraphsTuple) -> jraph.GraphsTuple: 97 | """Graph net function.""" 98 | # Add a global paramater for graph classification. 99 | graph = graph._replace(globals=jnp.zeros([graph.n_node.shape[0], 1])) 100 | embedder = jraph.GraphMapFeatures( 101 | hk.Linear(128), hk.Linear(128), hk.Linear(128)) 102 | net = jraph.GraphNetwork( 103 | update_node_fn=node_update_fn, 104 | update_edge_fn=edge_update_fn, 105 | update_global_fn=update_global_fn) 106 | return net(embedder(graph)) 107 | 108 | 109 | def _nearest_bigger_power_of_two(x: int) -> int: 110 | """Computes the nearest power of two greater than x for padding.""" 111 | y = 2 112 | while y < x: 113 | y *= 2 114 | return y 115 | 116 | 117 | def pad_graph_to_nearest_power_of_two( 118 | graphs_tuple: jraph.GraphsTuple) -> jraph.GraphsTuple: 119 | """Pads a batched `GraphsTuple` to the nearest power of two. 120 | 121 | For example, if a `GraphsTuple` has 7 nodes, 5 edges and 3 graphs, this method 122 | would pad the `GraphsTuple` nodes and edges: 123 | 7 nodes --> 8 nodes (2^3) 124 | 5 edges --> 8 edges (2^3) 125 | 126 | And since padding is accomplished using `jraph.pad_with_graphs`, an extra 127 | graph and node is added: 128 | 8 nodes --> 9 nodes 129 | 3 graphs --> 4 graphs 130 | 131 | Args: 132 | graphs_tuple: a batched `GraphsTuple` (can be batch size 1). 133 | 134 | Returns: 135 | A graphs_tuple batched to the nearest power of two. 136 | """ 137 | # Add 1 since we need at least one padding node for pad_with_graphs. 138 | pad_nodes_to = _nearest_bigger_power_of_two(jnp.sum(graphs_tuple.n_node)) + 1 139 | pad_edges_to = _nearest_bigger_power_of_two(jnp.sum(graphs_tuple.n_edge)) 140 | # Add 1 since we need at least one padding graph for pad_with_graphs. 141 | # We do not pad to nearest power of two because the batch size is fixed. 142 | pad_graphs_to = graphs_tuple.n_node.shape[0] + 1 143 | return jraph.pad_with_graphs(graphs_tuple, pad_nodes_to, pad_edges_to, 144 | pad_graphs_to) 145 | 146 | 147 | def compute_loss(params, graph, label, net): 148 | """Computes loss.""" 149 | pred_graph = net.apply(params, graph) 150 | preds = jax.nn.log_softmax(pred_graph.globals) 151 | targets = jax.nn.one_hot(label, 2) 152 | 153 | # Since we have an extra 'dummy' graph in our batch due to padding, we want 154 | # to mask out any loss associated with the dummy graph. 155 | # Since we padded with `pad_with_graphs` we can recover the mask by using 156 | # get_graph_padding_mask. 157 | mask = jraph.get_graph_padding_mask(pred_graph) 158 | 159 | # Cross entropy loss. 160 | loss = -jnp.mean(preds * targets * mask[:, None]) 161 | 162 | # Accuracy taking into account the mask. 163 | accuracy = jnp.sum( 164 | (jnp.argmax(pred_graph.globals, axis=1) == label) * mask)/jnp.sum(mask) 165 | return loss, accuracy 166 | 167 | 168 | def train(data_path, master_csv_path, split_path, batch_size, 169 | num_training_steps, save_dir): 170 | """OGB Training Script.""" 171 | # Initialize the dataset reader. 172 | reader = data_utils.DataReader( 173 | data_path=data_path, 174 | master_csv_path=master_csv_path, 175 | split_path=split_path, 176 | batch_size=batch_size) 177 | # Repeat the dataset forever for training. 178 | reader.repeat() 179 | 180 | # Transform impure `net_fn` to pure functions with hk.transform. 181 | net = hk.without_apply_rng(hk.transform(net_fn)) 182 | # Get a candidate graph and label to initialize the network. 183 | graph = reader.get_graph_by_idx(0) 184 | 185 | # Initialize the network. 186 | logging.info('Initializing network.') 187 | params = net.init(jax.random.PRNGKey(42), graph) 188 | # Initialize the optimizer. 189 | opt_init, opt_update = optax.adam(1e-4) 190 | opt_state = opt_init(params) 191 | 192 | compute_loss_fn = functools.partial(compute_loss, net=net) 193 | # We jit the computation of our loss, since this is the main computation. 194 | # Using jax.jit means that we will use a single accelerator. If you want 195 | # to use more than 1 accelerator, use jax.pmap. More information can be 196 | # found in the jax documentation. 197 | compute_loss_fn = jax.jit(jax.value_and_grad( 198 | compute_loss_fn, has_aux=True)) 199 | 200 | for idx in range(num_training_steps): 201 | graph = next(reader) 202 | # Jax will re-jit your graphnet every time a new graph shape is encountered. 203 | # In the limit, this means a new compilation every training step, which 204 | # will result in *extremely* slow training. To prevent this, pad each 205 | # batch of graphs to the nearest power of two. Since jax maintains a cache 206 | # of compiled programs, the compilation cost is amortized. 207 | graph = pad_graph_to_nearest_power_of_two(graph) 208 | 209 | # Extract the label from the graph. 210 | label = graph.globals['label'] 211 | graph = graph._replace(globals={}) 212 | 213 | (loss, acc), grad = compute_loss_fn(params, graph, label) 214 | updates, opt_state = opt_update(grad, opt_state, params) 215 | params = optax.apply_updates(params, updates) 216 | if idx % 100 == 0: 217 | logging.info('step: %s, loss: %s, acc: %s', idx, loss, acc) 218 | if save_dir is not None: 219 | with pathlib.Path(save_dir, 'molhiv.pkl').open('wb') as fp: 220 | logging.info('Saving model to %s', save_dir) 221 | pickle.dump(params, fp) 222 | logging.info('Training finished') 223 | 224 | 225 | def evaluate(data_path, master_csv_path, split_path, save_dir): 226 | """Evaluation Script.""" 227 | logging.info('Evaluating OGB molviv') 228 | logging.info('Dataset split: %s', split_path) 229 | # Initialize the dataset reader. 230 | reader = data_utils.DataReader( 231 | data_path=data_path, 232 | master_csv_path=master_csv_path, 233 | split_path=split_path, 234 | batch_size=1) 235 | # Transform impure `net_fn` to pure functions with hk.transform. 236 | net = hk.without_apply_rng(hk.transform(net_fn)) 237 | with pathlib.Path(save_dir, 'molhiv.pkl').open('rb') as fp: 238 | params = pickle.load(fp) 239 | accumulated_loss = 0 240 | accumulated_accuracy = 0 241 | idx = 0 242 | 243 | # We jit the computation of our loss, since this is the main computation. 244 | # Using jax.jit means that we will use a single accelerator. If you want 245 | # to use more than 1 accelerator, use jax.pmap. More information can be 246 | # found in the jax documentation. 247 | compute_loss_fn = jax.jit(functools.partial(compute_loss, net=net)) 248 | for graph in reader: 249 | 250 | # Jax will re-jit your graphnet every time a new graph shape is encountered. 251 | # In the limit, this means a new compilation every training step, which 252 | # will result in *extremely* slow training. To prevent this, pad each 253 | # batch of graphs to the nearest power of two. Since jax maintains a cache 254 | # of compiled programs, the compilation cost is amortized. 255 | graph = pad_graph_to_nearest_power_of_two(graph) 256 | 257 | # Extract the labels and remove from the graph. 258 | label = graph.globals['label'] 259 | graph = graph._replace(globals={}) 260 | loss, acc = compute_loss_fn(params, graph, label) 261 | accumulated_accuracy += acc 262 | accumulated_loss += loss 263 | idx += 1 264 | if idx % 100 == 0: 265 | logging.info('Evaluated %s graphs', idx) 266 | logging.info('Completed evaluation.') 267 | loss = accumulated_loss / idx 268 | accuracy = accumulated_accuracy / idx 269 | logging.info('Eval loss: %s, accuracy %s', loss, accuracy) 270 | return loss, accuracy 271 | 272 | 273 | def main(_): 274 | if FLAGS.mode == 'train': 275 | train(FLAGS.data_path, FLAGS.master_csv_path, FLAGS.split_path, 276 | FLAGS.batch_size, FLAGS.num_training_steps, FLAGS.save_dir) 277 | elif FLAGS.mode == 'evaluate': 278 | evaluate(FLAGS.data_path, FLAGS.master_csv_path, FLAGS.split_path, 279 | FLAGS.save_dir) 280 | 281 | if __name__ == '__main__': 282 | app.run(main) 283 | -------------------------------------------------------------------------------- /jraph/ogb_examples/train_flax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | r"""Example training script for training OGB molhiv with jax graph-nets & flax. 16 | 17 | The ogbg-molhiv dataset is a molecular property prediction dataset. 18 | It is adopted from the MoleculeNet [1]. All the molecules are pre-processed 19 | using RDKit [2]. 20 | 21 | Each graph represents a molecule, where nodes are atoms, and edges are chemical 22 | bonds. Input node features are 9-dimensional, containing atomic number and 23 | chirality, as well as other additional atom features such as formal charge and 24 | whether the atom is in the ring or not. 25 | 26 | The goal is to predict whether a molecule inhibits HIV virus replication or not. 27 | Performance is measured in ROC-AUC. 28 | 29 | This script uses a GraphNet to learn the prediction task. 30 | 31 | [1] Zhenqin Wu, Bharath Ramsundar, Evan N Feinberg, Joseph Gomes, 32 | Caleb Geniesse, Aneesh SPappu, Karl Leswing, and Vijay Pande. 33 | Moleculenet: a benchmark for molecular machine learning. 34 | Chemical Science, 9(2):513–530, 2018. 35 | 36 | [2] Greg Landrum et al. RDKit: Open-source cheminformatics, 2006. 37 | 38 | Example usage: 39 | 40 | python3 train.py --data_path={DATA_PATH} --master_csv_path={MASTER_CSV_PATH} \ 41 | --save_dir={SAVE_DIR} --split_path={SPLIT_PATH} 42 | """ 43 | 44 | 45 | import functools 46 | import logging 47 | import pathlib 48 | import pickle 49 | from typing import Sequence 50 | from absl import app 51 | from absl import flags 52 | from flax import linen as nn 53 | from flax import optim 54 | import jax 55 | import jax.numpy as jnp 56 | import jraph 57 | from jraph.ogb_examples import data_utils 58 | 59 | 60 | flags.DEFINE_string('data_path', None, 'Directory of the data.') 61 | flags.DEFINE_string('split_path', None, 'Path to the data split indices.') 62 | flags.DEFINE_string('master_csv_path', None, 'Path to OGB master.csv.') 63 | flags.DEFINE_string('save_dir', None, 'Directory to save parameters to.') 64 | flags.DEFINE_integer('batch_size', 1, 'Number of graphs in batch.') 65 | flags.DEFINE_integer('num_training_steps', 1000, 'Number of training steps.') 66 | flags.DEFINE_enum('mode', 'train', ['train', 'evaluate'], 'Train or evaluate.') 67 | FLAGS = flags.FLAGS 68 | 69 | 70 | class ExplicitMLP(nn.Module): 71 | """A flax MLP.""" 72 | features: Sequence[int] 73 | 74 | @nn.compact 75 | def __call__(self, inputs): 76 | x = inputs 77 | for i, lyr in enumerate([nn.Dense(feat) for feat in self.features]): 78 | x = lyr(x) 79 | if i != len(self.features) - 1: 80 | x = nn.relu(x) 81 | return x 82 | 83 | 84 | # Functions must be passed to jraph GNNs, but pytype does not recognise 85 | # linen Modules as callables to here we wrap in a function. 86 | def make_embed_fn(latent_size): 87 | def embed(inputs): 88 | return nn.Dense(latent_size)(inputs) 89 | return embed 90 | 91 | 92 | def make_mlp(features): 93 | @jraph.concatenated_args 94 | def update_fn(inputs): 95 | return ExplicitMLP(features)(inputs) 96 | return update_fn 97 | 98 | 99 | class GraphNetwork(nn.Module): 100 | """A flax GraphNetwork.""" 101 | mlp_features: Sequence[int] 102 | latent_size: int 103 | 104 | @nn.compact 105 | def __call__(self, graph): 106 | # Add a global parameter for graph classification. 107 | graph = graph._replace(globals=jnp.zeros([graph.n_node.shape[0], 1])) 108 | 109 | embedder = jraph.GraphMapFeatures( 110 | embed_node_fn=make_embed_fn(self.latent_size), 111 | embed_edge_fn=make_embed_fn(self.latent_size), 112 | embed_global_fn=make_embed_fn(self.latent_size)) 113 | net = jraph.GraphNetwork( 114 | update_node_fn=make_mlp(self.mlp_features), 115 | update_edge_fn=make_mlp(self.mlp_features), 116 | # The global update outputs size 2 for binary classification. 117 | update_global_fn=make_mlp(self.mlp_features + (2,))) # pytype: disable=unsupported-operands 118 | return net(embedder(graph)) 119 | 120 | 121 | def _nearest_bigger_power_of_two(x: int) -> int: 122 | """Computes the nearest power of two greater than x for padding.""" 123 | y = 2 124 | while y < x: 125 | y *= 2 126 | return y 127 | 128 | 129 | def pad_graph_to_nearest_power_of_two( 130 | graphs_tuple: jraph.GraphsTuple) -> jraph.GraphsTuple: 131 | """Pads a batched `GraphsTuple` to the nearest power of two. 132 | 133 | For example, if a `GraphsTuple` has 7 nodes, 5 edges and 3 graphs, this method 134 | would pad the `GraphsTuple` nodes and edges: 135 | 7 nodes --> 8 nodes (2^3) 136 | 5 edges --> 8 edges (2^3) 137 | 138 | And since padding is accomplished using `jraph.pad_with_graphs`, an extra 139 | graph and node is added: 140 | 8 nodes --> 9 nodes 141 | 3 graphs --> 4 graphs 142 | 143 | Args: 144 | graphs_tuple: a batched `GraphsTuple` (can be batch size 1). 145 | 146 | Returns: 147 | A graphs_tuple batched to the nearest power of two. 148 | """ 149 | # Add 1 since we need at least one padding node for pad_with_graphs. 150 | pad_nodes_to = _nearest_bigger_power_of_two(jnp.sum(graphs_tuple.n_node)) + 1 151 | pad_edges_to = _nearest_bigger_power_of_two(jnp.sum(graphs_tuple.n_edge)) 152 | # Add 1 since we need at least one padding graph for pad_with_graphs. 153 | # We do not pad to nearest power of two because the batch size is fixed. 154 | pad_graphs_to = graphs_tuple.n_node.shape[0] + 1 155 | return jraph.pad_with_graphs(graphs_tuple, pad_nodes_to, pad_edges_to, 156 | pad_graphs_to) 157 | 158 | 159 | def compute_loss(params, graph, label, net): 160 | """Computes loss.""" 161 | pred_graph = net.apply(params, graph) 162 | preds = jax.nn.log_softmax(pred_graph.globals) 163 | targets = jax.nn.one_hot(label, 2) 164 | 165 | # Since we have an extra 'dummy' graph in our batch due to padding, we want 166 | # to mask out any loss associated with the dummy graph. 167 | # Since we padded with `pad_with_graphs` we can recover the mask by using 168 | # get_graph_padding_mask. 169 | mask = jraph.get_graph_padding_mask(pred_graph) 170 | 171 | # Cross entropy loss. 172 | loss = -jnp.mean(preds * targets * mask[:, None]) 173 | 174 | # Accuracy taking into account the mask. 175 | accuracy = jnp.sum( 176 | (jnp.argmax(pred_graph.globals, axis=1) == label) * mask)/jnp.sum(mask) 177 | return loss, accuracy 178 | 179 | 180 | def train_step(optimizer, graph, label, net): 181 | partial_loss_fn = functools.partial( 182 | compute_loss, graph=graph, label=label, net=net) 183 | grad_fn = jax.value_and_grad(partial_loss_fn, has_aux=True) 184 | (loss, accuracy), grad = grad_fn(optimizer.target) 185 | optimizer = optimizer.apply_gradient(grad) 186 | return optimizer, {'loss': loss, 'accuracy': accuracy} 187 | 188 | 189 | def train(data_path, master_csv_path, split_path, batch_size, 190 | num_training_steps, save_dir): 191 | """OGB Training Script.""" 192 | 193 | # Initialize the dataset reader. 194 | reader = data_utils.DataReader( 195 | data_path=data_path, 196 | master_csv_path=master_csv_path, 197 | split_path=split_path, 198 | batch_size=batch_size) 199 | # Repeat the dataset forever for training. 200 | reader.repeat() 201 | 202 | net = GraphNetwork(mlp_features=(128, 128), latent_size=128) 203 | 204 | # Get a candidate graph and label to initialize the network. 205 | graph = reader.get_graph_by_idx(0) 206 | 207 | # Initialize the network. 208 | logging.info('Initializing network.') 209 | params = net.init(jax.random.PRNGKey(42), graph) 210 | optimizer = optim.Adam(learning_rate=1e-4).create(params) 211 | optimizer = jax.device_put(optimizer) 212 | 213 | for idx in range(num_training_steps): 214 | graph = next(reader) 215 | # Jax will re-jit your graphnet every time a new graph shape is encountered. 216 | # In the limit, this means a new compilation every training step, which 217 | # will result in *extremely* slow training. To prevent this, pad each 218 | # batch of graphs to the nearest power of two. Since jax maintains a cache 219 | # of compiled programs, the compilation cost is amortized. 220 | graph = pad_graph_to_nearest_power_of_two(graph) 221 | 222 | # Remove the label from the input graph/ 223 | label = graph.globals['label'] 224 | graph = graph._replace(globals={}) 225 | 226 | optimizer, scalars = train_step(optimizer, graph, label, net) 227 | if idx % 100 == 0: 228 | logging.info('step: %s, loss: %s, acc: %s', idx, scalars['loss'], 229 | scalars['accuracy']) 230 | if save_dir is not None: 231 | with pathlib.Path(save_dir, 'molhiv.pkl').open('wb') as fp: 232 | logging.info('Saving model to %s', save_dir) 233 | pickle.dump(optimizer.target, fp) 234 | logging.info('Training finished') 235 | 236 | 237 | def evaluate(data_path, master_csv_path, split_path, save_dir): 238 | """Evaluation Script.""" 239 | logging.info('Evaluating OGB molviv') 240 | logging.info('Dataset split: %s', split_path) 241 | # Initialize the dataset reader. 242 | reader = data_utils.DataReader( 243 | data_path=data_path, 244 | master_csv_path=master_csv_path, 245 | split_path=split_path, 246 | batch_size=1) 247 | 248 | with pathlib.Path(save_dir, 'molhiv.pkl').open('rb') as fp: 249 | params = pickle.load(fp) 250 | accumulated_loss = 0 251 | accumulated_accuracy = 0 252 | idx = 0 253 | 254 | # We jit the computation of our loss, since this is the main computation. 255 | # Using jax.jit means that we will use a single accelerator. If you want 256 | # to use more than 1 accelerator, use jax.pmap. More information can be 257 | # found in the jax documentation. 258 | net = GraphNetwork(mlp_features=[128, 128], latent_size=128) 259 | compute_loss_fn = jax.jit(functools.partial(compute_loss, net=net)) 260 | for graph in reader: 261 | 262 | # Jax will re-jit your graphnet every time a new graph shape is encountered. 263 | # In the limit, this means a new compilation every training step, which 264 | # will result in *extremely* slow training. To prevent this, pad each 265 | # batch of graphs to the nearest power of two. Since jax maintains a cache 266 | # of compiled programs, the compilation cost is amortized. 267 | graph = pad_graph_to_nearest_power_of_two(graph) 268 | 269 | # Extract the labels and remove from the graph. 270 | label = graph.globals['label'] 271 | graph = graph._replace(globals={}) 272 | loss, acc = compute_loss_fn(params, graph, label) 273 | accumulated_accuracy += acc 274 | accumulated_loss += loss 275 | idx += 1 276 | if idx % 100 == 0: 277 | logging.info('Evaluated %s graphs', idx) 278 | logging.info('Completed evaluation.') 279 | loss = accumulated_loss / idx 280 | accuracy = accumulated_accuracy / idx 281 | logging.info('Eval loss: %s, accuracy %s', loss, accuracy) 282 | return loss, accuracy 283 | 284 | 285 | def main(_): 286 | if FLAGS.mode == 'train': 287 | train(FLAGS.data_path, FLAGS.master_csv_path, FLAGS.split_path, 288 | FLAGS.batch_size, FLAGS.num_training_steps, FLAGS.save_dir) 289 | elif FLAGS.mode == 'evaluate': 290 | evaluate(FLAGS.data_path, FLAGS.master_csv_path, FLAGS.split_path, 291 | FLAGS.save_dir) 292 | 293 | if __name__ == '__main__': 294 | app.run(main) 295 | -------------------------------------------------------------------------------- /jraph/ogb_examples/train_flax_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tests for jraph.ogb_examples.train_flax.""" 16 | 17 | import pathlib 18 | from absl.testing import absltest 19 | from jraph.ogb_examples import train_flax 20 | 21 | 22 | class TrainTest(absltest.TestCase): 23 | 24 | def test_train_and_eval_overfit(self): 25 | ogb_path = pathlib.Path(train_flax.__file__).parents[0] 26 | master_csv_path = pathlib.Path(ogb_path, 'test_data', 'master.csv') 27 | split_path = pathlib.Path(ogb_path, 'test_data', 'train.csv.gz') 28 | data_path = master_csv_path.parents[0] 29 | temp_dir = self.create_tempdir().full_path 30 | train_flax.train(data_path, master_csv_path, split_path, 1, 101, temp_dir) 31 | _, accuracy = train_flax.evaluate( 32 | data_path, master_csv_path, split_path, temp_dir) 33 | self.assertEqual(accuracy, 1.0) 34 | 35 | 36 | if __name__ == '__main__': 37 | absltest.main() 38 | -------------------------------------------------------------------------------- /jraph/ogb_examples/train_pmap.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | r"""Example training script for training OGB molhiv with jax graph-nets. 16 | 17 | The ogbg-molhiv dataset is a molecular property prediction dataset. 18 | It is adopted from the MoleculeNet [1]. All the molecules are pre-processed 19 | using RDKit [2]. 20 | 21 | Each graph represents a molecule, where nodes are atoms, and edges are chemical 22 | bonds. Input node features are 9-dimensional, containing atomic number and 23 | chirality, as well as other additional atom features such as formal charge and 24 | whether the atom is in the ring or not. 25 | 26 | The goal is to predict whether a molecule inhibits HIV virus replication or not. 27 | Performance is measured in ROC-AUC. 28 | 29 | This script uses a GraphNet to learn the prediction task. 30 | 31 | [1] Zhenqin Wu, Bharath Ramsundar, Evan N Feinberg, Joseph Gomes, 32 | Caleb Geniesse, Aneesh SPappu, Karl Leswing, and Vijay Pande. 33 | Moleculenet: a benchmark for molecular machine learning. 34 | Chemical Science, 9(2):513–530, 2018. 35 | 36 | [2] Greg Landrum et al. RDKit: Open-source cheminformatics, 2006. 37 | 38 | Example usage: 39 | 40 | python3 train.py --data_path={DATA_PATH} --master_csv_path={MASTER_CSV_PATH} \ 41 | --save_dir={SAVE_DIR} --split_path={SPLIT_PATH} 42 | """ 43 | 44 | import functools 45 | import logging 46 | import pathlib 47 | import pickle 48 | from typing import Iterator 49 | from absl import app 50 | from absl import flags 51 | import haiku as hk 52 | import jax 53 | import jax.numpy as jnp 54 | import jraph 55 | from jraph.ogb_examples import data_utils 56 | import optax 57 | 58 | 59 | flags.DEFINE_string('data_path', None, 'Directory of the data.') 60 | flags.DEFINE_string('split_path', None, 'Path to the data split indices.') 61 | flags.DEFINE_string('master_csv_path', None, 'Path to OGB master.csv.') 62 | flags.DEFINE_string('save_dir', None, 'Directory to save parameters to.') 63 | flags.DEFINE_integer('batch_size', 1, 'Number of graphs in batch.') 64 | flags.DEFINE_integer('num_training_steps', 1000, 'Number of training steps.') 65 | flags.DEFINE_enum('mode', 'train', ['train', 'evaluate'], 'Train or evaluate.') 66 | FLAGS = flags.FLAGS 67 | 68 | 69 | @jraph.concatenated_args 70 | def edge_update_fn(feats: jnp.ndarray) -> jnp.ndarray: 71 | """Edge update function for graph net.""" 72 | net = hk.Sequential( 73 | [hk.Linear(128), jax.nn.relu, 74 | hk.Linear(128)]) 75 | return net(feats) 76 | 77 | 78 | @jraph.concatenated_args 79 | def node_update_fn(feats: jnp.ndarray) -> jnp.ndarray: 80 | """Node update function for graph net.""" 81 | net = hk.Sequential( 82 | [hk.Linear(128), jax.nn.relu, 83 | hk.Linear(128)]) 84 | return net(feats) 85 | 86 | 87 | @jraph.concatenated_args 88 | def update_global_fn(feats: jnp.ndarray) -> jnp.ndarray: 89 | """Global update function for graph net.""" 90 | # Molhiv is a binary classification task, so output pos neg logits. 91 | net = hk.Sequential( 92 | [hk.Linear(128), jax.nn.relu, 93 | hk.Linear(2)]) 94 | return net(feats) 95 | 96 | 97 | def net_fn(graph: jraph.GraphsTuple) -> jraph.GraphsTuple: 98 | """Graph net function.""" 99 | # Add a global paramater for graph classification. 100 | graph = graph._replace(globals=jnp.zeros([graph.n_node.shape[0], 1])) 101 | embedder = jraph.GraphMapFeatures( 102 | hk.Linear(128), hk.Linear(128), hk.Linear(128)) 103 | net = jraph.GraphNetwork( 104 | update_node_fn=node_update_fn, 105 | update_edge_fn=edge_update_fn, 106 | update_global_fn=update_global_fn) 107 | return net(embedder(graph)) 108 | 109 | 110 | def device_batch( 111 | graph_generator: data_utils.DataReader) -> Iterator[jraph.GraphsTuple]: 112 | """Batches a set of graphs the size of the number of devices.""" 113 | num_devices = jax.local_device_count() 114 | batch = [] 115 | for idx, graph in enumerate(graph_generator): 116 | if idx % num_devices == num_devices - 1: 117 | batch.append(graph) 118 | yield jax.tree_map(lambda *x: jnp.stack(x, axis=0), *batch) 119 | batch = [] 120 | else: 121 | batch.append(graph) 122 | 123 | 124 | def compute_loss(params, graph, label, net): 125 | """Computes loss.""" 126 | pred_graph = net.apply(params, graph) 127 | preds = jax.nn.log_softmax(pred_graph.globals) 128 | targets = jax.nn.one_hot(label, 2) 129 | 130 | # Since we have an extra 'dummy' graph in our batch due to padding, we want 131 | # to mask out any loss associated with the dummy graph. 132 | # Since we padded with `pad_with_graphs` we can recover the mask by using 133 | # get_graph_padding_mask. 134 | mask = jraph.get_graph_padding_mask(pred_graph) 135 | 136 | # Cross entropy loss. 137 | loss = -jnp.mean(preds * targets * mask[:, None]) 138 | 139 | # Accuracy taking into account the mask. 140 | accuracy = jnp.sum( 141 | (jnp.argmax(pred_graph.globals, axis=1) == label) * mask)/jnp.sum(mask) 142 | return loss, accuracy 143 | 144 | 145 | def train(data_path, master_csv_path, split_path, batch_size, 146 | num_training_steps, save_dir): 147 | """OGB Training Script.""" 148 | # Initialize the dataset reader. 149 | reader = data_utils.DataReader( 150 | data_path=data_path, 151 | master_csv_path=master_csv_path, 152 | split_path=split_path, 153 | batch_size=batch_size, 154 | dynamically_batch=True) 155 | # Repeat the dataset forever for training. 156 | reader.repeat() 157 | 158 | # Transform impure `net_fn` to pure functions with hk.transform. 159 | net = hk.without_apply_rng(hk.transform(net_fn)) 160 | # Get a candidate graph and label to initialize the network. 161 | graph = reader.get_graph_by_idx(0) 162 | 163 | # Initialize the network. 164 | logging.info('Initializing network.') 165 | params = net.init(jax.random.PRNGKey(42), graph) 166 | # Because we are training with multiple devices, params needs to have a 167 | # device axis. 168 | params = jax.device_put_replicated(params, list(jax.devices())) 169 | # Initialize the optimizer. 170 | opt_init, opt_update = optax.adam(1e-4) 171 | opt_state = jax.pmap(opt_init)(params) 172 | 173 | compute_loss_fn = functools.partial(compute_loss, net=net) 174 | # We pmap the computation of our loss, since this is the main computation. 175 | # Using jax.pmap means that we will use all available accelerators. 176 | # More information can be found in the jax documentation. 177 | @functools.partial(jax.pmap, axis_name='device') 178 | def update_fn(params, graph, label, opt_state): 179 | (loss, acc), grad = jax.value_and_grad( 180 | compute_loss_fn, has_aux=True)(params, graph, label) 181 | # Average gradients across devices 182 | grad = jax.lax.pmean(grad, axis_name='device') 183 | updates, opt_state = opt_update(grad, opt_state, params) 184 | params = optax.apply_updates(params, updates) 185 | return loss, acc, opt_state, params 186 | 187 | for idx in range(num_training_steps): 188 | graph_batch = next(device_batch(reader)) 189 | label = graph_batch.globals['label'] 190 | loss, acc, opt_state, params = update_fn( 191 | params, graph_batch, label, opt_state) 192 | if idx % 100 == 0: 193 | logging.info('step: %s, loss: %s, acc: %s', idx, loss, acc) 194 | if save_dir is not None: 195 | with pathlib.Path(save_dir, 'molhiv.pkl').open('wb') as fp: 196 | logging.info('Saving model to %s', save_dir) 197 | pickle.dump(params, fp) 198 | logging.info('Training finished') 199 | 200 | 201 | def evaluate(data_path, master_csv_path, split_path, save_dir): 202 | """Evaluation Script.""" 203 | logging.info('Evaluating OGB molviv') 204 | logging.info('Dataset split: %s', split_path) 205 | # Initialize the dataset reader. 206 | reader = data_utils.DataReader( 207 | data_path=data_path, 208 | master_csv_path=master_csv_path, 209 | split_path=split_path, 210 | batch_size=1, 211 | dynamically_batch=True) 212 | # Transform impure `net_fn` to pure functions with hk.transform. 213 | net = hk.without_apply_rng(hk.transform(net_fn)) 214 | with pathlib.Path(save_dir, 'molhiv.pkl').open('rb') as fp: 215 | params = pickle.load(fp) 216 | accumulated_loss = 0 217 | accumulated_accuracy = 0 218 | idx = 0 219 | 220 | # We pmap the computation of our loss, since this is the main computation. 221 | compute_loss_fn = jax.pmap(functools.partial(compute_loss, net=net)) 222 | for graph_batch in device_batch(reader): 223 | label = graph_batch.globals['label'] 224 | loss, acc = compute_loss_fn(params, graph_batch, label) 225 | accumulated_accuracy += jnp.sum(acc) 226 | accumulated_loss += jnp.sum(loss) 227 | total_num_padding_graphs = jnp.sum( 228 | jax.vmap(jraph.get_number_of_padding_with_graphs_graphs)(graph_batch)) 229 | idx += graph_batch.n_node.size - total_num_padding_graphs 230 | if idx % 100 == 0: 231 | logging.info('Evaluated %s graphs', idx) 232 | logging.info('Completed evaluation.') 233 | loss = accumulated_loss / idx 234 | accuracy = accumulated_accuracy / idx 235 | logging.info('Eval loss: %s, accuracy %s', loss, accuracy) 236 | return loss, accuracy 237 | 238 | 239 | def main(_): 240 | if FLAGS.mode == 'train': 241 | train(FLAGS.data_path, FLAGS.master_csv_path, FLAGS.split_path, 242 | FLAGS.batch_size, FLAGS.num_training_steps, FLAGS.save_dir) 243 | elif FLAGS.mode == 'evaluate': 244 | evaluate(FLAGS.data_path, FLAGS.master_csv_path, FLAGS.split_path, 245 | FLAGS.save_dir) 246 | 247 | if __name__ == '__main__': 248 | app.run(main) 249 | -------------------------------------------------------------------------------- /jraph/ogb_examples/train_pmap_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tests for jraph.ogb_examples.train_pmap.""" 16 | 17 | import pathlib 18 | from absl.testing import absltest 19 | from jraph.ogb_examples import train_pmap 20 | 21 | 22 | class TrainTest(absltest.TestCase): 23 | 24 | def test_train_and_eval_overfit(self): 25 | ogb_path = pathlib.Path(train_pmap.__file__).parents[0] 26 | master_csv_path = pathlib.Path(ogb_path, 'test_data', 'master.csv') 27 | split_path = pathlib.Path(ogb_path, 'test_data', 'train.csv.gz') 28 | data_path = master_csv_path.parents[0] 29 | temp_dir = self.create_tempdir().full_path 30 | train_pmap.train(data_path, master_csv_path, split_path, 1, 101, temp_dir) 31 | _, accuracy = train_pmap.evaluate(data_path, master_csv_path, split_path, 32 | temp_dir) 33 | self.assertEqual(float(accuracy), 1.0) 34 | 35 | 36 | if __name__ == '__main__': 37 | absltest.main() 38 | -------------------------------------------------------------------------------- /jraph/ogb_examples/train_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tests for jraph.ogb_examples.train.""" 16 | 17 | import pathlib 18 | from absl.testing import absltest 19 | from jraph.ogb_examples import train 20 | 21 | 22 | class TrainTest(absltest.TestCase): 23 | 24 | def test_train_and_eval_overfit(self): 25 | ogb_path = pathlib.Path(train.__file__).parents[0] 26 | master_csv_path = pathlib.Path(ogb_path, 'test_data', 'master.csv') 27 | split_path = pathlib.Path(ogb_path, 'test_data', 'train.csv.gz') 28 | data_path = master_csv_path.parents[0] 29 | temp_dir = self.create_tempdir().full_path 30 | train.train(data_path, master_csv_path, split_path, 1, 101, temp_dir) 31 | _, accuracy = train.evaluate( 32 | data_path, master_csv_path, split_path, temp_dir) 33 | self.assertEqual(accuracy, 1.0) 34 | 35 | 36 | if __name__ == '__main__': 37 | absltest.main() 38 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | sphinx: 4 | builder: html 5 | configuration: docs/conf.py 6 | fail_on_warning: false 7 | 8 | python: 9 | version: 3.7 10 | install: 11 | - requirements: requirements.txt 12 | - requirements: docs/requirements.txt 13 | - method: setuptools 14 | path: . 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax>=0.1.55 2 | jaxlib>=0.1.37 3 | numpy>=1.18.0 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Install script for setuptools.""" 16 | 17 | from setuptools import find_namespace_packages 18 | from setuptools import setup 19 | 20 | 21 | def _get_version(): 22 | with open('jraph/__init__.py') as fp: 23 | for line in fp: 24 | if line.startswith('__version__') and '=' in line: 25 | version = line[line.find('=')+1:].strip(' \'"\n') 26 | if version: 27 | return version 28 | raise ValueError('`__version__` not defined in `jraph/__init__.py`') 29 | 30 | 31 | setup( 32 | name='jraph', 33 | version=_get_version(), 34 | url='https://github.com/deepmind/jraph', 35 | license='Apache 2.0', 36 | author='DeepMind', 37 | description=('Jraph: A library for Graph Neural Networks in Jax'), 38 | long_description=open('README.md').read(), 39 | long_description_content_type='text/markdown', 40 | author_email='jax_graph_nets@google.com', 41 | keywords='jax graph neural networks python machine learning', 42 | packages=find_namespace_packages(exclude=['*_test.py']), 43 | package_data={'jraph': ['ogb_examples/test_data/*']}, 44 | python_requires='>=3.6', 45 | install_requires=[ 46 | 'jax>=0.1.55', 47 | 'jaxlib>=0.1.37', 48 | 'numpy>=1.18.0', 49 | ], 50 | extras_require={'examples': ['dm-haiku>=0.0.2', 'absl-py>=0.9', 51 | 'frozendict>=2.0.2', 'optax>=0.0.1', 52 | 'scipy>=1.2.1'], 53 | 'ogb_examples': ['dm-haiku>=0.0.2', 'absl-py>=0.9', 54 | 'optax>=0.0.1', 'pandas>=1.0.5', 55 | 'dm-tree>=0.1.5']}, 56 | classifiers=[ 57 | 'Development Status :: 5 - Production/Stable', 58 | 'Environment :: Console', 59 | 'Intended Audience :: Science/Research', 60 | 'Intended Audience :: Developers', 61 | 'License :: OSI Approved :: Apache Software License', 62 | 'Operating System :: OS Independent', 63 | 'Programming Language :: Python', 64 | 'Programming Language :: Python :: 3', 65 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 66 | 'Topic :: Software Development :: Libraries :: Python Modules', 67 | ], 68 | ) 69 | --------------------------------------------------------------------------------