├── .github └── workflows │ └── sar_test.yaml ├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── docs ├── Makefile └── source │ ├── _templates │ └── autosummary │ │ ├── distneighborsampler.rst │ │ └── graphshardmanager.rst │ ├── comm.rst │ ├── common_tuples.rst │ ├── conf.py │ ├── data_loading.rst │ ├── full_batch.rst │ ├── images │ ├── dom_parallel_naive.png │ ├── dom_parallel_naive.svg │ ├── dom_parallel_remat.pdf │ ├── dom_parallel_remat.png │ ├── dom_parallel_remat.svg │ ├── one_shot_aggregation.png │ ├── one_shot_aggregation.svg │ ├── papers_gat_memory.png │ ├── papers_os_scaling.png │ ├── papers_sage_memory.png │ ├── papers_train_full_doc.png │ └── sar_vs_distdgl.png │ ├── index.rst │ ├── model_prepare.rst │ ├── quick_start.rst │ ├── sampling_training.rst │ ├── sar_config.rst │ ├── sar_modes.rst │ └── shards.rst ├── examples ├── README.md ├── SIGN │ ├── README.md │ └── train_sign_with_sar.py ├── correct_and_smooth.py ├── partition_graph.py ├── rgcn-hetero │ ├── README.md │ ├── model.py │ ├── train_heterogeneous_graph.py │ └── train_heterogeneous_graph_mfg.py ├── train_dist_appnp_with_sar.py ├── train_distdgl_with_sar_inference.py ├── train_homogeneous_graph_advanced.py ├── train_homogeneous_graph_basic.py └── train_homogeneous_sampling_basic.py ├── pyproject.toml ├── requirements.txt ├── sar ├── __init__.py ├── comm.py ├── common_tuples.py ├── config.py ├── construct_shard_manager.py ├── core │ ├── __init__.py │ ├── full_partition_block.py │ ├── graphshard.py │ ├── sampling.py │ └── sar_aggregation.py ├── data_loading.py ├── distributed_bn.py ├── edge_softmax.py ├── logging_setup.py └── patch_dgl.py ├── security.md ├── setup.py └── tests ├── base_utils.py ├── conftest.py ├── constants.py ├── models.py ├── multiprocessing_utils.py ├── pytest.ini ├── test_comm.py ├── test_hetero_graph_shard_manager.py ├── test_patch_dgl.py └── test_sar.py /.github/workflows/sar_test.yaml: -------------------------------------------------------------------------------- 1 | name: SAR tests 2 | permissions: {} 3 | 4 | on: 5 | pull_request: 6 | branches: [main] 7 | workflow_dispatch: 8 | 9 | jobs: 10 | sar_tests: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Pull SAR 14 | uses: actions/checkout@v3 15 | with: 16 | fetch-depth: 0 17 | 18 | - name: Setup Python 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: '3.10' 22 | 23 | - name: Install requirements 24 | run: | 25 | python -m pip install --upgrade pip 26 | python -m pip install pytest 27 | # oneccl_bind_pt for torch 2.0.0 and python 3.10 28 | wget https://intel-extension-for-pytorch.s3.amazonaws.com/torch_ccl/cpu/oneccl_bind_pt-2.0.0%2Bcpu-cp310-cp310-linux_x86_64.whl 29 | python -m pip install oneccl_bind_pt-2.0.0+cpu-cp310-cp310-linux_x86_64.whl 30 | python -m pip install -e . "torch==2.0.0" 31 | python -m pip install pandas pyyaml pydantic 32 | 33 | - name: Run pytest 34 | run: | 35 | set +e 36 | python -m pytest tests/ -sv 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python cache 2 | __pycache__/ 3 | *.pyc 4 | 5 | # Jupyter notebook checkpoints 6 | .ipynb_checkpoints/ 7 | 8 | # Compiled Python files 9 | *.pyc 10 | *.pyo 11 | *.pyd 12 | __pycache__/ 13 | 14 | # Build directories 15 | build/ 16 | dist/ 17 | *.egg-info/ 18 | 19 | 20 | # Package distribution 21 | *.egg 22 | *.egg-info 23 | 24 | # IDE and editor files 25 | .vscode/ 26 | .idea/ 27 | *.iml 28 | *.iws 29 | *.ipr 30 | 31 | # Unit test / coverage reports 32 | htmlcov/ 33 | .tox/ 34 | .coverage 35 | .coverage.* 36 | .cache 37 | nosetests.xml 38 | coverage.xml 39 | *.cover 40 | .hypothesis/ 41 | .pytest_cache/ 42 | 43 | # Environments 44 | .env 45 | .venv 46 | env/ 47 | venv/ 48 | ENV/ 49 | env.bak/ 50 | venv.bak/ 51 | 52 | # Datasets and partitions 53 | dataset/ 54 | datasets/ 55 | partition_data/ -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Required 2 | version: 2 3 | 4 | # Set the version of Python and other tools you might need 5 | build: 6 | os: ubuntu-22.04 7 | tools: 8 | python: "3.9" 9 | 10 | # Build documentation in the docs/ directory with Sphinx 11 | sphinx: 12 | configuration: docs/source/conf.py 13 | 14 | 15 | # Optionally declare the Python requirements required to build your docs 16 | python: 17 | install: 18 | - requirements: requirements.txt 19 | 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Intel Labs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PROJECT NOT UNDER ACTIVE MANAGEMENT # 2 | This project will no longer be maintained by Intel. 3 | Intel has ceased development and contributions including, but not limited to, maintenance, bug fixes, new releases, or updates, to this project. 4 | Intel no longer accepts patches to this project. 5 | If you have an ongoing need to use this project, are interested in independently developing it, or would like to maintain patches for the open source software community, please create your own fork of this project. 6 | 7 | 8 | [Documentation](https://sar.readthedocs.io/en/latest/) | [Examples](https://github.com/IntelLabs/SAR/tree/main/examples) 9 | 10 | SAR is a pure Python library for distributed training of Graph Neural Networks (GNNs) on large graphs. SAR is built on top of PyTorch and DGL and supports distributed full-batch training as well as distributed sampling-based training. SAR is particularly suited for training GNNs on large graphs as the graph is partitioned across the training machines. In full-batch training, SAR can utilize the [sequential aggregation and rematerialization technique](https://proceedings.mlsys.org/paper_files/paper/2022/hash/1d781258d409a6efc66cd1aa14a1681c-Abstract.html) to guarantees linear memory scaling, i.e, the memory needed to store the GNN activiations in each host is guaranteed to go down linearly with the number of hosts, even for densely connected graphs. 11 | 12 | SAR requires minimal changes to existing GNN training code. SAR directly uses the graph partitioning data created by [DGL's partitioning tools](https://docs.dgl.ai/en/0.6.x/generated/dgl.distributed.partition.partition_graph.html) and can thus be used as a drop-in replacement for DGL's distributed sampling-based training. To get started using SAR, check out [SAR's documentation](https://sar.readthedocs.io/en/latest/) and the examples under the `examples/` folder. 13 | 14 | 15 | ## Installing required packages 16 | ```shell 17 | pip3 install -r requirements.txt 18 | ``` 19 | Python3.8 or higher is required. You also need to install [torch CCL](https://github.com/intel/torch-ccl) if you want to use Intel's OneCCL communication backend. 20 | 21 | ## Full-batch training Performance on ogbn-papers100M 22 | SAR consumes up to 2x less memory when training a 3-layer GraphSage network on ogbn-papers100M (111M nodes, 3.2B edges), and up to 4x less memory when training a 3-layer Graph Attention Network (GAT). SAR achieves near linear scaling for the peak memory requirements per machine. We use a 3-layer GraphSage network with hidden layer size of 256, and a 3-layer GAT network with hidden layer size of 128 and 4 attention heads. We use batch normalization between all layers 23 | 24 | 25 | 26 | 27 | The run-time of SAR improves as we add more machines. At 128 machines, the epoch time is 3.8s. Each machine is a 2-socket machine with 2 Icelake processors (36 cores each). The machines are connected using Infiniband HDR (200 Gbps) links. After 100 epochs, training has converged. We use a 3-layer GraphSage network with hidden layer size of 256 and batch normalization between all layers. The training curve is the same regardless of the number of machines/partition. 28 | 29 | 30 | 31 | 32 | ## Sampling-based training Performance on ogbn-papers100M 33 | SAR is considerably faster than DistDGL in sampling-based training on CPUs. Each machine is a 2-socket machine with 2 Icelake processors (36 cores each). The machines are connected using Infiniband HDR (200 Gbps) links. We benchmarked using 3-layer GraphSage network with hidden layer size of 256. We used a batch size of 1000 per machine. 34 | 35 | 36 | 37 | 38 | 39 | ## Cite 40 | 41 | If you use SAR in your publication, we would appreciate it if you cite the SAR paper: 42 | ``` 43 | @article{mostafa2021sequential, 44 | title={Sequential Aggregation and Rematerialization: Distributed Full-batch Training of Graph Neural Networks on Large Graphs}, 45 | author={Mostafa, Hesham}, 46 | journal={MLSys}, 47 | year={2022} 48 | } 49 | ``` 50 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | SOURCEDIR = source 5 | BUILDDIR = build 6 | 7 | # Put it first so that "make" without argument is like "make help". 8 | help: 9 | sphinx-build -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(O) 10 | 11 | .PHONY: help Makefile 12 | 13 | %: Makefile 14 | sphinx-build -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(O) 15 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/distneighborsampler.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | {{ name | underline}} 7 | 8 | .. autoclass:: {{ name }} 9 | :members: 10 | 11 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/graphshardmanager.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | {{ name | underline}} 7 | 8 | .. autoclass:: {{ name }} 9 | :members: 10 | 11 | -------------------------------------------------------------------------------- /docs/source/comm.rst: -------------------------------------------------------------------------------- 1 | .. _comm-guide: 2 | .. currentmodule:: sar 3 | 4 | 5 | SAR's communication routines 6 | ============================= 7 | SAR uses only two types of collective communication calls: ``all_to_all`` and ``all_reduce``. This choice was made to improve scalability by avoiding any point-to-point communication. SAR supports four backends, which are ``ccl``, ``nccl``, ``mpi`` and ``gloo``. (Note: Using ``gloo`` backend may not be as optimal as using other backends, because it doesn't support ``all_to_all`` routine - SAR must use its own implementation, which uses multiple asynchronous sends (torch.dist.isend) between workers). Nvidia's ``nccl`` is already included in the PyTorch distribution and it is the natural choice when training on GPUs. 8 | 9 | The ``ccl`` backend uses `Intel's OneCCL `_ library. You can install the PyTorch bindings for OneCCL `here `_ . ``ccl`` is the preferred backend when training on CPUs. 10 | 11 | You can train on CPUs and still use the ``nccl`` backend, or you can train on GPUs and use the ``ccl`` backend. However, you will incur extra overhead to move tensors back and forth between the CPU and GPU in order to provide the right tensors to the communication backend. 12 | 13 | In an environment with a networked file system, initializing ``torch.distributed`` is quite easy: :: 14 | 15 | if backend_name == 'nccl': 16 | comm_device = torch.device('cuda') 17 | else: 18 | comm_device = torch.device('cpu') 19 | 20 | master_ip_address = sar.nfs_ip_init(rank, path_to_ip_file) 21 | sar.initialize_comms(rank, world_size, master_ip_address, backend_name, comm_device) 22 | 23 | .. 24 | :func:`sar.initialize_comms` tries to initialize the torch.distributed process group, but only if it has not been initialized. User can initialize process group on his own before calling :func:`sar.initialize_comms`. 25 | :func:`sar.nfs_ip_init` communicates the master's ip address to the workers through the file system. In the absence of a networked file system, you should develop your own mechanism to communicate the master's ip address. 26 | 27 | You can specify the name of the socket that will be used for communication with `SAR_SOCKET_NAME` environment variable (if not specified, the first available socket will be selected). 28 | 29 | 30 | 31 | Relevant methods 32 | --------------------------------------------------------------------------- 33 | 34 | .. autosummary:: 35 | :toctree: comm package 36 | 37 | initialize_comms 38 | rank 39 | world_size 40 | sync_params 41 | gather_grads 42 | nfs_ip_init 43 | -------------------------------------------------------------------------------- /docs/source/common_tuples.rst: -------------------------------------------------------------------------------- 1 | .. _common-tuples: 2 | 3 | .. currentmodule:: sar.common_tuples 4 | 5 | 6 | Data tuples 7 | ============================= 8 | ``sar.common_tuples`` defines a number of useful ``NamedTuple`` classes that are used to exchange data between the different parts of SAR: 9 | 10 | .. autosummary:: 11 | :toctree: generated 12 | :nosignatures: 13 | :template: classtemplate 14 | 15 | PartitionData 16 | ShardEdgesAndFeatures 17 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.path.abspath('../..')) 4 | 5 | 6 | # -- Project information ----------------------------------------------------- 7 | 8 | project = 'SAR' 9 | copyright = '2022, Hesham Mostafa' 10 | author = 'Hesham Mostafa' 11 | 12 | # The full version, including alpha/beta/rc tags 13 | release = '1.0' 14 | import sar # noqa 15 | 16 | add_module_names = False 17 | 18 | extensions = [ 19 | 'sphinx.ext.autodoc', 20 | 'sphinx.ext.autosummary', 21 | 22 | ] 23 | 24 | def autodoc_skip_member_handler(app, what, name, obj, skip, options): 25 | return name == 'forward' or name == 'extra_repr' 26 | 27 | #def setup(app): 28 | # app.connect('autodoc-skip-member', autodoc_skip_member_handler) 29 | 30 | templates_path = ['_templates'] 31 | 32 | 33 | html_theme = 'sphinx_rtd_theme' 34 | 35 | html_static_path = ['_static'] 36 | -------------------------------------------------------------------------------- /docs/source/data_loading.rst: -------------------------------------------------------------------------------- 1 | .. _data-loading: 2 | 3 | 4 | Data loading and graph construction 5 | ========================================================== 6 | After partitioning the graph using DGL's `partition_graph `_ function, SAR can load the graph data using :func:`sar.load_dgl_partition_data`. This yields a :class:`sar.common_tuples.PartitionData` object. The ``PartitionData`` object can then be used to construct various types of graph-like objects that can be passed to GNN models. You can construct graph objects to use for distributed full-batch training or graph objects to use for distributed training as follows: 7 | 8 | .. contents:: :local: 9 | :depth: 3 10 | 11 | 12 | Full-batch training 13 | --------------------------------------------------------------------------------------- 14 | 15 | Constructing the full graph for sequential aggregation and rematerialization 16 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 17 | Construct a single distributed graph object of type :class:`sar.core.GraphShardManager`:: 18 | 19 | shard_manager = sar.construct_full_graph(partition_data) 20 | 21 | .. 22 | 23 | The ``GraphShardManager`` object encapsulates N DGL graph objects (where N is the number of workers). Each graph object represents the edges incoming from one partition (including the local partition). ``GraphShardManager`` implements the ``update_all`` and ``apply_edges`` methods in addition to several other methods from the standard ``dgl.heterograph.DGLGraph`` API. The ``update_all`` and ``apply_edges`` methods implement the sequential aggregation and rematerialization scheme to realize the distributed forward and backward passes. ``GraphShardManager`` can usually be passed to GNN layers instead of ``dgl.heterograph.DGLGraph``. See the :ref:`the distributed graph limitations section` for some exceptions. 24 | 25 | Constructing Message Flow Graphs (MFGs) for sequential aggregation and rematerialization 26 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 27 | In node classification tasks, gradients only backpropagate from the labeled nodes. DGL uses the concept of message flow graphs to construct layer-specific bi-partite graphs that update only a subset of nodes in each layer. These are the nodes that will ultimately affect the output, assuming each node only aggregates messages from its neighbors in every layer. 28 | 29 | If training a K-layer GNN on a node classification tasks, you can construct K distributed graph objects that reflect the message flow graphs at each layer using :class:`sar.construct_mfgs`: 30 | :: 31 | 32 | class GNNModel(nn.Module): 33 | def __init__(n_layers: int): 34 | super().__init__() 35 | self.convs = nn.ModuleList([ 36 | dgl.nn.SAGEConv(100, 100) 37 | for _ in range(n_layers) 38 | ]) 39 | 40 | def forward(blocks: List[sar.GraphShardManager], features: torch.Tensor): 41 | for idx in range(len(self.convs)): 42 | features = self.convs[idx](blocks[idx], features) 43 | return features 44 | 45 | K = 3 # number of layers 46 | gnn_model = GNNModel(K) 47 | train_blocks = sar.construct_mfgs(partition_data, 48 | global_indices_of_labeled_nodes_in_partition, 49 | K) 50 | model_out = gnn_model(train_blocks, local_node_features) 51 | 52 | .. 53 | 54 | Using message flow graphs at each layer can substantially lower run-time and memory consumption in node classification tasks with few labeled nodes. 55 | 56 | 57 | Constructing full graph or MFGs for one-shot aggregation 58 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 59 | As described in :ref:`training modes `, SAR supports doing one-shot distributed aggregation (mode 3). To run in this mode, you should extract the full partition graph from the :class:`sar.core.GraphShardManager` object and use that during training. When using the full graph: 60 | :: 61 | 62 | shard_manager = sar.construct_full_graph(partition_data) 63 | one_shot_graph = shard_manager.get_full_partition_graph() 64 | del shard_manager 65 | ## Use one_shot_graph from now on. 66 | 67 | .. 68 | 69 | When using MFGs: 70 | :: 71 | 72 | train_blocks = sar.construct_mfgs(partition_data, 73 | global_indices_of_labeled_nodes_in_partition, 74 | n_layers) 75 | one_shot_blocks = [block.get_full_partition_graph() for block in train_blocks] 76 | del train_blocks 77 | ## Use one_shot_blocks from now on 78 | 79 | .. 80 | 81 | 82 | Sampling-based training 83 | --------------------------------------------------------------------------------------- 84 | 85 | For sampling-based training, use the dataloader provided by SAR: :func:`sar.DataLoader` to construct globally-sampled graphs. The sampled graphs are vanilla DGL graphs that reside solely on the local machines. SAR provides a global neighbor sampler: :class:`sar.DistNeighborSampler` that defines the sampling process from the distributed graph. A typical use case is: 86 | 87 | :: 88 | 89 | shard_manager = sar.construct_full_graph(partition_data) 90 | 91 | neighbor_sampler = sar.DistNeighborSampler( 92 | [15, 10, 5], #Fanout for every layer 93 | input_node_features={'features': features}, #Input features to add to srcdata of first layer's sampled block 94 | output_node_features={'labels': labels} #Output features to add to dstdata of last layer's sampled block 95 | ) 96 | 97 | dataloader = sar.DataLoader( 98 | shard_manager, #Distributed graph 99 | train_nodes, #Global indices of nodes that will form the root of the sampled graphs. In node classification, these are the labeled nodes 100 | neighbor_sampler, #Distributed sampler 101 | batch_size) 102 | 103 | for blocks in dataloader: 104 | output = gnn_model(blocks) 105 | ... 106 | 107 | .. 108 | 109 | 110 | Full-graph inference 111 | --------------------------------------------------------------------------------------- 112 | SAR might also be utilized just for model evaluation. It is preferable to evaluate the model on the entire graph while performing mini-batch distributed training with the DGL package. To accomplish this, SAR can turn a `DistGraph `_ object into a GraphShardManager object, allowing for distributed full-graph inference. The procedure is simple since no further steps are required because the model parameters are already synchronized during inference. You can use :func:`sar.convert_dist_graph` in the following way to perform full-graph inference: 113 | :: 114 | 115 | class GNNModel(nn.Module): 116 | def __init__(n_layers: int): 117 | super().__init__() 118 | self.convs = nn.ModuleList([ 119 | dgl.nn.SAGEConv(100, 100) 120 | for _ in range(n_layers) 121 | ]) 122 | 123 | # forward function prepared for mini-batch training 124 | def forward(blocks: List[DGLBlock], features: torch.Tensor): 125 | h = features 126 | for idx, (layer, block) in enumerate(zip(self.convs, blocks)): 127 | h = self.convs[idx](blocks[idx], h) 128 | return h 129 | 130 | # implement inference function for full-graph input 131 | def full_graph_inference(graph: sar.GraphShardManager, featues: torch.Tensor): 132 | h = features 133 | for idx, layer in enumerate(self.convs): 134 | h = layer(graph, h) 135 | return h 136 | 137 | # model wrapped in pytorch DistributedDataParallel 138 | gnn_model = th.nn.parallel.DistributedDataParallel(GNNModel(3)) 139 | 140 | # Convert DistGraph into GraphShardManager 141 | gsm = sar.convert_dist_graph(g).to(device) 142 | 143 | # Access to model through DistributedDataParallel module field 144 | model_out = gnn_model.module.full_graph_inference(gsm, local_node_features) 145 | .. 146 | 147 | 148 | Relevant methods 149 | --------------------------------------------------------------------------------------- 150 | 151 | .. currentmodule:: sar 152 | 153 | 154 | .. autosummary:: 155 | :toctree: Data loading and graph construction 156 | :template: distneighborsampler 157 | 158 | 159 | load_dgl_partition_data 160 | construct_full_graph 161 | construct_mfgs 162 | convert_dist_graph 163 | DataLoader 164 | DistNeighborSampler 165 | -------------------------------------------------------------------------------- /docs/source/full_batch.rst: -------------------------------------------------------------------------------- 1 | Full-batch training 2 | ========== 3 | Distributed full-batch training in SAR may require some changes to your existing GNN model. Check :ref:`preparing your GNN for full-batch training in SAR` for more details. SAR supports multiple :ref:`training modes` that are suitable for different graph sizes and that trade speed for memory efficiency. For more information about SAR core distributed graph objects, check out the :ref:`distributed graph objects` section. 4 | 5 | .. toctree:: 6 | :maxdepth: 1 7 | :titlesonly: 8 | 9 | Preparing your GNN for full-batch training in SAR 10 | Training modes 11 | Distributed Graph Objects 12 | -------------------------------------------------------------------------------- /docs/source/images/dom_parallel_naive.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/SAR/9abdd1414bbdd9bf91b67b82376a985b0bdf9b72/docs/source/images/dom_parallel_naive.png -------------------------------------------------------------------------------- /docs/source/images/dom_parallel_remat.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/SAR/9abdd1414bbdd9bf91b67b82376a985b0bdf9b72/docs/source/images/dom_parallel_remat.pdf -------------------------------------------------------------------------------- /docs/source/images/dom_parallel_remat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/SAR/9abdd1414bbdd9bf91b67b82376a985b0bdf9b72/docs/source/images/dom_parallel_remat.png -------------------------------------------------------------------------------- /docs/source/images/one_shot_aggregation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/SAR/9abdd1414bbdd9bf91b67b82376a985b0bdf9b72/docs/source/images/one_shot_aggregation.png -------------------------------------------------------------------------------- /docs/source/images/papers_gat_memory.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/SAR/9abdd1414bbdd9bf91b67b82376a985b0bdf9b72/docs/source/images/papers_gat_memory.png -------------------------------------------------------------------------------- /docs/source/images/papers_os_scaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/SAR/9abdd1414bbdd9bf91b67b82376a985b0bdf9b72/docs/source/images/papers_os_scaling.png -------------------------------------------------------------------------------- /docs/source/images/papers_sage_memory.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/SAR/9abdd1414bbdd9bf91b67b82376a985b0bdf9b72/docs/source/images/papers_sage_memory.png -------------------------------------------------------------------------------- /docs/source/images/papers_train_full_doc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/SAR/9abdd1414bbdd9bf91b67b82376a985b0bdf9b72/docs/source/images/papers_train_full_doc.png -------------------------------------------------------------------------------- /docs/source/images/sar_vs_distdgl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/SAR/9abdd1414bbdd9bf91b67b82376a985b0bdf9b72/docs/source/images/sar_vs_distdgl.png -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. SAR documentation master file, created by 2 | sphinx-quickstart on Mon Mar 28 07:30:12 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | 7 | Welcome to SAR's documentation! 8 | =============================== 9 | 10 | 11 | SAR is a pure Python library built on top of `DGL `_ to accelerate distributed training of Graph Neural Networks (GNNs) on large graphs. SAR supports both full-batch training and sampling-based training. For full-batch training, SAR supports the `Sequenial Aggregation and Rematerialization (SAR) `_ scheme to reduce peak per-machine memory consumption and guarantee that model memory consumption per worker goes down linearly with the number of workers. This is achieved by eliminating most of the data redundancy (due to the halo effect) involved in standard spatially parallel training. 12 | 13 | SAR uses the graph partition data generated by DGL's `partitioning utilities `_. It can thus be used as a drop in replacement for DGL's sampling-based distributed training. SAR enables scalable, distributed training on very large graphs, and supports multiple training modes that balance speed against memory efficiency. SAR requires minimal changes to existing single-host DGL training code. See the quick start guide to get started using SAR. 14 | 15 | .. toctree:: 16 | :maxdepth: 1 17 | 18 | Quick start 19 | Data loading and graph construction 20 | Communication routines 21 | Full-batch training 22 | Sampling-based training 23 | Data tuples 24 | SAR Configuration 25 | 26 | 27 | 28 | Index 29 | ================== 30 | 31 | * :ref:`genindex` 32 | -------------------------------------------------------------------------------- /docs/source/model_prepare.rst: -------------------------------------------------------------------------------- 1 | .. _model-prepare: 2 | 3 | 4 | Preparing your GNN model for SAR 5 | ========================================================= 6 | 7 | The basic graph object in SAR is :class:`sar.core.GraphShardManager`. It can typically be used as a drop-in replacement for DGL's native graph object and provided as the input graph to most GNN layers. See :ref:`the distributed graph limitations section` for some important limitations of this approach. There are situations where you need to modify your layer to accomodate :class:`sar.core.GraphShardManager` or to modify your GNN network to take into account the distributed nature of the training. Three such situations are outlined here: 8 | 9 | Edge softmax 10 | ------------------------------------------------------------------------------------ 11 | DGL's ``edge_softmax`` function expects a native DGL graph object and will not work with a :class:`sar.core.GraphShardManager` object. Instead, you must use SAR's implementation :func:`sar.edge_softmax` which accepts a :class:`sar.core.GraphShardManager` object. DGL's attention based GNN layers make use of DGL's ``edge_softmax`` function. One solution to be able to use these layers with SAR is to monkey-patch them as shown below: 12 | :: 13 | 14 | import dgl 15 | import sar 16 | def patched_edge_softmax(graph, *args, **kwargs): 17 | if isinstance(graph, sar.GraphShardManager): 18 | return sar.edge_softmax(graph, *args, **kwargs) 19 | 20 | return dgl.nn.edge_softmax(graph, *args, **kwargs) # pylint: disable=no-member 21 | 22 | 23 | dgl.nn.pytorch.conv.gatconv.edge_softmax = patched_edge_softmax 24 | dgl.nn.pytorch.conv.dotgatconv.edge_softmax = patched_edge_softmax 25 | dgl.nn.pytorch.conv.agnnconv.edge_softmax = patched_edge_softmax 26 | 27 | .. 28 | 29 | ``patched_edge_softmax`` dispatches to either DGL's or SAR's implementation depending on the type of the input graph. SAR has the conveninece function :func:`sar.patch_dgl` that runs the above code to patch DGL's attention-based GNN layers. 30 | 31 | Parameterized message functions 32 | ----------------------------------------------------------------------------------- 33 | 34 | SAR's sequential rematerialization of the computational graph during the backward pass must be aware of any learnable parameters used to create the edge messages. SAR needs to know of these parameters so that it can correctly backpropagate gradients to them. There is no easy way for SAR to automatically detect the learnable parameters used by the message function. It is thus up to the user to use the :func:`sar.core.message_has_parameters` to tell SAR about these parameters. For example, DGL's ``RelGraphConv`` layer uses a message function with learnable parameters. To avoid the need to modify the original code of ``RelGraphconv``, we can subclass it as follows to provide the necessary decorator for the message function, and then use the subclass in the GNN model: 35 | :: 36 | 37 | import dgl 38 | import sar 39 | 40 | class RelGraphConv_sar(dgl.nn.pytorch.conv.RelGraphConv): 41 | def __init__(self, *args, **kwargs): 42 | super().__init__(*args, **kwargs) 43 | 44 | @sar.message_has_parameters(lambda self: tuple(self.linear_r.parameters())) 45 | def message(self, edges): 46 | return super().message(edges) 47 | 48 | .. 49 | 50 | SAR has the conveninece function :func:`sar.patch_dgl` that defines a new ``RelGraphConv`` layer as described in the code above and uses it to replace DGL's ``RelGraphConv`` layer. 51 | 52 | 53 | Batch normalization 54 | ----------------------------------------------------------------------------------- 55 | The batch normalization layers in PyTorch such as ``torch.nn.BatchNorm1d`` will normalize the GNN node features using statistics obtained only from the node features in the local partition. So the normalizing factors (mean and standard deviation) will be different in each worker, and will depend on the way the graph is partitioned. To normalize using global statistics obtained from all nodes in the graph, you can use :class:`sar.DistributedBN1D`. :class:`sar.DistributedBN1D` has a similar interface as ``torch.nn.BatchNorm1d``. For example:: 56 | 57 | norm_layer = sar.DistributedBN1D(out_dim, affine=True) 58 | .. 59 | #Will normalize the features of the nodes in the partition 60 | #by the global node statistics (mean and standard deviation) 61 | normalized_activations = norm_layer(partition_node_features) 62 | 63 | .. 64 | 65 | Relevant methods 66 | --------------------------------------------------------------------------- 67 | 68 | .. autosummary:: 69 | :toctree: Adapting GNNs to SAR 70 | :template: graphshardmanager 71 | 72 | sar.core.message_has_parameters 73 | sar.edge_softmax 74 | sar.DistributedBN1D 75 | sar.patch_dgl 76 | -------------------------------------------------------------------------------- /docs/source/quick_start.rst: -------------------------------------------------------------------------------- 1 | .. _quick-start: 2 | 3 | Quick start guide 4 | =============================== 5 | Follow the following steps to enable distributed training in your DGL code: 6 | 7 | .. contents:: 8 | :depth: 2 9 | :local: 10 | :backlinks: top 11 | 12 | Partition the graph 13 | ---------------------------------- 14 | Partition the graph using DGL's `partition_graph `_ function. See `here `_ for an example. The number of partitions should be the same as the number of training machines/workers that will be used. SAR requires consecutive node indices in each partition, and requires that the partition information include the one-hop neighborhoods of all nodes in the partition. Setting ``num_hops = 1`` and ``reshuffle = True`` (in DGL < 1.0) in the call to ``partition_graph`` takes care of these requirements. ``partition_graph`` yields a directory structure with the partition information and a .json file ``graph_name.json``. 15 | 16 | 17 | An example of partitioning the ogbn-arxiv graph in two parts: :: 18 | 19 | import dgl 20 | import torch 21 | from ogb.nodeproppred import DglNodePropPredDataset 22 | 23 | dataset = DglNodePropPredDataset(name='ogbn-arxiv') 24 | graph = dataset[0][0] 25 | graph = dgl.to_bidirected(graph, copy_ndata=True) 26 | graph = dgl.add_self_loop(graph) 27 | 28 | labels = dataset[0][1].view(-1) 29 | split_idx = dataset.get_idx_split() 30 | 31 | 32 | def _idx_to_mask(idx_tensor): 33 | mask = torch.BoolTensor(graph.number_of_nodes()).fill_(False) 34 | mask[idx_tensor] = True 35 | return mask 36 | 37 | 38 | train_mask, val_mask, test_mask = map( 39 | _idx_to_mask, [split_idx['train'], split_idx['valid'], split_idx['test']]) 40 | features = graph.ndata['feat'] 41 | graph.ndata.clear() 42 | for name, val in zip(['train_mask', 'val_mask', 'test_mask', 'labels', 'features'], 43 | [train_mask, val_mask, test_mask, labels, features]): 44 | graph.ndata[name] = val 45 | 46 | dgl.distributed.partition_graph( 47 | graph, 'arxiv', 2, './test_partition_data/', num_hops=1) # use reshuffle=True in DGL < 1.0 48 | 49 | .. 50 | 51 | Note that we add the labels, and the train/test/validation masks as node features so that they get split into multiple parts alongside the graph. 52 | 53 | 54 | Initialize communication 55 | ---------------------------------- 56 | SAR uses the `torch.distributed `_ package to handle all communication. See the :ref:`Communication Guide ` for more information on the communication routines. We require the IP address of the master worker/machine (the machine with rank 0) to initialize the ``torch.distributed`` package. In an environment with a networked file system where all workers/machines share a common file system, we can communicate the master's IP address through the file system. In that case, use :func:`sar.nfs_ip_init` to obtain the master ip address. 57 | 58 | Initialize the communication through a call to :func:`sar.initialize_comms` , specifying the current worker index, the total number of workers (which should be the same as the number of partitions from step 1), the master's IP address, and the communication device. The later is the device on which SAR should place the tensors before sending them through the communication backend. For example: :: 59 | 60 | if backend_name == 'nccl': 61 | comm_device = torch.device('cuda') 62 | else: 63 | comm_device = torch.device('cpu') 64 | master_ip_address = sar.nfs_ip_init(rank, path_to_ip_file) 65 | sar.initialize_comms(rank, world_size, master_ip_address, backend_name, comm_device) 66 | 67 | .. 68 | 69 | ``backend_name`` can be ``ccl``, ``nccl``, ``mpi`` or ``gloo``. 70 | 71 | 72 | 73 | Load partition data and construct graph 74 | ----------------------------------------------------------------- 75 | Use :func:`sar.load_dgl_partition_data` to load one graph partition from DGL's partition data in each worker. :func:`sar.load_dgl_partition_data` returns a :class:`sar.common_tuples.PartitionData` object that contains all the information about the partition. 76 | 77 | There are several ways to construct a distributed graph-like object from ``PartitionData``. See :ref:`constructing distributed graphs ` for more details. Here we will use the simplest method: :func:`sar.construct_full_graph` which returns a :class:`sar.core.GraphShardManager` object which implements many of the GNN-related functionality of DGL's native graph objects. ``GraphShardManager`` can thus be used as a drop-in replacement for DGL's native graphs or it can be passed to SAR's samplers and data loaders to construct graph mini-batches. 78 | 79 | :: 80 | 81 | partition_data = sar.load_dgl_partition_data( 82 | json_file_path, #Path to .json file created by DGL's partition_graph 83 | rank, #Worker rank 84 | device #Device to place the partition data (CPU or GPU) 85 | ) 86 | shard_manager = sar.construct_full_graph(partition_data) 87 | 88 | .. 89 | 90 | Full-batch training 91 | --------------------------------------------------------------------------- 92 | Full-batch training using SAR follows a very similar pattern as single-host training. Instead of using a vanilla DGL graph, we use a :class:`sar.core.GraphShardManager`. After initializing the communication backend, loading graph data and constructing the distributed graph, a simple training loop is :: 93 | 94 | gnn_model = construct_GNN_model(...) 95 | optimizer = torch.optim.Adam(gnn_model.parameters(),..) 96 | sar.sync_params(gnn_model) 97 | for train_iter in range(n_train_iters): 98 | model_out = gnn_model(shard_manager,features) 99 | loss = calculate_loss(model_out,labels) 100 | optimizer.zero_grad() 101 | loss.backward() 102 | sar.gather_grads(gnn_model) 103 | optimizer.step() 104 | 105 | .. 106 | 107 | In a distributed setting, each worker will construct the GNN model. Before training, we should synchronize the model parameters across all workers. :func:`sar.sync_params` is a convenience function that does just that. At the end of every training iteration, each worker needs to gather and sum the parameter gradients from all other workers before making the parameter update. This can be done using :func:`sar.gather_grads`. 108 | 109 | See :ref:`training modes ` for the different full-batch training modes. 110 | 111 | Sampling-based or mini-batch training 112 | --------------------------------------------------------------------------- 113 | A simple sampling-based training loop looks as follows: 114 | 115 | :: 116 | 117 | neighbor_sampler = sar.DistNeighborSampler( 118 | [15, 10, 5], #Fanout for every layer 119 | input_node_features={'features': features}, #Input features to add to srcdata of first layer's sampled block 120 | output_node_features={'labels': labels} #Output features to add to dstdata of last layer's sampled block 121 | ) 122 | 123 | dataloader = sar.DataLoader( 124 | shard_manager, #Distributed graph 125 | train_nodes, #Global indices of nodes that will form the root of the sampled graphs. In node classification, these are the labeled nodes 126 | neighbor_sampler, #Distributed sampler 127 | batch_size) 128 | 129 | for blocks in dataloader: 130 | output = gnn_model(blocks) 131 | loss = calculate_loss(output,labels) 132 | optimizer.zero_grad() 133 | loss.backward() 134 | sar.gather_grads(gnn_model) 135 | optimizer.step() 136 | 137 | .. 138 | 139 | 140 | We use :class:`sar.DistNeighborSampler` to construct a distributed sampler and :func:`sar.DataLoader` to construct an iterator that retrurn standard local DGL blocks constructed from the distributed graph. 141 | 142 | 143 | For complete examples, check the examples folder in the Git repository. 144 | -------------------------------------------------------------------------------- /docs/source/sampling_training.rst: -------------------------------------------------------------------------------- 1 | .. _sampling: 2 | 3 | 4 | Distributed Sampling-based training 5 | ================================================================== 6 | In addition to distributed full-batch training, the SAR library also supports distributed sampling-based training. The main difference between SAR's distributed sampling-based component and DistDGL is that SAR uses collective communication primitives such as ``all_to_all`` during the distributed mini-batch generation steps, while DistDGL uses point-to-point communication. One common use case in GNN training is to use sampling-based training followed by full-batch inference. Since SAR supports sampling-based as well as full-batch training and inference, this use case is particularly easy to implement. The same GNN model can be used for both full-batch and sampling-based runs. A simple 3-layer GraphSage model: 7 | 8 | :: 9 | 10 | class GNNModel(nn.Module): 11 | def __init__(self, in_dim: int, hidden_dim: int, out_dim: int): 12 | super().__init__() 13 | 14 | self.convs = nn.ModuleList([ 15 | dgl.nn.SAGEConv(in_dim, hidden_dim, aggregator_type='mean'), 16 | dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type='mean'), 17 | dgl.nn.SAGEConv(hidden_dim, out_dim, aggregator_type='mean'), 18 | ]) 19 | 20 | def forward(self, blocks: List[Union[DGLBlock, sar.GraphShardManager]], features: torch.Tensor): 21 | for idx, conv in enumerate(self.convs): 22 | features = conv(blocks[idx], features) 23 | if idx < len(self.convs) - 1: 24 | features = F.relu(features, inplace=True) 25 | 26 | return features 27 | 28 | .. 29 | 30 | Since :class:`sar.core.GraphShardManager` can be used as a drop-in replacement for DGL's native graph objects, we can use a standard DGL model and either pass it the sampled ``DGLBlock``s or the full distributed graph. 31 | 32 | As in full-batch training, we first load the DGL-generated partition data, and construct the full distributed graph. We then define the sampling strategy and dataloader. We use SAR's :class:`sar.DistNeighborSampler` and :func:`sar.DataLoader` to define the sampling strategy and the distributed dataloader, respectively. 33 | 34 | :: 35 | 36 | partition_data = sar.load_dgl_partition_data( 37 | args.partitioning_json_file, args.rank, torch.device('cpu')) 38 | 39 | full_graph_manager = sar.construct_full_graph( 40 | partition_data) # Keep full graph on CPU 41 | 42 | 43 | neighbor_sampler = sar.DistNeighborSampler( 44 | [15, 10, 5], #Fanout for every layer 45 | input_node_features={'features': features}, #Input features to add to srcdata of first layer's sampled block 46 | output_node_features={'labels': labels} #Output features to add to dstdata of last layer's sampled block 47 | ) 48 | 49 | dataloader = sar.DataLoader( 50 | full_graph_manager, #Distributed graph 51 | train_nodes, #Global indices of nodes that will form the root of the sampled graphs. In node classification, these are the labeled nodes 52 | neighbor_sampler, #Distributed sampler 53 | batch_size) 54 | 55 | .. 56 | 57 | A typical training loop is shown below. 58 | 59 | :: 60 | 61 | gnn_model = construct_GNN_model(...) 62 | optimizer = torch.optim.Adam(gnn_model.parameters(),..) 63 | sar.sync_params(gnn_model) 64 | 65 | 66 | for epoch in range(n_epochs): 67 | model.train() 68 | for blocks in dataloader: 69 | block_features = blocks[0].srcdata['features'] 70 | block_labels = blocks[-1].dstdata['labels'] 71 | logits = gnn_model(blocks, block_features) 72 | 73 | output = gnn_model(blocks) 74 | loss = calculate_loss(output, block_labels) 75 | optimizer.zero_grad() 76 | loss.backward() 77 | sar.gather_grads(gnn_model) 78 | optimizer.step() 79 | 80 | # inference 81 | model.eval() 82 | with torch.no_grad(): 83 | logits = gnn_model_cpu([full_graph_manager] * n_layers, features) 84 | calculate_loss_accuracy(logits, full_graph_labels) 85 | 86 | .. 87 | 88 | Note that we obtain instances of standard ``DGLBlock`` from the distributed dataloader every training iteration. After every epoch, we run distributed full-graph inference using the :class:`sar.core.GraphShardManager`. We use the same ``GraphShardManager`` object at each layer. Alternatively, as described in the :ref:`data loading section`, we can construct layer-specific distributed message flow graphs (MFGs) to avoid computing redundant node features at each layer. Redundant node features are the node features that do not contribute to the output at the labeled nodes. 89 | 90 | 91 | 92 | Relevant classes and methods 93 | --------------------------------------------------------------------------- 94 | 95 | .. currentmodule:: sar 96 | 97 | .. autosummary:: 98 | :toctree: Graph Shard classes 99 | :template: graphshardmanager 100 | 101 | GraphShardManager 102 | DataLoader 103 | DistNeighborSampler 104 | 105 | -------------------------------------------------------------------------------- /docs/source/sar_config.rst: -------------------------------------------------------------------------------- 1 | .. _sar-config: 2 | .. currentmodule:: sar 3 | 4 | SAR configuration 5 | ======================================= 6 | .. autoclass:: sar.Config 7 | -------------------------------------------------------------------------------- /docs/source/sar_modes.rst: -------------------------------------------------------------------------------- 1 | .. _sar-modes: 2 | 3 | SAR's training modes 4 | ============================= 5 | SAR can run the distributed GNN forward and backward pass in three distinct modes: 6 | 7 | .. contents:: :local: 8 | :depth: 2 9 | 10 | 11 | Mode 1: Sequential aggregation and rematerialization 12 | ------------------------------------------------------------------------------------ 13 | SAR's main training mode uses sequential aggregation and rematerialization to avoid fully materializaing the computational graph at any single worker. The distributed training scheme is illustrated below for 3 workers/partitions. The figure illustrates the forward and backward pass steps for machine/worker 1. 14 | 15 | .. image:: ./images/dom_parallel_remat.png 16 | :alt: SAR training 17 | :width: 500 px 18 | 19 | **Forward pass:**: When doing distributed message passing and aggregation, all workers disable PyTorch's autgrad to stop PyTorch from creating the computational graph. The forward pass steps in worker 1 are: 20 | 21 | #. Aggregates messages from nodes in the local partition. 22 | #. Fetch neighboring node features from worker 2 and aggregates their messages. Delete fetched nodes. 23 | #. Fetch neighboring node features from worker 3 and aggregates their messages. Delete fetched nodes. 24 | 25 | **Backward pass:**: Since no computational graph was constructed in the forward pass, the workers need to reconstruct the computational graph to backpropagate the gradients. The backward pass steps in worker 1 when worker one receives errors :math:`e_2,e_3` for nodes :math:`v_2,v_3`: 26 | 27 | #. Re-aggregate messages from nodes in the local partition with autograd enabled, backpropagate along the constructed computational graph, then delete all intermediate tensors to delete the computational graph. 28 | #. Re-fetch neighboring nodes from worker 2 and re-aggregate their messages with autograd enabled, backpropagate along the constructed computational graph, then delete all intermediate tensors to delete the computational graph. 29 | #. Re-fetch neighboring nodes from worker 3 and re-aggregate their messages with autograd enabled, backpropagate along the constructed computational graph, then delete all intermediate tensors to delete the computational graph. 30 | 31 | Note that many GNN layer simply use a sum or mean operation to aggregate the features of their neighbors. In that case, we do not need to reconstruct the computational graph during the backward pass, as the gradients of the input features can be easily obtained from the gradients of the output features. To see that consider the operation :math:`z = x + y`, and assume we have the gradient of the loss w.r.t :math:`z`: :math:`e_z` than the gradients of the loss w.r.t x and y are :math:`e_x = e_y = e_z`. SAR automatically detects this situation and directly pushes the correct gradients to the local and remote nodes without re-fetching remote features or re-construting the computational graph. 32 | 33 | Mode 1 is the default mode. It runs automatically when you pass a :class:`sar.core.GraphShardManager` object to your GNN model. For example: 34 | :: 35 | 36 | partition_data = sar.load_dgl_partition_data( 37 | json_file_path, #Path to .json file created by DGL's partition_graph 38 | rank, #Worker rank 39 | device #Device to place the partition data (CPU or GPU) 40 | ) 41 | shard_manager = sar.construct_full_graph(partition_data) 42 | model_out = gnn_model (shard_manager,local_node_features) 43 | loss_function(model_out).backward() 44 | 45 | .. 46 | 47 | Mode 1 is the most memory-efficient mode. Excluding the GNN parameters which are replicated across all workers, mode 1 guarantees that peak memory consumption per worker will go down linearly with the number of workers used in training, even for densely connected graphs. 48 | 49 | Mode 2: Sequential aggregation 50 | ------------------------------------------------------------------------------------ 51 | For many GNN layers, mode 1 must sequentially rematerialize the computational graph during the backwrd pass. This is the case for GAT layers for example. This introduces extra communication overhead during the backward pass in order to re-fetch the features of remote nodes. It also introduces an extra compute overhead in order to sequentially construct the computational graph (by re-executing the forward pass with autograd enabled). 52 | 53 | Mode 2 avoids this overhead by constructing the computational graph during the forward pass. This can potentially take up a lot of memory. In a densely connected graph for example, this may cause the features of all nodes in the graph to be materialized at every worker. The figure below illustrates the forward and backward pass in mode 2. Only the activity in worker 1 is shown. Note that remote node featues are stored in worker 1 during the forward pass as part of the computational graph at worker 1. In the backward pass, SAR has access to the full computational graph and it uses it to backpropagates gradients to remote nodes without any re-fetching. 54 | 55 | .. image:: ./images/dom_parallel_naive.png 56 | :alt: No re-materialization 57 | :width: 500 px 58 | 59 | Mode 2 can be enabled by disabling sequential rematerialization in SAR's configuration object :class:`sar.Config`.:: 60 | 61 | sar.Config.disable_sr = True 62 | partition_data = sar.load_dgl_partition_data( 63 | json_file_path, #Path to .json file created by DGL's partition_graph 64 | rank, #Worker rank 65 | device #Device to place the partition data (CPU or GPU) 66 | ) 67 | shard_manager = sar.construct_full_graph(partition_data) 68 | model_out = gnn_model (shard_manager,local_node_features) 69 | loss_function(model_out).backward() 70 | 71 | .. 72 | 73 | Mode 3: One-shot aggregation 74 | ------------------------------------------------------------------------------------ 75 | Modes 1 and 2 follow a sequential aggregation approach where data from remote partitions are sequentially fetched. This might introduce scalability issues since the forward and backward pass in each layer will involve N communication rounds each (where N is the number of workers/partitions). Sequential aggregation thus introduces N synchronization points in each layer's forward and backward passes as each worker needs to wait until every other worker has finished its aggregation step before moving to the next step in the aggregation sequence (See the steps in the figure above). 76 | 77 | In mode 3, the one-shot aggregation mode, each worker fetches all remote data in one communication round and does one aggregation round to aggregate message from all remotely fetched nodes. This is illustrated in the figure below: 78 | 79 | .. image:: ./images/one_shot_aggregation.png 80 | :alt: One shot aggregation 81 | :width: 500 px 82 | 83 | One advantage of mode 3 is that it only requires one communication round per layer in each of the forward and backward passes. One disadvantage is that mode 3 does not hide the communication latency. Due to the sequential nature of modes 1 and 2, SAR is able to simultaneously process data from one remote partition while pre-fetching data from the next remote partition in the aggregation sequence. Modes 1 and 2 can thus better hide the communication latency than mode 3. The memory requirements of mode 3 are similar to mode 2. 84 | 85 | To train in mode 3, you should extract the full partition graph from the :class:`sar.core.GraphShardManager` object and use that during training. 86 | :: 87 | 88 | partition_data = sar.load_dgl_partition_data( 89 | json_file_path, #Path to .json file created by DGL's partition_graph 90 | rank, #Worker rank 91 | device #Device to place the partition data (CPU or GPU) 92 | ) 93 | shard_manager = sar.construct_full_graph(partition_data) 94 | one_shot_graph = shard_manager.get_full_partition_graph() 95 | model_out = gnn_model (one_shot_graph,local_node_features) 96 | loss_function(model_out).backward() 97 | 98 | .. 99 | -------------------------------------------------------------------------------- /docs/source/shards.rst: -------------------------------------------------------------------------------- 1 | .. _shards: 2 | 3 | 4 | Distributed Graph Representation 5 | ================================================================== 6 | SAR represents the full graph as :math:`N^2` graph shards where :math:`N` is the number of workers/partitions. Each graph shard represents the edges from one partition to another and the features associated with these edges. :class:`sar.core.GraphShard` represents a single graph shard. Each worker stores the :math:`N` graph shards containing the incoming edges for the nodes in the worker's partition. These :math:`N` graph shards are managed by the :class:`sar.core.GraphShardManager` class. :class:`sar.core.GraphShardManager` implements a distributed version of ``update_all`` and ``apply_edges`` which are the main methods used by GNNs to create and exchange messages in the graph. :class:`sar.core.GraphShardManager` implements ``update_all`` and ``apply_edges`` in a sequential manner by iterating through the :math:`N` graph shards to sequentially create and aggregate messages from each partition into the local partition. 7 | 8 | The :meth:`sar.core.GraphShardManager.get_full_partition_graph` method can be used to combine the worker's :math:`N` graph shards into one monolithic graph object that represents all the incoming edges for nodes in the local partition. It returns a :class:`sar.core.DistributedBlock` object. The implementation of ``update_all`` and ``apply_edges`` in :class:`sar.core.DistributedBlock` is not sequential. Instead. It fetches all remote features in one step and aggregates all incoming messages to the local partition in one step. 9 | 10 | In the distributed implementation of the sequential backward pass in ``update_all`` and ``apply_edges`` in :class:`sar.core.GraphShardManager`, it is not possible for SAR to automatically detect if the message function uses any learnable parameters, and thus SAR will not be able to backpropagate gradients to these message parameters. To tell SAR that a message function is parameterized, use the :func:`sar.core.message_has_parameters` decorator to decorate message functions that use learnable parameters. 11 | 12 | 13 | .. _shard-limitations: 14 | 15 | Limitations of the distributed graph objects 16 | ------------------------------------------------------------------------------------ 17 | Keep in mind that the distributed graph class :class:`sar.core.GraphShardManager` does not implement all the functionality of DGL's native graph class. For example, it does not impelement the ``successors`` and ``predecessors`` methods. It supports primarily the methods of DGL's native graphs that are relevant to GNNs such as ``update_all``, ``apply_edges``, and ``local_scope``. It also supports setting graph node and edge features through the dictionaries ``srcdata``, ``dstdata``, and ``edata``. To remain compatible with DGLGraph :class:`sar.core.GraphShardManager` provides also access to the ``ndata`` member, which works as alias to ``srcdata``, however it is not accessible when working with MFGs. 18 | 19 | :class:`sar.core.GraphShardManager` also supports the ``in_degrees`` and ``out_degrees`` members and supports querying the number of nodes and edges in the graph. 20 | 21 | The ``update_all`` method in :class:`sar.core.GraphShardManager` only supports the 4 standard reduce functions in dgl: ``max``, ``min``, ``sum``, and ``mean``. The reason behind this is that SAR runs a sequential reduction of messages and therefore requires that :math:`reduce(msg_1,msg_2,msg_3) = reduce(msg_1,reduce(msg_2,msg_3))`. 22 | 23 | .. currentmodule:: sar.core 24 | 25 | 26 | Relevant classes methods 27 | --------------------------------------------------------------------------- 28 | 29 | .. autosummary:: 30 | :toctree: Graph Shard classes 31 | :template: graphshardmanager 32 | 33 | GraphShard 34 | GraphShardManager 35 | DistributedBlock 36 | message_has_parameters 37 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | ## Graph partitioning 2 | 3 | The ``partition_graph.py`` script can be used to partition both homogeneous and heterogeneous graphs. It utilizes DGL's metis-based partitioning algorithm to divide the graphs into smaller partitions. Note that all node-related information must be included in the graph's ``ndata`` dictionary so that they are correctly partitioned with the graph. 4 | Similarly, edge-related information must be included in the graph's ``edata`` dictionary 5 | 6 | ### Supported datasets: 7 | - ogbn-products, ogbn-arxiv, ogb-mag from [Open Graph Benchmarks](https://ogb.stanford.edu/) 8 | - cora, citeseer, pubmed 9 | 10 | ## Full-batch Training 11 | 12 | The script ``train_homogeneous_graph_basic.py`` demonstrates the basic functionality of SAR. It runs distriobuted training using a 3-layer GraphSage network on a partiotioned graph. If you want to train using ``N`` workers, then you need to launch the script ``N`` times, preferably on separate machines. For example, for ``N=2``, and assuming the two workers are on the same network file system, you can launch the 2 workers using the following two commands: 13 | 14 | ```shell 15 | python3 train_homogeneous_graph_basic.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 0 --world-size 2 16 | python3 train_homogeneous_graph_basic.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 1 --world-size 2 17 | 18 | ``` 19 | The worker with ``rank=0`` (the master) will write its address to the file specified by the ``--ip-file`` option and the other worker(s) will read this file and connect to the master. 20 | 21 | 22 | The ``train_homogeneous_graph_advanced.py`` script demonstrates more advanced features of SAR such as distributed construction of Message Flow Graphs (MFGs), and the multiple training modes supported by SAR. 23 | 24 | The ``train_heterogeneous_graph.py`` script demonstrates training on a heterogeneous graph (ogbn-mag). The script trains a 3-layer R-GCN. 25 | 26 | 27 | ## Sampling-based Training 28 | The script ``train_homogeneous_sampling_basic.py`` demonstrates distributed sampling-based training on SAR. It demonstrates the unique ability of the SAR library to run distributed sampling-based training followed by memory-efficient distributed full-graph inference. The script uses a 3-layer GraphSage network. Assuming 2 machines, you can launch training on the 2 machines using the following 2 commands, where each is executed on a different machine: 29 | 30 | ```shell 31 | python3 train_homogeneous_sampling_basic.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 0 --world-size 2 32 | python3 train_homogeneous_sampling_basic.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 1 --world-size 2 33 | 34 | ``` 35 | 36 | ## Distributed Mini-Batch Training with Full-Graph inference 37 | The script ``train_distdgl_with_sar_inference.py`` showcases how SAR can be effectively combined with native DGL distributed training. In this particular example, the training process utilizes a sampling approach, while the evaluation phase leverages the SAR library to perform computations on the entire graph. 38 | ```shell 39 | python /home/ubuntu/workspace/dgl/tools/launch.py \ 40 | --workspace /home/ubuntu/workspace/SAR/examples \ 41 | --num_trainers 1 \ 42 | --num_samplers 2 \ 43 | --num_servers 1 \ 44 | --part_config partition_data/ogbn-products.json \ 45 | --ip_config ip_config.txt \ 46 | "/home/ubuntu/miniconda3/bin/python train_distdgl_with_sar_inference.py --graph_name ogbn-products --ip_config ip_config.txt --num_epochs 2 --batch_size 1000 --part_config partition_data/ogbn-products.json" 47 | ``` 48 | 49 | ## Correct and Smooth 50 | Example taken from [DGL implemenetation](https://github.com/dmlc/dgl/tree/master/examples/pytorch/correct_and_smooth) of C&S. Code is adjusted to perform distributed training with SAR. Introduced modifications change the way data normalization is performed - workers need to communicate with each other to calculate mean and standard deviation for the entire dataset (not just their partition). Moreover, workers need to be synchronized with each other to calculate sigma value required during "correct" phase. 51 | 52 | For instance, you can run the example with following commands (2 machines scenario): 53 | 54 | * **Plain MLP + C&S** 55 | * Rank 0 machine: 56 | ```shell 57 | python correct_and_smooth.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 0 --world-size 2 --dropout 0.5 --correction-adj DA --smoothing-adj AD --autoscale 58 | ``` 59 | 60 | * Rank 1 machine: 61 | ```shell 62 | python correct_and_smooth.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 1 --world-size 2 --dropout 0.5 --correction-adj DA --smoothing-adj AD --autoscale 63 | ``` 64 | 65 | * **Plain Linear + C&S** 66 | * Rank 0 machine: 67 | ```shell 68 | python correct_and_smooth.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 0 --world-size 2 --model linear --dropout 0.5 --epochs 1000 --correction-alpha 0.87 --smoothing-alpha 0.81 --correction-adj AD --autoscale 69 | ``` 70 | 71 | * Rank 1 machine: 72 | ```shell 73 | python correct_and_smooth.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 1 --world-size 2 --model linear --dropout 0.5 --epochs 1000 --correction-alpha 0.87 --smoothing-alpha 0.81 --correction-adj AD --autoscale 74 | ``` 75 | -------------------------------------------------------------------------------- /examples/SIGN/README.md: -------------------------------------------------------------------------------- 1 | ## SIGN: Scalable Inception Graph Neural Networks 2 | 3 | Original script: https://github.com/dmlc/dgl/tree/master/examples/pytorch/sign 4 | 5 | Provided `train_sign_with_sar.py` script is an example how to intergrate SAR to preprocess graph data for training. 6 | 7 | ### Results 8 | Obtained results for two partitions: 9 | - ogbn-products: 0.7832 10 | - reddit: 0.9639 11 | 12 | ### Run command: 13 | 14 | ``` 15 | python train_sign_with_sar.py --partitioning-json-file partition_data/reddit.json --ip-file ip_file --backend ccl --rank 0 --world-size 2 16 | 17 | python train_sign_with_sar.py --partitioning-json-file partition_data/reddit.json --ip-file ip_file --backend ccl --rank 1 --world-size 2 18 | ``` -------------------------------------------------------------------------------- /examples/SIGN/train_sign_with_sar.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import dgl 6 | import dgl.function as fn 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | import sar 13 | 14 | def load_dataset(filename, rank, device): 15 | partition_data = sar.load_dgl_partition_data(filename, rank, device) 16 | # Obtain train,validation, and test masks 17 | # These are stored as node features. Partitioning may prepend 18 | # the node type to the mask names. So we use the convenience function 19 | # suffix_key_lookup to look up the mask name while ignoring the 20 | # arbitrary node type 21 | masks = {} 22 | for mask_name, indices_name in zip(["train_mask", "val_mask", "test_mask"], 23 | ["train_indices", "val_indices", "test_indices"]): 24 | boolean_mask = sar.suffix_key_lookup(partition_data.node_features, 25 | mask_name) 26 | masks[indices_name] = boolean_mask.nonzero( 27 | as_tuple=False).view(-1).to(device) 28 | print(partition_data.node_features.keys()) 29 | 30 | feature_name, label_name = ('feat', 'label') if 'reddit' in filename \ 31 | else ('features', 'labels') 32 | labels = sar.suffix_key_lookup(partition_data.node_features, 33 | label_name).long().to(device) 34 | 35 | # Obtain the number of classes by finding the max label across all workers 36 | n_classes = labels.max() + 1 37 | sar.comm.all_reduce(n_classes, torch.distributed.ReduceOp.MAX, move_to_comm_device=True) 38 | n_classes = n_classes.item() 39 | 40 | features = sar.suffix_key_lookup(partition_data.node_features, feature_name).to(device) 41 | full_graph_manager = sar.construct_full_graph(partition_data).to(device) 42 | 43 | full_graph_manager.ndata["feat"] = features 44 | full_graph_manager.ndata["label"] = labels 45 | return full_graph_manager, n_classes, \ 46 | masks["train_indices"], masks["val_indices"], masks["test_indices"], 47 | 48 | class FeedForwardNet(nn.Module): 49 | def __init__(self, in_feats, hidden, out_feats, n_layers, dropout): 50 | super(FeedForwardNet, self).__init__() 51 | self.layers = nn.ModuleList() 52 | self.n_layers = n_layers 53 | if n_layers == 1: 54 | self.layers.append(nn.Linear(in_feats, out_feats)) 55 | else: 56 | self.layers.append(nn.Linear(in_feats, hidden)) 57 | for _ in range(n_layers - 2): 58 | self.layers.append(nn.Linear(hidden, hidden)) 59 | self.layers.append(nn.Linear(hidden, out_feats)) 60 | if self.n_layers > 1: 61 | self.prelu = nn.PReLU() 62 | self.dropout = nn.Dropout(dropout) 63 | self.reset_parameters() 64 | 65 | def reset_parameters(self): 66 | gain = nn.init.calculate_gain("relu") 67 | for layer in self.layers: 68 | nn.init.xavier_uniform_(layer.weight, gain=gain) 69 | nn.init.zeros_(layer.bias) 70 | 71 | def forward(self, x): 72 | for layer_id, layer in enumerate(self.layers): 73 | x = layer(x) 74 | if layer_id < self.n_layers - 1: 75 | x = self.dropout(self.prelu(x)) 76 | return x 77 | 78 | 79 | class Model(nn.Module): 80 | def __init__(self, in_feats, hidden, out_feats, R, n_layers, dropout): 81 | super(Model, self).__init__() 82 | self.dropout = nn.Dropout(dropout) 83 | self.prelu = nn.PReLU() 84 | self.inception_ffs = nn.ModuleList() 85 | for hop in range(R + 1): 86 | self.inception_ffs.append( 87 | FeedForwardNet(in_feats, hidden, hidden, n_layers, dropout) 88 | ) 89 | # self.linear = nn.Linear(hidden * (R + 1), out_feats) 90 | self.project = FeedForwardNet( 91 | (R + 1) * hidden, hidden, out_feats, n_layers, dropout 92 | ) 93 | 94 | def forward(self, feats): 95 | hidden = [] 96 | for feat, ff in zip(feats, self.inception_ffs): 97 | hidden.append(ff(feat)) 98 | out = self.project(self.dropout(self.prelu(torch.cat(hidden, dim=-1)))) 99 | return out 100 | 101 | 102 | def calc_weight(g): 103 | """ 104 | Compute row_normalized(D^(-1/2)AD^(-1/2)) 105 | """ 106 | with g.local_scope(): 107 | # compute D^(-0.5)*D(-1/2), assuming A is Identity 108 | g.ndata["in_deg"] = g.in_degrees().float().pow(-0.5) 109 | g.ndata["out_deg"] = g.out_degrees().float().pow(-0.5) 110 | g.apply_edges(fn.u_mul_v("out_deg", "in_deg", "weight")) 111 | # row-normalize weight 112 | g.update_all(fn.copy_e("weight", "msg"), fn.sum("msg", "norm")) 113 | g.apply_edges(fn.e_div_v("weight", "norm", "weight")) 114 | return g.edata["weight"] 115 | 116 | 117 | def preprocess(g, features, args): 118 | """ 119 | Pre-compute the average of n-th hop neighbors 120 | """ 121 | with torch.no_grad(): 122 | g.edata["weight"] = calc_weight(g) 123 | g.ndata["feat_0"] = features 124 | for hop in range(1, args.R + 1): 125 | g.update_all( 126 | fn.u_mul_e(f"feat_{hop-1}", "weight", "msg"), 127 | fn.sum("msg", f"feat_{hop}"), 128 | ) 129 | res = [] 130 | for hop in range(args.R + 1): 131 | res.append(g.ndata.pop(f"feat_{hop}")) 132 | return res 133 | 134 | 135 | def prepare_data(device, args): 136 | data = load_dataset(args.partitioning_json_file, args.rank, device) 137 | g, n_classes, train_nid, val_nid, test_nid = data 138 | g = g.to(device) 139 | in_feats = g.ndata["feat"].shape[1] 140 | feats = preprocess(g, g.ndata["feat"], args) 141 | labels = g.ndata["label"] 142 | # move to device 143 | train_nid = train_nid.to(device) 144 | val_nid = val_nid.to(device) 145 | test_nid = test_nid.to(device) 146 | train_feats = [x[train_nid] for x in feats] 147 | train_labels = labels[train_nid] 148 | return ( 149 | feats, 150 | labels, 151 | train_feats, 152 | train_labels, 153 | in_feats, 154 | n_classes, 155 | train_nid, 156 | val_nid, 157 | test_nid, 158 | ) 159 | 160 | def evaluate(args, model, feats, labels, train, val, test): 161 | with torch.no_grad(): 162 | batch_size = args.eval_batch_size 163 | if batch_size <= 0: 164 | pred = model(feats) 165 | else: 166 | pred = [] 167 | num_nodes = labels.shape[0] 168 | n_batch = (num_nodes + batch_size - 1) // batch_size 169 | for i in range(n_batch): 170 | batch_start = i * batch_size 171 | batch_end = min((i + 1) * batch_size, num_nodes) 172 | batch_feats = [feat[batch_start:batch_end] for feat in feats] 173 | pred.append(model(batch_feats)) 174 | pred = torch.cat(pred) 175 | 176 | pred = torch.argmax(pred, dim=1) 177 | correct = (pred == labels).float() 178 | 179 | # Sum the n_correct, and number of mask elements across all workers 180 | results = [] 181 | for mask in [train, val, test]: 182 | n_correct = correct[mask].sum() 183 | results.extend([n_correct, mask.numel()]) 184 | 185 | acc_vec = torch.FloatTensor(results) 186 | # Sum the n_correct, and number of mask elements across all workers 187 | sar.comm.all_reduce(acc_vec, op=torch.distributed.ReduceOp.SUM, move_to_comm_device=True) 188 | (train_acc, val_acc, test_acc) = \ 189 | (acc_vec[0] / acc_vec[1], 190 | acc_vec[2] / acc_vec[3], 191 | acc_vec[4] / acc_vec[5],) 192 | 193 | return train_acc, val_acc, test_acc 194 | 195 | 196 | def main(args): 197 | if args.gpu < 0: 198 | device = "cpu" 199 | else: 200 | device = "cuda:{}".format(args.gpu) 201 | 202 | master_ip_address = sar.nfs_ip_init(args.rank, args.ip_file) 203 | sar.initialize_comms(args.rank, 204 | args.world_size, master_ip_address, 205 | args.backend) 206 | 207 | data = prepare_data(device, args) 208 | ( 209 | feats, 210 | labels, 211 | train_feats, 212 | train_labels, 213 | in_size, 214 | num_classes, 215 | train_nid, 216 | val_nid, 217 | test_nid, 218 | ) = data 219 | 220 | model = Model( 221 | in_size, 222 | args.num_hidden, 223 | num_classes, 224 | args.R, 225 | args.ff_layer, 226 | args.dropout, 227 | ).to(device) 228 | loss_fcn = nn.CrossEntropyLoss() 229 | optimizer = torch.optim.Adam( 230 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay 231 | ) 232 | sar.sync_params(model) 233 | 234 | best_epoch = 0 235 | best_val = 0 236 | best_test = 0 237 | 238 | for epoch in range(1, args.num_epochs + 1): 239 | start = time.time() 240 | model.train() 241 | loss = loss_fcn(model(train_feats), train_labels) 242 | optimizer.zero_grad() 243 | loss.backward() 244 | sar.gather_grads(model) 245 | optimizer.step() 246 | 247 | if epoch % args.eval_every == 0: 248 | model.eval() 249 | acc = evaluate( 250 | args, model, feats, labels, train_nid, val_nid, test_nid 251 | ) 252 | end = time.time() 253 | log = "Epoch {}, Times(s): {:.4f}".format(epoch, end - start) 254 | log += ", Accuracy: Train {:.4f}, Val {:.4f}, Test {:.4f}".format( 255 | *acc 256 | ) 257 | print(log) 258 | if acc[1] > best_val: 259 | best_val = acc[1] 260 | best_epoch = epoch 261 | best_test = acc[2] 262 | 263 | print( 264 | "Best Epoch {}, Val {:.4f}, Test {:.4f}".format( 265 | best_epoch, best_val, best_test 266 | ) 267 | ) 268 | 269 | 270 | if __name__ == "__main__": 271 | parser = argparse.ArgumentParser(description="SIGN") 272 | parser.add_argument("--partitioning-json-file", default="", type=str, 273 | help="Path to the .json file containing partitioning information") 274 | parser.add_argument("--ip-file", type=str, default="./ip_file", 275 | help="File with ip-address. " 276 | "Worker 0 creates this file and all others read it") 277 | parser.add_argument("--backend", type=str, default="ccl", 278 | choices=["ccl", "nccl", "mpi"], 279 | help="Communication backend to use") 280 | parser.add_argument("--rank", type=int, default=0, 281 | help="Rank of the current worker") 282 | parser.add_argument("--world-size", default=2, type=int, 283 | help="Number of workers ") 284 | parser.add_argument("--num-epochs", type=int, default=1000) 285 | parser.add_argument("--num-hidden", type=int, default=256) 286 | parser.add_argument("--R", type=int, default=3, help="number of hops") 287 | parser.add_argument("--lr", type=float, default=0.003) 288 | parser.add_argument("--dropout", type=float, default=0.5) 289 | parser.add_argument("--gpu", type=int, default=-1) 290 | parser.add_argument("--weight-decay", type=float, default=0) 291 | parser.add_argument("--eval-every", type=int, default=10) 292 | parser.add_argument("--eval-batch-size", type=int, default=250000, 293 | help="evaluation batch size, -1 for full batch") 294 | parser.add_argument("--ff-layer", type=int, default=2, help="number of feed-forward layers") 295 | args = parser.parse_args() 296 | 297 | print(args) 298 | main(args) 299 | -------------------------------------------------------------------------------- /examples/partition_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Intel Corporation 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | from argparse import ArgumentParser 22 | import dgl # type:ignore 23 | from dgl import AddReverse, Compose, ToSimple 24 | import torch 25 | from ogb.nodeproppred import DglNodePropPredDataset # type:ignore 26 | from dgl.data import ( 27 | CiteseerGraphDataset, 28 | CoraGraphDataset, 29 | PubmedGraphDataset, 30 | RedditDataset 31 | ) 32 | 33 | SUPPORTED_DATASETS = { 34 | "cora": CoraGraphDataset, 35 | "citeseer": CiteseerGraphDataset, 36 | "pubmed": PubmedGraphDataset, 37 | 'reddit': RedditDataset, 38 | "ogbn-products": DglNodePropPredDataset, 39 | "ogbn-arxiv": DglNodePropPredDataset, 40 | "ogbn-mag": DglNodePropPredDataset, 41 | } 42 | 43 | parser = ArgumentParser(description="Graph partitioning for common graph datasets") 44 | 45 | parser.add_argument("--dataset-root", type=str, default="./datasets/", 46 | help="The OGB datasets folder") 47 | 48 | parser.add_argument("--dataset-name", type=str, default="ogbn-arxiv", 49 | choices=['ogbn-arxiv', 'ogbn-products', 'ogbn-mag', 50 | 'cora', 'citeseer', 'pubmed', 'reddit'], 51 | help="Dataset name") 52 | 53 | parser.add_argument("--partition-out-path", type=str, default="./partition_data/", 54 | help="Path to the output directory for the partition data") 55 | 56 | parser.add_argument("--num-partitions", type=int, default=2, 57 | help="Number of graph partitions to generate") 58 | 59 | def get_dataset(args): 60 | dataset_name = args.dataset_name 61 | if dataset_name in ["cora", "citeseer", "pubmed"]: 62 | return SUPPORTED_DATASETS[dataset_name](args.dataset_root) 63 | elif dataset_name == 'reddit': 64 | return SUPPORTED_DATASETS[dataset_name](self_loop=True, raw_dir=args.dataset_root) 65 | else: 66 | return SUPPORTED_DATASETS[dataset_name](dataset_name, args.dataset_root) 67 | 68 | def prepare_features(args, dataset, graph): 69 | if args.dataset_name in ['cora', 'citeseer', 'pubmed', 'reddit']: 70 | assert all([x in graph.ndata.keys() for x in ['train_mask', 'val_mask', 'test_mask']]) 71 | return 72 | 73 | split_idx = dataset.get_idx_split() 74 | ntype = "paper" if args.dataset_name == "ogbn-mag" else None 75 | 76 | def idx_to_mask(idx_tensor): 77 | mask = torch.BoolTensor(graph.number_of_nodes(ntype)).fill_(False) 78 | if ntype: 79 | mask[idx_tensor[ntype]] = True 80 | else: 81 | mask[idx_tensor] = True 82 | return mask 83 | 84 | train_mask, val_mask, test_mask = map( 85 | idx_to_mask, [split_idx["train"], split_idx["valid"], split_idx["test"]]) 86 | 87 | if "feat" in graph.ndata.keys(): 88 | features = graph.ndata["feat"] 89 | else: 90 | features = graph.ndata["features"] 91 | 92 | graph.ndata.clear() 93 | 94 | labels = dataset[0][1] 95 | if ntype: 96 | features = features[ntype] 97 | labels = labels[ntype] 98 | labels = labels.view(-1) 99 | 100 | for name, val in zip(["train_mask", "val_mask", "test_mask", "labels", "features"], 101 | [train_mask, val_mask, test_mask, labels, features]): 102 | graph.ndata[name] = {ntype: val} if ntype else val 103 | 104 | def main(): 105 | args = parser.parse_args() 106 | dataset = get_dataset(args) 107 | dataset_name = args.dataset_name 108 | if dataset_name.startswith("ogbn"): 109 | graph = dataset[0][0] 110 | else: 111 | graph = dataset[0] 112 | 113 | if dataset_name != "ogbn-mag": 114 | graph = dgl.remove_self_loop(graph) 115 | graph = dgl.to_bidirected(graph, copy_ndata=True) 116 | graph = dgl.add_self_loop(graph) 117 | else: 118 | 119 | transform = Compose([ToSimple(), AddReverse()]) 120 | graph = transform(graph) 121 | 122 | prepare_features(args, dataset, graph) 123 | balance_ntypes = graph.ndata["train_mask"] \ 124 | if dataset_name in ["ogbn-products", "ogbn-arxiv"] else None 125 | dgl.distributed.partition_graph( 126 | graph, args.dataset_name, 127 | args.num_partitions, 128 | args.partition_out_path, 129 | num_hops=1, 130 | balance_ntypes=balance_ntypes, 131 | balance_edges=True) 132 | 133 | 134 | if __name__ == "__main__": 135 | main() 136 | -------------------------------------------------------------------------------- /examples/rgcn-hetero/README.md: -------------------------------------------------------------------------------- 1 | ## Hetero RGCN 2 | 3 | Exampel script for ogbn-mag dataset. 4 | Original script: https://github.com/dmlc/dgl/blob/master/examples/pytorch/ogb/ogbn-mag/hetero_rgcn.py 5 | You can find here two scripts `train_heterogeneous_graph.py` and `train_heterogeneous_graph_mfg.py`, the former is a simple full graph training and inference. The latter is a training and inference script which utilizes Message Flow Graph (MFG) - this approach is more computationally effective, because it computes embeddings only for nodes which require it, i.e. labeled nodes. 6 | 7 | ### Results 8 | Obtained results for two partitions (ogbn-mag dataset): 9 | - Train Acc: 77.18 ± 2.85% 10 | - Validation Acc: 40.03 ± 0.45% 11 | - Test Acc: 39.06 ± 0.44% 12 | Presented results are the average accuracies obtained after running 10 experiments (1 experiment = 60 epochs). Results from each experiment (train/val/test accuracies) were not necessarily taken from the 60th epoch. The values were obtained at the moment when the validation accuraccy was the highest. 13 | (Note: Results achieved in https://github.com/dmlc/dgl/blob/master/examples/pytorch/ogb/ogbn-mag/hetero_rgcn.py are different, because it uses mini-batch training instead of full-graph like in SAR) 14 | 15 | 16 | ### Run command: 17 | 18 | ``` 19 | python examples/rgcn-hetero/train_heterogeneous_graph.py --partitioning-json-file partition_data/ogbn-mag.json --ip-file ip_file --backend ccl --rank 0 --world-size 2 --train-iters 60 20 | 21 | python examples/rgcn-hetero/train_heterogeneous_graph.py --partitioning-json-file partition_data/ogbn-mag.json --ip-file ip_file --backend ccl --rank 1 --world-size 2 --train-iters 60 22 | ``` -------------------------------------------------------------------------------- /examples/rgcn-hetero/model.py: -------------------------------------------------------------------------------- 1 | ################################################################################################################ 2 | # File's content taken from https://github.com/dmlc/dgl/blob/master/examples/pytorch/ogb/ogbn-mag/hetero_rgcn.py 3 | ################################################################################################################ 4 | 5 | import dgl.nn as dglnn 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from dgl.nn import HeteroEmbedding 9 | 10 | def extract_embed(node_embed, input_nodes): 11 | emb = node_embed( 12 | {ntype: input_nodes[ntype] for ntype in input_nodes if ntype != "paper"} 13 | ) 14 | return emb 15 | 16 | def rel_graph_embed(graph, embed_size, num_nodes_dict): 17 | node_num = {} 18 | for ntype in graph.ntypes: 19 | if ntype == "paper": 20 | continue 21 | node_num[ntype] = num_nodes_dict[ntype] 22 | embeds = HeteroEmbedding(node_num, embed_size) 23 | return embeds 24 | 25 | 26 | class RelGraphConvLayer(nn.Module): 27 | def __init__( 28 | self, in_feat, out_feat, ntypes, rel_names, activation=None, dropout=0.0, self_loop=True 29 | ): 30 | super(RelGraphConvLayer, self).__init__() 31 | self.in_feat = in_feat 32 | self.out_feat = out_feat 33 | self.ntypes = ntypes 34 | self.rel_names = rel_names 35 | self.activation = activation 36 | self.self_loop = self_loop 37 | 38 | self.conv = dglnn.HeteroGraphConv( 39 | { 40 | rel: dglnn.GraphConv( 41 | in_feat, out_feat, norm="right", weight=False, bias=False 42 | ) 43 | for rel in rel_names 44 | } 45 | ) 46 | 47 | self.weight = nn.ModuleDict( 48 | { 49 | rel_name: nn.Linear(in_feat, out_feat, bias=False) 50 | for rel_name in self.rel_names 51 | } 52 | ) 53 | 54 | # weight for self loop 55 | if self.self_loop: 56 | self.loop_weights = nn.ModuleDict( 57 | { 58 | ntype: nn.Linear(in_feat, out_feat, bias=True) 59 | for ntype in self.ntypes 60 | } 61 | ) 62 | 63 | self.dropout = nn.Dropout(dropout) 64 | self.reset_parameters() 65 | 66 | def reset_parameters(self): 67 | for layer in self.weight.values(): 68 | layer.reset_parameters() 69 | if self.self_loop: 70 | for layer in self.loop_weights.values(): 71 | layer.reset_parameters() 72 | 73 | def forward(self, g, inputs): 74 | """ 75 | Parameters 76 | ---------- 77 | g : DGLGraph 78 | Input graph. 79 | inputs : dict[str, torch.Tensor] 80 | Node feature for each node type. 81 | 82 | Returns 83 | ------- 84 | dict[str, torch.Tensor] 85 | New node features for each node type. 86 | """ 87 | with g.local_scope(): 88 | wdict = { 89 | rel_name: {"weight": self.weight[rel_name].weight.T} 90 | for rel_name in self.rel_names 91 | } 92 | 93 | if self.self_loop: 94 | inputs_dst = { 95 | k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items() 96 | } 97 | 98 | hs = self.conv(g, inputs, mod_kwargs=wdict) 99 | 100 | def _apply(ntype, h): 101 | if self.self_loop: 102 | h = h + self.loop_weights[ntype](inputs_dst[ntype]) 103 | if self.activation: 104 | h = self.activation(h) 105 | return self.dropout(h) 106 | 107 | return {ntype: _apply(ntype, h) for ntype, h in hs.items()} 108 | 109 | 110 | class EntityClassify(nn.Module): 111 | def __init__(self, g, in_dim, out_dim): 112 | super(EntityClassify, self).__init__() 113 | self.in_dim = in_dim 114 | self.h_dim = 64 115 | self.out_dim = out_dim 116 | self.rel_names = list(set(g.etypes)) 117 | self.rel_names.sort() 118 | self.dropout = 0.5 119 | 120 | self.layers = nn.ModuleList() 121 | # i2h 122 | self.layers.append( 123 | RelGraphConvLayer( 124 | self.in_dim, 125 | self.h_dim, 126 | g.ntypes, 127 | self.rel_names, 128 | activation=F.relu, 129 | dropout=self.dropout, 130 | self_loop=g.tgt_in_src 131 | ) 132 | ) 133 | 134 | # h2o 135 | self.layers.append( 136 | RelGraphConvLayer( 137 | self.h_dim, 138 | self.out_dim, 139 | g.ntypes, 140 | self.rel_names, 141 | activation=None, 142 | self_loop=g.tgt_in_src 143 | ) 144 | ) 145 | 146 | def reset_parameters(self): 147 | for layer in self.layers: 148 | layer.reset_parameters() 149 | 150 | def forward(self, graph, h): 151 | if isinstance(graph, list): 152 | # Message Flow Graph 153 | for layer, block in zip(self.layers, graph): 154 | h = layer(block, h) 155 | else: 156 | for layer in self.layers: 157 | h = layer(graph, h) 158 | return h 159 | -------------------------------------------------------------------------------- /examples/rgcn-hetero/train_heterogeneous_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Intel Corporation 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | from argparse import ArgumentParser 22 | from model import EntityClassify, rel_graph_embed, extract_embed 23 | import time 24 | import itertools 25 | import torch 26 | import torch.nn.functional as F 27 | import torch.distributed as dist 28 | import dgl # type: ignore 29 | 30 | import sar 31 | 32 | 33 | parser = ArgumentParser( 34 | description="GNN training on node classification tasks in heterogenous graphs (MFG)") 35 | 36 | 37 | parser.add_argument("--partitioning-json-file", type=str, default="", 38 | help="Path to the .json file containing partitioning information") 39 | 40 | parser.add_argument("--ip-file", default="./ip_file", type=str, 41 | help="File with ip-address. Worker 0 creates this file and all others read it") 42 | 43 | parser.add_argument("--backend", default="nccl", type=str, choices=["ccl", "nccl", "mpi", "gloo"], 44 | help="Communication backend to use") 45 | 46 | parser.add_argument("--cpu-run", action="store_true", 47 | help="Run on CPUs if set, otherwise run on GPUs") 48 | 49 | parser.add_argument("--train-iters", default=60, type=int, 50 | help="number of training iterations") 51 | 52 | parser.add_argument("--lr", type=float, default=0.01, 53 | help="learning rate") 54 | 55 | parser.add_argument("--rank", default=0, type=int, 56 | help="Rank of the current worker") 57 | 58 | parser.add_argument("--world-size", default=2, type=int, 59 | help="Number of workers") 60 | 61 | parser.add_argument("--features-dim", default=128, type=int, 62 | help="Dimension of the node features") 63 | 64 | 65 | def main(): 66 | args = parser.parse_args() 67 | print('args', args) 68 | 69 | # Patch DGL's attention-based layers and RelGraphConv to support distributed graphs 70 | sar.patch_dgl() 71 | 72 | use_gpu = torch.cuda.is_available() and not args.cpu_run 73 | device = torch.device('cuda' if use_gpu else 'cpu') 74 | 75 | # Obtain the ip address of the master through the network file system 76 | master_ip_address = sar.nfs_ip_init(args.rank, args.ip_file) 77 | sar.initialize_comms(args.rank, 78 | args.world_size, master_ip_address, 79 | args.backend) 80 | 81 | # Load DGL partition data 82 | partition_data = sar.load_dgl_partition_data( 83 | args.partitioning_json_file, args.rank, device) 84 | partition_data.node_features["paper/features"] = partition_data.node_features["paper/features"].float() 85 | 86 | # Obtain train,validation, and test masks 87 | # These are stored as node features. Partitioning may prepend 88 | # the node type to the mask names. So we use the convenience function 89 | # suffix_key_lookup to look up the mask name while ignoring the 90 | # arbitrary node type 91 | #The train/val/test masks are only defined for nodes with type 'paper'. 92 | #We set the ``expand_to_all`` flag to expand the mask to all nodes in the 93 | #graph (mask will be filled with zeros). We use the expand_all option when 94 | #loading other node-type specific tensors such as features and labels 95 | 96 | bool_masks = {} 97 | for mask_name in ['train_mask', 'val_mask', 'test_mask']: 98 | local_mask = sar.suffix_key_lookup(partition_data.node_features, 99 | mask_name, 100 | expand_to_all = False, 101 | type_list = partition_data.node_type_names) 102 | bool_masks[mask_name] = local_mask.bool() 103 | 104 | labels = sar.suffix_key_lookup(partition_data.node_features, 105 | 'labels', 106 | expand_to_all = False, 107 | type_list = partition_data.node_type_names).long().to(device) 108 | 109 | # Obtain the number of classes by finding the max label across all workers 110 | num_labels = labels.max() + 1 111 | sar.comm.all_reduce(num_labels, dist.ReduceOp.MAX, move_to_comm_device=True) 112 | num_labels = num_labels.item() 113 | 114 | features = sar.suffix_key_lookup(partition_data.node_features, 115 | 'features', 116 | type_list = partition_data.node_type_names 117 | ).to(device) 118 | 119 | full_graph_manager = sar.construct_full_graph(partition_data).to(device) 120 | 121 | max_num_nodes = {} 122 | for ntype in partition_data.partition_book.ntypes: 123 | nodes = full_graph_manager.srcnodes(ntype) 124 | nodes_max = nodes.max() 125 | max_num_nodes[ntype] = nodes_max + 1 126 | 127 | embed_layer = rel_graph_embed(full_graph_manager, args.features_dim, max_num_nodes).to(device) 128 | gnn_model = EntityClassify(full_graph_manager, args.features_dim, num_labels).to(device) 129 | 130 | print('model', gnn_model) 131 | embed_layer.reset_parameters() 132 | gnn_model.reset_parameters() 133 | 134 | # Synchronize the model parmeters across all workers 135 | sar.sync_params(gnn_model) 136 | 137 | # Obtain the number of labeled nodes in the training 138 | # This will be needed to properly obtain a cross entropy loss 139 | # normalized by the number of training examples 140 | n_train_points = torch.LongTensor([bool_masks["train_mask"].sum().item()]) 141 | sar.comm.all_reduce(n_train_points, op=dist.ReduceOp.SUM, move_to_comm_device=True) 142 | n_train_points = n_train_points.item() 143 | 144 | all_params = itertools.chain( 145 | gnn_model.parameters(), embed_layer.parameters() 146 | ) 147 | optimizer = torch.optim.Adam(all_params, lr=args.lr) 148 | 149 | for train_iter_idx in range(args.train_iters): 150 | # Train 151 | t_1 = time.time() 152 | gnn_model.train() 153 | 154 | embeds = extract_embed(embed_layer, {ntype: full_graph_manager.srcnodes(ntype) for ntype in full_graph_manager.srctypes}) 155 | embeds.update({"paper": features[full_graph_manager.srcnodes("paper")]}) 156 | embeds = {k: e.to(device) for k, e in embeds.items()} 157 | 158 | logits = gnn_model(full_graph_manager, embeds) 159 | logits = logits["paper"].log_softmax(dim=-1) 160 | train_mask = bool_masks["train_mask"] 161 | loss = F.nll_loss(logits[train_mask], labels[train_mask], reduction="sum") / n_train_points 162 | 163 | optimizer.zero_grad() 164 | loss.backward() 165 | # Do not forget to gather the parameter gradients from all workers 166 | sar.gather_grads(gnn_model) 167 | optimizer.step() 168 | train_time = time.time() - t_1 169 | 170 | # Calculate accuracy for train/validation/test 171 | results = [] 172 | gnn_model.eval() 173 | with torch.no_grad(): 174 | embeds = extract_embed(embed_layer, {ntype: full_graph_manager.srcnodes(ntype) for ntype in full_graph_manager.srctypes}) 175 | embeds.update({"paper": features[full_graph_manager.srcnodes("paper")]}) 176 | embeds = {k: e.to(device) for k, e in embeds.items()} 177 | 178 | logits = gnn_model(full_graph_manager, embeds) 179 | logits = logits["paper"].log_softmax(dim=-1) 180 | 181 | for mask_name in ['train_mask', 'val_mask', 'test_mask']: 182 | masked_nodes = bool_masks[mask_name] 183 | if masked_nodes.sum() > 0: 184 | active_logits = logits[masked_nodes] 185 | active_labels = labels[masked_nodes] 186 | loss = F.nll_loss(active_logits, active_labels, reduction="sum") 187 | n_correct = (active_logits.argmax(1) == active_labels).float().sum() 188 | results.extend([loss.item(), n_correct.item(), masked_nodes.sum().item()]) 189 | else: 190 | results.extend([0.0, 0.0, 0.0]) 191 | 192 | loss_acc_vec = torch.FloatTensor(results) 193 | # Sum the n_correct, and number of mask elements across all workers 194 | sar.comm.all_reduce(loss_acc_vec, op=dist.ReduceOp.SUM, move_to_comm_device=True) 195 | (train_loss, train_acc, val_loss, val_acc, test_loss, test_acc) = \ 196 | (loss_acc_vec[0] / loss_acc_vec[2], 197 | loss_acc_vec[1] / loss_acc_vec[2], 198 | loss_acc_vec[3] / loss_acc_vec[5], 199 | loss_acc_vec[4] / loss_acc_vec[5], 200 | loss_acc_vec[6] / loss_acc_vec[8], 201 | loss_acc_vec[7] / loss_acc_vec[8]) 202 | 203 | result_message = ( 204 | f"iteration [{train_iter_idx}/{args.train_iters}] | " 205 | ) 206 | result_message += ', '.join([ 207 | f"train loss={train_loss:.4f}, " 208 | f"Accuracy: " 209 | f"train={100 * train_acc:.4f} " 210 | f"valid={100 * val_acc:.4f} " 211 | f"test={100 * test_acc:.4f} " 212 | f" | train time = {train_time} " 213 | f" |" 214 | ]) 215 | print(result_message, flush=True) 216 | 217 | 218 | if __name__ == '__main__': 219 | main() 220 | -------------------------------------------------------------------------------- /examples/rgcn-hetero/train_heterogeneous_graph_mfg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Intel Corporation 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | from argparse import ArgumentParser 22 | from model import EntityClassify, rel_graph_embed, extract_embed 23 | import time 24 | import itertools 25 | import torch 26 | import torch.nn.functional as F 27 | import torch.distributed as dist 28 | import dgl # type: ignore 29 | 30 | import sar 31 | 32 | 33 | parser = ArgumentParser( 34 | description="GNN training on node classification tasks in heterogenous graphs") 35 | 36 | 37 | parser.add_argument("--partitioning-json-file", type=str, default="", 38 | help="Path to the .json file containing partitioning information") 39 | 40 | parser.add_argument("--ip-file", default="./ip_file", type=str, 41 | help="File with ip-address. Worker 0 creates this file and all others read it") 42 | 43 | parser.add_argument("--backend", default="nccl", type=str, choices=["ccl", "nccl", "mpi", "gloo"], 44 | help="Communication backend to use") 45 | 46 | parser.add_argument("--cpu-run", action="store_true", 47 | help="Run on CPUs if set, otherwise run on GPUs") 48 | 49 | parser.add_argument("--train-iters", default=60, type=int, 50 | help="number of training iterations") 51 | 52 | parser.add_argument("--lr", type=float, default=0.01, 53 | help="learning rate") 54 | 55 | parser.add_argument("--rank", default=0, type=int, 56 | help="Rank of the current worker") 57 | 58 | parser.add_argument("--world-size", default=2, type=int, 59 | help="Number of workers") 60 | 61 | parser.add_argument("--features-dim", default=128, type=int, 62 | help="Dimension of the node features") 63 | 64 | 65 | def main(): 66 | args = parser.parse_args() 67 | print('args', args) 68 | 69 | # Patch DGL's attention-based layers and RelGraphConv to support distributed graphs 70 | sar.patch_dgl() 71 | 72 | use_gpu = torch.cuda.is_available() and not args.cpu_run 73 | device = torch.device('cuda' if use_gpu else 'cpu') 74 | 75 | # Obtain the ip address of the master through the network file system 76 | master_ip_address = sar.nfs_ip_init(args.rank, args.ip_file) 77 | sar.initialize_comms(args.rank, 78 | args.world_size, master_ip_address, 79 | args.backend) 80 | 81 | # Load DGL partition data 82 | partition_data = sar.load_dgl_partition_data( 83 | args.partitioning_json_file, args.rank, device) 84 | partition_data.node_features["paper/features"] = partition_data.node_features["paper/features"].float() 85 | 86 | # Obtain train,validation, and test masks 87 | # These are stored as node features. Partitioning may prepend 88 | # the node type to the mask names. So we use the convenience function 89 | # suffix_key_lookup to look up the mask name while ignoring the 90 | # arbitrary node type 91 | #The train/val/test masks are only defined for nodes with type 'paper'. 92 | #We set the ``expand_to_all`` flag to expand the mask to all nodes in the 93 | #graph (mask will be filled with zeros). We use the expand_all option when 94 | #loading other node-type specific tensors such as features and labels 95 | 96 | bool_masks = {} 97 | for mask_name in ['train_mask', 'val_mask', 'test_mask']: 98 | local_mask = sar.suffix_key_lookup(partition_data.node_features, 99 | mask_name, 100 | expand_to_all = False, 101 | type_list = partition_data.node_type_names) 102 | bool_masks[mask_name] = local_mask.bool() 103 | 104 | indices_masks = {} 105 | for mask_name, indices_name in zip(['train_mask', 'val_mask', 'test_mask'], 106 | ['train_indices', 'val_indices', 'test_indices']): 107 | global_mask = sar.suffix_key_lookup(partition_data.node_features, 108 | mask_name, 109 | expand_to_all = True, 110 | type_list = partition_data.node_type_names) 111 | indices_masks[indices_name] = global_mask.nonzero(as_tuple=False).view(-1).to(device) 112 | 113 | 114 | labels = sar.suffix_key_lookup(partition_data.node_features, 115 | 'labels', 116 | expand_to_all = False, 117 | type_list = partition_data.node_type_names).long().to(device) 118 | 119 | # Obtain the number of classes by finding the max label across all workers 120 | num_labels = labels.max() + 1 121 | sar.comm.all_reduce(num_labels, dist.ReduceOp.MAX, move_to_comm_device=True) 122 | num_labels = num_labels.item() 123 | 124 | features = sar.suffix_key_lookup(partition_data.node_features, 125 | 'features', 126 | type_list = partition_data.node_type_names 127 | ).to(device) 128 | # Create MFGs 129 | train_blocks = sar.construct_mfgs(partition_data, 130 | indices_masks['train_indices'] + 131 | partition_data.node_ranges[sar.comm.rank()][0], 132 | 2) 133 | 134 | eval_blocks = sar.construct_mfgs(partition_data, 135 | torch.cat((indices_masks['train_indices'], 136 | indices_masks['val_indices'], 137 | indices_masks['test_indices'])) + 138 | partition_data.node_ranges[sar.comm.rank()][0], 139 | 2) 140 | train_blocks = [block.to(device) for block in train_blocks] 141 | eval_blocks = [block.to(device) for block in eval_blocks] 142 | 143 | max_num_nodes = {} 144 | for ntype in partition_data.partition_book.ntypes: 145 | nodes = eval_blocks[0].srcnodes(ntype) 146 | nodes_max = nodes.max() 147 | max_num_nodes[ntype] = nodes_max + 1 148 | 149 | embed_layer = rel_graph_embed(eval_blocks[0], args.features_dim, max_num_nodes).to(device) 150 | gnn_model = EntityClassify(eval_blocks[0], args.features_dim, num_labels).to(device) 151 | 152 | print('model', gnn_model) 153 | embed_layer.reset_parameters() 154 | gnn_model.reset_parameters() 155 | 156 | # Synchronize the model parmeters across all workers 157 | sar.sync_params(gnn_model) 158 | 159 | # Obtain the number of labeled nodes in the training 160 | # This will be needed to properly obtain a cross entropy loss 161 | # normalized by the number of training examples 162 | n_train_points = torch.LongTensor([indices_masks['train_indices'].numel()]) 163 | sar.comm.all_reduce(n_train_points, op=dist.ReduceOp.SUM, move_to_comm_device=True) 164 | n_train_points = n_train_points.item() 165 | 166 | all_params = itertools.chain( 167 | gnn_model.parameters(), embed_layer.parameters() 168 | ) 169 | optimizer = torch.optim.Adam(all_params, lr=args.lr) 170 | 171 | for train_iter_idx in range(args.train_iters): 172 | # Train 173 | t_1 = time.time() 174 | gnn_model.train() 175 | 176 | embeds = extract_embed(embed_layer, {ntype: train_blocks[0].srcnodes(ntype) for ntype in train_blocks[0].srctypes}) 177 | embeds.update({"paper": features[train_blocks[0].srcnodes("paper")]}) 178 | embeds = {k: e.to(device) for k, e in embeds.items()} 179 | 180 | logits = gnn_model(train_blocks, embeds) 181 | logits = logits["paper"].log_softmax(dim=-1) 182 | train_mask = bool_masks["train_mask"] 183 | loss = F.nll_loss(logits, labels[train_mask], reduction="sum") / n_train_points 184 | 185 | optimizer.zero_grad() 186 | loss.backward() 187 | # Do not forget to gather the parameter gradients from all workers 188 | sar.gather_grads(gnn_model) 189 | optimizer.step() 190 | train_time = time.time() - t_1 191 | 192 | # Calculate accuracy for train/validation/test 193 | results = [] 194 | gnn_model.eval() 195 | with torch.no_grad(): 196 | embeds = extract_embed(embed_layer, {ntype: eval_blocks[0].srcnodes(ntype) for ntype in eval_blocks[0].srctypes}) 197 | embeds.update({"paper": features[eval_blocks[0].srcnodes("paper")]}) 198 | embeds = {k: e.to(device) for k, e in embeds.items()} 199 | 200 | logits = gnn_model(eval_blocks, embeds) 201 | logits = logits["paper"].log_softmax(dim=-1) 202 | 203 | for mask_name in ['train_mask', 'val_mask', 'test_mask']: 204 | masked_nodes = bool_masks[mask_name] 205 | if masked_nodes.sum() > 0: 206 | active_logits = logits[masked_nodes] 207 | active_labels = labels[masked_nodes] 208 | loss = F.nll_loss(active_logits, active_labels, reduction="sum") 209 | n_correct = (active_logits.argmax(1) == active_labels).float().sum() 210 | results.extend([loss.item(), n_correct.item(), masked_nodes.sum().item()]) 211 | else: 212 | results.extend([0.0, 0.0, 0.0]) 213 | 214 | loss_acc_vec = torch.FloatTensor(results) 215 | # Sum the n_correct, and number of mask elements across all workers 216 | sar.comm.all_reduce(loss_acc_vec, op=dist.ReduceOp.SUM, move_to_comm_device=True) 217 | (train_loss, train_acc, val_loss, val_acc, test_loss, test_acc) = \ 218 | (loss_acc_vec[0] / loss_acc_vec[2], 219 | loss_acc_vec[1] / loss_acc_vec[2], 220 | loss_acc_vec[3] / loss_acc_vec[5], 221 | loss_acc_vec[4] / loss_acc_vec[5], 222 | loss_acc_vec[6] / loss_acc_vec[8], 223 | loss_acc_vec[7] / loss_acc_vec[8]) 224 | 225 | result_message = ( 226 | f"iteration [{train_iter_idx}/{args.train_iters}] | " 227 | ) 228 | result_message += ', '.join([ 229 | f"train loss={train_loss:.4f}, " 230 | f"Accuracy: " 231 | f"train={100 * train_acc:.4f} " 232 | f"valid={100 * val_acc:.4f} " 233 | f"test={100 * test_acc:.4f} " 234 | f" | train time = {train_time} " 235 | f" |" 236 | ]) 237 | print(result_message, flush=True) 238 | 239 | 240 | if __name__ == '__main__': 241 | main() 242 | -------------------------------------------------------------------------------- /examples/train_dist_appnp_with_sar.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import dgl # type: ignore 4 | from dgl.nn.pytorch.conv import APPNPConv 5 | 6 | import sar 7 | 8 | import time 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | import torch.distributed as dist 13 | 14 | parser = ArgumentParser(description="APPNP example") 15 | 16 | parser.add_argument("--partitioning-json-file", type=str, default="", 17 | help="Path to the .json file containing partitioning information") 18 | 19 | parser.add_argument("--ip-file", type=str, default="./ip_file", 20 | help="File with ip-address. Worker 0 creates this file and all others read it") 21 | 22 | parser.add_argument("--backend", type=str, default="nccl", 23 | choices=["ccl", "nccl", "mpi", "gloo"], 24 | help="Communication backend to use") 25 | 26 | parser.add_argument("--cpu-run", action="store_true", 27 | help="Run on CPUs if set, otherwise run on GPUs") 28 | 29 | parser.add_argument("--train-iters", type=int, default=100, 30 | help="number of training iterations") 31 | 32 | parser.add_argument("--lr", type=float, default=1e-2, 33 | help="learning rate") 34 | 35 | parser.add_argument("--rank", type=int, default=0, 36 | help="Rank of the current worker") 37 | 38 | parser.add_argument("--world-size", type=int, default=2, 39 | help="Number of workers") 40 | 41 | parser.add_argument("--hidden-layer-dim", type=int, default=[64], nargs="+", 42 | help="Dimension of GNN hidden layer") 43 | 44 | parser.add_argument("--k", type=int, default=10, 45 | help="Number of propagation steps") 46 | 47 | parser.add_argument("--alpha", type=float, default=0.1, 48 | help="Teleport Probability") 49 | 50 | parser.add_argument("--in-drop", type=float, default=0.5, 51 | help="input feature dropout") 52 | 53 | parser.add_argument("--edge-drop", type=float, default=0.5, 54 | help="edge propagation dropout") 55 | 56 | class APPNP(nn.Module): 57 | def __init__( 58 | self, 59 | g, 60 | in_feats, 61 | hiddens, 62 | n_classes, 63 | activation, 64 | feat_drop, 65 | edge_drop, 66 | alpha, 67 | k, 68 | ): 69 | super(APPNP, self).__init__() 70 | self.g = g 71 | self.layers = nn.ModuleList() 72 | # input layer 73 | self.layers.append(nn.Linear(in_feats, hiddens[0])) 74 | # hidden layers 75 | for i in range(1, len(hiddens)): 76 | self.layers.append(nn.Linear(hiddens[i - 1], hiddens[i])) 77 | # output layer 78 | self.layers.append(nn.Linear(hiddens[-1], n_classes)) 79 | self.activation = activation 80 | if feat_drop: 81 | self.feat_drop = nn.Dropout(feat_drop) 82 | else: 83 | self.feat_drop = lambda x: x 84 | self.propagate = APPNPConv(k, alpha, edge_drop) 85 | self.reset_parameters() 86 | 87 | def reset_parameters(self): 88 | for layer in self.layers: 89 | layer.reset_parameters() 90 | 91 | def forward(self, features): 92 | # prediction step 93 | h = features 94 | h = self.feat_drop(h) 95 | h = self.activation(self.layers[0](h)) 96 | for layer in self.layers[1:-1]: 97 | h = self.activation(layer(h)) 98 | h = self.layers[-1](self.feat_drop(h)) 99 | # propagation step 100 | h = self.propagate(self.g, h) 101 | return h 102 | 103 | def evaluate(model, features, labels, masks): 104 | model.eval() 105 | train_mask, val_mask, test_mask = masks['train_indices'], masks['val_indices'], masks['test_indices'] 106 | with torch.no_grad(): 107 | logits = model(features) 108 | results = [] 109 | for mask in [train_mask, val_mask, test_mask]: 110 | n_correct = (logits[mask].argmax(1) == 111 | labels[mask]).float().sum() 112 | results.extend([n_correct, mask.numel()]) 113 | 114 | acc_vec = torch.FloatTensor(results) 115 | # Sum the n_correct, and number of mask elements across all workers 116 | sar.comm.all_reduce(acc_vec, op=dist.ReduceOp.SUM, move_to_comm_device=True) 117 | (train_acc, val_acc, test_acc) = \ 118 | (acc_vec[0] / acc_vec[1], 119 | acc_vec[2] / acc_vec[3], 120 | acc_vec[4] / acc_vec[5]) 121 | 122 | return train_acc, val_acc, test_acc 123 | 124 | def main(): 125 | args = parser.parse_args() 126 | print('args', args) 127 | 128 | use_gpu = torch.cuda.is_available() and not args.cpu_run 129 | device = torch.device('cuda' if use_gpu else 'cpu') 130 | 131 | # Obtain the ip address of the master through the network file system 132 | master_ip_address = sar.nfs_ip_init(args.rank, args.ip_file) 133 | sar.initialize_comms(args.rank, 134 | args.world_size, 135 | master_ip_address, 136 | args.backend) 137 | 138 | # Load DGL partition data 139 | partition_data = sar.load_dgl_partition_data( 140 | args.partitioning_json_file, args.rank, device) 141 | 142 | # Obtain train,validation, and test masks 143 | # These are stored as node features. Partitioning may prepend 144 | # the node type to the mask names. So we use the convenience function 145 | # suffix_key_lookup to look up the mask name while ignoring the 146 | # arbitrary node type 147 | masks = {} 148 | for mask_name, indices_name in zip(['train_mask', 'val_mask', 'test_mask'], 149 | ['train_indices', 'val_indices', 'test_indices']): 150 | boolean_mask = sar.suffix_key_lookup(partition_data.node_features, 151 | mask_name) 152 | masks[indices_name] = boolean_mask.nonzero( 153 | as_tuple=False).view(-1).to(device) 154 | 155 | labels = sar.suffix_key_lookup(partition_data.node_features, 156 | 'label').long().to(device) 157 | 158 | # Obtain the number of classes by finding the max label across all workers 159 | num_labels = labels.max() + 1 160 | sar.comm.all_reduce(num_labels, dist.ReduceOp.MAX, move_to_comm_device=True) 161 | num_labels = num_labels.item() 162 | 163 | features = sar.suffix_key_lookup(partition_data.node_features, 'feat').to(device) 164 | full_graph_manager = sar.construct_full_graph(partition_data).to(device) 165 | 166 | # We do not need the partition data anymore 167 | del partition_data 168 | 169 | gnn_model = APPNP( 170 | full_graph_manager, 171 | features.size(1), 172 | args.hidden_layer_dim, 173 | num_labels, 174 | F.relu, 175 | args.in_drop, 176 | args.edge_drop, 177 | args.alpha, 178 | args.k) 179 | 180 | gnn_model.reset_parameters() 181 | print('model', gnn_model) 182 | 183 | # Synchronize the model parmeters across all workers 184 | sar.sync_params(gnn_model) 185 | 186 | # Obtain the number of labeled nodes in the training 187 | # This will be needed to properly obtain a cross entropy loss 188 | # normalized by the number of training examples 189 | n_train_points = torch.LongTensor([masks['train_indices'].numel()]) 190 | sar.comm.all_reduce(n_train_points, op=dist.ReduceOp.SUM, move_to_comm_device=True) 191 | n_train_points = n_train_points.item() 192 | 193 | optimizer = torch.optim.Adam(gnn_model.parameters(), lr=args.lr, weight_decay=5e-4) 194 | for train_iter_idx in range(args.train_iters): 195 | # Train 196 | gnn_model.train() 197 | t_1 = time.time() 198 | logits = gnn_model(features) 199 | loss = F.cross_entropy(logits[masks['train_indices']], 200 | labels[masks['train_indices']], reduction='sum') / n_train_points 201 | 202 | optimizer.zero_grad() 203 | loss.backward() 204 | # Do not forget to gather the parameter gradients from all workers 205 | sar.gather_grads(gnn_model) 206 | optimizer.step() 207 | train_time = time.time() - t_1 208 | 209 | if (train_iter_idx + 1) % 10 == 0: 210 | train_acc, val_acc, test_acc = evaluate(gnn_model, features, labels, masks) 211 | 212 | result_message = ( 213 | f"iteration [{train_iter_idx + 1}/{args.train_iters}] | " 214 | ) 215 | result_message += ', '.join([ 216 | f"train loss={loss:.4f}, " 217 | f"Accuracy: " 218 | f"train={train_acc:.4f} " 219 | f"valid={val_acc:.4f} " 220 | f"test={test_acc:.4f} " 221 | f" | train time = {train_time} " 222 | f" |" 223 | ]) 224 | print(result_message, flush=True) 225 | 226 | if __name__ == '__main__': 227 | main() 228 | -------------------------------------------------------------------------------- /examples/train_homogeneous_graph_basic.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Intel Corporation 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | from typing import List, Union, Dict 22 | from argparse import ArgumentParser 23 | import os 24 | import logging 25 | import time 26 | import torch 27 | import torch.nn.functional as F 28 | from torch import nn 29 | import torch.distributed as dist 30 | import dgl # type: ignore 31 | 32 | import sar 33 | 34 | 35 | parser = ArgumentParser( 36 | description="GNN training on node classification tasks in homogeneous graphs") 37 | 38 | 39 | parser.add_argument( 40 | "--partitioning-json-file", 41 | type=str, 42 | default="", 43 | help="Path to the .json file containing partitioning information " 44 | ) 45 | 46 | parser.add_argument('--ip-file', default='./ip_file', type=str, 47 | help='File with ip-address. Worker 0 creates this file and all others read it ') 48 | 49 | 50 | parser.add_argument('--backend', default='nccl', type=str, choices=['ccl', 'nccl', 'mpi', 'gloo'], 51 | help='Communication backend to use ' 52 | ) 53 | 54 | parser.add_argument( 55 | "--cpu-run", action="store_true", 56 | help="Run on CPUs if set, otherwise run on GPUs " 57 | ) 58 | 59 | 60 | parser.add_argument('--train-iters', default=100, type=int, 61 | help='number of training iterations ') 62 | 63 | parser.add_argument( 64 | "--lr", 65 | type=float, 66 | default=1e-2, 67 | help="learning rate" 68 | ) 69 | 70 | 71 | parser.add_argument('--rank', default=0, type=int, 72 | help='Rank of the current worker ') 73 | 74 | parser.add_argument('--world-size', default=2, type=int, 75 | help='Number of workers ') 76 | 77 | parser.add_argument('--hidden-layer-dim', default=256, type=int, 78 | help='Dimension of GNN hidden layer') 79 | 80 | 81 | class GNNModel(nn.Module): 82 | def __init__(self, in_dim: int, hidden_dim: int, out_dim: int): 83 | super().__init__() 84 | 85 | self.convs = nn.ModuleList([ 86 | # pylint: disable=no-member 87 | dgl.nn.SAGEConv(in_dim, hidden_dim, aggregator_type='mean'), 88 | # pylint: disable=no-member 89 | dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type='mean'), 90 | # pylint: disable=no-member 91 | dgl.nn.SAGEConv(hidden_dim, out_dim, aggregator_type='mean'), 92 | ]) 93 | 94 | def forward(self, graph: sar.HeteroGraphShardManager, features: torch.Tensor): 95 | for idx, conv in enumerate(self.convs): 96 | features = conv(graph, features) 97 | if idx < len(self.convs) - 1: 98 | features = F.relu(features, inplace=True) 99 | 100 | return features 101 | 102 | 103 | def main(): 104 | args = parser.parse_args() 105 | print('args', args) 106 | 107 | use_gpu = torch.cuda.is_available() and not args.cpu_run 108 | device = torch.device('cuda' if use_gpu else 'cpu') 109 | 110 | # Obtain the ip address of the master through the network file system 111 | master_ip_address = sar.nfs_ip_init(args.rank, args.ip_file) 112 | sar.initialize_comms(args.rank, 113 | args.world_size, 114 | master_ip_address, 115 | args.backend) 116 | 117 | # Load DGL partition data 118 | partition_data = sar.load_dgl_partition_data( 119 | args.partitioning_json_file, args.rank, device) 120 | 121 | # Obtain train,validation, and test masks 122 | # These are stored as node features. Partitioning may prepend 123 | # the node type to the mask names. So we use the convenience function 124 | # suffix_key_lookup to look up the mask name while ignoring the 125 | # arbitrary node type 126 | masks = {} 127 | for mask_name, indices_name in zip(['train_mask', 'val_mask', 'test_mask'], 128 | ['train_indices', 'val_indices', 'test_indices']): 129 | boolean_mask = sar.suffix_key_lookup(partition_data.node_features, 130 | mask_name) 131 | masks[indices_name] = boolean_mask.nonzero( 132 | as_tuple=False).view(-1).to(device) 133 | 134 | labels = sar.suffix_key_lookup(partition_data.node_features, 135 | 'labels').long().to(device) 136 | 137 | # Obtain the number of classes by finding the max label across all workers 138 | num_labels = labels.max() + 1 139 | sar.comm.all_reduce(num_labels, dist.ReduceOp.MAX, move_to_comm_device=True) 140 | num_labels = num_labels.item() 141 | 142 | features = sar.suffix_key_lookup(partition_data.node_features, 'features').to(device) 143 | full_graph_manager = sar.construct_full_graph(partition_data).to(device) 144 | 145 | #We do not need the partition data anymore 146 | del partition_data 147 | 148 | gnn_model = GNNModel(features.size(1), 149 | args.hidden_layer_dim, 150 | num_labels).to(device) 151 | print('model', gnn_model) 152 | 153 | # Synchronize the model parmeters across all workers 154 | sar.sync_params(gnn_model) 155 | 156 | # Obtain the number of labeled nodes in the training 157 | # This will be needed to properly obtain a cross entropy loss 158 | # normalized by the number of training examples 159 | n_train_points = torch.LongTensor([masks['train_indices'].numel()]) 160 | sar.comm.all_reduce(n_train_points, op=dist.ReduceOp.SUM, move_to_comm_device=True) 161 | n_train_points = n_train_points.item() 162 | 163 | optimizer = torch.optim.Adam(gnn_model.parameters(), lr=args.lr) 164 | for train_iter_idx in range(args.train_iters): 165 | # Train 166 | t_1 = time.time() 167 | logits = gnn_model(full_graph_manager, features) 168 | loss = F.cross_entropy(logits[masks['train_indices']], 169 | labels[masks['train_indices']], reduction='sum')/n_train_points 170 | 171 | optimizer.zero_grad() 172 | loss.backward() 173 | # Do not forget to gather the parameter gradients from all workers 174 | sar.gather_grads(gnn_model) 175 | optimizer.step() 176 | train_time = time.time() - t_1 177 | 178 | # Calculate accuracy for train/validation/test 179 | results = [] 180 | for indices_name in ['train_indices', 'val_indices', 'test_indices']: 181 | n_correct = (logits[masks[indices_name]].argmax(1) == 182 | labels[masks[indices_name]]).float().sum() 183 | results.extend([n_correct, masks[indices_name].numel()]) 184 | 185 | acc_vec = torch.FloatTensor(results) 186 | # Sum the n_correct, and number of mask elements across all workers 187 | sar.comm.all_reduce(acc_vec, op=dist.ReduceOp.SUM, move_to_comm_device=True) 188 | (train_acc, val_acc, test_acc) = \ 189 | (acc_vec[0] / acc_vec[1], 190 | acc_vec[2] / acc_vec[3], 191 | acc_vec[4] / acc_vec[5]) 192 | 193 | result_message = ( 194 | f"iteration [{train_iter_idx}/{args.train_iters}] | " 195 | ) 196 | result_message += ', '.join([ 197 | f"train loss={loss:.4f}, " 198 | f"Accuracy: " 199 | f"train={train_acc:.4f} " 200 | f"valid={val_acc:.4f} " 201 | f"test={test_acc:.4f} " 202 | f" | train time = {train_time} " 203 | f" |" 204 | ]) 205 | print(result_message, flush=True) 206 | 207 | 208 | if __name__ == '__main__': 209 | main() 210 | -------------------------------------------------------------------------------- /examples/train_homogeneous_sampling_basic.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Intel Corporation 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | from typing import List, Union, Dict 22 | from argparse import ArgumentParser 23 | import os 24 | import logging 25 | import psutil 26 | import time 27 | import torch 28 | import torch.nn.functional as F 29 | from torch import nn 30 | import torch.distributed as dist 31 | import dgl # type: ignore 32 | from dgl.heterograph import DGLBlock # type: ignore 33 | 34 | 35 | import sar 36 | 37 | 38 | parser = ArgumentParser( 39 | description="GNN training on node classification tasks in homogeneous graphs") 40 | 41 | 42 | parser.add_argument( 43 | "--partitioning-json-file", 44 | type=str, 45 | default="", 46 | help="Path to the .json file containing partitioning information " 47 | ) 48 | 49 | parser.add_argument('--ip-file', default='./ip_file', type=str, 50 | help='File with ip-address. Worker 0 creates this file and all others read it ') 51 | 52 | 53 | parser.add_argument('--backend', default='nccl', type=str, choices=['ccl', 'nccl', 'mpi', 'gloo'], 54 | help='Communication backend to use ' 55 | ) 56 | 57 | parser.add_argument( 58 | "--cpu-run", action="store_true", 59 | help="Run on CPUs if set, otherwise run on GPUs " 60 | ) 61 | 62 | parser.add_argument( 63 | "--precompute-batches", action="store_true", 64 | help="Precompute the batches " 65 | ) 66 | 67 | 68 | parser.add_argument( 69 | "--optimized-batches-cache", 70 | type=str, 71 | default="", 72 | help="Prefix of the files used to store precomputed batches " 73 | ) 74 | 75 | 76 | parser.add_argument('--train-iters', default=100, type=int, 77 | help='number of training iterations ') 78 | 79 | parser.add_argument( 80 | "--lr", 81 | type=float, 82 | default=1e-2, 83 | help="learning rate" 84 | ) 85 | 86 | 87 | parser.add_argument('--rank', default=0, type=int, 88 | help='Rank of the current worker ') 89 | 90 | parser.add_argument('--world-size', default=2, type=int, 91 | help='Number of workers ') 92 | 93 | parser.add_argument('--hidden-layer-dim', default=256, type=int, 94 | help='Dimension of GNN hidden layer') 95 | 96 | parser.add_argument('--batch-size', default=5000, type=int, 97 | help='per worker batch size ') 98 | 99 | parser.add_argument('--num-workers', default=0, type=int, 100 | help='number of dataloader workers ') 101 | 102 | parser.add_argument('--fanout', nargs="+", type=int, 103 | help='fanouts for sampling ') 104 | 105 | parser.add_argument('--max-collective-size', default=0, type=int, 106 | help='The maximum allowed size of the data in a collective. \ 107 | If a collective would communicate more than this maximum, it is split into multiple collectives.\ 108 | Collective calls with large data may cause instabilities in some communication backends ') 109 | 110 | 111 | class GNNModel(nn.Module): 112 | def __init__(self, in_dim: int, hidden_dim: int, out_dim: int): 113 | super().__init__() 114 | 115 | self.convs = nn.ModuleList([ 116 | # pylint: disable=no-member 117 | dgl.nn.SAGEConv(in_dim, hidden_dim, aggregator_type='mean'), 118 | # pylint: disable=no-member 119 | dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type='mean'), 120 | # pylint: disable=no-member 121 | dgl.nn.SAGEConv(hidden_dim, out_dim, aggregator_type='mean'), 122 | ]) 123 | 124 | def forward(self, blocks: List[Union[DGLBlock, sar.GraphShardManager]], features: torch.Tensor): 125 | for idx, conv in enumerate(self.convs): 126 | features = conv(blocks[idx], features) 127 | if idx < len(self.convs) - 1: 128 | features = F.relu(features, inplace=True) 129 | 130 | return features 131 | 132 | 133 | def main(): 134 | # psutil.Process().cpu_affinity([8]) 135 | args = parser.parse_args() 136 | print('args', args) 137 | 138 | use_gpu = torch.cuda.is_available() and not args.cpu_run 139 | device = torch.device('cuda') if use_gpu else torch.device('cpu') 140 | 141 | if args.rank == -1: 142 | # Try to infer the worker's rank from environment variables 143 | # created by mpirun or similar MPI launchers 144 | args.rank = int(os.environ.get("PMI_RANK", -1)) 145 | if args.rank == -1: 146 | args.rank = int(os.environ["RANK"]) 147 | 148 | # os.environ.putenv("GOMP_CPU_AFFINITY", "10,11") 149 | # os.environ.putenv("OMP_NUM_THREADS", "16") 150 | # Obtain the ip address of the master through the network file system 151 | master_ip_address = sar.nfs_ip_init(args.rank, args.ip_file) 152 | sar.initialize_comms(args.rank, 153 | args.world_size, master_ip_address, 154 | args.backend) 155 | 156 | # Load DGL partition data 157 | partition_data = sar.load_dgl_partition_data( 158 | args.partitioning_json_file, args.rank, torch.device('cpu')) 159 | 160 | # Obtain train,validation, and test masks 161 | # These are stored as node features. Partitioning may prepend 162 | # the node type to the mask names. So we use the convenience function 163 | # suffix_key_lookup to look up the mask name while ignoring the 164 | # arbitrary node type 165 | masks = {} 166 | for mask_name, indices_name in zip(['train_mask', 'val_mask', 'test_mask'], 167 | ['train_indices', 'val_indices', 'test_indices']): 168 | boolean_mask = sar.suffix_key_lookup(partition_data.node_features, 169 | mask_name) 170 | masks[indices_name] = boolean_mask.nonzero( 171 | as_tuple=False).view(-1).to(device) 172 | print(f'mask {indices_name} : {masks[indices_name]} ') 173 | labels = sar.suffix_key_lookup(partition_data.node_features, 174 | 'labels').long().to(device) 175 | 176 | # Obtain the number of classes by finding the max label across all workers 177 | num_labels = labels.max() + 1 178 | sar.comm.all_reduce(num_labels, dist.ReduceOp.MAX, 179 | move_to_comm_device=True) 180 | num_labels = num_labels.item() 181 | 182 | features = sar.suffix_key_lookup( 183 | partition_data.node_features, 'features') # keep features always on CPU 184 | full_graph_manager = sar.construct_full_graph( 185 | partition_data) # Keep full graph on CPU 186 | 187 | node_ranges = partition_data.node_ranges 188 | del partition_data 189 | 190 | gnn_model = GNNModel(features.size(1), 191 | args.hidden_layer_dim, 192 | num_labels).to(device) 193 | 194 | # gnn_model_cpu will be used for inference 195 | if use_gpu: 196 | gnn_model_cpu = GNNModel(features.size(1), 197 | args.hidden_layer_dim, 198 | num_labels) 199 | else: 200 | gnn_model_cpu = gnn_model 201 | print('model', gnn_model) 202 | 203 | # Synchronize the model parmeters across all workers 204 | sar.sync_params(gnn_model) 205 | 206 | optimizer = torch.optim.Adam(gnn_model.parameters(), lr=args.lr) 207 | 208 | neighbor_sampler = sar.core.sampling.DistNeighborSampler([15, 10, 5], 209 | input_node_features={ 210 | 'features': features}, 211 | output_node_features={ 212 | 'labels': labels}, 213 | output_device=device 214 | ) 215 | 216 | train_nodes = masks['train_indices'] + node_ranges[sar.rank()][0] 217 | 218 | dataloader = sar.core.sampling.DataLoader( 219 | full_graph_manager, 220 | train_nodes, 221 | neighbor_sampler, 222 | args.batch_size, 223 | shuffle=True, 224 | precompute_optimized_batches=args.precompute_batches, 225 | optimized_batches_cache=( 226 | args.optimized_batches_cache if args.optimized_batches_cache else None), 227 | num_workers=args.num_workers) 228 | 229 | print('sampling graph edata', full_graph_manager.sampling_graph.edata) 230 | 231 | for k in list(masks.keys()): 232 | masks[k] = masks[k].to(device) 233 | 234 | for train_iter_idx in range(args.train_iters): 235 | total_loss = 0 236 | gnn_model.train() 237 | sar.Config.max_collective_size = 0 238 | 239 | train_t1 = dataloader_t1 = time.time() 240 | n_total = n_correct = 0 241 | pure_training_time = 0 242 | for block_idx, blocks in enumerate(dataloader): 243 | start_t1 = time.time() 244 | loading_time = time.time() - dataloader_t1 245 | print(f'in block {block_idx} : {blocks}. Loaded in {loading_time}') 246 | # print('block edata', [block.edata[dgl.EID] for block in blocks]) 247 | blocks = [b.to(device) for b in blocks] 248 | block_features = blocks[0].srcdata['features'] 249 | block_labels = blocks[-1].dstdata['labels'] 250 | logits = gnn_model(blocks, block_features) 251 | 252 | loss = F.cross_entropy(logits, block_labels, reduction='mean') 253 | n_correct += (logits.argmax(1) == 254 | block_labels).float().sum() 255 | n_total += len(block_labels) 256 | 257 | total_loss += loss.item() 258 | optimizer.zero_grad() 259 | loss.backward() 260 | # Do not forget to gather the parameter gradients from all workers 261 | tg = time.time() 262 | sar.gather_grads(gnn_model) 263 | print('gather grad time', time.time() - tg) 264 | optimizer.step() 265 | dataloader_t1 = time.time() 266 | pure_training_time += (time.time() - start_t1) 267 | 268 | train_time = time.time() - train_t1 269 | print('train time', train_time, flush=True) 270 | print('pure train time', pure_training_time, flush=True) 271 | 272 | print('loss', total_loss, flush=True) 273 | print('accuracy ', n_correct/n_total, flush=True) 274 | 275 | # Full graph inference is done on CPUs using sequential 276 | # aggregation and re-materialization 277 | gnn_model_cpu.eval() 278 | if gnn_model_cpu is not gnn_model: 279 | gnn_model_cpu.load_state_dict(gnn_model.state_dict()) 280 | sar.Config.max_collective_size = args.max_collective_size 281 | with torch.no_grad(): 282 | # Calculate accuracy for train/validation/test 283 | logits = gnn_model_cpu( 284 | [full_graph_manager] * 3, features) 285 | results = [] 286 | for indices_name in ['train_indices', 'val_indices', 'test_indices']: 287 | n_correct = (logits[masks[indices_name]].argmax(1) == 288 | labels[masks[indices_name]].cpu()).float().sum() 289 | results.extend([n_correct, masks[indices_name].numel()]) 290 | 291 | acc_vec = torch.FloatTensor(results) 292 | # Sum the n_correct, and number of mask elements across all workers 293 | sar.comm.all_reduce(acc_vec, op=dist.ReduceOp.SUM, 294 | move_to_comm_device=True) 295 | (train_acc, val_acc, test_acc) = (acc_vec[0] / acc_vec[1], 296 | acc_vec[2] / acc_vec[3], 297 | acc_vec[4] / acc_vec[5]) 298 | 299 | result_message = ( 300 | f"iteration [{train_iter_idx}/{args.train_iters}] | " 301 | ) 302 | result_message += ', '.join([ 303 | f"Accuracy: " 304 | f"train={train_acc:.4f} " 305 | f"valid={val_acc:.4f} " 306 | f"test={test_acc:.4f} " 307 | f" | train time = {train_time} " 308 | f" |" 309 | ]) 310 | print(result_message, flush=True) 311 | 312 | 313 | if __name__ == '__main__': 314 | main() 315 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 68.2.2", "wheel"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --find-links=https://data.dgl.ai/wheels/repo.html 2 | dgl>=1.0.0 3 | numpy>=1.22.0 4 | torch>=1.10.0 5 | ifaddr>=0.1.7 -------------------------------------------------------------------------------- /sar/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Intel Corporation 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | 22 | ''' 23 | Top-level SAR package 24 | ''' 25 | from . import core 26 | 27 | from .comm import initialize_comms, rank, world_size, comm_device,\ 28 | nfs_ip_init, sync_params, gather_grads 29 | from .core import GraphShardManager, HeteroGraphShardManager, message_has_parameters, DistributedBlock,\ 30 | DistNeighborSampler, DataLoader 31 | from .construct_shard_manager import construct_mfgs, construct_full_graph, convert_dist_graph 32 | from .data_loading import load_dgl_partition_data, suffix_key_lookup 33 | from .distributed_bn import DistributedBN1D 34 | from .config import Config 35 | from .edge_softmax import edge_softmax 36 | from .patch_dgl import patch_dgl, patched_edge_softmax, RelGraphConv 37 | from .logging_setup import logging_setup, logger 38 | 39 | 40 | __all__ = ['initialize_comms', 'rank', 'world_size', 'nfs_ip_init', 41 | 'comm_device', 'DistributedBN1D', 42 | 'construct_mfgs', 'construct_full_graph', 'convert_dist_graph', 'GraphShardManager', 'HeteroGraphShardManager', 43 | 'load_dgl_partition_data', 'suffix_key_lookup', 'Config', 'edge_softmax', 44 | 'message_has_parameters', 'DistributedBlock', 'DistNeighborSampler', 'DataLoader', 45 | 'logging_setup', 'logger', 'RelGraphConv', 'sync_params', 'gather_grads', 'patch_dgl', 'patched_edge_softmax'] 46 | -------------------------------------------------------------------------------- /sar/common_tuples.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Intel Corporation 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | ''' 22 | Tuples for grouping related data 23 | ''' 24 | from typing import NamedTuple, Dict, Tuple, List, Optional, Any, TYPE_CHECKING 25 | from enum import Enum 26 | from torch import Tensor 27 | import dgl # type: ignore 28 | 29 | if TYPE_CHECKING: 30 | from .core.graphshard import GraphShardManager 31 | from .core.sar_aggregation import BackwardManager 32 | 33 | 34 | class TensorPlace(Enum): 35 | SRC = 0 36 | DST = 1 37 | EDGE = 2 38 | PARAM = 3 39 | 40 | 41 | class ShardEdgesAndFeatures(NamedTuple): 42 | ''' 43 | Stores the edge information for all edges connecting nodes in one partition to 44 | nodes in another partition. For an N-way partition, each worker will have N ShardEdgesAndFeatures object, 45 | where each object contains data for incoming edges from each partition (including the worker's own 46 | partition) 47 | 48 | 49 | .. py:attribute:: edges : Tuple[Tensor,Tensor] 50 | 51 | The source and destination global node ids for each edge in the shard 52 | 53 | 54 | .. py:attribute:: edge_features : Dict[str,Tensor] 55 | 56 | A dictionary of the edge features 57 | 58 | ''' 59 | edges: Tuple[Tensor, Tensor] 60 | edge_features: Dict[str, Tensor] 61 | 62 | 63 | class GraphShardManagerData(NamedTuple): 64 | all_shard_edges: List[ShardEdgesAndFeatures] 65 | src_node_ranges: List[Tuple[int, int]] 66 | tgt_node_range: Tuple[int, int] 67 | tgt_seed_nodes: Tensor 68 | local_src_seed_nodes: Tensor 69 | 70 | 71 | class PartitionData(NamedTuple): 72 | ''' 73 | Stores all the data for the local partition 74 | 75 | 76 | .. py:attribute:: all_shard_edges : List[ShardEdgesAndFeatures] 77 | 78 | A list of ShardEdgesAndFeatures objects. One for edges incoming from each partition 79 | 80 | 81 | .. py:attribute:: node_ranges : List[Tuple[int,int]] 82 | 83 | node_ranges[i] is a tuple of the start and end global node indices for nodes in partition i. 84 | 85 | .. py:attribute:: node_features : Dict[str,Tensor] 86 | 87 | Dictionary of node features for nodes in local partition 88 | 89 | .. py:attribute:: node_type_names : List[str] 90 | 91 | List of node type names. Use in conjunction with dgl.NTYPE node features to get\ 92 | the node type of each node 93 | 94 | .. py:attribute:: edge_type_names : List[str] 95 | 96 | List of edge type names. Use in conjunction with dgl.ETYPE edge features to get\ 97 | the edge type of each edge 98 | 99 | .. py:attribute:: partition_book : dgl.distributed.GraphPartitionBook 100 | 101 | The graph partition information 102 | 103 | 104 | ''' 105 | 106 | all_shard_edges: List[ShardEdgesAndFeatures] 107 | node_ranges: List[Tuple[int, int]] 108 | node_features: Dict[str, Tensor] 109 | node_type_names: List[str] 110 | edge_type_names: List[str] 111 | partition_book: dgl.distributed.GraphPartitionBook 112 | 113 | 114 | class AggregationData(NamedTuple): 115 | graph_shard_manager: "GraphShardManager" 116 | message_func: Any 117 | reduce_func: Any 118 | etype: Any 119 | all_input_names: List[Tuple[TensorPlace, str]] 120 | n_params: int 121 | grad_enabled: bool 122 | remote_data: bool 123 | 124 | 125 | class ShardInfo(NamedTuple): 126 | shard_idx: int 127 | src_node_range: Tuple[int, int] 128 | tgt_node_range: Tuple[int, int] 129 | edge_range: Tuple[int, int] 130 | 131 | 132 | class SocketInfo(NamedTuple): 133 | name: str 134 | ip_addr: str 135 | -------------------------------------------------------------------------------- /sar/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Intel Corporation 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | import torch 22 | 23 | 24 | class Config(object): 25 | ''' 26 | General configuration for the SAR library. 27 | 28 | 29 | .. py:attribute:: disable_sr : bool 30 | 31 | Disables sequential re-materialization of the computational graph during the backward pass.\ 32 | The computational graph is constructed normally during the forward pass. default : False 33 | 34 | 35 | 36 | .. py:attribute:: max_collective_size : int 37 | 38 | Limits the maximum size of data in torch.distributed.all_to_all collective calls. If non-zero,\ 39 | the sar.comms.all_to_all wrapper method will break down the collective call into multiple torch.distributed.all_to_all\ 40 | calls so that the size of the data in each call is below max_collective_size. default : 0 41 | 42 | .. py:attribute:: pipeline_depth : int 43 | 44 | Sets the communication pipeline depth when doing sequential aggregation or sequential re-materialization.\ 45 | In a separate thread, SAR will pre-fetch up to ``pipeline_depth`` remote partitions into a data queue that will then\ 46 | be processed by the compute thread. Higher values will increase memory consumption but may hide\ 47 | communication latency. default : 1 48 | 49 | 50 | ''' 51 | 52 | disable_sr: bool = False 53 | max_collective_size: int = 0 54 | pipeline_depth: int = 1 55 | -------------------------------------------------------------------------------- /sar/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Intel Corporation 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | ''' 22 | Modules for sharded data representation and management 23 | ''' 24 | from .graphshard import GraphShard, GraphShardManager, HeteroGraphShardManager 25 | from .sar_aggregation import message_has_parameters 26 | from .full_partition_block import DistributedBlock 27 | from .sampling import DistNeighborSampler, DataLoader 28 | 29 | __all__ = ['GraphShard', 'GraphShardManager', 'HeteroGraphShardManager', 'message_has_parameters', 30 | 'DistributedBlock', 'DistNeighborSampler', 'DataLoader'] 31 | -------------------------------------------------------------------------------- /sar/core/full_partition_block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Intel Corporation 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | from typing import List, Dict, Optional, Tuple 22 | import logging 23 | from collections.abc import MutableMapping 24 | import torch 25 | import dgl # type: ignore 26 | from torch import Tensor 27 | import torch.distributed as dist 28 | from torch.autograd import profiler 29 | 30 | 31 | from ..comm import all_to_all, world_size, rank, all_reduce 32 | 33 | logger = logging.getLogger(__name__) 34 | logger.addHandler(logging.NullHandler()) 35 | logger.setLevel(logging.DEBUG) 36 | 37 | 38 | class ProxyDataView(MutableMapping): 39 | """A distributed dictionary""" 40 | 41 | def __init__(self, tensor_sz: int, base_dict: MutableMapping, 42 | indices_required_from_me: List[Tensor], 43 | sizes_expected_from_others: List[int]): 44 | self.base_dict = base_dict 45 | self.tensor_sz = tensor_sz 46 | self.indices_required_from_me = indices_required_from_me 47 | self.sizes_expected_from_others = sizes_expected_from_others 48 | 49 | def set_base_dict(self, new_base_dict: MutableMapping): 50 | self.base_dict = new_base_dict 51 | 52 | def __setitem__(self, key: str, value: Tensor): 53 | assert value.size(0) == self.tensor_sz, \ 54 | f'Tenosr size {value.size()} does not match graph data size {self.tensor_sz}' 55 | logger.debug(f'Distributing item {key} among all DistributedBlocks') 56 | 57 | with profiler.record_function("COMM_FETCH"): 58 | exchange_result = tensor_exchange_op( 59 | value, self.indices_required_from_me, self.sizes_expected_from_others) 60 | 61 | self.base_dict[key] = exchange_result 62 | 63 | def __getitem__(self, key: str): 64 | return self.base_dict[key] 65 | 66 | def __delitem__(self, key: str): 67 | del self.base_dict[key] 68 | 69 | def __iter__(self): 70 | return iter(self.base_dict) 71 | 72 | def __len__(self): 73 | return len(self.base_dict) 74 | 75 | 76 | class DistributedBlock: 77 | """ 78 | A wrapper around a dgl.DGLBlock object. The DGLBlock object represents all the edges incoming 79 | to the local partition. It communicates with remote partitions to implement one-shot communication and 80 | aggregation in the forward and backward passes . You should not construct DistributedBlock directly, 81 | but instead use :meth:`GraphShardManager.get_full_partition_graph` 82 | 83 | :param block: A DGLBlock object representing all edges incoming to the local partition 84 | :type block: 85 | :param indices_required_from_me: The local node indices required by every other partition to carry out\ 86 | one-hop aggregation 87 | :type indices_required_from_me: List[Tensor] 88 | :param sizes_expected_from_others: The number of remote indices that we need to fetch\ 89 | from remote partitions to update the features of the nodes in the local partition 90 | :type sizes_expected_from_others: List[int] 91 | :param src_ranges: The global node ids of the start node and end node in each partition. Nodes in each\ 92 | partition have consecutive indices 93 | :type src_ranges: List[Tuple[int, int]] 94 | :param unique_src_nodes: The absolute node indices of the source nodes in each remote partition 95 | :type unique_src_nodes: List[Tensor] 96 | :param input_nodes: The indices of the input nodes relative to the starting node index of the local partition\ 97 | The input nodes are the nodes needed to produce the output node features assuming one-hop aggregation 98 | :type input_nodes: Tensor 99 | :param seeds: The node indices of the output nodes relative to the starting node index of the local partition 100 | :type seeds: Tensor 101 | :param edge_type_names: A list of edge type names 102 | :type edge_type_names: List[str] 103 | 104 | """ 105 | 106 | def __init__(self, block, indices_required_from_me: List[Tensor], 107 | sizes_expected_from_others: List[int], 108 | src_ranges: List[Tuple[int, int]], 109 | unique_src_nodes: List[Tensor], 110 | input_nodes: Tensor, 111 | seeds: Tensor, 112 | edge_type_names: List[str]): 113 | 114 | self._block = block 115 | self.indices_required_from_me = indices_required_from_me 116 | self.sizes_expected_from_others = sizes_expected_from_others 117 | self.src_ranges = src_ranges 118 | self.unique_src_nodes = unique_src_nodes 119 | self.edge_type_names = edge_type_names 120 | self.input_nodes = input_nodes 121 | self.seeds = seeds 122 | 123 | self.srcdata = ProxyDataView(input_nodes.size(0), 124 | block.srcdata, indices_required_from_me, sizes_expected_from_others) 125 | 126 | self.out_degrees_cache: Dict[Optional[str], Tensor] = {} 127 | 128 | def out_degrees(self, vertices=dgl.ALL, etype=None) -> Tensor: 129 | if etype not in self.out_degrees_cache: 130 | src_out_degrees = self._block.out_degrees(etype=etype) 131 | src_out_degrees_split = torch.split(src_out_degrees, self.sizes_expected_from_others) 132 | 133 | for comm_round in range(world_size()): 134 | out_degrees = torch.zeros( 135 | self.src_ranges[comm_round][1] - self.src_ranges[comm_round][0], 136 | dtype=self._block.idtype).to(self._block.device) 137 | 138 | out_degrees[self.unique_src_nodes[comm_round] - self.src_ranges[comm_round][0] 139 | ] = src_out_degrees_split[comm_round] 140 | all_reduce(out_degrees, op=dist.ReduceOp.SUM, move_to_comm_device=True) 141 | if comm_round == rank(): 142 | out_degrees[out_degrees == 0] = 1 143 | self.out_degrees_cache[etype] = out_degrees.to(self._block.device) 144 | 145 | if vertices == dgl.ALL: 146 | return self.out_degrees_cache[etype] 147 | 148 | return self.out_degrees_cache[etype][vertices] 149 | 150 | def to(self, device: torch.device): 151 | self._block = self._block.to(device) 152 | self.srcdata.set_base_dict(self._block.srcdata) 153 | return self 154 | 155 | def __getattr__(self, name): 156 | return getattr(self._block, name) 157 | 158 | 159 | class TensorExchangeOp(torch.autograd.Function): # pylint: disable = abstract-method 160 | @ staticmethod 161 | # pylint: disable = arguments-differ,unused-argument 162 | def forward(ctx, val: Tensor, indices_required_from_me: Tensor, # type: ignore 163 | sizes_expected_from_others: Tensor) -> Tensor: # type: ignore 164 | ctx.sizes_expected_from_others = sizes_expected_from_others 165 | ctx.indices_required_from_me = indices_required_from_me 166 | ctx.input_size = val.size() 167 | 168 | send_tensors = [val[indices] for indices in indices_required_from_me] 169 | recv_tensors = [val.new(sz_from_others, *val.size()[1:]) 170 | for sz_from_others in sizes_expected_from_others] 171 | 172 | all_to_all(recv_tensors, send_tensors, move_to_comm_device=True) 173 | 174 | return torch.cat(recv_tensors) 175 | 176 | @ staticmethod 177 | # pylint: disable = arguments-differ 178 | # type: ignore 179 | def backward(ctx, grad): 180 | send_tensors = list(torch.split(grad, ctx.sizes_expected_from_others)) 181 | recv_tensors = [grad.new(len(indices), *grad.size()[1:]) 182 | for indices in ctx.indices_required_from_me] 183 | all_to_all(recv_tensors, send_tensors, move_to_comm_device=True) 184 | 185 | input_grad = grad.new(ctx.input_size).zero_() 186 | for r_tensor, indices in zip(recv_tensors, ctx.indices_required_from_me): 187 | input_grad[indices] += r_tensor 188 | 189 | return input_grad, None, None 190 | 191 | 192 | tensor_exchange_op = TensorExchangeOp.apply 193 | -------------------------------------------------------------------------------- /sar/data_loading.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Intel Corporation 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | from typing import List, Tuple, Dict, Optional 22 | import torch 23 | from torch import Tensor 24 | import dgl # type: ignore 25 | from dgl.distributed.partition import load_partition # type: ignore 26 | from .common_tuples import PartitionData, ShardEdgesAndFeatures 27 | 28 | 29 | def suffix_key_lookup(feature_dict: Dict[str, Tensor], key: str, 30 | expand_to_all: bool = False, 31 | type_list: Optional[List[str]] = None) -> Tensor: 32 | """ 33 | Looks up the provided key in the provided dictionary. Uses suffix matching, where a dictionary 34 | key matches if ends with the provided key. This allows feature name lookup in the edge/node 35 | feature dictionaries in DGL's partition data whose keys have the form 36 | ``{node or edge type name}/{feature_name}``. In heterogeneous graphs, some features might only 37 | be present for certain node/edge types. Set the ``expand_to_all`` flag to expand the 38 | feature tensor to all nodes/edges in the graph. The expanded tensor will be zero for all 39 | nodes/edges where the requested feature is not present 40 | 41 | 42 | :param feature_dict: Node or edge feature dictionary 43 | :type feature_dict: Dict[str, Tensor] 44 | :param key: Key to look up. 45 | :type key: str 46 | :param expand_to_all: Expand feature tensor to all nodes/edges. 47 | :type expand_to_all: bool 48 | :param type_list: List of edge or node type names. Required if ``expand_to_all`` is ``True`` 49 | :type type_list: Optional[List[str]] 50 | :returns: The matched (possibly expanded) feature tensor 51 | 52 | """ 53 | matched_keys = [k for k in feature_dict if k.endswith(key)] 54 | if len(matched_keys) == 0: 55 | return torch.LongTensor([]) 56 | assert len(matched_keys) == 1 57 | matched_features = feature_dict[matched_keys[0]] 58 | if expand_to_all: 59 | assert type_list is not None 60 | if len(type_list) > 1 and dgl.NTYPE in feature_dict: 61 | type_id = feature_dict[dgl.NTYPE] 62 | key_node_type = matched_keys[0].split('/')[0] 63 | node_type_idx = type_list.index(key_node_type) 64 | 65 | expanded_features = matched_features.new( 66 | type_id.size(0), *matched_features.size()[1:]).zero_() 67 | expanded_features[type_id == node_type_idx] = matched_features 68 | return expanded_features 69 | 70 | return matched_features 71 | 72 | 73 | def _mask_features_dict(edge_features: Dict[str, Tensor], 74 | mask: Tensor, device: torch.device) -> Dict[str, Tensor]: 75 | #TODO(kpietkun): allow using edge features 76 | return {k: edge_features[k][mask].to(device) for k in edge_features if k == dgl.ETYPE} 77 | 78 | 79 | def _get_type_ordered_edges(edge_mask: Tensor, edge_types: Tensor, 80 | n_edge_types: int) -> Tensor: 81 | reordered_edge_mask: List[Tensor] = [] 82 | for edge_type_idx in range(n_edge_types): 83 | edge_mask_typed = torch.logical_and( 84 | edge_mask, edge_types == edge_type_idx) 85 | reordered_edge_mask.append( 86 | edge_mask_typed.nonzero(as_tuple=False).view(-1)) 87 | 88 | return torch.cat(reordered_edge_mask) 89 | 90 | 91 | def create_partition_data(graph: dgl.DGLGraph, 92 | own_partition_idx: int, 93 | node_features: Dict[str, torch.Tensor], 94 | edge_features: Dict[str, Tensor], 95 | partition_book: dgl.distributed.GraphPartitionBook, 96 | node_type_list: List[str], 97 | edge_type_list: List[str], 98 | device: torch.device) -> PartitionData: 99 | """ 100 | Creates SAR's PartitionData object basing on graph partition and features. 101 | 102 | :param graph: The graph partition structure for specific ``own_partition_idx`` 103 | :type graph: dgl.DGLGraph 104 | :param own_partition_idx: The index of the partition to create. This is typically the\ 105 | worker/machine rank 106 | :type own_partition_idx: int 107 | :param node_features: Dictionary containing node features for graph partition 108 | :type node_features: Dict[str, Tensor] 109 | :param edge_features: Dictionary containing edge features for graph partition 110 | :type edge_features: Dict[(str, str, str), Tensor] 111 | :param partition_book: The graph partition information 112 | :type partition_book: dgl.distributed.GraphPartitionBook 113 | :param node_type_list: List of node types 114 | :type node_type_list: List[str] 115 | :param edge_type_list: List of edge types 116 | :type edge_type_list: List[str] 117 | :param device: Device on which to place the loaded partition data 118 | :type device: torch.device 119 | :returns: The loaded partition data 120 | """ 121 | is_heterogeneous = (len(edge_type_list) > 1) 122 | # Delete redundant edge features with keys {relation name}/reltype. graph.edata[dgl.ETYPE ] already contains 123 | # the edge type in a heterogeneous graph 124 | if is_heterogeneous: 125 | for edge_feat_key in list(edge_features.keys()): 126 | if 'reltype' in edge_feat_key: 127 | del edge_features[edge_feat_key] 128 | 129 | # Obtain the node ranges in each partition in the homogenized graph 130 | start_node_idx = 0 131 | node_ranges: List[Tuple[int, int]] = [] 132 | for part_metadata in partition_book.metadata(): 133 | node_ranges.append( 134 | (start_node_idx, start_node_idx + part_metadata['num_nodes'])) 135 | start_node_idx += part_metadata['num_nodes'] 136 | 137 | # Include the node types in the node feature dictionary 138 | if dgl.NTYPE in graph.ndata: 139 | node_features[dgl.NTYPE] = graph.ndata[dgl.NTYPE][graph.ndata['inner_node'].bool()] 140 | else: 141 | node_features[dgl.NTYPE] = torch.zeros(graph.num_nodes(), dtype=torch.int32)[graph.ndata['inner_node'].bool()] 142 | 143 | # Include the edge types in the edge feature dictionary 144 | inner_edge_mask = graph.edata['inner_edge'].bool() 145 | if dgl.ETYPE in graph.edata: 146 | edge_features[dgl.ETYPE] = graph.edata[dgl.ETYPE][inner_edge_mask] 147 | else: 148 | edge_features[dgl.ETYPE] = torch.zeros(graph.num_edges(), dtype=torch.int32)[inner_edge_mask] 149 | 150 | # Obtain the inner edges. These are the partition edges 151 | local_partition_edges = torch.stack(graph.all_edges())[:, inner_edge_mask] 152 | # Use global node ids in partition_edges 153 | partition_edges = graph.ndata[dgl.NID][local_partition_edges] 154 | 155 | # Check that all target nodes lie in the current partition 156 | assert partition_edges[1].min() >= node_ranges[own_partition_idx][0] \ 157 | and partition_edges[1].max() < node_ranges[own_partition_idx][1] 158 | 159 | all_shard_edges: List[ShardEdgesAndFeatures] = [] 160 | 161 | for part_idx in range(partition_book.num_partitions()): 162 | # obtain the mask for edges originating from partition part_idx 163 | edge_mask = torch.logical_and(partition_edges[0] >= node_ranges[part_idx][0], 164 | partition_edges[0] < node_ranges[part_idx][1]) 165 | 166 | # Reorder the edges in each shard so that edges with the same type 167 | # follow each other 168 | if is_heterogeneous: 169 | edge_mask = _get_type_ordered_edges( 170 | edge_mask, edge_features[dgl.ETYPE], len(edge_type_list)) 171 | 172 | all_shard_edges.append(ShardEdgesAndFeatures( 173 | (partition_edges[0, edge_mask], partition_edges[1, edge_mask]), 174 | _mask_features_dict(edge_features, edge_mask, device) 175 | )) 176 | 177 | return PartitionData(all_shard_edges, 178 | node_ranges, 179 | node_features, 180 | node_type_list, 181 | edge_type_list, 182 | partition_book 183 | ) 184 | 185 | 186 | def load_dgl_partition_data(partition_json_file: str, 187 | own_partition_idx: int, device: torch.device) -> PartitionData: 188 | """ 189 | Loads partition data created by DGL's ``partition_graph`` function 190 | 191 | :param partition_json_file: Path to the .json file containing partitioning data 192 | :type partition_json_file: str 193 | :param own_partition_idx: The index of the partition to load. This is typically the\ 194 | worker/machine rank 195 | :type own_partition_idx: int 196 | :param device: Device on which to place the loaded partition data 197 | :type device: torch.device 198 | :returns: The loaded partition data 199 | 200 | """ 201 | (graph, node_features, 202 | edge_features, partition_book, _, 203 | node_type_list, edge_type_list) = load_partition(partition_json_file, own_partition_idx) 204 | 205 | return create_partition_data(graph, own_partition_idx, 206 | node_features, edge_features, 207 | partition_book, node_type_list, 208 | edge_type_list, device) 209 | 210 | def load_dgl_partition_data_from_graph(graph: dgl.distributed.DistGraph, 211 | device: torch.device) -> PartitionData: 212 | """ 213 | Loads partition data from DistGraph object 214 | 215 | :param graph: The distributed graph 216 | :type graph: dgl.distributed.DistGraph 217 | :param device: Device on which to place the loaded partition data 218 | :type device: torch.device 219 | :returns: The loaded partition data 220 | 221 | """ 222 | own_partition_idx = graph.rank() 223 | local_g = graph.local_partition 224 | 225 | assert dgl.NID in local_g.ndata 226 | assert dgl.EID in local_g.edata 227 | 228 | # get originalmapping for node and edge ids 229 | orig_n_ids = local_g.ndata[dgl.NID][local_g.ndata['inner_node'].bool().nonzero().view(-1)] 230 | orig_e_ids = local_g.edata[dgl.EID][local_g.edata['inner_edge'].bool().nonzero().view(-1)] 231 | 232 | # fetch local features from DistTensor 233 | node_features = {key : torch.Tensor(graph.ndata[key][orig_n_ids]) for key in list(graph.ndata.keys())} 234 | edge_features = {key : torch.Tensor(graph.edata[key][orig_e_ids]) for key in list(graph.edata.keys())} 235 | 236 | partition_book = graph.get_partition_book() 237 | node_type_list = local_g.ntypes 238 | edge_type_list = [local_g.to_canonical_etype(etype) for etype in graph.etypes] 239 | 240 | return create_partition_data(local_g, own_partition_idx, 241 | node_features, edge_features, 242 | partition_book, node_type_list, 243 | edge_type_list, device) -------------------------------------------------------------------------------- /sar/distributed_bn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Intel Corporation 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | from typing import Optional 22 | import torch 23 | from torch import Tensor 24 | import torch.distributed as dist 25 | from torch import nn 26 | from torch.nn import Parameter 27 | from torch.nn import init 28 | from .comm import all_reduce, comm_device, is_initialized 29 | 30 | 31 | class DistributedBN1D(nn.Module): 32 | """Distributed Batch normalization layer 33 | 34 | Normalizes a 2D feature tensor using the global mean and standard deviation calculated across all workers. 35 | 36 | 37 | :param n_feats: The second dimension (feature dimension) in the 2D input tensor 38 | :type n_feats: int 39 | :param eps: a value added to the variance for numerical stability 40 | :type eps: float 41 | :param affine: When ``True``, the module will use learnable affine parameter 42 | :type affine: bool 43 | :param distributed: Boolean speficying whether to run in distributed mode where normalizing\ 44 | statistics are calculated across all workers, or local mode where the normalizing statistics\ 45 | are calculated using only the local input feature tensor. If not specified, it will be set to\ 46 | ``True`` if the user has called :func:`sar.initialize_comms`, and ``False`` otherwise 47 | :type distributed: Optional[bool] 48 | 49 | """ 50 | def __init__(self, n_feats: int, eps: float = 1.0e-5, affine: bool = True, distributed: Optional[bool] = None): 51 | super().__init__() 52 | self.n_feats = n_feats 53 | self.weight: Optional[Parameter] 54 | self.bias: Optional[Parameter] 55 | self.affine = affine 56 | if affine: 57 | self.weight = Parameter(torch.ones(n_feats)) 58 | self.bias = Parameter(torch.zeros(n_feats)) 59 | else: 60 | self.weight = None 61 | self.bias = None 62 | 63 | self.eps = eps 64 | 65 | if distributed is None: 66 | self.distributed = is_initialized() 67 | else: 68 | self.distributed = distributed 69 | 70 | def forward(self, inp): 71 | ''' 72 | forward implementation of DistributedBN1D 73 | ''' 74 | assert inp.ndim == 2, 'distributedBN1D must have a 2D input' 75 | if self.distributed: 76 | mean, var = mean_op(inp), var_op(inp) 77 | std = torch.sqrt(var - mean**2 + self.eps) 78 | else: 79 | mean = inp.mean(0) 80 | std = inp.std(0) 81 | normalized_x = (inp - mean.unsqueeze(0)) / std.unsqueeze(0) 82 | 83 | if self.weight is not None and self.bias is not None: 84 | result = normalized_x * self.weight.unsqueeze(0) + self.bias.unsqueeze(0) 85 | else: 86 | result = normalized_x 87 | return result 88 | 89 | def reset_parameters(self): 90 | if self.affine: 91 | init.ones_(self.weight) 92 | init.zeros_(self.bias) 93 | 94 | 95 | 96 | class MeanOp(torch.autograd.Function): # pylint: disable = abstract-method 97 | @staticmethod 98 | # pylint: disable = arguments-differ 99 | def forward(ctx, x): 100 | own_sum = torch.empty(x.size(1)+1, device=comm_device()) 101 | own_sum[:-1] = x.sum(0).data.to(comm_device()) 102 | own_sum[-1] = x.size(0) 103 | all_reduce(own_sum, op=dist.ReduceOp.SUM,move_to_comm_device = True) 104 | mean = (own_sum[:-1]/own_sum[-1]).to(x.device) 105 | ctx.n_points = torch.round(own_sum[-1]).long().item() 106 | ctx.inp_size = x.size(0) 107 | return mean 108 | 109 | @staticmethod 110 | # pylint: disable = arguments-differ 111 | def backward(ctx, grad): 112 | grad_comm = grad.to(comm_device()) 113 | all_reduce(grad_comm, op=dist.ReduceOp.SUM,move_to_comm_device = True) 114 | return grad_comm.repeat(ctx.inp_size, 1).to(grad.device) / ctx.n_points 115 | 116 | 117 | class VarOp(torch.autograd.Function): # pylint: disable = abstract-method 118 | @staticmethod 119 | # pylint: disable = arguments-differ 120 | def forward(ctx, features): 121 | own_sum = torch.empty(features.size(1)+1, device=comm_device()) 122 | own_sum[:-1] = (features**2).sum(0).data.to(comm_device()) 123 | own_sum[-1] = features.size(0) 124 | all_reduce(own_sum, op=dist.ReduceOp.SUM,move_to_comm_device = True) 125 | variance = (own_sum[:-1]/own_sum[-1]).to(features.device) 126 | 127 | ctx.n_points = torch.round(own_sum[-1]).long().item() 128 | ctx.save_for_backward(features) 129 | return variance 130 | 131 | @staticmethod 132 | # pylint: disable = arguments-differ 133 | def backward(ctx, grad): 134 | features, = ctx.saved_tensors 135 | grad_comm = grad.to(comm_device()) 136 | all_reduce(grad_comm, op=dist.ReduceOp.SUM,move_to_comm_device = True) 137 | return (grad_comm.to(grad.device).unsqueeze(0) * 2 * features) / ctx.n_points 138 | 139 | 140 | mean_op = MeanOp.apply 141 | var_op = VarOp.apply 142 | -------------------------------------------------------------------------------- /sar/edge_softmax.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Intel Corporation 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | import torch 22 | import dgl.function as fn # type: ignore 23 | import dgl # type: ignore 24 | from . import GraphShardManager 25 | 26 | 27 | def edge_softmax(graph: GraphShardManager, 28 | logits: torch.Tensor, eids=dgl.ALL, norm_by: str = 'dst') -> torch.Tensor: 29 | """ 30 | Implements a similar functionality as DGL's ``dgl.nn.edge_softmax`` on distributed graphs. 31 | 32 | Only supports a subset of the possible argument values. 33 | 34 | :param graph: The distributed graph 35 | :type graph: GraphShardManager 36 | :param logits: The edge logits. The size of the first dimension should be the same as the number of edges in the ``graph`` argument 37 | :type logits: torch.Tensor 38 | :param eids: must be ``dgl.ALL`` 39 | :type eids: 40 | :param norm_by: must be ``'dst'`` 41 | :type norm_by: str 42 | :returns: A tensor with the same size as logits contaning the softmax-normalized logits 43 | 44 | """ 45 | 46 | assert eids == dgl.ALL, \ 47 | 'edge_softmax on GraphShardManager only supported when eids==dgl.ALL' 48 | 49 | assert norm_by == 'dst', \ 50 | 'edge_softmax on GraphShardManager only supported when norm_by==dst' 51 | 52 | with graph.local_scope(): 53 | graph.edata['logits'] = logits 54 | with torch.no_grad(): 55 | graph.update_all(fn.copy_e('logits', 'temp'), 56 | fn.max('temp', 'max_logits')) # pylint: disable=no-member 57 | 58 | graph.apply_edges( 59 | fn.e_sub_v('logits', 'max_logits', 'adjusted_logits')) # pylint: disable=no-member 60 | 61 | graph.edata['exp_logits'] = torch.exp(graph.edata.pop('adjusted_logits')) 62 | 63 | graph.update_all(fn.copy_e('exp_logits', 'temp'), 64 | fn.sum('temp', 'normalization')) # pylint: disable=no-member 65 | 66 | graph.apply_edges( 67 | fn.e_div_v('exp_logits', 'normalization', 'sm_output')) # pylint: disable=no-member 68 | 69 | sm_output = graph.edata.pop('sm_output') 70 | 71 | return sm_output 72 | -------------------------------------------------------------------------------- /sar/logging_setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Intel Corporation 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | from typing import Optional, Union 22 | import logging 23 | 24 | logger = logging.getLogger('sar') 25 | logger.addHandler(logging.NullHandler()) 26 | logger.propagate = False 27 | 28 | 29 | def logging_setup(log_level: int, _rank: int = -1, _world_size: int = -1, log_file: Optional[str] = None): 30 | formatter = logging.Formatter( 31 | f'{_rank+1} / {_world_size} - %(name)s - %(levelname)s - %(message)s') 32 | 33 | handler: Union[logging.FileHandler, logging.StreamHandler] 34 | if log_file is not None: 35 | handler = logging.FileHandler(log_file, 'w') 36 | else: 37 | handler = logging.StreamHandler() 38 | 39 | handler.setFormatter(formatter) 40 | handler.setLevel(log_level) 41 | 42 | logger.setLevel(log_level) 43 | logger.addHandler(handler) 44 | -------------------------------------------------------------------------------- /sar/patch_dgl.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Intel Corporation 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | from typing import Union, List, Optional 22 | import dgl # type: ignore 23 | from .edge_softmax import edge_softmax 24 | from .core import GraphShardManager 25 | 26 | from . import message_has_parameters 27 | 28 | 29 | 30 | # patch edge_softmax in dgl's nn modules 31 | 32 | class RelGraphConv(dgl.nn.pytorch.conv.RelGraphConv): 33 | def __init__(self, *args, **kwargs): 34 | super().__init__(*args, **kwargs) 35 | 36 | @message_has_parameters(lambda self: tuple(self.linear_r.parameters())) 37 | def message(self, edges): 38 | return super().message(edges) 39 | 40 | 41 | def patched_edge_softmax(graph, *args, **kwargs): 42 | if isinstance(graph, GraphShardManager): 43 | return edge_softmax(graph, *args, **kwargs) 44 | 45 | return dgl.nn.edge_softmax(graph, *args, **kwargs) # pylint: disable=no-member 46 | 47 | 48 | def patch_dgl(): 49 | """Patches DGL so that attention layers (``gatconv``, ``dotgatconv``, 50 | ``agnngatconv``) use a different ``edge_softmax`` function 51 | that supports :class:`sar.core.GraphShardManager`. Also modifies DGL's 52 | ``RelGraphConv`` to add a decorator to its ``message`` function to tell 53 | SAR how to find the parameters used to create edge messages. 54 | 55 | """ 56 | dgl.nn.pytorch.conv.gatconv.edge_softmax = patched_edge_softmax 57 | dgl.nn.pytorch.conv.dotgatconv.edge_softmax = patched_edge_softmax 58 | dgl.nn.pytorch.conv.agnnconv.edge_softmax = patched_edge_softmax 59 | 60 | dgl.nn.pytorch.conv.RelGraphConv = RelGraphConv 61 | dgl.nn.RelGraphConv = RelGraphConv 62 | -------------------------------------------------------------------------------- /security.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | Intel is committed to rapidly addressing security vulnerabilities affecting our customers and providing clear guidance on the solution, impact, severity and mitigation. 3 | 4 | ## Reporting a Vulnerability 5 | Please report any security vulnerabilities in this project [utilizing the guidelines here](https://www.intel.com/content/www/us/en/security-center/vulnerability-handling-guidelines.html). 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | from pathlib import Path 4 | this_directory = Path(__file__).parent 5 | long_description = (this_directory / "README.md").read_text() 6 | 7 | setup( 8 | name='sar-gnn', 9 | version='0.1.0', 10 | python_requires='>=3.8', 11 | install_requires=[ 12 | 'dgl>=1.0.0', 13 | 'numpy>=1.22.0', 14 | 'torch>=1.10.0', 15 | 'ifaddr>=0.1.7', 16 | 'packaging>=23.1' 17 | ], 18 | packages=find_packages(), 19 | author='Hesham Mostafa', 20 | author_email='hesham.mostafa@intel.com', 21 | maintainer='Kacper Pietkun', 22 | maintainer_email='kacper.pietkun@intel.com', 23 | description='A Python library for distributed training of Graph Neural Networks (GNNs) on large graphs, ' 24 | 'supporting both full-batch and sampling-based training, and utilizing a sequential aggregation' 25 | 'and rematerialization technique for linear memory scaling.', 26 | long_description=long_description, 27 | long_description_content_type='text/markdown', 28 | project_urls={ 29 | 'GitHub': 'https://github.com/IntelLabs/SAR/', 30 | 'Documentation': 'https://sar.readthedocs.io/en/latest/', 31 | }, 32 | license='MIT', 33 | classifiers=[ 34 | 'License :: OSI Approved :: MIT License', 35 | "Operating System :: OS Independent", 36 | "Programming Language :: Python :: 3.8" 37 | ] 38 | ) -------------------------------------------------------------------------------- /tests/base_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sar 3 | import torch 4 | import torch.distributed as dist 5 | import dgl 6 | from constants import * 7 | # IMPORTANT - This module should be imported independently 8 | # only by the child processes - i.e. separate workers 9 | 10 | 11 | def initialize_worker(rank, world_size, tmp_dir, backend="ccl"): 12 | """ 13 | Boilerplate code for setting up connection between workers 14 | 15 | :param rank: Rank of the current machine 16 | :type rank: int 17 | :param world_size: Number of workers. The same as the number of graph partitions 18 | :type world_size: int 19 | :param tmp_dir: Path to the directory where ip file will be created 20 | :type tmp_dir: str 21 | """ 22 | torch.seed() 23 | ip_file = os.path.join(tmp_dir, 'ip_file') 24 | master_ip_address = sar.nfs_ip_init(rank, ip_file) 25 | sar.initialize_comms(rank, world_size, master_ip_address, backend) 26 | 27 | 28 | def load_partition_data(rank, graph_name, tmp_dir): 29 | """ 30 | Boilerplate code for loading partition data with standard `full_graph_manager` (FGM) 31 | 32 | :param rank: Rank of the current machine 33 | :type rank: int 34 | :param graph_name: Name of the partitioned graph 35 | :type graph_name: str 36 | :param tmp_dir: Path to the directory where partition data is located 37 | :type tmp_dir: str 38 | :returns: Tuple consisting of GraphShardManager object, partition features and labels 39 | """ 40 | partition_file = os.path.join(tmp_dir, f'{graph_name}.json') 41 | partition_data = sar.load_dgl_partition_data(partition_file, rank, "cpu") 42 | full_graph_manager = sar.construct_full_graph(partition_data).to('cpu') 43 | features = sar.suffix_key_lookup(partition_data.node_features, 'features') 44 | labels = sar.suffix_key_lookup(partition_data.node_features, 'labels') 45 | return full_graph_manager, features, labels 46 | 47 | 48 | def load_partition_data_mfg(rank, graph_name, tmp_dir): 49 | """ 50 | Boilerplate code for loading partition data with message flow graph (MFG) 51 | 52 | :param rank: Rank of the current machine 53 | :type rank: int 54 | :param graph_name: Name of the partitioned graph 55 | :type graph_name: str 56 | :param tmp_dir: Path to the directory where partition data is located 57 | :type tmp_dir: str 58 | :returns: Tuple consisting of GraphShardManager object, partition features and labels 59 | """ 60 | partition_file = os.path.join(tmp_dir, f'{graph_name}.json') 61 | partition_data = sar.load_dgl_partition_data(partition_file, rank, "cpu") 62 | blocks = sar.construct_mfgs(partition_data, 63 | (partition_data.node_features[dgl.NTYPE] == 0).nonzero(as_tuple=True)[0] + 64 | partition_data.node_ranges[sar.comm.rank()][0], 65 | 3, True) 66 | blocks = [block.to('cpu') for block in blocks] 67 | features = sar.suffix_key_lookup(partition_data.node_features, 'features') 68 | labels = sar.suffix_key_lookup(partition_data.node_features, 'labels') 69 | return blocks, features, labels 70 | 71 | 72 | def synchronize_processes(): 73 | """ 74 | Function that simulates dist.barrier (using all_reduce because there is an issue with dist.barrier() in ccl) 75 | """ 76 | dummy_tensor = torch.tensor(1) 77 | dist.all_reduce(dummy_tensor, dist.ReduceOp.MAX) 78 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import dgl 3 | from dgl.heterograph import DGLGraph 4 | import pytest 5 | import tempfile 6 | from constants import * 7 | import torch 8 | from torch import Tensor 9 | import multiprocessing as mp 10 | from typing import NamedTuple, Union, Dict 11 | 12 | 13 | class FixtureEnv(NamedTuple): 14 | """ 15 | Stores information about variables needed by tests 16 | 17 | .. py:attribute:: temp_dir : str 18 | 19 | Path to a temporary directory which will be 20 | deleted after tests are finished 21 | 22 | .. py:attribute:: homo_graph : DGLGraph 23 | 24 | DGLGraph object representing homogenous graph 25 | 26 | .. py:attribute:: hetero_graph : DGLGraph 27 | 28 | DGLGraph object representing heterogeneous graph 29 | 30 | .. py:attribute:: node_map : Dict[str, Union[Tensor, Dict[str, Tensor]]] 31 | 32 | Dictionary of tensors representing mapping between shuffled node IDs 33 | and the original node IDs for a homogeneous graph. 34 | or, dict of dicts of tensors whose key is the node type and value 35 | is a tensor mapping between shuffled node IDs and the 36 | original node IDs for each node type for a heterogeneous graph. 37 | Each dict element represents node_mapping for different graph. 38 | Homogeneous/heterogeneous and different world_sizes 39 | """ 40 | temp_dir: str 41 | homo_graph: DGLGraph 42 | hetero_graph: DGLGraph 43 | node_map: Dict[str, Union[Tensor, Dict[str, Tensor]]] 44 | 45 | 46 | @pytest.fixture(autouse=True, scope="session") 47 | def fixture_env(): 48 | """ 49 | Create temp directory that will be used by every test. 50 | Create and save partitioned graphs in that directory. 51 | """ 52 | manager = mp.Manager() 53 | mp_dict = manager.dict() 54 | with tempfile.TemporaryDirectory() as temp_dir: 55 | p = mp.Process(target=graph_partitioning, args=(mp_dict, temp_dir,)) 56 | p.start() 57 | p.join() 58 | yield FixtureEnv(temp_dir, 59 | mp_dict["homo_graph"], 60 | mp_dict["hetero_graph"], 61 | mp_dict["node_map"]) 62 | 63 | 64 | def graph_partitioning(mp_dict, temp_dir): 65 | """ 66 | Create and partition both homogeneous and heterogeneous 67 | graphs for different world_sizes 68 | 69 | :param temp_dir: Path to the directory where graphs will be partitioned 70 | :type temp_dir: str 71 | """ 72 | homo_g = get_random_graph() 73 | hetero_g = get_random_hetero_graph() 74 | 75 | node_mappings = {} 76 | world_sizes = [1, 2, 4, 8] 77 | for world_size in world_sizes: 78 | partition_homo_dir = os.path.join(temp_dir, f"homogeneous_{world_size}") 79 | os.makedirs(partition_homo_dir) 80 | node_map, _ = dgl.distributed.partition_graph(homo_g, HOMOGENEOUS_GRAPH_NAME, 81 | world_size, partition_homo_dir, 82 | num_hops=1, balance_edges=True, 83 | return_mapping=True) 84 | node_mappings[f"homogeneous_{world_size}"] = node_map 85 | 86 | partition_hetero_dir = os.path.join(temp_dir, f"heterogeneous_{world_size}") 87 | os.makedirs(partition_hetero_dir) 88 | node_map, _ =dgl.distributed.partition_graph(hetero_g, HETEROGENEOUS_GRAPH_NAME, 89 | world_size, partition_hetero_dir, 90 | num_hops=1, balance_edges=True, 91 | return_mapping=True) 92 | node_mappings[f"heterogeneous_{world_size}"] = node_map 93 | 94 | mp_dict["homo_graph"] = homo_g 95 | mp_dict["hetero_graph"] = hetero_g 96 | mp_dict["node_map"] = node_mappings 97 | 98 | 99 | def get_random_graph(): 100 | """ 101 | Generates small homogenous graph with features and labels 102 | 103 | :returns: dgl graph 104 | """ 105 | graph = dgl.rand_graph(1000, 2500) 106 | graph = dgl.add_self_loop(graph) 107 | graph.ndata.clear() 108 | graph.ndata['features'] = torch.rand((graph.num_nodes(), 10)) 109 | graph.ndata['labels'] = torch.randint(0, 10, (graph.num_nodes(),)) 110 | return graph 111 | 112 | 113 | def get_random_hetero_graph(): 114 | """ 115 | Generates small heterogenous graph with node features and labels only for the first node type 116 | 117 | :returns: dgl graph 118 | """ 119 | graph_data = { 120 | ("n_type_1", "rel_1", "n_type_1"): (torch.randint(0, 800, (1000,)), torch.randint(0, 800, (1000,))), 121 | ("n_type_1", "rel_2", "n_type_3"): (torch.randint(0, 800, (1000,)), torch.randint(0, 800, (1000,))), 122 | ("n_type_4", "rel_3", "n_type_2"): (torch.randint(0, 800, (1000,)), torch.randint(0, 800, (1000,))), 123 | ("n_type_4", "rel_4", "n_type_1"): (torch.randint(0, 800, (1000,)), torch.randint(0, 800, (1000,))), 124 | ("n_type_1", "rev-rel_1", "n_type_1"): (torch.randint(0, 800, (1000,)), torch.randint(0, 800, (1000,))), 125 | ("n_type_3", "rev-rel_2", "n_type_1"): (torch.randint(0, 800, (1000,)), torch.randint(0, 800, (1000,))), 126 | ("n_type_2", "rev-rel_3", "n_type_4"): (torch.randint(0, 800, (1000,)), torch.randint(0, 800, (1000,))), 127 | ("n_type_1", "rev-rel_4", "n_type_4"): (torch.randint(0, 800, (1000,)), torch.randint(0, 800, (1000,))) 128 | } 129 | hetero_graph = dgl.heterograph(graph_data) 130 | hetero_graph.nodes["n_type_1"].data["features"] = torch.rand((hetero_graph.num_nodes("n_type_1"), 10)) 131 | hetero_graph.nodes["n_type_1"].data["labels"] = torch.randint(0, 10, (hetero_graph.num_nodes("n_type_1"),)) 132 | return hetero_graph 133 | -------------------------------------------------------------------------------- /tests/constants.py: -------------------------------------------------------------------------------- 1 | HOMOGENEOUS_GRAPH_NAME = "dummy_homogeneous_graph" 2 | HETEROGENEOUS_GRAPH_NAME = "dummy_heterogeneous_graph" 3 | -------------------------------------------------------------------------------- /tests/models.py: -------------------------------------------------------------------------------- 1 | ######################################################################################## 2 | # File's content was partially taken from 3 | # https://github.com/dmlc/dgl/blob/master/examples/pytorch/ogb/ogbn-mag/hetero_rgcn.py 4 | ######################################################################################## 5 | 6 | import dgl 7 | import dgl.nn as dglnn 8 | from dgl.nn import HeteroEmbedding 9 | from dgl import DGLGraph 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | class GNNModel(nn.Module): 14 | def __init__(self, in_dim: int, h_dim: int, out_dim: int): 15 | super().__init__() 16 | 17 | self.convs = nn.ModuleList([ 18 | dgl.nn.GraphConv(in_dim, h_dim, activation=F.relu, bias=False), 19 | dgl.nn.GraphConv(h_dim, h_dim, activation=F.relu, bias=False), 20 | dgl.nn.GraphConv(h_dim, out_dim, activation=None, bias=False), 21 | ]) 22 | 23 | def forward(self, graph, features): 24 | if isinstance(graph, list): 25 | # Message Flow Graph 26 | for conv, block in zip(self.convs, graph): 27 | features = conv(block, features) 28 | else: 29 | # Whole graph 30 | for conv in self.convs: 31 | features = conv(graph, features) 32 | return features 33 | 34 | 35 | class HeteroGNNModel(nn.Module): 36 | def __init__(self, g: DGLGraph, in_dim: int, h_dim: int, out_dim: int): 37 | super().__init__() 38 | self.rel_names = list(set(g.etypes)) 39 | self.rel_names.sort() 40 | 41 | self.layers = nn.ModuleList([ 42 | RelGraphConvLayer(in_dim, h_dim, g.ntypes, self.rel_names, activation=F.relu), 43 | RelGraphConvLayer(h_dim, h_dim, g.ntypes, self.rel_names, activation=F.relu), 44 | RelGraphConvLayer(h_dim, out_dim, g.ntypes, self.rel_names, activation=None) 45 | ]) 46 | 47 | def forward(self, graph, h): 48 | if isinstance(graph, list): 49 | # Message Flow Graph 50 | for layer, block in zip(self.layers, graph): 51 | h = layer(block, h) 52 | else: 53 | # Whole graph 54 | for layer in self.layers: 55 | h = layer(graph, h) 56 | return h 57 | 58 | class RelGraphConvLayer(nn.Module): 59 | def __init__(self, in_feat, out_feat, ntypes, rel_names, activation=None): 60 | super().__init__() 61 | self.rel_names = rel_names 62 | self.activation = activation 63 | 64 | self.conv = dglnn.HeteroGraphConv( 65 | { 66 | rel: dglnn.GraphConv(in_feat, out_feat, norm="right", weight=False, bias=False) 67 | for rel in rel_names 68 | } 69 | ) 70 | self.weight = nn.ModuleDict({ 71 | rel_name: nn.Linear(in_feat, out_feat, bias=False) for rel_name in rel_names 72 | } 73 | ) 74 | self.loop_weights = nn.ModuleDict({ 75 | ntype: nn.Linear(in_feat, out_feat, bias=True) for ntype in ntypes 76 | } 77 | ) 78 | 79 | def forward(self, g, inputs): 80 | with g.local_scope(): 81 | wdict = { 82 | rel_name: {"weight": self.weight[rel_name].weight.T} 83 | for rel_name in self.rel_names 84 | } 85 | inputs_dst = { 86 | k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items() 87 | } 88 | hs = self.conv(g, inputs, mod_kwargs=wdict) 89 | 90 | def _apply(ntype, h): 91 | h = h + self.loop_weights[ntype](inputs_dst[ntype]) 92 | if self.activation: 93 | h = self.activation(h) 94 | return h 95 | 96 | return {ntype: _apply(ntype, h) for ntype, h in hs.items()} 97 | 98 | 99 | def extract_embed(node_embed, input_nodes, skip_type=None): 100 | emb = node_embed( 101 | {ntype: input_nodes[ntype] for ntype in input_nodes if ntype != skip_type} 102 | ) 103 | return emb 104 | 105 | 106 | def rel_graph_embed(graph, embed_size, num_nodes_dict=None, skip_type=None): 107 | node_num = {} 108 | for ntype in graph.ntypes: 109 | if ntype == skip_type: 110 | continue 111 | if num_nodes_dict != None: 112 | node_num[ntype] = num_nodes_dict[ntype] 113 | else: 114 | node_num[ntype] = graph.num_nodes(ntype) 115 | embeds = HeteroEmbedding(node_num, embed_size) 116 | return embeds 117 | -------------------------------------------------------------------------------- /tests/multiprocessing_utils.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import traceback 3 | import pytest 4 | import functools 5 | import tempfile 6 | 7 | 8 | def handle_mp_exception(mp_dict): 9 | """ 10 | Used to handle exceptions that occurred in child processes 11 | 12 | :param mp_dict: Dictionary that is shared between different processes 13 | :type mp_dict: multiprocessing.managers.DictProxy 14 | """ 15 | msg = mp_dict.get('traceback', "") 16 | for e_arg in mp_dict['exception'].args: 17 | msg += str(e_arg) 18 | print(str(msg), flush=True) 19 | pytest.fail(str(msg), pytrace=False) 20 | 21 | 22 | def run_workers(func, fixture_env, world_size, *args, **kwargs): 23 | """ 24 | Starts `world_size` number of processes, where each of them 25 | behaves as a separate worker and invokes function specified 26 | by the parameter. This function should be an entry point to the 27 | 'independent' process. It has to simulate behaviour of SAR which 28 | will be spawned across different machines independently from other 29 | instances. Each process have individual memory space so it is 30 | suitable environment for testing SAR. 31 | 32 | :param func: The function that will be invoked by each process. It should take four 33 | parameters: mp_dict - shared dictionary between different processes, rank - of the current machine, 34 | world_size - number of workers, tmp_dir - path to the working directory (additionaly one can pass args and kwargs) 35 | :type func: function 36 | :param fixture_env: named tuple with all of the necessary information about preapred environment for the tests 37 | :type fixture_env: FixtureEnv 38 | :param world_size: number of workers 39 | :type world_size: int 40 | :returns: mp_dict which can be used by workers to return 41 | results from `func` 42 | """ 43 | manager = mp.Manager() 44 | mp_dict = manager.dict() 45 | processes = [] 46 | for rank in range(1, world_size): 47 | my_args = (mp_dict, rank, world_size, fixture_env) + args 48 | p = mp.Process(target=func, args=my_args, kwargs=kwargs) 49 | p.daemon = True 50 | p.start() 51 | processes.append(p) 52 | func(mp_dict, 0, world_size, fixture_env, *args, **kwargs) 53 | 54 | for p in processes: 55 | p.join() 56 | if 'exception' in mp_dict: 57 | handle_mp_exception(mp_dict) 58 | return mp_dict 59 | 60 | 61 | def sar_test(func): 62 | """ 63 | A decorator function that wraps all SAR tests with the primary objective 64 | of facilitating module imports in tests without affecting other tests. 65 | 66 | :param func: The function that serves as the entry point to the test. 67 | :type func: function 68 | :returns: A function that encapsulates the pytest function. 69 | """ 70 | @functools.wraps(func) 71 | def test_wrapper(*args, **kwargs): 72 | """ 73 | The wrapping process involves defining another nested function, which is then invoked by a newly spawned process. 74 | function spawns a new process and uses the "join" method to wait for the results. 75 | Upon completion of the process, error and result handling are performed. 76 | """ 77 | def process_wrapper(func, mp_dict, *args, **kwargs): 78 | try: 79 | result = func(*args, **kwargs) 80 | mp_dict["result"] = result 81 | except Exception as e: 82 | mp_dict['traceback'] = str(traceback.format_exc()) 83 | mp_dict["exception"] = e 84 | 85 | manager = mp.Manager() 86 | mp_dict = manager.dict() 87 | 88 | mp_args = (func, mp_dict) + args 89 | p = mp.Process(target=process_wrapper, args=mp_args, kwargs=kwargs) 90 | p.start() 91 | p.join() 92 | 93 | if 'exception' in mp_dict: 94 | handle_mp_exception(mp_dict) 95 | 96 | return mp_dict["result"] 97 | return test_wrapper 98 | -------------------------------------------------------------------------------- /tests/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = tests/ 3 | python_files = test_*.py 4 | python_classes = Test* 5 | python_functions = test_* -------------------------------------------------------------------------------- /tests/test_comm.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | from multiprocessing_utils import * 4 | from constants import * 5 | # Do not import DGL and SAR - these modules should be 6 | # independently loaded inside each process 7 | 8 | 9 | @pytest.mark.parametrize("backend", ["ccl", "gloo"]) 10 | @pytest.mark.parametrize('world_size', [2, 4, 8]) 11 | @sar_test 12 | def test_sync_params(world_size, backend, fixture_env): 13 | """ 14 | Checks whether model's parameters are the same across all 15 | workers after calling sync_params function. Parameters of worker 0 16 | should be copied to all workers, so its parameters before and after 17 | sync_params should be the same 18 | """ 19 | def sync_params(mp_dict, rank, world_size, fixture_env, **kwargs): 20 | import torch 21 | import sar 22 | from base_utils import initialize_worker, synchronize_processes 23 | from models import GNNModel 24 | 25 | temp_dir = fixture_env.temp_dir 26 | initialize_worker(rank, world_size, temp_dir, backend=kwargs["backend"]) 27 | model = GNNModel(16, 8, 4) 28 | if rank == 0: 29 | mp_dict[f"result_{rank}"] = deepcopy(model.state_dict()) 30 | sar.sync_params(model) 31 | if rank != 0: 32 | mp_dict[f"result_{rank}"] = model.state_dict() 33 | 34 | synchronize_processes() 35 | for rank in range(1, world_size): 36 | for key in mp_dict[f"result_0"].keys(): 37 | assert torch.all(torch.eq(mp_dict[f"result_0"][key], mp_dict[f"result_{rank}"][key])) 38 | 39 | run_workers(sync_params, fixture_env, world_size, backend=backend) 40 | 41 | 42 | @pytest.mark.parametrize("backend", ["ccl", "gloo"]) 43 | @pytest.mark.parametrize('world_size', [2, 4, 8]) 44 | @sar_test 45 | def test_gather_grads(world_size, backend, fixture_env): 46 | """ 47 | Checks whether parameter's gradients are the same across all 48 | workers after calling gather_grads function 49 | """ 50 | def gather_grads(mp_dict, rank, world_size, fixture_env, **kwargs): 51 | import torch 52 | import sar 53 | import torch.nn.functional as F 54 | from models import GNNModel 55 | from base_utils import initialize_worker, synchronize_processes, load_partition_data 56 | 57 | temp_dir = fixture_env.temp_dir 58 | initialize_worker(rank, world_size, temp_dir, backend=kwargs["backend"]) 59 | fgm, feat, labels = load_partition_data(rank, HOMOGENEOUS_GRAPH_NAME, 60 | os.path.join(temp_dir, f"homogeneous_{world_size}")) 61 | model = GNNModel(feat.shape[1], feat.shape[1], labels.max()+1) 62 | sar.sync_params(model) 63 | sar_logits = model(fgm, feat) 64 | sar_loss = F.cross_entropy(sar_logits, labels) 65 | sar_loss.backward() 66 | sar.gather_grads(model) 67 | mp_dict[f"result_{rank}"] = [torch.tensor(x.grad) for x in model.parameters()] 68 | 69 | synchronize_processes() 70 | for rank in range(1, world_size): 71 | for i in range(len(mp_dict["result_0"])): 72 | assert torch.all(torch.eq(mp_dict["result_0"][i], mp_dict[f"result_{rank}"][i])) 73 | 74 | run_workers(gather_grads, fixture_env, world_size, backend=backend) 75 | 76 | 77 | @pytest.mark.parametrize("backend", ["ccl", "gloo"]) 78 | @pytest.mark.parametrize("world_size", [2, 4, 8]) 79 | @sar_test 80 | def test_all_to_all(world_size, backend, fixture_env): 81 | """ 82 | Checks whether all_to_all operation works as expected. Test is 83 | designed is such a way, that after calling all_to_all, each worker 84 | should receive a list of tensors with values equal to their rank 85 | """ 86 | def all_to_all(mp_dict, rank, world_size, fixture_env, **kwargs): 87 | import torch 88 | import sar 89 | from base_utils import initialize_worker, synchronize_processes 90 | 91 | temp_dir = fixture_env.temp_dir 92 | initialize_worker(rank, world_size, temp_dir, backend=kwargs["backend"]) 93 | send_tensors_list = [torch.tensor([x] * world_size) for x in range(world_size)] 94 | recv_tensors_list = [torch.tensor([-1] * world_size) for _ in range(world_size)] 95 | sar.comm.all_to_all(recv_tensors_list, send_tensors_list) 96 | mp_dict[f"result_{rank}"] = recv_tensors_list 97 | 98 | synchronize_processes() 99 | for rank in range(world_size): 100 | for tensor in mp_dict[f"result_{rank}"]: 101 | assert torch.all(torch.eq(tensor, torch.tensor([rank] * world_size))) 102 | 103 | run_workers(all_to_all, fixture_env, world_size, backend=backend) 104 | 105 | 106 | @pytest.mark.parametrize("backend", ["ccl", "gloo"]) 107 | @pytest.mark.parametrize("world_size", [2, 4, 8]) 108 | @sar_test 109 | def test_exchange_single_tensor(world_size, backend, fixture_env): 110 | """ 111 | Checks whether exchange_single_tensor operation works as expected. Test is 112 | designed is such a way, that after calling exchange_single_tensor between two machines, 113 | machine should recive a tensor with values equal to their rank 114 | """ 115 | def exchange_single_tensor(mp_dict, rank, world_size, fixture_env, **kwargs): 116 | import torch 117 | import sar 118 | from base_utils import initialize_worker, synchronize_processes 119 | 120 | temp_dir = fixture_env.temp_dir 121 | initialize_worker(rank, world_size, temp_dir, backend=kwargs["backend"]) 122 | send_idx = rank 123 | recv_idx = rank 124 | results = [] 125 | for _ in range(world_size): 126 | send_tensor = torch.tensor([send_idx] * world_size) 127 | recv_tensor = torch.tensor([-1] * world_size) 128 | sar.comm.exchange_single_tensor(recv_idx, send_idx, recv_tensor, send_tensor) 129 | results.append(recv_tensor) 130 | send_idx = (send_idx + 1) % world_size 131 | recv_idx = (recv_idx - 1) % world_size 132 | mp_dict[f"result_{rank}"] = results 133 | 134 | synchronize_processes() 135 | for recv_tensor in mp_dict[f"result_{rank}"]: 136 | assert torch.all(torch.eq(recv_tensor, torch.tensor([rank] * world_size))) 137 | 138 | run_workers(exchange_single_tensor, fixture_env, world_size, backend=backend) 139 | -------------------------------------------------------------------------------- /tests/test_patch_dgl.py: -------------------------------------------------------------------------------- 1 | from multiprocessing_utils import * 2 | # Do not import DGL and SAR - these modules should be 3 | # independently loaded inside each process 4 | 5 | 6 | @sar_test 7 | def test_patch_dgl(): 8 | """ 9 | Import DGL library and SAR and check whether `patch_dgl` function 10 | overrides edge_softmax function in specific GNN layers implementation. 11 | """ 12 | import dgl 13 | original_gat_edge_softmax = dgl.nn.pytorch.conv.gatconv.edge_softmax 14 | original_dotgat_edge_softmax = dgl.nn.pytorch.conv.dotgatconv.edge_softmax 15 | original_agnn_edge_softmax = dgl.nn.pytorch.conv.agnnconv.edge_softmax 16 | 17 | import sar 18 | sar.patch_dgl() 19 | 20 | assert original_gat_edge_softmax == dgl.nn.functional.edge_softmax 21 | assert original_dotgat_edge_softmax == dgl.nn.functional.edge_softmax 22 | assert original_agnn_edge_softmax == dgl.nn.functional.edge_softmax 23 | 24 | assert dgl.nn.pytorch.conv.gatconv.edge_softmax == sar.patched_edge_softmax 25 | assert dgl.nn.pytorch.conv.dotgatconv.edge_softmax == sar.patched_edge_softmax 26 | assert dgl.nn.pytorch.conv.RelGraphConv == sar.RelGraphConv 27 | assert dgl.nn.RelGraphConv == sar.RelGraphConv 28 | -------------------------------------------------------------------------------- /tests/test_sar.py: -------------------------------------------------------------------------------- 1 | from multiprocessing_utils import * 2 | from constants import * 3 | import os 4 | # Do not import DGL and SAR - these modules should be 5 | # independently loaded inside each process 6 | 7 | 8 | @pytest.mark.parametrize("backend", ["ccl", "gloo"]) 9 | @pytest.mark.parametrize('world_size', [1, 2, 4, 8]) 10 | @sar_test 11 | def test_homogeneous_fgm(world_size, backend, fixture_env): 12 | """ 13 | Perform full graph inference using SAR algorithm on homogeneous graph. 14 | Test is comparing mean of concatenated results from all processes 15 | with mean of native DGL full graph inference result. 16 | """ 17 | import torch 18 | def homogeneous_fgm(mp_dict, rank, world_size, fixture_env, **kwargs): 19 | import sar 20 | from models import GNNModel 21 | from base_utils import initialize_worker, load_partition_data 22 | 23 | temp_dir = fixture_env.temp_dir 24 | initialize_worker(rank, world_size, temp_dir, backend=kwargs["backend"]) 25 | fgm, feat, labels = load_partition_data(rank, HOMOGENEOUS_GRAPH_NAME, 26 | os.path.join(temp_dir, f"homogeneous_{world_size}")) 27 | model = GNNModel(feat.shape[1], feat.shape[1], labels.max()+1).to('cpu') 28 | sar.sync_params(model) 29 | model.eval() 30 | sar_logits = model(fgm, feat) 31 | 32 | mp_dict[f"result_{rank}"] = sar_logits.detach() 33 | if rank == 0: 34 | mp_dict["model"] = model 35 | mp_dict["graph"] = fixture_env.homo_graph 36 | mp_dict["node_map"] = fixture_env.node_map[f"homogeneous_{world_size}"] 37 | 38 | mp_dict = run_workers(homogeneous_fgm, fixture_env, world_size, backend=backend) 39 | 40 | model = mp_dict["model"] 41 | graph = mp_dict["graph"] 42 | dgl_logits = model(graph, graph.ndata['features']).detach() 43 | dgl_logits_mean = dgl_logits.mean(axis=1) 44 | 45 | sar_logits = torch.tensor([]) 46 | for rank in range(world_size): 47 | sar_logits = torch.cat((sar_logits, mp_dict[f"result_{rank}"])) 48 | sar_logits[mp_dict["node_map"]] = sar_logits.clone() 49 | sar_logits_mean = sar_logits.mean(axis=1) 50 | 51 | assert torch.all(torch.isclose(dgl_logits_mean, sar_logits_mean, atol=1e-6, rtol=1e-6)) 52 | 53 | @pytest.mark.parametrize("backend", ["ccl", "gloo"]) 54 | @pytest.mark.parametrize('world_size', [1, 2, 4, 8]) 55 | @sar_test 56 | def test_heterogeneous_fgm(world_size, backend, fixture_env): 57 | """ 58 | Perform full graph inference using SAR algorithm on heterogeneous graph. 59 | Test is comparing mean of concatenated results from all processes 60 | with mean of native DGL full graph inference result. 61 | """ 62 | import torch 63 | from models import rel_graph_embed 64 | def heterogeneous_fgm(mp_dict, rank, world_size, fixture_env, **kwargs): 65 | import sar 66 | from models import HeteroGNNModel, extract_embed 67 | from base_utils import initialize_worker, load_partition_data 68 | 69 | temp_dir = fixture_env.temp_dir 70 | initialize_worker(rank, world_size, temp_dir, backend=kwargs["backend"]) 71 | fgm, feats, labels = load_partition_data(rank, HETEROGENEOUS_GRAPH_NAME, 72 | os.path.join(temp_dir, f"heterogeneous_{world_size}")) 73 | model = HeteroGNNModel(fgm, feats.shape[1], feats.shape[1], labels.max()+1).to('cpu') 74 | model.eval() 75 | sar.sync_params(model) 76 | 77 | to_extract = {} 78 | node_map = fixture_env.node_map[f"heterogeneous_{world_size}"] 79 | for ntype in fgm.srctypes: 80 | if ntype == "n_type_1": 81 | continue 82 | down_lim = fgm._partition_book.partid2nids(rank, ntype).min() 83 | up_lim = fgm._partition_book.partid2nids(rank+1, ntype).min() if rank+1 < world_size else None 84 | ids = node_map[ntype][down_lim:up_lim] 85 | to_extract[ntype] = ids 86 | 87 | embed_layer = kwargs["embed_layer"] 88 | embeds = extract_embed(embed_layer, to_extract, skip_type="n_type_1") 89 | embeds.update({"n_type_1": feats[fgm.srcnodes("n_type_1")]}) 90 | embeds = {k: e.to("cpu") for k, e in embeds.items()} 91 | 92 | sar_logits = model(fgm, embeds) 93 | sar_logits = sar_logits["n_type_1"] 94 | 95 | mp_dict[f"result_{rank}"] = sar_logits.detach() 96 | if rank == 0: 97 | mp_dict["model"] = model 98 | mp_dict["node_map"] = fixture_env.node_map[f"heterogeneous_{world_size}"] 99 | 100 | graph = fixture_env.hetero_graph 101 | max_num_nodes = {ntype: graph.num_nodes(ntype) for ntype in graph.ntypes} 102 | embed_layer = rel_graph_embed(graph, graph.ndata["features"]["n_type_1"].shape[1], 103 | num_nodes_dict=max_num_nodes, 104 | skip_type="n_type_1").to('cpu') 105 | mp_dict = run_workers(heterogeneous_fgm, fixture_env, world_size, backend=backend, 106 | embed_layer=embed_layer) 107 | 108 | embeds = embed_layer.weight 109 | embeds.update({"n_type_1": graph.ndata["features"]["n_type_1"]}) 110 | embeds = {k: e.to("cpu") for k, e in embeds.items()} 111 | 112 | model = mp_dict["model"] 113 | dgl_logits = model(graph, embeds) 114 | dgl_logits = dgl_logits["n_type_1"].detach() 115 | dgl_logits_mean = dgl_logits.mean(axis=1) 116 | 117 | sar_logits = torch.tensor([]) 118 | for rank in range(world_size): 119 | sar_logits = torch.cat((sar_logits, mp_dict[f"result_{rank}"])) 120 | sar_logits[mp_dict["node_map"]["n_type_1"]] = sar_logits.clone() 121 | sar_logits_mean = sar_logits.mean(axis=1) 122 | 123 | assert torch.all(torch.isclose(dgl_logits_mean, sar_logits_mean, atol=1e-6, rtol=1e-6)) 124 | 125 | 126 | @pytest.mark.parametrize("backend", ["ccl", "gloo"]) 127 | @pytest.mark.parametrize("world_size", [1, 2, 4, 8]) 128 | @sar_test 129 | def test_homogeneous_mfg(world_size, backend, fixture_env): 130 | """ 131 | Perform full graph inference using SAR algorithm on homogeneous graph. 132 | Script is using Message Flow Graph (mfg). Test is comparing mean of 133 | concatenated results from all processes with mean of native DGL full 134 | graph inference result. 135 | """ 136 | import torch 137 | def homogeneous_mfg(mp_dict, rank, world_size, fixture_env, **kwargs): 138 | import sar 139 | from models import GNNModel 140 | from base_utils import initialize_worker, load_partition_data_mfg 141 | 142 | temp_dir = fixture_env.temp_dir 143 | initialize_worker(rank, world_size, temp_dir, backend=kwargs["backend"]) 144 | blocks, feat, labels = load_partition_data_mfg(rank, HOMOGENEOUS_GRAPH_NAME, 145 | os.path.join(temp_dir, f"homogeneous_{world_size}")) 146 | model = GNNModel(feat.shape[1], feat.shape[1], labels.max()+1).to('cpu') 147 | sar.sync_params(model) 148 | model.eval() 149 | sar_logits = model(blocks, feat) 150 | 151 | mp_dict[f"result_{rank}"] = sar_logits.detach() 152 | if rank == 0: 153 | mp_dict["model"] = model 154 | mp_dict["graph"] = fixture_env.homo_graph 155 | mp_dict["node_map"] = fixture_env.node_map[f"homogeneous_{world_size}"] 156 | 157 | mp_dict = run_workers(homogeneous_mfg, fixture_env, world_size, backend=backend) 158 | 159 | model = mp_dict["model"] 160 | graph = mp_dict["graph"] 161 | dgl_logits = model(graph, graph.ndata['features']).detach() 162 | dgl_logits_mean = dgl_logits.mean(axis=1) 163 | 164 | sar_logits = torch.tensor([]) 165 | for rank in range(world_size): 166 | sar_logits = torch.cat((sar_logits, mp_dict[f"result_{rank}"])) 167 | sar_logits[mp_dict["node_map"]] = sar_logits.clone() 168 | sar_logits_mean = sar_logits.mean(axis=1) 169 | 170 | assert torch.all(torch.isclose(dgl_logits_mean, sar_logits_mean, atol=1e-6, rtol=1e-6)) 171 | 172 | 173 | @pytest.mark.parametrize("backend", ["ccl", "gloo"]) 174 | @pytest.mark.parametrize('world_size', [1, 2, 4, 8]) 175 | @sar_test 176 | def test_heterogeneous_mfg(world_size, backend, fixture_env): 177 | """ 178 | Perform full graph inference using SAR algorithm on heterogeneous graph. 179 | Script is using Message Flow Graph (mfg). Test is comparing mean of 180 | concatenated results from all processes with mean of native DGL full 181 | graph inference result. 182 | """ 183 | import torch 184 | from models import rel_graph_embed 185 | def heterogeneous_mfg(mp_dict, rank, world_size, fixture_env, **kwargs): 186 | import sar 187 | from models import HeteroGNNModel, extract_embed 188 | from base_utils import initialize_worker, load_partition_data_mfg 189 | 190 | temp_dir = fixture_env.temp_dir 191 | initialize_worker(rank, world_size, temp_dir, backend=kwargs["backend"]) 192 | 193 | blocks, feats, labels = load_partition_data_mfg(rank, HETEROGENEOUS_GRAPH_NAME, 194 | os.path.join(temp_dir, f"heterogeneous_{world_size}")) 195 | model = HeteroGNNModel(blocks[0], feats.shape[1], feats.shape[1], labels.max()+1).to('cpu') 196 | model.eval() 197 | sar.sync_params(model) 198 | 199 | to_extract = {} 200 | node_map = fixture_env.node_map[f"heterogeneous_{world_size}"] 201 | for ntype in blocks[0].srctypes: 202 | if ntype == "n_type_1": 203 | continue 204 | down_lim = blocks[0]._partition_book.partid2nids(rank, ntype).min() 205 | up_lim = blocks[0]._partition_book.partid2nids(rank+1, ntype).min() if rank+1 < world_size else None 206 | ids = node_map[ntype][down_lim:up_lim] 207 | to_extract[ntype] = ids[blocks[0].srcnodes(ntype)] 208 | 209 | embed_layer = kwargs["embed_layer"] 210 | embeds = extract_embed(embed_layer, to_extract, skip_type="n_type_1") 211 | embeds.update({"n_type_1": feats[blocks[0].srcnodes("n_type_1")]}) 212 | embeds = {k: e.to("cpu") for k, e in embeds.items()} 213 | 214 | sar_logits = model(blocks, embeds) 215 | sar_logits = sar_logits["n_type_1"] 216 | 217 | mp_dict[f"result_{rank}"] = sar_logits.detach() 218 | if rank == 0: 219 | mp_dict["model"] = model 220 | mp_dict["node_map"] = fixture_env.node_map[f"heterogeneous_{world_size}"] 221 | 222 | graph = fixture_env.hetero_graph 223 | max_num_nodes = {ntype: graph.num_nodes(ntype) for ntype in graph.ntypes} 224 | embed_layer = rel_graph_embed(graph, graph.ndata["features"]["n_type_1"].shape[1], 225 | num_nodes_dict=max_num_nodes, 226 | skip_type="n_type_1").to('cpu') 227 | mp_dict = run_workers(heterogeneous_mfg, fixture_env, world_size, backend=backend, 228 | embed_layer=embed_layer) 229 | 230 | embeds = embed_layer.weight 231 | embeds.update({"n_type_1": graph.ndata["features"]["n_type_1"]}) 232 | embeds = {k: e.to("cpu") for k, e in embeds.items()} 233 | 234 | model = mp_dict["model"] 235 | dgl_logits = model(graph, embeds) 236 | dgl_logits = dgl_logits["n_type_1"].detach() 237 | dgl_logits_mean = dgl_logits.mean(axis=1) 238 | 239 | sar_logits = torch.tensor([]) 240 | for rank in range(world_size): 241 | sar_logits = torch.cat((sar_logits, mp_dict[f"result_{rank}"])) 242 | sar_logits[mp_dict["node_map"]["n_type_1"]] = sar_logits.clone() 243 | sar_logits_mean = sar_logits.mean(axis=1) 244 | 245 | assert torch.all(torch.isclose(dgl_logits_mean, sar_logits_mean, atol=1e-6, rtol=1e-6)) 246 | 247 | 248 | @pytest.mark.parametrize("backend", ["ccl"]) 249 | @pytest.mark.parametrize('world_size', [1]) 250 | @sar_test 251 | def test_convert_dist_graph(world_size, backend, fixture_env): 252 | """ 253 | Create DGL's DistGraph object with random graph partitioned into 254 | one part (only way to test DistGraph locally). Then perform converting 255 | DistGraph into SAR GraphShardManager and check relevant properties. 256 | """ 257 | def convert_dist_graph(mp_dict, rank, world_size, fixture_env, **kwargs): 258 | import dgl 259 | import sar 260 | from base_utils import initialize_worker 261 | 262 | temp_dir = fixture_env.temp_dir 263 | partition_file = os.path.join(temp_dir, f'homogeneous_{world_size}', f"{HOMOGENEOUS_GRAPH_NAME}.json") 264 | initialize_worker(rank, world_size, temp_dir, backend=kwargs["backend"]) 265 | dgl.distributed.initialize("kv_ip_config.txt") 266 | dist_g = dgl.distributed.DistGraph( 267 | HOMOGENEOUS_GRAPH_NAME, part_config=partition_file) 268 | 269 | sar_g = sar.convert_dist_graph(dist_g) 270 | assert len(sar_g.graph_shard_managers[0].graph_shards) == dist_g.get_partition_book().num_partitions() 271 | assert dist_g.num_edges() == sar_g.num_edges() 272 | assert dist_g.num_nodes() == sar_g.num_nodes() 273 | assert dist_g.ntypes == sar_g.ntypes 274 | assert dist_g.etypes == sar_g.etypes 275 | 276 | run_workers(convert_dist_graph, fixture_env, world_size, backend=backend) 277 | --------------------------------------------------------------------------------