├── galax ├── data │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── nodes │ │ │ ├── __init__.py │ │ │ ├── ogb.py │ │ │ ├── graphsage.py │ │ │ └── planetoid.py │ │ └── graphs │ │ │ └── gin.py │ └── dataloader.py ├── nn │ ├── zoo │ │ ├── __init__.py │ │ ├── gcn.py │ │ ├── graphsage.py │ │ └── gat.py │ ├── __init__.py │ ├── utils │ │ ├── __init__.py │ │ └── early_stopping.py │ └── module.py ├── __init__.py ├── tests │ ├── test_jittable.py │ └── test_dgl_consistency.py ├── core.py ├── view.py ├── batch.py ├── function.py ├── _version.py └── graph_index.py ├── requirements.txt ├── pytest.ini ├── benchmarks ├── gcn │ ├── README.md │ ├── cora.py │ ├── pubmed.py │ └── citeseer.py ├── gat │ ├── ppi.sh │ ├── ppi.py │ ├── cora.py │ ├── pubmed.py │ └── citeseer.py ├── graphsage │ └── reddit.py └── speed │ ├── kernel │ └── kernel.py │ └── training │ ├── gat │ ├── pubmed.py │ ├── citeseer.py │ └── cora.py │ └── gcn │ ├── cora.py │ ├── pubmed.py │ └── citeseer.py ├── docs ├── autosummary │ ├── galax.function.copy_e.rst │ ├── galax.function.copy_u.rst │ ├── galax.function.e_add_u.rst │ ├── galax.function.e_add_v.rst │ ├── galax.function.e_div_u.rst │ ├── galax.function.e_div_v.rst │ ├── galax.function.e_dot_u.rst │ ├── galax.function.e_dot_v.rst │ ├── galax.function.e_mul_u.rst │ ├── galax.function.e_mul_v.rst │ ├── galax.function.e_sub_u.rst │ ├── galax.function.e_sub_v.rst │ ├── galax.function.u_add_e.rst │ ├── galax.function.u_add_v.rst │ ├── galax.function.u_div_e.rst │ ├── galax.function.u_div_v.rst │ ├── galax.function.u_dot_e.rst │ ├── galax.function.u_dot_v.rst │ ├── galax.function.u_mul_e.rst │ ├── galax.function.u_mul_v.rst │ ├── galax.function.u_sub_e.rst │ ├── galax.function.u_sub_v.rst │ ├── galax.function.v_add_e.rst │ ├── galax.function.v_add_u.rst │ ├── galax.function.v_div_e.rst │ ├── galax.function.v_div_u.rst │ ├── galax.function.v_dot_e.rst │ ├── galax.function.v_dot_u.rst │ ├── galax.function.v_mul_e.rst │ ├── galax.function.v_mul_u.rst │ ├── galax.function.v_sub_e.rst │ ├── galax.function.v_sub_u.rst │ ├── galax.heterograph.graph.rst │ ├── galax.core.message_passing.rst │ ├── galax.function.apply_edges.rst │ ├── galax.function.apply_nodes.rst │ ├── galax.function.segment_mean.rst │ ├── galax.core.rst │ ├── galax.heterograph.rst │ ├── galax.heterograph.EdgeSpace.rst │ ├── galax.heterograph.NodeSpace.rst │ ├── galax.function.ReduceFunction.rst │ ├── galax.function.rst │ └── galax.heterograph.HeteroGraph.rst ├── api.rst ├── _templates │ ├── README.md │ ├── custom-class-template.rst │ └── custom-module-template.rst ├── _static │ └── README.md ├── index.rst ├── README.md ├── Makefile ├── make.bat └── conf.py ├── setup.cfg ├── .github └── workflows │ ├── DOC.yml │ └── CI.yml ├── LICENSE ├── README.md ├── .gitignore └── setup.py /galax/data/__init__.py: -------------------------------------------------------------------------------- 1 | """Data utility.""" 2 | -------------------------------------------------------------------------------- /galax/nn/zoo/__init__.py: -------------------------------------------------------------------------------- 1 | """Model zoo.""" 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax 2 | jaxlib 3 | flax 4 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = --doctest-modules 3 | -------------------------------------------------------------------------------- /galax/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """Benchmark datasets.""" 2 | -------------------------------------------------------------------------------- /galax/data/datasets/nodes/__init__.py: -------------------------------------------------------------------------------- 1 | """Node property prediction.""" 2 | -------------------------------------------------------------------------------- /galax/nn/__init__.py: -------------------------------------------------------------------------------- 1 | """Neural networks functionalities.""" 2 | 3 | from .module import ApplyNodes 4 | -------------------------------------------------------------------------------- /galax/nn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Neural networks utilities.""" 2 | from .early_stopping import EarlyStopping 3 | -------------------------------------------------------------------------------- /benchmarks/gcn/README.md: -------------------------------------------------------------------------------- 1 | [Graph Convolutaionl Networks](https://arxiv.org/abs/1609.02907) 2 | ============================= 3 | -------------------------------------------------------------------------------- /docs/autosummary/galax.function.copy_e.rst: -------------------------------------------------------------------------------- 1 | galax.function.copy\_e 2 | ====================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: copy_e -------------------------------------------------------------------------------- /docs/autosummary/galax.function.copy_u.rst: -------------------------------------------------------------------------------- 1 | galax.function.copy\_u 2 | ====================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: copy_u -------------------------------------------------------------------------------- /docs/autosummary/galax.function.e_add_u.rst: -------------------------------------------------------------------------------- 1 | galax.function.e\_add\_u 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: e_add_u -------------------------------------------------------------------------------- /docs/autosummary/galax.function.e_add_v.rst: -------------------------------------------------------------------------------- 1 | galax.function.e\_add\_v 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: e_add_v -------------------------------------------------------------------------------- /docs/autosummary/galax.function.e_div_u.rst: -------------------------------------------------------------------------------- 1 | galax.function.e\_div\_u 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: e_div_u -------------------------------------------------------------------------------- /docs/autosummary/galax.function.e_div_v.rst: -------------------------------------------------------------------------------- 1 | galax.function.e\_div\_v 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: e_div_v -------------------------------------------------------------------------------- /docs/autosummary/galax.function.e_dot_u.rst: -------------------------------------------------------------------------------- 1 | galax.function.e\_dot\_u 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: e_dot_u -------------------------------------------------------------------------------- /docs/autosummary/galax.function.e_dot_v.rst: -------------------------------------------------------------------------------- 1 | galax.function.e\_dot\_v 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: e_dot_v -------------------------------------------------------------------------------- /docs/autosummary/galax.function.e_mul_u.rst: -------------------------------------------------------------------------------- 1 | galax.function.e\_mul\_u 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: e_mul_u -------------------------------------------------------------------------------- /docs/autosummary/galax.function.e_mul_v.rst: -------------------------------------------------------------------------------- 1 | galax.function.e\_mul\_v 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: e_mul_v -------------------------------------------------------------------------------- /docs/autosummary/galax.function.e_sub_u.rst: -------------------------------------------------------------------------------- 1 | galax.function.e\_sub\_u 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: e_sub_u -------------------------------------------------------------------------------- /docs/autosummary/galax.function.e_sub_v.rst: -------------------------------------------------------------------------------- 1 | galax.function.e\_sub\_v 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: e_sub_v -------------------------------------------------------------------------------- /docs/autosummary/galax.function.u_add_e.rst: -------------------------------------------------------------------------------- 1 | galax.function.u\_add\_e 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: u_add_e -------------------------------------------------------------------------------- /docs/autosummary/galax.function.u_add_v.rst: -------------------------------------------------------------------------------- 1 | galax.function.u\_add\_v 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: u_add_v -------------------------------------------------------------------------------- /docs/autosummary/galax.function.u_div_e.rst: -------------------------------------------------------------------------------- 1 | galax.function.u\_div\_e 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: u_div_e -------------------------------------------------------------------------------- /docs/autosummary/galax.function.u_div_v.rst: -------------------------------------------------------------------------------- 1 | galax.function.u\_div\_v 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: u_div_v -------------------------------------------------------------------------------- /docs/autosummary/galax.function.u_dot_e.rst: -------------------------------------------------------------------------------- 1 | galax.function.u\_dot\_e 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: u_dot_e -------------------------------------------------------------------------------- /docs/autosummary/galax.function.u_dot_v.rst: -------------------------------------------------------------------------------- 1 | galax.function.u\_dot\_v 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: u_dot_v -------------------------------------------------------------------------------- /docs/autosummary/galax.function.u_mul_e.rst: -------------------------------------------------------------------------------- 1 | galax.function.u\_mul\_e 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: u_mul_e -------------------------------------------------------------------------------- /docs/autosummary/galax.function.u_mul_v.rst: -------------------------------------------------------------------------------- 1 | galax.function.u\_mul\_v 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: u_mul_v -------------------------------------------------------------------------------- /docs/autosummary/galax.function.u_sub_e.rst: -------------------------------------------------------------------------------- 1 | galax.function.u\_sub\_e 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: u_sub_e -------------------------------------------------------------------------------- /docs/autosummary/galax.function.u_sub_v.rst: -------------------------------------------------------------------------------- 1 | galax.function.u\_sub\_v 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: u_sub_v -------------------------------------------------------------------------------- /docs/autosummary/galax.function.v_add_e.rst: -------------------------------------------------------------------------------- 1 | galax.function.v\_add\_e 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: v_add_e -------------------------------------------------------------------------------- /docs/autosummary/galax.function.v_add_u.rst: -------------------------------------------------------------------------------- 1 | galax.function.v\_add\_u 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: v_add_u -------------------------------------------------------------------------------- /docs/autosummary/galax.function.v_div_e.rst: -------------------------------------------------------------------------------- 1 | galax.function.v\_div\_e 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: v_div_e -------------------------------------------------------------------------------- /docs/autosummary/galax.function.v_div_u.rst: -------------------------------------------------------------------------------- 1 | galax.function.v\_div\_u 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: v_div_u -------------------------------------------------------------------------------- /docs/autosummary/galax.function.v_dot_e.rst: -------------------------------------------------------------------------------- 1 | galax.function.v\_dot\_e 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: v_dot_e -------------------------------------------------------------------------------- /docs/autosummary/galax.function.v_dot_u.rst: -------------------------------------------------------------------------------- 1 | galax.function.v\_dot\_u 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: v_dot_u -------------------------------------------------------------------------------- /docs/autosummary/galax.function.v_mul_e.rst: -------------------------------------------------------------------------------- 1 | galax.function.v\_mul\_e 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: v_mul_e -------------------------------------------------------------------------------- /docs/autosummary/galax.function.v_mul_u.rst: -------------------------------------------------------------------------------- 1 | galax.function.v\_mul\_u 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: v_mul_u -------------------------------------------------------------------------------- /docs/autosummary/galax.function.v_sub_e.rst: -------------------------------------------------------------------------------- 1 | galax.function.v\_sub\_e 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: v_sub_e -------------------------------------------------------------------------------- /docs/autosummary/galax.function.v_sub_u.rst: -------------------------------------------------------------------------------- 1 | galax.function.v\_sub\_u 2 | ======================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: v_sub_u -------------------------------------------------------------------------------- /docs/autosummary/galax.heterograph.graph.rst: -------------------------------------------------------------------------------- 1 | galax.heterograph.graph 2 | ======================= 3 | 4 | .. currentmodule:: galax.heterograph 5 | 6 | .. autofunction:: graph -------------------------------------------------------------------------------- /docs/autosummary/galax.core.message_passing.rst: -------------------------------------------------------------------------------- 1 | galax.core.message\_passing 2 | =========================== 3 | 4 | .. currentmodule:: galax.core 5 | 6 | .. autofunction:: message_passing -------------------------------------------------------------------------------- /docs/autosummary/galax.function.apply_edges.rst: -------------------------------------------------------------------------------- 1 | galax.function.apply\_edges 2 | =========================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: apply_edges -------------------------------------------------------------------------------- /docs/autosummary/galax.function.apply_nodes.rst: -------------------------------------------------------------------------------- 1 | galax.function.apply\_nodes 2 | =========================== 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: apply_nodes -------------------------------------------------------------------------------- /docs/autosummary/galax.function.segment_mean.rst: -------------------------------------------------------------------------------- 1 | galax.function.segment\_mean 2 | ============================ 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autofunction:: segment_mean -------------------------------------------------------------------------------- /galax/__init__.py: -------------------------------------------------------------------------------- 1 | from .heterograph import graph, from_dgl 2 | from .core import message_passing 3 | from . import function, nn 4 | from .nn.module import ApplyNodes 5 | from .function import apply_nodes, apply_edges 6 | from .batch import batch, pad 7 | -------------------------------------------------------------------------------- /benchmarks/gat/ppi.sh: -------------------------------------------------------------------------------- 1 | #BSUB -q gpuqueue 2 | #BSUB -o %J.stdout 3 | #BSUB -gpu "num=1:j_exclusive=yes" 4 | #BSUB -R "rusage[mem=10] span[ptile=1]" 5 | #BSUB -W 0:59 6 | #BSUB -n 1 7 | #BSUB -gpu "num=1/task:j_exclusive=yes:mode=shared" 8 | 9 | python ppi.py 10 | 11 | 12 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | API Documentation 2 | ================= 3 | 4 | .. autosummary:: 5 | :toctree: autosummary 6 | :template: custom-module-template.rst 7 | :recursive: 8 | 9 | galax.heterograph 10 | galax.function 11 | galax.core 12 | galax.data 13 | galax.nn 14 | -------------------------------------------------------------------------------- /docs/autosummary/galax.core.rst: -------------------------------------------------------------------------------- 1 | galax.core 2 | ========== 3 | 4 | .. automodule:: galax.core 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | .. rubric:: Functions 13 | 14 | .. autosummary:: 15 | :toctree: 16 | 17 | message_passing 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /galax/data/datasets/nodes/ogb.py: -------------------------------------------------------------------------------- 1 | """`OGB datasets. `__ 2 | """ 3 | 4 | def arxiv(): 5 | from ogb.nodeproppred import DglNodePropPredDataset 6 | import galax 7 | g = DglNodePropPredDataset(name="ogbn-arxiv")[0][0] 8 | g.ndata['h'] = g.ndata['feat'] 9 | del g.ndata['feat'] 10 | g = galax.from_dgl(g) 11 | return g 12 | -------------------------------------------------------------------------------- /galax/nn/module.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Callable, Optional 3 | from flax import linen as nn 4 | from ..function import apply_nodes, apply_edges 5 | 6 | class ApplyNodes(nn.Module): 7 | layer: Callable 8 | 9 | def __call__(self, graph, field="h"): 10 | graph = graph.ndata.set(field, self.layer(graph.ndata[field])) 11 | return graph 12 | -------------------------------------------------------------------------------- /docs/_templates/README.md: -------------------------------------------------------------------------------- 1 | # Templates Doc Directory 2 | 3 | Add any paths that contain templates here, relative to 4 | the `conf.py` file's directory. 5 | They are copied after the builtin template files, 6 | so a file named "page.html" will overwrite the builtin "page.html". 7 | 8 | The path to this folder is set in the Sphinx `conf.py` file in the line: 9 | ```python 10 | html_static_path = ['_templates'] 11 | ``` 12 | 13 | ## Examples of file to add to this directory 14 | * HTML extensions of stock pages like `page.html` or `layout.html` 15 | -------------------------------------------------------------------------------- /docs/_static/README.md: -------------------------------------------------------------------------------- 1 | # Static Doc Directory 2 | 3 | Add any paths that contain custom static files (such as style sheets) here, 4 | relative to the `conf.py` file's directory. 5 | They are copied after the builtin static files, 6 | so a file named "default.css" will overwrite the builtin "default.css". 7 | 8 | The path to this folder is set in the Sphinx `conf.py` file in the line: 9 | ```python 10 | templates_path = ['_static'] 11 | ``` 12 | 13 | ## Examples of file to add to this directory 14 | * Custom Cascading Style Sheets 15 | * Custom JavaScript code 16 | * Static logo images 17 | -------------------------------------------------------------------------------- /docs/autosummary/galax.heterograph.rst: -------------------------------------------------------------------------------- 1 | galax.heterograph 2 | ================= 3 | 4 | .. automodule:: galax.heterograph 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | .. rubric:: Functions 13 | 14 | .. autosummary:: 15 | :toctree: 16 | 17 | graph 18 | 19 | 20 | 21 | 22 | 23 | .. rubric:: Classes 24 | 25 | .. autosummary:: 26 | :toctree: 27 | :template: custom-class-template.rst 28 | 29 | EdgeSpace 30 | HeteroGraph 31 | NodeSpace 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. galax documentation master file, created by 2 | sphinx-quickstart on Thu Mar 15 13:55:56 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | GALAX: Graph Learning with JAX 7 | ========================================================= 8 | 9 | .. mdinclude:: ../README.md 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | :caption: Contents: 14 | 15 | api 16 | 17 | 18 | Indices and tables 19 | ================== 20 | 21 | * :ref:`genindex` 22 | * :ref:`modindex` 23 | * :ref:`search` 24 | -------------------------------------------------------------------------------- /docs/autosummary/galax.heterograph.EdgeSpace.rst: -------------------------------------------------------------------------------- 1 | galax.heterograph.EdgeSpace 2 | =========================== 3 | 4 | .. currentmodule:: galax.heterograph 5 | 6 | .. autoclass:: EdgeSpace 7 | :members: 8 | :show-inheritance: 9 | :inherited-members: 10 | 11 | 12 | .. automethod:: __init__ 13 | 14 | 15 | .. rubric:: Methods 16 | 17 | .. autosummary:: 18 | 19 | ~EdgeSpace.__init__ 20 | ~EdgeSpace.count 21 | ~EdgeSpace.index 22 | 23 | 24 | 25 | 26 | 27 | .. rubric:: Attributes 28 | 29 | .. autosummary:: 30 | 31 | ~EdgeSpace.data 32 | 33 | -------------------------------------------------------------------------------- /docs/autosummary/galax.heterograph.NodeSpace.rst: -------------------------------------------------------------------------------- 1 | galax.heterograph.NodeSpace 2 | =========================== 3 | 4 | .. currentmodule:: galax.heterograph 5 | 6 | .. autoclass:: NodeSpace 7 | :members: 8 | :show-inheritance: 9 | :inherited-members: 10 | 11 | 12 | .. automethod:: __init__ 13 | 14 | 15 | .. rubric:: Methods 16 | 17 | .. autosummary:: 18 | 19 | ~NodeSpace.__init__ 20 | ~NodeSpace.count 21 | ~NodeSpace.index 22 | 23 | 24 | 25 | 26 | 27 | .. rubric:: Attributes 28 | 29 | .. autosummary:: 30 | 31 | ~NodeSpace.data 32 | 33 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Compiling galax's Documentation 2 | 3 | The docs for this project are built with [Sphinx](http://www.sphinx-doc.org/en/master/). 4 | To compile the docs, first ensure that Sphinx and the ReadTheDocs theme are installed. 5 | 6 | 7 | ```bash 8 | conda install sphinx sphinx_rtd_theme 9 | ``` 10 | 11 | 12 | Once installed, you can use the `Makefile` in this directory to compile static HTML pages by 13 | ```bash 14 | make html 15 | ``` 16 | 17 | The compiled docs will be in the `_build` directory and can be viewed by opening `index.html` (which may itself 18 | be inside a directory called `html/` depending on what version of Sphinx is installed). 19 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = galax 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /benchmarks/graphsage/reddit.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import jax 3 | from flax import linen as nn 4 | import optax 5 | import galax 6 | from importlib import import_module 7 | 8 | def run(args): 9 | from galax.data.datasets.nodes.graphsage import reddit 10 | G = reddit() 11 | OUT_FEATURES = Y_RED.max() + 1 12 | Y_REF = jax.nn.one_hot(G.ndata['label'], OUT_FEATURES) 13 | 14 | from galax.nn.zoo.graphsage import GraphSAGE 15 | model = nn.Sequential( 16 | GraphSAGE(args.features), 17 | ) 18 | 19 | 20 | 21 | 22 | if __name__ == "__main__": 23 | import argparse 24 | parser = argparse.ArgumentParser() 25 | args = parser.parse_args() 26 | run(args) 27 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # Helper file to handle all configs 2 | 3 | [coverage:run] 4 | # .coveragerc to control coverage.py and pytest-cov 5 | omit = 6 | # Omit the tests 7 | */tests/* 8 | # Omit generated versioneer 9 | galax/_version.py 10 | 11 | [yapf] 12 | # YAPF, in .style.yapf files this shows up as "[style]" header 13 | COLUMN_LIMIT = 119 14 | INDENT_WIDTH = 4 15 | USE_TABS = False 16 | 17 | [flake8] 18 | # Flake8, PyFlakes, etc 19 | max-line-length = 119 20 | 21 | [versioneer] 22 | # Automatic version numbering scheme 23 | VCS = git 24 | style = pep440 25 | versionfile_source = galax/_version.py 26 | versionfile_build = galax/_version.py 27 | tag_prefix = '' 28 | 29 | [aliases] 30 | test = pytest 31 | -------------------------------------------------------------------------------- /docs/autosummary/galax.function.ReduceFunction.rst: -------------------------------------------------------------------------------- 1 | galax.function.ReduceFunction 2 | ============================= 3 | 4 | .. currentmodule:: galax.function 5 | 6 | .. autoclass:: ReduceFunction 7 | :members: 8 | :show-inheritance: 9 | :inherited-members: 10 | 11 | 12 | .. automethod:: __init__ 13 | 14 | 15 | .. rubric:: Methods 16 | 17 | .. autosummary:: 18 | 19 | ~ReduceFunction.__init__ 20 | ~ReduceFunction.count 21 | ~ReduceFunction.index 22 | 23 | 24 | 25 | 26 | 27 | .. rubric:: Attributes 28 | 29 | .. autosummary:: 30 | 31 | ~ReduceFunction.msg_field 32 | ~ReduceFunction.op 33 | ~ReduceFunction.out_field 34 | 35 | -------------------------------------------------------------------------------- /docs/_templates/custom-class-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :members: 7 | :show-inheritance: 8 | :inherited-members: 9 | 10 | {% block methods %} 11 | .. automethod:: __init__ 12 | 13 | {% if methods %} 14 | .. rubric:: {{ _('Methods') }} 15 | 16 | .. autosummary:: 17 | {% for item in methods %} 18 | ~{{ name }}.{{ item }} 19 | {%- endfor %} 20 | {% endif %} 21 | {% endblock %} 22 | 23 | {% block attributes %} 24 | {% if attributes %} 25 | .. rubric:: {{ _('Attributes') }} 26 | 27 | .. autosummary:: 28 | {% for item in attributes %} 29 | ~{{ name }}.{{ item }} 30 | {%- endfor %} 31 | {% endif %} 32 | {% endblock %} 33 | -------------------------------------------------------------------------------- /galax/data/datasets/nodes/graphsage.py: -------------------------------------------------------------------------------- 1 | """Reddit and PPI. 2 | 3 | Reference: http://snap.stanford.edu/graphsage/ 4 | """ 5 | 6 | def reddit(): 7 | import galax 8 | from dgl.data import RedditDataset 9 | g = RedditDataset()[0] 10 | g.ndata['h'] = g.ndata['feat'] 11 | del g.ndata['feat'] 12 | g = galax.from_dgl(g) 13 | return g 14 | 15 | def ppi(): 16 | import galax 17 | from dgl.data import PPIDataset 18 | gs_tr = PPIDataset("train") 19 | gs_vl = PPIDataset("valid") 20 | gs_te = PPIDataset("test") 21 | 22 | def fn(g): 23 | g.ndata['h'] = g.ndata['feat'] 24 | del g.ndata['feat'] 25 | return galax.from_dgl(g) 26 | 27 | gs_tr = tuple((fn(g) for g in gs_tr)) 28 | gs_vl = tuple((fn(g) for g in gs_vl)) 29 | gs_te = tuple((fn(g) for g in gs_te)) 30 | return gs_tr, gs_vl, gs_te 31 | -------------------------------------------------------------------------------- /.github/workflows/DOC.yml: -------------------------------------------------------------------------------- 1 | name: "DOC" 2 | on: 3 | - push 4 | 5 | jobs: 6 | docs: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - name: Set up Python 3.7 11 | uses: actions/setup-python@v2 12 | with: 13 | python-version: 3.7 14 | - name: Install dependencies and compile 15 | run: | 16 | python -m pip install --upgrade pip 17 | python -m pip install flake8 pytest versioneer sphinx sphinx_rtd_theme pyyaml networkx m2r2 18 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 19 | python setup.py install 20 | cd docs && make html && cd _build/html && echo 'galax.wangyq.net' > CNAME 21 | - name: Deploy 22 | uses: peaceiris/actions-gh-pages@v3 23 | with: 24 | github_token: ${{ secrets.GITHUB_TOKEN }} 25 | publish_dir: docs/_build/html 26 | -------------------------------------------------------------------------------- /galax/data/datasets/nodes/planetoid.py: -------------------------------------------------------------------------------- 1 | """Cora, citeseer, pubmed dataset. 2 | 3 | Following dataset loading and preprocessing code from tkipf/gcn 4 | https://github.com/tkipf/gcn/blob/master/gcn/utils.py 5 | """ 6 | import galax 7 | 8 | def cora(): 9 | from dgl.data import CoraGraphDataset 10 | g = CoraGraphDataset()[0] 11 | g.ndata['h'] = g.ndata['feat'] 12 | del g.ndata['feat'] 13 | g = galax.from_dgl(g) 14 | return g 15 | 16 | def citeseer(): 17 | from dgl.data import CiteseerGraphDataset 18 | g = CiteseerGraphDataset()[0] 19 | g.ndata['h'] = g.ndata['feat'] 20 | del g.ndata['feat'] 21 | g = galax.from_dgl(g) 22 | return g 23 | 24 | def pubmed(): 25 | from dgl.data import PubmedGraphDataset 26 | g = PubmedGraphDataset()[0] 27 | g.ndata['h'] = g.ndata['feat'] 28 | del g.ndata['feat'] 29 | g = galax.from_dgl(g) 30 | return g 31 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=malt 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Yuanqing Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/autosummary/galax.function.rst: -------------------------------------------------------------------------------- 1 | galax.function 2 | ============== 3 | 4 | .. automodule:: galax.function 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | .. rubric:: Functions 13 | 14 | .. autosummary:: 15 | :toctree: 16 | 17 | apply_edges 18 | apply_nodes 19 | copy_e 20 | copy_u 21 | e_add_u 22 | e_add_v 23 | e_div_u 24 | e_div_v 25 | e_dot_u 26 | e_dot_v 27 | e_mul_u 28 | e_mul_v 29 | e_sub_u 30 | e_sub_v 31 | segment_mean 32 | u_add_e 33 | u_add_v 34 | u_div_e 35 | u_div_v 36 | u_dot_e 37 | u_dot_v 38 | u_mul_e 39 | u_mul_v 40 | u_sub_e 41 | u_sub_v 42 | v_add_e 43 | v_add_u 44 | v_div_e 45 | v_div_u 46 | v_dot_e 47 | v_dot_u 48 | v_mul_e 49 | v_mul_u 50 | v_sub_e 51 | v_sub_u 52 | 53 | 54 | 55 | 56 | 57 | .. rubric:: Classes 58 | 59 | .. autosummary:: 60 | :toctree: 61 | :template: custom-class-template.rst 62 | 63 | ReduceFunction 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /galax/nn/utils/early_stopping.py: -------------------------------------------------------------------------------- 1 | class EarlyStopping(object): 2 | """Early stopping. 3 | 4 | Parameters 5 | ---------- 6 | patience : int = 10 7 | Patience for early stopping. 8 | 9 | """ 10 | 11 | best_losses = None 12 | params = None 13 | counter = 0 14 | 15 | def __init__(self, patience: int = 10): 16 | self.patience = patience 17 | 18 | def __call__(self, losses, params): 19 | if self.best_losses is None: 20 | self.best_losses = losses 21 | self.counter = 0 22 | 23 | elif any( 24 | loss <= best_loss 25 | for loss, best_loss in zip(losses, self.best_losses) 26 | ): 27 | if all( 28 | loss <= best_loss 29 | for loss, best_loss in zip(losses, self.best_losses) 30 | ): 31 | self.params = params 32 | self.best_losses = [ 33 | min(loss, best_loss) 34 | for loss, best_loss in zip(losses, self.best_losses) 35 | ] 36 | self.counter = 0 37 | 38 | else: 39 | self.counter += 1 40 | if self.counter == self.patience: 41 | return True 42 | 43 | return False 44 | -------------------------------------------------------------------------------- /galax/data/datasets/graphs/gin.py: -------------------------------------------------------------------------------- 1 | """Datasets used in How Powerful Are Graph Neural Networks? 2 | (chen jun) 3 | Datasets include: 4 | MUTAG, COLLAB, IMDBBINARY, IMDBMULTI, NCI1, PROTEINS, 5 | PTC, REDDITBINARY, REDDITMULTI5K 6 | https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip 7 | """ 8 | import sys 9 | 10 | DATASETS = [ 11 | "MUTAG", "COLLAB", "IMDBBINARY", "IMDBMULTI", "NCI1", "PROTEINS", 12 | "PTC", "REDDITBINARY", "REDDITMULTI5K", 13 | ] 14 | 15 | def get_dataset_function(dataset): 16 | def transform(g, y): 17 | import galax 18 | g.ndata['h'] = g.ndata['attr'] 19 | del g.ndata['h'] 20 | g = galax.from_dgl(g) 21 | g = g.gdata.set("label", y) 22 | return g 23 | 24 | def fn(): 25 | import dgl 26 | from dgl.data import GINDataset 27 | import galax 28 | _dataset = GINDataset(dataset, self_loop=False) 29 | gs, ys = zip(*[_dataset[idx] for idx in range(len(_dataset))]) 30 | gs = tuple([transform(g, y) for g, y in zip(gs, ys)]) 31 | return gs 32 | 33 | fn.__name__ = dataset.lower() 34 | return fn 35 | 36 | for dataset in DATASETS: 37 | fn = get_dataset_function(dataset) 38 | setattr(sys.modules[__name__], dataset.lower(), fn) 39 | -------------------------------------------------------------------------------- /.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: CI 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [3.9] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip flake8 pytest pyyaml networkx dgl torch 29 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 30 | python setup.py install 31 | # - name: Lint with flake8 32 | # run: | 33 | # # stop the build if there are Python syntax errors or undefined names 34 | # flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 35 | # # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 36 | # flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 37 | - name: Test with pytest 38 | run: | 39 | pytest galax 40 | -------------------------------------------------------------------------------- /docs/_templates/custom-module-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. automodule:: {{ fullname }} 4 | 5 | {% block attributes %} 6 | {% if attributes %} 7 | .. rubric:: Module Attributes 8 | 9 | .. autosummary:: 10 | :toctree: 11 | {% for item in attributes %} 12 | {{ item }} 13 | {%- endfor %} 14 | {% endif %} 15 | {% endblock %} 16 | 17 | {% block functions %} 18 | {% if functions %} 19 | .. rubric:: {{ _('Functions') }} 20 | 21 | .. autosummary:: 22 | :toctree: 23 | {% for item in functions %} 24 | {{ item }} 25 | {%- endfor %} 26 | {% endif %} 27 | {% endblock %} 28 | 29 | {% block classes %} 30 | {% if classes %} 31 | .. rubric:: {{ _('Classes') }} 32 | 33 | .. autosummary:: 34 | :toctree: 35 | :template: custom-class-template.rst 36 | {% for item in classes %} 37 | {{ item }} 38 | {%- endfor %} 39 | {% endif %} 40 | {% endblock %} 41 | 42 | {% block exceptions %} 43 | {% if exceptions %} 44 | .. rubric:: {{ _('Exceptions') }} 45 | 46 | .. autosummary:: 47 | :toctree: 48 | {% for item in exceptions %} 49 | {{ item }} 50 | {%- endfor %} 51 | {% endif %} 52 | {% endblock %} 53 | 54 | {% block modules %} 55 | {% if modules %} 56 | .. rubric:: Modules 57 | 58 | .. autosummary:: 59 | :toctree: 60 | :template: custom-module-template.rst 61 | :recursive: 62 | 63 | {% for item in modules %} 64 | {{ item }} 65 | {%- endfor %} 66 | {% endif %} 67 | {% endblock %} 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Graph Learning with JAX 2 | ======================== 3 | [//]: # (Badges) 4 | [![CI](https://github.com/yuanqing-wang/galax/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/yuanqing-wang/galax/actions/workflows/CI.yml) 5 | [![Language grade: Python](https://img.shields.io/lgtm/grade/python/g/yuanqing-wang/galax.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/yuanqing-wang/galax/context:python) 6 | [![pypi](https://img.shields.io/pypi/v/g3x.svg)](https://pypi.org/project/g3x/) 7 | [![docs stable](https://img.shields.io/badge/docs-stable-5077AB.svg?logo=read%20the%20docs)](https://galax.wangyq.net/) 8 | 9 | Galax is a graph-centric, high-performance library for graph modeling with JAX. 10 | 11 | ## Installation 12 | ``` 13 | > pip install g3x 14 | ``` 15 | 16 | ## Design principle 17 | * Pure JAX: end-to-end differentiable and jittable. 18 | * Graphs (including heterographs with multiple node types), metagraphs, and node and edge data are simply pytrees (or more precisely, namedtuples), **and are thus immutable**. 19 | * All transforms (including neural networks inhereted from [flax](https://github.com/google/flax)) take and return graphs. 20 | * Grammar highly resembles [DGL](https://www.dgl.ai), except being purely functional. 21 | 22 | ## Quick start 23 | Implement a simple graph convolution in six lines. 24 | ```python 25 | >>> import jax.numpy as jnp; import galax 26 | >>> g = galax.graph(([0, 1], [1, 2])) 27 | >>> g = g.ndata.set("h", jnp.ones((3, 16))) 28 | >>> g = g.update_all(galax.function.copy_u("h", "m"), galax.function.sum("m", "h")) 29 | >>> W = jnp.random.normal(key=jax.random.PRNGKey(2666), shape=(16, 16)) 30 | >>> g = g.apply_nodes(lambda node: {"h": node.data["h"] @ W}) 31 | ``` 32 | -------------------------------------------------------------------------------- /benchmarks/speed/kernel/kernel.py: -------------------------------------------------------------------------------- 1 | import ogb 2 | import time 3 | import numpy as onp 4 | import jax 5 | import jax.numpy as jnp 6 | import galax 7 | 8 | n_cold_start = 2 9 | 10 | def bench_spmm(G, binary_op, reduce_op): 11 | print("SPMM\n----------------------------") 12 | 13 | for n_hid in [1, 2, 4, 8, 16, 32, 64, 128]: 14 | nfeat = jnp.array(onp.random.normal(size=(G.number_of_nodes(), n_hid))) 15 | efeat = jnp.array(onp.random.normal(size=(G.number_of_edges(), n_hid))) 16 | g = G.ndata.set("h", nfeat) 17 | g = G.edata.set("h", efeat) 18 | accum_time = 0 19 | 20 | @jax.jit 21 | def fn(g): 22 | fn_msg = getattr(galax.function, binary_op)("h", "m") 23 | fn_rdc = getattr(galax.function, reduce_op)("m", "h") 24 | _g = g.update_all(fn_msg, fn_rdc) 25 | return _g.ndata['h'] 26 | 27 | for n_times in range(10): 28 | time0 = time.time() 29 | _g = fn(G).block_until_ready() 30 | time1 = time.time() 31 | if n_times >= n_cold_start: 32 | accum_time += (time1 - time0) 33 | avg_time = accum_time / (n_times - n_cold_start) 34 | print('hidden size: {}, avg time: {}'.format( 35 | n_hid, avg_time)) 36 | 37 | if __name__ == '__main__': 38 | import argparse 39 | parser = argparse.ArgumentParser("Benchmark DGL kernels") 40 | parser.add_argument('--spmm-binary', type=str, default='copy_u') 41 | parser.add_argument('--spmm-reduce', type=str, default='sum') 42 | 43 | args = parser.parse_args() 44 | 45 | # for dataset in ['reddit', 'arxiv', 'proteins']: 46 | from galax.data.datasets.nodes.ogb import arxiv 47 | G = arxiv() 48 | print(G) 49 | 50 | # SPMM 51 | bench_spmm(G, args.spmm_binary, args.spmm_reduce) 52 | -------------------------------------------------------------------------------- /docs/autosummary/galax.heterograph.HeteroGraph.rst: -------------------------------------------------------------------------------- 1 | galax.heterograph.HeteroGraph 2 | ============================= 3 | 4 | .. currentmodule:: galax.heterograph 5 | 6 | .. autoclass:: HeteroGraph 7 | :members: 8 | :show-inheritance: 9 | :inherited-members: 10 | 11 | 12 | .. automethod:: __init__ 13 | 14 | 15 | .. rubric:: Methods 16 | 17 | .. autosummary:: 18 | 19 | ~HeteroGraph.__init__ 20 | ~HeteroGraph.add_edges 21 | ~HeteroGraph.add_nodes 22 | ~HeteroGraph.adj 23 | ~HeteroGraph.adjacency_matrix 24 | ~HeteroGraph.apply_nodes 25 | ~HeteroGraph.canonical_etypes 26 | ~HeteroGraph.count 27 | ~HeteroGraph.find_edges 28 | ~HeteroGraph.get_etype_id 29 | ~HeteroGraph.get_meta_edge 30 | ~HeteroGraph.get_ntype_id 31 | ~HeteroGraph.has_edges_between 32 | ~HeteroGraph.has_nodes 33 | ~HeteroGraph.in_degrees 34 | ~HeteroGraph.inc 35 | ~HeteroGraph.incidence_matrix 36 | ~HeteroGraph.index 37 | ~HeteroGraph.init 38 | ~HeteroGraph.is_homogeneous 39 | ~HeteroGraph.is_multigraph 40 | ~HeteroGraph.number_of_edges 41 | ~HeteroGraph.number_of_nodes 42 | ~HeteroGraph.out_degrees 43 | ~HeteroGraph.remove_edges 44 | ~HeteroGraph.remove_nodes 45 | ~HeteroGraph.set_edata 46 | ~HeteroGraph.set_ndata 47 | ~HeteroGraph.to_canonincal_etype 48 | 49 | 50 | 51 | 52 | 53 | .. rubric:: Attributes 54 | 55 | .. autosummary:: 56 | 57 | ~HeteroGraph.dstdata 58 | ~HeteroGraph.edata 59 | ~HeteroGraph.edge_frames 60 | ~HeteroGraph.edges 61 | ~HeteroGraph.etypes 62 | ~HeteroGraph.gidx 63 | ~HeteroGraph.metamap 64 | ~HeteroGraph.ndata 65 | ~HeteroGraph.node_frames 66 | ~HeteroGraph.nodes 67 | ~HeteroGraph.ntypes 68 | ~HeteroGraph.srcdata 69 | 70 | -------------------------------------------------------------------------------- /galax/tests/test_jittable.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | 5 | def test_graph_index_jit(): 6 | from galax.graph_index import GraphIndex 7 | 8 | g = GraphIndex( 9 | 2, 10 | jnp.array([0, 1], dtype=jnp.int32), 11 | jnp.array([0, 2], dtype=jnp.int32), 12 | ) 13 | 14 | @jax.jit 15 | def fn(g): 16 | g = g._replace(src=g.src + g.dst) 17 | return g 18 | 19 | _g = fn(g) 20 | assert _g.n_nodes == g.n_nodes 21 | assert (_g.dst == g.dst).all() 22 | 23 | 24 | def test_heterograph_index_jit(): 25 | from galax.graph_index import GraphIndex 26 | from galax.heterograph_index import HeteroGraphIndex 27 | 28 | metagraph = GraphIndex(3, jnp.array([0, 1]), jnp.array([1, 2])) 29 | n_nodes = jnp.array([3, 2, 1]) 30 | edges = ((jnp.array([0, 1]), jnp.array([1, 2])), (), ()) 31 | g = HeteroGraphIndex( 32 | metagraph=metagraph, 33 | n_nodes=n_nodes, 34 | edges=edges, 35 | ) 36 | 37 | @jax.jit 38 | def fn(g): 39 | _g = g._replace(n_nodes=g.n_nodes**2) 40 | return _g 41 | 42 | _g = fn(g) 43 | assert (_g.n_nodes == g.n_nodes**2).all() 44 | 45 | 46 | def test_graph_jit(): 47 | import galax 48 | import jax 49 | import jax.numpy as jnp 50 | 51 | g = galax.graph(((0, 1), (1, 2))) 52 | g = g.set_ndata("h", jnp.ones(3)) 53 | g = g.set_edata("he", jnp.ones(2)) 54 | 55 | @jax.jit 56 | def fn(g): 57 | return g.edata["he"] ** 2 58 | 59 | fn(g) 60 | 61 | 62 | def test_message_passing_jit(): 63 | import galax 64 | import jax 65 | import jax.numpy as jnp 66 | 67 | g = galax.graph(((0, 1), (1, 2))) 68 | g = g.set_ndata("h", jnp.ones(3)) 69 | 70 | @jax.jit 71 | def fn(g): 72 | mfunc = galax.function.copy_u("h", "m") 73 | rfunc = galax.function.sum("m", "h1") 74 | _g = galax.message_passing(g, mfunc, rfunc) 75 | return _g 76 | 77 | _g = fn(g) 78 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Cache 132 | .DS_Store 133 | 134 | # data 135 | ogbn_* 136 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | espaloma 3 | Extensible Surrogate Potential of Ab initio Learned and Optimized by Message-passing Algorithm 4 | """ 5 | import sys 6 | 7 | from setuptools import find_packages, setup 8 | 9 | import versioneer 10 | 11 | short_description = __doc__.split("\n") 12 | 13 | # from https://github.com/pytest-dev/pytest-runner#conditional-requirement 14 | needs_pytest = {'pytest', 'test', 'ptr'}.intersection(sys.argv) 15 | pytest_runner = ['pytest-runner'] if needs_pytest else [] 16 | 17 | try: 18 | with open("README.md", "r") as handle: 19 | long_description = handle.read() 20 | except: 21 | long_description = "\n".join(short_description[2:]) 22 | 23 | def local_scheme(version): 24 | return "" 25 | 26 | setup( 27 | # Self-descriptive entries which should always be present 28 | name='g3x', 29 | author='Yuanqing Wang', 30 | author_email='wangyq@wangyq.net', 31 | description=short_description[0], 32 | long_description=long_description, 33 | long_description_content_type="text/markdown", 34 | # use_scm_version={"local_scheme": local_scheme}, 35 | # version=versioneer.get_version(), 36 | version="0.0", 37 | # cmdclass=versioneer.get_cmdclass(), 38 | license='MIT', 39 | 40 | # Which Python importable modules should be included when your package is installed 41 | # Handled automatically by setuptools. Use 'exclude' to prevent some specific 42 | # subpackage(s) from being added, if needed 43 | packages=find_packages(), 44 | 45 | # Optional include package data to ship with your package 46 | # Customize MANIFEST.in if the general case does not suit your needs 47 | # Comment out this line to prevent the files from being packaged with your software 48 | include_package_data=True, 49 | 50 | # Allows `setup.py test` to work correctly with pytest 51 | setup_requires=[] + pytest_runner, 52 | 53 | # Additional entries you may want simply uncomment the lines you want and fill in the data 54 | # url='http://www.my_package.com', # Website 55 | # install_requires=[], # Required packages, pulls from pip if needed; do not use for Conda deployment 56 | # platforms=['Linux', 57 | # 'Mac OS-X', 58 | # 'Unix', 59 | # 'Windows'], # Valid platforms your code works on, adjust to your flavor 60 | # python_requires=">=3.5", # Python version restrictions 61 | 62 | # Manual control if final package is compressible or not, set False to prevent the .egg from being made 63 | # zip_safe=False, 64 | 65 | ) 66 | -------------------------------------------------------------------------------- /benchmarks/speed/training/gat/pubmed.py: -------------------------------------------------------------------------------- 1 | """Reference: 79.0; Reproduction: 78.1""" 2 | 3 | from functools import partial 4 | import jax 5 | from flax import linen as nn 6 | import optax 7 | import galax 8 | 9 | def run(): 10 | from galax.data.datasets.nodes.planetoid import pubmed 11 | G = pubmed() 12 | G = G.add_self_loop() 13 | Y_REF = jax.nn.one_hot(G.ndata['label'], 3) 14 | 15 | from galax.nn.zoo.gat import GAT 16 | ConcatenationPooling = galax.ApplyNodes(lambda x: x.reshape(*x.shape[:-2], -1)) 17 | AveragePooling = galax.ApplyNodes(lambda x: x.mean(-2)) 18 | 19 | model = nn.Sequential( 20 | ( 21 | GAT(8, 8, attn_drop=0.4, feat_drop=0.4, deterministic=False, activation=jax.nn.elu), 22 | ConcatenationPooling, 23 | GAT(3, 8, attn_drop=0.4, feat_drop=0.4, deterministic=False, activation=None), 24 | AveragePooling, 25 | ), 26 | ) 27 | 28 | model_eval = nn.Sequential( 29 | ( 30 | GAT(8, 8, attn_drop=0.4, feat_drop=0.4, deterministic=True, activation=jax.nn.elu), 31 | ConcatenationPooling, 32 | GAT(3, 8, attn_drop=0.4, feat_drop=0.4, deterministic=True, activation=None), 33 | AveragePooling, 34 | ), 35 | ) 36 | 37 | key = jax.random.PRNGKey(2666) 38 | key, key_dropout = jax.random.split(key) 39 | 40 | params = model.init({"params": key, "dropout": key}, G) 41 | mask = jax.tree_map(lambda x: (x != 0).any(), params) 42 | 43 | from flax.core import FrozenDict 44 | 45 | optimizer = optax.adam(0.01) 46 | 47 | from flax.training.train_state import TrainState 48 | state = TrainState.create( 49 | apply_fn=model.apply, params=params, tx=optimizer, 50 | ) 51 | 52 | def loss(params, key): 53 | g = model.apply(params, G, rngs={"dropout": key}) 54 | y = g.ndata['h'] 55 | return optax.softmax_cross_entropy( 56 | y[g.ndata['train_mask']], 57 | Y_REF[g.ndata['train_mask']], 58 | ).mean() 59 | 60 | @jax.jit 61 | def step(state, key): 62 | key, new_key = jax.random.split(key) 63 | grad_fn = jax.grad(partial(loss, key=new_key)) 64 | grads = grad_fn(state.params) 65 | state = state.apply_gradients(grads=grads) 66 | return state, key 67 | 68 | _, __ = step(state, key) 69 | 70 | import time 71 | time0 = time.time() 72 | for _ in range(200): 73 | state, key = step(state, key) 74 | state = jax.block_until_ready(state) 75 | time1 = time.time() 76 | print(time1 - time0) 77 | 78 | if __name__ == "__main__": 79 | import argparse 80 | run() 81 | -------------------------------------------------------------------------------- /benchmarks/speed/training/gat/citeseer.py: -------------------------------------------------------------------------------- 1 | """Reference: 72.5; Reproduction: 71.5""" 2 | 3 | from functools import partial 4 | import jax 5 | from flax import linen as nn 6 | import optax 7 | import galax 8 | 9 | def run(): 10 | from galax.data.datasets.nodes.planetoid import citeseer 11 | G = citeseer() 12 | G = G.add_self_loop() 13 | Y_REF = jax.nn.one_hot(G.ndata['label'], 6) 14 | 15 | from galax.nn.zoo.gat import GAT 16 | ConcatenationPooling = galax.ApplyNodes(lambda x: x.reshape(*x.shape[:-2], -1)) 17 | AveragePooling = galax.ApplyNodes(lambda x: x.mean(-2)) 18 | 19 | model = nn.Sequential( 20 | ( 21 | GAT(8, 8, attn_drop=0.4, feat_drop=0.4, deterministic=False, activation=jax.nn.elu), 22 | ConcatenationPooling, 23 | GAT(6, 1, attn_drop=0.4, feat_drop=0.4, deterministic=False, activation=None), 24 | AveragePooling, 25 | ), 26 | ) 27 | 28 | model_eval = nn.Sequential( 29 | ( 30 | GAT(8, 8, attn_drop=0.6, feat_drop=0.6, deterministic=True, activation=jax.nn.elu), 31 | ConcatenationPooling, 32 | GAT(6, 1, attn_drop=0.6, feat_drop=0.6, deterministic=True, activation=None), 33 | AveragePooling, 34 | ), 35 | ) 36 | 37 | key = jax.random.PRNGKey(2666) 38 | key, key_dropout = jax.random.split(key) 39 | 40 | params = model.init({"params": key, "dropout": key}, G) 41 | mask = jax.tree_map(lambda x: (x != 0).any(), params) 42 | 43 | from flax.core import FrozenDict 44 | 45 | optimizer = optax.adam(0.005) 46 | 47 | from flax.training.train_state import TrainState 48 | state = TrainState.create( 49 | apply_fn=model.apply, params=params, tx=optimizer, 50 | ) 51 | 52 | def loss(params, key): 53 | g = model.apply(params, G, rngs={"dropout": key}) 54 | y = g.ndata['h'] 55 | return optax.softmax_cross_entropy( 56 | y[g.ndata['train_mask']], 57 | Y_REF[g.ndata['train_mask']], 58 | ).mean() 59 | 60 | @jax.jit 61 | def step(state, key): 62 | key, new_key = jax.random.split(key) 63 | grad_fn = jax.grad(partial(loss, key=new_key)) 64 | grads = grad_fn(state.params) 65 | state = state.apply_gradients(grads=grads) 66 | return state, key 67 | 68 | _, __ = step(state, key) 69 | 70 | import time 71 | time0 = time.time() 72 | for _ in range(200): 73 | state, key = step(state, key) 74 | state = jax.block_until_ready(state) 75 | time1 = time.time() 76 | print(time1 - time0) 77 | 78 | if __name__ == "__main__": 79 | import argparse 80 | run() 81 | -------------------------------------------------------------------------------- /galax/core.py: -------------------------------------------------------------------------------- 1 | """Implementation for core graph computation.""" 2 | from typing import Callable, Optional, Any 3 | from functools import partial 4 | from flax.core import freeze, unfreeze 5 | from . import function 6 | from .function import ReduceFunction 7 | 8 | 9 | def message_passing( 10 | graph: Any, 11 | mfunc: Optional[Callable], 12 | rfunc: Optional[ReduceFunction], 13 | afunc: Optional[Callable] = None, 14 | etype: Optional[Callable] = None, 15 | ): 16 | """Invoke message passing computation on the whole graph. 17 | 18 | Parameters 19 | ---------- 20 | g : HeteroGraph 21 | The input graph. 22 | mfunc : Callable 23 | Message function. 24 | rfunc : Callable 25 | Reduce function. 26 | afunc : Callable 27 | Apply function. 28 | 29 | Returns 30 | ------- 31 | HeteroGraph 32 | The resulting graph. 33 | 34 | Examples 35 | -------- 36 | >>> import galax 37 | >>> import jax 38 | >>> import jax.numpy as jnp 39 | >>> g = galax.graph(((0, 1), (1, 2))) 40 | >>> g = g.ndata.set("h", jnp.ones(3)) 41 | >>> mfunc = galax.function.copy_u("h", "m") 42 | >>> rfunc = galax.function.sum("m", "h1") 43 | >>> _g = message_passing(g, mfunc, rfunc) 44 | >>> _g.ndata['h1'].flatten().tolist() 45 | [0.0, 1.0, 1.0] 46 | 47 | """ 48 | # TODO(yuanqing-wang): change this restriction in near future 49 | # assert isinstance(rfunc, ReduceFunction), "Only built-in reduce supported. " 50 | if etype is None: 51 | etype = graph.etypes[0] 52 | 53 | # find the edge type 54 | etype_idx = graph.get_etype_id(etype) 55 | 56 | # get number of nodes 57 | _, dsttype_idx = graph.get_meta_edge(etype_idx) 58 | n_dst = next(iter(graph.node_frames[dsttype_idx].values())).shape[0] 59 | 60 | # extract the message 61 | message = mfunc(graph.edges[etype]) 62 | 63 | # reduce by calling jax.ops.segment_ 64 | _rfunc = getattr(function, f"segment_{rfunc.op}") 65 | _rfunc = partial( 66 | _rfunc, 67 | segment_ids=graph.gidx.edges[0][1], 68 | num_segments=n_dst, 69 | ) 70 | reduced = {rfunc.out_field: _rfunc(message[rfunc.msg_field])} 71 | 72 | # apply if so specified 73 | if afunc is not None: 74 | reduced.update(afunc(reduced)) 75 | 76 | # update destination node frames 77 | node_frame = graph.node_frames[dsttype_idx] 78 | node_frame = unfreeze(node_frame) 79 | node_frame.update(reduced) 80 | node_frame = freeze(node_frame) 81 | node_frames = ( 82 | graph.node_frames[:dsttype_idx] 83 | + (node_frame,) 84 | + graph.node_frames[dsttype_idx + 1 :] 85 | ) 86 | 87 | return graph._replace(node_frames=node_frames) 88 | -------------------------------------------------------------------------------- /benchmarks/speed/training/gcn/cora.py: -------------------------------------------------------------------------------- 1 | """Reference: 70.3; Reproduction: 71.4""" 2 | 3 | from functools import partial 4 | import jax 5 | from flax import linen as nn 6 | import optax 7 | import galax 8 | 9 | 10 | def run(): 11 | from galax.data.datasets.nodes.planetoid import cora 12 | G = cora() 13 | G = G.add_self_loop() 14 | Y_REF = jax.nn.one_hot(G.ndata['label'], 7) 15 | 16 | from galax.nn.zoo.gcn import GCN 17 | model = nn.Sequential( 18 | ( 19 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=False)), 20 | GCN(16, activation=jax.nn.relu), 21 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=False)), 22 | GCN(7, activation=None), 23 | ), 24 | ) 25 | 26 | model_eval = nn.Sequential( 27 | ( 28 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=True)), 29 | GCN(16, activation=jax.nn.relu), 30 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=True)), 31 | GCN(7, activation=None), 32 | ), 33 | ) 34 | 35 | key = jax.random.PRNGKey(2666) 36 | key, key_dropout = jax.random.split(key) 37 | 38 | params = model.init({"params": key, "dropout": key_dropout}, G) 39 | 40 | from flax.core import FrozenDict 41 | mask = FrozenDict( 42 | {"params": 43 | { 44 | "layers_1": True, 45 | "layers_3": False, 46 | }, 47 | }, 48 | ) 49 | 50 | optimizer = optax.adam(1e-2) 51 | 52 | from flax.training.train_state import TrainState 53 | state = TrainState.create( 54 | apply_fn=model.apply, params=params, tx=optimizer, 55 | ) 56 | 57 | def loss(params, key): 58 | g = model.apply(params, G, rngs={"dropout": key}) 59 | y = g.ndata['h'] 60 | return optax.softmax_cross_entropy( 61 | y[g.ndata['train_mask']], 62 | Y_REF[g.ndata['train_mask']], 63 | ).mean() 64 | 65 | @jax.jit 66 | def step(state, key): 67 | key, new_key = jax.random.split(key) 68 | grad_fn = jax.grad(partial(loss, key=new_key)) 69 | grads = grad_fn(state.params) 70 | state = state.apply_gradients(grads=grads) 71 | return state, key 72 | 73 | 74 | @jax.jit 75 | def steps(state, key): 76 | for _ in range(10): 77 | state, key = step(state, key) 78 | return state, key 79 | 80 | for _ in range(2): 81 | _, __ = steps(state, key) 82 | 83 | import time 84 | time0 = time.time() 85 | for _ in range(20): 86 | state, key = steps(state, key) 87 | state = jax.block_until_ready(state) 88 | time1 = time.time() 89 | print(time1 - time0) 90 | 91 | if __name__ == "__main__": 92 | import argparse 93 | run() 94 | -------------------------------------------------------------------------------- /benchmarks/speed/training/gcn/pubmed.py: -------------------------------------------------------------------------------- 1 | """Reference: 79.0; Reproduction: 78.8""" 2 | 3 | from functools import partial 4 | import jax 5 | from flax import linen as nn 6 | import optax 7 | import galax 8 | 9 | 10 | def run(): 11 | from galax.data.datasets.nodes.planetoid import pubmed 12 | G = pubmed() 13 | G = G.add_self_loop() 14 | Y_REF = jax.nn.one_hot(G.ndata['label'], 3) 15 | 16 | from galax.nn.zoo.gcn import GCN 17 | model = nn.Sequential( 18 | ( 19 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=False)), 20 | GCN(16, activation=jax.nn.relu), 21 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=False)), 22 | GCN(3, activation=None), 23 | ), 24 | ) 25 | 26 | model_eval = nn.Sequential( 27 | ( 28 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=True)), 29 | GCN(16, activation=jax.nn.relu), 30 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=True)), 31 | GCN(3, activation=None), 32 | ), 33 | ) 34 | 35 | 36 | key = jax.random.PRNGKey(2666) 37 | key, key_dropout = jax.random.split(key) 38 | 39 | params = model.init({"params": key, "dropout": key_dropout}, G) 40 | 41 | from flax.core import FrozenDict 42 | mask = FrozenDict( 43 | {"params": 44 | { 45 | "layers_1": True, 46 | "layers_3": False, 47 | }, 48 | }, 49 | ) 50 | 51 | optimizer = optax.adam(1e-2) 52 | 53 | from flax.training.train_state import TrainState 54 | state = TrainState.create( 55 | apply_fn=model.apply, params=params, tx=optimizer, 56 | ) 57 | 58 | def loss(params, key): 59 | g = model.apply(params, G, rngs={"dropout": key}) 60 | y = g.ndata['h'] 61 | return optax.softmax_cross_entropy( 62 | y[g.ndata['train_mask']], 63 | Y_REF[g.ndata['train_mask']], 64 | ).mean() 65 | 66 | @jax.jit 67 | def step(state, key): 68 | key, new_key = jax.random.split(key) 69 | grad_fn = jax.grad(partial(loss, key=new_key)) 70 | grads = grad_fn(state.params) 71 | state = state.apply_gradients(grads=grads) 72 | return state, key 73 | 74 | @jax.jit 75 | def steps(state, key): 76 | for _ in range(10): 77 | state, key = step(state, key) 78 | return state, key 79 | 80 | for _ in range(2): 81 | _, __ = steps(state, key) 82 | 83 | import time 84 | time0 = time.time() 85 | for _ in range(20): 86 | state, key = steps(state, key) 87 | state = jax.block_until_ready(state) 88 | time1 = time.time() 89 | print(time1 - time0) 90 | 91 | if __name__ == "__main__": 92 | import argparse 93 | run() 94 | -------------------------------------------------------------------------------- /benchmarks/speed/training/gcn/citeseer.py: -------------------------------------------------------------------------------- 1 | """Reference: 70.3; Reproduction: 71.4""" 2 | 3 | from functools import partial 4 | import jax 5 | from flax import linen as nn 6 | import optax 7 | import galax 8 | 9 | 10 | def run(): 11 | from galax.data.datasets.nodes.planetoid import citeseer 12 | G = citeseer() 13 | G = G.add_self_loop() 14 | Y_REF = jax.nn.one_hot(G.ndata['label'], 6) 15 | 16 | from galax.nn.zoo.gcn import GCN 17 | model = nn.Sequential( 18 | ( 19 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=False)), 20 | GCN(16, activation=jax.nn.relu), 21 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=False)), 22 | GCN(6, activation=None), 23 | ), 24 | ) 25 | 26 | model_eval = nn.Sequential( 27 | ( 28 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=True)), 29 | GCN(16, activation=jax.nn.relu), 30 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=True)), 31 | GCN(6, activation=None), 32 | ), 33 | ) 34 | 35 | key = jax.random.PRNGKey(2666) 36 | key, key_dropout = jax.random.split(key) 37 | 38 | params = model.init({"params": key, "dropout": key_dropout}, G) 39 | 40 | from flax.core import FrozenDict 41 | mask = FrozenDict( 42 | {"params": 43 | { 44 | "layers_1": True, 45 | "layers_3": False, 46 | }, 47 | }, 48 | ) 49 | 50 | optimizer = optax.adam(1e-2) 51 | 52 | from flax.training.train_state import TrainState 53 | state = TrainState.create( 54 | apply_fn=model.apply, params=params, tx=optimizer, 55 | ) 56 | 57 | def loss(params, key): 58 | g = model.apply(params, G, rngs={"dropout": key}) 59 | y = g.ndata['h'] 60 | return optax.softmax_cross_entropy( 61 | y[g.ndata['train_mask']], 62 | Y_REF[g.ndata['train_mask']], 63 | ).mean() 64 | 65 | @jax.jit 66 | def step(state, key): 67 | key, new_key = jax.random.split(key) 68 | grad_fn = jax.grad(partial(loss, key=new_key)) 69 | grads = grad_fn(state.params) 70 | state = state.apply_gradients(grads=grads) 71 | return state, key 72 | 73 | 74 | @jax.jit 75 | def steps(state, key): 76 | for _ in range(10): 77 | state, key = step(state, key) 78 | return state, key 79 | 80 | for _ in range(2): 81 | _, __ = jax.block_until_ready(steps(state, key)) 82 | 83 | import time 84 | time0 = time.time() 85 | for _ in range(20): 86 | state, key = steps(state, key) 87 | state = jax.block_until_ready(state) 88 | time1 = time.time() 89 | print(time1 - time0) 90 | 91 | if __name__ == "__main__": 92 | import argparse 93 | run() 94 | -------------------------------------------------------------------------------- /benchmarks/speed/training/gat/cora.py: -------------------------------------------------------------------------------- 1 | """Reference: 83.0; Reproduction: 83.3""" 2 | 3 | from functools import partial 4 | import jax 5 | from flax import linen as nn 6 | import optax 7 | import galax 8 | 9 | def run(): 10 | from galax.data.datasets.nodes.planetoid import cora 11 | G = cora() 12 | G = G.add_self_loop() 13 | Y_REF = jax.nn.one_hot(G.ndata['label'], 7) 14 | 15 | from galax.nn.zoo.gat import GAT 16 | ConcatenationPooling = galax.ApplyNodes(lambda x: x.reshape(*x.shape[:-2], -1)) 17 | AveragePooling = galax.ApplyNodes(lambda x: x.mean(-2)) 18 | 19 | model = nn.Sequential( 20 | ( 21 | GAT(8, 8, attn_drop=0.4, feat_drop=0.4, deterministic=False, activation=jax.nn.elu), 22 | ConcatenationPooling, 23 | GAT(7, 1, attn_drop=0.4, feat_drop=0.4, deterministic=False, activation=None), 24 | AveragePooling, 25 | ), 26 | ) 27 | 28 | model_eval = nn.Sequential( 29 | ( 30 | GAT(8, 8, attn_drop=0.4, feat_drop=0.4, deterministic=True, activation=jax.nn.elu), 31 | ConcatenationPooling, 32 | GAT(7, 1, attn_drop=0.4, feat_drop=0.4, deterministic=True, activation=None), 33 | AveragePooling, 34 | ), 35 | ) 36 | 37 | key = jax.random.PRNGKey(2666) 38 | key, key_dropout = jax.random.split(key) 39 | 40 | params = model.init({"params": key, "dropout": key}, G) 41 | mask = jax.tree_map(lambda x: (x != 0).any(), params) 42 | 43 | from flax.core import FrozenDict 44 | 45 | optimizer = optax.adam(0.005) 46 | 47 | from flax.training.train_state import TrainState 48 | state = TrainState.create( 49 | apply_fn=model.apply, params=params, tx=optimizer, 50 | ) 51 | 52 | def loss(params, key): 53 | g = model.apply(params, G, rngs={"dropout": key}) 54 | y = g.ndata['h'] 55 | return optax.softmax_cross_entropy( 56 | y[g.ndata['train_mask']], 57 | Y_REF[g.ndata['train_mask']], 58 | ).mean() 59 | 60 | @jax.jit 61 | def step(state, key): 62 | key, new_key = jax.random.split(key) 63 | grad_fn = jax.grad(partial(loss, key=new_key)) 64 | grads = grad_fn(state.params) 65 | state = state.apply_gradients(grads=grads) 66 | return state, key 67 | 68 | def steps(state, key): 69 | for _ in range(10): 70 | state, key = step(state, key) 71 | return state, key 72 | 73 | for _ in range(2): 74 | _, __ = step(state, key) 75 | 76 | import time 77 | time0 = time.time() 78 | for _ in range(200): 79 | state, key = step(state, key) 80 | state = jax.block_until_ready(state) 81 | time1 = time.time() 82 | 83 | print(time1 - time0) 84 | 85 | if __name__ == "__main__": 86 | import argparse 87 | run() 88 | -------------------------------------------------------------------------------- /galax/data/dataloader.py: -------------------------------------------------------------------------------- 1 | """The data loader function for multiple graphs.""" 2 | 3 | import random 4 | from typing import Sequence 5 | import jax 6 | import jax.numpy as jnp 7 | from ..heterograph import HeteroGraph 8 | from ..batch import pad 9 | 10 | class PrixFixeDataLoader: 11 | """A helper object that shuffles and iterates over graphs. 12 | 13 | Parameters 14 | ---------- 15 | graphs : Sequence[HeteroGraph] 16 | Graphs to iterate over. 17 | batch_size : int = 1 18 | Batch size. 19 | 20 | Examples 21 | -------- 22 | >>> import galax 23 | >>> g0 = galax.graph(((0, 1), (1, 2))) 24 | >>> g1 = galax.graph(((0, 1, 2), (1, 2, 3))) 25 | >>> g2 = galax.graph(((0, 1, 2, 3), (1, 2, 3, 4))) 26 | >>> dataloader = PrixFixeDataLoader((g0, g1, g2), batch_size=3) 27 | >>> dataloader.max_num_edges.item() 28 | 9 29 | >>> dataloader.max_num_nodes.item() 30 | 12 31 | >>> g = next(iter(dataloader)) 32 | >>> int(g.number_of_nodes()) 33 | 12 34 | >>> int(g.number_of_edges()) 35 | 9 36 | 37 | 38 | """ 39 | def __init__( 40 | self, 41 | graphs: Sequence[HeteroGraph], 42 | batch_size: int = 1, 43 | ): 44 | self.graphs = graphs 45 | self.batch_size = batch_size 46 | self._graphs = None 47 | self._prepare() 48 | 49 | def _prepare(self): 50 | """Compute the max nodes and max edges for padding and batching.""" 51 | # compute max n_nodes and n_edges 52 | # (n_graphs, n_ntypes) 53 | n_nodes = jnp.stack( 54 | [graph.gidx.n_nodes for graph in self.graphs], 55 | axis=0, 56 | ) 57 | 58 | # (n_graphs, n_etypes) 59 | n_edges = jnp.stack( 60 | [ 61 | jnp.array([len(edge[0]) for edge in graph.gidx.edges]) 62 | for graph in self.graphs 63 | ] 64 | ) 65 | 66 | # (k, n_ntypes) 67 | top_n_nodes = jax.lax.top_k(n_nodes.T, self.batch_size)[0].T 68 | 69 | # (k, n_etypes) 70 | top_n_edges = jax.lax.top_k(n_edges.T, self.batch_size)[0].T 71 | 72 | max_n_nodes = top_n_nodes.sum(0) 73 | max_n_edges = top_n_edges.sum(0) 74 | 75 | self.max_num_nodes = max_n_nodes 76 | self.max_num_edges = max_n_edges 77 | 78 | def __iter__(self): 79 | self._graphs = list(self.graphs) 80 | random.shuffle(self._graphs) 81 | return self 82 | 83 | def __next__(self): 84 | if len(self._graphs) < self.batch_size: 85 | raise StopIteration 86 | else: 87 | graphs_to_serve = self._graphs[:self.batch_size] 88 | self._graphs = self._graphs[self.batch_size:] 89 | graphs_to_serve = pad( 90 | graphs_to_serve, 91 | self.max_num_nodes, 92 | self.max_num_edges, 93 | ) 94 | return graphs_to_serve 95 | -------------------------------------------------------------------------------- /galax/nn/zoo/gcn.py: -------------------------------------------------------------------------------- 1 | """`Graph Convolutional network. `__""" 2 | 3 | from typing import Callable, Optional 4 | import jax 5 | import jax.numpy as jnp 6 | from flax import linen as nn 7 | from ... import function as fn 8 | 9 | class GCN(nn.Module): 10 | r"""Graph convolutional layer from 11 | `Semi-Supervised Classification with Graph Convolutional 12 | Networks `__ 13 | 14 | Mathematically it is defined as follows: 15 | 16 | .. math:: 17 | h_i^{(l+1)} = 18 | \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)} 19 | \frac{1}{c_{ji}}h_j^{(l)}W^{(l)}) 20 | 21 | where :math:`\mathcal{N}(i)` is the set of neighbors of node :math:`i`, 22 | :math:`c_{ji}` is the product of the square root of node degrees 23 | (i.e., :math:`c_{ji} = \sqrt{|\mathcal{N}(j)|}\sqrt{|\mathcal{N}(i)|}`), 24 | and :math:`\sigma` is an activation function. 25 | 26 | Parameters 27 | ---------- 28 | out_feats : int 29 | Output features size. 30 | norm : Optional[str] 31 | 32 | Returns 33 | ------- 34 | HeteroGraph 35 | The resulting Graph. 36 | 37 | Examples 38 | -------- 39 | >>> import jax 40 | >>> import jax.numpy as jnp 41 | >>> import galax 42 | >>> g = galax.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) 43 | >>> g = g.add_self_loop() 44 | >>> g = g.set_ndata("h", jnp.ones((6, 10))) 45 | >>> gcn = GCN(2, use_bias=True) 46 | >>> params = gcn.init(jax.random.PRNGKey(2666), g) 47 | >>> g = gcn.apply(params, g) 48 | >>> x = g.ndata['h'] 49 | >>> x.shape 50 | (6, 2) 51 | """ 52 | features: int 53 | use_bias: bool = False 54 | activation: Optional[Callable] = None 55 | 56 | @nn.compact 57 | def __call__(self, graph, field="h"): 58 | # initialize parameters 59 | kernel = self.param( 60 | 'kernel', 61 | jax.nn.initializers.glorot_uniform(), 62 | (graph.ndata[field].shape[-1], self.features), 63 | ) 64 | 65 | if self.use_bias: 66 | bias = self.param( 67 | "bias", 68 | jax.nn.initializers.zeros, 69 | (self.features, ), 70 | ) 71 | else: 72 | bias = 0.0 73 | 74 | activation = self.activation 75 | if activation is None: 76 | activation = lambda x: x 77 | 78 | # propergate 79 | graph = graph.update_all(fn.copy_u(field, "m"), fn.sum("m", field)) 80 | 81 | # normalize 82 | degrees = graph.out_degrees() 83 | norm = degrees ** (-0.5) 84 | norm_shape = norm.shape + (1, ) * (graph.ndata[field].ndim - 1) 85 | norm = jnp.reshape(norm, norm_shape) 86 | 87 | # transform 88 | function = lambda h: activation((norm * h @ kernel) * norm + bias) 89 | graph = fn.apply_nodes(function, field)(graph) 90 | return graph 91 | -------------------------------------------------------------------------------- /galax/view.py: -------------------------------------------------------------------------------- 1 | """Views of Graph. 2 | 3 | Inspired by dgl.view 4 | """ 5 | 6 | from collections import namedtuple 7 | import jax.numpy as jnp 8 | 9 | NodeSpace = namedtuple("NodeSpace", ["data"]) 10 | EdgeSpace = namedtuple("EdgeSpace", ["data", "src", "dst"]) 11 | GraphSpace = namedtuple("GraphSpace", ["data"]) 12 | 13 | class NodeView(object): 14 | def __init__(self, graph): 15 | self.graph = graph 16 | 17 | def __getitem__(self, key): 18 | ntype_idx = self.graph._ntype_invmap[key] 19 | return NodeSpace( 20 | data=NodeDataView( 21 | graph=self.graph, 22 | ntype_idx=ntype_idx, 23 | ), 24 | ) 25 | 26 | class EdgeView(object): 27 | def __init__(self, graph): 28 | self.graph = graph 29 | 30 | def __getitem__(self, key): 31 | etype_idx = self.graph._etype_invmap[key] 32 | srctype_idx, dsttype_idx = self.graph.get_meta_edge(etype_idx) 33 | src, dst = self.graph.gidx.edges[etype_idx] 34 | return EdgeSpace( 35 | data=EdgeDataView( 36 | graph=self.graph, 37 | etype_idx=etype_idx, 38 | ), 39 | src=NodeDataView( 40 | graph=self.graph, 41 | ntype_idx=srctype_idx, 42 | idxs=src, 43 | ), 44 | dst=NodeDataView( 45 | graph=self.graph, 46 | ntype_idx=dsttype_idx, 47 | idxs=dst, 48 | ), 49 | ) 50 | 51 | def __call__(self, key=None): 52 | etype_idx = self.graph.get_etype_id(key) 53 | src, dst = self.graph.gidx.edges[etype_idx] 54 | return src, dst 55 | 56 | class NodeDataView(object): 57 | def __init__(self, graph, ntype_idx, idxs=None): 58 | self.graph = graph 59 | self.ntype_idx = ntype_idx 60 | self.idxs = idxs 61 | 62 | def __getitem__(self, key): 63 | res = self.graph.node_frames[self.ntype_idx][key] 64 | if self.idxs is not None: 65 | res = jnp.take(res, self.idxs, 0) 66 | return res 67 | 68 | def set(self, key, data): 69 | assert self.idxs is None, "Cannot partially set. " 70 | return self.graph.set_ndata(key=key, data=data, ntype=self.ntype_idx) 71 | 72 | def keys(self): 73 | return self.graph.node_frames[self.ntype_idx].keys() 74 | 75 | class EdgeDataView(object): 76 | def __init__(self, graph, etype_idx, idxs=None): 77 | self.graph = graph 78 | self.etype_idx = etype_idx 79 | self.idxs = idxs 80 | 81 | def __getitem__(self, key): 82 | res = self.graph.edge_frames[self.etype_idx][key] 83 | if self.idxs is not None: 84 | # res = res[self.idxs] 85 | res = jnp.take(res, self.idxs, 0) 86 | return res 87 | 88 | def set(self, key, data): 89 | assert self.idxs is None, "Cannot partially set. " 90 | return self.graph.set_edata(key=key, data=data, etype=self.etype_idx) 91 | 92 | def keys(self): 93 | return self.graph.edge_frames[self.etype_idx].keys() 94 | 95 | class GraphDataView(object): 96 | def __init__(self, graph): 97 | self.graph = graph 98 | 99 | def __getitem__(self, key): 100 | return self.graph.graph_frame[key] 101 | 102 | def set(self, key, data): 103 | return self.graph.set_gdata(key=key, data=data) 104 | 105 | def keys(self): 106 | return self.graph.graph_frame.keys() 107 | -------------------------------------------------------------------------------- /benchmarks/gat/ppi.py: -------------------------------------------------------------------------------- 1 | """Reference: 97.3; Reproduction: 97.4""" 2 | 3 | from functools import partial 4 | import jax 5 | import jax.numpy as jnp 6 | from flax import linen as nn 7 | import optax 8 | import galax 9 | import numpy as onp 10 | from sklearn.metrics import f1_score 11 | 12 | def run(): 13 | from galax.data.datasets.nodes.graphsage import ppi 14 | GS_TR, GS_VL, GS_TE = ppi() 15 | GS_TR = tuple(map(lambda g: g.add_self_loop(), GS_TR)) 16 | GS_VL = tuple(map(lambda g: g.add_self_loop(), GS_VL)) 17 | GS_TE = tuple(map(lambda g: g.add_self_loop(), GS_TE)) 18 | 19 | from galax.data.dataloader import PrixFixeDataLoader 20 | ds_tr = PrixFixeDataLoader(GS_TR, 2) 21 | g_vl = galax.batch(GS_VL) 22 | g_te = galax.batch(GS_TE) 23 | 24 | from galax.nn.zoo.gat import GAT 25 | _ConcatenationPooling = lambda x: x.reshape(*x.shape[:-2], -1) 26 | ConcatenationPooling = galax.ApplyNodes(_ConcatenationPooling) 27 | _AveragePooling = lambda x: x.mean(-2) 28 | AveragePooling = galax.ApplyNodes(_AveragePooling) 29 | 30 | class Model(nn.Module): 31 | def setup(self): 32 | self.l0 = GAT(256, 4, activation=jax.nn.elu) 33 | self.l1 = GAT(256, 4, activation=jax.nn.elu) 34 | self.l2 = GAT(121, 6, activation=None) 35 | 36 | def __call__(self, g): 37 | g0 = ConcatenationPooling(self.l0(g)) 38 | g1 = ConcatenationPooling(self.l1(g0)) 39 | g1 = g1.ndata.set("h", g1.ndata['h'] + g0.ndata['h']) 40 | g2 = AveragePooling(self.l2(g1)) 41 | return g2 42 | 43 | model = Model() 44 | 45 | key = jax.random.PRNGKey(2666) 46 | key, key_dropout = jax.random.split(key) 47 | 48 | params = model.init({"params": key, "dropout": key_dropout}, next(iter(ds_tr))) 49 | 50 | optimizer = optax.adam(0.005) 51 | 52 | from flax.training.train_state import TrainState 53 | state = TrainState.create( 54 | apply_fn=model.apply, params=params, tx=optimizer, 55 | ) 56 | 57 | # @jax.jit 58 | def loss(params, g): 59 | g = model.apply(params, g) 60 | _loss = optax.sigmoid_binary_cross_entropy( 61 | g.ndata['h'], g.ndata['label'], 62 | ).mean() 63 | _loss = jnp.where(jnp.expand_dims(g.is_not_dummy(), -1), _loss, 0.0) 64 | _loss = _loss.sum() / len(_loss) 65 | return _loss 66 | 67 | @jax.jit 68 | def step(state, g): 69 | grad_fn = jax.grad(partial(loss, g=g)) 70 | grads = grad_fn(state.params) 71 | state = state.apply_gradients(grads=grads) 72 | return state 73 | 74 | # @jax.jit 75 | def eval(state, g): 76 | g = model.apply(state.params, g) 77 | _loss = optax.sigmoid_binary_cross_entropy( 78 | g.ndata['h'], g.ndata['label'], 79 | ).mean() 80 | _loss = jnp.where(jnp.expand_dims(g.is_not_dummy(), -1), _loss, 0.0) 81 | _loss = _loss.sum() / len(_loss) 82 | 83 | y_hat = jax.nn.sigmoid(g.ndata['h']) 84 | y = g.ndata['label'] 85 | 86 | y_hat = y_hat[g.is_not_dummy()] 87 | y_hat = 1 * (jax.nn.sigmoid(y_hat) > 0.5) 88 | y = y[g.is_not_dummy()] 89 | 90 | f1 = f1_score(y, y_hat, average="micro") 91 | 92 | return _loss, f1 93 | 94 | from galax.nn.utils import EarlyStopping 95 | early_stopping = EarlyStopping(100) 96 | 97 | import tqdm 98 | for _ in tqdm.tqdm(range(1000)): 99 | for idx, g in enumerate(ds_tr): 100 | state = step(state, g) 101 | loss_vl, f1_vl = eval(state, g_vl) 102 | if early_stopping((-f1_vl, loss_vl), state.params): 103 | state = state.replace(params=early_stopping.params) 104 | break 105 | 106 | _, f1 = eval(state, g_te) 107 | print(f1) 108 | 109 | if __name__ == "__main__": 110 | import argparse 111 | run() 112 | -------------------------------------------------------------------------------- /benchmarks/gcn/cora.py: -------------------------------------------------------------------------------- 1 | """Reference: 81.5; Reproduction: 81.8""" 2 | 3 | from functools import partial 4 | import jax 5 | from flax import linen as nn 6 | import optax 7 | import galax 8 | 9 | 10 | def run(): 11 | from galax.data.datasets.nodes.planetoid import cora 12 | G = cora() 13 | G = G.add_self_loop() 14 | Y_REF = jax.nn.one_hot(G.ndata['label'], 7) 15 | 16 | from galax.nn.zoo.gcn import GCN 17 | model = nn.Sequential( 18 | ( 19 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=False)), 20 | GCN(16, activation=jax.nn.relu), 21 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=False)), 22 | GCN(7, activation=None), 23 | ), 24 | ) 25 | 26 | model_eval = nn.Sequential( 27 | ( 28 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=True)), 29 | GCN(16, activation=jax.nn.relu), 30 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=True)), 31 | GCN(7, activation=None), 32 | ), 33 | ) 34 | 35 | key = jax.random.PRNGKey(2666) 36 | key, key_dropout = jax.random.split(key) 37 | 38 | params = model.init({"params": key, "dropout": key_dropout}, G) 39 | 40 | from flax.core import FrozenDict 41 | mask = FrozenDict( 42 | {"params": 43 | { 44 | "layers_1": True, 45 | "layers_3": False, 46 | }, 47 | }, 48 | ) 49 | 50 | optimizer = optax.chain( 51 | optax.additive_weight_decay(5e-4, mask=mask), 52 | optax.adam(1e-2), 53 | ) 54 | 55 | from flax.training.train_state import TrainState 56 | state = TrainState.create( 57 | apply_fn=model.apply, params=params, tx=optimizer, 58 | ) 59 | 60 | def loss(params, key): 61 | g = model.apply(params, G, rngs={"dropout": key}) 62 | y = g.ndata['h'] 63 | return optax.softmax_cross_entropy( 64 | y[g.ndata['train_mask']], 65 | Y_REF[g.ndata['train_mask']], 66 | ).mean() 67 | 68 | @jax.jit 69 | def step(state, key): 70 | key, new_key = jax.random.split(key) 71 | grad_fn = jax.grad(partial(loss, key=new_key)) 72 | grads = grad_fn(state.params) 73 | state = state.apply_gradients(grads=grads) 74 | return state, key 75 | 76 | @jax.jit 77 | def eval(state): 78 | params = state.params 79 | g = model_eval.apply(params, G) 80 | y = g.ndata['h'] 81 | accuracy_vl = (Y_REF[g.ndata['val_mask']].argmax(-1) == 82 | y[g.ndata['val_mask']].argmax(-1)).sum() /\ 83 | g.ndata['val_mask'].sum() 84 | loss_vl = optax.softmax_cross_entropy( 85 | y[g.ndata['val_mask']], 86 | Y_REF[g.ndata['val_mask']], 87 | ).mean() 88 | return accuracy_vl, loss_vl 89 | 90 | @jax.jit 91 | def test(state): 92 | params = state.params 93 | g = model_eval.apply(params, G) 94 | y = g.ndata['h'] 95 | accuracy_te = (Y_REF[g.ndata['test_mask']].argmax(-1) == 96 | y[g.ndata['test_mask']].argmax(-1)).sum() /\ 97 | g.ndata['test_mask'].sum() 98 | loss_te = optax.softmax_cross_entropy( 99 | y[g.ndata['test_mask']], 100 | Y_REF[g.ndata['test_mask']], 101 | ).mean() 102 | return accuracy_te, loss_te 103 | 104 | from galax.nn.utils import EarlyStopping 105 | early_stopping = EarlyStopping(10) 106 | 107 | import tqdm 108 | for _ in tqdm.tqdm(range(1000)): 109 | state, key = step(state, key) 110 | accuracy_vl, loss_vl = eval(state) 111 | if early_stopping((-accuracy_vl, loss_vl), state.params): 112 | state = state.replace(params=early_stopping.params) 113 | break 114 | 115 | accuracy_te, _ = test(state) 116 | print(accuracy_te) 117 | 118 | if __name__ == "__main__": 119 | import argparse 120 | run() 121 | -------------------------------------------------------------------------------- /galax/nn/zoo/graphsage.py: -------------------------------------------------------------------------------- 1 | """`GraphSAGE `__""" 2 | 3 | from typing import Callable, Optional 4 | import jax 5 | import jax.numpy as jnp 6 | from flax import linen as nn 7 | from ... import function as fn 8 | 9 | class GraphSAGE(nn.Module): 10 | r"""GraphSAGE layer from `Inductive Representation Learning on 11 | Large Graphs `__ 12 | .. math:: 13 | h_{\mathcal{N}(i)}^{(l+1)} &= \mathrm{aggregate} 14 | \left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right) 15 | h_{i}^{(l+1)} &= \sigma \left(W \cdot \mathrm{concat} 16 | (h_{i}^{l}, h_{\mathcal{N}(i)}^{l+1}) \right) 17 | h_{i}^{(l+1)} &= \mathrm{norm}(h_{i}^{(l+1)}) 18 | 19 | Parameters 20 | ---------- 21 | in_feats : int, or pair of ints 22 | Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`. 23 | If aggregator type is ``gcn``, the feature size of source and 24 | destination nodes are required to be the same. 25 | out_feats : int 26 | Output feature size; i.e, the number of dimensions 27 | of :math:`h_i^{(l+1)}`. 28 | aggregator_type : str 29 | Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``). 30 | feat_drop : float 31 | Dropout rate on features, default: ``0``. 32 | use_bias : bool 33 | If True, adds a learnable bias to the output. Default: ``True``. 34 | norm : callable activation function/layer or None, optional 35 | If not None, applies normalization to the updated node features. 36 | activation : callable activation function/layer or None, optional 37 | If not None, applies an activation function to the updated node features. 38 | Default: ``None``. 39 | 40 | Examples 41 | -------- 42 | >>> import jax 43 | >>> import jax.numpy as jnp 44 | >>> import galax 45 | >>> g = galax.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) 46 | >>> g = g.add_self_loop() 47 | >>> g = g.set_ndata("h", jnp.ones((6, 10))) 48 | >>> graphsage = GraphSAGE(2, "pool") 49 | >>> params = graphsage.init(jax.random.PRNGKey(2666), g) 50 | >>> x = graphsage.apply(params, g).ndata["h"] 51 | >>> x.shape 52 | (6, 2) 53 | 54 | """ 55 | features: int 56 | aggregator_type: str = "mean" 57 | use_bias: bool = True 58 | activation: Optional[Callable] = None 59 | 60 | @nn.compact 61 | def __call__(self, graph, field="h"): 62 | h_self = graph.ndata[field] 63 | 64 | if self.aggregator_type == "mean": 65 | graph = graph.update_all( 66 | fn.copy_src(field, "m"), fn.mean("m", "neigh"), 67 | ) 68 | h_neigh = graph.ndata["neigh"] 69 | 70 | elif self.aggregator_type == "gcn": 71 | graph = graph.update_all( 72 | fn.copy_src(field, "m"), fn.sum("m", "neigh"), 73 | ) 74 | degrees = graph.in_degrees() 75 | h_neigh = (graph.ndata["neigh"] + graph.ndata[field])\ 76 | / (jnp.expand_dims(degrees, -1) + 1) 77 | 78 | elif self.aggregator_type == "pool": 79 | h_pool = jax.nn.relu( 80 | nn.Dense(graph.ndata[field].shape[-1])(graph.ndata[field]), 81 | ) 82 | graph = graph.ndata.set(field, h_pool) 83 | graph = graph.update_all( 84 | fn.copy_src(field, "m"), fn.max("m", "neigh"), 85 | ) 86 | h_neigh = graph.ndata["neigh"] 87 | 88 | h_neigh = nn.Dense(self.features, use_bias=False)(h_neigh) 89 | 90 | if self.aggregator_type == "gcn": 91 | rst = h_neigh 92 | else: 93 | rst = h_neigh + nn.Dense(self.features, use_bias=False)(h_self) 94 | 95 | if self.activation is not None: 96 | rst = self.activation(rst) 97 | 98 | graph = graph.ndata.set(field, rst) 99 | return graph 100 | -------------------------------------------------------------------------------- /benchmarks/gcn/pubmed.py: -------------------------------------------------------------------------------- 1 | """Reference: 79.0; Reproduction: 78.8""" 2 | 3 | from functools import partial 4 | import jax 5 | from flax import linen as nn 6 | import optax 7 | import galax 8 | 9 | 10 | def run(): 11 | from galax.data.datasets.nodes.planetoid import pubmed 12 | G = pubmed() 13 | G = G.add_self_loop() 14 | Y_REF = jax.nn.one_hot(G.ndata['label'], 3) 15 | 16 | from galax.nn.zoo.gcn import GCN 17 | model = nn.Sequential( 18 | ( 19 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=False)), 20 | GCN(16, activation=jax.nn.relu), 21 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=False)), 22 | GCN(3, activation=None), 23 | ), 24 | ) 25 | 26 | model_eval = nn.Sequential( 27 | ( 28 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=True)), 29 | GCN(16, activation=jax.nn.relu), 30 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=True)), 31 | GCN(3, activation=None), 32 | ), 33 | ) 34 | 35 | key = jax.random.PRNGKey(2666) 36 | key, key_dropout = jax.random.split(key) 37 | 38 | params = model.init({"params": key, "dropout": key_dropout}, G) 39 | 40 | from flax.core import FrozenDict 41 | mask = FrozenDict( 42 | {"params": 43 | { 44 | "layers_1": True, 45 | "layers_3": False, 46 | }, 47 | }, 48 | ) 49 | 50 | optimizer = optax.chain( 51 | optax.additive_weight_decay(5e-4, mask=mask), 52 | optax.adam(1e-2), 53 | ) 54 | 55 | from flax.training.train_state import TrainState 56 | state = TrainState.create( 57 | apply_fn=model.apply, params=params, tx=optimizer, 58 | ) 59 | 60 | def loss(params, key): 61 | g = model.apply(params, G, rngs={"dropout": key}) 62 | y = g.ndata['h'] 63 | return optax.softmax_cross_entropy( 64 | y[g.ndata['train_mask']], 65 | Y_REF[g.ndata['train_mask']], 66 | ).mean() 67 | 68 | @jax.jit 69 | def step(state, key): 70 | key, new_key = jax.random.split(key) 71 | grad_fn = jax.grad(partial(loss, key=new_key)) 72 | grads = grad_fn(state.params) 73 | state = state.apply_gradients(grads=grads) 74 | return state, key 75 | 76 | @jax.jit 77 | def eval(state): 78 | params = state.params 79 | g = model_eval.apply(params, G) 80 | y = g.ndata['h'] 81 | accuracy_vl = (Y_REF[g.ndata['val_mask']].argmax(-1) == 82 | y[g.ndata['val_mask']].argmax(-1)).sum() /\ 83 | g.ndata['val_mask'].sum() 84 | loss_vl = optax.softmax_cross_entropy( 85 | y[g.ndata['val_mask']], 86 | Y_REF[g.ndata['val_mask']], 87 | ).mean() 88 | return accuracy_vl, loss_vl 89 | 90 | @jax.jit 91 | def test(state): 92 | params = state.params 93 | g = model_eval.apply(params, G) 94 | y = g.ndata['h'] 95 | accuracy_te = (Y_REF[g.ndata['test_mask']].argmax(-1) == 96 | y[g.ndata['test_mask']].argmax(-1)).sum() /\ 97 | g.ndata['test_mask'].sum() 98 | loss_te = optax.softmax_cross_entropy( 99 | y[g.ndata['test_mask']], 100 | Y_REF[g.ndata['test_mask']], 101 | ).mean() 102 | return accuracy_te, loss_te 103 | 104 | from galax.nn.utils import EarlyStopping 105 | early_stopping = EarlyStopping(10) 106 | 107 | import tqdm 108 | for _ in tqdm.tqdm(range(1000)): 109 | state, key = step(state, key) 110 | accuracy_vl, loss_vl = eval(state) 111 | if early_stopping((-accuracy_vl, loss_vl), state.params): 112 | state = state.replace(params=early_stopping.params) 113 | break 114 | 115 | accuracy_te, _ = test(state) 116 | print(accuracy_te) 117 | 118 | if __name__ == "__main__": 119 | import argparse 120 | run() 121 | -------------------------------------------------------------------------------- /benchmarks/gcn/citeseer.py: -------------------------------------------------------------------------------- 1 | """Reference: 70.3; Reproduction: 71.4""" 2 | 3 | from functools import partial 4 | import jax 5 | from flax import linen as nn 6 | import optax 7 | import galax 8 | 9 | 10 | def run(): 11 | from galax.data.datasets.nodes.planetoid import citeseer 12 | G = citeseer() 13 | G = G.add_self_loop() 14 | Y_REF = jax.nn.one_hot(G.ndata['label'], 6) 15 | 16 | from galax.nn.zoo.gcn import GCN 17 | model = nn.Sequential( 18 | ( 19 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=False)), 20 | GCN(16, activation=jax.nn.relu), 21 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=False)), 22 | GCN(6, activation=None), 23 | ), 24 | ) 25 | 26 | model_eval = nn.Sequential( 27 | ( 28 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=True)), 29 | GCN(16, activation=jax.nn.relu), 30 | galax.ApplyNodes(nn.Dropout(0.5, deterministic=True)), 31 | GCN(6, activation=None), 32 | ), 33 | ) 34 | 35 | key = jax.random.PRNGKey(2666) 36 | key, key_dropout = jax.random.split(key) 37 | 38 | params = model.init({"params": key, "dropout": key_dropout}, G) 39 | 40 | from flax.core import FrozenDict 41 | mask = FrozenDict( 42 | {"params": 43 | { 44 | "layers_1": True, 45 | "layers_3": False, 46 | }, 47 | }, 48 | ) 49 | 50 | optimizer = optax.chain( 51 | optax.additive_weight_decay(5e-4, mask=mask), 52 | optax.adam(1e-2), 53 | ) 54 | 55 | from flax.training.train_state import TrainState 56 | state = TrainState.create( 57 | apply_fn=model.apply, params=params, tx=optimizer, 58 | ) 59 | 60 | def loss(params, key): 61 | g = model.apply(params, G, rngs={"dropout": key}) 62 | y = g.ndata['h'] 63 | return optax.softmax_cross_entropy( 64 | y[g.ndata['train_mask']], 65 | Y_REF[g.ndata['train_mask']], 66 | ).mean() 67 | 68 | @jax.jit 69 | def step(state, key): 70 | key, new_key = jax.random.split(key) 71 | grad_fn = jax.grad(partial(loss, key=new_key)) 72 | grads = grad_fn(state.params) 73 | state = state.apply_gradients(grads=grads) 74 | return state, key 75 | 76 | @jax.jit 77 | def eval(state): 78 | params = state.params 79 | g = model_eval.apply(params, G) 80 | y = g.ndata['h'] 81 | accuracy_vl = (Y_REF[g.ndata['val_mask']].argmax(-1) == 82 | y[g.ndata['val_mask']].argmax(-1)).sum() /\ 83 | g.ndata['val_mask'].sum() 84 | loss_vl = optax.softmax_cross_entropy( 85 | y[g.ndata['val_mask']], 86 | Y_REF[g.ndata['val_mask']], 87 | ).mean() 88 | return accuracy_vl, loss_vl 89 | 90 | @jax.jit 91 | def test(state): 92 | params = state.params 93 | g = model_eval.apply(params, G) 94 | y = g.ndata['h'] 95 | accuracy_te = (Y_REF[g.ndata['test_mask']].argmax(-1) == 96 | y[g.ndata['test_mask']].argmax(-1)).sum() /\ 97 | g.ndata['test_mask'].sum() 98 | loss_te = optax.softmax_cross_entropy( 99 | y[g.ndata['test_mask']], 100 | Y_REF[g.ndata['test_mask']], 101 | ).mean() 102 | return accuracy_te, loss_te 103 | 104 | from galax.nn.utils import EarlyStopping 105 | early_stopping = EarlyStopping(10) 106 | 107 | import tqdm 108 | for _ in tqdm.tqdm(range(1000)): 109 | state, key = step(state, key) 110 | accuracy_vl, loss_vl = eval(state) 111 | if early_stopping((-accuracy_vl, loss_vl), state.params): 112 | state = state.replace(params=early_stopping.params) 113 | break 114 | 115 | accuracy_te, _ = test(state) 116 | print(accuracy_te) 117 | 118 | 119 | if __name__ == "__main__": 120 | import argparse 121 | run() 122 | -------------------------------------------------------------------------------- /benchmarks/gat/cora.py: -------------------------------------------------------------------------------- 1 | """Reference: 83.0; Reproduction: 83.3""" 2 | 3 | from functools import partial 4 | import jax 5 | from flax import linen as nn 6 | import optax 7 | import galax 8 | 9 | def run(): 10 | from galax.data.datasets.nodes.planetoid import cora 11 | G = cora() 12 | G = G.add_self_loop() 13 | Y_REF = jax.nn.one_hot(G.ndata['label'], 7) 14 | 15 | from galax.nn.zoo.gat import GAT 16 | ConcatenationPooling = galax.ApplyNodes(lambda x: x.reshape(*x.shape[:-2], -1)) 17 | AveragePooling = galax.ApplyNodes(lambda x: x.mean(-2)) 18 | 19 | model = nn.Sequential( 20 | ( 21 | GAT(8, 8, attn_drop=0.4, feat_drop=0.4, deterministic=False, activation=jax.nn.elu), 22 | ConcatenationPooling, 23 | GAT(7, 1, attn_drop=0.4, feat_drop=0.4, deterministic=False, activation=None), 24 | AveragePooling, 25 | ), 26 | ) 27 | 28 | model_eval = nn.Sequential( 29 | ( 30 | GAT(8, 8, attn_drop=0.4, feat_drop=0.4, deterministic=True, activation=jax.nn.elu), 31 | ConcatenationPooling, 32 | GAT(7, 1, attn_drop=0.4, feat_drop=0.4, deterministic=True, activation=None), 33 | AveragePooling, 34 | ), 35 | ) 36 | 37 | key = jax.random.PRNGKey(2666) 38 | key, key_dropout = jax.random.split(key) 39 | 40 | params = model.init({"params": key, "dropout": key}, G) 41 | mask = jax.tree_map(lambda x: (x != 0).any(), params) 42 | 43 | from flax.core import FrozenDict 44 | 45 | optimizer = optax.chain( 46 | optax.additive_weight_decay(0.0005, mask=mask), 47 | optax.adam(0.005), 48 | ) 49 | 50 | from flax.training.train_state import TrainState 51 | state = TrainState.create( 52 | apply_fn=model.apply, params=params, tx=optimizer, 53 | ) 54 | 55 | def loss(params, key): 56 | g = model.apply(params, G, rngs={"dropout": key}) 57 | y = g.ndata['h'] 58 | return optax.softmax_cross_entropy( 59 | y[g.ndata['train_mask']], 60 | Y_REF[g.ndata['train_mask']], 61 | ).mean() 62 | 63 | @jax.jit 64 | def step(state, key): 65 | key, new_key = jax.random.split(key) 66 | grad_fn = jax.grad(partial(loss, key=new_key)) 67 | grads = grad_fn(state.params) 68 | state = state.apply_gradients(grads=grads) 69 | return state, key 70 | 71 | @jax.jit 72 | def eval(state): 73 | params = state.params 74 | g = model_eval.apply(params, G) 75 | y = g.ndata['h'] 76 | accuracy_vl = (Y_REF[g.ndata['val_mask']].argmax(-1) == 77 | y[g.ndata['val_mask']].argmax(-1)).sum() /\ 78 | g.ndata['val_mask'].sum() 79 | loss_vl = optax.softmax_cross_entropy( 80 | y[g.ndata['val_mask']], 81 | Y_REF[g.ndata['val_mask']], 82 | ).mean() 83 | return accuracy_vl, loss_vl 84 | 85 | @jax.jit 86 | def test(state): 87 | params = state.params 88 | g = model_eval.apply(params, G) 89 | y = g.ndata['h'] 90 | accuracy_te = (Y_REF[g.ndata['test_mask']].argmax(-1) == 91 | y[g.ndata['test_mask']].argmax(-1)).sum() /\ 92 | g.ndata['test_mask'].sum() 93 | loss_te = optax.softmax_cross_entropy( 94 | y[g.ndata['test_mask']], 95 | Y_REF[g.ndata['test_mask']], 96 | ).mean() 97 | return accuracy_te, loss_te 98 | 99 | from galax.nn.utils import EarlyStopping 100 | early_stopping = EarlyStopping(100) 101 | 102 | import tqdm 103 | for _ in tqdm.tqdm(range(1000)): 104 | state, key = step(state, key) 105 | accuracy_vl, loss_vl = eval(state) 106 | if early_stopping((-accuracy_vl, loss_vl), state.params): 107 | state = state.replace(params=early_stopping.params) 108 | break 109 | 110 | accuracy_te, _ = test(state) 111 | print(accuracy_te) 112 | 113 | if __name__ == "__main__": 114 | import argparse 115 | run() 116 | -------------------------------------------------------------------------------- /benchmarks/gat/pubmed.py: -------------------------------------------------------------------------------- 1 | """Reference: 79.0; Reproduction: 78.1""" 2 | 3 | from functools import partial 4 | import jax 5 | from flax import linen as nn 6 | import optax 7 | import galax 8 | 9 | def run(): 10 | from galax.data.datasets.nodes.planetoid import pubmed 11 | G = pubmed() 12 | G = G.add_self_loop() 13 | Y_REF = jax.nn.one_hot(G.ndata['label'], 3) 14 | 15 | from galax.nn.zoo.gat import GAT 16 | ConcatenationPooling = galax.ApplyNodes(lambda x: x.reshape(*x.shape[:-2], -1)) 17 | AveragePooling = galax.ApplyNodes(lambda x: x.mean(-2)) 18 | 19 | model = nn.Sequential( 20 | ( 21 | GAT(8, 8, attn_drop=0.4, feat_drop=0.4, deterministic=False, activation=jax.nn.elu), 22 | ConcatenationPooling, 23 | GAT(3, 8, attn_drop=0.4, feat_drop=0.4, deterministic=False, activation=None), 24 | AveragePooling, 25 | ), 26 | ) 27 | 28 | model_eval = nn.Sequential( 29 | ( 30 | GAT(8, 8, attn_drop=0.4, feat_drop=0.4, deterministic=True, activation=jax.nn.elu), 31 | ConcatenationPooling, 32 | GAT(3, 8, attn_drop=0.4, feat_drop=0.4, deterministic=True, activation=None), 33 | AveragePooling, 34 | ), 35 | ) 36 | 37 | key = jax.random.PRNGKey(2666) 38 | key, key_dropout = jax.random.split(key) 39 | 40 | params = model.init({"params": key, "dropout": key}, G) 41 | mask = jax.tree_map(lambda x: (x != 0).any(), params) 42 | 43 | from flax.core import FrozenDict 44 | 45 | optimizer = optax.chain( 46 | optax.additive_weight_decay(0.001, mask=mask), 47 | optax.adam(0.01), 48 | ) 49 | 50 | from flax.training.train_state import TrainState 51 | state = TrainState.create( 52 | apply_fn=model.apply, params=params, tx=optimizer, 53 | ) 54 | 55 | def loss(params, key): 56 | g = model.apply(params, G, rngs={"dropout": key}) 57 | y = g.ndata['h'] 58 | return optax.softmax_cross_entropy( 59 | y[g.ndata['train_mask']], 60 | Y_REF[g.ndata['train_mask']], 61 | ).mean() 62 | 63 | @jax.jit 64 | def step(state, key): 65 | key, new_key = jax.random.split(key) 66 | grad_fn = jax.grad(partial(loss, key=new_key)) 67 | grads = grad_fn(state.params) 68 | state = state.apply_gradients(grads=grads) 69 | return state, key 70 | 71 | @jax.jit 72 | def eval(state): 73 | params = state.params 74 | g = model_eval.apply(params, G) 75 | y = g.ndata['h'] 76 | accuracy_vl = (Y_REF[g.ndata['val_mask']].argmax(-1) == 77 | y[g.ndata['val_mask']].argmax(-1)).sum() /\ 78 | g.ndata['val_mask'].sum() 79 | loss_vl = optax.softmax_cross_entropy( 80 | y[g.ndata['val_mask']], 81 | Y_REF[g.ndata['val_mask']], 82 | ).mean() 83 | return accuracy_vl, loss_vl 84 | 85 | @jax.jit 86 | def test(state): 87 | params = state.params 88 | g = model_eval.apply(params, G) 89 | y = g.ndata['h'] 90 | accuracy_te = (Y_REF[g.ndata['test_mask']].argmax(-1) == 91 | y[g.ndata['test_mask']].argmax(-1)).sum() /\ 92 | g.ndata['test_mask'].sum() 93 | loss_te = optax.softmax_cross_entropy( 94 | y[g.ndata['test_mask']], 95 | Y_REF[g.ndata['test_mask']], 96 | ).mean() 97 | return accuracy_te, loss_te 98 | 99 | from galax.nn.utils import EarlyStopping 100 | early_stopping = EarlyStopping(100) 101 | 102 | import tqdm 103 | for _ in tqdm.tqdm(range(1000)): 104 | state, key = step(state, key) 105 | accuracy_vl, loss_vl = eval(state) 106 | if early_stopping((-accuracy_vl, loss_vl), state.params): 107 | state = state.replace(params=early_stopping.params) 108 | break 109 | 110 | accuracy_te, _ = test(state) 111 | print(accuracy_te) 112 | 113 | if __name__ == "__main__": 114 | import argparse 115 | run() 116 | -------------------------------------------------------------------------------- /benchmarks/gat/citeseer.py: -------------------------------------------------------------------------------- 1 | """Reference: 72.5; Reproduction: 71.5""" 2 | 3 | from functools import partial 4 | import jax 5 | from flax import linen as nn 6 | import optax 7 | import galax 8 | 9 | def run(): 10 | from galax.data.datasets.nodes.planetoid import citeseer 11 | G = citeseer() 12 | G = G.add_self_loop() 13 | Y_REF = jax.nn.one_hot(G.ndata['label'], 6) 14 | 15 | from galax.nn.zoo.gat import GAT 16 | ConcatenationPooling = galax.ApplyNodes(lambda x: x.reshape(*x.shape[:-2], -1)) 17 | AveragePooling = galax.ApplyNodes(lambda x: x.mean(-2)) 18 | 19 | model = nn.Sequential( 20 | ( 21 | GAT(8, 8, attn_drop=0.4, feat_drop=0.4, deterministic=False, activation=jax.nn.elu), 22 | ConcatenationPooling, 23 | GAT(6, 1, attn_drop=0.4, feat_drop=0.4, deterministic=False, activation=None), 24 | AveragePooling, 25 | ), 26 | ) 27 | 28 | model_eval = nn.Sequential( 29 | ( 30 | GAT(8, 8, attn_drop=0.6, feat_drop=0.6, deterministic=True, activation=jax.nn.elu), 31 | ConcatenationPooling, 32 | GAT(6, 1, attn_drop=0.6, feat_drop=0.6, deterministic=True, activation=None), 33 | AveragePooling, 34 | ), 35 | ) 36 | 37 | key = jax.random.PRNGKey(2666) 38 | key, key_dropout = jax.random.split(key) 39 | 40 | params = model.init({"params": key, "dropout": key}, G) 41 | mask = jax.tree_map(lambda x: (x != 0).any(), params) 42 | 43 | from flax.core import FrozenDict 44 | 45 | optimizer = optax.chain( 46 | optax.additive_weight_decay(0.0005, mask=mask), 47 | optax.adam(0.005), 48 | ) 49 | 50 | from flax.training.train_state import TrainState 51 | state = TrainState.create( 52 | apply_fn=model.apply, params=params, tx=optimizer, 53 | ) 54 | 55 | def loss(params, key): 56 | g = model.apply(params, G, rngs={"dropout": key}) 57 | y = g.ndata['h'] 58 | return optax.softmax_cross_entropy( 59 | y[g.ndata['train_mask']], 60 | Y_REF[g.ndata['train_mask']], 61 | ).mean() 62 | 63 | @jax.jit 64 | def step(state, key): 65 | key, new_key = jax.random.split(key) 66 | grad_fn = jax.grad(partial(loss, key=new_key)) 67 | grads = grad_fn(state.params) 68 | state = state.apply_gradients(grads=grads) 69 | return state, key 70 | 71 | @jax.jit 72 | def eval(state): 73 | params = state.params 74 | g = model_eval.apply(params, G) 75 | y = g.ndata['h'] 76 | accuracy_vl = (Y_REF[g.ndata['val_mask']].argmax(-1) == 77 | y[g.ndata['val_mask']].argmax(-1)).sum() /\ 78 | g.ndata['val_mask'].sum() 79 | loss_vl = optax.softmax_cross_entropy( 80 | y[g.ndata['val_mask']], 81 | Y_REF[g.ndata['val_mask']], 82 | ).mean() 83 | return accuracy_vl, loss_vl 84 | 85 | @jax.jit 86 | def test(state): 87 | params = state.params 88 | g = model_eval.apply(params, G) 89 | y = g.ndata['h'] 90 | accuracy_te = (Y_REF[g.ndata['test_mask']].argmax(-1) == 91 | y[g.ndata['test_mask']].argmax(-1)).sum() /\ 92 | g.ndata['test_mask'].sum() 93 | loss_te = optax.softmax_cross_entropy( 94 | y[g.ndata['test_mask']], 95 | Y_REF[g.ndata['test_mask']], 96 | ).mean() 97 | return accuracy_te, loss_te 98 | 99 | from galax.nn.utils import EarlyStopping 100 | early_stopping = EarlyStopping(100) 101 | 102 | import tqdm 103 | for _ in tqdm.tqdm(range(1000)): 104 | state, key = step(state, key) 105 | accuracy_vl, loss_vl = eval(state) 106 | if early_stopping((-accuracy_vl, loss_vl), state.params): 107 | state = state.replace(params=early_stopping.params) 108 | break 109 | 110 | accuracy_te, _ = test(state) 111 | print(accuracy_te) 112 | 113 | if __name__ == "__main__": 114 | import argparse 115 | run() 116 | -------------------------------------------------------------------------------- /galax/nn/zoo/gat.py: -------------------------------------------------------------------------------- 1 | """`Graph Attention Network `__""" 2 | 3 | from typing import Callable, Optional 4 | import jax 5 | import jax.numpy as jnp 6 | from flax import linen as nn 7 | from ... import function as fn 8 | 9 | class GAT(nn.Module): 10 | r""" 11 | Apply `Graph Attention Network `__ 12 | over an input signal. 13 | .. math:: 14 | h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)} 15 | where :math:`\alpha_{ij}` is the attention score bewteen node :math:`i` and 16 | node :math:`j`: 17 | .. math:: 18 | \alpha_{ij}^{l} &= \mathrm{softmax_i} (e_{ij}^{l}) 19 | e_{ij}^{l} &= 20 | \mathrm{LeakyReLU}\left(\vec{a}^T [W h_{i} \| W h_{j}]\right) 21 | 22 | Parameters 23 | ---------- 24 | features : int 25 | Features 26 | num_heads : int 27 | Number of attention heads. 28 | feat_drop : float, optional 29 | Dropout rate on feature. Defaults: ``0``. 30 | attn_drop : float, optional 31 | Dropout rate on attention weight. Defaults: ``0``. 32 | negative_slope : float, optional 33 | LeakyReLU angle of negative slope. Defaults: ``0.2``. 34 | activation : callable activation function/layer or None, optional. 35 | If not None, applies an activation function to the updated node features. 36 | Default: ``None``. 37 | 38 | Examples 39 | -------- 40 | >>> import jax 41 | >>> import jax.numpy as jnp 42 | >>> import galax 43 | >>> g = galax.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) 44 | >>> g = g.add_self_loop() 45 | >>> g = g.set_ndata("h", jnp.ones((6, 10))) 46 | >>> gat = GAT(2, 4, deterministic=True) 47 | >>> params = gat.init(jax.random.PRNGKey(2666), g) 48 | >>> g = gat.apply(params, g) 49 | >>> x = g.ndata['h'] 50 | >>> x.shape 51 | (6, 4, 2) 52 | """ 53 | 54 | features: int 55 | num_heads: int 56 | feat_drop: Optional[float] = 0.0 57 | attn_drop: Optional[float] = 0.0 58 | negative_slope: float = 0.2 59 | activation: Optional[Callable] = None 60 | deterministic: bool = True 61 | use_bias: bool = True 62 | 63 | def setup(self): 64 | self.fc = nn.Dense( 65 | self.features * self.num_heads, use_bias=False, 66 | kernel_init=nn.initializers.variance_scaling( 67 | 3.0, "fan_avg", "uniform" 68 | ), 69 | ) 70 | 71 | self.attn_l = nn.Dense( 72 | 1, 73 | kernel_init=nn.initializers.variance_scaling( 74 | 3.0, "fan_avg", "uniform" 75 | ), 76 | ) 77 | 78 | self.attn_r = nn.Dense( 79 | 1, 80 | kernel_init=nn.initializers.variance_scaling( 81 | 3.0, "fan_avg", "uniform" 82 | ), 83 | ) 84 | 85 | if self.use_bias: 86 | self.bias = self.param( 87 | "bias", 88 | nn.zeros, 89 | (self.num_heads, self.features), 90 | ) 91 | 92 | self.dropout_feat = nn.Dropout(self.feat_drop, deterministic=self.deterministic) 93 | self.dropout_attn = nn.Dropout(self.attn_drop, deterministic=self.deterministic) 94 | 95 | def __call__(self, graph, field="h", etype="E_"): 96 | h = graph.ndata[field] 97 | h0 = h 98 | h = self.dropout_feat(h) 99 | h = self.fc(h) 100 | h = h.reshape(h.shape[:-1] + (self.num_heads, self.features)) 101 | el = self.attn_l(h) 102 | er = self.attn_r(h) 103 | graph = graph.ndata.set(field, h) 104 | graph = graph.ndata.set("er", er) 105 | graph = graph.ndata.set("el", el) 106 | e = graph.edges[etype].src["er"] + graph.edges[etype].dst["el"] 107 | e = nn.leaky_relu(e, self.negative_slope) 108 | a = fn.segment_softmax(e, graph.edges[etype].dst.idxs, graph.number_of_nodes()) 109 | a = self.dropout_attn(a) 110 | graph = graph.edata.set("a", a) 111 | graph = graph.update_all( 112 | fn.u_mul_e(field, "a", "m"), 113 | fn.sum("m", field) 114 | ) 115 | 116 | if self.use_bias: 117 | graph = fn.apply_nodes( 118 | lambda x: x + self.bias, in_field=field, out_field=field 119 | )(graph) 120 | 121 | if self.activation is not None: 122 | graph = fn.apply_nodes( 123 | self.activation, in_field=field, out_field=field 124 | )(graph) 125 | return graph 126 | -------------------------------------------------------------------------------- /galax/tests/test_dgl_consistency.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import dgl 3 | import galax 4 | import numpy as onp 5 | import jax.numpy as jnp 6 | 7 | graphs = [] 8 | for n_nodes in [2, 10, 20]: 9 | for n_edges in range(n_nodes, n_nodes**2, n_nodes): 10 | graphs.append(dgl.rand_graph(n_nodes, n_edges)) 11 | 12 | 13 | tensor_shapes = [ 14 | (3, ), 15 | (3, 3), 16 | (3, 3, 1), 17 | (1, 3, 1, 3), 18 | (1, 3, 1, 1, 3), 19 | ] 20 | 21 | spmm_shapes = [ 22 | ((1, 2, 1, 3, 1), (4, 1, 3, 1, 1)), 23 | ((3, 3), (1, 3)), 24 | ((1,), (3,)), 25 | ((3,), (1,)), 26 | ((1,), (1,)), 27 | ((), ()) 28 | ] 29 | 30 | @pytest.mark.parametrize("g", graphs) 31 | def test_from_dgl(g): 32 | _g = galax.from_dgl(g) 33 | assert _g.number_of_nodes() == g.number_of_nodes() 34 | assert _g.number_of_edges() == g.number_of_edges() 35 | _src, _dst = _g.edges() 36 | src, dst = g.edges() 37 | assert(src.numpy().tolist() == _src.tolist()) 38 | assert(dst.numpy().tolist() == _dst.tolist()) 39 | 40 | @pytest.mark.parametrize("g", graphs) 41 | @pytest.mark.parametrize("shape", tensor_shapes) 42 | def test_apply_nodes(g, shape): 43 | _g = galax.from_dgl(g) 44 | 45 | W = onp.random.normal(size=(shape[-1], 7)) 46 | B = onp.random.normal(size=(7, )) 47 | X = onp.random.normal(size=(g.number_of_nodes(), *shape)) 48 | 49 | import torch 50 | w = torch.tensor(W) 51 | b = torch.tensor(B) 52 | x = torch.tensor(X) 53 | _w = jnp.array(w) 54 | _b = jnp.array(b) 55 | _x = jnp.array(x) 56 | 57 | def fn(x): 58 | return x @ w + b 59 | 60 | def _fn(_x): 61 | return _x @ _w + _b 62 | 63 | _g = _g.ndata.set("h", _x) 64 | g.ndata["h"] = x 65 | 66 | _g = galax.apply_nodes(_fn)(_g) 67 | _y = onp.array(_g.ndata['h']) 68 | 69 | g.apply_nodes(lambda nodes: {'h': fn(nodes.data['h'])}) 70 | y = g.ndata['h'].detach().numpy() 71 | 72 | assert onp.allclose(y, _y, rtol=1e-2, atol=1e-2) 73 | 74 | @pytest.mark.parametrize("g", graphs) 75 | @pytest.mark.parametrize("shape", tensor_shapes) 76 | def test_apply_edges(g, shape): 77 | _g = galax.from_dgl(g) 78 | 79 | W = onp.random.normal(size=(shape[-1], 7)) 80 | B = onp.random.normal(size=(7, )) 81 | X = onp.random.normal(size=(g.number_of_edges(), *shape)) 82 | 83 | import torch 84 | w = torch.tensor(W) 85 | b = torch.tensor(B) 86 | x = torch.tensor(X) 87 | _w = jnp.array(w) 88 | _b = jnp.array(b) 89 | _x = jnp.array(x) 90 | 91 | def fn(x): 92 | return x @ w + b 93 | 94 | def _fn(_x): 95 | return _x @ _w + _b 96 | 97 | _g = _g.edata.set("h", _x) 98 | g.edata["h"] = x 99 | 100 | _g = galax.apply_edges(_fn)(_g) 101 | _y = onp.array(_g.edata['h']) 102 | 103 | g.apply_edges(lambda edges: {'h': fn(edges.data['h'])}) 104 | y = g.edata['h'].detach().numpy() 105 | 106 | assert onp.allclose(y, _y, rtol=1e-2, atol=1e-2) 107 | 108 | 109 | @pytest.mark.parametrize('g', graphs) 110 | @pytest.mark.parametrize('shp', spmm_shapes) 111 | @pytest.mark.parametrize('msg', ['add', 'sub', 'mul', 'div', 'copy_lhs']) 112 | @pytest.mark.parametrize('reducer', ['sum', 'min', 'max']) 113 | def test_message_passing(g, shp, msg, reducer): 114 | _g = galax.from_dgl(g) 115 | 116 | udf_msg_dgl = { 117 | 'add': lambda edges: {'m': edges.src['x'] + edges.data['w']}, 118 | 'sub': lambda edges: {'m': edges.src['x'] - edges.data['w']}, 119 | 'mul': lambda edges: {'m': edges.src['x'] * edges.data['w']}, 120 | 'div': lambda edges: {'m': edges.src['x'] / edges.data['w']}, 121 | 'copy_lhs': lambda edges: {'m': edges.src['x']}, 122 | 'copy_rhs': lambda edges: {'m': edges.data['w']}, 123 | } 124 | 125 | import torch 126 | udf_reduce_dgl = { 127 | 'sum': lambda nodes: {'v': torch.sum(nodes.mailbox['m'], 1)}, 128 | 'min': lambda nodes: {'v': torch.min(nodes.mailbox['m'], 1)[0]}, 129 | 'max': lambda nodes: {'v': torch.max(nodes.mailbox['m'], 1)[0]} 130 | } 131 | 132 | hu = onp.random.normal(size=(g.number_of_nodes(),) + shp[0]) + 1.0 133 | he = onp.random.normal(size=(g.number_of_edges(),) + shp[1]) + 1.0 134 | 135 | import torch 136 | hu_dgl = torch.tensor(hu) 137 | he_dgl = torch.tensor(he) 138 | g.ndata['x'] = hu_dgl 139 | g.edata['w'] = he_dgl 140 | 141 | hu_glx = jnp.array(hu) 142 | he_glx = jnp.array(he) 143 | _g = _g.ndata.set('x', hu) 144 | _g = _g.edata.set('w', he) 145 | 146 | g.update_all(udf_msg_dgl[msg], udf_reduce_dgl[reducer]) 147 | 148 | if msg == "copy_lhs": 149 | msg_fn_glx = galax.function.copy_u("x", "m") 150 | elif msg == "copy_rhs": 151 | msg_fn_glx = galax.function.copy_e("w", "m") 152 | else: 153 | msg_fn_glx = getattr(galax.function, f"u_{msg}_e")("x", "w", "m") 154 | reduce_fn_glx = getattr(galax.function, reducer)("m", "v") 155 | _g = _g.update_all(msg_fn_glx, reduce_fn_glx) 156 | 157 | y_dgl = g.ndata['v'].detach().numpy() 158 | y_glx = onp.array(_g.ndata['v']) 159 | 160 | assert onp.allclose(y_dgl, y_glx, rtol=1e-2, atol=1e-2) 161 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/stable/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | 15 | # Incase the project was not installed 16 | import os 17 | import sys 18 | 19 | sys.path.insert(0, os.path.abspath('..')) 20 | 21 | import galax 22 | from galax import heterograph, function, core, nn, data 23 | 24 | 25 | # -- Project information ----------------------------------------------------- 26 | 27 | project = 'galax' 28 | copyright = ("2022, Yuanqing Wang") 29 | author = 'Yuanqing Wang' 30 | 31 | # The short X.Y version 32 | version = '' 33 | # The full version, including alpha/beta/rc tags 34 | release = '' 35 | 36 | 37 | # -- General configuration --------------------------------------------------- 38 | 39 | # If your documentation needs a minimal Sphinx version, state it here. 40 | # 41 | # needs_sphinx = '1.0' 42 | 43 | # Add any Sphinx extension module names here, as strings. They can be 44 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 45 | # ones. 46 | extensions = [ 47 | 'sphinx.ext.autosummary', 48 | 'sphinx.ext.autodoc', 49 | 'sphinx.ext.mathjax', 50 | 'sphinx.ext.viewcode', 51 | 'sphinx.ext.napoleon', 52 | 'sphinx.ext.intersphinx', 53 | 'sphinx.ext.extlinks', 54 | 'sphinx.ext.coverage', 55 | 'm2r2', 56 | # 'numpydoc', 57 | ] 58 | 59 | autosummary_generate = True 60 | napoleon_google_docstring = False 61 | napoleon_use_param = False 62 | napoleon_use_ivar = True 63 | 64 | # Add any paths that contain templates here, relative to this directory. 65 | templates_path = ['_templates'] 66 | 67 | # The suffix(es) of source filenames. 68 | # You can specify multiple suffix as a list of string: 69 | # 70 | # source_suffix = ['.rst', '.md'] 71 | source_suffix = '.rst' 72 | 73 | # The master toctree document. 74 | master_doc = 'index' 75 | 76 | # The language for content autogenerated by Sphinx. Refer to documentation 77 | # for a list of supported languages. 78 | # 79 | # This is also used if you do content translation via gettext catalogs. 80 | # Usually you set "language" from the command line for these cases. 81 | language = None 82 | 83 | # List of patterns, relative to source directory, that match files and 84 | # directories to ignore when looking for source files. 85 | # This pattern also affects html_static_path and html_extra_path . 86 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 87 | 88 | # The name of the Pygments (syntax highlighting) style to use. 89 | pygments_style = 'default' 90 | 91 | 92 | # -- Options for HTML output ------------------------------------------------- 93 | 94 | # The theme to use for HTML and HTML Help pages. See the documentation for 95 | # a list of builtin themes. 96 | # 97 | html_theme = 'sphinx_rtd_theme' 98 | 99 | # Theme options are theme-specific and customize the look and feel of a theme 100 | # further. For a list of options available for each theme, see the 101 | # documentation. 102 | # 103 | # html_theme_options = {} 104 | 105 | # Add any paths that contain custom static files (such as style sheets) here, 106 | # relative to this directory. They are copied after the builtin static files, 107 | # so a file named "default.css" will overwrite the builtin "default.css". 108 | html_static_path = ['_static'] 109 | 110 | # Custom sidebar templates, must be a dictionary that maps document names 111 | # to template names. 112 | # 113 | # The default sidebars (for documents that don't match any pattern) are 114 | # defined by theme itself. Builtin themes are using these templates by 115 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 116 | # 'searchbox.html']``. 117 | # 118 | # html_sidebars = {} 119 | 120 | 121 | # -- Options for HTMLHelp output --------------------------------------------- 122 | 123 | # Output file base name for HTML help builder. 124 | htmlhelp_basename = 'galax_doc' 125 | 126 | 127 | # -- Options for LaTeX output ------------------------------------------------ 128 | 129 | latex_elements = { 130 | # The paper size ('letterpaper' or 'a4paper'). 131 | # 132 | # 'papersize': 'letterpaper', 133 | 134 | # The font size ('10pt', '11pt' or '12pt'). 135 | # 136 | # 'pointsize': '10pt', 137 | 138 | # Additional stuff for the LaTeX preamble. 139 | # 140 | # 'preamble': '', 141 | 142 | # Latex figure (float) alignment 143 | # 144 | # 'figure_align': 'htbp', 145 | } 146 | 147 | # Grouping the document tree into LaTeX files. List of tuples 148 | # (source start file, target name, title, 149 | # author, documentclass [howto, manual, or own class]). 150 | latex_documents = [ 151 | (master_doc, 'galax.tex', 'galax Documentation', 152 | 'galax', 'manual'), 153 | ] 154 | 155 | 156 | # -- Options for manual page output ------------------------------------------ 157 | 158 | # One entry per manual page. List of tuples 159 | # (source start file, name, description, authors, manual section). 160 | man_pages = [ 161 | (master_doc, 'galax', 'galax Documentation', 162 | [author], 1) 163 | ] 164 | 165 | 166 | # -- Options for Texinfo output ---------------------------------------------- 167 | 168 | # Grouping the document tree into Texinfo files. List of tuples 169 | # (source start file, target name, title, author, 170 | # dir menu entry, description, category) 171 | texinfo_documents = [ 172 | (master_doc, 'galax', 'galax documentation', 173 | author, 'galax', 'graph learning with jax', 174 | 'Miscellaneous'), 175 | ] 176 | 177 | 178 | # -- Extension configuration ------------------------------------------------- 179 | -------------------------------------------------------------------------------- /galax/batch.py: -------------------------------------------------------------------------------- 1 | """Utilities for batching graphs.""" 2 | from typing import Sequence, Union 3 | import jax 4 | import jax.numpy as jnp 5 | from flax.core import FrozenDict 6 | from .heterograph import HeteroGraph 7 | from .heterograph_index import HeteroGraphIndex 8 | 9 | 10 | def batch(graphs: Sequence[HeteroGraph]): 11 | """Batch a sequence of graphs into one. 12 | 13 | Parameters 14 | ---------- 15 | graphs : Sequence[HeteroGraph] 16 | Sequence of graphs. 17 | 18 | Returns 19 | ------- 20 | HeteroGraph 21 | The batched graph. 22 | 23 | Examples 24 | -------- 25 | >>> import galax 26 | >>> g = galax.graph(([0, 0, 2], [0, 1, 2])) 27 | >>> _g = batch([g, g]) 28 | >>> _g.gidx.edges[0][0].tolist() 29 | [0, 0, 2, 3, 3, 5] 30 | >>> _g.gidx.edges[0][1].tolist() 31 | [0, 1, 2, 3, 4, 5] 32 | 33 | >>> g0 = g.ndata.set("h", jnp.zeros(3)) 34 | >>> g1 = g.ndata.set("h", jnp.ones(3)) 35 | >>> _g = batch([g0, g1]) 36 | >>> _g.ndata["h"].tolist() 37 | [0.0, 0.0, 0.0, 1.0, 1.0, 1.0] 38 | 39 | >>> _g.gdata["_batched_num_nodes"].flatten().tolist() 40 | [3, 3] 41 | 42 | """ 43 | # make sure the metagraphs are exactly the same 44 | assert all(graph.ntypes == graphs[0].ntypes for graph in graphs) 45 | assert all(graph.etypes == graphs[0].etypes for graph in graphs) 46 | assert all( 47 | graph.gidx.metagraph == graphs[0].gidx.metagraph 48 | for graph in graphs 49 | ) 50 | metagraph = graphs[0].gidx.metagraph 51 | 52 | # ntypes and etypes remain the same 53 | etypes = graphs[0].etypes 54 | ntypes = graphs[0].ntypes 55 | 56 | # number of nodes on offsets 57 | n_nodes = jnp.stack([graph.gidx.n_nodes for graph in graphs]) 58 | batched_num_nodes = n_nodes 59 | offsets = jnp.cumsum(n_nodes[:-1], axis=0) 60 | offsets = jnp.concatenate( 61 | [jnp.zeros((1, offsets.shape[-1]), dtype=jnp.int32), offsets] 62 | ) 63 | n_nodes = n_nodes.sum(axis=0) 64 | 65 | # edge indices with offsets added 66 | num_edge_types = len(graphs[0].gidx.edges) 67 | edges = [[[], []] for _ in range(num_edge_types)] 68 | for idx_etype in range(num_edge_types): 69 | for idx_graph, graph in enumerate(graphs): 70 | src, dst = graph.gidx.edges[idx_etype] 71 | src = src + offsets[idx_graph] 72 | dst = dst + offsets[idx_graph] 73 | edges[idx_etype][0].append(src) 74 | edges[idx_etype][1].append(dst) 75 | edges[idx_etype][0] = jnp.concatenate(edges[idx_etype][0]) 76 | edges[idx_etype][1] = jnp.concatenate(edges[idx_etype][1]) 77 | edges[idx_etype] = tuple(edges[idx_etype]) 78 | edges = tuple(edges) 79 | gidx = HeteroGraphIndex(n_nodes=n_nodes, edges=edges, metagraph=metagraph) 80 | 81 | # concatenate frames 82 | node_frames = ( 83 | FrozenDict( 84 | { 85 | key: 86 | jnp.concatenate( 87 | [graph.node_frames[idx][key] for graph in graphs] 88 | ) 89 | for key in graphs[0].node_frames[idx].keys() 90 | } 91 | ) 92 | if graphs[0].node_frames[idx] is not None else None 93 | for idx in range(len(graphs[0].node_frames)) 94 | ) 95 | 96 | edge_frames = ( 97 | FrozenDict( 98 | { 99 | key: 100 | jnp.concatenate( 101 | [graph.edge_frames[idx][key] for graph in graphs] 102 | ) 103 | for key in graphs[0].edge_frames[idx].keys() 104 | } 105 | ) 106 | if graphs[0].edge_frames[idx] is not None else None 107 | for idx in range(len(graphs[0].edge_frames)) 108 | ) 109 | 110 | # (n_graphs, n_ntypes) 111 | original_batched_num_nodes = [ 112 | graph.graph_frame.get("_batched_num_nodes", default=None) 113 | if graph.graph_frame is not None 114 | else None 115 | for graph in graphs 116 | ] 117 | 118 | batched_num_nodes = [ 119 | jnp.expand_dims(batched_num_nodes[idx], 0) 120 | if original_batched_num_nodes[idx] is None 121 | else original_batched_num_nodes[idx] 122 | for idx in range(len(original_batched_num_nodes)) 123 | ] 124 | 125 | batched_num_nodes = jnp.concatenate(batched_num_nodes) 126 | 127 | if graphs[0].graph_frame is not None: 128 | if list(graphs[0].graph_frame.keys()) != ["_batched_num_nodes"]: 129 | graph_frame = { 130 | key: 131 | jnp.concatenate( 132 | [graph.graph_frame[key] for graph in graphs] 133 | ) 134 | for key in graphs[0].graph_frame.keys() 135 | if key != "_batched_num_nodes" 136 | } 137 | else: 138 | graph_frame = {} 139 | else: 140 | graph_frame = {} 141 | 142 | graph_frame.update({"_batched_num_nodes": batched_num_nodes}) 143 | 144 | return HeteroGraph.init( 145 | gidx=gidx, 146 | ntypes=ntypes, 147 | etypes=etypes, 148 | node_frames=node_frames, 149 | edge_frames=edge_frames, 150 | graph_frame=graph_frame, 151 | ) 152 | 153 | def pad( 154 | graphs: Union[Sequence[HeteroGraph], HeteroGraph], 155 | n_nodes: Union[int, jnp.ndarray], 156 | n_edges: Union[int, jnp.ndarray], 157 | ): 158 | """Pad graphs to desired number of nodes and edges and batch them. 159 | 160 | Parameters 161 | ---------- 162 | graphs : Union[Sequence[HeteroGraph], HeteroGraph] 163 | A sequence of graphs, could be already batched. 164 | n_nodes : Union[int, jnp.ndarray] 165 | Number of nodes. 166 | n_edges : Union[int, jnp.ndarray] 167 | Number of edges. 168 | 169 | Returns 170 | ------- 171 | HeteroGraph 172 | Batched graph with desired padding. 173 | 174 | Examples 175 | -------- 176 | >>> import galax 177 | >>> g = galax.graph(([0, 0, 2], [0, 1, 2])) 178 | >>> _g = pad(g, g.number_of_nodes(), g.number_of_edges()) 179 | >>> _g.gidx == g.gidx 180 | True 181 | 182 | >>> _g = pad(g, 5, 8) 183 | >>> int(_g.number_of_edges()) 184 | 8 185 | >>> int(_g.number_of_nodes()) 186 | 5 187 | 188 | >>> g = g.ndata.set("h", jnp.ones(3)) 189 | >>> g = g.edata.set("h", jnp.ones(3)) 190 | >>> _g = pad(g, 5, 8) 191 | >>> _g.ndata["h"].tolist() 192 | [1.0, 1.0, 1.0, 0.0, 0.0] 193 | >>> _g.edata["h"].tolist() 194 | [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0] 195 | >>> _g.gdata["_batched_num_nodes"].flatten().tolist() 196 | [3, 2] 197 | >>> bool(_g.gdata["_has_dummy"]) 198 | True 199 | 200 | """ 201 | if not isinstance(graphs, HeteroGraph): 202 | graphs = batch(graphs) 203 | current_n_nodes = graphs.gidx.n_nodes 204 | current_n_edges = jnp.array([len(edge[0]) for edge in graphs.gidx.edges]) 205 | delta_n_nodes = n_nodes - current_n_nodes 206 | delta_n_edges = n_edges - current_n_edges 207 | n_nodes = delta_n_nodes 208 | edges = tuple( 209 | [ 210 | (jnp.zeros(_n_edge, jnp.int32), jnp.zeros(_n_edge, jnp.int32)) 211 | for _n_edge in delta_n_edges 212 | ] 213 | ) 214 | gidx = HeteroGraphIndex( 215 | metagraph=graphs.gidx.metagraph, 216 | n_nodes=n_nodes, 217 | edges=edges, 218 | ) 219 | 220 | node_frames = ( 221 | FrozenDict( 222 | { 223 | key: 224 | jnp.zeros( 225 | (n_nodes[idx], *graphs.node_frames[idx][key].shape[1:]), 226 | ) 227 | for key in graphs.node_frames[idx].keys() 228 | } 229 | ) 230 | if graphs.node_frames[idx] is not None else None 231 | for idx in range(len(graphs.node_frames)) 232 | ) 233 | 234 | edge_frames = ( 235 | FrozenDict( 236 | { 237 | key: 238 | jnp.zeros( 239 | ( 240 | delta_n_edges[idx], 241 | *graphs.edge_frames[idx][key].shape[1:], 242 | ), 243 | ) 244 | for key in graphs.edge_frames[idx].keys() 245 | } 246 | ) 247 | if graphs.edge_frames[idx] is not None else None 248 | for idx in range(len(graphs.edge_frames)) 249 | ) 250 | 251 | dummy = HeteroGraph.init( 252 | gidx=gidx, 253 | ntypes=graphs.ntypes, 254 | etypes=graphs.etypes, 255 | node_frames=node_frames, 256 | edge_frames=edge_frames, 257 | ) 258 | 259 | g = batch([graphs, dummy]) 260 | g = g.gdata.set("_has_dummy", jnp.array(True)) 261 | 262 | return g 263 | -------------------------------------------------------------------------------- /galax/function.py: -------------------------------------------------------------------------------- 1 | """Built-in functions.""" 2 | import sys 3 | from typing import Optional, Callable 4 | from functools import partial 5 | from itertools import product 6 | from collections import namedtuple 7 | import jax 8 | import jax.numpy as jnp 9 | from flax.core import freeze, unfreeze 10 | from flax import linen as nn 11 | 12 | # ============================================================================= 13 | # MESSAGE FUNCTIONS 14 | # ============================================================================= 15 | CODE2STR = { 16 | "u": "source", 17 | "v": "destination", 18 | "e": "edge", 19 | } 20 | 21 | CODE2OP = { 22 | "add": lambda x, y: x + y, 23 | "sub": lambda x, y: x - y, 24 | "mul": lambda x, y: x * y, 25 | "div": lambda x, y: x / y, 26 | "dot": lambda x, y: (x * y).sum(axis=-1, keepdims=True), 27 | } 28 | 29 | CODE2DATA = { 30 | "u": "src", 31 | "e": "data", 32 | "v": "dst", 33 | } 34 | 35 | 36 | def copy_u(u, out): 37 | """Builtin message function that computes message using source node 38 | feature. 39 | 40 | Parameters 41 | ---------- 42 | u : str 43 | The source feature field. 44 | out : str 45 | The output message field. 46 | """ 47 | return lambda edge: {out: edge.src[u]} 48 | 49 | # create alias 50 | copy_src = copy_u 51 | 52 | def copy_e(e, out): 53 | """Builtin message function that computes message using edge feature. 54 | 55 | Parameters 56 | ---------- 57 | e : str 58 | The edge feature field. 59 | out : str 60 | The output message field. 61 | """ 62 | return lambda edge: {out: edge.data[e]} 63 | 64 | # create alias 65 | copy_edge = copy_e 66 | 67 | def _gen_message_builtin(lhs, rhs, binary_op): 68 | name = "{}_{}_{}".format(lhs, binary_op, rhs) 69 | docstring = """Builtin message function that computes a message on an edge 70 | by performing element-wise {} between features of {} and {} 71 | if the features have the same shape; otherwise, it first broadcasts 72 | the features to a new shape and performs the element-wise operation. 73 | 74 | Broadcasting follows NumPy semantics. Please see 75 | https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html 76 | for more details about the NumPy broadcasting semantics. 77 | 78 | Parameters 79 | ---------- 80 | lhs_field : str 81 | The feature field of {}. 82 | rhs_field : str 83 | The feature field of {}. 84 | out : str 85 | The output message field. 86 | 87 | """.format( 88 | binary_op, 89 | CODE2STR[lhs], 90 | CODE2STR[rhs], 91 | CODE2STR[lhs], 92 | CODE2STR[rhs], 93 | ) 94 | 95 | # grab data field 96 | lhs_data, rhs_data = CODE2DATA[lhs], CODE2DATA[rhs] 97 | 98 | # define function 99 | def func(lhs_field, rhs_field, out): 100 | def fn(edges): 101 | return {out: 102 | CODE2OP[binary_op]( 103 | getattr(edges, lhs_data)[lhs_field], 104 | getattr(edges, rhs_data)[rhs_field], 105 | ) 106 | } 107 | return fn 108 | 109 | # attach name and doc 110 | func.__name__ = name 111 | func.__doc__ = docstring 112 | return func 113 | 114 | 115 | def _register_builtin_message_func(): 116 | """Register builtin message functions""" 117 | target = ["u", "v", "e"] 118 | for lhs, rhs in product(target, target): 119 | if lhs != rhs: 120 | for binary_op in ["add", "sub", "mul", "div", "dot"]: 121 | func = _gen_message_builtin(lhs, rhs, binary_op) 122 | setattr(sys.modules[__name__], func.__name__, func) 123 | 124 | 125 | _register_builtin_message_func() 126 | 127 | 128 | # ============================================================================= 129 | # REDUCE FUNCTIONS 130 | # ============================================================================= 131 | 132 | ReduceFunction = namedtuple( 133 | "ReduceFunction", ["op", "msg_field", "out_field"] 134 | ) 135 | 136 | sum = partial(ReduceFunction, "sum") 137 | mean = partial(ReduceFunction, "mean") 138 | max = partial(ReduceFunction, "max") 139 | min = partial(ReduceFunction, "min") 140 | 141 | segment_sum = jax.ops.segment_sum 142 | 143 | def segment_max(*args, **kwargs): 144 | """Alias of jax.ops.segment_max with nan_to_num.""" 145 | return jnp.nan_to_num( 146 | jax.ops.segment_max(*args, **kwargs), 147 | nan=0.0, 148 | posinf=0.0, 149 | neginf=0.0, 150 | ) 151 | 152 | def segment_min(*args, **kwargs): 153 | """Alias of jax.ops.segment_min with nan_to_num.""" 154 | return jnp.nan_to_num( 155 | jax.ops.segment_min(*args, **kwargs), 156 | nan=0.0, 157 | posinf=0.0, 158 | neginf=0.0, 159 | ) 160 | 161 | def segment_mean( 162 | data: jnp.ndarray, 163 | segment_ids: jnp.ndarray, 164 | num_segments: Optional[int] = None, 165 | indices_are_sorted: bool = False, 166 | unique_indices: bool = False, 167 | ): 168 | """Returns mean for each segment. 169 | 170 | Reference 171 | --------- 172 | * Shamelessly stolen from jraph.utils 173 | 174 | Parameters 175 | ---------- 176 | data : jnp.ndarray 177 | the values which are averaged segment-wise. 178 | segment_ids : jnp.ndarray 179 | indices for the segments. 180 | num_segments : Optional[int] 181 | total number of segments. 182 | indices_are_sorted : bool=False 183 | whether ``segment_ids`` is known to be sorted. 184 | unique_indices : bool=False 185 | whether ``segment_ids`` is known to be free of duplicates. 186 | 187 | Returns 188 | ------- 189 | jnp.ndarray 190 | The data after segmentation sum. 191 | """ 192 | nominator = segment_sum( 193 | data, 194 | segment_ids, 195 | num_segments, 196 | indices_are_sorted=indices_are_sorted, 197 | unique_indices=unique_indices, 198 | ) 199 | denominator = segment_sum( 200 | jnp.ones_like(data), 201 | segment_ids, 202 | num_segments, 203 | indices_are_sorted=indices_are_sorted, 204 | unique_indices=unique_indices, 205 | ) 206 | return nominator / jnp.maximum( 207 | denominator, jnp.ones(shape=[], dtype=denominator.dtype) 208 | ) 209 | 210 | def segment_softmax(data: jnp.ndarray, 211 | segment_ids: jnp.ndarray, 212 | num_segments: Optional[int] = None, 213 | indices_are_sorted: bool = False, 214 | unique_indices: bool = False): 215 | """Computes a segment-wise softmax. 216 | For a given tree of logits that can be divded into segments, computes a 217 | softmax over the segments. 218 | logits = jnp.ndarray([1.0, 2.0, 3.0, 1.0, 2.0]) 219 | segment_ids = jnp.ndarray([0, 0, 0, 1, 1]) 220 | segment_softmax(logits, segments) 221 | >> DeviceArray([0.09003057, 0.24472848, 0.66524094, 0.26894142, 0.7310586], 222 | >> dtype=float32) 223 | Args: 224 | logits: an array of logits to be segment softmaxed. 225 | segment_ids: an array with integer dtype that indicates the segments of 226 | `data` (along its leading axis) to be maxed over. Values can be repeated 227 | and need not be sorted. Values outside of the range [0, num_segments) are 228 | dropped and do not contribute to the result. 229 | num_segments: optional, an int with positive value indicating the number of 230 | segments. The default is ``jnp.maximum(jnp.max(segment_ids) + 1, 231 | jnp.max(-segment_ids))`` but since ``num_segments`` determines the size of 232 | the output, a static value must be provided to use ``segment_sum`` in a 233 | ``jit``-compiled function. 234 | indices_are_sorted: whether ``segment_ids`` is known to be sorted 235 | unique_indices: whether ``segment_ids`` is known to be free of duplicates 236 | Returns: 237 | The segment softmax-ed ``logits``. 238 | """ 239 | # First, subtract the segment max for numerical stability 240 | maxs = segment_max(data, segment_ids, num_segments, indices_are_sorted, 241 | unique_indices) 242 | logits = data - maxs[segment_ids] 243 | # Then take the exp 244 | logits = jnp.exp(logits) 245 | # Then calculate the normalizers 246 | normalizers = segment_sum(logits, segment_ids, num_segments, 247 | indices_are_sorted, unique_indices) 248 | normalizers = normalizers[segment_ids] 249 | softmax = logits / normalizers 250 | return softmax 251 | 252 | 253 | # ============================================================================= 254 | # APPLY FUNCTIONS 255 | # ============================================================================= 256 | def apply_nodes( 257 | function: Callable, 258 | in_field: str = "h", 259 | out_field: Optional[str] = None, 260 | ntype: Optional[str] = None, 261 | ): 262 | """Apply a function to node attributes. 263 | 264 | Parameters 265 | ---------- 266 | function : Callable 267 | Input function. 268 | in_field : str 269 | Input field 270 | out_field : str 271 | Output field. 272 | 273 | Returns 274 | ------- 275 | Callable 276 | Function that takes and returns a graph. 277 | 278 | Examples 279 | -------- 280 | Transform function. 281 | >>> import jax 282 | >>> import jax.numpy as jnp 283 | >>> import galax 284 | >>> graph = galax.graph(((0, 1), (1, 2))) 285 | >>> graph = graph.ndata.set("h", jnp.ones(3)) 286 | >>> fn = apply_nodes(lambda x: x * 2) 287 | >>> graph = jax.jit(fn)(graph) 288 | >>> graph.ndata['h'].tolist() 289 | [2.0, 2.0, 2.0] 290 | 291 | """ 292 | if out_field is None: 293 | out_field = in_field 294 | 295 | def _fn(graph, in_field=in_field, out_field=out_field, ntype=ntype): 296 | ntype_idx = graph.get_ntype_id(ntype) 297 | node_frame = unfreeze(graph.node_frames[ntype_idx]) 298 | node_frame[out_field] = function(node_frame[in_field]) 299 | node_frame = freeze(node_frame) 300 | node_frames = graph.node_frames[:ntype_idx] + (node_frame,)\ 301 | + graph.node_frames[ntype_idx+1:] 302 | return graph._replace(node_frames=node_frames) 303 | 304 | def __fn(graph, in_field=in_field, out_field=out_field): 305 | graph = graph.ndata.set(out_field, function(graph.ndata[in_field])) 306 | return graph 307 | 308 | if ntype is None: 309 | return __fn 310 | else: 311 | return _fn 312 | 313 | def apply_edges( 314 | function: Callable, 315 | in_field: str = "h", 316 | out_field: Optional[str] = None, 317 | etype: Optional[str] = None, 318 | ): 319 | """Apply a function to edge attributes. 320 | 321 | Parameters 322 | ---------- 323 | function : Callable 324 | Input function. 325 | in_field : str 326 | Input field 327 | out_field : str 328 | Output field. 329 | 330 | Returns 331 | ------- 332 | Callable 333 | Function that takes and returns a graph. 334 | 335 | Examples 336 | -------- 337 | Transform function. 338 | >>> import jax 339 | >>> import jax.numpy as jnp 340 | >>> import galax 341 | >>> graph = galax.graph(((0, 1), (1, 2))) 342 | >>> graph = graph.edata.set("h", jnp.ones(2)) 343 | >>> fn = apply_edges(lambda x: x * 3) 344 | >>> graph = jax.jit(fn)(graph) 345 | >>> graph.edata['h'].tolist() 346 | [3.0, 3.0] 347 | 348 | """ 349 | if out_field is None: 350 | out_field = in_field 351 | 352 | def _fn(graph, in_field=in_field, out_field=out_field, etype=etype): 353 | etype_idx = graph.get_etype_id(etype) 354 | edge_frame = unfreeze(graph.edge_frames[etype_idx]) 355 | edge_frame[out_field] = function(edge_frame[in_field]) 356 | edge_frame = freeze(edge_frame) 357 | edge_frames = graph.edge_frames[:etype_idx] + (edge_frame, )\ 358 | + graph.edge_frames[etype_idx+1:] 359 | return graph._replace(edge_frames=edge_frames) 360 | 361 | return _fn 362 | -------------------------------------------------------------------------------- /galax/_version.py: -------------------------------------------------------------------------------- 1 | # This file helps to compute a version number in source trees obtained from 2 | # git-archive tarball (such as those provided by githubs download-from-tag 3 | # feature). Distribution tarballs (built by setup.py sdist) and build 4 | # directories (produced by setup.py build) will contain a much shorter file 5 | # that just contains the computed version number. 6 | 7 | # This file is released into the public domain. Generated by 8 | # versioneer-0.18 (https://github.com/warner/python-versioneer) 9 | 10 | """Git implementation of _version.py.""" 11 | 12 | import errno 13 | import os 14 | import re 15 | import subprocess 16 | import sys 17 | 18 | 19 | def get_keywords(): 20 | """Get the keywords needed to look up the version information.""" 21 | # these strings will be replaced by git during git-archive. 22 | # setup.py/versioneer.py will grep for the variable names, so they must 23 | # each be defined on a line of their own. _version.py will just call 24 | # get_keywords(). 25 | git_refnames = "$Format:%d$" 26 | git_full = "$Format:%H$" 27 | git_date = "$Format:%ci$" 28 | keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} 29 | return keywords 30 | 31 | 32 | class VersioneerConfig: 33 | """Container for Versioneer configuration parameters.""" 34 | 35 | 36 | def get_config(): 37 | """Create, populate and return the VersioneerConfig() object.""" 38 | # these strings are filled in when 'setup.py versioneer' creates 39 | # _version.py 40 | cfg = VersioneerConfig() 41 | cfg.VCS = "git" 42 | cfg.style = "pep440" 43 | cfg.tag_prefix = "" 44 | cfg.parentdir_prefix = "None" 45 | cfg.versionfile_source = "galax/_version.py" 46 | cfg.verbose = False 47 | return cfg 48 | 49 | 50 | class NotThisMethod(Exception): 51 | """Exception raised if a method is not valid for the current scenario.""" 52 | 53 | 54 | LONG_VERSION_PY = {} 55 | HANDLERS = {} 56 | 57 | 58 | def register_vcs_handler(vcs, method): # decorator 59 | """Decorator to mark a method as the handler for a particular VCS.""" 60 | 61 | def decorate(f): 62 | """Store f in HANDLERS[vcs][method].""" 63 | if vcs not in HANDLERS: 64 | HANDLERS[vcs] = {} 65 | HANDLERS[vcs][method] = f 66 | return f 67 | 68 | return decorate 69 | 70 | 71 | def run_command( 72 | commands, args, cwd=None, verbose=False, hide_stderr=False, env=None 73 | ): 74 | """Call the given command(s).""" 75 | assert isinstance(commands, list) 76 | p = None 77 | for c in commands: 78 | try: 79 | dispcmd = str([c] + args) 80 | # remember shell=False, so use git.cmd on windows, not just git 81 | p = subprocess.Popen( 82 | [c] + args, 83 | cwd=cwd, 84 | env=env, 85 | stdout=subprocess.PIPE, 86 | stderr=(subprocess.PIPE if hide_stderr else None), 87 | ) 88 | break 89 | except EnvironmentError: 90 | e = sys.exc_info()[1] 91 | if e.errno == errno.ENOENT: 92 | continue 93 | if verbose: 94 | print("unable to run %s" % dispcmd) 95 | print(e) 96 | return None, None 97 | else: 98 | if verbose: 99 | print("unable to find command, tried %s" % (commands,)) 100 | return None, None 101 | stdout = p.communicate()[0].strip() 102 | if sys.version_info[0] >= 3: 103 | stdout = stdout.decode() 104 | if p.returncode != 0: 105 | if verbose: 106 | print("unable to run %s (error)" % dispcmd) 107 | print("stdout was %s" % stdout) 108 | return None, p.returncode 109 | return stdout, p.returncode 110 | 111 | 112 | def versions_from_parentdir(parentdir_prefix, root, verbose): 113 | """Try to determine the version from the parent directory name. 114 | 115 | Source tarballs conventionally unpack into a directory that includes both 116 | the project name and a version string. We will also support searching up 117 | two directory levels for an appropriately named parent directory 118 | """ 119 | rootdirs = [] 120 | 121 | for i in range(3): 122 | dirname = os.path.basename(root) 123 | if dirname.startswith(parentdir_prefix): 124 | return { 125 | "version": dirname[len(parentdir_prefix) :], 126 | "full-revisionid": None, 127 | "dirty": False, 128 | "error": None, 129 | "date": None, 130 | } 131 | else: 132 | rootdirs.append(root) 133 | root = os.path.dirname(root) # up a level 134 | 135 | if verbose: 136 | print( 137 | "Tried directories %s but none started with prefix %s" 138 | % (str(rootdirs), parentdir_prefix) 139 | ) 140 | raise NotThisMethod("rootdir doesn't start with parentdir_prefix") 141 | 142 | 143 | @register_vcs_handler("git", "get_keywords") 144 | def git_get_keywords(versionfile_abs): 145 | """Extract version information from the given file.""" 146 | # the code embedded in _version.py can just fetch the value of these 147 | # keywords. When used from setup.py, we don't want to import _version.py, 148 | # so we do it with a regexp instead. This function is not used from 149 | # _version.py. 150 | keywords = {} 151 | try: 152 | f = open(versionfile_abs, "r") 153 | for line in f.readlines(): 154 | if line.strip().startswith("git_refnames ="): 155 | mo = re.search(r'=\s*"(.*)"', line) 156 | if mo: 157 | keywords["refnames"] = mo.group(1) 158 | if line.strip().startswith("git_full ="): 159 | mo = re.search(r'=\s*"(.*)"', line) 160 | if mo: 161 | keywords["full"] = mo.group(1) 162 | if line.strip().startswith("git_date ="): 163 | mo = re.search(r'=\s*"(.*)"', line) 164 | if mo: 165 | keywords["date"] = mo.group(1) 166 | f.close() 167 | except EnvironmentError: 168 | pass 169 | return keywords 170 | 171 | 172 | @register_vcs_handler("git", "keywords") 173 | def git_versions_from_keywords(keywords, tag_prefix, verbose): 174 | """Get version information from git keywords.""" 175 | if not keywords: 176 | raise NotThisMethod("no keywords at all, weird") 177 | date = keywords.get("date") 178 | if date is not None: 179 | # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant 180 | # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 181 | # -like" string, which we must then edit to make compliant), because 182 | # it's been around since git-1.5.3, and it's too difficult to 183 | # discover which version we're using, or to work around using an 184 | # older one. 185 | date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) 186 | refnames = keywords["refnames"].strip() 187 | if refnames.startswith("$Format"): 188 | if verbose: 189 | print("keywords are unexpanded, not using") 190 | raise NotThisMethod("unexpanded keywords, not a git-archive tarball") 191 | refs = set([r.strip() for r in refnames.strip("()").split(",")]) 192 | # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of 193 | # just "foo-1.0". If we see a "tag: " prefix, prefer those. 194 | TAG = "tag: " 195 | tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) 196 | if not tags: 197 | # Either we're using git < 1.8.3, or there really are no tags. We use 198 | # a heuristic: assume all version tags have a digit. The old git %d 199 | # expansion behaves like git log --decorate=short and strips out the 200 | # refs/heads/ and refs/tags/ prefixes that would let us distinguish 201 | # between branches and tags. By ignoring refnames without digits, we 202 | # filter out many common branch names like "release" and 203 | # "stabilization", as well as "HEAD" and "master". 204 | tags = set([r for r in refs if re.search(r"\d", r)]) 205 | if verbose: 206 | print("discarding '%s', no digits" % ",".join(refs - tags)) 207 | if verbose: 208 | print("likely tags: %s" % ",".join(sorted(tags))) 209 | for ref in sorted(tags): 210 | # sorting will prefer e.g. "2.0" over "2.0rc1" 211 | if ref.startswith(tag_prefix): 212 | r = ref[len(tag_prefix) :] 213 | if verbose: 214 | print("picking %s" % r) 215 | return { 216 | "version": r, 217 | "full-revisionid": keywords["full"].strip(), 218 | "dirty": False, 219 | "error": None, 220 | "date": date, 221 | } 222 | # no suitable tags, so version is "0+unknown", but full hex is still there 223 | if verbose: 224 | print("no suitable tags, using unknown + full revision id") 225 | return { 226 | "version": "0+unknown", 227 | "full-revisionid": keywords["full"].strip(), 228 | "dirty": False, 229 | "error": "no suitable tags", 230 | "date": None, 231 | } 232 | 233 | 234 | @register_vcs_handler("git", "pieces_from_vcs") 235 | def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): 236 | """Get version from 'git describe' in the root of the source tree. 237 | 238 | This only gets called if the git-archive 'subst' keywords were *not* 239 | expanded, and _version.py hasn't already been rewritten with a short 240 | version string, meaning we're inside a checked out source tree. 241 | """ 242 | GITS = ["git"] 243 | if sys.platform == "win32": 244 | GITS = ["git.cmd", "git.exe"] 245 | 246 | out, rc = run_command( 247 | GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True 248 | ) 249 | if rc != 0: 250 | if verbose: 251 | print("Directory %s not under git control" % root) 252 | raise NotThisMethod("'git rev-parse --git-dir' returned error") 253 | 254 | # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] 255 | # if there isn't one, this yields HEX[-dirty] (no NUM) 256 | describe_out, rc = run_command( 257 | GITS, 258 | [ 259 | "describe", 260 | "--tags", 261 | "--dirty", 262 | "--always", 263 | "--long", 264 | "--match", 265 | "%s*" % tag_prefix, 266 | ], 267 | cwd=root, 268 | ) 269 | # --long was added in git-1.5.5 270 | if describe_out is None: 271 | raise NotThisMethod("'git describe' failed") 272 | describe_out = describe_out.strip() 273 | full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) 274 | if full_out is None: 275 | raise NotThisMethod("'git rev-parse' failed") 276 | full_out = full_out.strip() 277 | 278 | pieces = {} 279 | pieces["long"] = full_out 280 | pieces["short"] = full_out[:7] # maybe improved later 281 | pieces["error"] = None 282 | 283 | # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] 284 | # TAG might have hyphens. 285 | git_describe = describe_out 286 | 287 | # look for -dirty suffix 288 | dirty = git_describe.endswith("-dirty") 289 | pieces["dirty"] = dirty 290 | if dirty: 291 | git_describe = git_describe[: git_describe.rindex("-dirty")] 292 | 293 | # now we have TAG-NUM-gHEX or HEX 294 | 295 | if "-" in git_describe: 296 | # TAG-NUM-gHEX 297 | mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) 298 | if not mo: 299 | # unparseable. Maybe git-describe is misbehaving? 300 | pieces["error"] = ( 301 | "unable to parse git-describe output: '%s'" % describe_out 302 | ) 303 | return pieces 304 | 305 | # tag 306 | full_tag = mo.group(1) 307 | if not full_tag.startswith(tag_prefix): 308 | if verbose: 309 | fmt = "tag '%s' doesn't start with prefix '%s'" 310 | print(fmt % (full_tag, tag_prefix)) 311 | pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( 312 | full_tag, 313 | tag_prefix, 314 | ) 315 | return pieces 316 | pieces["closest-tag"] = full_tag[len(tag_prefix) :] 317 | 318 | # distance: number of commits since tag 319 | pieces["distance"] = int(mo.group(2)) 320 | 321 | # commit: short hex revision ID 322 | pieces["short"] = mo.group(3) 323 | 324 | else: 325 | # HEX: no tags 326 | pieces["closest-tag"] = None 327 | count_out, rc = run_command( 328 | GITS, ["rev-list", "HEAD", "--count"], cwd=root 329 | ) 330 | pieces["distance"] = int(count_out) # total number of commits 331 | 332 | # commit date: see ISO-8601 comment in git_versions_from_keywords() 333 | date = run_command( 334 | GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root 335 | )[0].strip() 336 | pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) 337 | 338 | return pieces 339 | 340 | 341 | def plus_or_dot(pieces): 342 | """Return a + if we don't already have one, else return a .""" 343 | if "+" in pieces.get("closest-tag", ""): 344 | return "." 345 | return "+" 346 | 347 | 348 | def render_pep440(pieces): 349 | """Build up version string, with post-release "local version identifier". 350 | 351 | Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you 352 | get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty 353 | 354 | Exceptions: 355 | 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] 356 | """ 357 | if pieces["closest-tag"]: 358 | rendered = pieces["closest-tag"] 359 | if pieces["distance"] or pieces["dirty"]: 360 | rendered += plus_or_dot(pieces) 361 | rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) 362 | if pieces["dirty"]: 363 | rendered += ".dirty" 364 | else: 365 | # exception #1 366 | rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) 367 | if pieces["dirty"]: 368 | rendered += ".dirty" 369 | return rendered 370 | 371 | 372 | def render_pep440_pre(pieces): 373 | """TAG[.post.devDISTANCE] -- No -dirty. 374 | 375 | Exceptions: 376 | 1: no tags. 0.post.devDISTANCE 377 | """ 378 | if pieces["closest-tag"]: 379 | rendered = pieces["closest-tag"] 380 | if pieces["distance"]: 381 | rendered += ".post.dev%d" % pieces["distance"] 382 | else: 383 | # exception #1 384 | rendered = "0.post.dev%d" % pieces["distance"] 385 | return rendered 386 | 387 | 388 | def render_pep440_post(pieces): 389 | """TAG[.postDISTANCE[.dev0]+gHEX] . 390 | 391 | The ".dev0" means dirty. Note that .dev0 sorts backwards 392 | (a dirty tree will appear "older" than the corresponding clean one), 393 | but you shouldn't be releasing software with -dirty anyways. 394 | 395 | Exceptions: 396 | 1: no tags. 0.postDISTANCE[.dev0] 397 | """ 398 | if pieces["closest-tag"]: 399 | rendered = pieces["closest-tag"] 400 | if pieces["distance"] or pieces["dirty"]: 401 | rendered += ".post%d" % pieces["distance"] 402 | if pieces["dirty"]: 403 | rendered += ".dev0" 404 | rendered += plus_or_dot(pieces) 405 | rendered += "g%s" % pieces["short"] 406 | else: 407 | # exception #1 408 | rendered = "0.post%d" % pieces["distance"] 409 | if pieces["dirty"]: 410 | rendered += ".dev0" 411 | rendered += "+g%s" % pieces["short"] 412 | return rendered 413 | 414 | 415 | def render_pep440_old(pieces): 416 | """TAG[.postDISTANCE[.dev0]] . 417 | 418 | The ".dev0" means dirty. 419 | 420 | Eexceptions: 421 | 1: no tags. 0.postDISTANCE[.dev0] 422 | """ 423 | if pieces["closest-tag"]: 424 | rendered = pieces["closest-tag"] 425 | if pieces["distance"] or pieces["dirty"]: 426 | rendered += ".post%d" % pieces["distance"] 427 | if pieces["dirty"]: 428 | rendered += ".dev0" 429 | else: 430 | # exception #1 431 | rendered = "0.post%d" % pieces["distance"] 432 | if pieces["dirty"]: 433 | rendered += ".dev0" 434 | return rendered 435 | 436 | 437 | def render_git_describe(pieces): 438 | """TAG[-DISTANCE-gHEX][-dirty]. 439 | 440 | Like 'git describe --tags --dirty --always'. 441 | 442 | Exceptions: 443 | 1: no tags. HEX[-dirty] (note: no 'g' prefix) 444 | """ 445 | if pieces["closest-tag"]: 446 | rendered = pieces["closest-tag"] 447 | if pieces["distance"]: 448 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) 449 | else: 450 | # exception #1 451 | rendered = pieces["short"] 452 | if pieces["dirty"]: 453 | rendered += "-dirty" 454 | return rendered 455 | 456 | 457 | def render_git_describe_long(pieces): 458 | """TAG-DISTANCE-gHEX[-dirty]. 459 | 460 | Like 'git describe --tags --dirty --always -long'. 461 | The distance/hash is unconditional. 462 | 463 | Exceptions: 464 | 1: no tags. HEX[-dirty] (note: no 'g' prefix) 465 | """ 466 | if pieces["closest-tag"]: 467 | rendered = pieces["closest-tag"] 468 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) 469 | else: 470 | # exception #1 471 | rendered = pieces["short"] 472 | if pieces["dirty"]: 473 | rendered += "-dirty" 474 | return rendered 475 | 476 | 477 | def render(pieces, style): 478 | """Render the given version pieces into the requested style.""" 479 | if pieces["error"]: 480 | return { 481 | "version": "unknown", 482 | "full-revisionid": pieces.get("long"), 483 | "dirty": None, 484 | "error": pieces["error"], 485 | "date": None, 486 | } 487 | 488 | if not style or style == "default": 489 | style = "pep440" # the default 490 | 491 | if style == "pep440": 492 | rendered = render_pep440(pieces) 493 | elif style == "pep440-pre": 494 | rendered = render_pep440_pre(pieces) 495 | elif style == "pep440-post": 496 | rendered = render_pep440_post(pieces) 497 | elif style == "pep440-old": 498 | rendered = render_pep440_old(pieces) 499 | elif style == "git-describe": 500 | rendered = render_git_describe(pieces) 501 | elif style == "git-describe-long": 502 | rendered = render_git_describe_long(pieces) 503 | else: 504 | raise ValueError("unknown style '%s'" % style) 505 | 506 | return { 507 | "version": rendered, 508 | "full-revisionid": pieces["long"], 509 | "dirty": pieces["dirty"], 510 | "error": None, 511 | "date": pieces.get("date"), 512 | } 513 | 514 | 515 | def get_versions(): 516 | """Get version information or return default if unable to do so.""" 517 | # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have 518 | # __file__, we can work backwards from there to the root. Some 519 | # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which 520 | # case we can only use expanded keywords. 521 | 522 | cfg = get_config() 523 | verbose = cfg.verbose 524 | 525 | try: 526 | return git_versions_from_keywords( 527 | get_keywords(), cfg.tag_prefix, verbose 528 | ) 529 | except NotThisMethod: 530 | pass 531 | 532 | try: 533 | root = os.path.realpath(__file__) 534 | # versionfile_source is the relative path from the top of the source 535 | # tree (where the .git directory might live) to this file. Invert 536 | # this to find the root from __file__. 537 | for i in cfg.versionfile_source.split("/"): 538 | root = os.path.dirname(root) 539 | except NameError: 540 | return { 541 | "version": "0+unknown", 542 | "full-revisionid": None, 543 | "dirty": None, 544 | "error": "unable to find root of source tree", 545 | "date": None, 546 | } 547 | 548 | try: 549 | pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) 550 | return render(pieces, cfg.style) 551 | except NotThisMethod: 552 | pass 553 | 554 | try: 555 | if cfg.parentdir_prefix: 556 | return versions_from_parentdir( 557 | cfg.parentdir_prefix, root, verbose 558 | ) 559 | except NotThisMethod: 560 | pass 561 | 562 | return { 563 | "version": "0+unknown", 564 | "full-revisionid": None, 565 | "dirty": None, 566 | "error": "unable to compute version", 567 | "date": None, 568 | } 569 | -------------------------------------------------------------------------------- /galax/graph_index.py: -------------------------------------------------------------------------------- 1 | """Module for graph index class definition. 2 | 3 | Inspired by: dgl.graph_index. 4 | """ 5 | from typing import NamedTuple, Optional, Tuple 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as onp 9 | from jax.experimental.sparse import BCOO 10 | 11 | # @register_pytree_node_class 12 | 13 | 14 | class GraphIndex(NamedTuple): 15 | """Graph index object. 16 | 17 | Attriubtes 18 | ---------- 19 | n_nodes : int 20 | The number of nodes in the graph. 21 | src : jnp.ndarray 22 | The indices of the source nodes, for each edge. 23 | dst : jnp.ndarray 24 | The indices of the destination nodes, for each edge. 25 | 26 | Notes 27 | ----- 28 | * All transformations returns new object rather than modify it in-place. 29 | * Not all functions are jittable. 30 | 31 | Examples 32 | -------- 33 | >>> g = GraphIndex() 34 | >>> assert g.n_nodes == 0 35 | >>> assert len(g.src) == 0 36 | >>> assert len(g.dst) == 0 37 | 38 | >>> g = GraphIndex(n_nodes=2, src=jnp.array([0]), dst=jnp.array([1])) 39 | >>> assert g.n_nodes == 2 40 | 41 | """ 42 | 43 | n_nodes: int = 0 44 | src: Optional[jnp.ndarray] = None 45 | dst: Optional[jnp.ndarray] = None 46 | 47 | # default empty array as src and src 48 | if src is None: 49 | src = jnp.array([], dtype=jnp.int32) 50 | if dst is None: 51 | dst = jnp.array([], dtype=jnp.int32) 52 | 53 | def add_nodes(self, num: int): 54 | """Add nodes. 55 | 56 | Parameters 57 | ---------- 58 | num : int 59 | Number of nodes to be added. 60 | 61 | Examples 62 | -------- 63 | >>> g = GraphIndex() 64 | >>> g_new = g.add_nodes(num=1) 65 | >>> (g.n_nodes, g_new.n_nodes) 66 | (0, 1) 67 | 68 | """ 69 | assert num >= 0, "Can only add positive number of nodes." 70 | return self._replace( 71 | n_nodes=self.n_nodes + num, 72 | ) 73 | 74 | def add_edge(self, u: int, v: int): 75 | """Add one edge. 76 | Parameters 77 | ---------- 78 | u : int 79 | The src node. 80 | v : int 81 | The dst node. 82 | 83 | Examples 84 | -------- 85 | >>> g = GraphIndex() 86 | >>> g = g.add_nodes(2) 87 | >>> g = g.add_edge(0, 1) 88 | >>> g.src.tolist() 89 | [0] 90 | >>> g.dst.tolist() 91 | [1] 92 | 93 | """ 94 | assert self.has_node(u) and self.has_node(v) 95 | return self._replace( 96 | src=jnp.concatenate([self.src, jnp.array([u])]), 97 | dst=jnp.concatenate([self.dst, jnp.array([v])]), 98 | ) 99 | 100 | def is_multigraph(self) -> bool: 101 | """Return whether the graph is a multigraph 102 | 103 | Returns 104 | ------- 105 | bool 106 | True if it is a multigraph, False otherwise. 107 | 108 | Examples 109 | -------- 110 | >>> GraphIndex( 111 | ... n_nodes=2, src=jnp.array([0]), dst=jnp.array([1]) 112 | ... ).is_multigraph() 113 | False 114 | 115 | >>> GraphIndex( 116 | ... n_nodes=2, src=jnp.array([0, 0]), dst=jnp.array([1, 1]) 117 | ... ).is_multigraph() 118 | True 119 | 120 | """ 121 | src_and_dst = jnp.stack([self.src, self.dst], axis=-1) 122 | return ( 123 | jnp.unique(src_and_dst, axis=0).shape[0] != src_and_dst.shape[0] 124 | ) 125 | 126 | def number_of_nodes(self) -> int: 127 | """Return the number of nodes. 128 | 129 | Returns 130 | ------- 131 | int 132 | The number of nodes 133 | 134 | Examples 135 | -------- 136 | >>> GraphIndex(2666).number_of_nodes() 137 | 2666 138 | 139 | """ 140 | return self.n_nodes 141 | 142 | def number_of_edges(self) -> int: 143 | """Return the number of edges. 144 | 145 | Returns 146 | ------- 147 | int 148 | The number of edges 149 | 150 | Examples 151 | -------- 152 | >>> g = GraphIndex(2) 153 | >>> g = g.add_edge(0, 1) 154 | >>> g.number_of_edges() 155 | 1 156 | """ 157 | return self.src.shape[0] 158 | 159 | def has_node(self, vid: int) -> bool: 160 | """Return true if the node exists. 161 | 162 | Parameters 163 | ---------- 164 | vid : int 165 | The nodes 166 | 167 | Returns 168 | ------- 169 | bool 170 | True if the node exists, False otherwise. 171 | 172 | Examples 173 | -------- 174 | >>> GraphIndex(5).has_node(5) 175 | False 176 | 177 | >>> GraphIndex(4).has_node(3) 178 | True 179 | 180 | """ 181 | assert vid >= 0, "Node does not exist. " 182 | return vid < self.number_of_nodes() 183 | 184 | def has_nodes(self, vids: jnp.ndarray) -> jnp.array: 185 | """Return true if the nodes exist. 186 | 187 | Parameters 188 | ---------- 189 | vid : jnp.ndarray 190 | The nodes 191 | 192 | Returns 193 | ------- 194 | jnp.ndarray 195 | 0-1 array indicating existence 196 | 197 | Examples 198 | -------- 199 | >>> g = GraphIndex(2) 200 | >>> vids = jnp.array([0, 1, 2]) 201 | >>> g.has_nodes(vids).tolist() 202 | [1, 1, 0] 203 | 204 | """ 205 | assert (vids >= 0).all(), "Node does not exist. " 206 | return 1 * (vids < self.number_of_nodes()) 207 | 208 | def has_edge_between(self, u: int, v: int) -> bool: 209 | """Return true if the edge exists. 210 | 211 | Parameters 212 | ---------- 213 | u : int 214 | The src node. 215 | v : int 216 | The dst node. 217 | 218 | Returns 219 | ------- 220 | bool 221 | True if the edge exists, False otherwise 222 | 223 | Examples 224 | -------- 225 | >>> g = GraphIndex(2, src=jnp.array([0]), dst=jnp.array([1])) 226 | >>> assert g.has_edge_between(0, 1) 227 | >>> assert ~g.has_edge_between(0, 0) 228 | 229 | """ 230 | assert self.has_node(u) and self.has_node(v), "Node does not exist. " 231 | u_in_src = u == self.src 232 | v_in_dst = v == self.dst 233 | return (u_in_src * v_in_dst).any() 234 | 235 | def has_edges_between(self, u: int, v: int) -> jnp.array: 236 | """Return true if the edge exists. 237 | 238 | Parameters 239 | ---------- 240 | u : jnp.ndarray 241 | The src nodes. 242 | v : jnp.ndarray 243 | The dst nodes. 244 | 245 | Returns 246 | ------- 247 | jnp.ndarray 248 | 0-1 array indicating existence 249 | 250 | Examples 251 | -------- 252 | >>> g = GraphIndex( 253 | ... n_nodes=5, src=jnp.array([0, 1]), dst=jnp.array([1, 2]) 254 | ... ) 255 | >>> g.has_edges_between( 256 | ... jnp.array([0, 1, 2]), jnp.array([1, 2, 3]), 257 | ... ).tolist() 258 | [1, 1, 0] 259 | 260 | """ 261 | assert ( 262 | self.has_nodes(u).all() and self.has_nodes(v).all() 263 | ), "Node does not exist. " 264 | 265 | u = jnp.expand_dims(u, -1) 266 | v = jnp.expand_dims(v, -1) 267 | src = jnp.expand_dims(self.src, 0) 268 | dst = jnp.expand_dims(self.dst, 0) 269 | 270 | u_is_src = u == src 271 | v_is_dst = v == dst 272 | 273 | return 1 * (u_is_src * v_is_dst).any(axis=-1) 274 | 275 | def eid(self, u: int, v: int) -> jnp.array: 276 | """Return the id array of all edges between u and v. 277 | 278 | Parameters 279 | ---------- 280 | u : int 281 | The src node. 282 | v : int 283 | The dst node. 284 | 285 | Returns 286 | ------- 287 | jnp.ndarray 288 | The edge id array. 289 | 290 | Examples 291 | -------- 292 | >>> g = GraphIndex( 293 | ... 5, src=jnp.array([0, 0, 3]), dst=jnp.array([1, 1, 2]) 294 | ... ) 295 | >>> g.eid(0, 1).tolist() 296 | [0, 1] 297 | """ 298 | assert self.has_node(u) and self.has_node(v), "Node does not exist. " 299 | u_in_src = u == self.src 300 | v_in_dst = v == self.dst 301 | return jnp.where(u_in_src * v_in_dst)[0] 302 | 303 | def find_edge(self, eid: int) -> Tuple[int]: 304 | """Return the edge tuple of the given id. 305 | 306 | Parameters 307 | ---------- 308 | eid : int 309 | The edge id. 310 | 311 | Returns 312 | ------- 313 | int 314 | src node id 315 | int 316 | dst node id 317 | 318 | >>> g = GraphIndex( 319 | ... 5, src=jnp.array([0, 0, 3]), dst=jnp.array([1, 1, 2]) 320 | ... ) 321 | >>> src, dst = g.find_edge(0) 322 | >>> (int(src), int(dst)) 323 | (0, 1) 324 | 325 | """ 326 | # assert eid < len(self.src) 327 | return self.src[eid], self.dst[eid] 328 | 329 | def find_edges(self, eid: jnp.ndarray) -> Tuple[jnp.array]: 330 | """Return the source and destination nodes that contain the eids. 331 | 332 | Parameters 333 | ---------- 334 | eid : jnp.ndarray 335 | The edge ids. 336 | 337 | Returns 338 | ------- 339 | jnp.ndarray 340 | The src nodes. 341 | jnp.ndarray 342 | The dst nodes. 343 | 344 | Examples 345 | -------- 346 | >>> g = GraphIndex( 347 | ... 10, jnp.array([2, 3]), jnp.array([3, 4]), 348 | ... ) 349 | >>> src, dst = g.find_edges(jnp.array([0, 1])) 350 | >>> src.tolist(), dst.tolist() 351 | ([2, 3], [3, 4]) 352 | 353 | """ 354 | assert (eid < len(self.src)).all() 355 | return self.src[eid], self.dst[eid] 356 | 357 | def in_edges(self, v: int): 358 | """Return the in edges of the node(s). 359 | 360 | Parameters 361 | ---------- 362 | v : int 363 | The node. 364 | 365 | Returns 366 | ------- 367 | jnp.ndarray 368 | The src nodes. 369 | jnp.ndarray 370 | The dst nodes. 371 | jnp.ndarray 372 | The edge ids. 373 | """ 374 | assert self.has_node(v), "Node does not exist. " 375 | v_is_dst = v == self.dst 376 | eids = jnp.arange(self.src.shape[0])[v_is_dst] 377 | src = self.src[eids] 378 | dst = self.dst[eids] 379 | return src, dst, eids 380 | 381 | def out_edges(self, v: int): 382 | """Return the out edges of the node(s). 383 | 384 | Parameters 385 | ---------- 386 | v : int 387 | The node. 388 | 389 | Returns 390 | ------- 391 | jnp.ndarray 392 | The src nodes. 393 | jnp.ndarray 394 | The dst nodes. 395 | jnp.ndarray 396 | The edge ids. 397 | """ 398 | assert self.has_node(v), "Node does not exist. " 399 | v_is_src = v == self.src 400 | eids = jnp.arange(self.src.shape[0])[v_is_src] 401 | src = self.src[eids] 402 | dst = self.dst[eids] 403 | return src, dst, eids 404 | 405 | @staticmethod 406 | def _reindex_after_remove( 407 | original_index: jnp.ndarray, removed_index: jnp.ndarray 408 | ): 409 | """Reindex an array after removing some indicies. 410 | 411 | Parameters 412 | ---------- 413 | original_index : jnp.ndarray 414 | Original indicies. 415 | removed_index : jnp.ndarray 416 | Indicies that are removed. 417 | 418 | Returns 419 | ------- 420 | jnp.ndarray 421 | New indicies. 422 | 423 | Examples 424 | -------- 425 | >>> original_index = jnp.array([1, 2, 3, 4, 7, 10]) 426 | >>> removed_index = jnp.array([2, 7]) 427 | >>> new_index = GraphIndex._reindex_after_remove( 428 | ... original_index=original_index, removed_index=removed_index, 429 | ... ) 430 | >>> new_index.tolist() 431 | [1, 2, 3, 8] 432 | """ 433 | is_removed = ( 434 | jnp.expand_dims(original_index, -1) 435 | == jnp.expand_dims(removed_index, 0) 436 | ).any(axis=-1) 437 | 438 | new_index = original_index[~is_removed] 439 | 440 | def get_new_index(old_index): 441 | offset = (old_index > removed_index).sum() 442 | return old_index - offset 443 | 444 | new_index = jax.lax.map(get_new_index, new_index) 445 | return new_index 446 | 447 | def remove_nodes(self, nids: jnp.ndarray): 448 | """Remove nodes. The edges connected to the nodes will too be removed. 449 | 450 | Parameters 451 | ---------- 452 | nids : jnp.ndarray 453 | Nodes to remove. 454 | 455 | Returns 456 | ------- 457 | GraphIndex 458 | A new graph with nodes removed. 459 | 460 | Examples 461 | -------- 462 | >>> g = GraphIndex( 463 | ... 10, jnp.array([2, 3]), jnp.array([3, 4]), 464 | ... ) 465 | >>> _g = g.remove_nodes(jnp.array([2])) 466 | >>> _g.src.tolist() 467 | [2] 468 | >>> _g.dst.tolist() 469 | [3] 470 | """ 471 | assert self.has_nodes(nids).all(), "Node does not exist. " 472 | v_is_src = jnp.expand_dims(nids, -1) == self.src 473 | v_is_dst = jnp.expand_dims(nids, -1) == self.dst 474 | v_is_in_edge = (v_is_src + v_is_dst).any(axis=0) 475 | eids = jnp.where(v_is_in_edge)[0] 476 | src = jnp.delete(self.src, eids) 477 | dst = jnp.delete(self.dst, eids) 478 | src = self._reindex_after_remove(src, nids) 479 | dst = self._reindex_after_remove(dst, nids) 480 | n_nodes = self.n_nodes - len(nids) 481 | return self.__class__( 482 | n_nodes, 483 | src=src, 484 | dst=dst, 485 | ) 486 | 487 | def remove_node(self, nid: int): 488 | """Remove a node. 489 | The edges connected to the nodes will too be removed. 490 | 491 | Parameters 492 | ---------- 493 | nid : int 494 | Node to remove. 495 | 496 | Returns 497 | ------- 498 | GraphIndex 499 | A new graph with nodes removed. 500 | """ 501 | nids = jnp.array([nid]) 502 | return self.remove_nodes(nids) 503 | 504 | def remove_edges(self, eids: jnp.ndarray): 505 | """Remove edges. 506 | 507 | Parameters 508 | ---------- 509 | eids : jnp.ndarray 510 | Edges to remove. 511 | 512 | Returns 513 | ------- 514 | GraphIndex 515 | A new graph with edges removed. 516 | 517 | """ 518 | assert (eids < len(self.src)).all(), "Edge does not exist. " 519 | src = jnp.delete(self.src, eids) 520 | dst = jnp.delete(self.dst, eids) 521 | return self._replace( 522 | src=src, 523 | dst=dst, 524 | ) 525 | 526 | def remove_edge(self, eid: int): 527 | """Remove edges. 528 | 529 | Parameters 530 | ---------- 531 | eids : jnp.ndarray 532 | Edges to remove. 533 | 534 | Returns 535 | ------- 536 | GraphIndex 537 | A new graph with edges removed. 538 | 539 | """ 540 | eids = jnp.array([eid]) 541 | return self.remove_edges(eids) 542 | 543 | def edges(self, order: Optional[str] = None) -> Tuple[jnp.array]: 544 | """Return all the edges. 545 | 546 | Parameters 547 | ---------- 548 | order : string 549 | The order of the returned edges. Currently support: 550 | - 'srcdst' : sorted by their src and dst ids. 551 | - 'eid' : sorted by edge Ids. 552 | - None : the arbitrary order. 553 | 554 | Returns 555 | ------- 556 | jnp.ndarray 557 | The src nodes. 558 | jnp.ndarray 559 | The dst nodes. 560 | jnp.ndarray 561 | The edge ids. 562 | 563 | Examples 564 | -------- 565 | >>> g = GraphIndex(6, jnp.array([0, 1, 2]), jnp.array([3, 4, 5])) 566 | >>> src, dst, eid = g.edges() 567 | >>> src.tolist(), dst.tolist(), eid.tolist() 568 | ([0, 1, 2], [3, 4, 5], [0, 1, 2]) 569 | 570 | """ 571 | src, dst, eid = self.src, self.dst, jnp.arange(len(self.src)) 572 | if order == "srcdst": 573 | idxs = jnp.lexsort((src, dst)) 574 | src, dst, eid = src[idxs], dst[idxs], eid[idxs] 575 | return src, dst, eid 576 | 577 | all_edges = edges 578 | 579 | def in_degree(self, v: int) -> int: 580 | """Return the in degree of the node. 581 | 582 | Parameters 583 | ---------- 584 | v : int 585 | The node. 586 | 587 | Returns 588 | ------- 589 | int 590 | The in degree. 591 | 592 | Examples 593 | -------- 594 | >>> g = GraphIndex(6, jnp.array([0, 1, 2]), jnp.array([3, 3, 3])) 595 | >>> int(g.in_degree(3)) 596 | 3 597 | 598 | """ 599 | assert self.has_node(v), "Node does not exist. " 600 | return (v == self.dst).sum() 601 | 602 | def in_degrees(self, v: jnp.array) -> jnp.array: 603 | """Return the in degrees of the nodes. 604 | 605 | Parameters 606 | ---------- 607 | v : jnp.ndarray 608 | The nodes. 609 | 610 | Returns 611 | ------- 612 | tensor 613 | The in degree array. 614 | 615 | Examples 616 | -------- 617 | >>> g = GraphIndex(6, jnp.array([0, 1, 2]), jnp.array([3, 3, 3])) 618 | >>> g.in_degrees(jnp.array([0, 1, 2, 3])).tolist() 619 | [0, 0, 0, 3] 620 | """ 621 | assert self.has_nodes(v).all(), "Node does not exist. " 622 | v = jnp.expand_dims(v, -1) 623 | dst = jnp.expand_dims(self.dst, 0) 624 | 625 | # (len(v), len(dst)) 626 | v_is_dst = v == dst 627 | 628 | return v_is_dst.sum(axis=-1) 629 | 630 | def out_degree(self, v: int) -> int: 631 | """Return the out degree of the node. 632 | 633 | Parameters 634 | ---------- 635 | v : int 636 | The node. 637 | 638 | Returns 639 | ------- 640 | int 641 | The out degree. 642 | 643 | Examples 644 | -------- 645 | >>> g = GraphIndex(6, jnp.array([0, 0, 0]), jnp.array([1, 2, 3])) 646 | >>> int(g.out_degree(0)) 647 | 3 648 | """ 649 | assert self.has_node(v), "Node does not exist. " 650 | return (v == self.src).sum() 651 | 652 | def out_degrees(self, v: jnp.array) -> jnp.array: 653 | """Return the out degrees of the nodes. 654 | 655 | Parameters 656 | ---------- 657 | v : jnp.ndarray 658 | The nodes. 659 | 660 | Returns 661 | ------- 662 | tensor 663 | The out degree array. 664 | 665 | Examples 666 | -------- 667 | >>> g = GraphIndex(6, jnp.array([0, 0, 0]), jnp.array([1, 2, 3])) 668 | >>> g.out_degrees(jnp.array([0, 0, 1])).tolist() 669 | [3, 3, 0] 670 | """ 671 | assert self.has_nodes(v).all(), "Node does not exist. " 672 | v = jnp.expand_dims(v, -1) 673 | src = jnp.expand_dims(self.src, 0) 674 | 675 | # (len(v), len(src)) 676 | v_is_dst = v == src 677 | 678 | return v_is_dst.sum(axis=-1) 679 | 680 | def edge_ids(self, u: int, v: int): 681 | """Return the edge id between two nodes. 682 | 683 | Parameters 684 | ---------- 685 | u : int 686 | Source node. 687 | v : int 688 | Destination node. 689 | 690 | Returns 691 | ------- 692 | jnp.ndarray 693 | Edge ids. 694 | """ 695 | assert self.has_node(v) and self.has_node(u), "Node does not exist. " 696 | return jnp.where((self.src == u) * (self.dst == v))[0] 697 | 698 | def adjacency_matrix_scipy( 699 | self, 700 | transpose: bool = False, 701 | fmt: str = "coo", 702 | return_edge_ids: Optional[bool] = None, 703 | ): 704 | """Return the scipy adjacency matrix representation of this graph. 705 | 706 | By default, a row of returned adjacency matrix represents the 707 | destination of an edge and the column represents the source. 708 | 709 | When transpose is True, a row represents the source and a column 710 | represents a destination. 711 | 712 | Parameters 713 | ---------- 714 | transpose : bool, default=False 715 | A flag to transpose the returned adjacency matrix. 716 | fmt : str, default="coo" 717 | Indicates the format of returned adjacency matrix. 718 | return_edge_ids : bool 719 | Indicates whether to return edge IDs or 1 as elements. 720 | 721 | Returns 722 | ------- 723 | scipy.sparse.spmatrix 724 | The scipy representation of adjacency matrix. 725 | 726 | Examples 727 | -------- 728 | >>> g = GraphIndex(4, jnp.array([0, 1, 2]), jnp.array([1, 2, 3])) 729 | >>> adj = g.adjacency_matrix_scipy() 730 | >>> adj.toarray() 731 | array([[0, 0, 0, 0], 732 | [1, 0, 0, 0], 733 | [0, 1, 0, 0], 734 | [0, 0, 1, 0]], dtype=int32) 735 | """ 736 | if return_edge_ids is None: 737 | return_edge_ids = False 738 | 739 | if fmt is not "coo": 740 | raise NotImplementedError 741 | 742 | n = self.number_of_nodes() 743 | m = self.number_of_edges() 744 | if transpose: 745 | row, col = onp.array(self.src), onp.array(self.dst) 746 | else: 747 | row, col = onp.array(self.dst), onp.array(self.src) 748 | data = onp.arange(0, m) if return_edge_ids else onp.ones_like(row) 749 | import scipy 750 | 751 | return scipy.sparse.coo_matrix((data, (row, col)), shape=(n, n)) 752 | 753 | def adjacency_matrix( 754 | self, 755 | transpose: bool = False, 756 | ) -> BCOO: 757 | """Return the adjacency matrix representation of this graph. 758 | 759 | By default, a row of returned adjacency matrix represents the destination 760 | of an edge and the column represents the source. 761 | 762 | When transpose is True, a row represents the source and a column represents 763 | a destination. 764 | 765 | Parameters 766 | ---------- 767 | transpose : bool 768 | A flag to transpose the returned adjacency matrix. 769 | 770 | 771 | Returns 772 | ------- 773 | SparseTensor 774 | The adjacency matrix. 775 | jnp.ndarray 776 | A index for data shuffling due to sparse format change. Return None 777 | if shuffle is not required. 778 | 779 | Examples 780 | -------- 781 | >>> g = GraphIndex(4, jnp.array([0, 1, 2]), jnp.array([1, 2, 3])) 782 | >>> adj = g.adjacency_matrix() 783 | >>> onp.array(adj.todense()) 784 | array([[0, 0, 0, 0], 785 | [1, 0, 0, 0], 786 | [0, 1, 0, 0], 787 | [0, 0, 1, 0]], dtype=int32) 788 | """ 789 | m = self.number_of_edges() 790 | n = self.number_of_nodes() 791 | if transpose: 792 | row, col = onp.array(self.src), onp.array(self.dst) 793 | else: 794 | row, col = onp.array(self.dst), onp.array(self.src) 795 | idx = jnp.stack([row, col], axis=-1) 796 | data = jnp.ones((m,), dtype=jnp.int32) 797 | shape = (n, n) 798 | spmat = BCOO((data, idx), shape=shape) 799 | return spmat 800 | 801 | adj = adjacency_matrix 802 | 803 | def incidence_matrix( 804 | self, 805 | typestr: str, 806 | ) -> BCOO: 807 | """Return the incidence matrix representation of this graph. 808 | 809 | An incidence matrix is an n x m sparse matrix, where n is 810 | the number of nodes and m is the number of edges. Each nnz 811 | value indicating whether the edge is incident to the node 812 | or not. 813 | 814 | There are three types of an incidence matrix `I`: 815 | * "in": 816 | - I[v, e] = 1 if e is the in-edge of v (or v is the dst node of e); 817 | - I[v, e] = 0 otherwise. 818 | * "out": 819 | - I[v, e] = 1 if e is the out-edge of v (or v is the src node of e); 820 | - I[v, e] = 0 otherwise. 821 | * "both": 822 | - I[v, e] = 1 if e is the in-edge of v; 823 | - I[v, e] = -1 if e is the out-edge of v; 824 | - I[v, e] = 0 otherwise (including self-loop). 825 | 826 | Parameters 827 | ---------- 828 | typestr : str 829 | Can be either "in", "out" or "both" 830 | 831 | Returns 832 | ------- 833 | SparseTensor 834 | The incidence matrix. 835 | jnp.ndarray 836 | A index for data shuffling due to sparse format change. Return None 837 | if shuffle is not required. 838 | 839 | Examples 840 | -------- 841 | >>> g = GraphIndex(4, jnp.array([0, 1, 2]), jnp.array([1, 2, 3])) 842 | >>> adj = g.incidence_matrix("in") 843 | >>> onp.array(adj.todense()) 844 | array([[0, 0, 0], 845 | [1, 0, 0], 846 | [0, 1, 0], 847 | [0, 0, 1]], dtype=int32) 848 | 849 | >>> adj = g.incidence_matrix("out") 850 | >>> onp.array(adj.todense()) 851 | array([[1, 0, 0], 852 | [0, 1, 0], 853 | [0, 0, 1], 854 | [0, 0, 0]], dtype=int32) 855 | 856 | >>> adj = g.incidence_matrix("both") 857 | >>> onp.array(adj.todense()) 858 | array([[-1, 0, 0], 859 | [ 1, -1, 0], 860 | [ 0, 1, -1], 861 | [ 0, 0, 1]], dtype=int32) 862 | 863 | """ 864 | src, dst, eid = self.edges() 865 | n = self.number_of_nodes() 866 | m = self.number_of_edges() 867 | if typestr == "in": 868 | row, col = dst, eid 869 | idx = jnp.stack([row, col], axis=-1) 870 | dat = jnp.ones((m,), dtype=jnp.int32) 871 | inc = BCOO((dat, idx), shape=(n, m)) 872 | elif typestr == "out": 873 | row, col = src, eid 874 | idx = jnp.stack([row, col], axis=-1) 875 | dat = jnp.ones((m,), dtype=jnp.int32) 876 | inc = BCOO((dat, idx), shape=(n, m)) 877 | elif typestr == "both": 878 | # first remove entries for self loops 879 | mask = src != dst 880 | src = src[mask] 881 | dst = dst[mask] 882 | eid = eid[mask] 883 | n_entries = src.shape[0] 884 | # create index 885 | row = jnp.concatenate([src, dst], axis=0) 886 | col = jnp.concatenate([eid, eid], axis=0) 887 | idx = jnp.stack([row, col], axis=-1) 888 | # FIXME(minjie): data type 889 | x = -jnp.ones((n_entries,), dtype=jnp.int32) 890 | y = jnp.ones((n_entries,), dtype=jnp.int32) 891 | dat = jnp.concatenate([x, y], axis=0) 892 | inc = BCOO((dat, idx), shape=(n, m)) 893 | return inc 894 | 895 | inc = incidence_matrix 896 | 897 | def to_networkx(self): 898 | """Convert to networkx graph. 899 | 900 | The edge id will be saved as the 'id' edge attribute. 901 | 902 | Returns 903 | ------- 904 | networkx.DiGraph 905 | The nx graph 906 | 907 | Examples 908 | -------- 909 | >>> g = GraphIndex(4, jnp.array([0, 1, 2]), jnp.array([1, 2, 3])) 910 | >>> import networkx as nx 911 | >>> g_nx = g.to_networkx() 912 | >>> assert isinstance(g_nx, nx.DiGraph) 913 | >>> g_nx.number_of_nodes() 914 | 4 915 | >>> g_nx.number_of_edges() 916 | 3 917 | 918 | """ 919 | src, dst, eid = self.edges() 920 | # xiangsx: Always treat graph as multigraph 921 | import networkx as nx 922 | 923 | ret = nx.MultiDiGraph() 924 | ret.add_nodes_from(range(self.number_of_nodes())) 925 | for u, v, e in zip(src, dst, eid): 926 | u, v, e = int(u), int(v), int(e) 927 | ret.add_edge(u, v, id=e) 928 | return ret 929 | 930 | def reverse(self): 931 | """Reverse the heterogeneous graph adjacency. 932 | 933 | Returns 934 | ------- 935 | GraphIndex 936 | A new graph index. 937 | """ 938 | return self.__class__( 939 | n_nodes=n_nodes, 940 | src=self.dst, 941 | dst=self.src, 942 | ) 943 | 944 | @classmethod 945 | def from_dgl(cls, graph): 946 | src, dst, _ = graph.edges() 947 | n_nodes = int(graph.number_of_nodes()) 948 | src, dst = jnp.array(src), jnp.array(dst) 949 | return cls(n_nodes=n_nodes, src=src, dst=dst) 950 | 951 | def __eq__(self, other): 952 | """Return if two graph index are identical.""" 953 | if type(self) != type(other): 954 | return False 955 | if len(self.src) != len(other.src): 956 | return False 957 | if len(self.dst) != len(other.dst): 958 | return False 959 | return ( 960 | other.n_nodes == self.n_nodes 961 | and (self.src == other.src).all() 962 | and (self.dst == other.dst).all() 963 | ) 964 | 965 | 966 | def from_coo( 967 | num_nodes: int, src: jnp.ndarray, dst: jnp.ndarray 968 | ) -> GraphIndex: 969 | """Convert from coo arrays. 970 | 971 | Parameters 972 | ---------- 973 | num_nodes : int 974 | Number of nodes. 975 | src : Tensor 976 | Src end nodes of the edges. 977 | dst : Tensor 978 | Dst end nodes of the edges. 979 | 980 | Returns 981 | ------- 982 | GraphIndex 983 | The graph index. 984 | 985 | Examples 986 | -------- 987 | >>> g = from_coo(4, jnp.array([0, 1, 2]), jnp.array([1, 2, 3])) 988 | >>> assert isinstance(g, GraphIndex) 989 | >>> g.number_of_nodes() 990 | 4 991 | >>> g.number_of_edges() 992 | 3 993 | """ 994 | return GraphIndex( 995 | n_nodes=num_nodes, 996 | src=src, 997 | dst=dst, 998 | ) 999 | 1000 | 1001 | def from_networkx(nx_graph): 1002 | """Convert from networkx graph. 1003 | 1004 | If 'id' edge attribute exists, the edge will be added follows 1005 | the edge id order. Otherwise, order is undefined. 1006 | 1007 | Parameters 1008 | ---------- 1009 | nx_graph : networkx.DiGraph 1010 | The nx graph or any graph that can be converted to nx.DiGraph 1011 | 1012 | Returns 1013 | ------- 1014 | GraphIndex 1015 | The graph index. 1016 | """ 1017 | if not isinstance(nx_graph, nx.Graph): 1018 | nx_graph = nx.DiGraph(nx_graph) 1019 | else: 1020 | if not nx_graph.is_directed(): 1021 | # to_directed creates a deep copy of the networkx graph even if 1022 | # the original graph is already directed and we do not want to do it. 1023 | nx_graph = nx_graph.to_directed() 1024 | num_nodes = nx_graph.number_of_nodes() 1025 | 1026 | # nx_graph.edges(data=True) returns src, dst, attr_dict 1027 | if nx_graph.number_of_edges() > 0: 1028 | has_edge_id = "id" in next(iter(nx_graph.edges(data=True)))[-1] 1029 | else: 1030 | has_edge_id = False 1031 | 1032 | if has_edge_id: 1033 | num_edges = nx_graph.number_of_edges() 1034 | src = np.zeros((num_edges,), dtype=np.int64) 1035 | dst = np.zeros((num_edges,), dtype=np.int64) 1036 | for u, v, attr in nx_graph.edges(data=True): 1037 | eid = attr["id"] 1038 | src[eid] = u 1039 | dst[eid] = v 1040 | else: 1041 | src = [] 1042 | dst = [] 1043 | for e in nx_graph.edges: 1044 | src.append(e[0]) 1045 | dst.append(e[1]) 1046 | # We store edge Ids as an edge attribute. 1047 | src = jnp.array(src) 1048 | dst = jnp.array(dst) 1049 | return from_coo(num_nodes, src, dst) 1050 | 1051 | 1052 | def from_scipy_sparse_matrix(adj): 1053 | """Convert from scipy sparse matrix. 1054 | 1055 | Parameters 1056 | ---------- 1057 | adj : scipy sparse matrix 1058 | 1059 | Returns 1060 | ------- 1061 | GraphIndex 1062 | The graph index. 1063 | """ 1064 | num_nodes = max(adj.shape[0], adj.shape[1]) 1065 | adj_coo = adj.tocoo() 1066 | return from_coo(num_nodes, adj_coo.row, adj_coo.col) 1067 | 1068 | 1069 | def from_edge_list(elist, readonly): 1070 | """Convert from an edge list. 1071 | 1072 | Parameters 1073 | --------- 1074 | elist : list, tuple 1075 | List of (u, v) edge tuple, or a tuple of src/dst lists 1076 | 1077 | """ 1078 | if isinstance(elist, tuple): 1079 | src, dst = elist 1080 | else: 1081 | src, dst = zip(*elist) 1082 | src_ids = jnp.asarray(src) 1083 | dst_ids = jnp.asarray(dst) 1084 | num_nodes = max(src.max(), dst.max()) + 1 1085 | return from_coo(num_nodes, src_ids, dst_ids) 1086 | 1087 | 1088 | def create_graph_index(graph_data): 1089 | """Create a graph index object. 1090 | 1091 | Parameters 1092 | ---------- 1093 | graph_data : graph data 1094 | Data to initialize graph. Same as networkx's semantics. 1095 | 1096 | """ 1097 | if isinstance(graph_data, GraphIndex): 1098 | # FIXME(minjie): this return is not correct for mutable graph index 1099 | return graph_data 1100 | 1101 | if graph_data is None: 1102 | return GraphIndex() 1103 | elif isinstance(graph_data, (list, tuple)): 1104 | # edge list 1105 | return from_edge_list(graph_data) 1106 | elif isinstance(graph_data, scipy.sparse.spmatrix): 1107 | # scipy format 1108 | return from_scipy_sparse_matrix(graph_data) 1109 | else: 1110 | try: 1111 | gidx = from_networkx(graph_data) 1112 | except Exception: 1113 | raise RuntimeError( 1114 | "Error while creating graph from input of type %s" 1115 | % type(graph_data) 1116 | ) 1117 | return gidx 1118 | --------------------------------------------------------------------------------