├── .github
└── workflows
│ └── sar_test.yaml
├── .gitignore
├── .readthedocs.yaml
├── LICENSE
├── README.md
├── docs
├── Makefile
└── source
│ ├── _templates
│ └── autosummary
│ │ ├── distneighborsampler.rst
│ │ └── graphshardmanager.rst
│ ├── comm.rst
│ ├── common_tuples.rst
│ ├── conf.py
│ ├── data_loading.rst
│ ├── full_batch.rst
│ ├── images
│ ├── dom_parallel_naive.png
│ ├── dom_parallel_naive.svg
│ ├── dom_parallel_remat.pdf
│ ├── dom_parallel_remat.png
│ ├── dom_parallel_remat.svg
│ ├── one_shot_aggregation.png
│ ├── one_shot_aggregation.svg
│ ├── papers_gat_memory.png
│ ├── papers_os_scaling.png
│ ├── papers_sage_memory.png
│ ├── papers_train_full_doc.png
│ └── sar_vs_distdgl.png
│ ├── index.rst
│ ├── model_prepare.rst
│ ├── quick_start.rst
│ ├── sampling_training.rst
│ ├── sar_config.rst
│ ├── sar_modes.rst
│ └── shards.rst
├── examples
├── README.md
├── SIGN
│ ├── README.md
│ └── train_sign_with_sar.py
├── correct_and_smooth.py
├── partition_graph.py
├── rgcn-hetero
│ ├── README.md
│ ├── model.py
│ ├── train_heterogeneous_graph.py
│ └── train_heterogeneous_graph_mfg.py
├── train_dist_appnp_with_sar.py
├── train_distdgl_with_sar_inference.py
├── train_homogeneous_graph_advanced.py
├── train_homogeneous_graph_basic.py
└── train_homogeneous_sampling_basic.py
├── pyproject.toml
├── requirements.txt
├── sar
├── __init__.py
├── comm.py
├── common_tuples.py
├── config.py
├── construct_shard_manager.py
├── core
│ ├── __init__.py
│ ├── full_partition_block.py
│ ├── graphshard.py
│ ├── sampling.py
│ └── sar_aggregation.py
├── data_loading.py
├── distributed_bn.py
├── edge_softmax.py
├── logging_setup.py
└── patch_dgl.py
├── security.md
├── setup.py
└── tests
├── base_utils.py
├── conftest.py
├── constants.py
├── models.py
├── multiprocessing_utils.py
├── pytest.ini
├── test_comm.py
├── test_hetero_graph_shard_manager.py
├── test_patch_dgl.py
└── test_sar.py
/.github/workflows/sar_test.yaml:
--------------------------------------------------------------------------------
1 | name: SAR tests
2 | permissions: {}
3 |
4 | on:
5 | pull_request:
6 | branches: [main]
7 | workflow_dispatch:
8 |
9 | jobs:
10 | sar_tests:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - name: Pull SAR
14 | uses: actions/checkout@v3
15 | with:
16 | fetch-depth: 0
17 |
18 | - name: Setup Python
19 | uses: actions/setup-python@v4
20 | with:
21 | python-version: '3.10'
22 |
23 | - name: Install requirements
24 | run: |
25 | python -m pip install --upgrade pip
26 | python -m pip install pytest
27 | # oneccl_bind_pt for torch 2.0.0 and python 3.10
28 | wget https://intel-extension-for-pytorch.s3.amazonaws.com/torch_ccl/cpu/oneccl_bind_pt-2.0.0%2Bcpu-cp310-cp310-linux_x86_64.whl
29 | python -m pip install oneccl_bind_pt-2.0.0+cpu-cp310-cp310-linux_x86_64.whl
30 | python -m pip install -e . "torch==2.0.0"
31 | python -m pip install pandas pyyaml pydantic
32 |
33 | - name: Run pytest
34 | run: |
35 | set +e
36 | python -m pytest tests/ -sv
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python cache
2 | __pycache__/
3 | *.pyc
4 |
5 | # Jupyter notebook checkpoints
6 | .ipynb_checkpoints/
7 |
8 | # Compiled Python files
9 | *.pyc
10 | *.pyo
11 | *.pyd
12 | __pycache__/
13 |
14 | # Build directories
15 | build/
16 | dist/
17 | *.egg-info/
18 |
19 |
20 | # Package distribution
21 | *.egg
22 | *.egg-info
23 |
24 | # IDE and editor files
25 | .vscode/
26 | .idea/
27 | *.iml
28 | *.iws
29 | *.ipr
30 |
31 | # Unit test / coverage reports
32 | htmlcov/
33 | .tox/
34 | .coverage
35 | .coverage.*
36 | .cache
37 | nosetests.xml
38 | coverage.xml
39 | *.cover
40 | .hypothesis/
41 | .pytest_cache/
42 |
43 | # Environments
44 | .env
45 | .venv
46 | env/
47 | venv/
48 | ENV/
49 | env.bak/
50 | venv.bak/
51 |
52 | # Datasets and partitions
53 | dataset/
54 | datasets/
55 | partition_data/
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # Required
2 | version: 2
3 |
4 | # Set the version of Python and other tools you might need
5 | build:
6 | os: ubuntu-22.04
7 | tools:
8 | python: "3.9"
9 |
10 | # Build documentation in the docs/ directory with Sphinx
11 | sphinx:
12 | configuration: docs/source/conf.py
13 |
14 |
15 | # Optionally declare the Python requirements required to build your docs
16 | python:
17 | install:
18 | - requirements: requirements.txt
19 |
20 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Intel Labs
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PROJECT NOT UNDER ACTIVE MANAGEMENT #
2 | This project will no longer be maintained by Intel.
3 | Intel has ceased development and contributions including, but not limited to, maintenance, bug fixes, new releases, or updates, to this project.
4 | Intel no longer accepts patches to this project.
5 | If you have an ongoing need to use this project, are interested in independently developing it, or would like to maintain patches for the open source software community, please create your own fork of this project.
6 |
7 |
8 | [Documentation](https://sar.readthedocs.io/en/latest/) | [Examples](https://github.com/IntelLabs/SAR/tree/main/examples)
9 |
10 | SAR is a pure Python library for distributed training of Graph Neural Networks (GNNs) on large graphs. SAR is built on top of PyTorch and DGL and supports distributed full-batch training as well as distributed sampling-based training. SAR is particularly suited for training GNNs on large graphs as the graph is partitioned across the training machines. In full-batch training, SAR can utilize the [sequential aggregation and rematerialization technique](https://proceedings.mlsys.org/paper_files/paper/2022/hash/1d781258d409a6efc66cd1aa14a1681c-Abstract.html) to guarantees linear memory scaling, i.e, the memory needed to store the GNN activiations in each host is guaranteed to go down linearly with the number of hosts, even for densely connected graphs.
11 |
12 | SAR requires minimal changes to existing GNN training code. SAR directly uses the graph partitioning data created by [DGL's partitioning tools](https://docs.dgl.ai/en/0.6.x/generated/dgl.distributed.partition.partition_graph.html) and can thus be used as a drop-in replacement for DGL's distributed sampling-based training. To get started using SAR, check out [SAR's documentation](https://sar.readthedocs.io/en/latest/) and the examples under the `examples/` folder.
13 |
14 |
15 | ## Installing required packages
16 | ```shell
17 | pip3 install -r requirements.txt
18 | ```
19 | Python3.8 or higher is required. You also need to install [torch CCL](https://github.com/intel/torch-ccl) if you want to use Intel's OneCCL communication backend.
20 |
21 | ## Full-batch training Performance on ogbn-papers100M
22 | SAR consumes up to 2x less memory when training a 3-layer GraphSage network on ogbn-papers100M (111M nodes, 3.2B edges), and up to 4x less memory when training a 3-layer Graph Attention Network (GAT). SAR achieves near linear scaling for the peak memory requirements per machine. We use a 3-layer GraphSage network with hidden layer size of 256, and a 3-layer GAT network with hidden layer size of 128 and 4 attention heads. We use batch normalization between all layers
23 |
24 |
25 |
26 |
27 | The run-time of SAR improves as we add more machines. At 128 machines, the epoch time is 3.8s. Each machine is a 2-socket machine with 2 Icelake processors (36 cores each). The machines are connected using Infiniband HDR (200 Gbps) links. After 100 epochs, training has converged. We use a 3-layer GraphSage network with hidden layer size of 256 and batch normalization between all layers. The training curve is the same regardless of the number of machines/partition.
28 |
29 |
30 |
31 |
32 | ## Sampling-based training Performance on ogbn-papers100M
33 | SAR is considerably faster than DistDGL in sampling-based training on CPUs. Each machine is a 2-socket machine with 2 Icelake processors (36 cores each). The machines are connected using Infiniband HDR (200 Gbps) links. We benchmarked using 3-layer GraphSage network with hidden layer size of 256. We used a batch size of 1000 per machine.
34 |
35 |
36 |
37 |
38 |
39 | ## Cite
40 |
41 | If you use SAR in your publication, we would appreciate it if you cite the SAR paper:
42 | ```
43 | @article{mostafa2021sequential,
44 | title={Sequential Aggregation and Rematerialization: Distributed Full-batch Training of Graph Neural Networks on Large Graphs},
45 | author={Mostafa, Hesham},
46 | journal={MLSys},
47 | year={2022}
48 | }
49 | ```
50 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | SOURCEDIR = source
5 | BUILDDIR = build
6 |
7 | # Put it first so that "make" without argument is like "make help".
8 | help:
9 | sphinx-build -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(O)
10 |
11 | .PHONY: help Makefile
12 |
13 | %: Makefile
14 | sphinx-build -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(O)
15 |
--------------------------------------------------------------------------------
/docs/source/_templates/autosummary/distneighborsampler.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 | .. currentmodule:: {{ module }}
4 |
5 |
6 | {{ name | underline}}
7 |
8 | .. autoclass:: {{ name }}
9 | :members:
10 |
11 |
--------------------------------------------------------------------------------
/docs/source/_templates/autosummary/graphshardmanager.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 | .. currentmodule:: {{ module }}
4 |
5 |
6 | {{ name | underline}}
7 |
8 | .. autoclass:: {{ name }}
9 | :members:
10 |
11 |
--------------------------------------------------------------------------------
/docs/source/comm.rst:
--------------------------------------------------------------------------------
1 | .. _comm-guide:
2 | .. currentmodule:: sar
3 |
4 |
5 | SAR's communication routines
6 | =============================
7 | SAR uses only two types of collective communication calls: ``all_to_all`` and ``all_reduce``. This choice was made to improve scalability by avoiding any point-to-point communication. SAR supports four backends, which are ``ccl``, ``nccl``, ``mpi`` and ``gloo``. (Note: Using ``gloo`` backend may not be as optimal as using other backends, because it doesn't support ``all_to_all`` routine - SAR must use its own implementation, which uses multiple asynchronous sends (torch.dist.isend) between workers). Nvidia's ``nccl`` is already included in the PyTorch distribution and it is the natural choice when training on GPUs.
8 |
9 | The ``ccl`` backend uses `Intel's OneCCL `_ library. You can install the PyTorch bindings for OneCCL `here `_ . ``ccl`` is the preferred backend when training on CPUs.
10 |
11 | You can train on CPUs and still use the ``nccl`` backend, or you can train on GPUs and use the ``ccl`` backend. However, you will incur extra overhead to move tensors back and forth between the CPU and GPU in order to provide the right tensors to the communication backend.
12 |
13 | In an environment with a networked file system, initializing ``torch.distributed`` is quite easy: ::
14 |
15 | if backend_name == 'nccl':
16 | comm_device = torch.device('cuda')
17 | else:
18 | comm_device = torch.device('cpu')
19 |
20 | master_ip_address = sar.nfs_ip_init(rank, path_to_ip_file)
21 | sar.initialize_comms(rank, world_size, master_ip_address, backend_name, comm_device)
22 |
23 | ..
24 | :func:`sar.initialize_comms` tries to initialize the torch.distributed process group, but only if it has not been initialized. User can initialize process group on his own before calling :func:`sar.initialize_comms`.
25 | :func:`sar.nfs_ip_init` communicates the master's ip address to the workers through the file system. In the absence of a networked file system, you should develop your own mechanism to communicate the master's ip address.
26 |
27 | You can specify the name of the socket that will be used for communication with `SAR_SOCKET_NAME` environment variable (if not specified, the first available socket will be selected).
28 |
29 |
30 |
31 | Relevant methods
32 | ---------------------------------------------------------------------------
33 |
34 | .. autosummary::
35 | :toctree: comm package
36 |
37 | initialize_comms
38 | rank
39 | world_size
40 | sync_params
41 | gather_grads
42 | nfs_ip_init
43 |
--------------------------------------------------------------------------------
/docs/source/common_tuples.rst:
--------------------------------------------------------------------------------
1 | .. _common-tuples:
2 |
3 | .. currentmodule:: sar.common_tuples
4 |
5 |
6 | Data tuples
7 | =============================
8 | ``sar.common_tuples`` defines a number of useful ``NamedTuple`` classes that are used to exchange data between the different parts of SAR:
9 |
10 | .. autosummary::
11 | :toctree: generated
12 | :nosignatures:
13 | :template: classtemplate
14 |
15 | PartitionData
16 | ShardEdgesAndFeatures
17 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, os.path.abspath('../..'))
4 |
5 |
6 | # -- Project information -----------------------------------------------------
7 |
8 | project = 'SAR'
9 | copyright = '2022, Hesham Mostafa'
10 | author = 'Hesham Mostafa'
11 |
12 | # The full version, including alpha/beta/rc tags
13 | release = '1.0'
14 | import sar # noqa
15 |
16 | add_module_names = False
17 |
18 | extensions = [
19 | 'sphinx.ext.autodoc',
20 | 'sphinx.ext.autosummary',
21 |
22 | ]
23 |
24 | def autodoc_skip_member_handler(app, what, name, obj, skip, options):
25 | return name == 'forward' or name == 'extra_repr'
26 |
27 | #def setup(app):
28 | # app.connect('autodoc-skip-member', autodoc_skip_member_handler)
29 |
30 | templates_path = ['_templates']
31 |
32 |
33 | html_theme = 'sphinx_rtd_theme'
34 |
35 | html_static_path = ['_static']
36 |
--------------------------------------------------------------------------------
/docs/source/data_loading.rst:
--------------------------------------------------------------------------------
1 | .. _data-loading:
2 |
3 |
4 | Data loading and graph construction
5 | ==========================================================
6 | After partitioning the graph using DGL's `partition_graph `_ function, SAR can load the graph data using :func:`sar.load_dgl_partition_data`. This yields a :class:`sar.common_tuples.PartitionData` object. The ``PartitionData`` object can then be used to construct various types of graph-like objects that can be passed to GNN models. You can construct graph objects to use for distributed full-batch training or graph objects to use for distributed training as follows:
7 |
8 | .. contents:: :local:
9 | :depth: 3
10 |
11 |
12 | Full-batch training
13 | ---------------------------------------------------------------------------------------
14 |
15 | Constructing the full graph for sequential aggregation and rematerialization
16 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
17 | Construct a single distributed graph object of type :class:`sar.core.GraphShardManager`::
18 |
19 | shard_manager = sar.construct_full_graph(partition_data)
20 |
21 | ..
22 |
23 | The ``GraphShardManager`` object encapsulates N DGL graph objects (where N is the number of workers). Each graph object represents the edges incoming from one partition (including the local partition). ``GraphShardManager`` implements the ``update_all`` and ``apply_edges`` methods in addition to several other methods from the standard ``dgl.heterograph.DGLGraph`` API. The ``update_all`` and ``apply_edges`` methods implement the sequential aggregation and rematerialization scheme to realize the distributed forward and backward passes. ``GraphShardManager`` can usually be passed to GNN layers instead of ``dgl.heterograph.DGLGraph``. See the :ref:`the distributed graph limitations section` for some exceptions.
24 |
25 | Constructing Message Flow Graphs (MFGs) for sequential aggregation and rematerialization
26 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
27 | In node classification tasks, gradients only backpropagate from the labeled nodes. DGL uses the concept of message flow graphs to construct layer-specific bi-partite graphs that update only a subset of nodes in each layer. These are the nodes that will ultimately affect the output, assuming each node only aggregates messages from its neighbors in every layer.
28 |
29 | If training a K-layer GNN on a node classification tasks, you can construct K distributed graph objects that reflect the message flow graphs at each layer using :class:`sar.construct_mfgs`:
30 | ::
31 |
32 | class GNNModel(nn.Module):
33 | def __init__(n_layers: int):
34 | super().__init__()
35 | self.convs = nn.ModuleList([
36 | dgl.nn.SAGEConv(100, 100)
37 | for _ in range(n_layers)
38 | ])
39 |
40 | def forward(blocks: List[sar.GraphShardManager], features: torch.Tensor):
41 | for idx in range(len(self.convs)):
42 | features = self.convs[idx](blocks[idx], features)
43 | return features
44 |
45 | K = 3 # number of layers
46 | gnn_model = GNNModel(K)
47 | train_blocks = sar.construct_mfgs(partition_data,
48 | global_indices_of_labeled_nodes_in_partition,
49 | K)
50 | model_out = gnn_model(train_blocks, local_node_features)
51 |
52 | ..
53 |
54 | Using message flow graphs at each layer can substantially lower run-time and memory consumption in node classification tasks with few labeled nodes.
55 |
56 |
57 | Constructing full graph or MFGs for one-shot aggregation
58 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
59 | As described in :ref:`training modes `, SAR supports doing one-shot distributed aggregation (mode 3). To run in this mode, you should extract the full partition graph from the :class:`sar.core.GraphShardManager` object and use that during training. When using the full graph:
60 | ::
61 |
62 | shard_manager = sar.construct_full_graph(partition_data)
63 | one_shot_graph = shard_manager.get_full_partition_graph()
64 | del shard_manager
65 | ## Use one_shot_graph from now on.
66 |
67 | ..
68 |
69 | When using MFGs:
70 | ::
71 |
72 | train_blocks = sar.construct_mfgs(partition_data,
73 | global_indices_of_labeled_nodes_in_partition,
74 | n_layers)
75 | one_shot_blocks = [block.get_full_partition_graph() for block in train_blocks]
76 | del train_blocks
77 | ## Use one_shot_blocks from now on
78 |
79 | ..
80 |
81 |
82 | Sampling-based training
83 | ---------------------------------------------------------------------------------------
84 |
85 | For sampling-based training, use the dataloader provided by SAR: :func:`sar.DataLoader` to construct globally-sampled graphs. The sampled graphs are vanilla DGL graphs that reside solely on the local machines. SAR provides a global neighbor sampler: :class:`sar.DistNeighborSampler` that defines the sampling process from the distributed graph. A typical use case is:
86 |
87 | ::
88 |
89 | shard_manager = sar.construct_full_graph(partition_data)
90 |
91 | neighbor_sampler = sar.DistNeighborSampler(
92 | [15, 10, 5], #Fanout for every layer
93 | input_node_features={'features': features}, #Input features to add to srcdata of first layer's sampled block
94 | output_node_features={'labels': labels} #Output features to add to dstdata of last layer's sampled block
95 | )
96 |
97 | dataloader = sar.DataLoader(
98 | shard_manager, #Distributed graph
99 | train_nodes, #Global indices of nodes that will form the root of the sampled graphs. In node classification, these are the labeled nodes
100 | neighbor_sampler, #Distributed sampler
101 | batch_size)
102 |
103 | for blocks in dataloader:
104 | output = gnn_model(blocks)
105 | ...
106 |
107 | ..
108 |
109 |
110 | Full-graph inference
111 | ---------------------------------------------------------------------------------------
112 | SAR might also be utilized just for model evaluation. It is preferable to evaluate the model on the entire graph while performing mini-batch distributed training with the DGL package. To accomplish this, SAR can turn a `DistGraph `_ object into a GraphShardManager object, allowing for distributed full-graph inference. The procedure is simple since no further steps are required because the model parameters are already synchronized during inference. You can use :func:`sar.convert_dist_graph` in the following way to perform full-graph inference:
113 | ::
114 |
115 | class GNNModel(nn.Module):
116 | def __init__(n_layers: int):
117 | super().__init__()
118 | self.convs = nn.ModuleList([
119 | dgl.nn.SAGEConv(100, 100)
120 | for _ in range(n_layers)
121 | ])
122 |
123 | # forward function prepared for mini-batch training
124 | def forward(blocks: List[DGLBlock], features: torch.Tensor):
125 | h = features
126 | for idx, (layer, block) in enumerate(zip(self.convs, blocks)):
127 | h = self.convs[idx](blocks[idx], h)
128 | return h
129 |
130 | # implement inference function for full-graph input
131 | def full_graph_inference(graph: sar.GraphShardManager, featues: torch.Tensor):
132 | h = features
133 | for idx, layer in enumerate(self.convs):
134 | h = layer(graph, h)
135 | return h
136 |
137 | # model wrapped in pytorch DistributedDataParallel
138 | gnn_model = th.nn.parallel.DistributedDataParallel(GNNModel(3))
139 |
140 | # Convert DistGraph into GraphShardManager
141 | gsm = sar.convert_dist_graph(g).to(device)
142 |
143 | # Access to model through DistributedDataParallel module field
144 | model_out = gnn_model.module.full_graph_inference(gsm, local_node_features)
145 | ..
146 |
147 |
148 | Relevant methods
149 | ---------------------------------------------------------------------------------------
150 |
151 | .. currentmodule:: sar
152 |
153 |
154 | .. autosummary::
155 | :toctree: Data loading and graph construction
156 | :template: distneighborsampler
157 |
158 |
159 | load_dgl_partition_data
160 | construct_full_graph
161 | construct_mfgs
162 | convert_dist_graph
163 | DataLoader
164 | DistNeighborSampler
165 |
--------------------------------------------------------------------------------
/docs/source/full_batch.rst:
--------------------------------------------------------------------------------
1 | Full-batch training
2 | ==========
3 | Distributed full-batch training in SAR may require some changes to your existing GNN model. Check :ref:`preparing your GNN for full-batch training in SAR` for more details. SAR supports multiple :ref:`training modes` that are suitable for different graph sizes and that trade speed for memory efficiency. For more information about SAR core distributed graph objects, check out the :ref:`distributed graph objects` section.
4 |
5 | .. toctree::
6 | :maxdepth: 1
7 | :titlesonly:
8 |
9 | Preparing your GNN for full-batch training in SAR
10 | Training modes
11 | Distributed Graph Objects
12 |
--------------------------------------------------------------------------------
/docs/source/images/dom_parallel_naive.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/SAR/9abdd1414bbdd9bf91b67b82376a985b0bdf9b72/docs/source/images/dom_parallel_naive.png
--------------------------------------------------------------------------------
/docs/source/images/dom_parallel_remat.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/SAR/9abdd1414bbdd9bf91b67b82376a985b0bdf9b72/docs/source/images/dom_parallel_remat.pdf
--------------------------------------------------------------------------------
/docs/source/images/dom_parallel_remat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/SAR/9abdd1414bbdd9bf91b67b82376a985b0bdf9b72/docs/source/images/dom_parallel_remat.png
--------------------------------------------------------------------------------
/docs/source/images/one_shot_aggregation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/SAR/9abdd1414bbdd9bf91b67b82376a985b0bdf9b72/docs/source/images/one_shot_aggregation.png
--------------------------------------------------------------------------------
/docs/source/images/papers_gat_memory.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/SAR/9abdd1414bbdd9bf91b67b82376a985b0bdf9b72/docs/source/images/papers_gat_memory.png
--------------------------------------------------------------------------------
/docs/source/images/papers_os_scaling.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/SAR/9abdd1414bbdd9bf91b67b82376a985b0bdf9b72/docs/source/images/papers_os_scaling.png
--------------------------------------------------------------------------------
/docs/source/images/papers_sage_memory.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/SAR/9abdd1414bbdd9bf91b67b82376a985b0bdf9b72/docs/source/images/papers_sage_memory.png
--------------------------------------------------------------------------------
/docs/source/images/papers_train_full_doc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/SAR/9abdd1414bbdd9bf91b67b82376a985b0bdf9b72/docs/source/images/papers_train_full_doc.png
--------------------------------------------------------------------------------
/docs/source/images/sar_vs_distdgl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/SAR/9abdd1414bbdd9bf91b67b82376a985b0bdf9b72/docs/source/images/sar_vs_distdgl.png
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. SAR documentation master file, created by
2 | sphinx-quickstart on Mon Mar 28 07:30:12 2022.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 |
7 | Welcome to SAR's documentation!
8 | ===============================
9 |
10 |
11 | SAR is a pure Python library built on top of `DGL `_ to accelerate distributed training of Graph Neural Networks (GNNs) on large graphs. SAR supports both full-batch training and sampling-based training. For full-batch training, SAR supports the `Sequenial Aggregation and Rematerialization (SAR) `_ scheme to reduce peak per-machine memory consumption and guarantee that model memory consumption per worker goes down linearly with the number of workers. This is achieved by eliminating most of the data redundancy (due to the halo effect) involved in standard spatially parallel training.
12 |
13 | SAR uses the graph partition data generated by DGL's `partitioning utilities `_. It can thus be used as a drop in replacement for DGL's sampling-based distributed training. SAR enables scalable, distributed training on very large graphs, and supports multiple training modes that balance speed against memory efficiency. SAR requires minimal changes to existing single-host DGL training code. See the quick start guide to get started using SAR.
14 |
15 | .. toctree::
16 | :maxdepth: 1
17 |
18 | Quick start
19 | Data loading and graph construction
20 | Communication routines
21 | Full-batch training
22 | Sampling-based training
23 | Data tuples
24 | SAR Configuration
25 |
26 |
27 |
28 | Index
29 | ==================
30 |
31 | * :ref:`genindex`
32 |
--------------------------------------------------------------------------------
/docs/source/model_prepare.rst:
--------------------------------------------------------------------------------
1 | .. _model-prepare:
2 |
3 |
4 | Preparing your GNN model for SAR
5 | =========================================================
6 |
7 | The basic graph object in SAR is :class:`sar.core.GraphShardManager`. It can typically be used as a drop-in replacement for DGL's native graph object and provided as the input graph to most GNN layers. See :ref:`the distributed graph limitations section` for some important limitations of this approach. There are situations where you need to modify your layer to accomodate :class:`sar.core.GraphShardManager` or to modify your GNN network to take into account the distributed nature of the training. Three such situations are outlined here:
8 |
9 | Edge softmax
10 | ------------------------------------------------------------------------------------
11 | DGL's ``edge_softmax`` function expects a native DGL graph object and will not work with a :class:`sar.core.GraphShardManager` object. Instead, you must use SAR's implementation :func:`sar.edge_softmax` which accepts a :class:`sar.core.GraphShardManager` object. DGL's attention based GNN layers make use of DGL's ``edge_softmax`` function. One solution to be able to use these layers with SAR is to monkey-patch them as shown below:
12 | ::
13 |
14 | import dgl
15 | import sar
16 | def patched_edge_softmax(graph, *args, **kwargs):
17 | if isinstance(graph, sar.GraphShardManager):
18 | return sar.edge_softmax(graph, *args, **kwargs)
19 |
20 | return dgl.nn.edge_softmax(graph, *args, **kwargs) # pylint: disable=no-member
21 |
22 |
23 | dgl.nn.pytorch.conv.gatconv.edge_softmax = patched_edge_softmax
24 | dgl.nn.pytorch.conv.dotgatconv.edge_softmax = patched_edge_softmax
25 | dgl.nn.pytorch.conv.agnnconv.edge_softmax = patched_edge_softmax
26 |
27 | ..
28 |
29 | ``patched_edge_softmax`` dispatches to either DGL's or SAR's implementation depending on the type of the input graph. SAR has the conveninece function :func:`sar.patch_dgl` that runs the above code to patch DGL's attention-based GNN layers.
30 |
31 | Parameterized message functions
32 | -----------------------------------------------------------------------------------
33 |
34 | SAR's sequential rematerialization of the computational graph during the backward pass must be aware of any learnable parameters used to create the edge messages. SAR needs to know of these parameters so that it can correctly backpropagate gradients to them. There is no easy way for SAR to automatically detect the learnable parameters used by the message function. It is thus up to the user to use the :func:`sar.core.message_has_parameters` to tell SAR about these parameters. For example, DGL's ``RelGraphConv`` layer uses a message function with learnable parameters. To avoid the need to modify the original code of ``RelGraphconv``, we can subclass it as follows to provide the necessary decorator for the message function, and then use the subclass in the GNN model:
35 | ::
36 |
37 | import dgl
38 | import sar
39 |
40 | class RelGraphConv_sar(dgl.nn.pytorch.conv.RelGraphConv):
41 | def __init__(self, *args, **kwargs):
42 | super().__init__(*args, **kwargs)
43 |
44 | @sar.message_has_parameters(lambda self: tuple(self.linear_r.parameters()))
45 | def message(self, edges):
46 | return super().message(edges)
47 |
48 | ..
49 |
50 | SAR has the conveninece function :func:`sar.patch_dgl` that defines a new ``RelGraphConv`` layer as described in the code above and uses it to replace DGL's ``RelGraphConv`` layer.
51 |
52 |
53 | Batch normalization
54 | -----------------------------------------------------------------------------------
55 | The batch normalization layers in PyTorch such as ``torch.nn.BatchNorm1d`` will normalize the GNN node features using statistics obtained only from the node features in the local partition. So the normalizing factors (mean and standard deviation) will be different in each worker, and will depend on the way the graph is partitioned. To normalize using global statistics obtained from all nodes in the graph, you can use :class:`sar.DistributedBN1D`. :class:`sar.DistributedBN1D` has a similar interface as ``torch.nn.BatchNorm1d``. For example::
56 |
57 | norm_layer = sar.DistributedBN1D(out_dim, affine=True)
58 | ..
59 | #Will normalize the features of the nodes in the partition
60 | #by the global node statistics (mean and standard deviation)
61 | normalized_activations = norm_layer(partition_node_features)
62 |
63 | ..
64 |
65 | Relevant methods
66 | ---------------------------------------------------------------------------
67 |
68 | .. autosummary::
69 | :toctree: Adapting GNNs to SAR
70 | :template: graphshardmanager
71 |
72 | sar.core.message_has_parameters
73 | sar.edge_softmax
74 | sar.DistributedBN1D
75 | sar.patch_dgl
76 |
--------------------------------------------------------------------------------
/docs/source/quick_start.rst:
--------------------------------------------------------------------------------
1 | .. _quick-start:
2 |
3 | Quick start guide
4 | ===============================
5 | Follow the following steps to enable distributed training in your DGL code:
6 |
7 | .. contents::
8 | :depth: 2
9 | :local:
10 | :backlinks: top
11 |
12 | Partition the graph
13 | ----------------------------------
14 | Partition the graph using DGL's `partition_graph `_ function. See `here `_ for an example. The number of partitions should be the same as the number of training machines/workers that will be used. SAR requires consecutive node indices in each partition, and requires that the partition information include the one-hop neighborhoods of all nodes in the partition. Setting ``num_hops = 1`` and ``reshuffle = True`` (in DGL < 1.0) in the call to ``partition_graph`` takes care of these requirements. ``partition_graph`` yields a directory structure with the partition information and a .json file ``graph_name.json``.
15 |
16 |
17 | An example of partitioning the ogbn-arxiv graph in two parts: ::
18 |
19 | import dgl
20 | import torch
21 | from ogb.nodeproppred import DglNodePropPredDataset
22 |
23 | dataset = DglNodePropPredDataset(name='ogbn-arxiv')
24 | graph = dataset[0][0]
25 | graph = dgl.to_bidirected(graph, copy_ndata=True)
26 | graph = dgl.add_self_loop(graph)
27 |
28 | labels = dataset[0][1].view(-1)
29 | split_idx = dataset.get_idx_split()
30 |
31 |
32 | def _idx_to_mask(idx_tensor):
33 | mask = torch.BoolTensor(graph.number_of_nodes()).fill_(False)
34 | mask[idx_tensor] = True
35 | return mask
36 |
37 |
38 | train_mask, val_mask, test_mask = map(
39 | _idx_to_mask, [split_idx['train'], split_idx['valid'], split_idx['test']])
40 | features = graph.ndata['feat']
41 | graph.ndata.clear()
42 | for name, val in zip(['train_mask', 'val_mask', 'test_mask', 'labels', 'features'],
43 | [train_mask, val_mask, test_mask, labels, features]):
44 | graph.ndata[name] = val
45 |
46 | dgl.distributed.partition_graph(
47 | graph, 'arxiv', 2, './test_partition_data/', num_hops=1) # use reshuffle=True in DGL < 1.0
48 |
49 | ..
50 |
51 | Note that we add the labels, and the train/test/validation masks as node features so that they get split into multiple parts alongside the graph.
52 |
53 |
54 | Initialize communication
55 | ----------------------------------
56 | SAR uses the `torch.distributed `_ package to handle all communication. See the :ref:`Communication Guide ` for more information on the communication routines. We require the IP address of the master worker/machine (the machine with rank 0) to initialize the ``torch.distributed`` package. In an environment with a networked file system where all workers/machines share a common file system, we can communicate the master's IP address through the file system. In that case, use :func:`sar.nfs_ip_init` to obtain the master ip address.
57 |
58 | Initialize the communication through a call to :func:`sar.initialize_comms` , specifying the current worker index, the total number of workers (which should be the same as the number of partitions from step 1), the master's IP address, and the communication device. The later is the device on which SAR should place the tensors before sending them through the communication backend. For example: ::
59 |
60 | if backend_name == 'nccl':
61 | comm_device = torch.device('cuda')
62 | else:
63 | comm_device = torch.device('cpu')
64 | master_ip_address = sar.nfs_ip_init(rank, path_to_ip_file)
65 | sar.initialize_comms(rank, world_size, master_ip_address, backend_name, comm_device)
66 |
67 | ..
68 |
69 | ``backend_name`` can be ``ccl``, ``nccl``, ``mpi`` or ``gloo``.
70 |
71 |
72 |
73 | Load partition data and construct graph
74 | -----------------------------------------------------------------
75 | Use :func:`sar.load_dgl_partition_data` to load one graph partition from DGL's partition data in each worker. :func:`sar.load_dgl_partition_data` returns a :class:`sar.common_tuples.PartitionData` object that contains all the information about the partition.
76 |
77 | There are several ways to construct a distributed graph-like object from ``PartitionData``. See :ref:`constructing distributed graphs ` for more details. Here we will use the simplest method: :func:`sar.construct_full_graph` which returns a :class:`sar.core.GraphShardManager` object which implements many of the GNN-related functionality of DGL's native graph objects. ``GraphShardManager`` can thus be used as a drop-in replacement for DGL's native graphs or it can be passed to SAR's samplers and data loaders to construct graph mini-batches.
78 |
79 | ::
80 |
81 | partition_data = sar.load_dgl_partition_data(
82 | json_file_path, #Path to .json file created by DGL's partition_graph
83 | rank, #Worker rank
84 | device #Device to place the partition data (CPU or GPU)
85 | )
86 | shard_manager = sar.construct_full_graph(partition_data)
87 |
88 | ..
89 |
90 | Full-batch training
91 | ---------------------------------------------------------------------------
92 | Full-batch training using SAR follows a very similar pattern as single-host training. Instead of using a vanilla DGL graph, we use a :class:`sar.core.GraphShardManager`. After initializing the communication backend, loading graph data and constructing the distributed graph, a simple training loop is ::
93 |
94 | gnn_model = construct_GNN_model(...)
95 | optimizer = torch.optim.Adam(gnn_model.parameters(),..)
96 | sar.sync_params(gnn_model)
97 | for train_iter in range(n_train_iters):
98 | model_out = gnn_model(shard_manager,features)
99 | loss = calculate_loss(model_out,labels)
100 | optimizer.zero_grad()
101 | loss.backward()
102 | sar.gather_grads(gnn_model)
103 | optimizer.step()
104 |
105 | ..
106 |
107 | In a distributed setting, each worker will construct the GNN model. Before training, we should synchronize the model parameters across all workers. :func:`sar.sync_params` is a convenience function that does just that. At the end of every training iteration, each worker needs to gather and sum the parameter gradients from all other workers before making the parameter update. This can be done using :func:`sar.gather_grads`.
108 |
109 | See :ref:`training modes ` for the different full-batch training modes.
110 |
111 | Sampling-based or mini-batch training
112 | ---------------------------------------------------------------------------
113 | A simple sampling-based training loop looks as follows:
114 |
115 | ::
116 |
117 | neighbor_sampler = sar.DistNeighborSampler(
118 | [15, 10, 5], #Fanout for every layer
119 | input_node_features={'features': features}, #Input features to add to srcdata of first layer's sampled block
120 | output_node_features={'labels': labels} #Output features to add to dstdata of last layer's sampled block
121 | )
122 |
123 | dataloader = sar.DataLoader(
124 | shard_manager, #Distributed graph
125 | train_nodes, #Global indices of nodes that will form the root of the sampled graphs. In node classification, these are the labeled nodes
126 | neighbor_sampler, #Distributed sampler
127 | batch_size)
128 |
129 | for blocks in dataloader:
130 | output = gnn_model(blocks)
131 | loss = calculate_loss(output,labels)
132 | optimizer.zero_grad()
133 | loss.backward()
134 | sar.gather_grads(gnn_model)
135 | optimizer.step()
136 |
137 | ..
138 |
139 |
140 | We use :class:`sar.DistNeighborSampler` to construct a distributed sampler and :func:`sar.DataLoader` to construct an iterator that retrurn standard local DGL blocks constructed from the distributed graph.
141 |
142 |
143 | For complete examples, check the examples folder in the Git repository.
144 |
--------------------------------------------------------------------------------
/docs/source/sampling_training.rst:
--------------------------------------------------------------------------------
1 | .. _sampling:
2 |
3 |
4 | Distributed Sampling-based training
5 | ==================================================================
6 | In addition to distributed full-batch training, the SAR library also supports distributed sampling-based training. The main difference between SAR's distributed sampling-based component and DistDGL is that SAR uses collective communication primitives such as ``all_to_all`` during the distributed mini-batch generation steps, while DistDGL uses point-to-point communication. One common use case in GNN training is to use sampling-based training followed by full-batch inference. Since SAR supports sampling-based as well as full-batch training and inference, this use case is particularly easy to implement. The same GNN model can be used for both full-batch and sampling-based runs. A simple 3-layer GraphSage model:
7 |
8 | ::
9 |
10 | class GNNModel(nn.Module):
11 | def __init__(self, in_dim: int, hidden_dim: int, out_dim: int):
12 | super().__init__()
13 |
14 | self.convs = nn.ModuleList([
15 | dgl.nn.SAGEConv(in_dim, hidden_dim, aggregator_type='mean'),
16 | dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type='mean'),
17 | dgl.nn.SAGEConv(hidden_dim, out_dim, aggregator_type='mean'),
18 | ])
19 |
20 | def forward(self, blocks: List[Union[DGLBlock, sar.GraphShardManager]], features: torch.Tensor):
21 | for idx, conv in enumerate(self.convs):
22 | features = conv(blocks[idx], features)
23 | if idx < len(self.convs) - 1:
24 | features = F.relu(features, inplace=True)
25 |
26 | return features
27 |
28 | ..
29 |
30 | Since :class:`sar.core.GraphShardManager` can be used as a drop-in replacement for DGL's native graph objects, we can use a standard DGL model and either pass it the sampled ``DGLBlock``s or the full distributed graph.
31 |
32 | As in full-batch training, we first load the DGL-generated partition data, and construct the full distributed graph. We then define the sampling strategy and dataloader. We use SAR's :class:`sar.DistNeighborSampler` and :func:`sar.DataLoader` to define the sampling strategy and the distributed dataloader, respectively.
33 |
34 | ::
35 |
36 | partition_data = sar.load_dgl_partition_data(
37 | args.partitioning_json_file, args.rank, torch.device('cpu'))
38 |
39 | full_graph_manager = sar.construct_full_graph(
40 | partition_data) # Keep full graph on CPU
41 |
42 |
43 | neighbor_sampler = sar.DistNeighborSampler(
44 | [15, 10, 5], #Fanout for every layer
45 | input_node_features={'features': features}, #Input features to add to srcdata of first layer's sampled block
46 | output_node_features={'labels': labels} #Output features to add to dstdata of last layer's sampled block
47 | )
48 |
49 | dataloader = sar.DataLoader(
50 | full_graph_manager, #Distributed graph
51 | train_nodes, #Global indices of nodes that will form the root of the sampled graphs. In node classification, these are the labeled nodes
52 | neighbor_sampler, #Distributed sampler
53 | batch_size)
54 |
55 | ..
56 |
57 | A typical training loop is shown below.
58 |
59 | ::
60 |
61 | gnn_model = construct_GNN_model(...)
62 | optimizer = torch.optim.Adam(gnn_model.parameters(),..)
63 | sar.sync_params(gnn_model)
64 |
65 |
66 | for epoch in range(n_epochs):
67 | model.train()
68 | for blocks in dataloader:
69 | block_features = blocks[0].srcdata['features']
70 | block_labels = blocks[-1].dstdata['labels']
71 | logits = gnn_model(blocks, block_features)
72 |
73 | output = gnn_model(blocks)
74 | loss = calculate_loss(output, block_labels)
75 | optimizer.zero_grad()
76 | loss.backward()
77 | sar.gather_grads(gnn_model)
78 | optimizer.step()
79 |
80 | # inference
81 | model.eval()
82 | with torch.no_grad():
83 | logits = gnn_model_cpu([full_graph_manager] * n_layers, features)
84 | calculate_loss_accuracy(logits, full_graph_labels)
85 |
86 | ..
87 |
88 | Note that we obtain instances of standard ``DGLBlock`` from the distributed dataloader every training iteration. After every epoch, we run distributed full-graph inference using the :class:`sar.core.GraphShardManager`. We use the same ``GraphShardManager`` object at each layer. Alternatively, as described in the :ref:`data loading section`, we can construct layer-specific distributed message flow graphs (MFGs) to avoid computing redundant node features at each layer. Redundant node features are the node features that do not contribute to the output at the labeled nodes.
89 |
90 |
91 |
92 | Relevant classes and methods
93 | ---------------------------------------------------------------------------
94 |
95 | .. currentmodule:: sar
96 |
97 | .. autosummary::
98 | :toctree: Graph Shard classes
99 | :template: graphshardmanager
100 |
101 | GraphShardManager
102 | DataLoader
103 | DistNeighborSampler
104 |
105 |
--------------------------------------------------------------------------------
/docs/source/sar_config.rst:
--------------------------------------------------------------------------------
1 | .. _sar-config:
2 | .. currentmodule:: sar
3 |
4 | SAR configuration
5 | =======================================
6 | .. autoclass:: sar.Config
7 |
--------------------------------------------------------------------------------
/docs/source/sar_modes.rst:
--------------------------------------------------------------------------------
1 | .. _sar-modes:
2 |
3 | SAR's training modes
4 | =============================
5 | SAR can run the distributed GNN forward and backward pass in three distinct modes:
6 |
7 | .. contents:: :local:
8 | :depth: 2
9 |
10 |
11 | Mode 1: Sequential aggregation and rematerialization
12 | ------------------------------------------------------------------------------------
13 | SAR's main training mode uses sequential aggregation and rematerialization to avoid fully materializaing the computational graph at any single worker. The distributed training scheme is illustrated below for 3 workers/partitions. The figure illustrates the forward and backward pass steps for machine/worker 1.
14 |
15 | .. image:: ./images/dom_parallel_remat.png
16 | :alt: SAR training
17 | :width: 500 px
18 |
19 | **Forward pass:**: When doing distributed message passing and aggregation, all workers disable PyTorch's autgrad to stop PyTorch from creating the computational graph. The forward pass steps in worker 1 are:
20 |
21 | #. Aggregates messages from nodes in the local partition.
22 | #. Fetch neighboring node features from worker 2 and aggregates their messages. Delete fetched nodes.
23 | #. Fetch neighboring node features from worker 3 and aggregates their messages. Delete fetched nodes.
24 |
25 | **Backward pass:**: Since no computational graph was constructed in the forward pass, the workers need to reconstruct the computational graph to backpropagate the gradients. The backward pass steps in worker 1 when worker one receives errors :math:`e_2,e_3` for nodes :math:`v_2,v_3`:
26 |
27 | #. Re-aggregate messages from nodes in the local partition with autograd enabled, backpropagate along the constructed computational graph, then delete all intermediate tensors to delete the computational graph.
28 | #. Re-fetch neighboring nodes from worker 2 and re-aggregate their messages with autograd enabled, backpropagate along the constructed computational graph, then delete all intermediate tensors to delete the computational graph.
29 | #. Re-fetch neighboring nodes from worker 3 and re-aggregate their messages with autograd enabled, backpropagate along the constructed computational graph, then delete all intermediate tensors to delete the computational graph.
30 |
31 | Note that many GNN layer simply use a sum or mean operation to aggregate the features of their neighbors. In that case, we do not need to reconstruct the computational graph during the backward pass, as the gradients of the input features can be easily obtained from the gradients of the output features. To see that consider the operation :math:`z = x + y`, and assume we have the gradient of the loss w.r.t :math:`z`: :math:`e_z` than the gradients of the loss w.r.t x and y are :math:`e_x = e_y = e_z`. SAR automatically detects this situation and directly pushes the correct gradients to the local and remote nodes without re-fetching remote features or re-construting the computational graph.
32 |
33 | Mode 1 is the default mode. It runs automatically when you pass a :class:`sar.core.GraphShardManager` object to your GNN model. For example:
34 | ::
35 |
36 | partition_data = sar.load_dgl_partition_data(
37 | json_file_path, #Path to .json file created by DGL's partition_graph
38 | rank, #Worker rank
39 | device #Device to place the partition data (CPU or GPU)
40 | )
41 | shard_manager = sar.construct_full_graph(partition_data)
42 | model_out = gnn_model (shard_manager,local_node_features)
43 | loss_function(model_out).backward()
44 |
45 | ..
46 |
47 | Mode 1 is the most memory-efficient mode. Excluding the GNN parameters which are replicated across all workers, mode 1 guarantees that peak memory consumption per worker will go down linearly with the number of workers used in training, even for densely connected graphs.
48 |
49 | Mode 2: Sequential aggregation
50 | ------------------------------------------------------------------------------------
51 | For many GNN layers, mode 1 must sequentially rematerialize the computational graph during the backwrd pass. This is the case for GAT layers for example. This introduces extra communication overhead during the backward pass in order to re-fetch the features of remote nodes. It also introduces an extra compute overhead in order to sequentially construct the computational graph (by re-executing the forward pass with autograd enabled).
52 |
53 | Mode 2 avoids this overhead by constructing the computational graph during the forward pass. This can potentially take up a lot of memory. In a densely connected graph for example, this may cause the features of all nodes in the graph to be materialized at every worker. The figure below illustrates the forward and backward pass in mode 2. Only the activity in worker 1 is shown. Note that remote node featues are stored in worker 1 during the forward pass as part of the computational graph at worker 1. In the backward pass, SAR has access to the full computational graph and it uses it to backpropagates gradients to remote nodes without any re-fetching.
54 |
55 | .. image:: ./images/dom_parallel_naive.png
56 | :alt: No re-materialization
57 | :width: 500 px
58 |
59 | Mode 2 can be enabled by disabling sequential rematerialization in SAR's configuration object :class:`sar.Config`.::
60 |
61 | sar.Config.disable_sr = True
62 | partition_data = sar.load_dgl_partition_data(
63 | json_file_path, #Path to .json file created by DGL's partition_graph
64 | rank, #Worker rank
65 | device #Device to place the partition data (CPU or GPU)
66 | )
67 | shard_manager = sar.construct_full_graph(partition_data)
68 | model_out = gnn_model (shard_manager,local_node_features)
69 | loss_function(model_out).backward()
70 |
71 | ..
72 |
73 | Mode 3: One-shot aggregation
74 | ------------------------------------------------------------------------------------
75 | Modes 1 and 2 follow a sequential aggregation approach where data from remote partitions are sequentially fetched. This might introduce scalability issues since the forward and backward pass in each layer will involve N communication rounds each (where N is the number of workers/partitions). Sequential aggregation thus introduces N synchronization points in each layer's forward and backward passes as each worker needs to wait until every other worker has finished its aggregation step before moving to the next step in the aggregation sequence (See the steps in the figure above).
76 |
77 | In mode 3, the one-shot aggregation mode, each worker fetches all remote data in one communication round and does one aggregation round to aggregate message from all remotely fetched nodes. This is illustrated in the figure below:
78 |
79 | .. image:: ./images/one_shot_aggregation.png
80 | :alt: One shot aggregation
81 | :width: 500 px
82 |
83 | One advantage of mode 3 is that it only requires one communication round per layer in each of the forward and backward passes. One disadvantage is that mode 3 does not hide the communication latency. Due to the sequential nature of modes 1 and 2, SAR is able to simultaneously process data from one remote partition while pre-fetching data from the next remote partition in the aggregation sequence. Modes 1 and 2 can thus better hide the communication latency than mode 3. The memory requirements of mode 3 are similar to mode 2.
84 |
85 | To train in mode 3, you should extract the full partition graph from the :class:`sar.core.GraphShardManager` object and use that during training.
86 | ::
87 |
88 | partition_data = sar.load_dgl_partition_data(
89 | json_file_path, #Path to .json file created by DGL's partition_graph
90 | rank, #Worker rank
91 | device #Device to place the partition data (CPU or GPU)
92 | )
93 | shard_manager = sar.construct_full_graph(partition_data)
94 | one_shot_graph = shard_manager.get_full_partition_graph()
95 | model_out = gnn_model (one_shot_graph,local_node_features)
96 | loss_function(model_out).backward()
97 |
98 | ..
99 |
--------------------------------------------------------------------------------
/docs/source/shards.rst:
--------------------------------------------------------------------------------
1 | .. _shards:
2 |
3 |
4 | Distributed Graph Representation
5 | ==================================================================
6 | SAR represents the full graph as :math:`N^2` graph shards where :math:`N` is the number of workers/partitions. Each graph shard represents the edges from one partition to another and the features associated with these edges. :class:`sar.core.GraphShard` represents a single graph shard. Each worker stores the :math:`N` graph shards containing the incoming edges for the nodes in the worker's partition. These :math:`N` graph shards are managed by the :class:`sar.core.GraphShardManager` class. :class:`sar.core.GraphShardManager` implements a distributed version of ``update_all`` and ``apply_edges`` which are the main methods used by GNNs to create and exchange messages in the graph. :class:`sar.core.GraphShardManager` implements ``update_all`` and ``apply_edges`` in a sequential manner by iterating through the :math:`N` graph shards to sequentially create and aggregate messages from each partition into the local partition.
7 |
8 | The :meth:`sar.core.GraphShardManager.get_full_partition_graph` method can be used to combine the worker's :math:`N` graph shards into one monolithic graph object that represents all the incoming edges for nodes in the local partition. It returns a :class:`sar.core.DistributedBlock` object. The implementation of ``update_all`` and ``apply_edges`` in :class:`sar.core.DistributedBlock` is not sequential. Instead. It fetches all remote features in one step and aggregates all incoming messages to the local partition in one step.
9 |
10 | In the distributed implementation of the sequential backward pass in ``update_all`` and ``apply_edges`` in :class:`sar.core.GraphShardManager`, it is not possible for SAR to automatically detect if the message function uses any learnable parameters, and thus SAR will not be able to backpropagate gradients to these message parameters. To tell SAR that a message function is parameterized, use the :func:`sar.core.message_has_parameters` decorator to decorate message functions that use learnable parameters.
11 |
12 |
13 | .. _shard-limitations:
14 |
15 | Limitations of the distributed graph objects
16 | ------------------------------------------------------------------------------------
17 | Keep in mind that the distributed graph class :class:`sar.core.GraphShardManager` does not implement all the functionality of DGL's native graph class. For example, it does not impelement the ``successors`` and ``predecessors`` methods. It supports primarily the methods of DGL's native graphs that are relevant to GNNs such as ``update_all``, ``apply_edges``, and ``local_scope``. It also supports setting graph node and edge features through the dictionaries ``srcdata``, ``dstdata``, and ``edata``. To remain compatible with DGLGraph :class:`sar.core.GraphShardManager` provides also access to the ``ndata`` member, which works as alias to ``srcdata``, however it is not accessible when working with MFGs.
18 |
19 | :class:`sar.core.GraphShardManager` also supports the ``in_degrees`` and ``out_degrees`` members and supports querying the number of nodes and edges in the graph.
20 |
21 | The ``update_all`` method in :class:`sar.core.GraphShardManager` only supports the 4 standard reduce functions in dgl: ``max``, ``min``, ``sum``, and ``mean``. The reason behind this is that SAR runs a sequential reduction of messages and therefore requires that :math:`reduce(msg_1,msg_2,msg_3) = reduce(msg_1,reduce(msg_2,msg_3))`.
22 |
23 | .. currentmodule:: sar.core
24 |
25 |
26 | Relevant classes methods
27 | ---------------------------------------------------------------------------
28 |
29 | .. autosummary::
30 | :toctree: Graph Shard classes
31 | :template: graphshardmanager
32 |
33 | GraphShard
34 | GraphShardManager
35 | DistributedBlock
36 | message_has_parameters
37 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | ## Graph partitioning
2 |
3 | The ``partition_graph.py`` script can be used to partition both homogeneous and heterogeneous graphs. It utilizes DGL's metis-based partitioning algorithm to divide the graphs into smaller partitions. Note that all node-related information must be included in the graph's ``ndata`` dictionary so that they are correctly partitioned with the graph.
4 | Similarly, edge-related information must be included in the graph's ``edata`` dictionary
5 |
6 | ### Supported datasets:
7 | - ogbn-products, ogbn-arxiv, ogb-mag from [Open Graph Benchmarks](https://ogb.stanford.edu/)
8 | - cora, citeseer, pubmed
9 |
10 | ## Full-batch Training
11 |
12 | The script ``train_homogeneous_graph_basic.py`` demonstrates the basic functionality of SAR. It runs distriobuted training using a 3-layer GraphSage network on a partiotioned graph. If you want to train using ``N`` workers, then you need to launch the script ``N`` times, preferably on separate machines. For example, for ``N=2``, and assuming the two workers are on the same network file system, you can launch the 2 workers using the following two commands:
13 |
14 | ```shell
15 | python3 train_homogeneous_graph_basic.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 0 --world-size 2
16 | python3 train_homogeneous_graph_basic.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 1 --world-size 2
17 |
18 | ```
19 | The worker with ``rank=0`` (the master) will write its address to the file specified by the ``--ip-file`` option and the other worker(s) will read this file and connect to the master.
20 |
21 |
22 | The ``train_homogeneous_graph_advanced.py`` script demonstrates more advanced features of SAR such as distributed construction of Message Flow Graphs (MFGs), and the multiple training modes supported by SAR.
23 |
24 | The ``train_heterogeneous_graph.py`` script demonstrates training on a heterogeneous graph (ogbn-mag). The script trains a 3-layer R-GCN.
25 |
26 |
27 | ## Sampling-based Training
28 | The script ``train_homogeneous_sampling_basic.py`` demonstrates distributed sampling-based training on SAR. It demonstrates the unique ability of the SAR library to run distributed sampling-based training followed by memory-efficient distributed full-graph inference. The script uses a 3-layer GraphSage network. Assuming 2 machines, you can launch training on the 2 machines using the following 2 commands, where each is executed on a different machine:
29 |
30 | ```shell
31 | python3 train_homogeneous_sampling_basic.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 0 --world-size 2
32 | python3 train_homogeneous_sampling_basic.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 1 --world-size 2
33 |
34 | ```
35 |
36 | ## Distributed Mini-Batch Training with Full-Graph inference
37 | The script ``train_distdgl_with_sar_inference.py`` showcases how SAR can be effectively combined with native DGL distributed training. In this particular example, the training process utilizes a sampling approach, while the evaluation phase leverages the SAR library to perform computations on the entire graph.
38 | ```shell
39 | python /home/ubuntu/workspace/dgl/tools/launch.py \
40 | --workspace /home/ubuntu/workspace/SAR/examples \
41 | --num_trainers 1 \
42 | --num_samplers 2 \
43 | --num_servers 1 \
44 | --part_config partition_data/ogbn-products.json \
45 | --ip_config ip_config.txt \
46 | "/home/ubuntu/miniconda3/bin/python train_distdgl_with_sar_inference.py --graph_name ogbn-products --ip_config ip_config.txt --num_epochs 2 --batch_size 1000 --part_config partition_data/ogbn-products.json"
47 | ```
48 |
49 | ## Correct and Smooth
50 | Example taken from [DGL implemenetation](https://github.com/dmlc/dgl/tree/master/examples/pytorch/correct_and_smooth) of C&S. Code is adjusted to perform distributed training with SAR. Introduced modifications change the way data normalization is performed - workers need to communicate with each other to calculate mean and standard deviation for the entire dataset (not just their partition). Moreover, workers need to be synchronized with each other to calculate sigma value required during "correct" phase.
51 |
52 | For instance, you can run the example with following commands (2 machines scenario):
53 |
54 | * **Plain MLP + C&S**
55 | * Rank 0 machine:
56 | ```shell
57 | python correct_and_smooth.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 0 --world-size 2 --dropout 0.5 --correction-adj DA --smoothing-adj AD --autoscale
58 | ```
59 |
60 | * Rank 1 machine:
61 | ```shell
62 | python correct_and_smooth.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 1 --world-size 2 --dropout 0.5 --correction-adj DA --smoothing-adj AD --autoscale
63 | ```
64 |
65 | * **Plain Linear + C&S**
66 | * Rank 0 machine:
67 | ```shell
68 | python correct_and_smooth.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 0 --world-size 2 --model linear --dropout 0.5 --epochs 1000 --correction-alpha 0.87 --smoothing-alpha 0.81 --correction-adj AD --autoscale
69 | ```
70 |
71 | * Rank 1 machine:
72 | ```shell
73 | python correct_and_smooth.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 1 --world-size 2 --model linear --dropout 0.5 --epochs 1000 --correction-alpha 0.87 --smoothing-alpha 0.81 --correction-adj AD --autoscale
74 | ```
75 |
--------------------------------------------------------------------------------
/examples/SIGN/README.md:
--------------------------------------------------------------------------------
1 | ## SIGN: Scalable Inception Graph Neural Networks
2 |
3 | Original script: https://github.com/dmlc/dgl/tree/master/examples/pytorch/sign
4 |
5 | Provided `train_sign_with_sar.py` script is an example how to intergrate SAR to preprocess graph data for training.
6 |
7 | ### Results
8 | Obtained results for two partitions:
9 | - ogbn-products: 0.7832
10 | - reddit: 0.9639
11 |
12 | ### Run command:
13 |
14 | ```
15 | python train_sign_with_sar.py --partitioning-json-file partition_data/reddit.json --ip-file ip_file --backend ccl --rank 0 --world-size 2
16 |
17 | python train_sign_with_sar.py --partitioning-json-file partition_data/reddit.json --ip-file ip_file --backend ccl --rank 1 --world-size 2
18 | ```
--------------------------------------------------------------------------------
/examples/SIGN/train_sign_with_sar.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 |
5 | import dgl
6 | import dgl.function as fn
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | import sar
13 |
14 | def load_dataset(filename, rank, device):
15 | partition_data = sar.load_dgl_partition_data(filename, rank, device)
16 | # Obtain train,validation, and test masks
17 | # These are stored as node features. Partitioning may prepend
18 | # the node type to the mask names. So we use the convenience function
19 | # suffix_key_lookup to look up the mask name while ignoring the
20 | # arbitrary node type
21 | masks = {}
22 | for mask_name, indices_name in zip(["train_mask", "val_mask", "test_mask"],
23 | ["train_indices", "val_indices", "test_indices"]):
24 | boolean_mask = sar.suffix_key_lookup(partition_data.node_features,
25 | mask_name)
26 | masks[indices_name] = boolean_mask.nonzero(
27 | as_tuple=False).view(-1).to(device)
28 | print(partition_data.node_features.keys())
29 |
30 | feature_name, label_name = ('feat', 'label') if 'reddit' in filename \
31 | else ('features', 'labels')
32 | labels = sar.suffix_key_lookup(partition_data.node_features,
33 | label_name).long().to(device)
34 |
35 | # Obtain the number of classes by finding the max label across all workers
36 | n_classes = labels.max() + 1
37 | sar.comm.all_reduce(n_classes, torch.distributed.ReduceOp.MAX, move_to_comm_device=True)
38 | n_classes = n_classes.item()
39 |
40 | features = sar.suffix_key_lookup(partition_data.node_features, feature_name).to(device)
41 | full_graph_manager = sar.construct_full_graph(partition_data).to(device)
42 |
43 | full_graph_manager.ndata["feat"] = features
44 | full_graph_manager.ndata["label"] = labels
45 | return full_graph_manager, n_classes, \
46 | masks["train_indices"], masks["val_indices"], masks["test_indices"],
47 |
48 | class FeedForwardNet(nn.Module):
49 | def __init__(self, in_feats, hidden, out_feats, n_layers, dropout):
50 | super(FeedForwardNet, self).__init__()
51 | self.layers = nn.ModuleList()
52 | self.n_layers = n_layers
53 | if n_layers == 1:
54 | self.layers.append(nn.Linear(in_feats, out_feats))
55 | else:
56 | self.layers.append(nn.Linear(in_feats, hidden))
57 | for _ in range(n_layers - 2):
58 | self.layers.append(nn.Linear(hidden, hidden))
59 | self.layers.append(nn.Linear(hidden, out_feats))
60 | if self.n_layers > 1:
61 | self.prelu = nn.PReLU()
62 | self.dropout = nn.Dropout(dropout)
63 | self.reset_parameters()
64 |
65 | def reset_parameters(self):
66 | gain = nn.init.calculate_gain("relu")
67 | for layer in self.layers:
68 | nn.init.xavier_uniform_(layer.weight, gain=gain)
69 | nn.init.zeros_(layer.bias)
70 |
71 | def forward(self, x):
72 | for layer_id, layer in enumerate(self.layers):
73 | x = layer(x)
74 | if layer_id < self.n_layers - 1:
75 | x = self.dropout(self.prelu(x))
76 | return x
77 |
78 |
79 | class Model(nn.Module):
80 | def __init__(self, in_feats, hidden, out_feats, R, n_layers, dropout):
81 | super(Model, self).__init__()
82 | self.dropout = nn.Dropout(dropout)
83 | self.prelu = nn.PReLU()
84 | self.inception_ffs = nn.ModuleList()
85 | for hop in range(R + 1):
86 | self.inception_ffs.append(
87 | FeedForwardNet(in_feats, hidden, hidden, n_layers, dropout)
88 | )
89 | # self.linear = nn.Linear(hidden * (R + 1), out_feats)
90 | self.project = FeedForwardNet(
91 | (R + 1) * hidden, hidden, out_feats, n_layers, dropout
92 | )
93 |
94 | def forward(self, feats):
95 | hidden = []
96 | for feat, ff in zip(feats, self.inception_ffs):
97 | hidden.append(ff(feat))
98 | out = self.project(self.dropout(self.prelu(torch.cat(hidden, dim=-1))))
99 | return out
100 |
101 |
102 | def calc_weight(g):
103 | """
104 | Compute row_normalized(D^(-1/2)AD^(-1/2))
105 | """
106 | with g.local_scope():
107 | # compute D^(-0.5)*D(-1/2), assuming A is Identity
108 | g.ndata["in_deg"] = g.in_degrees().float().pow(-0.5)
109 | g.ndata["out_deg"] = g.out_degrees().float().pow(-0.5)
110 | g.apply_edges(fn.u_mul_v("out_deg", "in_deg", "weight"))
111 | # row-normalize weight
112 | g.update_all(fn.copy_e("weight", "msg"), fn.sum("msg", "norm"))
113 | g.apply_edges(fn.e_div_v("weight", "norm", "weight"))
114 | return g.edata["weight"]
115 |
116 |
117 | def preprocess(g, features, args):
118 | """
119 | Pre-compute the average of n-th hop neighbors
120 | """
121 | with torch.no_grad():
122 | g.edata["weight"] = calc_weight(g)
123 | g.ndata["feat_0"] = features
124 | for hop in range(1, args.R + 1):
125 | g.update_all(
126 | fn.u_mul_e(f"feat_{hop-1}", "weight", "msg"),
127 | fn.sum("msg", f"feat_{hop}"),
128 | )
129 | res = []
130 | for hop in range(args.R + 1):
131 | res.append(g.ndata.pop(f"feat_{hop}"))
132 | return res
133 |
134 |
135 | def prepare_data(device, args):
136 | data = load_dataset(args.partitioning_json_file, args.rank, device)
137 | g, n_classes, train_nid, val_nid, test_nid = data
138 | g = g.to(device)
139 | in_feats = g.ndata["feat"].shape[1]
140 | feats = preprocess(g, g.ndata["feat"], args)
141 | labels = g.ndata["label"]
142 | # move to device
143 | train_nid = train_nid.to(device)
144 | val_nid = val_nid.to(device)
145 | test_nid = test_nid.to(device)
146 | train_feats = [x[train_nid] for x in feats]
147 | train_labels = labels[train_nid]
148 | return (
149 | feats,
150 | labels,
151 | train_feats,
152 | train_labels,
153 | in_feats,
154 | n_classes,
155 | train_nid,
156 | val_nid,
157 | test_nid,
158 | )
159 |
160 | def evaluate(args, model, feats, labels, train, val, test):
161 | with torch.no_grad():
162 | batch_size = args.eval_batch_size
163 | if batch_size <= 0:
164 | pred = model(feats)
165 | else:
166 | pred = []
167 | num_nodes = labels.shape[0]
168 | n_batch = (num_nodes + batch_size - 1) // batch_size
169 | for i in range(n_batch):
170 | batch_start = i * batch_size
171 | batch_end = min((i + 1) * batch_size, num_nodes)
172 | batch_feats = [feat[batch_start:batch_end] for feat in feats]
173 | pred.append(model(batch_feats))
174 | pred = torch.cat(pred)
175 |
176 | pred = torch.argmax(pred, dim=1)
177 | correct = (pred == labels).float()
178 |
179 | # Sum the n_correct, and number of mask elements across all workers
180 | results = []
181 | for mask in [train, val, test]:
182 | n_correct = correct[mask].sum()
183 | results.extend([n_correct, mask.numel()])
184 |
185 | acc_vec = torch.FloatTensor(results)
186 | # Sum the n_correct, and number of mask elements across all workers
187 | sar.comm.all_reduce(acc_vec, op=torch.distributed.ReduceOp.SUM, move_to_comm_device=True)
188 | (train_acc, val_acc, test_acc) = \
189 | (acc_vec[0] / acc_vec[1],
190 | acc_vec[2] / acc_vec[3],
191 | acc_vec[4] / acc_vec[5],)
192 |
193 | return train_acc, val_acc, test_acc
194 |
195 |
196 | def main(args):
197 | if args.gpu < 0:
198 | device = "cpu"
199 | else:
200 | device = "cuda:{}".format(args.gpu)
201 |
202 | master_ip_address = sar.nfs_ip_init(args.rank, args.ip_file)
203 | sar.initialize_comms(args.rank,
204 | args.world_size, master_ip_address,
205 | args.backend)
206 |
207 | data = prepare_data(device, args)
208 | (
209 | feats,
210 | labels,
211 | train_feats,
212 | train_labels,
213 | in_size,
214 | num_classes,
215 | train_nid,
216 | val_nid,
217 | test_nid,
218 | ) = data
219 |
220 | model = Model(
221 | in_size,
222 | args.num_hidden,
223 | num_classes,
224 | args.R,
225 | args.ff_layer,
226 | args.dropout,
227 | ).to(device)
228 | loss_fcn = nn.CrossEntropyLoss()
229 | optimizer = torch.optim.Adam(
230 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay
231 | )
232 | sar.sync_params(model)
233 |
234 | best_epoch = 0
235 | best_val = 0
236 | best_test = 0
237 |
238 | for epoch in range(1, args.num_epochs + 1):
239 | start = time.time()
240 | model.train()
241 | loss = loss_fcn(model(train_feats), train_labels)
242 | optimizer.zero_grad()
243 | loss.backward()
244 | sar.gather_grads(model)
245 | optimizer.step()
246 |
247 | if epoch % args.eval_every == 0:
248 | model.eval()
249 | acc = evaluate(
250 | args, model, feats, labels, train_nid, val_nid, test_nid
251 | )
252 | end = time.time()
253 | log = "Epoch {}, Times(s): {:.4f}".format(epoch, end - start)
254 | log += ", Accuracy: Train {:.4f}, Val {:.4f}, Test {:.4f}".format(
255 | *acc
256 | )
257 | print(log)
258 | if acc[1] > best_val:
259 | best_val = acc[1]
260 | best_epoch = epoch
261 | best_test = acc[2]
262 |
263 | print(
264 | "Best Epoch {}, Val {:.4f}, Test {:.4f}".format(
265 | best_epoch, best_val, best_test
266 | )
267 | )
268 |
269 |
270 | if __name__ == "__main__":
271 | parser = argparse.ArgumentParser(description="SIGN")
272 | parser.add_argument("--partitioning-json-file", default="", type=str,
273 | help="Path to the .json file containing partitioning information")
274 | parser.add_argument("--ip-file", type=str, default="./ip_file",
275 | help="File with ip-address. "
276 | "Worker 0 creates this file and all others read it")
277 | parser.add_argument("--backend", type=str, default="ccl",
278 | choices=["ccl", "nccl", "mpi"],
279 | help="Communication backend to use")
280 | parser.add_argument("--rank", type=int, default=0,
281 | help="Rank of the current worker")
282 | parser.add_argument("--world-size", default=2, type=int,
283 | help="Number of workers ")
284 | parser.add_argument("--num-epochs", type=int, default=1000)
285 | parser.add_argument("--num-hidden", type=int, default=256)
286 | parser.add_argument("--R", type=int, default=3, help="number of hops")
287 | parser.add_argument("--lr", type=float, default=0.003)
288 | parser.add_argument("--dropout", type=float, default=0.5)
289 | parser.add_argument("--gpu", type=int, default=-1)
290 | parser.add_argument("--weight-decay", type=float, default=0)
291 | parser.add_argument("--eval-every", type=int, default=10)
292 | parser.add_argument("--eval-batch-size", type=int, default=250000,
293 | help="evaluation batch size, -1 for full batch")
294 | parser.add_argument("--ff-layer", type=int, default=2, help="number of feed-forward layers")
295 | args = parser.parse_args()
296 |
297 | print(args)
298 | main(args)
299 |
--------------------------------------------------------------------------------
/examples/partition_graph.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Intel Corporation
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy
4 | # of this software and associated documentation files (the "Software"), to deal
5 | # in the Software without restriction, including without limitation the rights
6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | # copies of the Software, and to permit persons to whom the Software is
8 | # furnished to do so, subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | # SOFTWARE.
20 |
21 | from argparse import ArgumentParser
22 | import dgl # type:ignore
23 | from dgl import AddReverse, Compose, ToSimple
24 | import torch
25 | from ogb.nodeproppred import DglNodePropPredDataset # type:ignore
26 | from dgl.data import (
27 | CiteseerGraphDataset,
28 | CoraGraphDataset,
29 | PubmedGraphDataset,
30 | RedditDataset
31 | )
32 |
33 | SUPPORTED_DATASETS = {
34 | "cora": CoraGraphDataset,
35 | "citeseer": CiteseerGraphDataset,
36 | "pubmed": PubmedGraphDataset,
37 | 'reddit': RedditDataset,
38 | "ogbn-products": DglNodePropPredDataset,
39 | "ogbn-arxiv": DglNodePropPredDataset,
40 | "ogbn-mag": DglNodePropPredDataset,
41 | }
42 |
43 | parser = ArgumentParser(description="Graph partitioning for common graph datasets")
44 |
45 | parser.add_argument("--dataset-root", type=str, default="./datasets/",
46 | help="The OGB datasets folder")
47 |
48 | parser.add_argument("--dataset-name", type=str, default="ogbn-arxiv",
49 | choices=['ogbn-arxiv', 'ogbn-products', 'ogbn-mag',
50 | 'cora', 'citeseer', 'pubmed', 'reddit'],
51 | help="Dataset name")
52 |
53 | parser.add_argument("--partition-out-path", type=str, default="./partition_data/",
54 | help="Path to the output directory for the partition data")
55 |
56 | parser.add_argument("--num-partitions", type=int, default=2,
57 | help="Number of graph partitions to generate")
58 |
59 | def get_dataset(args):
60 | dataset_name = args.dataset_name
61 | if dataset_name in ["cora", "citeseer", "pubmed"]:
62 | return SUPPORTED_DATASETS[dataset_name](args.dataset_root)
63 | elif dataset_name == 'reddit':
64 | return SUPPORTED_DATASETS[dataset_name](self_loop=True, raw_dir=args.dataset_root)
65 | else:
66 | return SUPPORTED_DATASETS[dataset_name](dataset_name, args.dataset_root)
67 |
68 | def prepare_features(args, dataset, graph):
69 | if args.dataset_name in ['cora', 'citeseer', 'pubmed', 'reddit']:
70 | assert all([x in graph.ndata.keys() for x in ['train_mask', 'val_mask', 'test_mask']])
71 | return
72 |
73 | split_idx = dataset.get_idx_split()
74 | ntype = "paper" if args.dataset_name == "ogbn-mag" else None
75 |
76 | def idx_to_mask(idx_tensor):
77 | mask = torch.BoolTensor(graph.number_of_nodes(ntype)).fill_(False)
78 | if ntype:
79 | mask[idx_tensor[ntype]] = True
80 | else:
81 | mask[idx_tensor] = True
82 | return mask
83 |
84 | train_mask, val_mask, test_mask = map(
85 | idx_to_mask, [split_idx["train"], split_idx["valid"], split_idx["test"]])
86 |
87 | if "feat" in graph.ndata.keys():
88 | features = graph.ndata["feat"]
89 | else:
90 | features = graph.ndata["features"]
91 |
92 | graph.ndata.clear()
93 |
94 | labels = dataset[0][1]
95 | if ntype:
96 | features = features[ntype]
97 | labels = labels[ntype]
98 | labels = labels.view(-1)
99 |
100 | for name, val in zip(["train_mask", "val_mask", "test_mask", "labels", "features"],
101 | [train_mask, val_mask, test_mask, labels, features]):
102 | graph.ndata[name] = {ntype: val} if ntype else val
103 |
104 | def main():
105 | args = parser.parse_args()
106 | dataset = get_dataset(args)
107 | dataset_name = args.dataset_name
108 | if dataset_name.startswith("ogbn"):
109 | graph = dataset[0][0]
110 | else:
111 | graph = dataset[0]
112 |
113 | if dataset_name != "ogbn-mag":
114 | graph = dgl.remove_self_loop(graph)
115 | graph = dgl.to_bidirected(graph, copy_ndata=True)
116 | graph = dgl.add_self_loop(graph)
117 | else:
118 |
119 | transform = Compose([ToSimple(), AddReverse()])
120 | graph = transform(graph)
121 |
122 | prepare_features(args, dataset, graph)
123 | balance_ntypes = graph.ndata["train_mask"] \
124 | if dataset_name in ["ogbn-products", "ogbn-arxiv"] else None
125 | dgl.distributed.partition_graph(
126 | graph, args.dataset_name,
127 | args.num_partitions,
128 | args.partition_out_path,
129 | num_hops=1,
130 | balance_ntypes=balance_ntypes,
131 | balance_edges=True)
132 |
133 |
134 | if __name__ == "__main__":
135 | main()
136 |
--------------------------------------------------------------------------------
/examples/rgcn-hetero/README.md:
--------------------------------------------------------------------------------
1 | ## Hetero RGCN
2 |
3 | Exampel script for ogbn-mag dataset.
4 | Original script: https://github.com/dmlc/dgl/blob/master/examples/pytorch/ogb/ogbn-mag/hetero_rgcn.py
5 | You can find here two scripts `train_heterogeneous_graph.py` and `train_heterogeneous_graph_mfg.py`, the former is a simple full graph training and inference. The latter is a training and inference script which utilizes Message Flow Graph (MFG) - this approach is more computationally effective, because it computes embeddings only for nodes which require it, i.e. labeled nodes.
6 |
7 | ### Results
8 | Obtained results for two partitions (ogbn-mag dataset):
9 | - Train Acc: 77.18 ± 2.85%
10 | - Validation Acc: 40.03 ± 0.45%
11 | - Test Acc: 39.06 ± 0.44%
12 | Presented results are the average accuracies obtained after running 10 experiments (1 experiment = 60 epochs). Results from each experiment (train/val/test accuracies) were not necessarily taken from the 60th epoch. The values were obtained at the moment when the validation accuraccy was the highest.
13 | (Note: Results achieved in https://github.com/dmlc/dgl/blob/master/examples/pytorch/ogb/ogbn-mag/hetero_rgcn.py are different, because it uses mini-batch training instead of full-graph like in SAR)
14 |
15 |
16 | ### Run command:
17 |
18 | ```
19 | python examples/rgcn-hetero/train_heterogeneous_graph.py --partitioning-json-file partition_data/ogbn-mag.json --ip-file ip_file --backend ccl --rank 0 --world-size 2 --train-iters 60
20 |
21 | python examples/rgcn-hetero/train_heterogeneous_graph.py --partitioning-json-file partition_data/ogbn-mag.json --ip-file ip_file --backend ccl --rank 1 --world-size 2 --train-iters 60
22 | ```
--------------------------------------------------------------------------------
/examples/rgcn-hetero/model.py:
--------------------------------------------------------------------------------
1 | ################################################################################################################
2 | # File's content taken from https://github.com/dmlc/dgl/blob/master/examples/pytorch/ogb/ogbn-mag/hetero_rgcn.py
3 | ################################################################################################################
4 |
5 | import dgl.nn as dglnn
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from dgl.nn import HeteroEmbedding
9 |
10 | def extract_embed(node_embed, input_nodes):
11 | emb = node_embed(
12 | {ntype: input_nodes[ntype] for ntype in input_nodes if ntype != "paper"}
13 | )
14 | return emb
15 |
16 | def rel_graph_embed(graph, embed_size, num_nodes_dict):
17 | node_num = {}
18 | for ntype in graph.ntypes:
19 | if ntype == "paper":
20 | continue
21 | node_num[ntype] = num_nodes_dict[ntype]
22 | embeds = HeteroEmbedding(node_num, embed_size)
23 | return embeds
24 |
25 |
26 | class RelGraphConvLayer(nn.Module):
27 | def __init__(
28 | self, in_feat, out_feat, ntypes, rel_names, activation=None, dropout=0.0, self_loop=True
29 | ):
30 | super(RelGraphConvLayer, self).__init__()
31 | self.in_feat = in_feat
32 | self.out_feat = out_feat
33 | self.ntypes = ntypes
34 | self.rel_names = rel_names
35 | self.activation = activation
36 | self.self_loop = self_loop
37 |
38 | self.conv = dglnn.HeteroGraphConv(
39 | {
40 | rel: dglnn.GraphConv(
41 | in_feat, out_feat, norm="right", weight=False, bias=False
42 | )
43 | for rel in rel_names
44 | }
45 | )
46 |
47 | self.weight = nn.ModuleDict(
48 | {
49 | rel_name: nn.Linear(in_feat, out_feat, bias=False)
50 | for rel_name in self.rel_names
51 | }
52 | )
53 |
54 | # weight for self loop
55 | if self.self_loop:
56 | self.loop_weights = nn.ModuleDict(
57 | {
58 | ntype: nn.Linear(in_feat, out_feat, bias=True)
59 | for ntype in self.ntypes
60 | }
61 | )
62 |
63 | self.dropout = nn.Dropout(dropout)
64 | self.reset_parameters()
65 |
66 | def reset_parameters(self):
67 | for layer in self.weight.values():
68 | layer.reset_parameters()
69 | if self.self_loop:
70 | for layer in self.loop_weights.values():
71 | layer.reset_parameters()
72 |
73 | def forward(self, g, inputs):
74 | """
75 | Parameters
76 | ----------
77 | g : DGLGraph
78 | Input graph.
79 | inputs : dict[str, torch.Tensor]
80 | Node feature for each node type.
81 |
82 | Returns
83 | -------
84 | dict[str, torch.Tensor]
85 | New node features for each node type.
86 | """
87 | with g.local_scope():
88 | wdict = {
89 | rel_name: {"weight": self.weight[rel_name].weight.T}
90 | for rel_name in self.rel_names
91 | }
92 |
93 | if self.self_loop:
94 | inputs_dst = {
95 | k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()
96 | }
97 |
98 | hs = self.conv(g, inputs, mod_kwargs=wdict)
99 |
100 | def _apply(ntype, h):
101 | if self.self_loop:
102 | h = h + self.loop_weights[ntype](inputs_dst[ntype])
103 | if self.activation:
104 | h = self.activation(h)
105 | return self.dropout(h)
106 |
107 | return {ntype: _apply(ntype, h) for ntype, h in hs.items()}
108 |
109 |
110 | class EntityClassify(nn.Module):
111 | def __init__(self, g, in_dim, out_dim):
112 | super(EntityClassify, self).__init__()
113 | self.in_dim = in_dim
114 | self.h_dim = 64
115 | self.out_dim = out_dim
116 | self.rel_names = list(set(g.etypes))
117 | self.rel_names.sort()
118 | self.dropout = 0.5
119 |
120 | self.layers = nn.ModuleList()
121 | # i2h
122 | self.layers.append(
123 | RelGraphConvLayer(
124 | self.in_dim,
125 | self.h_dim,
126 | g.ntypes,
127 | self.rel_names,
128 | activation=F.relu,
129 | dropout=self.dropout,
130 | self_loop=g.tgt_in_src
131 | )
132 | )
133 |
134 | # h2o
135 | self.layers.append(
136 | RelGraphConvLayer(
137 | self.h_dim,
138 | self.out_dim,
139 | g.ntypes,
140 | self.rel_names,
141 | activation=None,
142 | self_loop=g.tgt_in_src
143 | )
144 | )
145 |
146 | def reset_parameters(self):
147 | for layer in self.layers:
148 | layer.reset_parameters()
149 |
150 | def forward(self, graph, h):
151 | if isinstance(graph, list):
152 | # Message Flow Graph
153 | for layer, block in zip(self.layers, graph):
154 | h = layer(block, h)
155 | else:
156 | for layer in self.layers:
157 | h = layer(graph, h)
158 | return h
159 |
--------------------------------------------------------------------------------
/examples/rgcn-hetero/train_heterogeneous_graph.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Intel Corporation
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy
4 | # of this software and associated documentation files (the "Software"), to deal
5 | # in the Software without restriction, including without limitation the rights
6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | # copies of the Software, and to permit persons to whom the Software is
8 | # furnished to do so, subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | # SOFTWARE.
20 |
21 | from argparse import ArgumentParser
22 | from model import EntityClassify, rel_graph_embed, extract_embed
23 | import time
24 | import itertools
25 | import torch
26 | import torch.nn.functional as F
27 | import torch.distributed as dist
28 | import dgl # type: ignore
29 |
30 | import sar
31 |
32 |
33 | parser = ArgumentParser(
34 | description="GNN training on node classification tasks in heterogenous graphs (MFG)")
35 |
36 |
37 | parser.add_argument("--partitioning-json-file", type=str, default="",
38 | help="Path to the .json file containing partitioning information")
39 |
40 | parser.add_argument("--ip-file", default="./ip_file", type=str,
41 | help="File with ip-address. Worker 0 creates this file and all others read it")
42 |
43 | parser.add_argument("--backend", default="nccl", type=str, choices=["ccl", "nccl", "mpi", "gloo"],
44 | help="Communication backend to use")
45 |
46 | parser.add_argument("--cpu-run", action="store_true",
47 | help="Run on CPUs if set, otherwise run on GPUs")
48 |
49 | parser.add_argument("--train-iters", default=60, type=int,
50 | help="number of training iterations")
51 |
52 | parser.add_argument("--lr", type=float, default=0.01,
53 | help="learning rate")
54 |
55 | parser.add_argument("--rank", default=0, type=int,
56 | help="Rank of the current worker")
57 |
58 | parser.add_argument("--world-size", default=2, type=int,
59 | help="Number of workers")
60 |
61 | parser.add_argument("--features-dim", default=128, type=int,
62 | help="Dimension of the node features")
63 |
64 |
65 | def main():
66 | args = parser.parse_args()
67 | print('args', args)
68 |
69 | # Patch DGL's attention-based layers and RelGraphConv to support distributed graphs
70 | sar.patch_dgl()
71 |
72 | use_gpu = torch.cuda.is_available() and not args.cpu_run
73 | device = torch.device('cuda' if use_gpu else 'cpu')
74 |
75 | # Obtain the ip address of the master through the network file system
76 | master_ip_address = sar.nfs_ip_init(args.rank, args.ip_file)
77 | sar.initialize_comms(args.rank,
78 | args.world_size, master_ip_address,
79 | args.backend)
80 |
81 | # Load DGL partition data
82 | partition_data = sar.load_dgl_partition_data(
83 | args.partitioning_json_file, args.rank, device)
84 | partition_data.node_features["paper/features"] = partition_data.node_features["paper/features"].float()
85 |
86 | # Obtain train,validation, and test masks
87 | # These are stored as node features. Partitioning may prepend
88 | # the node type to the mask names. So we use the convenience function
89 | # suffix_key_lookup to look up the mask name while ignoring the
90 | # arbitrary node type
91 | #The train/val/test masks are only defined for nodes with type 'paper'.
92 | #We set the ``expand_to_all`` flag to expand the mask to all nodes in the
93 | #graph (mask will be filled with zeros). We use the expand_all option when
94 | #loading other node-type specific tensors such as features and labels
95 |
96 | bool_masks = {}
97 | for mask_name in ['train_mask', 'val_mask', 'test_mask']:
98 | local_mask = sar.suffix_key_lookup(partition_data.node_features,
99 | mask_name,
100 | expand_to_all = False,
101 | type_list = partition_data.node_type_names)
102 | bool_masks[mask_name] = local_mask.bool()
103 |
104 | labels = sar.suffix_key_lookup(partition_data.node_features,
105 | 'labels',
106 | expand_to_all = False,
107 | type_list = partition_data.node_type_names).long().to(device)
108 |
109 | # Obtain the number of classes by finding the max label across all workers
110 | num_labels = labels.max() + 1
111 | sar.comm.all_reduce(num_labels, dist.ReduceOp.MAX, move_to_comm_device=True)
112 | num_labels = num_labels.item()
113 |
114 | features = sar.suffix_key_lookup(partition_data.node_features,
115 | 'features',
116 | type_list = partition_data.node_type_names
117 | ).to(device)
118 |
119 | full_graph_manager = sar.construct_full_graph(partition_data).to(device)
120 |
121 | max_num_nodes = {}
122 | for ntype in partition_data.partition_book.ntypes:
123 | nodes = full_graph_manager.srcnodes(ntype)
124 | nodes_max = nodes.max()
125 | max_num_nodes[ntype] = nodes_max + 1
126 |
127 | embed_layer = rel_graph_embed(full_graph_manager, args.features_dim, max_num_nodes).to(device)
128 | gnn_model = EntityClassify(full_graph_manager, args.features_dim, num_labels).to(device)
129 |
130 | print('model', gnn_model)
131 | embed_layer.reset_parameters()
132 | gnn_model.reset_parameters()
133 |
134 | # Synchronize the model parmeters across all workers
135 | sar.sync_params(gnn_model)
136 |
137 | # Obtain the number of labeled nodes in the training
138 | # This will be needed to properly obtain a cross entropy loss
139 | # normalized by the number of training examples
140 | n_train_points = torch.LongTensor([bool_masks["train_mask"].sum().item()])
141 | sar.comm.all_reduce(n_train_points, op=dist.ReduceOp.SUM, move_to_comm_device=True)
142 | n_train_points = n_train_points.item()
143 |
144 | all_params = itertools.chain(
145 | gnn_model.parameters(), embed_layer.parameters()
146 | )
147 | optimizer = torch.optim.Adam(all_params, lr=args.lr)
148 |
149 | for train_iter_idx in range(args.train_iters):
150 | # Train
151 | t_1 = time.time()
152 | gnn_model.train()
153 |
154 | embeds = extract_embed(embed_layer, {ntype: full_graph_manager.srcnodes(ntype) for ntype in full_graph_manager.srctypes})
155 | embeds.update({"paper": features[full_graph_manager.srcnodes("paper")]})
156 | embeds = {k: e.to(device) for k, e in embeds.items()}
157 |
158 | logits = gnn_model(full_graph_manager, embeds)
159 | logits = logits["paper"].log_softmax(dim=-1)
160 | train_mask = bool_masks["train_mask"]
161 | loss = F.nll_loss(logits[train_mask], labels[train_mask], reduction="sum") / n_train_points
162 |
163 | optimizer.zero_grad()
164 | loss.backward()
165 | # Do not forget to gather the parameter gradients from all workers
166 | sar.gather_grads(gnn_model)
167 | optimizer.step()
168 | train_time = time.time() - t_1
169 |
170 | # Calculate accuracy for train/validation/test
171 | results = []
172 | gnn_model.eval()
173 | with torch.no_grad():
174 | embeds = extract_embed(embed_layer, {ntype: full_graph_manager.srcnodes(ntype) for ntype in full_graph_manager.srctypes})
175 | embeds.update({"paper": features[full_graph_manager.srcnodes("paper")]})
176 | embeds = {k: e.to(device) for k, e in embeds.items()}
177 |
178 | logits = gnn_model(full_graph_manager, embeds)
179 | logits = logits["paper"].log_softmax(dim=-1)
180 |
181 | for mask_name in ['train_mask', 'val_mask', 'test_mask']:
182 | masked_nodes = bool_masks[mask_name]
183 | if masked_nodes.sum() > 0:
184 | active_logits = logits[masked_nodes]
185 | active_labels = labels[masked_nodes]
186 | loss = F.nll_loss(active_logits, active_labels, reduction="sum")
187 | n_correct = (active_logits.argmax(1) == active_labels).float().sum()
188 | results.extend([loss.item(), n_correct.item(), masked_nodes.sum().item()])
189 | else:
190 | results.extend([0.0, 0.0, 0.0])
191 |
192 | loss_acc_vec = torch.FloatTensor(results)
193 | # Sum the n_correct, and number of mask elements across all workers
194 | sar.comm.all_reduce(loss_acc_vec, op=dist.ReduceOp.SUM, move_to_comm_device=True)
195 | (train_loss, train_acc, val_loss, val_acc, test_loss, test_acc) = \
196 | (loss_acc_vec[0] / loss_acc_vec[2],
197 | loss_acc_vec[1] / loss_acc_vec[2],
198 | loss_acc_vec[3] / loss_acc_vec[5],
199 | loss_acc_vec[4] / loss_acc_vec[5],
200 | loss_acc_vec[6] / loss_acc_vec[8],
201 | loss_acc_vec[7] / loss_acc_vec[8])
202 |
203 | result_message = (
204 | f"iteration [{train_iter_idx}/{args.train_iters}] | "
205 | )
206 | result_message += ', '.join([
207 | f"train loss={train_loss:.4f}, "
208 | f"Accuracy: "
209 | f"train={100 * train_acc:.4f} "
210 | f"valid={100 * val_acc:.4f} "
211 | f"test={100 * test_acc:.4f} "
212 | f" | train time = {train_time} "
213 | f" |"
214 | ])
215 | print(result_message, flush=True)
216 |
217 |
218 | if __name__ == '__main__':
219 | main()
220 |
--------------------------------------------------------------------------------
/examples/rgcn-hetero/train_heterogeneous_graph_mfg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Intel Corporation
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy
4 | # of this software and associated documentation files (the "Software"), to deal
5 | # in the Software without restriction, including without limitation the rights
6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | # copies of the Software, and to permit persons to whom the Software is
8 | # furnished to do so, subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | # SOFTWARE.
20 |
21 | from argparse import ArgumentParser
22 | from model import EntityClassify, rel_graph_embed, extract_embed
23 | import time
24 | import itertools
25 | import torch
26 | import torch.nn.functional as F
27 | import torch.distributed as dist
28 | import dgl # type: ignore
29 |
30 | import sar
31 |
32 |
33 | parser = ArgumentParser(
34 | description="GNN training on node classification tasks in heterogenous graphs")
35 |
36 |
37 | parser.add_argument("--partitioning-json-file", type=str, default="",
38 | help="Path to the .json file containing partitioning information")
39 |
40 | parser.add_argument("--ip-file", default="./ip_file", type=str,
41 | help="File with ip-address. Worker 0 creates this file and all others read it")
42 |
43 | parser.add_argument("--backend", default="nccl", type=str, choices=["ccl", "nccl", "mpi", "gloo"],
44 | help="Communication backend to use")
45 |
46 | parser.add_argument("--cpu-run", action="store_true",
47 | help="Run on CPUs if set, otherwise run on GPUs")
48 |
49 | parser.add_argument("--train-iters", default=60, type=int,
50 | help="number of training iterations")
51 |
52 | parser.add_argument("--lr", type=float, default=0.01,
53 | help="learning rate")
54 |
55 | parser.add_argument("--rank", default=0, type=int,
56 | help="Rank of the current worker")
57 |
58 | parser.add_argument("--world-size", default=2, type=int,
59 | help="Number of workers")
60 |
61 | parser.add_argument("--features-dim", default=128, type=int,
62 | help="Dimension of the node features")
63 |
64 |
65 | def main():
66 | args = parser.parse_args()
67 | print('args', args)
68 |
69 | # Patch DGL's attention-based layers and RelGraphConv to support distributed graphs
70 | sar.patch_dgl()
71 |
72 | use_gpu = torch.cuda.is_available() and not args.cpu_run
73 | device = torch.device('cuda' if use_gpu else 'cpu')
74 |
75 | # Obtain the ip address of the master through the network file system
76 | master_ip_address = sar.nfs_ip_init(args.rank, args.ip_file)
77 | sar.initialize_comms(args.rank,
78 | args.world_size, master_ip_address,
79 | args.backend)
80 |
81 | # Load DGL partition data
82 | partition_data = sar.load_dgl_partition_data(
83 | args.partitioning_json_file, args.rank, device)
84 | partition_data.node_features["paper/features"] = partition_data.node_features["paper/features"].float()
85 |
86 | # Obtain train,validation, and test masks
87 | # These are stored as node features. Partitioning may prepend
88 | # the node type to the mask names. So we use the convenience function
89 | # suffix_key_lookup to look up the mask name while ignoring the
90 | # arbitrary node type
91 | #The train/val/test masks are only defined for nodes with type 'paper'.
92 | #We set the ``expand_to_all`` flag to expand the mask to all nodes in the
93 | #graph (mask will be filled with zeros). We use the expand_all option when
94 | #loading other node-type specific tensors such as features and labels
95 |
96 | bool_masks = {}
97 | for mask_name in ['train_mask', 'val_mask', 'test_mask']:
98 | local_mask = sar.suffix_key_lookup(partition_data.node_features,
99 | mask_name,
100 | expand_to_all = False,
101 | type_list = partition_data.node_type_names)
102 | bool_masks[mask_name] = local_mask.bool()
103 |
104 | indices_masks = {}
105 | for mask_name, indices_name in zip(['train_mask', 'val_mask', 'test_mask'],
106 | ['train_indices', 'val_indices', 'test_indices']):
107 | global_mask = sar.suffix_key_lookup(partition_data.node_features,
108 | mask_name,
109 | expand_to_all = True,
110 | type_list = partition_data.node_type_names)
111 | indices_masks[indices_name] = global_mask.nonzero(as_tuple=False).view(-1).to(device)
112 |
113 |
114 | labels = sar.suffix_key_lookup(partition_data.node_features,
115 | 'labels',
116 | expand_to_all = False,
117 | type_list = partition_data.node_type_names).long().to(device)
118 |
119 | # Obtain the number of classes by finding the max label across all workers
120 | num_labels = labels.max() + 1
121 | sar.comm.all_reduce(num_labels, dist.ReduceOp.MAX, move_to_comm_device=True)
122 | num_labels = num_labels.item()
123 |
124 | features = sar.suffix_key_lookup(partition_data.node_features,
125 | 'features',
126 | type_list = partition_data.node_type_names
127 | ).to(device)
128 | # Create MFGs
129 | train_blocks = sar.construct_mfgs(partition_data,
130 | indices_masks['train_indices'] +
131 | partition_data.node_ranges[sar.comm.rank()][0],
132 | 2)
133 |
134 | eval_blocks = sar.construct_mfgs(partition_data,
135 | torch.cat((indices_masks['train_indices'],
136 | indices_masks['val_indices'],
137 | indices_masks['test_indices'])) +
138 | partition_data.node_ranges[sar.comm.rank()][0],
139 | 2)
140 | train_blocks = [block.to(device) for block in train_blocks]
141 | eval_blocks = [block.to(device) for block in eval_blocks]
142 |
143 | max_num_nodes = {}
144 | for ntype in partition_data.partition_book.ntypes:
145 | nodes = eval_blocks[0].srcnodes(ntype)
146 | nodes_max = nodes.max()
147 | max_num_nodes[ntype] = nodes_max + 1
148 |
149 | embed_layer = rel_graph_embed(eval_blocks[0], args.features_dim, max_num_nodes).to(device)
150 | gnn_model = EntityClassify(eval_blocks[0], args.features_dim, num_labels).to(device)
151 |
152 | print('model', gnn_model)
153 | embed_layer.reset_parameters()
154 | gnn_model.reset_parameters()
155 |
156 | # Synchronize the model parmeters across all workers
157 | sar.sync_params(gnn_model)
158 |
159 | # Obtain the number of labeled nodes in the training
160 | # This will be needed to properly obtain a cross entropy loss
161 | # normalized by the number of training examples
162 | n_train_points = torch.LongTensor([indices_masks['train_indices'].numel()])
163 | sar.comm.all_reduce(n_train_points, op=dist.ReduceOp.SUM, move_to_comm_device=True)
164 | n_train_points = n_train_points.item()
165 |
166 | all_params = itertools.chain(
167 | gnn_model.parameters(), embed_layer.parameters()
168 | )
169 | optimizer = torch.optim.Adam(all_params, lr=args.lr)
170 |
171 | for train_iter_idx in range(args.train_iters):
172 | # Train
173 | t_1 = time.time()
174 | gnn_model.train()
175 |
176 | embeds = extract_embed(embed_layer, {ntype: train_blocks[0].srcnodes(ntype) for ntype in train_blocks[0].srctypes})
177 | embeds.update({"paper": features[train_blocks[0].srcnodes("paper")]})
178 | embeds = {k: e.to(device) for k, e in embeds.items()}
179 |
180 | logits = gnn_model(train_blocks, embeds)
181 | logits = logits["paper"].log_softmax(dim=-1)
182 | train_mask = bool_masks["train_mask"]
183 | loss = F.nll_loss(logits, labels[train_mask], reduction="sum") / n_train_points
184 |
185 | optimizer.zero_grad()
186 | loss.backward()
187 | # Do not forget to gather the parameter gradients from all workers
188 | sar.gather_grads(gnn_model)
189 | optimizer.step()
190 | train_time = time.time() - t_1
191 |
192 | # Calculate accuracy for train/validation/test
193 | results = []
194 | gnn_model.eval()
195 | with torch.no_grad():
196 | embeds = extract_embed(embed_layer, {ntype: eval_blocks[0].srcnodes(ntype) for ntype in eval_blocks[0].srctypes})
197 | embeds.update({"paper": features[eval_blocks[0].srcnodes("paper")]})
198 | embeds = {k: e.to(device) for k, e in embeds.items()}
199 |
200 | logits = gnn_model(eval_blocks, embeds)
201 | logits = logits["paper"].log_softmax(dim=-1)
202 |
203 | for mask_name in ['train_mask', 'val_mask', 'test_mask']:
204 | masked_nodes = bool_masks[mask_name]
205 | if masked_nodes.sum() > 0:
206 | active_logits = logits[masked_nodes]
207 | active_labels = labels[masked_nodes]
208 | loss = F.nll_loss(active_logits, active_labels, reduction="sum")
209 | n_correct = (active_logits.argmax(1) == active_labels).float().sum()
210 | results.extend([loss.item(), n_correct.item(), masked_nodes.sum().item()])
211 | else:
212 | results.extend([0.0, 0.0, 0.0])
213 |
214 | loss_acc_vec = torch.FloatTensor(results)
215 | # Sum the n_correct, and number of mask elements across all workers
216 | sar.comm.all_reduce(loss_acc_vec, op=dist.ReduceOp.SUM, move_to_comm_device=True)
217 | (train_loss, train_acc, val_loss, val_acc, test_loss, test_acc) = \
218 | (loss_acc_vec[0] / loss_acc_vec[2],
219 | loss_acc_vec[1] / loss_acc_vec[2],
220 | loss_acc_vec[3] / loss_acc_vec[5],
221 | loss_acc_vec[4] / loss_acc_vec[5],
222 | loss_acc_vec[6] / loss_acc_vec[8],
223 | loss_acc_vec[7] / loss_acc_vec[8])
224 |
225 | result_message = (
226 | f"iteration [{train_iter_idx}/{args.train_iters}] | "
227 | )
228 | result_message += ', '.join([
229 | f"train loss={train_loss:.4f}, "
230 | f"Accuracy: "
231 | f"train={100 * train_acc:.4f} "
232 | f"valid={100 * val_acc:.4f} "
233 | f"test={100 * test_acc:.4f} "
234 | f" | train time = {train_time} "
235 | f" |"
236 | ])
237 | print(result_message, flush=True)
238 |
239 |
240 | if __name__ == '__main__':
241 | main()
242 |
--------------------------------------------------------------------------------
/examples/train_dist_appnp_with_sar.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | import dgl # type: ignore
4 | from dgl.nn.pytorch.conv import APPNPConv
5 |
6 | import sar
7 |
8 | import time
9 | import torch
10 | import torch.nn.functional as F
11 | from torch import nn
12 | import torch.distributed as dist
13 |
14 | parser = ArgumentParser(description="APPNP example")
15 |
16 | parser.add_argument("--partitioning-json-file", type=str, default="",
17 | help="Path to the .json file containing partitioning information")
18 |
19 | parser.add_argument("--ip-file", type=str, default="./ip_file",
20 | help="File with ip-address. Worker 0 creates this file and all others read it")
21 |
22 | parser.add_argument("--backend", type=str, default="nccl",
23 | choices=["ccl", "nccl", "mpi", "gloo"],
24 | help="Communication backend to use")
25 |
26 | parser.add_argument("--cpu-run", action="store_true",
27 | help="Run on CPUs if set, otherwise run on GPUs")
28 |
29 | parser.add_argument("--train-iters", type=int, default=100,
30 | help="number of training iterations")
31 |
32 | parser.add_argument("--lr", type=float, default=1e-2,
33 | help="learning rate")
34 |
35 | parser.add_argument("--rank", type=int, default=0,
36 | help="Rank of the current worker")
37 |
38 | parser.add_argument("--world-size", type=int, default=2,
39 | help="Number of workers")
40 |
41 | parser.add_argument("--hidden-layer-dim", type=int, default=[64], nargs="+",
42 | help="Dimension of GNN hidden layer")
43 |
44 | parser.add_argument("--k", type=int, default=10,
45 | help="Number of propagation steps")
46 |
47 | parser.add_argument("--alpha", type=float, default=0.1,
48 | help="Teleport Probability")
49 |
50 | parser.add_argument("--in-drop", type=float, default=0.5,
51 | help="input feature dropout")
52 |
53 | parser.add_argument("--edge-drop", type=float, default=0.5,
54 | help="edge propagation dropout")
55 |
56 | class APPNP(nn.Module):
57 | def __init__(
58 | self,
59 | g,
60 | in_feats,
61 | hiddens,
62 | n_classes,
63 | activation,
64 | feat_drop,
65 | edge_drop,
66 | alpha,
67 | k,
68 | ):
69 | super(APPNP, self).__init__()
70 | self.g = g
71 | self.layers = nn.ModuleList()
72 | # input layer
73 | self.layers.append(nn.Linear(in_feats, hiddens[0]))
74 | # hidden layers
75 | for i in range(1, len(hiddens)):
76 | self.layers.append(nn.Linear(hiddens[i - 1], hiddens[i]))
77 | # output layer
78 | self.layers.append(nn.Linear(hiddens[-1], n_classes))
79 | self.activation = activation
80 | if feat_drop:
81 | self.feat_drop = nn.Dropout(feat_drop)
82 | else:
83 | self.feat_drop = lambda x: x
84 | self.propagate = APPNPConv(k, alpha, edge_drop)
85 | self.reset_parameters()
86 |
87 | def reset_parameters(self):
88 | for layer in self.layers:
89 | layer.reset_parameters()
90 |
91 | def forward(self, features):
92 | # prediction step
93 | h = features
94 | h = self.feat_drop(h)
95 | h = self.activation(self.layers[0](h))
96 | for layer in self.layers[1:-1]:
97 | h = self.activation(layer(h))
98 | h = self.layers[-1](self.feat_drop(h))
99 | # propagation step
100 | h = self.propagate(self.g, h)
101 | return h
102 |
103 | def evaluate(model, features, labels, masks):
104 | model.eval()
105 | train_mask, val_mask, test_mask = masks['train_indices'], masks['val_indices'], masks['test_indices']
106 | with torch.no_grad():
107 | logits = model(features)
108 | results = []
109 | for mask in [train_mask, val_mask, test_mask]:
110 | n_correct = (logits[mask].argmax(1) ==
111 | labels[mask]).float().sum()
112 | results.extend([n_correct, mask.numel()])
113 |
114 | acc_vec = torch.FloatTensor(results)
115 | # Sum the n_correct, and number of mask elements across all workers
116 | sar.comm.all_reduce(acc_vec, op=dist.ReduceOp.SUM, move_to_comm_device=True)
117 | (train_acc, val_acc, test_acc) = \
118 | (acc_vec[0] / acc_vec[1],
119 | acc_vec[2] / acc_vec[3],
120 | acc_vec[4] / acc_vec[5])
121 |
122 | return train_acc, val_acc, test_acc
123 |
124 | def main():
125 | args = parser.parse_args()
126 | print('args', args)
127 |
128 | use_gpu = torch.cuda.is_available() and not args.cpu_run
129 | device = torch.device('cuda' if use_gpu else 'cpu')
130 |
131 | # Obtain the ip address of the master through the network file system
132 | master_ip_address = sar.nfs_ip_init(args.rank, args.ip_file)
133 | sar.initialize_comms(args.rank,
134 | args.world_size,
135 | master_ip_address,
136 | args.backend)
137 |
138 | # Load DGL partition data
139 | partition_data = sar.load_dgl_partition_data(
140 | args.partitioning_json_file, args.rank, device)
141 |
142 | # Obtain train,validation, and test masks
143 | # These are stored as node features. Partitioning may prepend
144 | # the node type to the mask names. So we use the convenience function
145 | # suffix_key_lookup to look up the mask name while ignoring the
146 | # arbitrary node type
147 | masks = {}
148 | for mask_name, indices_name in zip(['train_mask', 'val_mask', 'test_mask'],
149 | ['train_indices', 'val_indices', 'test_indices']):
150 | boolean_mask = sar.suffix_key_lookup(partition_data.node_features,
151 | mask_name)
152 | masks[indices_name] = boolean_mask.nonzero(
153 | as_tuple=False).view(-1).to(device)
154 |
155 | labels = sar.suffix_key_lookup(partition_data.node_features,
156 | 'label').long().to(device)
157 |
158 | # Obtain the number of classes by finding the max label across all workers
159 | num_labels = labels.max() + 1
160 | sar.comm.all_reduce(num_labels, dist.ReduceOp.MAX, move_to_comm_device=True)
161 | num_labels = num_labels.item()
162 |
163 | features = sar.suffix_key_lookup(partition_data.node_features, 'feat').to(device)
164 | full_graph_manager = sar.construct_full_graph(partition_data).to(device)
165 |
166 | # We do not need the partition data anymore
167 | del partition_data
168 |
169 | gnn_model = APPNP(
170 | full_graph_manager,
171 | features.size(1),
172 | args.hidden_layer_dim,
173 | num_labels,
174 | F.relu,
175 | args.in_drop,
176 | args.edge_drop,
177 | args.alpha,
178 | args.k)
179 |
180 | gnn_model.reset_parameters()
181 | print('model', gnn_model)
182 |
183 | # Synchronize the model parmeters across all workers
184 | sar.sync_params(gnn_model)
185 |
186 | # Obtain the number of labeled nodes in the training
187 | # This will be needed to properly obtain a cross entropy loss
188 | # normalized by the number of training examples
189 | n_train_points = torch.LongTensor([masks['train_indices'].numel()])
190 | sar.comm.all_reduce(n_train_points, op=dist.ReduceOp.SUM, move_to_comm_device=True)
191 | n_train_points = n_train_points.item()
192 |
193 | optimizer = torch.optim.Adam(gnn_model.parameters(), lr=args.lr, weight_decay=5e-4)
194 | for train_iter_idx in range(args.train_iters):
195 | # Train
196 | gnn_model.train()
197 | t_1 = time.time()
198 | logits = gnn_model(features)
199 | loss = F.cross_entropy(logits[masks['train_indices']],
200 | labels[masks['train_indices']], reduction='sum') / n_train_points
201 |
202 | optimizer.zero_grad()
203 | loss.backward()
204 | # Do not forget to gather the parameter gradients from all workers
205 | sar.gather_grads(gnn_model)
206 | optimizer.step()
207 | train_time = time.time() - t_1
208 |
209 | if (train_iter_idx + 1) % 10 == 0:
210 | train_acc, val_acc, test_acc = evaluate(gnn_model, features, labels, masks)
211 |
212 | result_message = (
213 | f"iteration [{train_iter_idx + 1}/{args.train_iters}] | "
214 | )
215 | result_message += ', '.join([
216 | f"train loss={loss:.4f}, "
217 | f"Accuracy: "
218 | f"train={train_acc:.4f} "
219 | f"valid={val_acc:.4f} "
220 | f"test={test_acc:.4f} "
221 | f" | train time = {train_time} "
222 | f" |"
223 | ])
224 | print(result_message, flush=True)
225 |
226 | if __name__ == '__main__':
227 | main()
228 |
--------------------------------------------------------------------------------
/examples/train_homogeneous_graph_basic.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Intel Corporation
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy
4 | # of this software and associated documentation files (the "Software"), to deal
5 | # in the Software without restriction, including without limitation the rights
6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | # copies of the Software, and to permit persons to whom the Software is
8 | # furnished to do so, subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | # SOFTWARE.
20 |
21 | from typing import List, Union, Dict
22 | from argparse import ArgumentParser
23 | import os
24 | import logging
25 | import time
26 | import torch
27 | import torch.nn.functional as F
28 | from torch import nn
29 | import torch.distributed as dist
30 | import dgl # type: ignore
31 |
32 | import sar
33 |
34 |
35 | parser = ArgumentParser(
36 | description="GNN training on node classification tasks in homogeneous graphs")
37 |
38 |
39 | parser.add_argument(
40 | "--partitioning-json-file",
41 | type=str,
42 | default="",
43 | help="Path to the .json file containing partitioning information "
44 | )
45 |
46 | parser.add_argument('--ip-file', default='./ip_file', type=str,
47 | help='File with ip-address. Worker 0 creates this file and all others read it ')
48 |
49 |
50 | parser.add_argument('--backend', default='nccl', type=str, choices=['ccl', 'nccl', 'mpi', 'gloo'],
51 | help='Communication backend to use '
52 | )
53 |
54 | parser.add_argument(
55 | "--cpu-run", action="store_true",
56 | help="Run on CPUs if set, otherwise run on GPUs "
57 | )
58 |
59 |
60 | parser.add_argument('--train-iters', default=100, type=int,
61 | help='number of training iterations ')
62 |
63 | parser.add_argument(
64 | "--lr",
65 | type=float,
66 | default=1e-2,
67 | help="learning rate"
68 | )
69 |
70 |
71 | parser.add_argument('--rank', default=0, type=int,
72 | help='Rank of the current worker ')
73 |
74 | parser.add_argument('--world-size', default=2, type=int,
75 | help='Number of workers ')
76 |
77 | parser.add_argument('--hidden-layer-dim', default=256, type=int,
78 | help='Dimension of GNN hidden layer')
79 |
80 |
81 | class GNNModel(nn.Module):
82 | def __init__(self, in_dim: int, hidden_dim: int, out_dim: int):
83 | super().__init__()
84 |
85 | self.convs = nn.ModuleList([
86 | # pylint: disable=no-member
87 | dgl.nn.SAGEConv(in_dim, hidden_dim, aggregator_type='mean'),
88 | # pylint: disable=no-member
89 | dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type='mean'),
90 | # pylint: disable=no-member
91 | dgl.nn.SAGEConv(hidden_dim, out_dim, aggregator_type='mean'),
92 | ])
93 |
94 | def forward(self, graph: sar.HeteroGraphShardManager, features: torch.Tensor):
95 | for idx, conv in enumerate(self.convs):
96 | features = conv(graph, features)
97 | if idx < len(self.convs) - 1:
98 | features = F.relu(features, inplace=True)
99 |
100 | return features
101 |
102 |
103 | def main():
104 | args = parser.parse_args()
105 | print('args', args)
106 |
107 | use_gpu = torch.cuda.is_available() and not args.cpu_run
108 | device = torch.device('cuda' if use_gpu else 'cpu')
109 |
110 | # Obtain the ip address of the master through the network file system
111 | master_ip_address = sar.nfs_ip_init(args.rank, args.ip_file)
112 | sar.initialize_comms(args.rank,
113 | args.world_size,
114 | master_ip_address,
115 | args.backend)
116 |
117 | # Load DGL partition data
118 | partition_data = sar.load_dgl_partition_data(
119 | args.partitioning_json_file, args.rank, device)
120 |
121 | # Obtain train,validation, and test masks
122 | # These are stored as node features. Partitioning may prepend
123 | # the node type to the mask names. So we use the convenience function
124 | # suffix_key_lookup to look up the mask name while ignoring the
125 | # arbitrary node type
126 | masks = {}
127 | for mask_name, indices_name in zip(['train_mask', 'val_mask', 'test_mask'],
128 | ['train_indices', 'val_indices', 'test_indices']):
129 | boolean_mask = sar.suffix_key_lookup(partition_data.node_features,
130 | mask_name)
131 | masks[indices_name] = boolean_mask.nonzero(
132 | as_tuple=False).view(-1).to(device)
133 |
134 | labels = sar.suffix_key_lookup(partition_data.node_features,
135 | 'labels').long().to(device)
136 |
137 | # Obtain the number of classes by finding the max label across all workers
138 | num_labels = labels.max() + 1
139 | sar.comm.all_reduce(num_labels, dist.ReduceOp.MAX, move_to_comm_device=True)
140 | num_labels = num_labels.item()
141 |
142 | features = sar.suffix_key_lookup(partition_data.node_features, 'features').to(device)
143 | full_graph_manager = sar.construct_full_graph(partition_data).to(device)
144 |
145 | #We do not need the partition data anymore
146 | del partition_data
147 |
148 | gnn_model = GNNModel(features.size(1),
149 | args.hidden_layer_dim,
150 | num_labels).to(device)
151 | print('model', gnn_model)
152 |
153 | # Synchronize the model parmeters across all workers
154 | sar.sync_params(gnn_model)
155 |
156 | # Obtain the number of labeled nodes in the training
157 | # This will be needed to properly obtain a cross entropy loss
158 | # normalized by the number of training examples
159 | n_train_points = torch.LongTensor([masks['train_indices'].numel()])
160 | sar.comm.all_reduce(n_train_points, op=dist.ReduceOp.SUM, move_to_comm_device=True)
161 | n_train_points = n_train_points.item()
162 |
163 | optimizer = torch.optim.Adam(gnn_model.parameters(), lr=args.lr)
164 | for train_iter_idx in range(args.train_iters):
165 | # Train
166 | t_1 = time.time()
167 | logits = gnn_model(full_graph_manager, features)
168 | loss = F.cross_entropy(logits[masks['train_indices']],
169 | labels[masks['train_indices']], reduction='sum')/n_train_points
170 |
171 | optimizer.zero_grad()
172 | loss.backward()
173 | # Do not forget to gather the parameter gradients from all workers
174 | sar.gather_grads(gnn_model)
175 | optimizer.step()
176 | train_time = time.time() - t_1
177 |
178 | # Calculate accuracy for train/validation/test
179 | results = []
180 | for indices_name in ['train_indices', 'val_indices', 'test_indices']:
181 | n_correct = (logits[masks[indices_name]].argmax(1) ==
182 | labels[masks[indices_name]]).float().sum()
183 | results.extend([n_correct, masks[indices_name].numel()])
184 |
185 | acc_vec = torch.FloatTensor(results)
186 | # Sum the n_correct, and number of mask elements across all workers
187 | sar.comm.all_reduce(acc_vec, op=dist.ReduceOp.SUM, move_to_comm_device=True)
188 | (train_acc, val_acc, test_acc) = \
189 | (acc_vec[0] / acc_vec[1],
190 | acc_vec[2] / acc_vec[3],
191 | acc_vec[4] / acc_vec[5])
192 |
193 | result_message = (
194 | f"iteration [{train_iter_idx}/{args.train_iters}] | "
195 | )
196 | result_message += ', '.join([
197 | f"train loss={loss:.4f}, "
198 | f"Accuracy: "
199 | f"train={train_acc:.4f} "
200 | f"valid={val_acc:.4f} "
201 | f"test={test_acc:.4f} "
202 | f" | train time = {train_time} "
203 | f" |"
204 | ])
205 | print(result_message, flush=True)
206 |
207 |
208 | if __name__ == '__main__':
209 | main()
210 |
--------------------------------------------------------------------------------
/examples/train_homogeneous_sampling_basic.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Intel Corporation
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy
4 | # of this software and associated documentation files (the "Software"), to deal
5 | # in the Software without restriction, including without limitation the rights
6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | # copies of the Software, and to permit persons to whom the Software is
8 | # furnished to do so, subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | # SOFTWARE.
20 |
21 | from typing import List, Union, Dict
22 | from argparse import ArgumentParser
23 | import os
24 | import logging
25 | import psutil
26 | import time
27 | import torch
28 | import torch.nn.functional as F
29 | from torch import nn
30 | import torch.distributed as dist
31 | import dgl # type: ignore
32 | from dgl.heterograph import DGLBlock # type: ignore
33 |
34 |
35 | import sar
36 |
37 |
38 | parser = ArgumentParser(
39 | description="GNN training on node classification tasks in homogeneous graphs")
40 |
41 |
42 | parser.add_argument(
43 | "--partitioning-json-file",
44 | type=str,
45 | default="",
46 | help="Path to the .json file containing partitioning information "
47 | )
48 |
49 | parser.add_argument('--ip-file', default='./ip_file', type=str,
50 | help='File with ip-address. Worker 0 creates this file and all others read it ')
51 |
52 |
53 | parser.add_argument('--backend', default='nccl', type=str, choices=['ccl', 'nccl', 'mpi', 'gloo'],
54 | help='Communication backend to use '
55 | )
56 |
57 | parser.add_argument(
58 | "--cpu-run", action="store_true",
59 | help="Run on CPUs if set, otherwise run on GPUs "
60 | )
61 |
62 | parser.add_argument(
63 | "--precompute-batches", action="store_true",
64 | help="Precompute the batches "
65 | )
66 |
67 |
68 | parser.add_argument(
69 | "--optimized-batches-cache",
70 | type=str,
71 | default="",
72 | help="Prefix of the files used to store precomputed batches "
73 | )
74 |
75 |
76 | parser.add_argument('--train-iters', default=100, type=int,
77 | help='number of training iterations ')
78 |
79 | parser.add_argument(
80 | "--lr",
81 | type=float,
82 | default=1e-2,
83 | help="learning rate"
84 | )
85 |
86 |
87 | parser.add_argument('--rank', default=0, type=int,
88 | help='Rank of the current worker ')
89 |
90 | parser.add_argument('--world-size', default=2, type=int,
91 | help='Number of workers ')
92 |
93 | parser.add_argument('--hidden-layer-dim', default=256, type=int,
94 | help='Dimension of GNN hidden layer')
95 |
96 | parser.add_argument('--batch-size', default=5000, type=int,
97 | help='per worker batch size ')
98 |
99 | parser.add_argument('--num-workers', default=0, type=int,
100 | help='number of dataloader workers ')
101 |
102 | parser.add_argument('--fanout', nargs="+", type=int,
103 | help='fanouts for sampling ')
104 |
105 | parser.add_argument('--max-collective-size', default=0, type=int,
106 | help='The maximum allowed size of the data in a collective. \
107 | If a collective would communicate more than this maximum, it is split into multiple collectives.\
108 | Collective calls with large data may cause instabilities in some communication backends ')
109 |
110 |
111 | class GNNModel(nn.Module):
112 | def __init__(self, in_dim: int, hidden_dim: int, out_dim: int):
113 | super().__init__()
114 |
115 | self.convs = nn.ModuleList([
116 | # pylint: disable=no-member
117 | dgl.nn.SAGEConv(in_dim, hidden_dim, aggregator_type='mean'),
118 | # pylint: disable=no-member
119 | dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type='mean'),
120 | # pylint: disable=no-member
121 | dgl.nn.SAGEConv(hidden_dim, out_dim, aggregator_type='mean'),
122 | ])
123 |
124 | def forward(self, blocks: List[Union[DGLBlock, sar.GraphShardManager]], features: torch.Tensor):
125 | for idx, conv in enumerate(self.convs):
126 | features = conv(blocks[idx], features)
127 | if idx < len(self.convs) - 1:
128 | features = F.relu(features, inplace=True)
129 |
130 | return features
131 |
132 |
133 | def main():
134 | # psutil.Process().cpu_affinity([8])
135 | args = parser.parse_args()
136 | print('args', args)
137 |
138 | use_gpu = torch.cuda.is_available() and not args.cpu_run
139 | device = torch.device('cuda') if use_gpu else torch.device('cpu')
140 |
141 | if args.rank == -1:
142 | # Try to infer the worker's rank from environment variables
143 | # created by mpirun or similar MPI launchers
144 | args.rank = int(os.environ.get("PMI_RANK", -1))
145 | if args.rank == -1:
146 | args.rank = int(os.environ["RANK"])
147 |
148 | # os.environ.putenv("GOMP_CPU_AFFINITY", "10,11")
149 | # os.environ.putenv("OMP_NUM_THREADS", "16")
150 | # Obtain the ip address of the master through the network file system
151 | master_ip_address = sar.nfs_ip_init(args.rank, args.ip_file)
152 | sar.initialize_comms(args.rank,
153 | args.world_size, master_ip_address,
154 | args.backend)
155 |
156 | # Load DGL partition data
157 | partition_data = sar.load_dgl_partition_data(
158 | args.partitioning_json_file, args.rank, torch.device('cpu'))
159 |
160 | # Obtain train,validation, and test masks
161 | # These are stored as node features. Partitioning may prepend
162 | # the node type to the mask names. So we use the convenience function
163 | # suffix_key_lookup to look up the mask name while ignoring the
164 | # arbitrary node type
165 | masks = {}
166 | for mask_name, indices_name in zip(['train_mask', 'val_mask', 'test_mask'],
167 | ['train_indices', 'val_indices', 'test_indices']):
168 | boolean_mask = sar.suffix_key_lookup(partition_data.node_features,
169 | mask_name)
170 | masks[indices_name] = boolean_mask.nonzero(
171 | as_tuple=False).view(-1).to(device)
172 | print(f'mask {indices_name} : {masks[indices_name]} ')
173 | labels = sar.suffix_key_lookup(partition_data.node_features,
174 | 'labels').long().to(device)
175 |
176 | # Obtain the number of classes by finding the max label across all workers
177 | num_labels = labels.max() + 1
178 | sar.comm.all_reduce(num_labels, dist.ReduceOp.MAX,
179 | move_to_comm_device=True)
180 | num_labels = num_labels.item()
181 |
182 | features = sar.suffix_key_lookup(
183 | partition_data.node_features, 'features') # keep features always on CPU
184 | full_graph_manager = sar.construct_full_graph(
185 | partition_data) # Keep full graph on CPU
186 |
187 | node_ranges = partition_data.node_ranges
188 | del partition_data
189 |
190 | gnn_model = GNNModel(features.size(1),
191 | args.hidden_layer_dim,
192 | num_labels).to(device)
193 |
194 | # gnn_model_cpu will be used for inference
195 | if use_gpu:
196 | gnn_model_cpu = GNNModel(features.size(1),
197 | args.hidden_layer_dim,
198 | num_labels)
199 | else:
200 | gnn_model_cpu = gnn_model
201 | print('model', gnn_model)
202 |
203 | # Synchronize the model parmeters across all workers
204 | sar.sync_params(gnn_model)
205 |
206 | optimizer = torch.optim.Adam(gnn_model.parameters(), lr=args.lr)
207 |
208 | neighbor_sampler = sar.core.sampling.DistNeighborSampler([15, 10, 5],
209 | input_node_features={
210 | 'features': features},
211 | output_node_features={
212 | 'labels': labels},
213 | output_device=device
214 | )
215 |
216 | train_nodes = masks['train_indices'] + node_ranges[sar.rank()][0]
217 |
218 | dataloader = sar.core.sampling.DataLoader(
219 | full_graph_manager,
220 | train_nodes,
221 | neighbor_sampler,
222 | args.batch_size,
223 | shuffle=True,
224 | precompute_optimized_batches=args.precompute_batches,
225 | optimized_batches_cache=(
226 | args.optimized_batches_cache if args.optimized_batches_cache else None),
227 | num_workers=args.num_workers)
228 |
229 | print('sampling graph edata', full_graph_manager.sampling_graph.edata)
230 |
231 | for k in list(masks.keys()):
232 | masks[k] = masks[k].to(device)
233 |
234 | for train_iter_idx in range(args.train_iters):
235 | total_loss = 0
236 | gnn_model.train()
237 | sar.Config.max_collective_size = 0
238 |
239 | train_t1 = dataloader_t1 = time.time()
240 | n_total = n_correct = 0
241 | pure_training_time = 0
242 | for block_idx, blocks in enumerate(dataloader):
243 | start_t1 = time.time()
244 | loading_time = time.time() - dataloader_t1
245 | print(f'in block {block_idx} : {blocks}. Loaded in {loading_time}')
246 | # print('block edata', [block.edata[dgl.EID] for block in blocks])
247 | blocks = [b.to(device) for b in blocks]
248 | block_features = blocks[0].srcdata['features']
249 | block_labels = blocks[-1].dstdata['labels']
250 | logits = gnn_model(blocks, block_features)
251 |
252 | loss = F.cross_entropy(logits, block_labels, reduction='mean')
253 | n_correct += (logits.argmax(1) ==
254 | block_labels).float().sum()
255 | n_total += len(block_labels)
256 |
257 | total_loss += loss.item()
258 | optimizer.zero_grad()
259 | loss.backward()
260 | # Do not forget to gather the parameter gradients from all workers
261 | tg = time.time()
262 | sar.gather_grads(gnn_model)
263 | print('gather grad time', time.time() - tg)
264 | optimizer.step()
265 | dataloader_t1 = time.time()
266 | pure_training_time += (time.time() - start_t1)
267 |
268 | train_time = time.time() - train_t1
269 | print('train time', train_time, flush=True)
270 | print('pure train time', pure_training_time, flush=True)
271 |
272 | print('loss', total_loss, flush=True)
273 | print('accuracy ', n_correct/n_total, flush=True)
274 |
275 | # Full graph inference is done on CPUs using sequential
276 | # aggregation and re-materialization
277 | gnn_model_cpu.eval()
278 | if gnn_model_cpu is not gnn_model:
279 | gnn_model_cpu.load_state_dict(gnn_model.state_dict())
280 | sar.Config.max_collective_size = args.max_collective_size
281 | with torch.no_grad():
282 | # Calculate accuracy for train/validation/test
283 | logits = gnn_model_cpu(
284 | [full_graph_manager] * 3, features)
285 | results = []
286 | for indices_name in ['train_indices', 'val_indices', 'test_indices']:
287 | n_correct = (logits[masks[indices_name]].argmax(1) ==
288 | labels[masks[indices_name]].cpu()).float().sum()
289 | results.extend([n_correct, masks[indices_name].numel()])
290 |
291 | acc_vec = torch.FloatTensor(results)
292 | # Sum the n_correct, and number of mask elements across all workers
293 | sar.comm.all_reduce(acc_vec, op=dist.ReduceOp.SUM,
294 | move_to_comm_device=True)
295 | (train_acc, val_acc, test_acc) = (acc_vec[0] / acc_vec[1],
296 | acc_vec[2] / acc_vec[3],
297 | acc_vec[4] / acc_vec[5])
298 |
299 | result_message = (
300 | f"iteration [{train_iter_idx}/{args.train_iters}] | "
301 | )
302 | result_message += ', '.join([
303 | f"Accuracy: "
304 | f"train={train_acc:.4f} "
305 | f"valid={val_acc:.4f} "
306 | f"test={test_acc:.4f} "
307 | f" | train time = {train_time} "
308 | f" |"
309 | ])
310 | print(result_message, flush=True)
311 |
312 |
313 | if __name__ == '__main__':
314 | main()
315 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools >= 68.2.2", "wheel"]
3 | build-backend = "setuptools.build_meta"
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --find-links=https://data.dgl.ai/wheels/repo.html
2 | dgl>=1.0.0
3 | numpy>=1.22.0
4 | torch>=1.10.0
5 | ifaddr>=0.1.7
--------------------------------------------------------------------------------
/sar/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Intel Corporation
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy
4 | # of this software and associated documentation files (the "Software"), to deal
5 | # in the Software without restriction, including without limitation the rights
6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | # copies of the Software, and to permit persons to whom the Software is
8 | # furnished to do so, subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | # SOFTWARE.
20 |
21 |
22 | '''
23 | Top-level SAR package
24 | '''
25 | from . import core
26 |
27 | from .comm import initialize_comms, rank, world_size, comm_device,\
28 | nfs_ip_init, sync_params, gather_grads
29 | from .core import GraphShardManager, HeteroGraphShardManager, message_has_parameters, DistributedBlock,\
30 | DistNeighborSampler, DataLoader
31 | from .construct_shard_manager import construct_mfgs, construct_full_graph, convert_dist_graph
32 | from .data_loading import load_dgl_partition_data, suffix_key_lookup
33 | from .distributed_bn import DistributedBN1D
34 | from .config import Config
35 | from .edge_softmax import edge_softmax
36 | from .patch_dgl import patch_dgl, patched_edge_softmax, RelGraphConv
37 | from .logging_setup import logging_setup, logger
38 |
39 |
40 | __all__ = ['initialize_comms', 'rank', 'world_size', 'nfs_ip_init',
41 | 'comm_device', 'DistributedBN1D',
42 | 'construct_mfgs', 'construct_full_graph', 'convert_dist_graph', 'GraphShardManager', 'HeteroGraphShardManager',
43 | 'load_dgl_partition_data', 'suffix_key_lookup', 'Config', 'edge_softmax',
44 | 'message_has_parameters', 'DistributedBlock', 'DistNeighborSampler', 'DataLoader',
45 | 'logging_setup', 'logger', 'RelGraphConv', 'sync_params', 'gather_grads', 'patch_dgl', 'patched_edge_softmax']
46 |
--------------------------------------------------------------------------------
/sar/common_tuples.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Intel Corporation
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy
4 | # of this software and associated documentation files (the "Software"), to deal
5 | # in the Software without restriction, including without limitation the rights
6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | # copies of the Software, and to permit persons to whom the Software is
8 | # furnished to do so, subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | # SOFTWARE.
20 |
21 | '''
22 | Tuples for grouping related data
23 | '''
24 | from typing import NamedTuple, Dict, Tuple, List, Optional, Any, TYPE_CHECKING
25 | from enum import Enum
26 | from torch import Tensor
27 | import dgl # type: ignore
28 |
29 | if TYPE_CHECKING:
30 | from .core.graphshard import GraphShardManager
31 | from .core.sar_aggregation import BackwardManager
32 |
33 |
34 | class TensorPlace(Enum):
35 | SRC = 0
36 | DST = 1
37 | EDGE = 2
38 | PARAM = 3
39 |
40 |
41 | class ShardEdgesAndFeatures(NamedTuple):
42 | '''
43 | Stores the edge information for all edges connecting nodes in one partition to
44 | nodes in another partition. For an N-way partition, each worker will have N ShardEdgesAndFeatures object,
45 | where each object contains data for incoming edges from each partition (including the worker's own
46 | partition)
47 |
48 |
49 | .. py:attribute:: edges : Tuple[Tensor,Tensor]
50 |
51 | The source and destination global node ids for each edge in the shard
52 |
53 |
54 | .. py:attribute:: edge_features : Dict[str,Tensor]
55 |
56 | A dictionary of the edge features
57 |
58 | '''
59 | edges: Tuple[Tensor, Tensor]
60 | edge_features: Dict[str, Tensor]
61 |
62 |
63 | class GraphShardManagerData(NamedTuple):
64 | all_shard_edges: List[ShardEdgesAndFeatures]
65 | src_node_ranges: List[Tuple[int, int]]
66 | tgt_node_range: Tuple[int, int]
67 | tgt_seed_nodes: Tensor
68 | local_src_seed_nodes: Tensor
69 |
70 |
71 | class PartitionData(NamedTuple):
72 | '''
73 | Stores all the data for the local partition
74 |
75 |
76 | .. py:attribute:: all_shard_edges : List[ShardEdgesAndFeatures]
77 |
78 | A list of ShardEdgesAndFeatures objects. One for edges incoming from each partition
79 |
80 |
81 | .. py:attribute:: node_ranges : List[Tuple[int,int]]
82 |
83 | node_ranges[i] is a tuple of the start and end global node indices for nodes in partition i.
84 |
85 | .. py:attribute:: node_features : Dict[str,Tensor]
86 |
87 | Dictionary of node features for nodes in local partition
88 |
89 | .. py:attribute:: node_type_names : List[str]
90 |
91 | List of node type names. Use in conjunction with dgl.NTYPE node features to get\
92 | the node type of each node
93 |
94 | .. py:attribute:: edge_type_names : List[str]
95 |
96 | List of edge type names. Use in conjunction with dgl.ETYPE edge features to get\
97 | the edge type of each edge
98 |
99 | .. py:attribute:: partition_book : dgl.distributed.GraphPartitionBook
100 |
101 | The graph partition information
102 |
103 |
104 | '''
105 |
106 | all_shard_edges: List[ShardEdgesAndFeatures]
107 | node_ranges: List[Tuple[int, int]]
108 | node_features: Dict[str, Tensor]
109 | node_type_names: List[str]
110 | edge_type_names: List[str]
111 | partition_book: dgl.distributed.GraphPartitionBook
112 |
113 |
114 | class AggregationData(NamedTuple):
115 | graph_shard_manager: "GraphShardManager"
116 | message_func: Any
117 | reduce_func: Any
118 | etype: Any
119 | all_input_names: List[Tuple[TensorPlace, str]]
120 | n_params: int
121 | grad_enabled: bool
122 | remote_data: bool
123 |
124 |
125 | class ShardInfo(NamedTuple):
126 | shard_idx: int
127 | src_node_range: Tuple[int, int]
128 | tgt_node_range: Tuple[int, int]
129 | edge_range: Tuple[int, int]
130 |
131 |
132 | class SocketInfo(NamedTuple):
133 | name: str
134 | ip_addr: str
135 |
--------------------------------------------------------------------------------
/sar/config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Intel Corporation
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy
4 | # of this software and associated documentation files (the "Software"), to deal
5 | # in the Software without restriction, including without limitation the rights
6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | # copies of the Software, and to permit persons to whom the Software is
8 | # furnished to do so, subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | # SOFTWARE.
20 |
21 | import torch
22 |
23 |
24 | class Config(object):
25 | '''
26 | General configuration for the SAR library.
27 |
28 |
29 | .. py:attribute:: disable_sr : bool
30 |
31 | Disables sequential re-materialization of the computational graph during the backward pass.\
32 | The computational graph is constructed normally during the forward pass. default : False
33 |
34 |
35 |
36 | .. py:attribute:: max_collective_size : int
37 |
38 | Limits the maximum size of data in torch.distributed.all_to_all collective calls. If non-zero,\
39 | the sar.comms.all_to_all wrapper method will break down the collective call into multiple torch.distributed.all_to_all\
40 | calls so that the size of the data in each call is below max_collective_size. default : 0
41 |
42 | .. py:attribute:: pipeline_depth : int
43 |
44 | Sets the communication pipeline depth when doing sequential aggregation or sequential re-materialization.\
45 | In a separate thread, SAR will pre-fetch up to ``pipeline_depth`` remote partitions into a data queue that will then\
46 | be processed by the compute thread. Higher values will increase memory consumption but may hide\
47 | communication latency. default : 1
48 |
49 |
50 | '''
51 |
52 | disable_sr: bool = False
53 | max_collective_size: int = 0
54 | pipeline_depth: int = 1
55 |
--------------------------------------------------------------------------------
/sar/core/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Intel Corporation
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy
4 | # of this software and associated documentation files (the "Software"), to deal
5 | # in the Software without restriction, including without limitation the rights
6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | # copies of the Software, and to permit persons to whom the Software is
8 | # furnished to do so, subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | # SOFTWARE.
20 |
21 | '''
22 | Modules for sharded data representation and management
23 | '''
24 | from .graphshard import GraphShard, GraphShardManager, HeteroGraphShardManager
25 | from .sar_aggregation import message_has_parameters
26 | from .full_partition_block import DistributedBlock
27 | from .sampling import DistNeighborSampler, DataLoader
28 |
29 | __all__ = ['GraphShard', 'GraphShardManager', 'HeteroGraphShardManager', 'message_has_parameters',
30 | 'DistributedBlock', 'DistNeighborSampler', 'DataLoader']
31 |
--------------------------------------------------------------------------------
/sar/core/full_partition_block.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Intel Corporation
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy
4 | # of this software and associated documentation files (the "Software"), to deal
5 | # in the Software without restriction, including without limitation the rights
6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | # copies of the Software, and to permit persons to whom the Software is
8 | # furnished to do so, subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | # SOFTWARE.
20 |
21 | from typing import List, Dict, Optional, Tuple
22 | import logging
23 | from collections.abc import MutableMapping
24 | import torch
25 | import dgl # type: ignore
26 | from torch import Tensor
27 | import torch.distributed as dist
28 | from torch.autograd import profiler
29 |
30 |
31 | from ..comm import all_to_all, world_size, rank, all_reduce
32 |
33 | logger = logging.getLogger(__name__)
34 | logger.addHandler(logging.NullHandler())
35 | logger.setLevel(logging.DEBUG)
36 |
37 |
38 | class ProxyDataView(MutableMapping):
39 | """A distributed dictionary"""
40 |
41 | def __init__(self, tensor_sz: int, base_dict: MutableMapping,
42 | indices_required_from_me: List[Tensor],
43 | sizes_expected_from_others: List[int]):
44 | self.base_dict = base_dict
45 | self.tensor_sz = tensor_sz
46 | self.indices_required_from_me = indices_required_from_me
47 | self.sizes_expected_from_others = sizes_expected_from_others
48 |
49 | def set_base_dict(self, new_base_dict: MutableMapping):
50 | self.base_dict = new_base_dict
51 |
52 | def __setitem__(self, key: str, value: Tensor):
53 | assert value.size(0) == self.tensor_sz, \
54 | f'Tenosr size {value.size()} does not match graph data size {self.tensor_sz}'
55 | logger.debug(f'Distributing item {key} among all DistributedBlocks')
56 |
57 | with profiler.record_function("COMM_FETCH"):
58 | exchange_result = tensor_exchange_op(
59 | value, self.indices_required_from_me, self.sizes_expected_from_others)
60 |
61 | self.base_dict[key] = exchange_result
62 |
63 | def __getitem__(self, key: str):
64 | return self.base_dict[key]
65 |
66 | def __delitem__(self, key: str):
67 | del self.base_dict[key]
68 |
69 | def __iter__(self):
70 | return iter(self.base_dict)
71 |
72 | def __len__(self):
73 | return len(self.base_dict)
74 |
75 |
76 | class DistributedBlock:
77 | """
78 | A wrapper around a dgl.DGLBlock object. The DGLBlock object represents all the edges incoming
79 | to the local partition. It communicates with remote partitions to implement one-shot communication and
80 | aggregation in the forward and backward passes . You should not construct DistributedBlock directly,
81 | but instead use :meth:`GraphShardManager.get_full_partition_graph`
82 |
83 | :param block: A DGLBlock object representing all edges incoming to the local partition
84 | :type block:
85 | :param indices_required_from_me: The local node indices required by every other partition to carry out\
86 | one-hop aggregation
87 | :type indices_required_from_me: List[Tensor]
88 | :param sizes_expected_from_others: The number of remote indices that we need to fetch\
89 | from remote partitions to update the features of the nodes in the local partition
90 | :type sizes_expected_from_others: List[int]
91 | :param src_ranges: The global node ids of the start node and end node in each partition. Nodes in each\
92 | partition have consecutive indices
93 | :type src_ranges: List[Tuple[int, int]]
94 | :param unique_src_nodes: The absolute node indices of the source nodes in each remote partition
95 | :type unique_src_nodes: List[Tensor]
96 | :param input_nodes: The indices of the input nodes relative to the starting node index of the local partition\
97 | The input nodes are the nodes needed to produce the output node features assuming one-hop aggregation
98 | :type input_nodes: Tensor
99 | :param seeds: The node indices of the output nodes relative to the starting node index of the local partition
100 | :type seeds: Tensor
101 | :param edge_type_names: A list of edge type names
102 | :type edge_type_names: List[str]
103 |
104 | """
105 |
106 | def __init__(self, block, indices_required_from_me: List[Tensor],
107 | sizes_expected_from_others: List[int],
108 | src_ranges: List[Tuple[int, int]],
109 | unique_src_nodes: List[Tensor],
110 | input_nodes: Tensor,
111 | seeds: Tensor,
112 | edge_type_names: List[str]):
113 |
114 | self._block = block
115 | self.indices_required_from_me = indices_required_from_me
116 | self.sizes_expected_from_others = sizes_expected_from_others
117 | self.src_ranges = src_ranges
118 | self.unique_src_nodes = unique_src_nodes
119 | self.edge_type_names = edge_type_names
120 | self.input_nodes = input_nodes
121 | self.seeds = seeds
122 |
123 | self.srcdata = ProxyDataView(input_nodes.size(0),
124 | block.srcdata, indices_required_from_me, sizes_expected_from_others)
125 |
126 | self.out_degrees_cache: Dict[Optional[str], Tensor] = {}
127 |
128 | def out_degrees(self, vertices=dgl.ALL, etype=None) -> Tensor:
129 | if etype not in self.out_degrees_cache:
130 | src_out_degrees = self._block.out_degrees(etype=etype)
131 | src_out_degrees_split = torch.split(src_out_degrees, self.sizes_expected_from_others)
132 |
133 | for comm_round in range(world_size()):
134 | out_degrees = torch.zeros(
135 | self.src_ranges[comm_round][1] - self.src_ranges[comm_round][0],
136 | dtype=self._block.idtype).to(self._block.device)
137 |
138 | out_degrees[self.unique_src_nodes[comm_round] - self.src_ranges[comm_round][0]
139 | ] = src_out_degrees_split[comm_round]
140 | all_reduce(out_degrees, op=dist.ReduceOp.SUM, move_to_comm_device=True)
141 | if comm_round == rank():
142 | out_degrees[out_degrees == 0] = 1
143 | self.out_degrees_cache[etype] = out_degrees.to(self._block.device)
144 |
145 | if vertices == dgl.ALL:
146 | return self.out_degrees_cache[etype]
147 |
148 | return self.out_degrees_cache[etype][vertices]
149 |
150 | def to(self, device: torch.device):
151 | self._block = self._block.to(device)
152 | self.srcdata.set_base_dict(self._block.srcdata)
153 | return self
154 |
155 | def __getattr__(self, name):
156 | return getattr(self._block, name)
157 |
158 |
159 | class TensorExchangeOp(torch.autograd.Function): # pylint: disable = abstract-method
160 | @ staticmethod
161 | # pylint: disable = arguments-differ,unused-argument
162 | def forward(ctx, val: Tensor, indices_required_from_me: Tensor, # type: ignore
163 | sizes_expected_from_others: Tensor) -> Tensor: # type: ignore
164 | ctx.sizes_expected_from_others = sizes_expected_from_others
165 | ctx.indices_required_from_me = indices_required_from_me
166 | ctx.input_size = val.size()
167 |
168 | send_tensors = [val[indices] for indices in indices_required_from_me]
169 | recv_tensors = [val.new(sz_from_others, *val.size()[1:])
170 | for sz_from_others in sizes_expected_from_others]
171 |
172 | all_to_all(recv_tensors, send_tensors, move_to_comm_device=True)
173 |
174 | return torch.cat(recv_tensors)
175 |
176 | @ staticmethod
177 | # pylint: disable = arguments-differ
178 | # type: ignore
179 | def backward(ctx, grad):
180 | send_tensors = list(torch.split(grad, ctx.sizes_expected_from_others))
181 | recv_tensors = [grad.new(len(indices), *grad.size()[1:])
182 | for indices in ctx.indices_required_from_me]
183 | all_to_all(recv_tensors, send_tensors, move_to_comm_device=True)
184 |
185 | input_grad = grad.new(ctx.input_size).zero_()
186 | for r_tensor, indices in zip(recv_tensors, ctx.indices_required_from_me):
187 | input_grad[indices] += r_tensor
188 |
189 | return input_grad, None, None
190 |
191 |
192 | tensor_exchange_op = TensorExchangeOp.apply
193 |
--------------------------------------------------------------------------------
/sar/data_loading.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Intel Corporation
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy
4 | # of this software and associated documentation files (the "Software"), to deal
5 | # in the Software without restriction, including without limitation the rights
6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | # copies of the Software, and to permit persons to whom the Software is
8 | # furnished to do so, subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | # SOFTWARE.
20 |
21 | from typing import List, Tuple, Dict, Optional
22 | import torch
23 | from torch import Tensor
24 | import dgl # type: ignore
25 | from dgl.distributed.partition import load_partition # type: ignore
26 | from .common_tuples import PartitionData, ShardEdgesAndFeatures
27 |
28 |
29 | def suffix_key_lookup(feature_dict: Dict[str, Tensor], key: str,
30 | expand_to_all: bool = False,
31 | type_list: Optional[List[str]] = None) -> Tensor:
32 | """
33 | Looks up the provided key in the provided dictionary. Uses suffix matching, where a dictionary
34 | key matches if ends with the provided key. This allows feature name lookup in the edge/node
35 | feature dictionaries in DGL's partition data whose keys have the form
36 | ``{node or edge type name}/{feature_name}``. In heterogeneous graphs, some features might only
37 | be present for certain node/edge types. Set the ``expand_to_all`` flag to expand the
38 | feature tensor to all nodes/edges in the graph. The expanded tensor will be zero for all
39 | nodes/edges where the requested feature is not present
40 |
41 |
42 | :param feature_dict: Node or edge feature dictionary
43 | :type feature_dict: Dict[str, Tensor]
44 | :param key: Key to look up.
45 | :type key: str
46 | :param expand_to_all: Expand feature tensor to all nodes/edges.
47 | :type expand_to_all: bool
48 | :param type_list: List of edge or node type names. Required if ``expand_to_all`` is ``True``
49 | :type type_list: Optional[List[str]]
50 | :returns: The matched (possibly expanded) feature tensor
51 |
52 | """
53 | matched_keys = [k for k in feature_dict if k.endswith(key)]
54 | if len(matched_keys) == 0:
55 | return torch.LongTensor([])
56 | assert len(matched_keys) == 1
57 | matched_features = feature_dict[matched_keys[0]]
58 | if expand_to_all:
59 | assert type_list is not None
60 | if len(type_list) > 1 and dgl.NTYPE in feature_dict:
61 | type_id = feature_dict[dgl.NTYPE]
62 | key_node_type = matched_keys[0].split('/')[0]
63 | node_type_idx = type_list.index(key_node_type)
64 |
65 | expanded_features = matched_features.new(
66 | type_id.size(0), *matched_features.size()[1:]).zero_()
67 | expanded_features[type_id == node_type_idx] = matched_features
68 | return expanded_features
69 |
70 | return matched_features
71 |
72 |
73 | def _mask_features_dict(edge_features: Dict[str, Tensor],
74 | mask: Tensor, device: torch.device) -> Dict[str, Tensor]:
75 | #TODO(kpietkun): allow using edge features
76 | return {k: edge_features[k][mask].to(device) for k in edge_features if k == dgl.ETYPE}
77 |
78 |
79 | def _get_type_ordered_edges(edge_mask: Tensor, edge_types: Tensor,
80 | n_edge_types: int) -> Tensor:
81 | reordered_edge_mask: List[Tensor] = []
82 | for edge_type_idx in range(n_edge_types):
83 | edge_mask_typed = torch.logical_and(
84 | edge_mask, edge_types == edge_type_idx)
85 | reordered_edge_mask.append(
86 | edge_mask_typed.nonzero(as_tuple=False).view(-1))
87 |
88 | return torch.cat(reordered_edge_mask)
89 |
90 |
91 | def create_partition_data(graph: dgl.DGLGraph,
92 | own_partition_idx: int,
93 | node_features: Dict[str, torch.Tensor],
94 | edge_features: Dict[str, Tensor],
95 | partition_book: dgl.distributed.GraphPartitionBook,
96 | node_type_list: List[str],
97 | edge_type_list: List[str],
98 | device: torch.device) -> PartitionData:
99 | """
100 | Creates SAR's PartitionData object basing on graph partition and features.
101 |
102 | :param graph: The graph partition structure for specific ``own_partition_idx``
103 | :type graph: dgl.DGLGraph
104 | :param own_partition_idx: The index of the partition to create. This is typically the\
105 | worker/machine rank
106 | :type own_partition_idx: int
107 | :param node_features: Dictionary containing node features for graph partition
108 | :type node_features: Dict[str, Tensor]
109 | :param edge_features: Dictionary containing edge features for graph partition
110 | :type edge_features: Dict[(str, str, str), Tensor]
111 | :param partition_book: The graph partition information
112 | :type partition_book: dgl.distributed.GraphPartitionBook
113 | :param node_type_list: List of node types
114 | :type node_type_list: List[str]
115 | :param edge_type_list: List of edge types
116 | :type edge_type_list: List[str]
117 | :param device: Device on which to place the loaded partition data
118 | :type device: torch.device
119 | :returns: The loaded partition data
120 | """
121 | is_heterogeneous = (len(edge_type_list) > 1)
122 | # Delete redundant edge features with keys {relation name}/reltype. graph.edata[dgl.ETYPE ] already contains
123 | # the edge type in a heterogeneous graph
124 | if is_heterogeneous:
125 | for edge_feat_key in list(edge_features.keys()):
126 | if 'reltype' in edge_feat_key:
127 | del edge_features[edge_feat_key]
128 |
129 | # Obtain the node ranges in each partition in the homogenized graph
130 | start_node_idx = 0
131 | node_ranges: List[Tuple[int, int]] = []
132 | for part_metadata in partition_book.metadata():
133 | node_ranges.append(
134 | (start_node_idx, start_node_idx + part_metadata['num_nodes']))
135 | start_node_idx += part_metadata['num_nodes']
136 |
137 | # Include the node types in the node feature dictionary
138 | if dgl.NTYPE in graph.ndata:
139 | node_features[dgl.NTYPE] = graph.ndata[dgl.NTYPE][graph.ndata['inner_node'].bool()]
140 | else:
141 | node_features[dgl.NTYPE] = torch.zeros(graph.num_nodes(), dtype=torch.int32)[graph.ndata['inner_node'].bool()]
142 |
143 | # Include the edge types in the edge feature dictionary
144 | inner_edge_mask = graph.edata['inner_edge'].bool()
145 | if dgl.ETYPE in graph.edata:
146 | edge_features[dgl.ETYPE] = graph.edata[dgl.ETYPE][inner_edge_mask]
147 | else:
148 | edge_features[dgl.ETYPE] = torch.zeros(graph.num_edges(), dtype=torch.int32)[inner_edge_mask]
149 |
150 | # Obtain the inner edges. These are the partition edges
151 | local_partition_edges = torch.stack(graph.all_edges())[:, inner_edge_mask]
152 | # Use global node ids in partition_edges
153 | partition_edges = graph.ndata[dgl.NID][local_partition_edges]
154 |
155 | # Check that all target nodes lie in the current partition
156 | assert partition_edges[1].min() >= node_ranges[own_partition_idx][0] \
157 | and partition_edges[1].max() < node_ranges[own_partition_idx][1]
158 |
159 | all_shard_edges: List[ShardEdgesAndFeatures] = []
160 |
161 | for part_idx in range(partition_book.num_partitions()):
162 | # obtain the mask for edges originating from partition part_idx
163 | edge_mask = torch.logical_and(partition_edges[0] >= node_ranges[part_idx][0],
164 | partition_edges[0] < node_ranges[part_idx][1])
165 |
166 | # Reorder the edges in each shard so that edges with the same type
167 | # follow each other
168 | if is_heterogeneous:
169 | edge_mask = _get_type_ordered_edges(
170 | edge_mask, edge_features[dgl.ETYPE], len(edge_type_list))
171 |
172 | all_shard_edges.append(ShardEdgesAndFeatures(
173 | (partition_edges[0, edge_mask], partition_edges[1, edge_mask]),
174 | _mask_features_dict(edge_features, edge_mask, device)
175 | ))
176 |
177 | return PartitionData(all_shard_edges,
178 | node_ranges,
179 | node_features,
180 | node_type_list,
181 | edge_type_list,
182 | partition_book
183 | )
184 |
185 |
186 | def load_dgl_partition_data(partition_json_file: str,
187 | own_partition_idx: int, device: torch.device) -> PartitionData:
188 | """
189 | Loads partition data created by DGL's ``partition_graph`` function
190 |
191 | :param partition_json_file: Path to the .json file containing partitioning data
192 | :type partition_json_file: str
193 | :param own_partition_idx: The index of the partition to load. This is typically the\
194 | worker/machine rank
195 | :type own_partition_idx: int
196 | :param device: Device on which to place the loaded partition data
197 | :type device: torch.device
198 | :returns: The loaded partition data
199 |
200 | """
201 | (graph, node_features,
202 | edge_features, partition_book, _,
203 | node_type_list, edge_type_list) = load_partition(partition_json_file, own_partition_idx)
204 |
205 | return create_partition_data(graph, own_partition_idx,
206 | node_features, edge_features,
207 | partition_book, node_type_list,
208 | edge_type_list, device)
209 |
210 | def load_dgl_partition_data_from_graph(graph: dgl.distributed.DistGraph,
211 | device: torch.device) -> PartitionData:
212 | """
213 | Loads partition data from DistGraph object
214 |
215 | :param graph: The distributed graph
216 | :type graph: dgl.distributed.DistGraph
217 | :param device: Device on which to place the loaded partition data
218 | :type device: torch.device
219 | :returns: The loaded partition data
220 |
221 | """
222 | own_partition_idx = graph.rank()
223 | local_g = graph.local_partition
224 |
225 | assert dgl.NID in local_g.ndata
226 | assert dgl.EID in local_g.edata
227 |
228 | # get originalmapping for node and edge ids
229 | orig_n_ids = local_g.ndata[dgl.NID][local_g.ndata['inner_node'].bool().nonzero().view(-1)]
230 | orig_e_ids = local_g.edata[dgl.EID][local_g.edata['inner_edge'].bool().nonzero().view(-1)]
231 |
232 | # fetch local features from DistTensor
233 | node_features = {key : torch.Tensor(graph.ndata[key][orig_n_ids]) for key in list(graph.ndata.keys())}
234 | edge_features = {key : torch.Tensor(graph.edata[key][orig_e_ids]) for key in list(graph.edata.keys())}
235 |
236 | partition_book = graph.get_partition_book()
237 | node_type_list = local_g.ntypes
238 | edge_type_list = [local_g.to_canonical_etype(etype) for etype in graph.etypes]
239 |
240 | return create_partition_data(local_g, own_partition_idx,
241 | node_features, edge_features,
242 | partition_book, node_type_list,
243 | edge_type_list, device)
--------------------------------------------------------------------------------
/sar/distributed_bn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Intel Corporation
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy
4 | # of this software and associated documentation files (the "Software"), to deal
5 | # in the Software without restriction, including without limitation the rights
6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | # copies of the Software, and to permit persons to whom the Software is
8 | # furnished to do so, subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | # SOFTWARE.
20 |
21 | from typing import Optional
22 | import torch
23 | from torch import Tensor
24 | import torch.distributed as dist
25 | from torch import nn
26 | from torch.nn import Parameter
27 | from torch.nn import init
28 | from .comm import all_reduce, comm_device, is_initialized
29 |
30 |
31 | class DistributedBN1D(nn.Module):
32 | """Distributed Batch normalization layer
33 |
34 | Normalizes a 2D feature tensor using the global mean and standard deviation calculated across all workers.
35 |
36 |
37 | :param n_feats: The second dimension (feature dimension) in the 2D input tensor
38 | :type n_feats: int
39 | :param eps: a value added to the variance for numerical stability
40 | :type eps: float
41 | :param affine: When ``True``, the module will use learnable affine parameter
42 | :type affine: bool
43 | :param distributed: Boolean speficying whether to run in distributed mode where normalizing\
44 | statistics are calculated across all workers, or local mode where the normalizing statistics\
45 | are calculated using only the local input feature tensor. If not specified, it will be set to\
46 | ``True`` if the user has called :func:`sar.initialize_comms`, and ``False`` otherwise
47 | :type distributed: Optional[bool]
48 |
49 | """
50 | def __init__(self, n_feats: int, eps: float = 1.0e-5, affine: bool = True, distributed: Optional[bool] = None):
51 | super().__init__()
52 | self.n_feats = n_feats
53 | self.weight: Optional[Parameter]
54 | self.bias: Optional[Parameter]
55 | self.affine = affine
56 | if affine:
57 | self.weight = Parameter(torch.ones(n_feats))
58 | self.bias = Parameter(torch.zeros(n_feats))
59 | else:
60 | self.weight = None
61 | self.bias = None
62 |
63 | self.eps = eps
64 |
65 | if distributed is None:
66 | self.distributed = is_initialized()
67 | else:
68 | self.distributed = distributed
69 |
70 | def forward(self, inp):
71 | '''
72 | forward implementation of DistributedBN1D
73 | '''
74 | assert inp.ndim == 2, 'distributedBN1D must have a 2D input'
75 | if self.distributed:
76 | mean, var = mean_op(inp), var_op(inp)
77 | std = torch.sqrt(var - mean**2 + self.eps)
78 | else:
79 | mean = inp.mean(0)
80 | std = inp.std(0)
81 | normalized_x = (inp - mean.unsqueeze(0)) / std.unsqueeze(0)
82 |
83 | if self.weight is not None and self.bias is not None:
84 | result = normalized_x * self.weight.unsqueeze(0) + self.bias.unsqueeze(0)
85 | else:
86 | result = normalized_x
87 | return result
88 |
89 | def reset_parameters(self):
90 | if self.affine:
91 | init.ones_(self.weight)
92 | init.zeros_(self.bias)
93 |
94 |
95 |
96 | class MeanOp(torch.autograd.Function): # pylint: disable = abstract-method
97 | @staticmethod
98 | # pylint: disable = arguments-differ
99 | def forward(ctx, x):
100 | own_sum = torch.empty(x.size(1)+1, device=comm_device())
101 | own_sum[:-1] = x.sum(0).data.to(comm_device())
102 | own_sum[-1] = x.size(0)
103 | all_reduce(own_sum, op=dist.ReduceOp.SUM,move_to_comm_device = True)
104 | mean = (own_sum[:-1]/own_sum[-1]).to(x.device)
105 | ctx.n_points = torch.round(own_sum[-1]).long().item()
106 | ctx.inp_size = x.size(0)
107 | return mean
108 |
109 | @staticmethod
110 | # pylint: disable = arguments-differ
111 | def backward(ctx, grad):
112 | grad_comm = grad.to(comm_device())
113 | all_reduce(grad_comm, op=dist.ReduceOp.SUM,move_to_comm_device = True)
114 | return grad_comm.repeat(ctx.inp_size, 1).to(grad.device) / ctx.n_points
115 |
116 |
117 | class VarOp(torch.autograd.Function): # pylint: disable = abstract-method
118 | @staticmethod
119 | # pylint: disable = arguments-differ
120 | def forward(ctx, features):
121 | own_sum = torch.empty(features.size(1)+1, device=comm_device())
122 | own_sum[:-1] = (features**2).sum(0).data.to(comm_device())
123 | own_sum[-1] = features.size(0)
124 | all_reduce(own_sum, op=dist.ReduceOp.SUM,move_to_comm_device = True)
125 | variance = (own_sum[:-1]/own_sum[-1]).to(features.device)
126 |
127 | ctx.n_points = torch.round(own_sum[-1]).long().item()
128 | ctx.save_for_backward(features)
129 | return variance
130 |
131 | @staticmethod
132 | # pylint: disable = arguments-differ
133 | def backward(ctx, grad):
134 | features, = ctx.saved_tensors
135 | grad_comm = grad.to(comm_device())
136 | all_reduce(grad_comm, op=dist.ReduceOp.SUM,move_to_comm_device = True)
137 | return (grad_comm.to(grad.device).unsqueeze(0) * 2 * features) / ctx.n_points
138 |
139 |
140 | mean_op = MeanOp.apply
141 | var_op = VarOp.apply
142 |
--------------------------------------------------------------------------------
/sar/edge_softmax.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Intel Corporation
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy
4 | # of this software and associated documentation files (the "Software"), to deal
5 | # in the Software without restriction, including without limitation the rights
6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | # copies of the Software, and to permit persons to whom the Software is
8 | # furnished to do so, subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | # SOFTWARE.
20 |
21 | import torch
22 | import dgl.function as fn # type: ignore
23 | import dgl # type: ignore
24 | from . import GraphShardManager
25 |
26 |
27 | def edge_softmax(graph: GraphShardManager,
28 | logits: torch.Tensor, eids=dgl.ALL, norm_by: str = 'dst') -> torch.Tensor:
29 | """
30 | Implements a similar functionality as DGL's ``dgl.nn.edge_softmax`` on distributed graphs.
31 |
32 | Only supports a subset of the possible argument values.
33 |
34 | :param graph: The distributed graph
35 | :type graph: GraphShardManager
36 | :param logits: The edge logits. The size of the first dimension should be the same as the number of edges in the ``graph`` argument
37 | :type logits: torch.Tensor
38 | :param eids: must be ``dgl.ALL``
39 | :type eids:
40 | :param norm_by: must be ``'dst'``
41 | :type norm_by: str
42 | :returns: A tensor with the same size as logits contaning the softmax-normalized logits
43 |
44 | """
45 |
46 | assert eids == dgl.ALL, \
47 | 'edge_softmax on GraphShardManager only supported when eids==dgl.ALL'
48 |
49 | assert norm_by == 'dst', \
50 | 'edge_softmax on GraphShardManager only supported when norm_by==dst'
51 |
52 | with graph.local_scope():
53 | graph.edata['logits'] = logits
54 | with torch.no_grad():
55 | graph.update_all(fn.copy_e('logits', 'temp'),
56 | fn.max('temp', 'max_logits')) # pylint: disable=no-member
57 |
58 | graph.apply_edges(
59 | fn.e_sub_v('logits', 'max_logits', 'adjusted_logits')) # pylint: disable=no-member
60 |
61 | graph.edata['exp_logits'] = torch.exp(graph.edata.pop('adjusted_logits'))
62 |
63 | graph.update_all(fn.copy_e('exp_logits', 'temp'),
64 | fn.sum('temp', 'normalization')) # pylint: disable=no-member
65 |
66 | graph.apply_edges(
67 | fn.e_div_v('exp_logits', 'normalization', 'sm_output')) # pylint: disable=no-member
68 |
69 | sm_output = graph.edata.pop('sm_output')
70 |
71 | return sm_output
72 |
--------------------------------------------------------------------------------
/sar/logging_setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Intel Corporation
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy
4 | # of this software and associated documentation files (the "Software"), to deal
5 | # in the Software without restriction, including without limitation the rights
6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | # copies of the Software, and to permit persons to whom the Software is
8 | # furnished to do so, subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | # SOFTWARE.
20 |
21 | from typing import Optional, Union
22 | import logging
23 |
24 | logger = logging.getLogger('sar')
25 | logger.addHandler(logging.NullHandler())
26 | logger.propagate = False
27 |
28 |
29 | def logging_setup(log_level: int, _rank: int = -1, _world_size: int = -1, log_file: Optional[str] = None):
30 | formatter = logging.Formatter(
31 | f'{_rank+1} / {_world_size} - %(name)s - %(levelname)s - %(message)s')
32 |
33 | handler: Union[logging.FileHandler, logging.StreamHandler]
34 | if log_file is not None:
35 | handler = logging.FileHandler(log_file, 'w')
36 | else:
37 | handler = logging.StreamHandler()
38 |
39 | handler.setFormatter(formatter)
40 | handler.setLevel(log_level)
41 |
42 | logger.setLevel(log_level)
43 | logger.addHandler(handler)
44 |
--------------------------------------------------------------------------------
/sar/patch_dgl.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Intel Corporation
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy
4 | # of this software and associated documentation files (the "Software"), to deal
5 | # in the Software without restriction, including without limitation the rights
6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | # copies of the Software, and to permit persons to whom the Software is
8 | # furnished to do so, subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | # SOFTWARE.
20 |
21 | from typing import Union, List, Optional
22 | import dgl # type: ignore
23 | from .edge_softmax import edge_softmax
24 | from .core import GraphShardManager
25 |
26 | from . import message_has_parameters
27 |
28 |
29 |
30 | # patch edge_softmax in dgl's nn modules
31 |
32 | class RelGraphConv(dgl.nn.pytorch.conv.RelGraphConv):
33 | def __init__(self, *args, **kwargs):
34 | super().__init__(*args, **kwargs)
35 |
36 | @message_has_parameters(lambda self: tuple(self.linear_r.parameters()))
37 | def message(self, edges):
38 | return super().message(edges)
39 |
40 |
41 | def patched_edge_softmax(graph, *args, **kwargs):
42 | if isinstance(graph, GraphShardManager):
43 | return edge_softmax(graph, *args, **kwargs)
44 |
45 | return dgl.nn.edge_softmax(graph, *args, **kwargs) # pylint: disable=no-member
46 |
47 |
48 | def patch_dgl():
49 | """Patches DGL so that attention layers (``gatconv``, ``dotgatconv``,
50 | ``agnngatconv``) use a different ``edge_softmax`` function
51 | that supports :class:`sar.core.GraphShardManager`. Also modifies DGL's
52 | ``RelGraphConv`` to add a decorator to its ``message`` function to tell
53 | SAR how to find the parameters used to create edge messages.
54 |
55 | """
56 | dgl.nn.pytorch.conv.gatconv.edge_softmax = patched_edge_softmax
57 | dgl.nn.pytorch.conv.dotgatconv.edge_softmax = patched_edge_softmax
58 | dgl.nn.pytorch.conv.agnnconv.edge_softmax = patched_edge_softmax
59 |
60 | dgl.nn.pytorch.conv.RelGraphConv = RelGraphConv
61 | dgl.nn.RelGraphConv = RelGraphConv
62 |
--------------------------------------------------------------------------------
/security.md:
--------------------------------------------------------------------------------
1 | # Security Policy
2 | Intel is committed to rapidly addressing security vulnerabilities affecting our customers and providing clear guidance on the solution, impact, severity and mitigation.
3 |
4 | ## Reporting a Vulnerability
5 | Please report any security vulnerabilities in this project [utilizing the guidelines here](https://www.intel.com/content/www/us/en/security-center/vulnerability-handling-guidelines.html).
6 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | from pathlib import Path
4 | this_directory = Path(__file__).parent
5 | long_description = (this_directory / "README.md").read_text()
6 |
7 | setup(
8 | name='sar-gnn',
9 | version='0.1.0',
10 | python_requires='>=3.8',
11 | install_requires=[
12 | 'dgl>=1.0.0',
13 | 'numpy>=1.22.0',
14 | 'torch>=1.10.0',
15 | 'ifaddr>=0.1.7',
16 | 'packaging>=23.1'
17 | ],
18 | packages=find_packages(),
19 | author='Hesham Mostafa',
20 | author_email='hesham.mostafa@intel.com',
21 | maintainer='Kacper Pietkun',
22 | maintainer_email='kacper.pietkun@intel.com',
23 | description='A Python library for distributed training of Graph Neural Networks (GNNs) on large graphs, '
24 | 'supporting both full-batch and sampling-based training, and utilizing a sequential aggregation'
25 | 'and rematerialization technique for linear memory scaling.',
26 | long_description=long_description,
27 | long_description_content_type='text/markdown',
28 | project_urls={
29 | 'GitHub': 'https://github.com/IntelLabs/SAR/',
30 | 'Documentation': 'https://sar.readthedocs.io/en/latest/',
31 | },
32 | license='MIT',
33 | classifiers=[
34 | 'License :: OSI Approved :: MIT License',
35 | "Operating System :: OS Independent",
36 | "Programming Language :: Python :: 3.8"
37 | ]
38 | )
--------------------------------------------------------------------------------
/tests/base_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sar
3 | import torch
4 | import torch.distributed as dist
5 | import dgl
6 | from constants import *
7 | # IMPORTANT - This module should be imported independently
8 | # only by the child processes - i.e. separate workers
9 |
10 |
11 | def initialize_worker(rank, world_size, tmp_dir, backend="ccl"):
12 | """
13 | Boilerplate code for setting up connection between workers
14 |
15 | :param rank: Rank of the current machine
16 | :type rank: int
17 | :param world_size: Number of workers. The same as the number of graph partitions
18 | :type world_size: int
19 | :param tmp_dir: Path to the directory where ip file will be created
20 | :type tmp_dir: str
21 | """
22 | torch.seed()
23 | ip_file = os.path.join(tmp_dir, 'ip_file')
24 | master_ip_address = sar.nfs_ip_init(rank, ip_file)
25 | sar.initialize_comms(rank, world_size, master_ip_address, backend)
26 |
27 |
28 | def load_partition_data(rank, graph_name, tmp_dir):
29 | """
30 | Boilerplate code for loading partition data with standard `full_graph_manager` (FGM)
31 |
32 | :param rank: Rank of the current machine
33 | :type rank: int
34 | :param graph_name: Name of the partitioned graph
35 | :type graph_name: str
36 | :param tmp_dir: Path to the directory where partition data is located
37 | :type tmp_dir: str
38 | :returns: Tuple consisting of GraphShardManager object, partition features and labels
39 | """
40 | partition_file = os.path.join(tmp_dir, f'{graph_name}.json')
41 | partition_data = sar.load_dgl_partition_data(partition_file, rank, "cpu")
42 | full_graph_manager = sar.construct_full_graph(partition_data).to('cpu')
43 | features = sar.suffix_key_lookup(partition_data.node_features, 'features')
44 | labels = sar.suffix_key_lookup(partition_data.node_features, 'labels')
45 | return full_graph_manager, features, labels
46 |
47 |
48 | def load_partition_data_mfg(rank, graph_name, tmp_dir):
49 | """
50 | Boilerplate code for loading partition data with message flow graph (MFG)
51 |
52 | :param rank: Rank of the current machine
53 | :type rank: int
54 | :param graph_name: Name of the partitioned graph
55 | :type graph_name: str
56 | :param tmp_dir: Path to the directory where partition data is located
57 | :type tmp_dir: str
58 | :returns: Tuple consisting of GraphShardManager object, partition features and labels
59 | """
60 | partition_file = os.path.join(tmp_dir, f'{graph_name}.json')
61 | partition_data = sar.load_dgl_partition_data(partition_file, rank, "cpu")
62 | blocks = sar.construct_mfgs(partition_data,
63 | (partition_data.node_features[dgl.NTYPE] == 0).nonzero(as_tuple=True)[0] +
64 | partition_data.node_ranges[sar.comm.rank()][0],
65 | 3, True)
66 | blocks = [block.to('cpu') for block in blocks]
67 | features = sar.suffix_key_lookup(partition_data.node_features, 'features')
68 | labels = sar.suffix_key_lookup(partition_data.node_features, 'labels')
69 | return blocks, features, labels
70 |
71 |
72 | def synchronize_processes():
73 | """
74 | Function that simulates dist.barrier (using all_reduce because there is an issue with dist.barrier() in ccl)
75 | """
76 | dummy_tensor = torch.tensor(1)
77 | dist.all_reduce(dummy_tensor, dist.ReduceOp.MAX)
78 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import os
2 | import dgl
3 | from dgl.heterograph import DGLGraph
4 | import pytest
5 | import tempfile
6 | from constants import *
7 | import torch
8 | from torch import Tensor
9 | import multiprocessing as mp
10 | from typing import NamedTuple, Union, Dict
11 |
12 |
13 | class FixtureEnv(NamedTuple):
14 | """
15 | Stores information about variables needed by tests
16 |
17 | .. py:attribute:: temp_dir : str
18 |
19 | Path to a temporary directory which will be
20 | deleted after tests are finished
21 |
22 | .. py:attribute:: homo_graph : DGLGraph
23 |
24 | DGLGraph object representing homogenous graph
25 |
26 | .. py:attribute:: hetero_graph : DGLGraph
27 |
28 | DGLGraph object representing heterogeneous graph
29 |
30 | .. py:attribute:: node_map : Dict[str, Union[Tensor, Dict[str, Tensor]]]
31 |
32 | Dictionary of tensors representing mapping between shuffled node IDs
33 | and the original node IDs for a homogeneous graph.
34 | or, dict of dicts of tensors whose key is the node type and value
35 | is a tensor mapping between shuffled node IDs and the
36 | original node IDs for each node type for a heterogeneous graph.
37 | Each dict element represents node_mapping for different graph.
38 | Homogeneous/heterogeneous and different world_sizes
39 | """
40 | temp_dir: str
41 | homo_graph: DGLGraph
42 | hetero_graph: DGLGraph
43 | node_map: Dict[str, Union[Tensor, Dict[str, Tensor]]]
44 |
45 |
46 | @pytest.fixture(autouse=True, scope="session")
47 | def fixture_env():
48 | """
49 | Create temp directory that will be used by every test.
50 | Create and save partitioned graphs in that directory.
51 | """
52 | manager = mp.Manager()
53 | mp_dict = manager.dict()
54 | with tempfile.TemporaryDirectory() as temp_dir:
55 | p = mp.Process(target=graph_partitioning, args=(mp_dict, temp_dir,))
56 | p.start()
57 | p.join()
58 | yield FixtureEnv(temp_dir,
59 | mp_dict["homo_graph"],
60 | mp_dict["hetero_graph"],
61 | mp_dict["node_map"])
62 |
63 |
64 | def graph_partitioning(mp_dict, temp_dir):
65 | """
66 | Create and partition both homogeneous and heterogeneous
67 | graphs for different world_sizes
68 |
69 | :param temp_dir: Path to the directory where graphs will be partitioned
70 | :type temp_dir: str
71 | """
72 | homo_g = get_random_graph()
73 | hetero_g = get_random_hetero_graph()
74 |
75 | node_mappings = {}
76 | world_sizes = [1, 2, 4, 8]
77 | for world_size in world_sizes:
78 | partition_homo_dir = os.path.join(temp_dir, f"homogeneous_{world_size}")
79 | os.makedirs(partition_homo_dir)
80 | node_map, _ = dgl.distributed.partition_graph(homo_g, HOMOGENEOUS_GRAPH_NAME,
81 | world_size, partition_homo_dir,
82 | num_hops=1, balance_edges=True,
83 | return_mapping=True)
84 | node_mappings[f"homogeneous_{world_size}"] = node_map
85 |
86 | partition_hetero_dir = os.path.join(temp_dir, f"heterogeneous_{world_size}")
87 | os.makedirs(partition_hetero_dir)
88 | node_map, _ =dgl.distributed.partition_graph(hetero_g, HETEROGENEOUS_GRAPH_NAME,
89 | world_size, partition_hetero_dir,
90 | num_hops=1, balance_edges=True,
91 | return_mapping=True)
92 | node_mappings[f"heterogeneous_{world_size}"] = node_map
93 |
94 | mp_dict["homo_graph"] = homo_g
95 | mp_dict["hetero_graph"] = hetero_g
96 | mp_dict["node_map"] = node_mappings
97 |
98 |
99 | def get_random_graph():
100 | """
101 | Generates small homogenous graph with features and labels
102 |
103 | :returns: dgl graph
104 | """
105 | graph = dgl.rand_graph(1000, 2500)
106 | graph = dgl.add_self_loop(graph)
107 | graph.ndata.clear()
108 | graph.ndata['features'] = torch.rand((graph.num_nodes(), 10))
109 | graph.ndata['labels'] = torch.randint(0, 10, (graph.num_nodes(),))
110 | return graph
111 |
112 |
113 | def get_random_hetero_graph():
114 | """
115 | Generates small heterogenous graph with node features and labels only for the first node type
116 |
117 | :returns: dgl graph
118 | """
119 | graph_data = {
120 | ("n_type_1", "rel_1", "n_type_1"): (torch.randint(0, 800, (1000,)), torch.randint(0, 800, (1000,))),
121 | ("n_type_1", "rel_2", "n_type_3"): (torch.randint(0, 800, (1000,)), torch.randint(0, 800, (1000,))),
122 | ("n_type_4", "rel_3", "n_type_2"): (torch.randint(0, 800, (1000,)), torch.randint(0, 800, (1000,))),
123 | ("n_type_4", "rel_4", "n_type_1"): (torch.randint(0, 800, (1000,)), torch.randint(0, 800, (1000,))),
124 | ("n_type_1", "rev-rel_1", "n_type_1"): (torch.randint(0, 800, (1000,)), torch.randint(0, 800, (1000,))),
125 | ("n_type_3", "rev-rel_2", "n_type_1"): (torch.randint(0, 800, (1000,)), torch.randint(0, 800, (1000,))),
126 | ("n_type_2", "rev-rel_3", "n_type_4"): (torch.randint(0, 800, (1000,)), torch.randint(0, 800, (1000,))),
127 | ("n_type_1", "rev-rel_4", "n_type_4"): (torch.randint(0, 800, (1000,)), torch.randint(0, 800, (1000,)))
128 | }
129 | hetero_graph = dgl.heterograph(graph_data)
130 | hetero_graph.nodes["n_type_1"].data["features"] = torch.rand((hetero_graph.num_nodes("n_type_1"), 10))
131 | hetero_graph.nodes["n_type_1"].data["labels"] = torch.randint(0, 10, (hetero_graph.num_nodes("n_type_1"),))
132 | return hetero_graph
133 |
--------------------------------------------------------------------------------
/tests/constants.py:
--------------------------------------------------------------------------------
1 | HOMOGENEOUS_GRAPH_NAME = "dummy_homogeneous_graph"
2 | HETEROGENEOUS_GRAPH_NAME = "dummy_heterogeneous_graph"
3 |
--------------------------------------------------------------------------------
/tests/models.py:
--------------------------------------------------------------------------------
1 | ########################################################################################
2 | # File's content was partially taken from
3 | # https://github.com/dmlc/dgl/blob/master/examples/pytorch/ogb/ogbn-mag/hetero_rgcn.py
4 | ########################################################################################
5 |
6 | import dgl
7 | import dgl.nn as dglnn
8 | from dgl.nn import HeteroEmbedding
9 | from dgl import DGLGraph
10 | import torch.nn.functional as F
11 | from torch import nn
12 |
13 | class GNNModel(nn.Module):
14 | def __init__(self, in_dim: int, h_dim: int, out_dim: int):
15 | super().__init__()
16 |
17 | self.convs = nn.ModuleList([
18 | dgl.nn.GraphConv(in_dim, h_dim, activation=F.relu, bias=False),
19 | dgl.nn.GraphConv(h_dim, h_dim, activation=F.relu, bias=False),
20 | dgl.nn.GraphConv(h_dim, out_dim, activation=None, bias=False),
21 | ])
22 |
23 | def forward(self, graph, features):
24 | if isinstance(graph, list):
25 | # Message Flow Graph
26 | for conv, block in zip(self.convs, graph):
27 | features = conv(block, features)
28 | else:
29 | # Whole graph
30 | for conv in self.convs:
31 | features = conv(graph, features)
32 | return features
33 |
34 |
35 | class HeteroGNNModel(nn.Module):
36 | def __init__(self, g: DGLGraph, in_dim: int, h_dim: int, out_dim: int):
37 | super().__init__()
38 | self.rel_names = list(set(g.etypes))
39 | self.rel_names.sort()
40 |
41 | self.layers = nn.ModuleList([
42 | RelGraphConvLayer(in_dim, h_dim, g.ntypes, self.rel_names, activation=F.relu),
43 | RelGraphConvLayer(h_dim, h_dim, g.ntypes, self.rel_names, activation=F.relu),
44 | RelGraphConvLayer(h_dim, out_dim, g.ntypes, self.rel_names, activation=None)
45 | ])
46 |
47 | def forward(self, graph, h):
48 | if isinstance(graph, list):
49 | # Message Flow Graph
50 | for layer, block in zip(self.layers, graph):
51 | h = layer(block, h)
52 | else:
53 | # Whole graph
54 | for layer in self.layers:
55 | h = layer(graph, h)
56 | return h
57 |
58 | class RelGraphConvLayer(nn.Module):
59 | def __init__(self, in_feat, out_feat, ntypes, rel_names, activation=None):
60 | super().__init__()
61 | self.rel_names = rel_names
62 | self.activation = activation
63 |
64 | self.conv = dglnn.HeteroGraphConv(
65 | {
66 | rel: dglnn.GraphConv(in_feat, out_feat, norm="right", weight=False, bias=False)
67 | for rel in rel_names
68 | }
69 | )
70 | self.weight = nn.ModuleDict({
71 | rel_name: nn.Linear(in_feat, out_feat, bias=False) for rel_name in rel_names
72 | }
73 | )
74 | self.loop_weights = nn.ModuleDict({
75 | ntype: nn.Linear(in_feat, out_feat, bias=True) for ntype in ntypes
76 | }
77 | )
78 |
79 | def forward(self, g, inputs):
80 | with g.local_scope():
81 | wdict = {
82 | rel_name: {"weight": self.weight[rel_name].weight.T}
83 | for rel_name in self.rel_names
84 | }
85 | inputs_dst = {
86 | k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()
87 | }
88 | hs = self.conv(g, inputs, mod_kwargs=wdict)
89 |
90 | def _apply(ntype, h):
91 | h = h + self.loop_weights[ntype](inputs_dst[ntype])
92 | if self.activation:
93 | h = self.activation(h)
94 | return h
95 |
96 | return {ntype: _apply(ntype, h) for ntype, h in hs.items()}
97 |
98 |
99 | def extract_embed(node_embed, input_nodes, skip_type=None):
100 | emb = node_embed(
101 | {ntype: input_nodes[ntype] for ntype in input_nodes if ntype != skip_type}
102 | )
103 | return emb
104 |
105 |
106 | def rel_graph_embed(graph, embed_size, num_nodes_dict=None, skip_type=None):
107 | node_num = {}
108 | for ntype in graph.ntypes:
109 | if ntype == skip_type:
110 | continue
111 | if num_nodes_dict != None:
112 | node_num[ntype] = num_nodes_dict[ntype]
113 | else:
114 | node_num[ntype] = graph.num_nodes(ntype)
115 | embeds = HeteroEmbedding(node_num, embed_size)
116 | return embeds
117 |
--------------------------------------------------------------------------------
/tests/multiprocessing_utils.py:
--------------------------------------------------------------------------------
1 | import multiprocessing as mp
2 | import traceback
3 | import pytest
4 | import functools
5 | import tempfile
6 |
7 |
8 | def handle_mp_exception(mp_dict):
9 | """
10 | Used to handle exceptions that occurred in child processes
11 |
12 | :param mp_dict: Dictionary that is shared between different processes
13 | :type mp_dict: multiprocessing.managers.DictProxy
14 | """
15 | msg = mp_dict.get('traceback', "")
16 | for e_arg in mp_dict['exception'].args:
17 | msg += str(e_arg)
18 | print(str(msg), flush=True)
19 | pytest.fail(str(msg), pytrace=False)
20 |
21 |
22 | def run_workers(func, fixture_env, world_size, *args, **kwargs):
23 | """
24 | Starts `world_size` number of processes, where each of them
25 | behaves as a separate worker and invokes function specified
26 | by the parameter. This function should be an entry point to the
27 | 'independent' process. It has to simulate behaviour of SAR which
28 | will be spawned across different machines independently from other
29 | instances. Each process have individual memory space so it is
30 | suitable environment for testing SAR.
31 |
32 | :param func: The function that will be invoked by each process. It should take four
33 | parameters: mp_dict - shared dictionary between different processes, rank - of the current machine,
34 | world_size - number of workers, tmp_dir - path to the working directory (additionaly one can pass args and kwargs)
35 | :type func: function
36 | :param fixture_env: named tuple with all of the necessary information about preapred environment for the tests
37 | :type fixture_env: FixtureEnv
38 | :param world_size: number of workers
39 | :type world_size: int
40 | :returns: mp_dict which can be used by workers to return
41 | results from `func`
42 | """
43 | manager = mp.Manager()
44 | mp_dict = manager.dict()
45 | processes = []
46 | for rank in range(1, world_size):
47 | my_args = (mp_dict, rank, world_size, fixture_env) + args
48 | p = mp.Process(target=func, args=my_args, kwargs=kwargs)
49 | p.daemon = True
50 | p.start()
51 | processes.append(p)
52 | func(mp_dict, 0, world_size, fixture_env, *args, **kwargs)
53 |
54 | for p in processes:
55 | p.join()
56 | if 'exception' in mp_dict:
57 | handle_mp_exception(mp_dict)
58 | return mp_dict
59 |
60 |
61 | def sar_test(func):
62 | """
63 | A decorator function that wraps all SAR tests with the primary objective
64 | of facilitating module imports in tests without affecting other tests.
65 |
66 | :param func: The function that serves as the entry point to the test.
67 | :type func: function
68 | :returns: A function that encapsulates the pytest function.
69 | """
70 | @functools.wraps(func)
71 | def test_wrapper(*args, **kwargs):
72 | """
73 | The wrapping process involves defining another nested function, which is then invoked by a newly spawned process.
74 | function spawns a new process and uses the "join" method to wait for the results.
75 | Upon completion of the process, error and result handling are performed.
76 | """
77 | def process_wrapper(func, mp_dict, *args, **kwargs):
78 | try:
79 | result = func(*args, **kwargs)
80 | mp_dict["result"] = result
81 | except Exception as e:
82 | mp_dict['traceback'] = str(traceback.format_exc())
83 | mp_dict["exception"] = e
84 |
85 | manager = mp.Manager()
86 | mp_dict = manager.dict()
87 |
88 | mp_args = (func, mp_dict) + args
89 | p = mp.Process(target=process_wrapper, args=mp_args, kwargs=kwargs)
90 | p.start()
91 | p.join()
92 |
93 | if 'exception' in mp_dict:
94 | handle_mp_exception(mp_dict)
95 |
96 | return mp_dict["result"]
97 | return test_wrapper
98 |
--------------------------------------------------------------------------------
/tests/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | testpaths = tests/
3 | python_files = test_*.py
4 | python_classes = Test*
5 | python_functions = test_*
--------------------------------------------------------------------------------
/tests/test_comm.py:
--------------------------------------------------------------------------------
1 | import os
2 | from copy import deepcopy
3 | from multiprocessing_utils import *
4 | from constants import *
5 | # Do not import DGL and SAR - these modules should be
6 | # independently loaded inside each process
7 |
8 |
9 | @pytest.mark.parametrize("backend", ["ccl", "gloo"])
10 | @pytest.mark.parametrize('world_size', [2, 4, 8])
11 | @sar_test
12 | def test_sync_params(world_size, backend, fixture_env):
13 | """
14 | Checks whether model's parameters are the same across all
15 | workers after calling sync_params function. Parameters of worker 0
16 | should be copied to all workers, so its parameters before and after
17 | sync_params should be the same
18 | """
19 | def sync_params(mp_dict, rank, world_size, fixture_env, **kwargs):
20 | import torch
21 | import sar
22 | from base_utils import initialize_worker, synchronize_processes
23 | from models import GNNModel
24 |
25 | temp_dir = fixture_env.temp_dir
26 | initialize_worker(rank, world_size, temp_dir, backend=kwargs["backend"])
27 | model = GNNModel(16, 8, 4)
28 | if rank == 0:
29 | mp_dict[f"result_{rank}"] = deepcopy(model.state_dict())
30 | sar.sync_params(model)
31 | if rank != 0:
32 | mp_dict[f"result_{rank}"] = model.state_dict()
33 |
34 | synchronize_processes()
35 | for rank in range(1, world_size):
36 | for key in mp_dict[f"result_0"].keys():
37 | assert torch.all(torch.eq(mp_dict[f"result_0"][key], mp_dict[f"result_{rank}"][key]))
38 |
39 | run_workers(sync_params, fixture_env, world_size, backend=backend)
40 |
41 |
42 | @pytest.mark.parametrize("backend", ["ccl", "gloo"])
43 | @pytest.mark.parametrize('world_size', [2, 4, 8])
44 | @sar_test
45 | def test_gather_grads(world_size, backend, fixture_env):
46 | """
47 | Checks whether parameter's gradients are the same across all
48 | workers after calling gather_grads function
49 | """
50 | def gather_grads(mp_dict, rank, world_size, fixture_env, **kwargs):
51 | import torch
52 | import sar
53 | import torch.nn.functional as F
54 | from models import GNNModel
55 | from base_utils import initialize_worker, synchronize_processes, load_partition_data
56 |
57 | temp_dir = fixture_env.temp_dir
58 | initialize_worker(rank, world_size, temp_dir, backend=kwargs["backend"])
59 | fgm, feat, labels = load_partition_data(rank, HOMOGENEOUS_GRAPH_NAME,
60 | os.path.join(temp_dir, f"homogeneous_{world_size}"))
61 | model = GNNModel(feat.shape[1], feat.shape[1], labels.max()+1)
62 | sar.sync_params(model)
63 | sar_logits = model(fgm, feat)
64 | sar_loss = F.cross_entropy(sar_logits, labels)
65 | sar_loss.backward()
66 | sar.gather_grads(model)
67 | mp_dict[f"result_{rank}"] = [torch.tensor(x.grad) for x in model.parameters()]
68 |
69 | synchronize_processes()
70 | for rank in range(1, world_size):
71 | for i in range(len(mp_dict["result_0"])):
72 | assert torch.all(torch.eq(mp_dict["result_0"][i], mp_dict[f"result_{rank}"][i]))
73 |
74 | run_workers(gather_grads, fixture_env, world_size, backend=backend)
75 |
76 |
77 | @pytest.mark.parametrize("backend", ["ccl", "gloo"])
78 | @pytest.mark.parametrize("world_size", [2, 4, 8])
79 | @sar_test
80 | def test_all_to_all(world_size, backend, fixture_env):
81 | """
82 | Checks whether all_to_all operation works as expected. Test is
83 | designed is such a way, that after calling all_to_all, each worker
84 | should receive a list of tensors with values equal to their rank
85 | """
86 | def all_to_all(mp_dict, rank, world_size, fixture_env, **kwargs):
87 | import torch
88 | import sar
89 | from base_utils import initialize_worker, synchronize_processes
90 |
91 | temp_dir = fixture_env.temp_dir
92 | initialize_worker(rank, world_size, temp_dir, backend=kwargs["backend"])
93 | send_tensors_list = [torch.tensor([x] * world_size) for x in range(world_size)]
94 | recv_tensors_list = [torch.tensor([-1] * world_size) for _ in range(world_size)]
95 | sar.comm.all_to_all(recv_tensors_list, send_tensors_list)
96 | mp_dict[f"result_{rank}"] = recv_tensors_list
97 |
98 | synchronize_processes()
99 | for rank in range(world_size):
100 | for tensor in mp_dict[f"result_{rank}"]:
101 | assert torch.all(torch.eq(tensor, torch.tensor([rank] * world_size)))
102 |
103 | run_workers(all_to_all, fixture_env, world_size, backend=backend)
104 |
105 |
106 | @pytest.mark.parametrize("backend", ["ccl", "gloo"])
107 | @pytest.mark.parametrize("world_size", [2, 4, 8])
108 | @sar_test
109 | def test_exchange_single_tensor(world_size, backend, fixture_env):
110 | """
111 | Checks whether exchange_single_tensor operation works as expected. Test is
112 | designed is such a way, that after calling exchange_single_tensor between two machines,
113 | machine should recive a tensor with values equal to their rank
114 | """
115 | def exchange_single_tensor(mp_dict, rank, world_size, fixture_env, **kwargs):
116 | import torch
117 | import sar
118 | from base_utils import initialize_worker, synchronize_processes
119 |
120 | temp_dir = fixture_env.temp_dir
121 | initialize_worker(rank, world_size, temp_dir, backend=kwargs["backend"])
122 | send_idx = rank
123 | recv_idx = rank
124 | results = []
125 | for _ in range(world_size):
126 | send_tensor = torch.tensor([send_idx] * world_size)
127 | recv_tensor = torch.tensor([-1] * world_size)
128 | sar.comm.exchange_single_tensor(recv_idx, send_idx, recv_tensor, send_tensor)
129 | results.append(recv_tensor)
130 | send_idx = (send_idx + 1) % world_size
131 | recv_idx = (recv_idx - 1) % world_size
132 | mp_dict[f"result_{rank}"] = results
133 |
134 | synchronize_processes()
135 | for recv_tensor in mp_dict[f"result_{rank}"]:
136 | assert torch.all(torch.eq(recv_tensor, torch.tensor([rank] * world_size)))
137 |
138 | run_workers(exchange_single_tensor, fixture_env, world_size, backend=backend)
139 |
--------------------------------------------------------------------------------
/tests/test_patch_dgl.py:
--------------------------------------------------------------------------------
1 | from multiprocessing_utils import *
2 | # Do not import DGL and SAR - these modules should be
3 | # independently loaded inside each process
4 |
5 |
6 | @sar_test
7 | def test_patch_dgl():
8 | """
9 | Import DGL library and SAR and check whether `patch_dgl` function
10 | overrides edge_softmax function in specific GNN layers implementation.
11 | """
12 | import dgl
13 | original_gat_edge_softmax = dgl.nn.pytorch.conv.gatconv.edge_softmax
14 | original_dotgat_edge_softmax = dgl.nn.pytorch.conv.dotgatconv.edge_softmax
15 | original_agnn_edge_softmax = dgl.nn.pytorch.conv.agnnconv.edge_softmax
16 |
17 | import sar
18 | sar.patch_dgl()
19 |
20 | assert original_gat_edge_softmax == dgl.nn.functional.edge_softmax
21 | assert original_dotgat_edge_softmax == dgl.nn.functional.edge_softmax
22 | assert original_agnn_edge_softmax == dgl.nn.functional.edge_softmax
23 |
24 | assert dgl.nn.pytorch.conv.gatconv.edge_softmax == sar.patched_edge_softmax
25 | assert dgl.nn.pytorch.conv.dotgatconv.edge_softmax == sar.patched_edge_softmax
26 | assert dgl.nn.pytorch.conv.RelGraphConv == sar.RelGraphConv
27 | assert dgl.nn.RelGraphConv == sar.RelGraphConv
28 |
--------------------------------------------------------------------------------
/tests/test_sar.py:
--------------------------------------------------------------------------------
1 | from multiprocessing_utils import *
2 | from constants import *
3 | import os
4 | # Do not import DGL and SAR - these modules should be
5 | # independently loaded inside each process
6 |
7 |
8 | @pytest.mark.parametrize("backend", ["ccl", "gloo"])
9 | @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
10 | @sar_test
11 | def test_homogeneous_fgm(world_size, backend, fixture_env):
12 | """
13 | Perform full graph inference using SAR algorithm on homogeneous graph.
14 | Test is comparing mean of concatenated results from all processes
15 | with mean of native DGL full graph inference result.
16 | """
17 | import torch
18 | def homogeneous_fgm(mp_dict, rank, world_size, fixture_env, **kwargs):
19 | import sar
20 | from models import GNNModel
21 | from base_utils import initialize_worker, load_partition_data
22 |
23 | temp_dir = fixture_env.temp_dir
24 | initialize_worker(rank, world_size, temp_dir, backend=kwargs["backend"])
25 | fgm, feat, labels = load_partition_data(rank, HOMOGENEOUS_GRAPH_NAME,
26 | os.path.join(temp_dir, f"homogeneous_{world_size}"))
27 | model = GNNModel(feat.shape[1], feat.shape[1], labels.max()+1).to('cpu')
28 | sar.sync_params(model)
29 | model.eval()
30 | sar_logits = model(fgm, feat)
31 |
32 | mp_dict[f"result_{rank}"] = sar_logits.detach()
33 | if rank == 0:
34 | mp_dict["model"] = model
35 | mp_dict["graph"] = fixture_env.homo_graph
36 | mp_dict["node_map"] = fixture_env.node_map[f"homogeneous_{world_size}"]
37 |
38 | mp_dict = run_workers(homogeneous_fgm, fixture_env, world_size, backend=backend)
39 |
40 | model = mp_dict["model"]
41 | graph = mp_dict["graph"]
42 | dgl_logits = model(graph, graph.ndata['features']).detach()
43 | dgl_logits_mean = dgl_logits.mean(axis=1)
44 |
45 | sar_logits = torch.tensor([])
46 | for rank in range(world_size):
47 | sar_logits = torch.cat((sar_logits, mp_dict[f"result_{rank}"]))
48 | sar_logits[mp_dict["node_map"]] = sar_logits.clone()
49 | sar_logits_mean = sar_logits.mean(axis=1)
50 |
51 | assert torch.all(torch.isclose(dgl_logits_mean, sar_logits_mean, atol=1e-6, rtol=1e-6))
52 |
53 | @pytest.mark.parametrize("backend", ["ccl", "gloo"])
54 | @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
55 | @sar_test
56 | def test_heterogeneous_fgm(world_size, backend, fixture_env):
57 | """
58 | Perform full graph inference using SAR algorithm on heterogeneous graph.
59 | Test is comparing mean of concatenated results from all processes
60 | with mean of native DGL full graph inference result.
61 | """
62 | import torch
63 | from models import rel_graph_embed
64 | def heterogeneous_fgm(mp_dict, rank, world_size, fixture_env, **kwargs):
65 | import sar
66 | from models import HeteroGNNModel, extract_embed
67 | from base_utils import initialize_worker, load_partition_data
68 |
69 | temp_dir = fixture_env.temp_dir
70 | initialize_worker(rank, world_size, temp_dir, backend=kwargs["backend"])
71 | fgm, feats, labels = load_partition_data(rank, HETEROGENEOUS_GRAPH_NAME,
72 | os.path.join(temp_dir, f"heterogeneous_{world_size}"))
73 | model = HeteroGNNModel(fgm, feats.shape[1], feats.shape[1], labels.max()+1).to('cpu')
74 | model.eval()
75 | sar.sync_params(model)
76 |
77 | to_extract = {}
78 | node_map = fixture_env.node_map[f"heterogeneous_{world_size}"]
79 | for ntype in fgm.srctypes:
80 | if ntype == "n_type_1":
81 | continue
82 | down_lim = fgm._partition_book.partid2nids(rank, ntype).min()
83 | up_lim = fgm._partition_book.partid2nids(rank+1, ntype).min() if rank+1 < world_size else None
84 | ids = node_map[ntype][down_lim:up_lim]
85 | to_extract[ntype] = ids
86 |
87 | embed_layer = kwargs["embed_layer"]
88 | embeds = extract_embed(embed_layer, to_extract, skip_type="n_type_1")
89 | embeds.update({"n_type_1": feats[fgm.srcnodes("n_type_1")]})
90 | embeds = {k: e.to("cpu") for k, e in embeds.items()}
91 |
92 | sar_logits = model(fgm, embeds)
93 | sar_logits = sar_logits["n_type_1"]
94 |
95 | mp_dict[f"result_{rank}"] = sar_logits.detach()
96 | if rank == 0:
97 | mp_dict["model"] = model
98 | mp_dict["node_map"] = fixture_env.node_map[f"heterogeneous_{world_size}"]
99 |
100 | graph = fixture_env.hetero_graph
101 | max_num_nodes = {ntype: graph.num_nodes(ntype) for ntype in graph.ntypes}
102 | embed_layer = rel_graph_embed(graph, graph.ndata["features"]["n_type_1"].shape[1],
103 | num_nodes_dict=max_num_nodes,
104 | skip_type="n_type_1").to('cpu')
105 | mp_dict = run_workers(heterogeneous_fgm, fixture_env, world_size, backend=backend,
106 | embed_layer=embed_layer)
107 |
108 | embeds = embed_layer.weight
109 | embeds.update({"n_type_1": graph.ndata["features"]["n_type_1"]})
110 | embeds = {k: e.to("cpu") for k, e in embeds.items()}
111 |
112 | model = mp_dict["model"]
113 | dgl_logits = model(graph, embeds)
114 | dgl_logits = dgl_logits["n_type_1"].detach()
115 | dgl_logits_mean = dgl_logits.mean(axis=1)
116 |
117 | sar_logits = torch.tensor([])
118 | for rank in range(world_size):
119 | sar_logits = torch.cat((sar_logits, mp_dict[f"result_{rank}"]))
120 | sar_logits[mp_dict["node_map"]["n_type_1"]] = sar_logits.clone()
121 | sar_logits_mean = sar_logits.mean(axis=1)
122 |
123 | assert torch.all(torch.isclose(dgl_logits_mean, sar_logits_mean, atol=1e-6, rtol=1e-6))
124 |
125 |
126 | @pytest.mark.parametrize("backend", ["ccl", "gloo"])
127 | @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
128 | @sar_test
129 | def test_homogeneous_mfg(world_size, backend, fixture_env):
130 | """
131 | Perform full graph inference using SAR algorithm on homogeneous graph.
132 | Script is using Message Flow Graph (mfg). Test is comparing mean of
133 | concatenated results from all processes with mean of native DGL full
134 | graph inference result.
135 | """
136 | import torch
137 | def homogeneous_mfg(mp_dict, rank, world_size, fixture_env, **kwargs):
138 | import sar
139 | from models import GNNModel
140 | from base_utils import initialize_worker, load_partition_data_mfg
141 |
142 | temp_dir = fixture_env.temp_dir
143 | initialize_worker(rank, world_size, temp_dir, backend=kwargs["backend"])
144 | blocks, feat, labels = load_partition_data_mfg(rank, HOMOGENEOUS_GRAPH_NAME,
145 | os.path.join(temp_dir, f"homogeneous_{world_size}"))
146 | model = GNNModel(feat.shape[1], feat.shape[1], labels.max()+1).to('cpu')
147 | sar.sync_params(model)
148 | model.eval()
149 | sar_logits = model(blocks, feat)
150 |
151 | mp_dict[f"result_{rank}"] = sar_logits.detach()
152 | if rank == 0:
153 | mp_dict["model"] = model
154 | mp_dict["graph"] = fixture_env.homo_graph
155 | mp_dict["node_map"] = fixture_env.node_map[f"homogeneous_{world_size}"]
156 |
157 | mp_dict = run_workers(homogeneous_mfg, fixture_env, world_size, backend=backend)
158 |
159 | model = mp_dict["model"]
160 | graph = mp_dict["graph"]
161 | dgl_logits = model(graph, graph.ndata['features']).detach()
162 | dgl_logits_mean = dgl_logits.mean(axis=1)
163 |
164 | sar_logits = torch.tensor([])
165 | for rank in range(world_size):
166 | sar_logits = torch.cat((sar_logits, mp_dict[f"result_{rank}"]))
167 | sar_logits[mp_dict["node_map"]] = sar_logits.clone()
168 | sar_logits_mean = sar_logits.mean(axis=1)
169 |
170 | assert torch.all(torch.isclose(dgl_logits_mean, sar_logits_mean, atol=1e-6, rtol=1e-6))
171 |
172 |
173 | @pytest.mark.parametrize("backend", ["ccl", "gloo"])
174 | @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
175 | @sar_test
176 | def test_heterogeneous_mfg(world_size, backend, fixture_env):
177 | """
178 | Perform full graph inference using SAR algorithm on heterogeneous graph.
179 | Script is using Message Flow Graph (mfg). Test is comparing mean of
180 | concatenated results from all processes with mean of native DGL full
181 | graph inference result.
182 | """
183 | import torch
184 | from models import rel_graph_embed
185 | def heterogeneous_mfg(mp_dict, rank, world_size, fixture_env, **kwargs):
186 | import sar
187 | from models import HeteroGNNModel, extract_embed
188 | from base_utils import initialize_worker, load_partition_data_mfg
189 |
190 | temp_dir = fixture_env.temp_dir
191 | initialize_worker(rank, world_size, temp_dir, backend=kwargs["backend"])
192 |
193 | blocks, feats, labels = load_partition_data_mfg(rank, HETEROGENEOUS_GRAPH_NAME,
194 | os.path.join(temp_dir, f"heterogeneous_{world_size}"))
195 | model = HeteroGNNModel(blocks[0], feats.shape[1], feats.shape[1], labels.max()+1).to('cpu')
196 | model.eval()
197 | sar.sync_params(model)
198 |
199 | to_extract = {}
200 | node_map = fixture_env.node_map[f"heterogeneous_{world_size}"]
201 | for ntype in blocks[0].srctypes:
202 | if ntype == "n_type_1":
203 | continue
204 | down_lim = blocks[0]._partition_book.partid2nids(rank, ntype).min()
205 | up_lim = blocks[0]._partition_book.partid2nids(rank+1, ntype).min() if rank+1 < world_size else None
206 | ids = node_map[ntype][down_lim:up_lim]
207 | to_extract[ntype] = ids[blocks[0].srcnodes(ntype)]
208 |
209 | embed_layer = kwargs["embed_layer"]
210 | embeds = extract_embed(embed_layer, to_extract, skip_type="n_type_1")
211 | embeds.update({"n_type_1": feats[blocks[0].srcnodes("n_type_1")]})
212 | embeds = {k: e.to("cpu") for k, e in embeds.items()}
213 |
214 | sar_logits = model(blocks, embeds)
215 | sar_logits = sar_logits["n_type_1"]
216 |
217 | mp_dict[f"result_{rank}"] = sar_logits.detach()
218 | if rank == 0:
219 | mp_dict["model"] = model
220 | mp_dict["node_map"] = fixture_env.node_map[f"heterogeneous_{world_size}"]
221 |
222 | graph = fixture_env.hetero_graph
223 | max_num_nodes = {ntype: graph.num_nodes(ntype) for ntype in graph.ntypes}
224 | embed_layer = rel_graph_embed(graph, graph.ndata["features"]["n_type_1"].shape[1],
225 | num_nodes_dict=max_num_nodes,
226 | skip_type="n_type_1").to('cpu')
227 | mp_dict = run_workers(heterogeneous_mfg, fixture_env, world_size, backend=backend,
228 | embed_layer=embed_layer)
229 |
230 | embeds = embed_layer.weight
231 | embeds.update({"n_type_1": graph.ndata["features"]["n_type_1"]})
232 | embeds = {k: e.to("cpu") for k, e in embeds.items()}
233 |
234 | model = mp_dict["model"]
235 | dgl_logits = model(graph, embeds)
236 | dgl_logits = dgl_logits["n_type_1"].detach()
237 | dgl_logits_mean = dgl_logits.mean(axis=1)
238 |
239 | sar_logits = torch.tensor([])
240 | for rank in range(world_size):
241 | sar_logits = torch.cat((sar_logits, mp_dict[f"result_{rank}"]))
242 | sar_logits[mp_dict["node_map"]["n_type_1"]] = sar_logits.clone()
243 | sar_logits_mean = sar_logits.mean(axis=1)
244 |
245 | assert torch.all(torch.isclose(dgl_logits_mean, sar_logits_mean, atol=1e-6, rtol=1e-6))
246 |
247 |
248 | @pytest.mark.parametrize("backend", ["ccl"])
249 | @pytest.mark.parametrize('world_size', [1])
250 | @sar_test
251 | def test_convert_dist_graph(world_size, backend, fixture_env):
252 | """
253 | Create DGL's DistGraph object with random graph partitioned into
254 | one part (only way to test DistGraph locally). Then perform converting
255 | DistGraph into SAR GraphShardManager and check relevant properties.
256 | """
257 | def convert_dist_graph(mp_dict, rank, world_size, fixture_env, **kwargs):
258 | import dgl
259 | import sar
260 | from base_utils import initialize_worker
261 |
262 | temp_dir = fixture_env.temp_dir
263 | partition_file = os.path.join(temp_dir, f'homogeneous_{world_size}', f"{HOMOGENEOUS_GRAPH_NAME}.json")
264 | initialize_worker(rank, world_size, temp_dir, backend=kwargs["backend"])
265 | dgl.distributed.initialize("kv_ip_config.txt")
266 | dist_g = dgl.distributed.DistGraph(
267 | HOMOGENEOUS_GRAPH_NAME, part_config=partition_file)
268 |
269 | sar_g = sar.convert_dist_graph(dist_g)
270 | assert len(sar_g.graph_shard_managers[0].graph_shards) == dist_g.get_partition_book().num_partitions()
271 | assert dist_g.num_edges() == sar_g.num_edges()
272 | assert dist_g.num_nodes() == sar_g.num_nodes()
273 | assert dist_g.ntypes == sar_g.ntypes
274 | assert dist_g.etypes == sar_g.etypes
275 |
276 | run_workers(convert_dist_graph, fixture_env, world_size, backend=backend)
277 |
--------------------------------------------------------------------------------